diff --git a/proj-2/balancebeam/implement_notes.md b/proj-2/balancebeam/implement_notes.md index cb67b43..0e13e03 100644 --- a/proj-2/balancebeam/implement_notes.md +++ b/proj-2/balancebeam/implement_notes.md @@ -30,5 +30,8 @@ `ThreadPool` 和 `std::thread` 没啥太大区别,都封装的很好了,直接看一下 `doc.rs` 里面的例子就行了。 +### Milestone 2 + +又是一个奇怪的 task point,要求把所有的 IO 换成异步的。但是依然没有任何的难度,根据 IDE 提示加 `async` 和 `await` 就可以了。 ## 附加任务? \ No newline at end of file diff --git a/proj-2/balancebeam/src/main.rs b/proj-2/balancebeam/src/main.rs index 49fc28b..b1c71a8 100644 --- a/proj-2/balancebeam/src/main.rs +++ b/proj-2/balancebeam/src/main.rs @@ -3,7 +3,8 @@ mod response; use clap::Parser; use rand::{Rng, SeedableRng}; -use std::net::{TcpListener, TcpStream}; +// use std::net::{TcpListener, TcpStream}; +use tokio::net::{TcpListener, TcpStream}; /// Contains information parsed from the command-line invocation of balancebeam. The Clap macros /// provide a fancy way to automatically construct a command-line argument parser. @@ -49,7 +50,8 @@ struct ProxyState { upstream_addresses: Vec, } -fn main() { +#[tokio::main] +async fn main() -> std::io::Result<()> { // Initialize the logging library. You can print log messages using the `log` macros: // https://docs.rs/log/0.4.8/log/ You are welcome to continue using print! statements; this // just looks a little prettier. @@ -66,7 +68,7 @@ fn main() { } // Start listening for connections - let listener = match TcpListener::bind(&options.bind) { + let listener = match TcpListener::bind(&options.bind).await { Ok(listener) => listener, Err(err) => { log::error!("Could not bind to {}: {}", options.bind, err); @@ -84,55 +86,59 @@ fn main() { }; let astate = std::sync::Arc::new(state); // let thread_pool = threadpool::ThreadPool::new(4); - for stream in listener.incoming() { - if let Ok(stream) = stream { - // Handle the connection! - // ------ std thread ------ - let thd_state = astate.clone(); - std::thread::spawn(move || { - handle_connection(stream, &thd_state); - }); - // ------ thread pool ------ - // let thd_state = astate.clone(); - // thread_pool.execute(move || { - // handle_connection(stream, &thd_state); - // }); - // ------ single thread ------ - // handle_connection(stream, &state); - } + // for stream in listener.incoming() { + // if let Ok(stream) = stream { + // // Handle the connection! + // // ------ std thread ------ + // let thd_state = astate.clone(); + // std::thread::spawn(move || { + // handle_connection(stream, &thd_state); + // }); + // // ------ thread pool ------ + // // let thd_state = astate.clone(); + // // thread_pool.execute(move || { + // // handle_connection(stream, &thd_state); + // // }); + // // ------ single thread ------ + // // handle_connection(stream, &state); + // } + // } + loop { + let (socket, _) = listener.accept().await?; + handle_connection(socket, &astate).await; } } -fn connect_to_upstream(state: &ProxyState) -> Result { +async fn connect_to_upstream(state: &ProxyState) -> Result { let mut rng = rand::rngs::StdRng::from_entropy(); let upstream_idx = rng.gen_range(0..state.upstream_addresses.len()); let upstream_ip = &state.upstream_addresses[upstream_idx]; - TcpStream::connect(upstream_ip).or_else(|err| { + TcpStream::connect(upstream_ip).await.or_else(|err| { log::error!("Failed to connect to upstream {}: {}", upstream_ip, err); Err(err) }) // TODO: implement failover (milestone 3) } -fn send_response(client_conn: &mut TcpStream, response: &http::Response>) { +async fn send_response(client_conn: &mut TcpStream, response: &http::Response>) { let client_ip = client_conn.peer_addr().unwrap().ip().to_string(); log::info!("{} <- {}", client_ip, response::format_response_line(&response)); - if let Err(error) = response::write_to_stream(&response, client_conn) { + if let Err(error) = response::write_to_stream(&response, client_conn).await { log::warn!("Failed to send response to client: {}", error); return; } } -fn handle_connection(mut client_conn: TcpStream, state: &ProxyState) { +async fn handle_connection(mut client_conn: TcpStream, state: &ProxyState) { let client_ip = client_conn.peer_addr().unwrap().ip().to_string(); log::info!("Connection received from {}", client_ip); // Open a connection to a random destination server - let mut upstream_conn = match connect_to_upstream(state) { + let mut upstream_conn = match connect_to_upstream(state).await { Ok(stream) => stream, Err(_error) => { let response = response::make_http_error(http::StatusCode::BAD_GATEWAY); - send_response(&mut client_conn, &response); + send_response(&mut client_conn, &response).await; return; } }; @@ -142,7 +148,7 @@ fn handle_connection(mut client_conn: TcpStream, state: &ProxyState) { // client hangs up or we get an error. loop { // Read a request from the client - let mut request = match request::read_from_stream(&mut client_conn) { + let mut request = match request::read_from_stream(&mut client_conn).await { Ok(request) => request, // Handle case where client closed connection and is no longer sending requests Err(request::Error::IncompleteRequest(0)) => { @@ -164,7 +170,7 @@ fn handle_connection(mut client_conn: TcpStream, state: &ProxyState) { request::Error::RequestBodyTooLarge => http::StatusCode::PAYLOAD_TOO_LARGE, request::Error::ConnectionError(_) => http::StatusCode::SERVICE_UNAVAILABLE, }); - send_response(&mut client_conn, &response); + send_response(&mut client_conn, &response).await; continue; } }; @@ -181,26 +187,26 @@ fn handle_connection(mut client_conn: TcpStream, state: &ProxyState) { request::extend_header_value(&mut request, "x-forwarded-for", &client_ip); // Forward the request to the server - if let Err(error) = request::write_to_stream(&request, &mut upstream_conn) { + if let Err(error) = request::write_to_stream(&request, &mut upstream_conn).await { log::error!("Failed to send request to upstream {}: {}", upstream_ip, error); let response = response::make_http_error(http::StatusCode::BAD_GATEWAY); - send_response(&mut client_conn, &response); + send_response(&mut client_conn, &response).await; return; } log::debug!("Forwarded request to server"); // Read the server's response - let response = match response::read_from_stream(&mut upstream_conn, request.method()) { + let response = match response::read_from_stream(&mut upstream_conn, request.method()).await { Ok(response) => response, Err(error) => { log::error!("Error reading response from server: {:?}", error); let response = response::make_http_error(http::StatusCode::BAD_GATEWAY); - send_response(&mut client_conn, &response); + send_response(&mut client_conn, &response).await; return; } }; // Forward the response to the client - send_response(&mut client_conn, &response); + send_response(&mut client_conn, &response).await; log::debug!("Forwarded response to client"); } } diff --git a/proj-2/balancebeam/src/request.rs b/proj-2/balancebeam/src/request.rs index d07897f..65024a5 100644 --- a/proj-2/balancebeam/src/request.rs +++ b/proj-2/balancebeam/src/request.rs @@ -1,6 +1,8 @@ use std::cmp::min; -use std::io::{Read, Write}; -use std::net::TcpStream; +// use std::io::{Read, Write}; +// use std::net::TcpStream; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; +use tokio::net::TcpStream; const MAX_HEADERS_SIZE: usize = 8000; const MAX_BODY_SIZE: usize = 10000000; @@ -101,7 +103,7 @@ fn parse_request(buffer: &[u8]) -> Result>, usize) /// Returns Ok(http::Request) if a valid request is received, or Error if not. /// /// You will need to modify this function in Milestone 2. -fn read_headers(stream: &mut TcpStream) -> Result>, Error> { +async fn read_headers(stream: &mut TcpStream) -> Result>, Error> { // Try reading the headers from the request. We may not receive all the headers in one shot // (e.g. we might receive the first few bytes of a request, and then the rest follows later). // Try parsing repeatedly until we read a valid HTTP request @@ -110,7 +112,7 @@ fn read_headers(stream: &mut TcpStream) -> Result>, Error> loop { // Read bytes from the connection into the buffer, starting at position bytes_read let new_bytes = stream - .read(&mut request_buffer[bytes_read..]) + .read(&mut request_buffer[bytes_read..]).await .or_else(|err| Err(Error::ConnectionError(err)))?; if new_bytes == 0 { // We didn't manage to read a complete request @@ -137,7 +139,7 @@ fn read_headers(stream: &mut TcpStream) -> Result>, Error> /// returns Ok(()) if successful, or Err(Error) if Content-Length bytes couldn't be read. /// /// You will need to modify this function in Milestone 2. -fn read_body( +async fn read_body( stream: &mut TcpStream, request: &mut http::Request>, content_length: usize, @@ -147,7 +149,7 @@ fn read_body( // Read up to 512 bytes at a time. (If the client only sent a small body, then only allocate // space to read that body.) let mut buffer = vec![0_u8; min(512, content_length)]; - let bytes_read = stream.read(&mut buffer).or_else(|err| Err(Error::ConnectionError(err)))?; + let bytes_read = stream.read(&mut buffer).await.or_else(|err| Err(Error::ConnectionError(err)))?; // Make sure the client is still sending us bytes if bytes_read == 0 { @@ -178,15 +180,15 @@ fn read_body( /// closes the connection prematurely or sends an invalid request. /// /// You will need to modify this function in Milestone 2. -pub fn read_from_stream(stream: &mut TcpStream) -> Result>, Error> { +pub async fn read_from_stream(stream: &mut TcpStream) -> Result>, Error> { // Read headers - let mut request = read_headers(stream)?; + let mut request = read_headers(stream).await?; // Read body if the client supplied the Content-Length header (which it does for POST requests) if let Some(content_length) = get_content_length(&request)? { if content_length > MAX_BODY_SIZE { return Err(Error::RequestBodyTooLarge); } else { - read_body(stream, &mut request, content_length)?; + read_body(stream, &mut request, content_length).await?; } } Ok(request) @@ -195,20 +197,20 @@ pub fn read_from_stream(stream: &mut TcpStream) -> Result> /// This function serializes a request to bytes and writes those bytes to the provided stream. /// /// You will need to modify this function in Milestone 2. -pub fn write_to_stream( +pub async fn write_to_stream( request: &http::Request>, stream: &mut TcpStream, ) -> Result<(), std::io::Error> { - stream.write(&format_request_line(request).into_bytes())?; - stream.write(&['\r' as u8, '\n' as u8])?; // \r\n + stream.write(&format_request_line(request).into_bytes()).await?; + stream.write(&['\r' as u8, '\n' as u8]).await?; // \r\n for (header_name, header_value) in request.headers() { - stream.write(&format!("{}: ", header_name).as_bytes())?; - stream.write(header_value.as_bytes())?; - stream.write(&['\r' as u8, '\n' as u8])?; // \r\n + stream.write(&format!("{}: ", header_name).as_bytes()).await?; + stream.write(header_value.as_bytes()).await?; + stream.write(&['\r' as u8, '\n' as u8]).await?; // \r\n } - stream.write(&['\r' as u8, '\n' as u8])?; + stream.write(&['\r' as u8, '\n' as u8]).await?; if request.body().len() > 0 { - stream.write(request.body())?; + stream.write(request.body()).await?; } Ok(()) } diff --git a/proj-2/balancebeam/src/response.rs b/proj-2/balancebeam/src/response.rs index a6dedca..b74f386 100644 --- a/proj-2/balancebeam/src/response.rs +++ b/proj-2/balancebeam/src/response.rs @@ -1,5 +1,7 @@ -use std::io::{Read, Write}; -use std::net::TcpStream; +// use std::io::{Read, Write}; +// use std::net::TcpStream; +use tokio::net::TcpStream; +use tokio::io::{AsyncReadExt, AsyncWriteExt}; const MAX_HEADERS_SIZE: usize = 8000; const MAX_BODY_SIZE: usize = 10000000; @@ -80,7 +82,7 @@ fn parse_response(buffer: &[u8]) -> Result>, usiz /// Returns Ok(http::Response) if a valid response is received, or Error if not. /// /// You will need to modify this function in Milestone 2. -fn read_headers(stream: &mut TcpStream) -> Result>, Error> { +async fn read_headers(stream: &mut TcpStream) -> Result>, Error> { // Try reading the headers from the response. We may not receive all the headers in one shot // (e.g. we might receive the first few bytes of a response, and then the rest follows later). // Try parsing repeatedly until we read a valid HTTP response @@ -89,7 +91,7 @@ fn read_headers(stream: &mut TcpStream) -> Result>, Error loop { // Read bytes from the connection into the buffer, starting at position bytes_read let new_bytes = stream - .read(&mut response_buffer[bytes_read..]) + .read(&mut response_buffer[bytes_read..]).await .or_else(|err| Err(Error::ConnectionError(err)))?; if new_bytes == 0 { // We didn't manage to read a complete response @@ -114,7 +116,7 @@ fn read_headers(stream: &mut TcpStream) -> Result>, Error /// present, it reads that many bytes; otherwise, it reads bytes until the connection is closed. /// /// You will need to modify this function in Milestone 2. -fn read_body(stream: &mut TcpStream, response: &mut http::Response>) -> Result<(), Error> { +async fn read_body(stream: &mut TcpStream, response: &mut http::Response>) -> Result<(), Error> { // The response may or may not supply a Content-Length header. If it provides the header, then // we want to read that number of bytes; if it does not, we want to keep reading bytes until // the connection is closed. @@ -123,7 +125,7 @@ fn read_body(stream: &mut TcpStream, response: &mut http::Response>) -> while content_length.is_none() || response.body().len() < content_length.unwrap() { let mut buffer = [0_u8; 512]; let bytes_read = stream - .read(&mut buffer) + .read(&mut buffer).await .or_else(|err| Err(Error::ConnectionError(err)))?; if bytes_read == 0 { // The server has hung up! @@ -158,11 +160,11 @@ fn read_body(stream: &mut TcpStream, response: &mut http::Response>) -> /// closes the connection prematurely or sends an invalid response. /// /// You will need to modify this function in Milestone 2. -pub fn read_from_stream( +pub async fn read_from_stream( stream: &mut TcpStream, request_method: &http::Method, ) -> Result>, Error> { - let mut response = read_headers(stream)?; + let mut response = read_headers(stream).await?; // A response may have a body as long as it is not responding to a HEAD request and as long as // the response status code is not 1xx, 204 (no content), or 304 (not modified). if !(request_method == http::Method::HEAD @@ -170,7 +172,7 @@ pub fn read_from_stream( || response.status() == http::StatusCode::NO_CONTENT || response.status() == http::StatusCode::NOT_MODIFIED) { - read_body(stream, &mut response)?; + read_body(stream, &mut response).await?; } Ok(response) } @@ -178,20 +180,20 @@ pub fn read_from_stream( /// This function serializes a response to bytes and writes those bytes to the provided stream. /// /// You will need to modify this function in Milestone 2. -pub fn write_to_stream( +pub async fn write_to_stream( response: &http::Response>, stream: &mut TcpStream, ) -> Result<(), std::io::Error> { - stream.write(&format_response_line(response).into_bytes())?; - stream.write(&['\r' as u8, '\n' as u8])?; // \r\n + stream.write(&format_response_line(response).into_bytes()).await?; + stream.write(&['\r' as u8, '\n' as u8]).await?; // \r\n for (header_name, header_value) in response.headers() { - stream.write(&format!("{}: ", header_name).as_bytes())?; - stream.write(header_value.as_bytes())?; - stream.write(&['\r' as u8, '\n' as u8])?; // \r\n + stream.write(&format!("{}: ", header_name).as_bytes()).await?; + stream.write(header_value.as_bytes()).await?; + stream.write(&['\r' as u8, '\n' as u8]).await?; // \r\n } - stream.write(&['\r' as u8, '\n' as u8])?; + stream.write(&['\r' as u8, '\n' as u8]).await?; if response.body().len() > 0 { - stream.write(response.body())?; + stream.write(response.body()).await?; } Ok(()) }