diff --git a/proj-2/balancebeam/implement_notes.md b/proj-2/balancebeam/implement_notes.md index d1e82e7..14b1c53 100644 --- a/proj-2/balancebeam/implement_notes.md +++ b/proj-2/balancebeam/implement_notes.md @@ -43,4 +43,12 @@ 感觉 RwLock 和这个东西差不多,但是可能在某些情况下的并发性能会更好一些。至于它说的用 channel,感觉挺麻烦的,所以没写。 +### Milestone 4 + +实现周期性的主动连接测试(active health check):就是我们的代理程序隔一段时间向每个 upstream 的某一个特定地址(参数里面的 `active_health_check_path`)发一个 HTTP 请求,如果返回 200 就说明没问题;如果之前有问题现在没问题,就可以继续用;出现问题则需要标记为下线。 + +由于上个 Milestone 里面用到的极简数据结构,导致实现起来有点生草。这里我是用 `tokio::spawn` 来生成了一个 task,根据文档的说明,这是一个 green thread,鬼知道是什么东西。然后一个大循环,先 sleep 一波(不能放到后面,因为他的测试比较智障,放到后面会出错);遍历所有的 upstream(不管有没有被 disable),建立 `tcpstream` ,构造空请求,然后接受响应,最后判断状态码是不是200。以上任何一步出错,都必须从可用列表中删除(注意有可能已经删除了);如果是 200,则需要将它重新加入到可用列表中(如果之前被删除了的话)。 + +一个坑:不要使用前面实现的 `connect_to_upstream`,因为它无法连接到已经被移出可用列表的 upstream(不如说它根本就不让自己选 upstream),但是我们这里需要测试每个 upstream。 + ## 附加任务? diff --git a/proj-2/balancebeam/src/main.rs b/proj-2/balancebeam/src/main.rs index 7c12b0d..1a4ee95 100644 --- a/proj-2/balancebeam/src/main.rs +++ b/proj-2/balancebeam/src/main.rs @@ -4,8 +4,8 @@ mod response; use clap::Parser; use rand::{Rng, SeedableRng}; // use std::net::{TcpListener, TcpStream}; -use tokio::net::{TcpListener, TcpStream}; use std::sync::Arc; +use tokio::{net::{TcpListener, TcpStream}, io::AsyncWriteExt}; /// Contains information parsed from the command-line invocation of balancebeam. The Clap macros /// provide a fancy way to automatically construct a command-line argument parser. @@ -53,6 +53,72 @@ struct ProxyState { upstream_status: Arc>>, } +async fn disable_upstream(state: &ProxyState, index: usize) { + let mut upstream_status = state.upstream_status.lock().await; + log::debug!("Should disable {}", index); + if let Some(index) = upstream_status.iter().position(|&r| r == index) { + log::debug!("Disable [{}]{}", index, upstream_status[index]); + upstream_status.remove(index); + } + drop(upstream_status); +} + +fn active_health_check(state: Arc) { + let _ = tokio::spawn( async move { + let active_health_check_interval = state.active_health_check_interval; + let active_health_check_path = &state.active_health_check_path; + loop { + std::thread::sleep(std::time::Duration::from_secs( + active_health_check_interval as u64, + )); + let mut count = 0; + for upstream in &state.upstream_addresses { + count += 1; + let mut stream = match TcpStream::connect(upstream).await { + Ok(stream) => stream, + Err(_) => { + log::info!("Health check connect to {} failed", upstream); + disable_upstream(&state, count - 1).await; + continue; + } + }; + let request = http::Request::builder() + .method(http::Method::GET) + .uri(active_health_check_path) + .header("Host", upstream) + .body(Vec::new()) + .unwrap(); + if let Err(error) = request::write_to_stream(&request, &mut stream).await { + log::info!("[HealthCheck] Failed to send response to client: {}", error); + disable_upstream(&state, count - 1).await; + continue; + } + let response = match response::read_from_stream(&mut stream, request.method()).await + { + Ok(response) => response, + Err(error) => { + log::info!("Error reading response from server: {:?}", error); + disable_upstream(&state, count - 1).await; + continue; + } + }; + if response.status() != 200 { + log::info!("[HealthCheck] Get bad HTTP response {} from {}", response.status(), upstream); + disable_upstream(&state, count - 1).await; + } else { + log::debug!("[HealthCheck] Get good HTTP response {} from {}", response.status(), upstream); + let mut upstream_status = state.upstream_status.lock().await; + if let None = upstream_status.iter().position(|&r| r == count - 1) { + upstream_status.push(count - 1); + } + drop(upstream_status); + } + } + } + } + ); +} + #[tokio::main] async fn main() -> std::io::Result<()> { // Initialize the logging library. You can print log messages using the `log` macros: @@ -94,6 +160,11 @@ async fn main() -> std::io::Result<()> { upstream_status: Arc::new(tokio::sync::Mutex::new(upstream_status)), }; let astate = Arc::new(state); + // let _ = tokio::spawn(async move { + // active_health_check(&thd_state).await; + // }); + active_health_check(astate.clone()); + // let thread_pool = threadpool::ThreadPool::new(4); // for stream in listener.incoming() { // if let Ok(stream) = stream { @@ -114,7 +185,10 @@ async fn main() -> std::io::Result<()> { // } loop { let (socket, _) = listener.accept().await?; - handle_connection(socket, &astate).await; + let astate = astate.clone(); + tokio::spawn(async move { + handle_connection(socket, &astate).await; + }); } } @@ -123,7 +197,10 @@ async fn connect_to_upstream(state: &ProxyState) -> Result Result { let mut upstream_state = state.upstream_status.lock().await; upstream_state.remove(upstream_idx_idx); - if upstream_state.len() == 0{ + if upstream_state.len() == 0 { return stream; } drop(upstream_state); @@ -150,7 +227,11 @@ async fn connect_to_upstream(state: &ProxyState) -> Result>) { let client_ip = client_conn.peer_addr().unwrap().ip().to_string(); - log::info!("{} <- {}", client_ip, response::format_response_line(&response)); + log::info!( + "{} <- {}", + client_ip, + response::format_response_line(&response) + ); if let Err(error) = response::write_to_stream(&response, client_conn).await { log::warn!("Failed to send response to client: {}", error); return; @@ -216,7 +297,11 @@ async fn handle_connection(mut client_conn: TcpStream, state: &ProxyState) { // Forward the request to the server if let Err(error) = request::write_to_stream(&request, &mut upstream_conn).await { - log::error!("Failed to send request to upstream {}: {}", upstream_ip, error); + log::error!( + "Failed to send request to upstream {}: {}", + upstream_ip, + error + ); let response = response::make_http_error(http::StatusCode::BAD_GATEWAY); send_response(&mut client_conn, &response).await; return; @@ -224,7 +309,8 @@ async fn handle_connection(mut client_conn: TcpStream, state: &ProxyState) { log::debug!("Forwarded request to server"); // Read the server's response - let response = match response::read_from_stream(&mut upstream_conn, request.method()).await { + let response = match response::read_from_stream(&mut upstream_conn, request.method()).await + { Ok(response) => response, Err(error) => { log::error!("Error reading response from server: {:?}", error);