Milestone 2

This commit is contained in:
ridethepig 2023-03-10 14:22:43 +08:00
parent e374f0421f
commit 374f2690d8
4 changed files with 80 additions and 67 deletions

View File

@ -30,5 +30,8 @@
`ThreadPool``std::thread` 没啥太大区别,都封装的很好了,直接看一下 `doc.rs` 里面的例子就行了。
### Milestone 2
又是一个奇怪的 task point要求把所有的 IO 换成异步的。但是依然没有任何的难度,根据 IDE 提示加 `async``await` 就可以了。
## 附加任务?

View File

@ -3,7 +3,8 @@ mod response;
use clap::Parser;
use rand::{Rng, SeedableRng};
use std::net::{TcpListener, TcpStream};
// use std::net::{TcpListener, TcpStream};
use tokio::net::{TcpListener, TcpStream};
/// Contains information parsed from the command-line invocation of balancebeam. The Clap macros
/// provide a fancy way to automatically construct a command-line argument parser.
@ -49,7 +50,8 @@ struct ProxyState {
upstream_addresses: Vec<String>,
}
fn main() {
#[tokio::main]
async fn main() -> std::io::Result<()> {
// Initialize the logging library. You can print log messages using the `log` macros:
// https://docs.rs/log/0.4.8/log/ You are welcome to continue using print! statements; this
// just looks a little prettier.
@ -66,7 +68,7 @@ fn main() {
}
// Start listening for connections
let listener = match TcpListener::bind(&options.bind) {
let listener = match TcpListener::bind(&options.bind).await {
Ok(listener) => listener,
Err(err) => {
log::error!("Could not bind to {}: {}", options.bind, err);
@ -84,55 +86,59 @@ fn main() {
};
let astate = std::sync::Arc::new(state);
// let thread_pool = threadpool::ThreadPool::new(4);
for stream in listener.incoming() {
if let Ok(stream) = stream {
// Handle the connection!
// ------ std thread ------
let thd_state = astate.clone();
std::thread::spawn(move || {
handle_connection(stream, &thd_state);
});
// ------ thread pool ------
// let thd_state = astate.clone();
// thread_pool.execute(move || {
// handle_connection(stream, &thd_state);
// });
// ------ single thread ------
// handle_connection(stream, &state);
}
// for stream in listener.incoming() {
// if let Ok(stream) = stream {
// // Handle the connection!
// // ------ std thread ------
// let thd_state = astate.clone();
// std::thread::spawn(move || {
// handle_connection(stream, &thd_state);
// });
// // ------ thread pool ------
// // let thd_state = astate.clone();
// // thread_pool.execute(move || {
// // handle_connection(stream, &thd_state);
// // });
// // ------ single thread ------
// // handle_connection(stream, &state);
// }
// }
loop {
let (socket, _) = listener.accept().await?;
handle_connection(socket, &astate).await;
}
}
fn connect_to_upstream(state: &ProxyState) -> Result<TcpStream, std::io::Error> {
async fn connect_to_upstream(state: &ProxyState) -> Result<TcpStream, std::io::Error> {
let mut rng = rand::rngs::StdRng::from_entropy();
let upstream_idx = rng.gen_range(0..state.upstream_addresses.len());
let upstream_ip = &state.upstream_addresses[upstream_idx];
TcpStream::connect(upstream_ip).or_else(|err| {
TcpStream::connect(upstream_ip).await.or_else(|err| {
log::error!("Failed to connect to upstream {}: {}", upstream_ip, err);
Err(err)
})
// TODO: implement failover (milestone 3)
}
fn send_response(client_conn: &mut TcpStream, response: &http::Response<Vec<u8>>) {
async fn send_response(client_conn: &mut TcpStream, response: &http::Response<Vec<u8>>) {
let client_ip = client_conn.peer_addr().unwrap().ip().to_string();
log::info!("{} <- {}", client_ip, response::format_response_line(&response));
if let Err(error) = response::write_to_stream(&response, client_conn) {
if let Err(error) = response::write_to_stream(&response, client_conn).await {
log::warn!("Failed to send response to client: {}", error);
return;
}
}
fn handle_connection(mut client_conn: TcpStream, state: &ProxyState) {
async fn handle_connection(mut client_conn: TcpStream, state: &ProxyState) {
let client_ip = client_conn.peer_addr().unwrap().ip().to_string();
log::info!("Connection received from {}", client_ip);
// Open a connection to a random destination server
let mut upstream_conn = match connect_to_upstream(state) {
let mut upstream_conn = match connect_to_upstream(state).await {
Ok(stream) => stream,
Err(_error) => {
let response = response::make_http_error(http::StatusCode::BAD_GATEWAY);
send_response(&mut client_conn, &response);
send_response(&mut client_conn, &response).await;
return;
}
};
@ -142,7 +148,7 @@ fn handle_connection(mut client_conn: TcpStream, state: &ProxyState) {
// client hangs up or we get an error.
loop {
// Read a request from the client
let mut request = match request::read_from_stream(&mut client_conn) {
let mut request = match request::read_from_stream(&mut client_conn).await {
Ok(request) => request,
// Handle case where client closed connection and is no longer sending requests
Err(request::Error::IncompleteRequest(0)) => {
@ -164,7 +170,7 @@ fn handle_connection(mut client_conn: TcpStream, state: &ProxyState) {
request::Error::RequestBodyTooLarge => http::StatusCode::PAYLOAD_TOO_LARGE,
request::Error::ConnectionError(_) => http::StatusCode::SERVICE_UNAVAILABLE,
});
send_response(&mut client_conn, &response);
send_response(&mut client_conn, &response).await;
continue;
}
};
@ -181,26 +187,26 @@ fn handle_connection(mut client_conn: TcpStream, state: &ProxyState) {
request::extend_header_value(&mut request, "x-forwarded-for", &client_ip);
// Forward the request to the server
if let Err(error) = request::write_to_stream(&request, &mut upstream_conn) {
if let Err(error) = request::write_to_stream(&request, &mut upstream_conn).await {
log::error!("Failed to send request to upstream {}: {}", upstream_ip, error);
let response = response::make_http_error(http::StatusCode::BAD_GATEWAY);
send_response(&mut client_conn, &response);
send_response(&mut client_conn, &response).await;
return;
}
log::debug!("Forwarded request to server");
// Read the server's response
let response = match response::read_from_stream(&mut upstream_conn, request.method()) {
let response = match response::read_from_stream(&mut upstream_conn, request.method()).await {
Ok(response) => response,
Err(error) => {
log::error!("Error reading response from server: {:?}", error);
let response = response::make_http_error(http::StatusCode::BAD_GATEWAY);
send_response(&mut client_conn, &response);
send_response(&mut client_conn, &response).await;
return;
}
};
// Forward the response to the client
send_response(&mut client_conn, &response);
send_response(&mut client_conn, &response).await;
log::debug!("Forwarded response to client");
}
}

View File

@ -1,6 +1,8 @@
use std::cmp::min;
use std::io::{Read, Write};
use std::net::TcpStream;
// use std::io::{Read, Write};
// use std::net::TcpStream;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpStream;
const MAX_HEADERS_SIZE: usize = 8000;
const MAX_BODY_SIZE: usize = 10000000;
@ -101,7 +103,7 @@ fn parse_request(buffer: &[u8]) -> Result<Option<(http::Request<Vec<u8>>, usize)
/// Returns Ok(http::Request) if a valid request is received, or Error if not.
///
/// You will need to modify this function in Milestone 2.
fn read_headers(stream: &mut TcpStream) -> Result<http::Request<Vec<u8>>, Error> {
async fn read_headers(stream: &mut TcpStream) -> Result<http::Request<Vec<u8>>, Error> {
// Try reading the headers from the request. We may not receive all the headers in one shot
// (e.g. we might receive the first few bytes of a request, and then the rest follows later).
// Try parsing repeatedly until we read a valid HTTP request
@ -110,7 +112,7 @@ fn read_headers(stream: &mut TcpStream) -> Result<http::Request<Vec<u8>>, Error>
loop {
// Read bytes from the connection into the buffer, starting at position bytes_read
let new_bytes = stream
.read(&mut request_buffer[bytes_read..])
.read(&mut request_buffer[bytes_read..]).await
.or_else(|err| Err(Error::ConnectionError(err)))?;
if new_bytes == 0 {
// We didn't manage to read a complete request
@ -137,7 +139,7 @@ fn read_headers(stream: &mut TcpStream) -> Result<http::Request<Vec<u8>>, Error>
/// returns Ok(()) if successful, or Err(Error) if Content-Length bytes couldn't be read.
///
/// You will need to modify this function in Milestone 2.
fn read_body(
async fn read_body(
stream: &mut TcpStream,
request: &mut http::Request<Vec<u8>>,
content_length: usize,
@ -147,7 +149,7 @@ fn read_body(
// Read up to 512 bytes at a time. (If the client only sent a small body, then only allocate
// space to read that body.)
let mut buffer = vec![0_u8; min(512, content_length)];
let bytes_read = stream.read(&mut buffer).or_else(|err| Err(Error::ConnectionError(err)))?;
let bytes_read = stream.read(&mut buffer).await.or_else(|err| Err(Error::ConnectionError(err)))?;
// Make sure the client is still sending us bytes
if bytes_read == 0 {
@ -178,15 +180,15 @@ fn read_body(
/// closes the connection prematurely or sends an invalid request.
///
/// You will need to modify this function in Milestone 2.
pub fn read_from_stream(stream: &mut TcpStream) -> Result<http::Request<Vec<u8>>, Error> {
pub async fn read_from_stream(stream: &mut TcpStream) -> Result<http::Request<Vec<u8>>, Error> {
// Read headers
let mut request = read_headers(stream)?;
let mut request = read_headers(stream).await?;
// Read body if the client supplied the Content-Length header (which it does for POST requests)
if let Some(content_length) = get_content_length(&request)? {
if content_length > MAX_BODY_SIZE {
return Err(Error::RequestBodyTooLarge);
} else {
read_body(stream, &mut request, content_length)?;
read_body(stream, &mut request, content_length).await?;
}
}
Ok(request)
@ -195,20 +197,20 @@ pub fn read_from_stream(stream: &mut TcpStream) -> Result<http::Request<Vec<u8>>
/// This function serializes a request to bytes and writes those bytes to the provided stream.
///
/// You will need to modify this function in Milestone 2.
pub fn write_to_stream(
pub async fn write_to_stream(
request: &http::Request<Vec<u8>>,
stream: &mut TcpStream,
) -> Result<(), std::io::Error> {
stream.write(&format_request_line(request).into_bytes())?;
stream.write(&['\r' as u8, '\n' as u8])?; // \r\n
stream.write(&format_request_line(request).into_bytes()).await?;
stream.write(&['\r' as u8, '\n' as u8]).await?; // \r\n
for (header_name, header_value) in request.headers() {
stream.write(&format!("{}: ", header_name).as_bytes())?;
stream.write(header_value.as_bytes())?;
stream.write(&['\r' as u8, '\n' as u8])?; // \r\n
stream.write(&format!("{}: ", header_name).as_bytes()).await?;
stream.write(header_value.as_bytes()).await?;
stream.write(&['\r' as u8, '\n' as u8]).await?; // \r\n
}
stream.write(&['\r' as u8, '\n' as u8])?;
stream.write(&['\r' as u8, '\n' as u8]).await?;
if request.body().len() > 0 {
stream.write(request.body())?;
stream.write(request.body()).await?;
}
Ok(())
}

View File

@ -1,5 +1,7 @@
use std::io::{Read, Write};
use std::net::TcpStream;
// use std::io::{Read, Write};
// use std::net::TcpStream;
use tokio::net::TcpStream;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
const MAX_HEADERS_SIZE: usize = 8000;
const MAX_BODY_SIZE: usize = 10000000;
@ -80,7 +82,7 @@ fn parse_response(buffer: &[u8]) -> Result<Option<(http::Response<Vec<u8>>, usiz
/// Returns Ok(http::Response) if a valid response is received, or Error if not.
///
/// You will need to modify this function in Milestone 2.
fn read_headers(stream: &mut TcpStream) -> Result<http::Response<Vec<u8>>, Error> {
async fn read_headers(stream: &mut TcpStream) -> Result<http::Response<Vec<u8>>, Error> {
// Try reading the headers from the response. We may not receive all the headers in one shot
// (e.g. we might receive the first few bytes of a response, and then the rest follows later).
// Try parsing repeatedly until we read a valid HTTP response
@ -89,7 +91,7 @@ fn read_headers(stream: &mut TcpStream) -> Result<http::Response<Vec<u8>>, Error
loop {
// Read bytes from the connection into the buffer, starting at position bytes_read
let new_bytes = stream
.read(&mut response_buffer[bytes_read..])
.read(&mut response_buffer[bytes_read..]).await
.or_else(|err| Err(Error::ConnectionError(err)))?;
if new_bytes == 0 {
// We didn't manage to read a complete response
@ -114,7 +116,7 @@ fn read_headers(stream: &mut TcpStream) -> Result<http::Response<Vec<u8>>, Error
/// present, it reads that many bytes; otherwise, it reads bytes until the connection is closed.
///
/// You will need to modify this function in Milestone 2.
fn read_body(stream: &mut TcpStream, response: &mut http::Response<Vec<u8>>) -> Result<(), Error> {
async fn read_body(stream: &mut TcpStream, response: &mut http::Response<Vec<u8>>) -> Result<(), Error> {
// The response may or may not supply a Content-Length header. If it provides the header, then
// we want to read that number of bytes; if it does not, we want to keep reading bytes until
// the connection is closed.
@ -123,7 +125,7 @@ fn read_body(stream: &mut TcpStream, response: &mut http::Response<Vec<u8>>) ->
while content_length.is_none() || response.body().len() < content_length.unwrap() {
let mut buffer = [0_u8; 512];
let bytes_read = stream
.read(&mut buffer)
.read(&mut buffer).await
.or_else(|err| Err(Error::ConnectionError(err)))?;
if bytes_read == 0 {
// The server has hung up!
@ -158,11 +160,11 @@ fn read_body(stream: &mut TcpStream, response: &mut http::Response<Vec<u8>>) ->
/// closes the connection prematurely or sends an invalid response.
///
/// You will need to modify this function in Milestone 2.
pub fn read_from_stream(
pub async fn read_from_stream(
stream: &mut TcpStream,
request_method: &http::Method,
) -> Result<http::Response<Vec<u8>>, Error> {
let mut response = read_headers(stream)?;
let mut response = read_headers(stream).await?;
// A response may have a body as long as it is not responding to a HEAD request and as long as
// the response status code is not 1xx, 204 (no content), or 304 (not modified).
if !(request_method == http::Method::HEAD
@ -170,7 +172,7 @@ pub fn read_from_stream(
|| response.status() == http::StatusCode::NO_CONTENT
|| response.status() == http::StatusCode::NOT_MODIFIED)
{
read_body(stream, &mut response)?;
read_body(stream, &mut response).await?;
}
Ok(response)
}
@ -178,20 +180,20 @@ pub fn read_from_stream(
/// This function serializes a response to bytes and writes those bytes to the provided stream.
///
/// You will need to modify this function in Milestone 2.
pub fn write_to_stream(
pub async fn write_to_stream(
response: &http::Response<Vec<u8>>,
stream: &mut TcpStream,
) -> Result<(), std::io::Error> {
stream.write(&format_response_line(response).into_bytes())?;
stream.write(&['\r' as u8, '\n' as u8])?; // \r\n
stream.write(&format_response_line(response).into_bytes()).await?;
stream.write(&['\r' as u8, '\n' as u8]).await?; // \r\n
for (header_name, header_value) in response.headers() {
stream.write(&format!("{}: ", header_name).as_bytes())?;
stream.write(header_value.as_bytes())?;
stream.write(&['\r' as u8, '\n' as u8])?; // \r\n
stream.write(&format!("{}: ", header_name).as_bytes()).await?;
stream.write(header_value.as_bytes()).await?;
stream.write(&['\r' as u8, '\n' as u8]).await?; // \r\n
}
stream.write(&['\r' as u8, '\n' as u8])?;
stream.write(&['\r' as u8, '\n' as u8]).await?;
if response.body().len() > 0 {
stream.write(response.body())?;
stream.write(response.body()).await?;
}
Ok(())
}