本文最后更新于 2024-03-29,本文发布时间距今超过 90 天, 文章内容可能已经过时。最新内容请以官方内容为准

Tokio Async Example

This example demonstrates how to use the tokio crate to perform asynchronous operations.

Prerequisites

  • Rust
  • tokio crate
  • bytes crate
  • mini-redis crate

Installation

  1. Install Rust
  2. Add the tokio crate:
    • cargo add tokio
  3. Add the bytes crate:
    • cargo add bytes
  4. Add the mini-redis crate:
    • cargo add mini-redis

Code Snippet

For server:


// server.rs
use std::{
    collections::HashMap,
    sync::{Arc, Mutex},
};

use bytes::Bytes;
use mini_redis::{Connection, Frame};
use tokio::net::{TcpListener, TcpStream};

type Db = Arc<Mutex<HashMap<String, Bytes>>>;
// type ShardedDb = Arc<Vec<Mutex<HashMap<String, Bytes>>>>;

#[tokio::main]
async fn main() {
    let listener = TcpListener::bind("127.0.0.1:6379").await.unwrap();

    let db: Db = Arc::new(Mutex::new(HashMap::new()));

    loop {
        let (socket, _) = listener.accept().await.unwrap();

        let db = db.clone();

        // process(socket).await; // it's only for one socket at same time.
        tokio::spawn(async move {
            process(socket, db).await;
        });
    }
}

// fn get_sharded_db(num_shard: usize) -> ShardedDb {
//     let mut sharded_db = Vec::with_capacity(num_shard);
//     for _ in 0..num_shard {
//         sharded_db.push(Mutex::new(HashMap::new()));
//     }
//     Arc::new(sharded_db)
// }

async fn process(socket: TcpStream, db: Db) {
    use mini_redis::Command::{self, Get, Set};

    let mut connection = Connection::new(socket);

    while let Some(cmd) = connection.read_frame().await.unwrap() {
        let response = match Command::from_frame(cmd).unwrap() {
            Set(data) => {
                let mut db = db.lock().unwrap();
                db.insert(data.key().to_string(), data.value().clone());
                Frame::Simple("OK".to_string())
            }
            Get(data) => {
                let db = db.lock().unwrap();
                if let Some(value) = db.get(data.key()) {
                    Frame::Bulk(value.clone().into())
                } else {
                    Frame::Null
                }
            }
            cmd => panic!("Unimplemented! {:?}", cmd),
        };
        connection.write_frame(&response).await.unwrap();
    }
}

For client:


// client.rs
use bytes::Bytes;
use mini_redis::client;
use tokio::{sync::mpsc, sync::oneshot};

// provided by the requester and used bt the manager task to send
// the command response back to the requester
type Responder<T> = oneshot::Sender<mini_redis::Result<T>>;

#[derive(Debug)]
enum Command {
    GET {
        key: String,
        resp: Responder<Option<Bytes>>,
    },
    SET {
        key: String,
        value: Bytes,
        resp: Responder<()>,
    },
}

#[tokio::main]
async fn main() {
    // create a new channel with capacity of at most 32.
    println!("start of client");
    let (tx, mut rx) = mpsc::channel(32);

    let tx2 = tx.clone();

    let manager = tokio::spawn(async move {
        let mut client = client::connect("127.0.0.1:6379").await.unwrap();
        println!("start of manager");

        while let Some(cmd) = rx.recv().await {
            use Command::*;

            match cmd {
                GET { key, resp } => {
                    println!("Cmd is GET, key is {}", &key);
                    let res = client.get(&key).await;
                    let _ = resp.send(res);
                }
                SET { key, value, resp } => {
                    println!("Cmd is SET, key is {},value is {:?}", &key, &value);
                    let res = client.set(&key, value).await;
                    let _ = resp.send(res);
                }
            }
        }
    });

    // Spawn two tasks, one setting a value and other querying for key that was
    // set.
    let t1 = tokio::spawn(async move {
        // tx.send("Sending from first handle").await.unwrap();
        println!("Start of task 1");

        let (resp_tx, resp_rx) = oneshot::channel();
        let cmd = Command::GET {
            key: "foo".into(),
            resp: resp_tx,
        };

        // send the GET request
        tx.send(cmd).await.unwrap();
        println!("Task 1: Send done.");

        // Await the response
        let resp = resp_rx.await;
        println!("Task 1: GOT = {:?}", resp);
    });

    let t2 = tokio::spawn(async move {
        // tx2.send("Sending from scecond handle").await.unwrap();
        println!("Start of task 2");

        let (resp_tx, resp_rx) = oneshot::channel();

        let cmd = Command::SET {
            key: "foo".into(),
            value: "bar".into(),
            resp: resp_tx,
        };

        // send the SET request
        tx2.send(cmd).await.unwrap();
        println!("Task 2: Send done.");

        // Await the response
        let resp = resp_rx.await;
        println!("Task 2: GOT = {:?}", resp);
    });

    t1.await.unwrap();
    t2.await.unwrap();

    manager.await.unwrap();

}

How to Run

  1. Copy the code to a file named server.rs and client.rs respectively under src/bin/ folder.
  2. Open two separate terminals.
  3. Run the server: cargo run --bin server
    • The server will start listening on port 6379.
  4. Run the client: cargo run --bin client
    • The client will send two commands to the server: one to set a value and another to get the value.

Explanation

这段代码是一个简化版的 Redis 客户端和服务器的模拟实现,使用了 Rust 的 tokio 库来处理异步操作,mini_redis 库来模拟 Redis 命令和协议。以下是代码的逐步解释:

客户端 (client.rs)

  1. 导入必要的模块和类型,包括字节处理库 bytes,异步通信库 tokiompsconeshot,以及 mini_redis 用于模拟 Redis 命令。
  2. 定义 Responder<T> 类型,它是一个 oneshot::Sender,用于将命令的响应发送回请求者。
  3. 定义 Command 枚举,它包含两种类型的命令:GETSET。每种命令都包含一个键(key)和一个响应发送器(resp),用于发送结果。SET 命令还包括一个值(value)。
  4. main 函数中,创建一个容量为 32 的 mpsc 通道 (txrx)。
  5. 克隆 tx 通道,以便在不同的任务中使用。
  6. 启动一个名为 manager 的异步任务,它连接到 Redis 服务器(在本例中是本地服务器),并等待从 rx 通道接收命令。
  7. 对于接收到的每个命令,根据命令类型(GETSET)执行相应的操作,并通过响应发送器发送结果。
  8. 启动两个子任务(t1t2),一个用于设置键值对(SET 命令),另一个用于查询键(GET 命令)。
  9. 等待两个子任务和 manager 任务完成。

服务器 (server.rs)

  1. 导入必要的模块和类型,包括同步原语 ArcMutex,字节处理库 bytes,异步通信库 tokio,以及 mini_redis 用于处理 Redis 命令。
  2. 定义 DbShardedDb 类型,用于模拟数据库存储。
  3. main 函数中,创建一个 TcpListener 来监听本地端口 6379 上的连接。
  4. 创建一个 Db 实例,用于存储键值对。
  5. 循环接受连接,并为每个连接创建一个新的任务来处理它。
  6. process 函数中,创建一个新的 TcpStream 连接,并根据接收到的命令(SetGet)执行相应的操作。对于 Set 命令,它将键值对存储在数据库中。对于 Get 命令,它从数据库中检索值,并将其作为响应发送回客户端。
  7. 使用 mini_redis::Command 枚举来匹配接收到的命令,并根据命令类型生成响应。

客户端和服务器的交互

  1. 客户端启动并创建一个消息通道 (mpsc)。
  2. 客户端启动一个 manager 任务,它等待接收命令并将其发送到 Redis 服务器。
  3. 客户端启动两个子任务,一个用于发送 SET 命令,另一个用于发送 GET 命令。
  4. SET 子任务通过 mpsc 通道发送一个 SET 命令到 manager 任务。
  5. manager 任务接收到 SET 命令,并通过 Redis 客户端将其发送到 Redis 服务器。
  6. Redis 服务器接收到 SET 命令,将其添加到数据库中,并发送一个确认响应回 manager 任务。
  7. manager 任务接收到响应,并通过 oneshot 通道将其发送回 SET 子任务。
  8. GET 子任务通过 mpsc 通道发送一个 GET 命令到 manager 任务。
  9. manager 任务接收到 GET 命令,并通过 Redis 客户端将其发送到 Redis 服务器。
  10. Redis 服务器检索键的值,并通过 manager 任务将响应发送回客户端。
  11. manager 任务接收到响应,并通过 oneshot 通道将其发送回 GET 子任务。
  12. 客户端接收到响应,并打印结果。

这个模拟实现了基本的客户端 - 服务器通信模式,其中客户端发送命令到服务器,服务器处理命令并返回响应。
这个例子使用了异步编程模型来处理多个并发连接和命令。

详细解释

当客户端和服务器运行时,以下是代码的逐步执行和交互逻辑的详细解释:

  1. 服务器端 server.rs 开始运行:

    • 创建一个 TCP 监听器 TcpListener,绑定到本地地址 127.0.0.1:6379,准备接受客户端连接请求。
    • 创建一个数据库 Db,用于存储键值对。
  2. 服务器进入无限循环,等待客户端连接:

    • 当有新的客户端连接请求到达时,通过 listener.accept().await 接受连接,返回一个套接字 socket
    • 克隆数据库 Db,以便在每个任务中都有独立的数据库副本。
  3. 为每个客户端连接创建一个任务:

    • 使用 tokio::spawn 创建一个异步任务,将客户端的套接字 socket 和数据库 Db 传递给 process 函数进行处理。
  4. 在客户端 client.rs 中,开始运行:

    • 创建一个多生产者单消费者通道 mpsc::channel,用于客户端和管理任务之间的通信。
    • 克隆通道的发送端 tx,以便在后续的任务中使用。
  5. 创建管理任务 manager

    • 通过 tokio::spawn 创建一个异步任务,该任务会连接到服务器的 127.0.0.1:6379 地址,并在接收到客户端命令时进行处理。
  6. 任务 t1 开始运行:

    • 创建一个响应者(Responder),用于接收来自管理任务的命令响应。
    • 创建一个 GET 命令,其中包含键 "foo" 和响应者 resp
    • 通过通道的发送端 tx,将 GET 命令发送给管理任务。
  7. 管理任务接收到 GET 命令:

    • 通过 rx.recv().await 从通道的接收端接收到 GET 命令。
    • 根据命令类型为 GET,执行相应的操作:
      • 连接到服务器的 127.0.0.1:6379 地址。
      • 发送 GET 命令到服务器,请求获取键 "foo" 的值。
      • 将获取到的值作为响应发送给客户端的响应者 resp
  8. 任务 t1 继续运行:

    • 等待从响应者 resp_rx 接收到响应。
    • 打印接收到的响应。
  9. 任务 t2 开始运行:

    • 创建一个响应者(Responder),用于接收来自管理任务的命令响应。
    • 创建一个 SET 命令,其中包含键 "foo"、值 "bar" 和响应者 resp
    • 通过通道的发送端 tx2,将 SET 命令发送给管理任务。
  10. 管理任务接收到 SET 命令:

    • 通过 rx.recv().await 从通道的接收端接收到 SET 命令。
    • 根据命令类型为 SET,执行相应的操作:
      • 连接到服务器的 127.0.0.1:6379 地址。
      • 发送 SET 命令到服务器,将键 "foo" 的值设置为 "bar"
      • 将操作结果作为响应发送给客户端的响应者 resp
  11. 任务 t2 继续运行:

    • 等待从响应者 resp_rx 接收到响应。
    • 打印接收到的响应。
  12. 管理任务继续运行:

    • 通过 rx.recv().await 从通道的接收端接收到 None 值,表示通道已关闭,不再有新的命令到达。
  13. 服务器端的 process 函数处理客户端连接:

    • 通过 connection.read_frame().await 从客户端连接中读取命令。
    • 根据命令类型执行相应的操作:
      • 如果是 SET 命令,将键值对存储到数据库 Db 中,并回复客户端一个 “OK” 的响应。
      • 如果是 GET 命令,从数据库 Db 中获取相应的值,并回复客户端对应的值。
      • 如果收到的命令是其他未实现的命令,则抛出一个错误。
    • 通过 connection.write_frame(&response).await 将响应发送给客户端。

Improvements

  1. 在 Server 端,每次 loop 时,都会 clone 一个 Db 实例,感觉有点多余,可以优化一下。
    • 优化 Server 端的代码,减少 clone 操作。
      • 使用 Arc<Mutex< ShardedDb >> 对数据库进行分片
      • 然后固定数据库分片数量,再通过 hash 函数,将 key 映射到分片上

Arc(原子引用计数)

An explanation of the relationship between Arc and clone:
Arc stands for “atomic reference counting”. It is a type in Rust’s standard library that provides thread-safe reference counting, allowing multiple ownership of a value across threads. It is often used to share data between multiple tasks or threads in a concurrent program.

clone() is a method available on many types in Rust, including Arc. It creates a new instance of the same type with the same data, effectively producing a clone of the original value. The clone() method for Arc creates a new Arc instance that shares ownership of the same underlying data as the original Arc. It increases the reference count of the shared data, ensuring that it remains alive as long as there is at least one active reference to it.

In the context of Arc, calling clone() on an Arc instance does not create a full deep copy of the underlying data. Instead, it increments the reference count and returns a new Arc that points to the same shared data. This means that the original Arc and the cloned Arc refer to the same data in memory, and any changes made to the data will be visible through both Arc instances.

Here are some references to the official Rust documentation for Arc and clone():


// server.rs

use std::{
    collections::HashMap,
    hash::{DefaultHasher, Hash, Hasher},
    sync::{Arc, Mutex},
};

use bytes::Bytes;
use mini_redis::{Connection, Frame};
use tokio::net::{TcpListener, TcpStream};

type Db = Arc<Mutex<HashMap<String, Bytes>>>;
type ShardedDb = Arc<Vec<Mutex<HashMap<String, Bytes>>>>;
const NUM_SHARDS: usize = 16;

#[tokio::main]
async fn main() {
    let listener = TcpListener::bind("127.0.0.1:6379").await.unwrap();

    // let db: Db = Arc::new(Mutex::new(HashMap::new()));
    let db: ShardedDb = get_sharded_db(NUM_SHARDS);

    loop {
        let (socket, _) = listener.accept().await.unwrap();

        // ##############################################################################
        // ## If not clone it, the ownership cannot be shared. Compiler will complain. ##
        // ##############################################################################
        // Clone the Arc
        // The clone() method for Arc creates a new Arc instance that shares ownership of the same underlying data as the original Arc.
        // It increases the reference count of the shared data,
        // ensuring that it remains alive as long as there is at least one active reference to it.
        
        let db = db.clone();

        tokio::spawn(async move {
            process(socket, db).await;
        });
    }
}

/**
 * 创建并返回一个分片数据库的共享实例。
 * 
 * `num_shards` 参数指定数据库的分片数量。
 * 每个分片都是一个互斥锁保护的哈希映射,用于存储数据。
 * 
 * 返回值是一个指向分片数据库的 Arc (原子引用计数) 指针,
 * 允许多个线程安全地共享和访问分片数据库。
 * 
 * @param num_shards 分片的数量。
 * @return 返回一个 `ShardedDb`,它是分片数据库的共享实例。
 */
fn get_sharded_db(num_shards: usize) -> ShardedDb {
    // 初始化一个具有指定容量的空 Vec,用于存储分片。
    let mut sharded_db = Vec::with_capacity(num_shards);
    
    // 为每个分片创建一个互斥锁保护的哈希映射,并添加到 sharded_db 中。
    for _ in 0..num_shards {
        sharded_db.push(Mutex::new(HashMap::new()));
    }
    
    // 将 sharded_db 包装在 Arc 中,以便于跨线程共享。
    Arc::new(sharded_db)
}

/**
 * 对给定的字符串键进行哈希处理并返回哈希值。
 * 
 * @param key 指向要哈希的字符串的引用。
 * @return 返回一个 usize 类型的哈希值。
 */
fn hash_key(key: &str) -> usize {
    // 创建一个默认的哈希器实例
    let mut hasher = DefaultHasher::default();
    // 使用字符串 key 对哈希器进行哈希处理
    key.hash(&mut hasher);
    // 获取并转换哈希结果为 usize 类型
    hasher.finish() as usize
}

async fn process(socket: TcpStream, db: ShardedDb) {
    use mini_redis::Command::{self, Get, Set};

    let mut connection = Connection::new(socket);

    while let Some(cmd) = connection.read_frame().await.unwrap() {
        let response = match Command::from_frame(cmd).unwrap() {
            Set(data) => {
                // let mut db = db.lock().unwrap();
                let shard_index = hash_key(data.key()) % db.len();
                let mut db = db[shard_index].lock().unwrap();
                db.insert(data.key().to_string(), data.value().clone());
                println!(
                    "Process done: SET (key: {}, value: {:?})",
                    data.key(),
                    data.value()
                );
                Frame::Simple("OK".to_string())
            }
            Get(data) => {
                // let db = db.lock().unwrap();
                let shard_index = hash_key(data.key()) % db.len();
                let db = db[shard_index].lock().unwrap();
                if let Some(value) = db.get(data.key()) {
                    println!(
                        "Process done: GET (key: {},value: {:?})",
                        data.key(),
                        value.clone()
                    );
                    Frame::Bulk(value.clone().into())
                } else {
                    println!("Process done: GET (key: {},value: NUll)", data.key());
                    Frame::Null
                }
            }
            cmd => panic!("Unimplemented! {:?}", cmd),
        };
        connection.write_frame(&response).await.unwrap();
    }
}

References

  1. tokio
  2. Rust 异步编程
  3. mpsc