DEV Community

BC
BC

Posted on

Day30:a redis-memcache server with tokio - 100DayOfRust

I will write a memcache server with tokio by using redis protocol. This memcache server only support 2 command from redis:

  • get <key>
  • set <key> <value>

The goal is, after we run the server, we can use the standard redis-cli to connect to it and use get & set command.

Basic code structure

We utilize tokio to write this high performance server software:

In Cargo.toml:

[dependencies]
tokio = { version = "0.2", features = ["full"] }
Enter fullscreen mode Exit fullscreen mode

In main.rs:

use std::collections::HashMap;
use std::error::Error;

use tokio::net::{TcpListener, TcpStream};
use tokio::prelude::*;
use tokio::sync::Mutex;
use std::sync::Arc;


#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
    let addr = "127.0.0.1:7777";
    let mut listener = TcpListener::bind(addr).await?;

    println!("Listen on {}", addr);

    let dict = Arc::new(Mutex::new(HashMap::new()));

    loop {
        let (mut sock, _) = listener.accept().await?;
        let dict = dict.clone();
        tokio::spawn(async move {
            // todo
        });
    }
}
Enter fullscreen mode Exit fullscreen mode

The dict variable is for saving the key-value pair later. This dict we wrapped it with Mutex and Arc, because it will be used in multi-threads in tokio.

Redis Protocol

Redis use \r\n as separator in its protocol. To know the details of its protocol needs another post, so here we just introduce the basics:

For Simple Strings the first byte of the reply is "+"
For Errors the first byte of the reply is "-"
For Integers the first byte of the reply is ":"
For Bulk Strings the first byte of the reply is "$"
For Arrays the first byte of the reply is "*"

For example, for the get <key> command, redis-cli will send the content like this (suppose our key is "hello" in this case):

*2
$3
get
$5
hello
Enter fullscreen mode Exit fullscreen mode

The leading *2 indicates this is an array which contains 2 elements, the first one's length is 3, and the second one's length is 5.

Let's write 2 util functions to get the length integer and from the length integer get the followed string.

async fn read_till_crlf(stream: &mut TcpStream, skip: u8) -> Vec<u8> {
    let mut ret: Vec<u8> = vec![];
    let mut skip_num = skip;
    loop {
        let mut buf = [0; 1];
        stream.read_exact(&mut buf).await.unwrap();
        // LF's ascii number is 10
        if skip_num == 0 && buf[0] == 10 {
            break;
        }
        if skip_num > 0 {
            skip_num -= 1;
        } else {
            ret.push(buf[0]);
        }
    }
    // pop the last CR
    ret.pop();
    ret
}


async fn read_nbytes(stream: &mut TcpStream, nbytes: usize) -> Vec<u8> {
    let mut ret: Vec<u8> = vec![0; nbytes];
    stream.read_exact(&mut ret).await.unwrap();
    ret
}


async fn get_next_len(stream: &mut TcpStream) -> usize {
    let vlen = read_till_crlf(stream, 1).await;
    let slen = String::from_utf8(vlen).unwrap();
    let len:usize = slen.parse().unwrap();
    len
}


async fn get_next_string(stream: &mut TcpStream) -> String {
    let len = get_next_len(stream).await;
    let vs = read_nbytes(stream, len).await;
    // consume the followed \r\n
    let _ = read_nbytes(stream, 2).await;
    // build string and return
    let s = String::from_utf8(vs).unwrap();
    s
}
Enter fullscreen mode Exit fullscreen mode

The get_next_len will return the number from the stream like "*2\r\n" and "$3\r\n", etc. The get_next_string will return the string, for example, for "$5\r\nhello\r\n", it will return "hello".

Handle unknown command and syntax error

Since we only support get <key> and set <key> <value> command, let's write other 2 util functions:

async fn handle_unknown(stream: &mut TcpStream) {
    stream.write_all(b"-Unknown command\r\n").await.unwrap();
}

async fn handle_syntax_err(stream: &mut TcpStream) {
    stream.write_all(b"-ERR syntax error\r\n").await.unwrap();
}
Enter fullscreen mode Exit fullscreen mode

Handle GET command

async fn handle_get(stream: &mut TcpStream,
                    dict: &Arc<Mutex<HashMap<String, String>>>) {
    let key = get_next_string(stream).await;
    let map = dict.lock().await;
    let s = match map.get(key.as_str()) {
        Some(v) => {
            format!("${}\r\n{}\r\n", v.len(), v)
        },
        None => {
            "$-1\r\n".to_owned()
        },
    };
    stream.write_all(s.as_bytes()).await.unwrap();
}
Enter fullscreen mode Exit fullscreen mode

First we get the key, if key is in our dictionary, then return the value. Remember the returned value also need to follow the redis protocol.

Handle SET command

async fn handle_set(stream: &mut TcpStream,
                    dict: &Arc<Mutex<HashMap<String, String>>>) {
    let key = get_next_string(stream).await;
    let val = get_next_string(stream).await;
    let mut map = dict.lock().await;
    map.insert(key, val);
    stream.write_all(b"+OK\r\n").await.unwrap();
}
Enter fullscreen mode Exit fullscreen mode

For the set command, we get the key then the value, then set it into our dictionary. We don't care if the key already exists: if it does, we will just overwrite it.

Put it together

Full code:

use std::collections::HashMap;
use std::error::Error;

use tokio::net::{TcpListener, TcpStream};
use tokio::prelude::*;
use tokio::sync::Mutex;
use std::sync::Arc;


#[tokio::main]
async fn main() -> Result<(), Box<dyn Error>> {
    let addr = "127.0.0.1:7777";
    let mut listener = TcpListener::bind(addr).await?;

    println!("Listen on {}", addr);

    let dict = Arc::new(Mutex::new(HashMap::new()));

    loop {
        let (mut sock, _) = listener.accept().await?;
        let dict = dict.clone();
        tokio::spawn(async move {
            // get arg array length like *2, *3
            let arg_len = get_next_len(&mut sock).await;
            let cmd = get_next_string(&mut sock).await;
            let cmd = cmd.to_lowercase();
            if cmd == "get" {
                if arg_len != 2 {
                    handle_syntax_err(&mut sock).await;
                } else {
                    handle_get(&mut sock, &dict).await;
                }
            } else if cmd == "set" {
                if arg_len != 3 {
                    handle_syntax_err(&mut sock).await;
                } else {
                    handle_set(&mut sock, &dict).await;
                }
            } else {
                handle_unknown(&mut sock).await;
            }
        });
    }
}


async fn read_till_crlf(stream: &mut TcpStream, skip: u8) -> Vec<u8> {
    let mut ret: Vec<u8> = vec![];
    let mut skip_num = skip;
    loop {
        let mut buf = [0; 1];
        stream.read_exact(&mut buf).await.unwrap();
        // LF's ascii number is 10
        if skip_num == 0 && buf[0] == 10 {
            break;
        }
        if skip_num > 0 {
            skip_num -= 1;
        } else {
            ret.push(buf[0]);
        }
    }
    // pop the last CR
    ret.pop();
    ret
}


async fn read_nbytes(stream: &mut TcpStream, nbytes: usize) -> Vec<u8> {
    let mut ret: Vec<u8> = vec![0; nbytes];
    stream.read_exact(&mut ret).await.unwrap();
    ret
}


async fn get_next_len(stream: &mut TcpStream) -> usize {
    let vlen = read_till_crlf(stream, 1).await;
    let slen = String::from_utf8(vlen).unwrap();
    let len:usize = slen.parse().unwrap();
    len
}


async fn get_next_string(stream: &mut TcpStream) -> String {
    let len = get_next_len(stream).await;
    let vs = read_nbytes(stream, len).await;
    // consume the followed \r\n
    let _ = read_nbytes(stream, 2).await;
    // build string and return
    let s = String::from_utf8(vs).unwrap();
    s
}


async fn handle_get(stream: &mut TcpStream,
                    dict: &Arc<Mutex<HashMap<String, String>>>) {
    let key = get_next_string(stream).await;
    let map = dict.lock().await;
    let s = match map.get(key.as_str()) {
        Some(v) => {
            format!("${}\r\n{}\r\n", v.len(), v)
        },
        None => {
            "$-1\r\n".to_owned()
        },
    };
    stream.write_all(s.as_bytes()).await.unwrap();
}


async fn handle_set(stream: &mut TcpStream,
                    dict: &Arc<Mutex<HashMap<String, String>>>) {
    let key = get_next_string(stream).await;
    let val = get_next_string(stream).await;
    let mut map = dict.lock().await;
    map.insert(key, val);
    stream.write_all(b"+OK\r\n").await.unwrap();
}


async fn handle_unknown(stream: &mut TcpStream) {
    stream.write_all(b"-Unknown command\r\n").await.unwrap();
}

async fn handle_syntax_err(stream: &mut TcpStream) {
    stream.write_all(b"-ERR syntax error\r\n").await.unwrap();
}
Enter fullscreen mode Exit fullscreen mode

Let's run it with cargo run:

$ cargo run
Listen on 127.0.0.1:7777

Enter fullscreen mode Exit fullscreen mode

In another terminal window, use redis-cli to connect to it:

$ redis-cli -p 7777
127.0.0.1:7777> get hello
(nil)
127.0.0.1:7777> set hello world
OK
127.0.0.1:7777> get hello
"world"
127.0.0.1:7777> set hello world1 world2 world3
(error) ERR syntax error
127.0.0.1:7777> get hello
"world"
127.0.0.1:7777> set hello "world overwrite"
OK
127.0.0.1:7777> get hello
"world overwrite"
127.0.0.1:7777> command
(error) Unknown command
127.0.0.1:7777>
Enter fullscreen mode Exit fullscreen mode

It works! Thanks to tokio, with only around 100 lines code, we now have a multi-threads async memcache server. Even more, because we are using redis protocol, we can use the exisiting client library. For example, with Python's redis package (pip install redis), we can get the value associate with "hello" we just set:

>>> import redis
>>> r = redis.Redis(port=7777)
>>> r.get("hello")
b'world overwrite'
Enter fullscreen mode Exit fullscreen mode

Reference

Top comments (0)