Milestone 5, Project 2 all done
This commit is contained in:
parent
449145d796
commit
38afdcebe8
@ -51,4 +51,19 @@
|
|||||||
|
|
||||||
一个坑:不要使用前面实现的 `connect_to_upstream`,因为它无法连接到已经被移出可用列表的 upstream(不如说它根本就不让自己选 upstream),但是我们这里需要测试每个 upstream。
|
一个坑:不要使用前面实现的 `connect_to_upstream`,因为它无法连接到已经被移出可用列表的 upstream(不如说它根本就不让自己选 upstream),但是我们这里需要测试每个 upstream。
|
||||||
|
|
||||||
|
### Milestone 5
|
||||||
|
|
||||||
|
实现请求速率限制。针对每个 IP 限制一段时间内的请求数量,如果次数过多就返回 `HTTP 429`。提到了三个算法,leaky bucket、fixed window 和 sliding window。leaky bucket 不太适合这里的要求,就不管了。
|
||||||
|
|
||||||
|
这里实现了 fixed window 和 sliding window。这两个玩意基本上一样,fixed window 简单一点,单纯记录每个 time unit 里面请求的次数,达到上限就丢掉,开始下一个 unit 的时候就重置计数器。这就导致了一个问题,如果请求集中在计数器重置的时间点附近,就会导致短时间内的请求速率翻倍,实际上超过了限制。 sliding window 在此基础上考虑了上一个 unit,按照一定比例将上一次 unit 的请求计数折算到当前的请求计数中,从而避免了上述的问题。一种简单的实现方式就是,记录当前 unit 的计数和上个 unit 的计数,考虑当前时刻向前推一个 unit 的窗口,用窗口和上个 unit 重叠的比例进行折算。写成算式就是 $\frac{UNIT - (time.now - cur\_unit\_start)}{UNIT} \times prev\_cnt + cur\_ cnt$。
|
||||||
|
|
||||||
|
具体实现上,用一个 `HashMap<String, RateRecord>` 来维护 IP 地址到请求次数的信息。因为一个 IP 地址会发起多个 connection,所以得把数据放到全局的 `ProxyState` 里面再用 `Arc<Mutex<>>` 裹起来。封装了 `RateRecord` 类型来存放前面提到的这些数据,然后写了几个简单的小函数来让 handler 代码更加简洁一些。这些都不是什么难绷的事情,最难绷的是 `HashMap` 居然不支持直接修改内容,下面细说遇到的问题。
|
||||||
|
|
||||||
|
遇到的问题
|
||||||
|
1. `HashMap` 不能直接修改 `value` 的内容,说是要实现 `IndexMut` 这个 trait,但是 `HashMap` 没实现,所以不能像实现了这个 trait 的 `Vec` 类型那样直接 `map[key] = newvalue`,而是要用非常别扭的 `(*map.get_mut(&key).unwrap()).field = new_val_for_field` 来写。
|
||||||
|
2. 对于 primitive 类型之间的显式类型转换好像可以直接用 `as xx`,这个东西好像是编译器实现的,因此既不需要标准库也不需要什么其他的转换方法。不过剩下的转换就要通过 `From` 和 `Into` 这些 trait 来整了。
|
||||||
|
3. 最后就是,直接运行 `cargo test` 的时候有可能会在 `assert_eq!(total_request_count, rate_limit_threshold);` 这个断言上报错,不过单独测这一个点的时候从来不会错。根据打印信息判断,这大概是由于前面实现的 health check 的线程发送的检查请求也被计算在累计请求数中,所以导致这个东西爆炸。至于为啥会出现这个事情,可能是随机数(或者是 cpu 性能问题)导致测试时间过长,然后 health check 的线程开始执行了。
|
||||||
|
|
||||||
## 附加任务?
|
## 附加任务?
|
||||||
|
|
||||||
|
实验设计者提供了不少的可选项,看上去都挺棒的,但是我懒得自己写测试了,所以就不做了捏,啦啦啦。
|
||||||
@ -4,8 +4,8 @@ mod response;
|
|||||||
use clap::Parser;
|
use clap::Parser;
|
||||||
use rand::{Rng, SeedableRng};
|
use rand::{Rng, SeedableRng};
|
||||||
// use std::net::{TcpListener, TcpStream};
|
// use std::net::{TcpListener, TcpStream};
|
||||||
use std::sync::Arc;
|
use std::{collections::HashMap, sync::Arc};
|
||||||
use tokio::{net::{TcpListener, TcpStream}, io::AsyncWriteExt};
|
use tokio::net::{TcpListener, TcpStream};
|
||||||
|
|
||||||
/// Contains information parsed from the command-line invocation of balancebeam. The Clap macros
|
/// Contains information parsed from the command-line invocation of balancebeam. The Clap macros
|
||||||
/// provide a fancy way to automatically construct a command-line argument parser.
|
/// provide a fancy way to automatically construct a command-line argument parser.
|
||||||
@ -33,24 +33,79 @@ struct CmdOptions {
|
|||||||
max_requests_per_minute: usize,
|
max_requests_per_minute: usize,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
struct RateRecord {
|
||||||
|
count: usize,
|
||||||
|
pre_count: usize,
|
||||||
|
max_count: usize,
|
||||||
|
last_time: std::time::Instant,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RateRecord {
|
||||||
|
const TIME_UNIT: u64 = 60;
|
||||||
|
|
||||||
|
pub fn new(max_count: usize) -> RateRecord {
|
||||||
|
RateRecord {
|
||||||
|
count: 0,
|
||||||
|
pre_count: 0,
|
||||||
|
max_count,
|
||||||
|
last_time: std::time::Instant::now(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn update(&mut self) {
|
||||||
|
// log::debug!("[Cur]{:?} [Last]{:?} [Diff]{:?} [Cnt]{}", cur_time, last_time, diff, rate_record[&client_ip].count);
|
||||||
|
let cur_time = std::time::Instant::now();
|
||||||
|
let diff_unit = cur_time.duration_since(self.last_time).as_secs() / Self::TIME_UNIT;
|
||||||
|
if diff_unit > 0 {
|
||||||
|
self.pre_count = if diff_unit == 1 {
|
||||||
|
self.count
|
||||||
|
} else {
|
||||||
|
0 // if diff is greater than a unit, then there should be 0 reqs in the last unit
|
||||||
|
};
|
||||||
|
self.count = 1;
|
||||||
|
self.last_time += std::time::Duration::from_secs(diff_unit * Self::TIME_UNIT);
|
||||||
|
}
|
||||||
|
if self.count <= self.max_count {
|
||||||
|
self.count += 1; // to avoid overflow if time unit is long enough
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn check(&self) -> bool {
|
||||||
|
return self.count > self.max_count;
|
||||||
|
}
|
||||||
|
|
||||||
|
#[allow(dead_code)]
|
||||||
|
pub fn check_sliding_window(&self) -> bool {
|
||||||
|
let count_f64 = self.count as f64;
|
||||||
|
let pre_count_f64 = self.pre_count as f64;
|
||||||
|
let cur_time = std::time::Instant::now();
|
||||||
|
let time_unit_f64 = Self::TIME_UNIT as f64;
|
||||||
|
let pre_ratio =
|
||||||
|
(time_unit_f64 - cur_time.duration_since(self.last_time).as_secs_f64()) / time_unit_f64;
|
||||||
|
let final_count = pre_count_f64 * pre_ratio + count_f64;
|
||||||
|
// log::debug!("pre_ratio:{pre_ratio}, pre_count:{pre_count_f64}, count:{count_f64} final_count:{final_count}");
|
||||||
|
return final_count > self.max_count as f64;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
/// Contains information about the state of balancebeam (e.g. what servers we are currently proxying
|
/// Contains information about the state of balancebeam (e.g. what servers we are currently proxying
|
||||||
/// to, what servers have failed, rate limiting counts, etc.)
|
/// to, what servers have failed, rate limiting counts, etc.)
|
||||||
///
|
///
|
||||||
/// You should add fields to this struct in later milestones.
|
/// You should add fields to this struct in later milestones.
|
||||||
struct ProxyState {
|
struct ProxyState {
|
||||||
/// How frequently we check whether upstream servers are alive (Milestone 4)
|
/// How frequently we check whether upstream servers are alive (Milestone 4)
|
||||||
#[allow(dead_code)]
|
|
||||||
active_health_check_interval: usize,
|
active_health_check_interval: usize,
|
||||||
/// Where we should send requests when doing active health checks (Milestone 4)
|
/// Where we should send requests when doing active health checks (Milestone 4)
|
||||||
#[allow(dead_code)]
|
|
||||||
active_health_check_path: String,
|
active_health_check_path: String,
|
||||||
/// Maximum number of requests an individual IP can make in a minute (Milestone 5)
|
/// Maximum number of requests an individual IP can make in a minute (Milestone 5)
|
||||||
#[allow(dead_code)]
|
|
||||||
max_requests_per_minute: usize,
|
max_requests_per_minute: usize,
|
||||||
/// Addresses of servers that we are proxying to
|
/// Addresses of servers that we are proxying to
|
||||||
upstream_addresses: Vec<String>,
|
upstream_addresses: Vec<String>,
|
||||||
/// Dead upstream marks
|
/// Dead upstream marks
|
||||||
upstream_status: Arc<tokio::sync::Mutex<Vec<usize>>>,
|
upstream_status: Arc<tokio::sync::Mutex<Vec<usize>>>,
|
||||||
|
/// Rate limiting counter
|
||||||
|
rate_limit_counter: Arc<tokio::sync::Mutex<HashMap<String, RateRecord>>>,
|
||||||
}
|
}
|
||||||
|
|
||||||
async fn disable_upstream(state: &ProxyState, index: usize) {
|
async fn disable_upstream(state: &ProxyState, index: usize) {
|
||||||
@ -64,7 +119,7 @@ async fn disable_upstream(state: &ProxyState, index: usize) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
fn active_health_check(state: Arc<ProxyState>) {
|
fn active_health_check(state: Arc<ProxyState>) {
|
||||||
let _ = tokio::spawn( async move {
|
let _ = tokio::spawn(async move {
|
||||||
let active_health_check_interval = state.active_health_check_interval;
|
let active_health_check_interval = state.active_health_check_interval;
|
||||||
let active_health_check_path = &state.active_health_check_path;
|
let active_health_check_path = &state.active_health_check_path;
|
||||||
loop {
|
loop {
|
||||||
@ -103,10 +158,18 @@ fn active_health_check(state: Arc<ProxyState>) {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
if response.status() != 200 {
|
if response.status() != 200 {
|
||||||
log::info!("[HealthCheck] Get bad HTTP response {} from {}", response.status(), upstream);
|
log::info!(
|
||||||
|
"[HealthCheck] Get bad HTTP response {} from {}",
|
||||||
|
response.status(),
|
||||||
|
upstream
|
||||||
|
);
|
||||||
disable_upstream(&state, count - 1).await;
|
disable_upstream(&state, count - 1).await;
|
||||||
} else {
|
} else {
|
||||||
log::debug!("[HealthCheck] Get good HTTP response {} from {}", response.status(), upstream);
|
log::debug!(
|
||||||
|
"[HealthCheck] Get good HTTP response {} from {}",
|
||||||
|
response.status(),
|
||||||
|
upstream
|
||||||
|
);
|
||||||
let mut upstream_status = state.upstream_status.lock().await;
|
let mut upstream_status = state.upstream_status.lock().await;
|
||||||
if let None = upstream_status.iter().position(|&r| r == count - 1) {
|
if let None = upstream_status.iter().position(|&r| r == count - 1) {
|
||||||
upstream_status.push(count - 1);
|
upstream_status.push(count - 1);
|
||||||
@ -115,8 +178,7 @@ fn active_health_check(state: Arc<ProxyState>) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
});
|
||||||
);
|
|
||||||
}
|
}
|
||||||
|
|
||||||
#[tokio::main]
|
#[tokio::main]
|
||||||
@ -158,6 +220,7 @@ async fn main() -> std::io::Result<()> {
|
|||||||
active_health_check_path: options.active_health_check_path,
|
active_health_check_path: options.active_health_check_path,
|
||||||
max_requests_per_minute: options.max_requests_per_minute,
|
max_requests_per_minute: options.max_requests_per_minute,
|
||||||
upstream_status: Arc::new(tokio::sync::Mutex::new(upstream_status)),
|
upstream_status: Arc::new(tokio::sync::Mutex::new(upstream_status)),
|
||||||
|
rate_limit_counter: Arc::new(tokio::sync::Mutex::new(HashMap::new())),
|
||||||
};
|
};
|
||||||
let astate = Arc::new(state);
|
let astate = Arc::new(state);
|
||||||
// let _ = tokio::spawn(async move {
|
// let _ = tokio::spawn(async move {
|
||||||
@ -252,7 +315,6 @@ async fn handle_connection(mut client_conn: TcpStream, state: &ProxyState) {
|
|||||||
}
|
}
|
||||||
};
|
};
|
||||||
let upstream_ip = client_conn.peer_addr().unwrap().ip().to_string();
|
let upstream_ip = client_conn.peer_addr().unwrap().ip().to_string();
|
||||||
|
|
||||||
// The client may now send us one or more requests. Keep trying to read requests until the
|
// The client may now send us one or more requests. Keep trying to read requests until the
|
||||||
// client hangs up or we get an error.
|
// client hangs up or we get an error.
|
||||||
loop {
|
loop {
|
||||||
@ -289,6 +351,28 @@ async fn handle_connection(mut client_conn: TcpStream, state: &ProxyState) {
|
|||||||
upstream_ip,
|
upstream_ip,
|
||||||
request::format_request_line(&request)
|
request::format_request_line(&request)
|
||||||
);
|
);
|
||||||
|
// rate limit stuff
|
||||||
|
// Note: rate limit test may fail in the last assertion when running all tests
|
||||||
|
// The reason is that health check requests may be counted as requests that escaped rate-limiting,
|
||||||
|
// this never fails in unit test
|
||||||
|
if state.max_requests_per_minute != 0 {
|
||||||
|
let mut rate_record = state.rate_limit_counter.lock().await;
|
||||||
|
if !rate_record.contains_key(&client_ip) {
|
||||||
|
rate_record.insert(
|
||||||
|
client_ip.clone(),
|
||||||
|
RateRecord::new(state.max_requests_per_minute),
|
||||||
|
);
|
||||||
|
}
|
||||||
|
(*rate_record.get_mut(&client_ip).unwrap()).update();
|
||||||
|
if rate_record[&client_ip].check() {
|
||||||
|
let response = response::make_http_error(http::StatusCode::TOO_MANY_REQUESTS);
|
||||||
|
log::info!("Drop request due to rate limiting");
|
||||||
|
drop(rate_record);
|
||||||
|
send_response(&mut client_conn, &response).await;
|
||||||
|
continue;
|
||||||
|
}
|
||||||
|
drop(rate_record);
|
||||||
|
}
|
||||||
|
|
||||||
// Add X-Forwarded-For header so that the upstream server knows the client's IP address.
|
// Add X-Forwarded-For header so that the upstream server knows the client's IP address.
|
||||||
// (We're the ones connecting directly to the upstream server, so without this header, the
|
// (We're the ones connecting directly to the upstream server, so without this header, the
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user