use axum::{ Extension, Json, Router, http::{ Method, StatusCode, header::{self, CONTENT_TYPE}, }, routing::get, // middleware, }; use axum::{body::Body, extract::Request, middleware::Next, response::Response}; use clap::Parser; use serde_json::json; use std::{net::SocketAddr, path::PathBuf, sync::Arc, time::Duration}; use tower_governor::{GovernorLayer, governor::GovernorConfigBuilder}; use tower_http::{ cors::{Any, CorsLayer}, trace::{DefaultMakeSpan, DefaultOnResponse, TraceLayer}, }; use tracing::Level; use tracing::info; mod auth; mod db; mod realtime; mod routes; pub struct AppConfig { pub avatar_dir: PathBuf, pub prohibit_registration: bool, } #[derive(clap::Parser, Debug)] #[command(author, version, about, long_about = None)] pub struct Cli { /// Server port #[arg(short, long, default_value = "8080")] port: String, /// Database URL #[arg(short, long, default_value = "0.0.0.0:5432")] database: String, /// Data directory path #[arg(short = 'D', long, default_value = "/var/lib/frangipane")] data_dir: String, /// Whether to disable user registration #[arg(short, long)] no_registration: bool, /// Verbose mode #[arg(short, long)] verbose: bool, } #[tokio::main] async fn main() -> anyhow::Result<()> { let cli = Cli::parse(); tracing_subscriber::fmt() .with_max_level(tracing::Level::INFO) .init(); tracing::info!("Connecting to database..."); let db_pool = db::init_db(cli.database).await?; let cors = CorsLayer::new() .allow_origin(Any) .allow_methods([Method::GET, Method::POST]) .allow_headers([header::AUTHORIZATION, header::CONTENT_TYPE]); let governor_conf = GovernorConfigBuilder::default() .burst_size(20) .per_millisecond(250) .finish() .unwrap(); let governor_limiter = governor_conf.limiter().clone(); // a separate background task to clean up let interval = Duration::from_secs(60); std::thread::spawn(move || { loop { std::thread::sleep(interval); // tracing::info!("rate limiting storage size: {}", governor_limiter.len()); governor_limiter.retain_recent(); } }); let realtime = realtime::Realtime::new(); let data_dir = PathBuf::from(cli.data_dir); let config = Arc::new(AppConfig { avatar_dir: data_dir.join("avatars"), prohibit_registration: cli.no_registration, }); let mut app = Router::new() .route("/version", get(get_version)) .merge(routes::users::routes()) .merge(routes::rooms::routes()) .merge(routes::messages::routes()) .merge(routes::friends::routes()) .merge(routes::ws::routes()) .layer(Extension(db_pool)) .layer(Extension(realtime)) .layer(Extension(config)) .layer(GovernorLayer::new(governor_conf)) .layer(cors); if cli.verbose { app = app.layer( TraceLayer::new_for_http() .make_span_with(DefaultMakeSpan::new().level(Level::INFO)) .on_response(DefaultOnResponse::new().level(Level::INFO)), ) // .layer(middleware::from_fn(log_json_body)); } let port = cli.port; let addr = format!("0.0.0.0:{port}"); let listener = tokio::net::TcpListener::bind(&addr).await.unwrap(); tracing::info!("Listening on {addr}"); axum::serve( listener, app.into_make_service_with_connect_info::(), ) .await .unwrap(); Ok(()) } async fn _log_json_body(req: Request, next: Next) -> Response { let (parts, body) = req.into_parts(); // Check if the content type is JSON let is_json = parts .headers .get(CONTENT_TYPE) .and_then(|v| v.to_str().ok()) .map_or(false, |v| v.contains("application/json")); let bytes = if is_json { // Read the body bytes let bytes = axum::body::to_bytes(body, usize::MAX) .await .unwrap_or_default(); // Log the body (converting to string) if let Ok(body_str) = std::str::from_utf8(&bytes) { info!("JSON Request Body: {}", body_str); } bytes } else { // If not JSON, we still need to collect it or just pass it through axum::body::to_bytes(body, usize::MAX) .await .unwrap_or_default() }; // Reconstruct the request with the bytes we read let req = Request::from_parts(parts, Body::from(bytes)); next.run(req).await } // Public route to get current version async fn get_version() -> Result, (StatusCode, String)> { const VERSION: &str = env!("CARGO_PKG_VERSION"); Ok(Json(json!({ "version": VERSION }))) }