diff --git a/proj-2/balancebeam/implement_notes.md b/proj-2/balancebeam/implement_notes.md index 0e13e03..d1e82e7 100644 --- a/proj-2/balancebeam/implement_notes.md +++ b/proj-2/balancebeam/implement_notes.md @@ -34,4 +34,13 @@ 又是一个奇怪的 task point,要求把所有的 IO 换成异步的。但是依然没有任何的难度,根据 IDE 提示加 `async` 和 `await` 就可以了。 -## 附加任务? \ No newline at end of file +### Milestone 3 + +实现故障转移(failover):如果一个 upstream 挂了,将连接选一个别的(一开始有一堆可用的 upstream),同时还要标记这个挂掉的 upstream 防止后面的连接继续选择它;如果所有的 upstream 都挂了,这时候才返回错误。实现方法是修改 `main.rs::connect_to_upstream`,因为这个过程仅在建立连接的时候执行。 + +实现了一个最简单的,在 `ProxyState` 里面加了一个 `Arc>>`,用来存放可用的 upstream 的 index。 +选择的时候就在这个加了锁的 `Vec` 里面选,然后用这个里面的 index 来取对应的 upstream 的地址。如果一开始就发现这个可用列表为空,那么直接返回错误;否则尝试连接。如果连接失败,需要将这个 index 从可用列表里面删除。唯一需要注意的就是,可以在 `TcpStream::connect` 之前释放一次锁,然后需要修改可用列表的时候再加锁,因为 connect 比较耗时。 + +感觉 RwLock 和这个东西差不多,但是可能在某些情况下的并发性能会更好一些。至于它说的用 channel,感觉挺麻烦的,所以没写。 + +## 附加任务? diff --git a/proj-2/balancebeam/src/main.rs b/proj-2/balancebeam/src/main.rs index b1c71a8..7c12b0d 100644 --- a/proj-2/balancebeam/src/main.rs +++ b/proj-2/balancebeam/src/main.rs @@ -5,6 +5,7 @@ use clap::Parser; use rand::{Rng, SeedableRng}; // use std::net::{TcpListener, TcpStream}; use tokio::net::{TcpListener, TcpStream}; +use std::sync::Arc; /// Contains information parsed from the command-line invocation of balancebeam. The Clap macros /// provide a fancy way to automatically construct a command-line argument parser. @@ -48,6 +49,8 @@ struct ProxyState { max_requests_per_minute: usize, /// Addresses of servers that we are proxying to upstream_addresses: Vec, + /// Dead upstream marks + upstream_status: Arc>>, } #[tokio::main] @@ -77,14 +80,20 @@ async fn main() -> std::io::Result<()> { }; log::info!("Listening for requests on {}", options.bind); + let mut upstream_status = Vec::with_capacity(options.upstream.len()); + for idx in 0..options.upstream.len() { + upstream_status.push(idx); + } + // Handle incoming connections let state = ProxyState { upstream_addresses: options.upstream, active_health_check_interval: options.active_health_check_interval, active_health_check_path: options.active_health_check_path, max_requests_per_minute: options.max_requests_per_minute, + upstream_status: Arc::new(tokio::sync::Mutex::new(upstream_status)), }; - let astate = std::sync::Arc::new(state); + let astate = Arc::new(state); // let thread_pool = threadpool::ThreadPool::new(4); // for stream in listener.incoming() { // if let Ok(stream) = stream { @@ -111,13 +120,32 @@ async fn main() -> std::io::Result<()> { async fn connect_to_upstream(state: &ProxyState) -> Result { let mut rng = rand::rngs::StdRng::from_entropy(); - let upstream_idx = rng.gen_range(0..state.upstream_addresses.len()); - let upstream_ip = &state.upstream_addresses[upstream_idx]; - TcpStream::connect(upstream_ip).await.or_else(|err| { - log::error!("Failed to connect to upstream {}: {}", upstream_ip, err); - Err(err) - }) - // TODO: implement failover (milestone 3) + loop { + let upstream_state = state.upstream_status.lock().await; + if upstream_state.len() == 0 { + return Err(std::io::Error::new(std::io::ErrorKind::Other, "No living upstream")); + } + let upstream_idx_idx = rng.gen_range(0..upstream_state.len()); + let upstream_ip = &state.upstream_addresses[upstream_state[upstream_idx_idx]]; + drop(upstream_state); // first release lock, 'cause tcp connect is time consuming + let stream = TcpStream::connect(upstream_ip).await; + match stream { + Ok(_) => return stream, + Err(_) => { + let mut upstream_state = state.upstream_status.lock().await; + upstream_state.remove(upstream_idx_idx); + if upstream_state.len() == 0{ + return stream; + } + drop(upstream_state); + } + } + // TcpStream::connect(upstream_ip).await.or_else(|err| { + // log::error!("Failed to connect to upstream {}: {}", upstream_ip, err); + // upstream_state[upstream_idx] = false; + // Err(err) + // }) + } } async fn send_response(client_conn: &mut TcpStream, response: &http::Response>) {