Milestone 5, Project 2 all done
This commit is contained in:
parent
449145d796
commit
38afdcebe8
@ -51,4 +51,19 @@
|
||||
|
||||
一个坑:不要使用前面实现的 `connect_to_upstream`,因为它无法连接到已经被移出可用列表的 upstream(不如说它根本就不让自己选 upstream),但是我们这里需要测试每个 upstream。
|
||||
|
||||
### Milestone 5
|
||||
|
||||
实现请求速率限制。针对每个 IP 限制一段时间内的请求数量,如果次数过多就返回 `HTTP 429`。提到了三个算法,leaky bucket、fixed window 和 sliding window。leaky bucket 不太适合这里的要求,就不管了。
|
||||
|
||||
这里实现了 fixed window 和 sliding window。这两个玩意基本上一样,fixed window 简单一点,单纯记录每个 time unit 里面请求的次数,达到上限就丢掉,开始下一个 unit 的时候就重置计数器。这就导致了一个问题,如果请求集中在计数器重置的时间点附近,就会导致短时间内的请求速率翻倍,实际上超过了限制。 sliding window 在此基础上考虑了上一个 unit,按照一定比例将上一次 unit 的请求计数折算到当前的请求计数中,从而避免了上述的问题。一种简单的实现方式就是,记录当前 unit 的计数和上个 unit 的计数,考虑当前时刻向前推一个 unit 的窗口,用窗口和上个 unit 重叠的比例进行折算。写成算式就是 $\frac{UNIT - (time.now - cur\_unit\_start)}{UNIT} \times prev\_cnt + cur\_ cnt$。
|
||||
|
||||
具体实现上,用一个 `HashMap<String, RateRecord>` 来维护 IP 地址到请求次数的信息。因为一个 IP 地址会发起多个 connection,所以得把数据放到全局的 `ProxyState` 里面再用 `Arc<Mutex<>>` 裹起来。封装了 `RateRecord` 类型来存放前面提到的这些数据,然后写了几个简单的小函数来让 handler 代码更加简洁一些。这些都不是什么难绷的事情,最难绷的是 `HashMap` 居然不支持直接修改内容,下面细说遇到的问题。
|
||||
|
||||
遇到的问题
|
||||
1. `HashMap` 不能直接修改 `value` 的内容,说是要实现 `IndexMut` 这个 trait,但是 `HashMap` 没实现,所以不能像实现了这个 trait 的 `Vec` 类型那样直接 `map[key] = newvalue`,而是要用非常别扭的 `(*map.get_mut(&key).unwrap()).field = new_val_for_field` 来写。
|
||||
2. 对于 primitive 类型之间的显式类型转换好像可以直接用 `as xx`,这个东西好像是编译器实现的,因此既不需要标准库也不需要什么其他的转换方法。不过剩下的转换就要通过 `From` 和 `Into` 这些 trait 来整了。
|
||||
3. 最后就是,直接运行 `cargo test` 的时候有可能会在 `assert_eq!(total_request_count, rate_limit_threshold);` 这个断言上报错,不过单独测这一个点的时候从来不会错。根据打印信息判断,这大概是由于前面实现的 health check 的线程发送的检查请求也被计算在累计请求数中,所以导致这个东西爆炸。至于为啥会出现这个事情,可能是随机数(或者是 cpu 性能问题)导致测试时间过长,然后 health check 的线程开始执行了。
|
||||
|
||||
## 附加任务?
|
||||
|
||||
实验设计者提供了不少的可选项,看上去都挺棒的,但是我懒得自己写测试了,所以就不做了捏,啦啦啦。
|
||||
@ -4,8 +4,8 @@ mod response;
|
||||
use clap::Parser;
|
||||
use rand::{Rng, SeedableRng};
|
||||
// use std::net::{TcpListener, TcpStream};
|
||||
use std::sync::Arc;
|
||||
use tokio::{net::{TcpListener, TcpStream}, io::AsyncWriteExt};
|
||||
use std::{collections::HashMap, sync::Arc};
|
||||
use tokio::net::{TcpListener, TcpStream};
|
||||
|
||||
/// Contains information parsed from the command-line invocation of balancebeam. The Clap macros
|
||||
/// provide a fancy way to automatically construct a command-line argument parser.
|
||||
@ -33,24 +33,79 @@ struct CmdOptions {
|
||||
max_requests_per_minute: usize,
|
||||
}
|
||||
|
||||
struct RateRecord {
|
||||
count: usize,
|
||||
pre_count: usize,
|
||||
max_count: usize,
|
||||
last_time: std::time::Instant,
|
||||
}
|
||||
|
||||
impl RateRecord {
|
||||
const TIME_UNIT: u64 = 60;
|
||||
|
||||
pub fn new(max_count: usize) -> RateRecord {
|
||||
RateRecord {
|
||||
count: 0,
|
||||
pre_count: 0,
|
||||
max_count,
|
||||
last_time: std::time::Instant::now(),
|
||||
}
|
||||
}
|
||||
|
||||
pub fn update(&mut self) {
|
||||
// log::debug!("[Cur]{:?} [Last]{:?} [Diff]{:?} [Cnt]{}", cur_time, last_time, diff, rate_record[&client_ip].count);
|
||||
let cur_time = std::time::Instant::now();
|
||||
let diff_unit = cur_time.duration_since(self.last_time).as_secs() / Self::TIME_UNIT;
|
||||
if diff_unit > 0 {
|
||||
self.pre_count = if diff_unit == 1 {
|
||||
self.count
|
||||
} else {
|
||||
0 // if diff is greater than a unit, then there should be 0 reqs in the last unit
|
||||
};
|
||||
self.count = 1;
|
||||
self.last_time += std::time::Duration::from_secs(diff_unit * Self::TIME_UNIT);
|
||||
}
|
||||
if self.count <= self.max_count {
|
||||
self.count += 1; // to avoid overflow if time unit is long enough
|
||||
}
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn check(&self) -> bool {
|
||||
return self.count > self.max_count;
|
||||
}
|
||||
|
||||
#[allow(dead_code)]
|
||||
pub fn check_sliding_window(&self) -> bool {
|
||||
let count_f64 = self.count as f64;
|
||||
let pre_count_f64 = self.pre_count as f64;
|
||||
let cur_time = std::time::Instant::now();
|
||||
let time_unit_f64 = Self::TIME_UNIT as f64;
|
||||
let pre_ratio =
|
||||
(time_unit_f64 - cur_time.duration_since(self.last_time).as_secs_f64()) / time_unit_f64;
|
||||
let final_count = pre_count_f64 * pre_ratio + count_f64;
|
||||
// log::debug!("pre_ratio:{pre_ratio}, pre_count:{pre_count_f64}, count:{count_f64} final_count:{final_count}");
|
||||
return final_count > self.max_count as f64;
|
||||
}
|
||||
}
|
||||
|
||||
/// Contains information about the state of balancebeam (e.g. what servers we are currently proxying
|
||||
/// to, what servers have failed, rate limiting counts, etc.)
|
||||
///
|
||||
/// You should add fields to this struct in later milestones.
|
||||
struct ProxyState {
|
||||
/// How frequently we check whether upstream servers are alive (Milestone 4)
|
||||
#[allow(dead_code)]
|
||||
active_health_check_interval: usize,
|
||||
/// Where we should send requests when doing active health checks (Milestone 4)
|
||||
#[allow(dead_code)]
|
||||
active_health_check_path: String,
|
||||
/// Maximum number of requests an individual IP can make in a minute (Milestone 5)
|
||||
#[allow(dead_code)]
|
||||
max_requests_per_minute: usize,
|
||||
/// Addresses of servers that we are proxying to
|
||||
upstream_addresses: Vec<String>,
|
||||
/// Dead upstream marks
|
||||
upstream_status: Arc<tokio::sync::Mutex<Vec<usize>>>,
|
||||
/// Rate limiting counter
|
||||
rate_limit_counter: Arc<tokio::sync::Mutex<HashMap<String, RateRecord>>>,
|
||||
}
|
||||
|
||||
async fn disable_upstream(state: &ProxyState, index: usize) {
|
||||
@ -103,10 +158,18 @@ fn active_health_check(state: Arc<ProxyState>) {
|
||||
}
|
||||
};
|
||||
if response.status() != 200 {
|
||||
log::info!("[HealthCheck] Get bad HTTP response {} from {}", response.status(), upstream);
|
||||
log::info!(
|
||||
"[HealthCheck] Get bad HTTP response {} from {}",
|
||||
response.status(),
|
||||
upstream
|
||||
);
|
||||
disable_upstream(&state, count - 1).await;
|
||||
} else {
|
||||
log::debug!("[HealthCheck] Get good HTTP response {} from {}", response.status(), upstream);
|
||||
log::debug!(
|
||||
"[HealthCheck] Get good HTTP response {} from {}",
|
||||
response.status(),
|
||||
upstream
|
||||
);
|
||||
let mut upstream_status = state.upstream_status.lock().await;
|
||||
if let None = upstream_status.iter().position(|&r| r == count - 1) {
|
||||
upstream_status.push(count - 1);
|
||||
@ -115,8 +178,7 @@ fn active_health_check(state: Arc<ProxyState>) {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
);
|
||||
});
|
||||
}
|
||||
|
||||
#[tokio::main]
|
||||
@ -158,6 +220,7 @@ async fn main() -> std::io::Result<()> {
|
||||
active_health_check_path: options.active_health_check_path,
|
||||
max_requests_per_minute: options.max_requests_per_minute,
|
||||
upstream_status: Arc::new(tokio::sync::Mutex::new(upstream_status)),
|
||||
rate_limit_counter: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
||||
};
|
||||
let astate = Arc::new(state);
|
||||
// let _ = tokio::spawn(async move {
|
||||
@ -252,7 +315,6 @@ async fn handle_connection(mut client_conn: TcpStream, state: &ProxyState) {
|
||||
}
|
||||
};
|
||||
let upstream_ip = client_conn.peer_addr().unwrap().ip().to_string();
|
||||
|
||||
// The client may now send us one or more requests. Keep trying to read requests until the
|
||||
// client hangs up or we get an error.
|
||||
loop {
|
||||
@ -289,6 +351,28 @@ async fn handle_connection(mut client_conn: TcpStream, state: &ProxyState) {
|
||||
upstream_ip,
|
||||
request::format_request_line(&request)
|
||||
);
|
||||
// rate limit stuff
|
||||
// Note: rate limit test may fail in the last assertion when running all tests
|
||||
// The reason is that health check requests may be counted as requests that escaped rate-limiting,
|
||||
// this never fails in unit test
|
||||
if state.max_requests_per_minute != 0 {
|
||||
let mut rate_record = state.rate_limit_counter.lock().await;
|
||||
if !rate_record.contains_key(&client_ip) {
|
||||
rate_record.insert(
|
||||
client_ip.clone(),
|
||||
RateRecord::new(state.max_requests_per_minute),
|
||||
);
|
||||
}
|
||||
(*rate_record.get_mut(&client_ip).unwrap()).update();
|
||||
if rate_record[&client_ip].check() {
|
||||
let response = response::make_http_error(http::StatusCode::TOO_MANY_REQUESTS);
|
||||
log::info!("Drop request due to rate limiting");
|
||||
drop(rate_record);
|
||||
send_response(&mut client_conn, &response).await;
|
||||
continue;
|
||||
}
|
||||
drop(rate_record);
|
||||
}
|
||||
|
||||
// Add X-Forwarded-For header so that the upstream server knows the client's IP address.
|
||||
// (We're the ones connecting directly to the upstream server, so without this header, the
|
||||
|
||||
Loading…
Reference in New Issue
Block a user