Milestone 5, Project 2 all done

This commit is contained in:
ridethepig 2023-03-10 23:50:26 +08:00
parent 449145d796
commit 38afdcebe8
2 changed files with 110 additions and 11 deletions

View File

@ -51,4 +51,19 @@
一个坑:不要使用前面实现的 `connect_to_upstream`,因为它无法连接到已经被移出可用列表的 upstream不如说它根本就不让自己选 upstream但是我们这里需要测试每个 upstream。
### Milestone 5
实现请求速率限制。针对每个 IP 限制一段时间内的请求数量,如果次数过多就返回 `HTTP 429`。提到了三个算法leaky bucket、fixed window 和 sliding window。leaky bucket 不太适合这里的要求,就不管了。
这里实现了 fixed window 和 sliding window。这两个玩意基本上一样fixed window 简单一点,单纯记录每个 time unit 里面请求的次数,达到上限就丢掉,开始下一个 unit 的时候就重置计数器。这就导致了一个问题,如果请求集中在计数器重置的时间点附近,就会导致短时间内的请求速率翻倍,实际上超过了限制。 sliding window 在此基础上考虑了上一个 unit按照一定比例将上一次 unit 的请求计数折算到当前的请求计数中,从而避免了上述的问题。一种简单的实现方式就是,记录当前 unit 的计数和上个 unit 的计数,考虑当前时刻向前推一个 unit 的窗口,用窗口和上个 unit 重叠的比例进行折算。写成算式就是 $\frac{UNIT - (time.now - cur\_unit\_start)}{UNIT} \times prev\_cnt + cur\_ cnt$。
具体实现上,用一个 `HashMap<String, RateRecord>` 来维护 IP 地址到请求次数的信息。因为一个 IP 地址会发起多个 connection所以得把数据放到全局的 `ProxyState` 里面再用 `Arc<Mutex<>>` 裹起来。封装了 `RateRecord` 类型来存放前面提到的这些数据,然后写了几个简单的小函数来让 handler 代码更加简洁一些。这些都不是什么难绷的事情,最难绷的是 `HashMap` 居然不支持直接修改内容,下面细说遇到的问题。
遇到的问题
1. `HashMap` 不能直接修改 `value` 的内容,说是要实现 `IndexMut` 这个 trait但是 `HashMap` 没实现,所以不能像实现了这个 trait 的 `Vec` 类型那样直接 `map[key] = newvalue`,而是要用非常别扭的 `(*map.get_mut(&key).unwrap()).field = new_val_for_field` 来写。
2. 对于 primitive 类型之间的显式类型转换好像可以直接用 `as xx`,这个东西好像是编译器实现的,因此既不需要标准库也不需要什么其他的转换方法。不过剩下的转换就要通过 `From``Into` 这些 trait 来整了。
3. 最后就是,直接运行 `cargo test` 的时候有可能会在 `assert_eq!(total_request_count, rate_limit_threshold);` 这个断言上报错,不过单独测这一个点的时候从来不会错。根据打印信息判断,这大概是由于前面实现的 health check 的线程发送的检查请求也被计算在累计请求数中,所以导致这个东西爆炸。至于为啥会出现这个事情,可能是随机数(或者是 cpu 性能问题)导致测试时间过长,然后 health check 的线程开始执行了。
## 附加任务?
实验设计者提供了不少的可选项,看上去都挺棒的,但是我懒得自己写测试了,所以就不做了捏,啦啦啦。

View File

@ -4,8 +4,8 @@ mod response;
use clap::Parser;
use rand::{Rng, SeedableRng};
// use std::net::{TcpListener, TcpStream};
use std::sync::Arc;
use tokio::{net::{TcpListener, TcpStream}, io::AsyncWriteExt};
use std::{collections::HashMap, sync::Arc};
use tokio::net::{TcpListener, TcpStream};
/// Contains information parsed from the command-line invocation of balancebeam. The Clap macros
/// provide a fancy way to automatically construct a command-line argument parser.
@ -33,24 +33,79 @@ struct CmdOptions {
max_requests_per_minute: usize,
}
struct RateRecord {
count: usize,
pre_count: usize,
max_count: usize,
last_time: std::time::Instant,
}
impl RateRecord {
const TIME_UNIT: u64 = 60;
pub fn new(max_count: usize) -> RateRecord {
RateRecord {
count: 0,
pre_count: 0,
max_count,
last_time: std::time::Instant::now(),
}
}
pub fn update(&mut self) {
// log::debug!("[Cur]{:?} [Last]{:?} [Diff]{:?} [Cnt]{}", cur_time, last_time, diff, rate_record[&client_ip].count);
let cur_time = std::time::Instant::now();
let diff_unit = cur_time.duration_since(self.last_time).as_secs() / Self::TIME_UNIT;
if diff_unit > 0 {
self.pre_count = if diff_unit == 1 {
self.count
} else {
0 // if diff is greater than a unit, then there should be 0 reqs in the last unit
};
self.count = 1;
self.last_time += std::time::Duration::from_secs(diff_unit * Self::TIME_UNIT);
}
if self.count <= self.max_count {
self.count += 1; // to avoid overflow if time unit is long enough
}
}
#[allow(dead_code)]
pub fn check(&self) -> bool {
return self.count > self.max_count;
}
#[allow(dead_code)]
pub fn check_sliding_window(&self) -> bool {
let count_f64 = self.count as f64;
let pre_count_f64 = self.pre_count as f64;
let cur_time = std::time::Instant::now();
let time_unit_f64 = Self::TIME_UNIT as f64;
let pre_ratio =
(time_unit_f64 - cur_time.duration_since(self.last_time).as_secs_f64()) / time_unit_f64;
let final_count = pre_count_f64 * pre_ratio + count_f64;
// log::debug!("pre_ratio:{pre_ratio}, pre_count:{pre_count_f64}, count:{count_f64} final_count:{final_count}");
return final_count > self.max_count as f64;
}
}
/// Contains information about the state of balancebeam (e.g. what servers we are currently proxying
/// to, what servers have failed, rate limiting counts, etc.)
///
/// You should add fields to this struct in later milestones.
struct ProxyState {
/// How frequently we check whether upstream servers are alive (Milestone 4)
#[allow(dead_code)]
active_health_check_interval: usize,
/// Where we should send requests when doing active health checks (Milestone 4)
#[allow(dead_code)]
active_health_check_path: String,
/// Maximum number of requests an individual IP can make in a minute (Milestone 5)
#[allow(dead_code)]
max_requests_per_minute: usize,
/// Addresses of servers that we are proxying to
upstream_addresses: Vec<String>,
/// Dead upstream marks
upstream_status: Arc<tokio::sync::Mutex<Vec<usize>>>,
/// Rate limiting counter
rate_limit_counter: Arc<tokio::sync::Mutex<HashMap<String, RateRecord>>>,
}
async fn disable_upstream(state: &ProxyState, index: usize) {
@ -64,7 +119,7 @@ async fn disable_upstream(state: &ProxyState, index: usize) {
}
fn active_health_check(state: Arc<ProxyState>) {
let _ = tokio::spawn( async move {
let _ = tokio::spawn(async move {
let active_health_check_interval = state.active_health_check_interval;
let active_health_check_path = &state.active_health_check_path;
loop {
@ -103,10 +158,18 @@ fn active_health_check(state: Arc<ProxyState>) {
}
};
if response.status() != 200 {
log::info!("[HealthCheck] Get bad HTTP response {} from {}", response.status(), upstream);
log::info!(
"[HealthCheck] Get bad HTTP response {} from {}",
response.status(),
upstream
);
disable_upstream(&state, count - 1).await;
} else {
log::debug!("[HealthCheck] Get good HTTP response {} from {}", response.status(), upstream);
log::debug!(
"[HealthCheck] Get good HTTP response {} from {}",
response.status(),
upstream
);
let mut upstream_status = state.upstream_status.lock().await;
if let None = upstream_status.iter().position(|&r| r == count - 1) {
upstream_status.push(count - 1);
@ -115,8 +178,7 @@ fn active_health_check(state: Arc<ProxyState>) {
}
}
}
}
);
});
}
#[tokio::main]
@ -158,6 +220,7 @@ async fn main() -> std::io::Result<()> {
active_health_check_path: options.active_health_check_path,
max_requests_per_minute: options.max_requests_per_minute,
upstream_status: Arc::new(tokio::sync::Mutex::new(upstream_status)),
rate_limit_counter: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
};
let astate = Arc::new(state);
// let _ = tokio::spawn(async move {
@ -252,7 +315,6 @@ async fn handle_connection(mut client_conn: TcpStream, state: &ProxyState) {
}
};
let upstream_ip = client_conn.peer_addr().unwrap().ip().to_string();
// The client may now send us one or more requests. Keep trying to read requests until the
// client hangs up or we get an error.
loop {
@ -289,6 +351,28 @@ async fn handle_connection(mut client_conn: TcpStream, state: &ProxyState) {
upstream_ip,
request::format_request_line(&request)
);
// rate limit stuff
// Note: rate limit test may fail in the last assertion when running all tests
// The reason is that health check requests may be counted as requests that escaped rate-limiting,
// this never fails in unit test
if state.max_requests_per_minute != 0 {
let mut rate_record = state.rate_limit_counter.lock().await;
if !rate_record.contains_key(&client_ip) {
rate_record.insert(
client_ip.clone(),
RateRecord::new(state.max_requests_per_minute),
);
}
(*rate_record.get_mut(&client_ip).unwrap()).update();
if rate_record[&client_ip].check() {
let response = response::make_http_error(http::StatusCode::TOO_MANY_REQUESTS);
log::info!("Drop request due to rate limiting");
drop(rate_record);
send_response(&mut client_conn, &response).await;
continue;
}
drop(rate_record);
}
// Add X-Forwarded-For header so that the upstream server knows the client's IP address.
// (We're the ones connecting directly to the upstream server, so without this header, the