Milestone 2
This commit is contained in:
parent
e374f0421f
commit
374f2690d8
@ -30,5 +30,8 @@
|
||||
|
||||
`ThreadPool` 和 `std::thread` 没啥太大区别,都封装的很好了,直接看一下 `doc.rs` 里面的例子就行了。
|
||||
|
||||
### Milestone 2
|
||||
|
||||
又是一个奇怪的 task point,要求把所有的 IO 换成异步的。但是依然没有任何的难度,根据 IDE 提示加 `async` 和 `await` 就可以了。
|
||||
|
||||
## 附加任务?
|
||||
@ -3,7 +3,8 @@ mod response;
|
||||
|
||||
use clap::Parser;
|
||||
use rand::{Rng, SeedableRng};
|
||||
use std::net::{TcpListener, TcpStream};
|
||||
// use std::net::{TcpListener, TcpStream};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
|
||||
/// Contains information parsed from the command-line invocation of balancebeam. The Clap macros
|
||||
/// provide a fancy way to automatically construct a command-line argument parser.
|
||||
@ -49,7 +50,8 @@ struct ProxyState {
|
||||
upstream_addresses: Vec<String>,
|
||||
}
|
||||
|
||||
fn main() {
|
||||
#[tokio::main]
|
||||
async fn main() -> std::io::Result<()> {
|
||||
// Initialize the logging library. You can print log messages using the `log` macros:
|
||||
// https://docs.rs/log/0.4.8/log/ You are welcome to continue using print! statements; this
|
||||
// just looks a little prettier.
|
||||
@ -66,7 +68,7 @@ fn main() {
|
||||
}
|
||||
|
||||
// Start listening for connections
|
||||
let listener = match TcpListener::bind(&options.bind) {
|
||||
let listener = match TcpListener::bind(&options.bind).await {
|
||||
Ok(listener) => listener,
|
||||
Err(err) => {
|
||||
log::error!("Could not bind to {}: {}", options.bind, err);
|
||||
@ -84,55 +86,59 @@ fn main() {
|
||||
};
|
||||
let astate = std::sync::Arc::new(state);
|
||||
// let thread_pool = threadpool::ThreadPool::new(4);
|
||||
for stream in listener.incoming() {
|
||||
if let Ok(stream) = stream {
|
||||
// Handle the connection!
|
||||
// ------ std thread ------
|
||||
let thd_state = astate.clone();
|
||||
std::thread::spawn(move || {
|
||||
handle_connection(stream, &thd_state);
|
||||
});
|
||||
// ------ thread pool ------
|
||||
// let thd_state = astate.clone();
|
||||
// thread_pool.execute(move || {
|
||||
// handle_connection(stream, &thd_state);
|
||||
// });
|
||||
// ------ single thread ------
|
||||
// handle_connection(stream, &state);
|
||||
}
|
||||
// for stream in listener.incoming() {
|
||||
// if let Ok(stream) = stream {
|
||||
// // Handle the connection!
|
||||
// // ------ std thread ------
|
||||
// let thd_state = astate.clone();
|
||||
// std::thread::spawn(move || {
|
||||
// handle_connection(stream, &thd_state);
|
||||
// });
|
||||
// // ------ thread pool ------
|
||||
// // let thd_state = astate.clone();
|
||||
// // thread_pool.execute(move || {
|
||||
// // handle_connection(stream, &thd_state);
|
||||
// // });
|
||||
// // ------ single thread ------
|
||||
// // handle_connection(stream, &state);
|
||||
// }
|
||||
// }
|
||||
loop {
|
||||
let (socket, _) = listener.accept().await?;
|
||||
handle_connection(socket, &astate).await;
|
||||
}
|
||||
}
|
||||
|
||||
fn connect_to_upstream(state: &ProxyState) -> Result<TcpStream, std::io::Error> {
|
||||
async fn connect_to_upstream(state: &ProxyState) -> Result<TcpStream, std::io::Error> {
|
||||
let mut rng = rand::rngs::StdRng::from_entropy();
|
||||
let upstream_idx = rng.gen_range(0..state.upstream_addresses.len());
|
||||
let upstream_ip = &state.upstream_addresses[upstream_idx];
|
||||
TcpStream::connect(upstream_ip).or_else(|err| {
|
||||
TcpStream::connect(upstream_ip).await.or_else(|err| {
|
||||
log::error!("Failed to connect to upstream {}: {}", upstream_ip, err);
|
||||
Err(err)
|
||||
})
|
||||
// TODO: implement failover (milestone 3)
|
||||
}
|
||||
|
||||
fn send_response(client_conn: &mut TcpStream, response: &http::Response<Vec<u8>>) {
|
||||
async fn send_response(client_conn: &mut TcpStream, response: &http::Response<Vec<u8>>) {
|
||||
let client_ip = client_conn.peer_addr().unwrap().ip().to_string();
|
||||
log::info!("{} <- {}", client_ip, response::format_response_line(&response));
|
||||
if let Err(error) = response::write_to_stream(&response, client_conn) {
|
||||
if let Err(error) = response::write_to_stream(&response, client_conn).await {
|
||||
log::warn!("Failed to send response to client: {}", error);
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
||||
fn handle_connection(mut client_conn: TcpStream, state: &ProxyState) {
|
||||
async fn handle_connection(mut client_conn: TcpStream, state: &ProxyState) {
|
||||
let client_ip = client_conn.peer_addr().unwrap().ip().to_string();
|
||||
log::info!("Connection received from {}", client_ip);
|
||||
|
||||
// Open a connection to a random destination server
|
||||
let mut upstream_conn = match connect_to_upstream(state) {
|
||||
let mut upstream_conn = match connect_to_upstream(state).await {
|
||||
Ok(stream) => stream,
|
||||
Err(_error) => {
|
||||
let response = response::make_http_error(http::StatusCode::BAD_GATEWAY);
|
||||
send_response(&mut client_conn, &response);
|
||||
send_response(&mut client_conn, &response).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
@ -142,7 +148,7 @@ fn handle_connection(mut client_conn: TcpStream, state: &ProxyState) {
|
||||
// client hangs up or we get an error.
|
||||
loop {
|
||||
// Read a request from the client
|
||||
let mut request = match request::read_from_stream(&mut client_conn) {
|
||||
let mut request = match request::read_from_stream(&mut client_conn).await {
|
||||
Ok(request) => request,
|
||||
// Handle case where client closed connection and is no longer sending requests
|
||||
Err(request::Error::IncompleteRequest(0)) => {
|
||||
@ -164,7 +170,7 @@ fn handle_connection(mut client_conn: TcpStream, state: &ProxyState) {
|
||||
request::Error::RequestBodyTooLarge => http::StatusCode::PAYLOAD_TOO_LARGE,
|
||||
request::Error::ConnectionError(_) => http::StatusCode::SERVICE_UNAVAILABLE,
|
||||
});
|
||||
send_response(&mut client_conn, &response);
|
||||
send_response(&mut client_conn, &response).await;
|
||||
continue;
|
||||
}
|
||||
};
|
||||
@ -181,26 +187,26 @@ fn handle_connection(mut client_conn: TcpStream, state: &ProxyState) {
|
||||
request::extend_header_value(&mut request, "x-forwarded-for", &client_ip);
|
||||
|
||||
// Forward the request to the server
|
||||
if let Err(error) = request::write_to_stream(&request, &mut upstream_conn) {
|
||||
if let Err(error) = request::write_to_stream(&request, &mut upstream_conn).await {
|
||||
log::error!("Failed to send request to upstream {}: {}", upstream_ip, error);
|
||||
let response = response::make_http_error(http::StatusCode::BAD_GATEWAY);
|
||||
send_response(&mut client_conn, &response);
|
||||
send_response(&mut client_conn, &response).await;
|
||||
return;
|
||||
}
|
||||
log::debug!("Forwarded request to server");
|
||||
|
||||
// Read the server's response
|
||||
let response = match response::read_from_stream(&mut upstream_conn, request.method()) {
|
||||
let response = match response::read_from_stream(&mut upstream_conn, request.method()).await {
|
||||
Ok(response) => response,
|
||||
Err(error) => {
|
||||
log::error!("Error reading response from server: {:?}", error);
|
||||
let response = response::make_http_error(http::StatusCode::BAD_GATEWAY);
|
||||
send_response(&mut client_conn, &response);
|
||||
send_response(&mut client_conn, &response).await;
|
||||
return;
|
||||
}
|
||||
};
|
||||
// Forward the response to the client
|
||||
send_response(&mut client_conn, &response);
|
||||
send_response(&mut client_conn, &response).await;
|
||||
log::debug!("Forwarded response to client");
|
||||
}
|
||||
}
|
||||
|
||||
@ -1,6 +1,8 @@
|
||||
use std::cmp::min;
|
||||
use std::io::{Read, Write};
|
||||
use std::net::TcpStream;
|
||||
// use std::io::{Read, Write};
|
||||
// use std::net::TcpStream;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
use tokio::net::TcpStream;
|
||||
|
||||
const MAX_HEADERS_SIZE: usize = 8000;
|
||||
const MAX_BODY_SIZE: usize = 10000000;
|
||||
@ -101,7 +103,7 @@ fn parse_request(buffer: &[u8]) -> Result<Option<(http::Request<Vec<u8>>, usize)
|
||||
/// Returns Ok(http::Request) if a valid request is received, or Error if not.
|
||||
///
|
||||
/// You will need to modify this function in Milestone 2.
|
||||
fn read_headers(stream: &mut TcpStream) -> Result<http::Request<Vec<u8>>, Error> {
|
||||
async fn read_headers(stream: &mut TcpStream) -> Result<http::Request<Vec<u8>>, Error> {
|
||||
// Try reading the headers from the request. We may not receive all the headers in one shot
|
||||
// (e.g. we might receive the first few bytes of a request, and then the rest follows later).
|
||||
// Try parsing repeatedly until we read a valid HTTP request
|
||||
@ -110,7 +112,7 @@ fn read_headers(stream: &mut TcpStream) -> Result<http::Request<Vec<u8>>, Error>
|
||||
loop {
|
||||
// Read bytes from the connection into the buffer, starting at position bytes_read
|
||||
let new_bytes = stream
|
||||
.read(&mut request_buffer[bytes_read..])
|
||||
.read(&mut request_buffer[bytes_read..]).await
|
||||
.or_else(|err| Err(Error::ConnectionError(err)))?;
|
||||
if new_bytes == 0 {
|
||||
// We didn't manage to read a complete request
|
||||
@ -137,7 +139,7 @@ fn read_headers(stream: &mut TcpStream) -> Result<http::Request<Vec<u8>>, Error>
|
||||
/// returns Ok(()) if successful, or Err(Error) if Content-Length bytes couldn't be read.
|
||||
///
|
||||
/// You will need to modify this function in Milestone 2.
|
||||
fn read_body(
|
||||
async fn read_body(
|
||||
stream: &mut TcpStream,
|
||||
request: &mut http::Request<Vec<u8>>,
|
||||
content_length: usize,
|
||||
@ -147,7 +149,7 @@ fn read_body(
|
||||
// Read up to 512 bytes at a time. (If the client only sent a small body, then only allocate
|
||||
// space to read that body.)
|
||||
let mut buffer = vec![0_u8; min(512, content_length)];
|
||||
let bytes_read = stream.read(&mut buffer).or_else(|err| Err(Error::ConnectionError(err)))?;
|
||||
let bytes_read = stream.read(&mut buffer).await.or_else(|err| Err(Error::ConnectionError(err)))?;
|
||||
|
||||
// Make sure the client is still sending us bytes
|
||||
if bytes_read == 0 {
|
||||
@ -178,15 +180,15 @@ fn read_body(
|
||||
/// closes the connection prematurely or sends an invalid request.
|
||||
///
|
||||
/// You will need to modify this function in Milestone 2.
|
||||
pub fn read_from_stream(stream: &mut TcpStream) -> Result<http::Request<Vec<u8>>, Error> {
|
||||
pub async fn read_from_stream(stream: &mut TcpStream) -> Result<http::Request<Vec<u8>>, Error> {
|
||||
// Read headers
|
||||
let mut request = read_headers(stream)?;
|
||||
let mut request = read_headers(stream).await?;
|
||||
// Read body if the client supplied the Content-Length header (which it does for POST requests)
|
||||
if let Some(content_length) = get_content_length(&request)? {
|
||||
if content_length > MAX_BODY_SIZE {
|
||||
return Err(Error::RequestBodyTooLarge);
|
||||
} else {
|
||||
read_body(stream, &mut request, content_length)?;
|
||||
read_body(stream, &mut request, content_length).await?;
|
||||
}
|
||||
}
|
||||
Ok(request)
|
||||
@ -195,20 +197,20 @@ pub fn read_from_stream(stream: &mut TcpStream) -> Result<http::Request<Vec<u8>>
|
||||
/// This function serializes a request to bytes and writes those bytes to the provided stream.
|
||||
///
|
||||
/// You will need to modify this function in Milestone 2.
|
||||
pub fn write_to_stream(
|
||||
pub async fn write_to_stream(
|
||||
request: &http::Request<Vec<u8>>,
|
||||
stream: &mut TcpStream,
|
||||
) -> Result<(), std::io::Error> {
|
||||
stream.write(&format_request_line(request).into_bytes())?;
|
||||
stream.write(&['\r' as u8, '\n' as u8])?; // \r\n
|
||||
stream.write(&format_request_line(request).into_bytes()).await?;
|
||||
stream.write(&['\r' as u8, '\n' as u8]).await?; // \r\n
|
||||
for (header_name, header_value) in request.headers() {
|
||||
stream.write(&format!("{}: ", header_name).as_bytes())?;
|
||||
stream.write(header_value.as_bytes())?;
|
||||
stream.write(&['\r' as u8, '\n' as u8])?; // \r\n
|
||||
stream.write(&format!("{}: ", header_name).as_bytes()).await?;
|
||||
stream.write(header_value.as_bytes()).await?;
|
||||
stream.write(&['\r' as u8, '\n' as u8]).await?; // \r\n
|
||||
}
|
||||
stream.write(&['\r' as u8, '\n' as u8])?;
|
||||
stream.write(&['\r' as u8, '\n' as u8]).await?;
|
||||
if request.body().len() > 0 {
|
||||
stream.write(request.body())?;
|
||||
stream.write(request.body()).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
@ -1,5 +1,7 @@
|
||||
use std::io::{Read, Write};
|
||||
use std::net::TcpStream;
|
||||
// use std::io::{Read, Write};
|
||||
// use std::net::TcpStream;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::io::{AsyncReadExt, AsyncWriteExt};
|
||||
|
||||
const MAX_HEADERS_SIZE: usize = 8000;
|
||||
const MAX_BODY_SIZE: usize = 10000000;
|
||||
@ -80,7 +82,7 @@ fn parse_response(buffer: &[u8]) -> Result<Option<(http::Response<Vec<u8>>, usiz
|
||||
/// Returns Ok(http::Response) if a valid response is received, or Error if not.
|
||||
///
|
||||
/// You will need to modify this function in Milestone 2.
|
||||
fn read_headers(stream: &mut TcpStream) -> Result<http::Response<Vec<u8>>, Error> {
|
||||
async fn read_headers(stream: &mut TcpStream) -> Result<http::Response<Vec<u8>>, Error> {
|
||||
// Try reading the headers from the response. We may not receive all the headers in one shot
|
||||
// (e.g. we might receive the first few bytes of a response, and then the rest follows later).
|
||||
// Try parsing repeatedly until we read a valid HTTP response
|
||||
@ -89,7 +91,7 @@ fn read_headers(stream: &mut TcpStream) -> Result<http::Response<Vec<u8>>, Error
|
||||
loop {
|
||||
// Read bytes from the connection into the buffer, starting at position bytes_read
|
||||
let new_bytes = stream
|
||||
.read(&mut response_buffer[bytes_read..])
|
||||
.read(&mut response_buffer[bytes_read..]).await
|
||||
.or_else(|err| Err(Error::ConnectionError(err)))?;
|
||||
if new_bytes == 0 {
|
||||
// We didn't manage to read a complete response
|
||||
@ -114,7 +116,7 @@ fn read_headers(stream: &mut TcpStream) -> Result<http::Response<Vec<u8>>, Error
|
||||
/// present, it reads that many bytes; otherwise, it reads bytes until the connection is closed.
|
||||
///
|
||||
/// You will need to modify this function in Milestone 2.
|
||||
fn read_body(stream: &mut TcpStream, response: &mut http::Response<Vec<u8>>) -> Result<(), Error> {
|
||||
async fn read_body(stream: &mut TcpStream, response: &mut http::Response<Vec<u8>>) -> Result<(), Error> {
|
||||
// The response may or may not supply a Content-Length header. If it provides the header, then
|
||||
// we want to read that number of bytes; if it does not, we want to keep reading bytes until
|
||||
// the connection is closed.
|
||||
@ -123,7 +125,7 @@ fn read_body(stream: &mut TcpStream, response: &mut http::Response<Vec<u8>>) ->
|
||||
while content_length.is_none() || response.body().len() < content_length.unwrap() {
|
||||
let mut buffer = [0_u8; 512];
|
||||
let bytes_read = stream
|
||||
.read(&mut buffer)
|
||||
.read(&mut buffer).await
|
||||
.or_else(|err| Err(Error::ConnectionError(err)))?;
|
||||
if bytes_read == 0 {
|
||||
// The server has hung up!
|
||||
@ -158,11 +160,11 @@ fn read_body(stream: &mut TcpStream, response: &mut http::Response<Vec<u8>>) ->
|
||||
/// closes the connection prematurely or sends an invalid response.
|
||||
///
|
||||
/// You will need to modify this function in Milestone 2.
|
||||
pub fn read_from_stream(
|
||||
pub async fn read_from_stream(
|
||||
stream: &mut TcpStream,
|
||||
request_method: &http::Method,
|
||||
) -> Result<http::Response<Vec<u8>>, Error> {
|
||||
let mut response = read_headers(stream)?;
|
||||
let mut response = read_headers(stream).await?;
|
||||
// A response may have a body as long as it is not responding to a HEAD request and as long as
|
||||
// the response status code is not 1xx, 204 (no content), or 304 (not modified).
|
||||
if !(request_method == http::Method::HEAD
|
||||
@ -170,7 +172,7 @@ pub fn read_from_stream(
|
||||
|| response.status() == http::StatusCode::NO_CONTENT
|
||||
|| response.status() == http::StatusCode::NOT_MODIFIED)
|
||||
{
|
||||
read_body(stream, &mut response)?;
|
||||
read_body(stream, &mut response).await?;
|
||||
}
|
||||
Ok(response)
|
||||
}
|
||||
@ -178,20 +180,20 @@ pub fn read_from_stream(
|
||||
/// This function serializes a response to bytes and writes those bytes to the provided stream.
|
||||
///
|
||||
/// You will need to modify this function in Milestone 2.
|
||||
pub fn write_to_stream(
|
||||
pub async fn write_to_stream(
|
||||
response: &http::Response<Vec<u8>>,
|
||||
stream: &mut TcpStream,
|
||||
) -> Result<(), std::io::Error> {
|
||||
stream.write(&format_response_line(response).into_bytes())?;
|
||||
stream.write(&['\r' as u8, '\n' as u8])?; // \r\n
|
||||
stream.write(&format_response_line(response).into_bytes()).await?;
|
||||
stream.write(&['\r' as u8, '\n' as u8]).await?; // \r\n
|
||||
for (header_name, header_value) in response.headers() {
|
||||
stream.write(&format!("{}: ", header_name).as_bytes())?;
|
||||
stream.write(header_value.as_bytes())?;
|
||||
stream.write(&['\r' as u8, '\n' as u8])?; // \r\n
|
||||
stream.write(&format!("{}: ", header_name).as_bytes()).await?;
|
||||
stream.write(header_value.as_bytes()).await?;
|
||||
stream.write(&['\r' as u8, '\n' as u8]).await?; // \r\n
|
||||
}
|
||||
stream.write(&['\r' as u8, '\n' as u8])?;
|
||||
stream.write(&['\r' as u8, '\n' as u8]).await?;
|
||||
if response.body().len() > 0 {
|
||||
stream.write(response.body())?;
|
||||
stream.write(response.body()).await?;
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
Loading…
Reference in New Issue
Block a user