diff --git a/db/init.sql b/db/init.sql index e8bacf5..073e72b 100644 --- a/db/init.sql +++ b/db/init.sql @@ -3,6 +3,7 @@ CREATE TABLE IF NOT EXISTS user_ ( uuid UUID UNIQUE, email TEXT UNIQUE, username TEXT NOT NULL UNIQUE, + avatar_url TEXT, password_hash TEXT NOT NULL ); diff --git a/src/main.rs b/src/main.rs index cfd173a..8e9b90f 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,14 +1,15 @@ use axum::{ - Extension, Router, + Extension, + Router, http::{ Method, header::{self, CONTENT_TYPE}, }, - middleware, + // middleware, }; use axum::{body::Body, extract::Request, middleware::Next, response::Response}; use clap::Parser; -use std::{net::SocketAddr, time::Duration}; +use std::{net::SocketAddr, path::PathBuf, sync::Arc, time::Duration}; use tower_governor::{GovernorLayer, governor::GovernorConfigBuilder}; use tower_http::{ cors::{Any, CorsLayer}, @@ -22,6 +23,10 @@ mod db; mod realtime; mod routes; +pub struct AppConfig { + pub avatar_dir: PathBuf, +} + #[derive(clap::Parser, Debug)] #[command(author, version, about, long_about = None)] pub struct Cli { @@ -33,6 +38,10 @@ pub struct Cli { #[arg(short, long, default_value = "localhost:5432")] database: String, + /// Data directory path + #[arg(short = 'D', long, default_value = "/var/lib/chatapp")] + data_dir: String, + /// Verbose mode #[arg(short, long)] verbose: bool, @@ -74,6 +83,11 @@ async fn main() -> anyhow::Result<()> { let realtime = realtime::Realtime::new(); + let data_dir = PathBuf::from(cli.data_dir); + let config = Arc::new(AppConfig { + avatar_dir: data_dir.join("avatars"), + }); + let mut app = Router::new() .merge(routes::users::routes()) .merge(routes::rooms::routes()) @@ -82,17 +96,17 @@ async fn main() -> anyhow::Result<()> { .merge(routes::ws::routes()) .layer(Extension(db_pool)) .layer(Extension(realtime)) + .layer(Extension(config)) .layer(GovernorLayer::new(governor_conf)) .layer(cors); if cli.verbose { - app = app - .layer( - TraceLayer::new_for_http() - .make_span_with(DefaultMakeSpan::new().level(Level::INFO)) - .on_response(DefaultOnResponse::new().level(Level::INFO)), - ) - .layer(middleware::from_fn(log_json_body)); + app = app.layer( + TraceLayer::new_for_http() + .make_span_with(DefaultMakeSpan::new().level(Level::INFO)) + .on_response(DefaultOnResponse::new().level(Level::INFO)), + ) + // .layer(middleware::from_fn(log_json_body)); } let port = cli.port; @@ -111,7 +125,7 @@ async fn main() -> anyhow::Result<()> { Ok(()) } -async fn log_json_body(req: Request, next: Next) -> Response { +async fn _log_json_body(req: Request, next: Next) -> Response { let (parts, body) = req.into_parts(); // Check if the content type is JSON diff --git a/src/routes/users.rs b/src/routes/users.rs index 2ff7324..44afc1d 100644 --- a/src/routes/users.rs +++ b/src/routes/users.rs @@ -1,17 +1,18 @@ use axum::{ Extension, Json, Router, - extract::Request, + extract::{Path, Request}, http::{HeaderMap, StatusCode}, middleware::Next, response::Response, routing::{get, post, put}, }; use sqlx::PgPool; -use std::env; +use std::{env, sync::Arc}; use uuid::Uuid; use validator::ValidateEmail; use crate::{ + AppConfig, auth::{create_jwt, hash_password, validate_token, verify_jwt, verify_password}, db::{user_id_from_uuid, username_from_uuid}, }; @@ -65,7 +66,9 @@ pub fn routes() -> Router { .route("/login", post(login)) .route("/register", post(register_user)) .route("/validate-token", get(validate_token)) - .route("/account", put(update_user)) + .route("/account/settings", put(update_user)) + .route("/account/upload-avatar", post(upload_avatar)) + .route("/account/get-avatar/{uuid}", get(get_avatar)) .layer(axum::middleware::from_fn(registration_guard)) } @@ -264,3 +267,75 @@ pub async fn update_user( }), )) } + +async fn upload_avatar( + headers: HeaderMap, + Extension(db): Extension, + Extension(config): Extension>, + body: axum::body::Bytes, +) -> Result { + let claims = verify_jwt(headers)?; + + let user_id = user_id_from_uuid(&db, claims.sub).await?; + tracing::info!("User ID {} is uploading {} bytes)", user_id, body.len()); + + let base_dir = &config.avatar_dir; + let file_extension = "png"; // TODO: detect MIME type + let filename = format!("{}.{}", claims.sub, file_extension); + let full_path = std::path::Path::new(&base_dir).join(&filename); + + tokio::fs::create_dir_all(&base_dir).await.map_err(|e| { + tracing::error!("Failed to create storage: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to upload file".into(), + ) + })?; + + tokio::fs::write(&full_path, body).await.map_err(|e| { + tracing::error!("Failed to save file: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + "Failed to upload file".into(), + ) + })?; + + sqlx::query("UPDATE user_ SET avatar_url = $1 WHERE id = $2") + .bind(filename) + .bind(user_id) + .execute(&db) + .await + .map_err(|e| { + tracing::error!("DB error: {}", e); + (StatusCode::INTERNAL_SERVER_ERROR, "DB erorr".into()) + })?; + + Ok(StatusCode::OK) +} + +// Public route +async fn get_avatar( + Path(uuid): Path, + Extension(config): Extension>, +) -> Result { + let base_dir = &config.avatar_dir; + let filename = format!("{}.png", uuid); + let full_path = std::path::Path::new(&base_dir).join(filename); + + if !full_path.exists() { + return Err((StatusCode::NOT_FOUND, "Avatar not found".into())); + } + + let file_contents = tokio::fs::read(&full_path).await.map_err(|e| { + tracing::error!("Could not read avatar file: {}", e); + ( + StatusCode::INTERNAL_SERVER_ERROR, + "Could not read file".into(), + ) + })?; + + Ok(Response::builder() + .header("Content-Type", "image/png") + .body(axum::body::Body::from(file_contents)) + .unwrap()) +}