DEV Community πŸ‘©β€πŸ’»πŸ‘¨β€πŸ’»

DEV Community πŸ‘©β€πŸ’»πŸ‘¨β€πŸ’» is a community of 967,911 amazing developers

We're a place where coders share, stay up-to-date and grow their careers.

Create account Log in
Ashish Singh
Ashish Singh

Posted on

Building Redis Server in Rust: Part 2

In the previous blog post, we covered single threaded client server communication. In this blog post, we will be covering multi-threaded server client communication. Every new client request will be handled in its own separate thread. We will be using tokio::spawn to spawn a new thread to handle the incoming request from the client.

We will also be implementing server shutdown using ctrl + c. It will close the server but the clients will be gracefully shutdown without panicking. For this we need to communication from server to client. Let's use tokio::sync::broadcast channel broadcast messages to client. But this is only possible when all the clients subscribe to the broadcast sender. This means our client will be holding at least fields socket and receiver.

Also, since multiple clients can connect simultaneously, they also read and write data simultaneously. This means, we also need to supply a copy of Db to each thread. As of now, we just have one instance variable of Db, this object needs to be shared across the threads. For this purpose, 2 things need to happen.

  1. We will use Arc<Mutex> and wrap it around the HashMap<String, Bytes> in our Db to create shared references. We could use Rc<RefCell> for this purpose, but since we are talking about multi-threading Rc is not a great fit as it is not thread-safe.
  2. Since, we are looking at many values associated with a client connection, we should create a struct Handler to encapsulate all the fields together.

Let's try to breakdown our code such that hitting ctrl + c while the server is running triggers the shutdown process for the server. In the main function of bin/server.rs let's capture the shutdown signal using tokio::signal.

use tokio::signal;

// in main function
let shutdown = signal::ctrl_c();
Enter fullscreen mode Exit fullscreen mode

Whenever we hit ctrl+ c this shutdown future completes on the .await or Future::poll.

We want to run 2 branches at this point, one to handle the incoming requests from clients and the other to listen on this shutdown. So whenever ctrl+c is received the shutdown process starts dropping the client handle branch. This is where tokio::select comes into picture. It takes any number of async branches in form of futures(similar to promise in js) and runs them concurrently, waiting for value from any branch. Upon receiving value from a branch, it execuates the handler function and drops the rest of the branches.

tokio::select! {
    res = server::run(&mut listener) => {
            if let Err(_err) = res {
            println!("failed to accept connection");
        }
    }
    _ = shutdown => {
        println!("inside shutdown loop");
    }
}
// execution reaches here only when shutdown future returns value or server encounters and error.
Enter fullscreen mode Exit fullscreen mode

In the code above, both the branches will run and wait for any one to return a value. To keep the server running, we will not return a value from server::run unless there is an error. If we hit ctrl+c at this point, then shutdown branch will run stopping our server::run() future.

At this point, lets create our Handler and Listener structs. The Listener struct is the server object holding TcpListener and broadcast channel.

Handler is also responsible for following in addition to manage socket and db clone:-

  1. Receiving shutdown notification from server, in case server start shutting down.
  2. Notifying the server when the handler is going out of scope.

Let's combine the shutdown related functionality into a separate field into a Shutdown struct and move the socket into a new field which is Connection struct.

File src/handler.rs

pub struct Handler {
    pub connection: Connection,
    pub db: Db,
    pub shutdown: Shutdown,
    // when the handler object is dropped, this sends a message to the receiver
    _shutdown_complete: mpsc::Sender<()>,
}

pub struct Connection {
    pub stream: TcpStream,
}

pub struct Shutdown {
    shutdown: bool,
    notify: broadcast::Receiver<()>,
}
Enter fullscreen mode Exit fullscreen mode

The shutdown struct holds notify which is a broadcast receiver which subscribes to the server broadcasting shutdown. It's time to define the Listener struct and connect the broadcast from listener to handler subscriber.

Let's create a new file src/listener.rs.

use tokio::{
    net::{TcpListener, TcpStream},
    sync::{broadcast, mpsc},
};

use crate::Db;

pub struct Listener {
    pub db: Db,
    pub listener: TcpListener,
    pub notify_shutdown: broadcast::Sender<()>,
    pub shutdown_complete_rx: mpsc::Receiver<()>,
    pub shutdown_complete_tx: mpsc::Sender<()>,
}
Enter fullscreen mode Exit fullscreen mode

In listener struct, notify_shutdown is a broadcast channel sender. It will broadcast a message when the listener notify_shutdown is dropped.

The other 2 fields shutdown_complete_rx and shutdown_complete_tx are mpsc objects used to communicate between the multiple handlers to the server. Imagine, there are many clients connected to the server and the server decides to shutdown. In that case, the mpsc::Receiver<()> will only await till all the senders have closed before the mpsc::Receiver<()> stops blocking for receiving the messages. In short, all senders should be completed/dropped for the receiver to allow for further execution of code.

The flow of messages is follows:-

  1. Server start shutting down. It sends a message to all handlers.
  2. All the handlers get this message from the subscription of broadcast::Sender<()> subscription.
  3. All the handlers decide to wrap functioning and drop themselves.
  4. Once the handler object goes out of scope (drops) they send message via mpsc::Sender<()> clone(explained later).
  5. The server receives message in shutdown_complete_rx and knows for sure that handler object is dropped.
  6. shutdown_complete_rx keeps receiving messages till all the handlers go out of scope.
  7. Server closes gracefully having shutdown all the other handlers.

Note that we are using mpsc channel which stands for multiple producers single consumer. It fits our use case well as we are interested in sending messages from multiple handlers to a single server thread.

Now let's look at how we will create objects:-

  1. Upon server start, we will create a listener object, which initiates the Db and channel receiver and senders.
  2. Then we wait for client socket connections in a loop.
  3. When a new socket connection starts, we will create a handler object which will subscribe to the broadcast::sender into a receiver. It will also initiate _shutdown_complete field using mpsc::sender clone.
  4. This handler object then becomes the lifetime of the client connection.

listener object create

In bin/server.rs

use blog_redis::server;
use blog_redis::Listener;
use tokio::signal;
use tokio::{
    net::TcpListener,
    sync::{broadcast, mpsc},
};

#[tokio::main]
pub async fn main() -> Result<(), std::io::Error> {
    let listener = TcpListener::bind("127.0.0.1:8081").await?;
    let shutdown = signal::ctrl_c();
    let (notify_shutdown, _) = broadcast::channel(1);
    let (shutdown_complete_tx, shutdown_complete_rx) = mpsc::channel(1);

    let mut listener = Listener::new(
        listener,
        notify_shutdown,
        shutdown_complete_tx,
        shutdown_complete_rx,
    );

    tokio::select! {
        res = server::run(&mut listener) => {
             if let Err(_err) = res {
               println!("failed to accept connection");
            }
        }
        _ = shutdown => {
            println!("inside shutdown loop");
        }
    }

    // graceful shutdown code will go here. Refer Graceful shutdown section

    Ok(())
}
Enter fullscreen mode Exit fullscreen mode

Next, we will change our src/server.rs file

use crate::Handler;
use crate::Listener;

pub async fn run(listener: &Listener) -> std::io::Result<()> {
    loop {
        let socket = listener.accept().await?;
        let mut handler = Handler::new(listener, socket);

        tokio::spawn(async move {
            if let Err(_err) = process_method(&mut handler).await {
                println!("Connection Error");
            }
        });
    }
}

async fn process_method(handler: &mut Handler) -> Result<(), std::io::Error> {
    while !handler.shutdown.is_shutdown() {
        let result = tokio::select! {
            _ = handler.shutdown.listen_recv() => {
                return Ok(());
            },
            res = handler.connection.read_buf_data() => res,
        };

        let (cmd, vec) = match result {
            Some((cmd, vec)) => (cmd, vec),
            None => return Ok(()),
        };

        handler.process_query(cmd, vec).await?;
    }
    Ok(())
}
Enter fullscreen mode Exit fullscreen mode

Let's now implement all the methods on the Handler, Shutdown and Connection structs.

impl Connection {
    fn new(stream: TcpStream) -> Connection {
        Connection { stream: stream }
    }

    pub async fn read_buf_data(&mut self) -> Option<(Command, Vec<String>)> {
        let mut buf = BytesMut::with_capacity(1024);
        match self.stream.read_buf(&mut buf).await {
            Ok(size) => {
                if size == 0 {
                    // returning from empty buffer
                    return None;
                }
            }
            Err(err) => {
                println!("error {:?}", err);
                return None;
            }
        };
        let attrs = buffer_to_array(&mut buf);
        Some((Command::get_command(&attrs[0]), attrs))
    }
}

impl Shutdown {
    fn new(shutdown: bool, notify: broadcast::Receiver<()>) -> Shutdown {
        Shutdown { shutdown, notify }
    }

    pub async fn listen_recv(&mut self) -> Result<(),     tokio::sync::broadcast::error::RecvError> {
        self.notify.recv().await?; // returns error of type `tokio::sync::broadcast::error::RecvError`
        self.shutdown = true;
        Ok(())
    }

    pub fn is_shutdown(&self) -> bool {
        self.shutdown
    }
}

impl Handler {
    pub fn new(listener: &Listener, socket: TcpStream) -> Handler {
        Handler {
            connection: Connection::new(socket),
            db: listener.db.clone(),
            shutdown: Shutdown::new(false, listener.notify_shutdown.subscribe()),
            _shutdown_complete: listener.shutdown_complete_tx.clone(),
        }
    }

    pub async fn process_query(
        &mut self,
        command: Command,
        attrs: Vec<String>,
    ) -> Result<(), std::io::Error> {
        let connection = &mut self.connection;
        let db = &self.db;

        match command {
            Command::Get => {
                let result = db.read(&attrs);
                match result {
                    Ok(result) => {
                        connection.stream.write_all(&result).await?;
                    }
                    Err(_err) => {
                        connection.stream.write_all(b"").await?;
                    }
                }
                return Ok(());
            }
            Command::Set => {
                let resp = db.write(&attrs);
                match resp {
                    Ok(result) => {
                        connection.stream.write_all(&result.as_bytes()).await?;
                    }
                    Err(_err) => {
                        connection.stream.write_all(b"").await?;
                    }
                }
                return Ok(());
            }
            Command::Invalid => {
                connection.stream.write_all(b"invalid command").await?;
                Err(std::io::Error::from(ErrorKind::InvalidData))
            }
        }
    }
}
Enter fullscreen mode Exit fullscreen mode

Db changes to Arc

In the previous blog post, we had defined Db entries as HashMap<String, Bytes> it worked fine because we just had one-one server client communication. But in case of multi-threading we need to pass shared mutable references to all the handler objects. One way to achieve this is using Arc<Mutex> wrapper on our HashMap. Arc<T> provides the shared ownership of the value it encapsulates T. It is a thread-safe way of updating the shared data. To pass mutable references, we will clone the Arc which produces a new Arc instance which points to the same memory location. Mutex provides a way to get the lock on the data such that only one acquirer can update the data at a time. Once the acquirer is done updating the data it releases the lock for other acquirers to do the same.

File src/db.rs

use bytes::Bytes;
use std::collections::HashMap;
use std::sync::{Arc, Mutex};

#[derive(Clone, Debug)]
pub struct Db {
    pub entries: Arc<Mutex<HashMap<String, Bytes>>>,
}

impl Db {
    pub fn new() -> Db {
        Db {
            entries: Arc::new(Mutex::new(HashMap::new())),
        }
    }

    pub fn write(&self, arr: &[String]) -> Result<&str, &'static str> {
        let key = &arr[1];
        let value = &arr[2];

        // we need to clone the referenced value since Bytes::from() function expects a 'static lifetime
        // variable but `value` has unknown lifetime in this function context
        let val = value.clone();
        let p = self.entries.lock().unwrap()
                    .insert(String::from(key), Bytes::from(val));

        match p {
            Some(_p) => Ok("r Ok"), // if they key was already present
            None => Ok("Ok"),       // if the key was not present
        }
    }

    /// Reads data from the database
    pub fn read(&self, arr: &[String]) -> Result<Bytes, &'static str> {
        let key = &arr[1];
        let query_result = self.entries.lock().unwrap();
        let res = query_result.get(key);
        match res {
            Some(value) => Ok(value.clone()),
            None => Err("no such key found"),
        }
    }
}

Enter fullscreen mode Exit fullscreen mode

You might have noticed that we are calling lock() method on &self.entries where self is Db. There is no such method as lock on entries which is of type Arc. This method is present in Mutex module. The reason why this lock method is accessible on Arc is because both Arc and Mutex implement Deref trait. It stands for derefrencing of the smart pointer and Deref returns the reference to the inner data.

Deref & Deref Coercion

Deref does something known as Deref Coercion. Meaning it will modify the type of the reference into what is expected from the method being called on the reference in this case lock().

Here the entries field of type Arc<Mutex<HashMap<..>>> is coerced to &Mutex<Hashmap<..>>.

It does this by implementing the deref method which borrows the self (entries: Arc<>) and returns a reference to the inner data (&Mutex<>). It happens automatically when we pass a reference to a particular type’s value as an argument to a function or method that doesn’t match the parameter type in the function or method definition. A sequence of calls to the deref method converts the type we provided into the type the parameter needs.

Arc and Mutex implement both Deref and DerefMut. In case of get, we only need a reference to the &HashMap but in case of set we need a mutable reference to the &mut HashMap. .lock().unwrap() returns a MutexGuard<HashMap<..>> which gets coerced to &mut HashMap<..> since the insert method expects a mutable reference to the HashMap.

One thing to notice is we are calling the insert method in a single line call whereas, the get method is called in 2 separate lines. There is a reason for it.

The insert method returns an Option<Bytes> whereas the get method returns an Option<&Bytes>. In case of get we need to persist the reference of the Bytes until the function returns which depends on the lifetime of the hashmap under the hood. But in case of set, the return value is owned by the function itself. Let's try to understand the lifetime of the HashMap in both the cases.

The lock() method returns a &MutexGuard<HashMap<..>>. MutexGuard also implements Deref hence it will return the reference of HashMap using the deref function from Deref trait.

trait Deref {
    //...
    fn deref<'a>(&'a self) -> &'a T
    // T is the data encapsulated inside self eg. MutexGuard<HashMap<..>>
}
Enter fullscreen mode Exit fullscreen mode

When we call query_result.get(key) under the hood, MutexGuard pointer derefrencing takes place using the above deref instance method. It means that the lifetime of the wrapped data HashMap will be the same as MutexGuard.

Since we do not store the value of MutexGuard anywhere it gets dropped right after the .get call if used in a single line. Since the hashmap has the same lifetime it also gets dropped and along with it any references it has returned as a result of get method also gets dropped.

However, when we are storing the mutexguard in a local variable of the function the lifetime of the mutexguard becomes the lifetime of the function scope. The lifetime of hashMap also becomes the lifetime of the function and hence the lifetime of the value returned from the hashMap also becomes the lifetime of the function scope.

We can use the query_result inside the function. The reason why we had to clone the value when returning from the function is because, query_result gets dropped after the scope of this function. The value is &Bytes type it also gets dropped after the function scope. Hence, using clone() to make a new copy is needed.

Handler object create

Let's move the main server logic to a new file src/server.rs. This file will hold main execution of server. We will move our code from bin/server.rs to src/server.rs.

File src/server.rs

use crate::Handler;
use crate::Listener;

pub async fn run(listener: &Listener) -> std::io::Result<()> {
    loop {
        let socket = listener.accept().await?;
        let mut handler = Handler::new(listener, socket);

        tokio::spawn(async move {
            if let Err(_err) = process_method(&mut handler).await {
                println!("Connection Error");
            }
        });
    }
}

async fn process_method(handler: &mut Handler) -> Result<(), std::io::Error> {
    while !handler.shutdown.is_shutdown() {
        let result = tokio::select! {
            _ = handler.shutdown.listen_recv() => {
                return Ok(());
            },
            res = handler.connection.read_buf_data() => res,
        };

        let (cmd, vec) = match result {
            Some((cmd, vec)) => (cmd, vec),
            None => return Ok(()),
        };

        handler.process_query(cmd, vec).await?;
    }
    Ok(())
}
Enter fullscreen mode Exit fullscreen mode

One very important conideration in the code above is that we are invoking handler.shutdown.listen_recv() and handler.connection.read_buf_data() in the tokio::select! macro block. Since the handler object is &mut i.e. mutable reference one might think that this violates the borrow rule that there can't be more than one borrow mutable reference of one object. However, we have 2 different fields being passed down and there is no overlap between them hence the compiler doesn't complain of the borrow rule being voilated.

Initially, to keep things simple, I had not created Connection and Shutdown structs. So I was getting the error of borrow rules. Later I decided to encapsulate the respective fields in different structs and the compiler stopped complaining of the borrow error.

Graceful Shutdown

The handler objects are listening for messages from notify_shutdown of listener object on the notify field in the shutdown object inside handler.

Let's explore the process_method function inside src/server.rs. This method exits in following cases:

  1. a successful completion of the process_query method at the bottom of the function.
  2. When the buffer is empty and handler.read_buf_data returns None.
  3. When handler.shutdown.listen_recv() returns a value.

In 3rd case, the value is only returned from handler.shutdown.listen_recv() when the notify_shutdown from listener sends a message.

So let's send a message from notify_shutdown using ctrl+c from bin/server.rs. This will result in tokio::select! going into the _ = shutdown => {} branch. The code moves out of the tokio::select! block. One way is to drop the notify_shutdown. All this message passing happens on the broadcast channel.

File bin/server.rs

// tokio::select! { ... }

drop(listener.notify_shutdown);
Enter fullscreen mode Exit fullscreen mode

File src/handler.rs. This is where the message is received.

pub async fn listen_recv(&mut self) -> Result<(), tokio::sync::broadcast::error::RecvError> {
    self.notify.recv().await?; // the message from drop(notify_shutdown) is received here and await unblock the control flow
    self.shutdown = true;
    Ok(())
}
Enter fullscreen mode Exit fullscreen mode

When the receiver notify receives the message, it immediately returns from the while loop into the run method of server.rs. Taking the control flow to the tokio::spawn(). This is where the handler object is dropped from the scope. At this point, the handler sends a message from _shutdown_complete to the receiver. This happens on the mpsc channel. Now we need to listen to these messages from the handler objects getting dropped.

// tokio::select! { ... }

drop(listener.notify_shutdown);
drop(listener.shutdown_complete_tx);

let _ = listener.shutdown_complete_rx.await;

// this point is only reached when all the handlers are dropped.
Enter fullscreen mode Exit fullscreen mode

In the code above, we are first dropping the shutdown_complete_tx. This needs to be done because, if we don't drop this sender then the shutdown_complete_rx.await will wait forever for a message from the sender going into an infinite wait time. The listener.shutdown_complete_rx.await; will wait for all the handlers to go out of scope and then return None when all handlers(cloned senders) have been dropped.

Using .await we are guaranteed that all the handlers are dropped at this point.

Congratualtions! you have successfully created a multi-threaded redis server and client using tokio.

Here is the latest code repository found on github.

If you have any queries of questions feel free to ask me in the comments section or discord channel.

Top comments (0)

🌚 Life is too short to browse without dark mode