diff --git a/src/controller/middleware/anti_enumeration_middleware.rs b/src/controller/middleware/anti_enumeration_middleware.rs new file mode 100644 index 0000000..7d32f4b --- /dev/null +++ b/src/controller/middleware/anti_enumeration_middleware.rs @@ -0,0 +1,19 @@ +use axum::{extract::Request, middleware::Next, response::Response}; +use rand::RngExt; +use std::time::{Duration, Instant}; +use tokio::time::sleep; + +const MIN_DELAY_MS: u64 = 150; +const MAX_DELAY_MS: u64 = 500; + +pub async fn random_delay_middleware(request: Request, next: Next) -> Response { + let start = Instant::now(); + let target = Duration::from_millis(rand::rng().random_range(MIN_DELAY_MS..=MAX_DELAY_MS)); + + let response = next.run(request).await; + let elapsed = start.elapsed(); + if elapsed < target { + sleep(target - elapsed).await; + } + response +} diff --git a/src/controller/middleware/mod.rs b/src/controller/middleware/mod.rs new file mode 100644 index 0000000..a7523a9 --- /dev/null +++ b/src/controller/middleware/mod.rs @@ -0,0 +1,2 @@ +pub mod anti_enumeration_middleware; +pub mod rate_limiting_middleware; diff --git a/src/controller/middleware/rate_limiting_middleware.rs b/src/controller/middleware/rate_limiting_middleware.rs new file mode 100644 index 0000000..83ac2a3 --- /dev/null +++ b/src/controller/middleware/rate_limiting_middleware.rs @@ -0,0 +1,41 @@ +use crate::state::AppState; +use axum::{ + extract::{ConnectInfo, Request, State}, + http::StatusCode, + middleware::Next, + response::{IntoResponse, Response}, +}; +use std::net::SocketAddr; + +pub async fn rate_limiting_middleware( + State(state): State, + request: Request, + next: Next, +) -> Response { + let client_ip = request + .headers() + .get("x-client-ip") + .and_then(|h| h.to_str().ok()) + .map(|s| s.to_string()) + .unwrap_or_else(|| { + request + .extensions() + .get::>() + .map(|ci| ci.0.ip().to_string()) + .unwrap_or_else(|| "unknown".to_string()) + }); + + let has_tokens = { + let mut entry = state + .rate_limit + .entry(client_ip) + .or_insert_with(|| crate::state::TokenBucket::new()); + entry.value_mut().take() + }; + + if has_tokens { + next.run(request).await + } else { + StatusCode::TOO_MANY_REQUESTS.into_response() + } +} diff --git a/src/controller/mod.rs b/src/controller/mod.rs index f5aea4b..e76ac26 100644 --- a/src/controller/mod.rs +++ b/src/controller/mod.rs @@ -3,6 +3,7 @@ use tower_http::trace::TraceLayer; use crate::state::AppState; +mod middleware; pub mod model; mod v1; diff --git a/src/controller/v1/auth_controller.rs b/src/controller/v1/auth_controller.rs index bc4eaa3..ff950a2 100644 --- a/src/controller/v1/auth_controller.rs +++ b/src/controller/v1/auth_controller.rs @@ -1,8 +1,11 @@ use axum::extract::State; +use axum::middleware::{from_fn, from_fn_with_state}; use axum::{Json, Router, routing::post}; use tower_cookies::{CookieManagerLayer, Cookies}; use crate::{ + controller::middleware::anti_enumeration_middleware::random_delay_middleware, + controller::middleware::rate_limiting_middleware::rate_limiting_middleware, controller::model::auth_model::{AuthResponse, LoginRequest, RegisterRequest}, errors::AppError, service::auth_service::{login, refresh, register}, @@ -15,6 +18,8 @@ pub fn auth_router(state: AppState) -> Router { .route("/register", post(register_handler)) .route("/refresh", post(refresh_handler)) .route("/logout", post(logout_handler)) + .layer(from_fn(random_delay_middleware)) + .layer(from_fn_with_state(state.clone(), rate_limiting_middleware)) .layer(CookieManagerLayer::new()) } diff --git a/src/server.rs b/src/server.rs index afb2939..085f243 100644 --- a/src/server.rs +++ b/src/server.rs @@ -16,7 +16,10 @@ pub async fn init(cfg: &config::Config, db: PgPool) -> Result<(), AppError> { .map_err(AppError::Bind)?; tracing::info!("Server started on {}", cfg.socket_address); - axum::serve(listener, app) + axum::serve( + listener, + app.into_make_service_with_connect_info::(), + ) .with_graceful_shutdown(logging::shutdown_signal()) .await .map_err(AppError::Bind)?; diff --git a/src/state.rs b/src/state.rs index 2741bca..ed9cf6d 100644 --- a/src/state.rs +++ b/src/state.rs @@ -1,7 +1,10 @@ -use std::sync::Arc; -use std::time::Instant; use dashmap::DashMap; use sqlx::PgPool; +use std::sync::Arc; +use std::time::Instant; + +const REFILL_RATE: f64 = 1.0; +const MAX_TOKENS: f64 = 10.0; #[derive(Clone)] pub struct AppState { @@ -16,10 +19,25 @@ pub struct TokenBucket { } impl TokenBucket { - pub fn new(max_tokens: f64) -> Self { + pub fn new() -> Self { Self { - tokens: max_tokens, + tokens: MAX_TOKENS, last_refill: Instant::now(), } } + + pub fn take(&mut self) -> bool { + let now = Instant::now(); + let elapsed = now.duration_since(self.last_refill).as_secs_f64(); + + self.tokens = (self.tokens + elapsed * REFILL_RATE).min(MAX_TOKENS); + self.last_refill = now; + + if self.tokens >= 1.0 { + self.tokens -= 1.0; + true + } else { + false + } + } } diff --git a/src/utils/anti_enumeration.rs b/src/utils/anti_enumeration.rs deleted file mode 100644 index 2edf595..0000000 --- a/src/utils/anti_enumeration.rs +++ /dev/null @@ -1,21 +0,0 @@ -use std::time::{Duration, Instant}; - -use rand::RngExt; -use tokio::time::sleep; - -/// Anti-enumeration: ensures consistent response timing regardless of outcome. -/// Call at the end of request handler, before returning. -/// Range: 150-350ms (configurable) -/// -/// # Arguments -/// * `start` - The Instant when request processing began -/// * `min_ms` - Minimum delay in milliseconds (default: 150) -/// * `max_ms` - Maximum delay in milliseconds (default: 350) -pub async fn anti_enumeration_delay(start: Instant, min_ms: u64, max_ms: u64) { - let target = min_ms + rand::rng().random::() % (max_ms - min_ms); - let target_duration = Duration::from_millis(target); - - if let Some(remaining) = target_duration.checked_sub(start.elapsed()) { - sleep(remaining).await; - } -} diff --git a/src/utils/mod.rs b/src/utils/mod.rs index 4ce163b..a346033 100644 --- a/src/utils/mod.rs +++ b/src/utils/mod.rs @@ -1,4 +1,3 @@ -pub mod anti_enumeration; pub mod hash; pub mod jwt; pub mod refresh_token;