rate limiting and anti enumeration
Some checks failed
Build and Push Docker Image / build-and-push (push) Has been cancelled

This commit is contained in:
Dmitri 2026-05-03 12:10:21 +02:00
parent bc7866b4fb
commit 308639e418
Signed by: kanopo
GPG Key ID: 759ADD40E3132AC7
9 changed files with 94 additions and 27 deletions

View File

@ -0,0 +1,19 @@
use axum::{extract::Request, middleware::Next, response::Response};
use rand::RngExt;
use std::time::{Duration, Instant};
use tokio::time::sleep;
const MIN_DELAY_MS: u64 = 150;
const MAX_DELAY_MS: u64 = 500;
pub async fn random_delay_middleware(request: Request, next: Next) -> Response {
let start = Instant::now();
let target = Duration::from_millis(rand::rng().random_range(MIN_DELAY_MS..=MAX_DELAY_MS));
let response = next.run(request).await;
let elapsed = start.elapsed();
if elapsed < target {
sleep(target - elapsed).await;
}
response
}

View File

@ -0,0 +1,2 @@
pub mod anti_enumeration_middleware;
pub mod rate_limiting_middleware;

View File

@ -0,0 +1,41 @@
use crate::state::AppState;
use axum::{
extract::{ConnectInfo, Request, State},
http::StatusCode,
middleware::Next,
response::{IntoResponse, Response},
};
use std::net::SocketAddr;
pub async fn rate_limiting_middleware(
State(state): State<AppState>,
request: Request,
next: Next,
) -> Response {
let client_ip = request
.headers()
.get("x-client-ip")
.and_then(|h| h.to_str().ok())
.map(|s| s.to_string())
.unwrap_or_else(|| {
request
.extensions()
.get::<ConnectInfo<SocketAddr>>()
.map(|ci| ci.0.ip().to_string())
.unwrap_or_else(|| "unknown".to_string())
});
let has_tokens = {
let mut entry = state
.rate_limit
.entry(client_ip)
.or_insert_with(|| crate::state::TokenBucket::new());
entry.value_mut().take()
};
if has_tokens {
next.run(request).await
} else {
StatusCode::TOO_MANY_REQUESTS.into_response()
}
}

View File

@ -3,6 +3,7 @@ use tower_http::trace::TraceLayer;
use crate::state::AppState; use crate::state::AppState;
mod middleware;
pub mod model; pub mod model;
mod v1; mod v1;

View File

@ -1,8 +1,11 @@
use axum::extract::State; use axum::extract::State;
use axum::middleware::{from_fn, from_fn_with_state};
use axum::{Json, Router, routing::post}; use axum::{Json, Router, routing::post};
use tower_cookies::{CookieManagerLayer, Cookies}; use tower_cookies::{CookieManagerLayer, Cookies};
use crate::{ use crate::{
controller::middleware::anti_enumeration_middleware::random_delay_middleware,
controller::middleware::rate_limiting_middleware::rate_limiting_middleware,
controller::model::auth_model::{AuthResponse, LoginRequest, RegisterRequest}, controller::model::auth_model::{AuthResponse, LoginRequest, RegisterRequest},
errors::AppError, errors::AppError,
service::auth_service::{login, refresh, register}, service::auth_service::{login, refresh, register},
@ -15,6 +18,8 @@ pub fn auth_router(state: AppState) -> Router<AppState> {
.route("/register", post(register_handler)) .route("/register", post(register_handler))
.route("/refresh", post(refresh_handler)) .route("/refresh", post(refresh_handler))
.route("/logout", post(logout_handler)) .route("/logout", post(logout_handler))
.layer(from_fn(random_delay_middleware))
.layer(from_fn_with_state(state.clone(), rate_limiting_middleware))
.layer(CookieManagerLayer::new()) .layer(CookieManagerLayer::new())
} }

View File

@ -16,7 +16,10 @@ pub async fn init(cfg: &config::Config, db: PgPool) -> Result<(), AppError> {
.map_err(AppError::Bind)?; .map_err(AppError::Bind)?;
tracing::info!("Server started on {}", cfg.socket_address); tracing::info!("Server started on {}", cfg.socket_address);
axum::serve(listener, app) axum::serve(
listener,
app.into_make_service_with_connect_info::<std::net::SocketAddr>(),
)
.with_graceful_shutdown(logging::shutdown_signal()) .with_graceful_shutdown(logging::shutdown_signal())
.await .await
.map_err(AppError::Bind)?; .map_err(AppError::Bind)?;

View File

@ -1,7 +1,10 @@
use std::sync::Arc;
use std::time::Instant;
use dashmap::DashMap; use dashmap::DashMap;
use sqlx::PgPool; use sqlx::PgPool;
use std::sync::Arc;
use std::time::Instant;
const REFILL_RATE: f64 = 1.0;
const MAX_TOKENS: f64 = 10.0;
#[derive(Clone)] #[derive(Clone)]
pub struct AppState { pub struct AppState {
@ -16,10 +19,25 @@ pub struct TokenBucket {
} }
impl TokenBucket { impl TokenBucket {
pub fn new(max_tokens: f64) -> Self { pub fn new() -> Self {
Self { Self {
tokens: max_tokens, tokens: MAX_TOKENS,
last_refill: Instant::now(), last_refill: Instant::now(),
} }
} }
pub fn take(&mut self) -> bool {
let now = Instant::now();
let elapsed = now.duration_since(self.last_refill).as_secs_f64();
self.tokens = (self.tokens + elapsed * REFILL_RATE).min(MAX_TOKENS);
self.last_refill = now;
if self.tokens >= 1.0 {
self.tokens -= 1.0;
true
} else {
false
}
}
} }

View File

@ -1,21 +0,0 @@
use std::time::{Duration, Instant};
use rand::RngExt;
use tokio::time::sleep;
/// Anti-enumeration: ensures consistent response timing regardless of outcome.
/// Call at the end of request handler, before returning.
/// Range: 150-350ms (configurable)
///
/// # Arguments
/// * `start` - The Instant when request processing began
/// * `min_ms` - Minimum delay in milliseconds (default: 150)
/// * `max_ms` - Maximum delay in milliseconds (default: 350)
pub async fn anti_enumeration_delay(start: Instant, min_ms: u64, max_ms: u64) {
let target = min_ms + rand::rng().random::<u64>() % (max_ms - min_ms);
let target_duration = Duration::from_millis(target);
if let Some(remaining) = target_duration.checked_sub(start.elapsed()) {
sleep(remaining).await;
}
}

View File

@ -1,4 +1,3 @@
pub mod anti_enumeration;
pub mod hash; pub mod hash;
pub mod jwt; pub mod jwt;
pub mod refresh_token; pub mod refresh_token;