diff --git a/Cargo.lock b/Cargo.lock index 982aa58..462d7a5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -83,6 +83,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5b098575ebe77cb6d14fc7f32749631a6e44edbef6b796f89b020e99ba20d425" dependencies = [ "axum-core", + "base64", "bytes", "form_urlencoded", "futures-util", @@ -102,8 +103,10 @@ dependencies = [ "serde_json", "serde_path_to_error", "serde_urlencoded", + "sha1", "sync_wrapper", "tokio", + "tokio-tungstenite", "tower", "tower-layer", "tower-service", @@ -210,6 +213,7 @@ dependencies = [ "argon2", "axum", "chrono", + "dashmap", "jsonwebtoken", "password-hash", "serde", @@ -332,6 +336,12 @@ dependencies = [ "parking_lot_core", ] +[[package]] +name = "data-encoding" +version = "2.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2a2330da5de22e8a3cb63252ce2abb30116bf5265e89c0e01bc17015ce30a476" + [[package]] name = "der" version = "0.7.10" @@ -2266,6 +2276,18 @@ dependencies = [ "tokio", ] +[[package]] +name = "tokio-tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d25a406cddcc431a75d3d9afc6a7c0f7428d4891dd973e4d54c56b46127bf857" +dependencies = [ + "futures-util", + "log", + "tokio", + "tungstenite", +] + [[package]] name = "tokio-util" version = "0.7.17" @@ -2436,6 +2458,23 @@ version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e421abadd41a4225275504ea4d6566923418b7f05506fbc9c0fe86ba7396114b" +[[package]] +name = "tungstenite" +version = "0.28.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8628dcc84e5a09eb3d8423d6cb682965dea9133204e8fb3efee74c2a0c259442" +dependencies = [ + "bytes", + "data-encoding", + "http", + "httparse", + "log", + "rand 0.9.2", + "sha1", + "thiserror 2.0.17", + "utf-8", +] + [[package]] name = "typenum" version = "1.19.0" @@ -2487,6 +2526,12 @@ dependencies = [ "serde", ] +[[package]] +name = "utf-8" +version = "0.7.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09cc8ee72d2a9becf2f2febe0205bbed8fc6615b7cb429ad062dc7b7ddd036a9" + [[package]] name = "utf8_iter" version = "1.0.4" diff --git a/Cargo.toml b/Cargo.toml index 1fcdb4b..e80324d 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -6,8 +6,9 @@ edition = "2024" [dependencies] anyhow = "1.0.99" argon2 = "0.5.3" -axum = { version = "0.8.4", features = ["multipart"] } +axum = { version = "0.8.4", features = ["multipart", "ws"] } chrono = { version = "0.4.42", features = ["serde"] } +dashmap = "6.1.0" jsonwebtoken = "9.3.1" password-hash = "0.5.0" serde = { version = "1.0.219", features = ["derive"] } diff --git a/db/init.sql b/db/init.sql index 838808a..c3a97f4 100644 --- a/db/init.sql +++ b/db/init.sql @@ -28,6 +28,12 @@ CREATE TABLE IF NOT EXISTS message_ ( sent_at TIMESTAMP ); +CREATE TABLE ws_token_ ( + token TEXT PRIMARY KEY, + room_id INT NOT NULL, + expires_at TIMESTAMPTZ NOT NULL +); + -- Message timestamp creation CREATE OR REPLACE FUNCTION create_message_timestamp() RETURNS trigger diff --git a/src/main.rs b/src/main.rs index ae66b43..a2ee9d4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -8,6 +8,7 @@ use tower_http::cors::{Any, CorsLayer}; mod auth; mod db; +mod realtime; mod routes; #[tokio::main] @@ -24,8 +25,8 @@ async fn main() -> anyhow::Result<()> { .allow_headers([header::AUTHORIZATION, header::CONTENT_TYPE]); let governor_conf = GovernorConfigBuilder::default() - .per_second(25) - .burst_size(50) + .per_second(50) + .burst_size(200) .finish() .unwrap(); @@ -41,11 +42,15 @@ async fn main() -> anyhow::Result<()> { } }); + let realtime = realtime::Realtime::new(); + let app = Router::new() .merge(routes::users::routes()) .merge(routes::rooms::routes()) .merge(routes::messages::routes()) + .merge(routes::ws::routes()) .layer(Extension(db_pool)) + .layer(Extension(realtime)) .layer(cors) .layer(GovernorLayer::new(governor_conf)); diff --git a/src/realtime.rs b/src/realtime.rs new file mode 100644 index 0000000..14cdb17 --- /dev/null +++ b/src/realtime.rs @@ -0,0 +1,27 @@ +use dashmap::DashMap; +use std::sync::Arc; +use tokio::sync::broadcast; + +use crate::routes::messages::Message; + +pub type RoomId = i32; + +#[derive(Clone)] +pub struct Realtime { + pub rooms: Arc>>, +} + +impl Realtime { + pub fn new() -> Self { + Self { + rooms: Arc::new(DashMap::new()), + } + } + + pub fn sender_for(&self, room: RoomId) -> broadcast::Sender { + self.rooms + .entry(room) + .or_insert_with(|| broadcast::channel(100).0) + .clone() + } +} diff --git a/src/routes/messages.rs b/src/routes/messages.rs index a9c7117..f2ba177 100644 --- a/src/routes/messages.rs +++ b/src/routes/messages.rs @@ -7,8 +7,11 @@ use axum::{ use sqlx::PgPool; use uuid::Uuid; -use crate::db::{user_id_from_uuid, username_from_uuid}; use crate::{auth::verify_jwt, db::room_id_from_uuid}; +use crate::{ + db::{user_id_from_uuid, username_from_uuid}, + realtime::Realtime, +}; #[derive(sqlx::FromRow, serde::Serialize, Debug)] pub struct MessageRow { @@ -18,7 +21,7 @@ pub struct MessageRow { pub sent_at: chrono::NaiveDateTime, } -#[derive(sqlx::FromRow, serde::Serialize, Debug)] +#[derive(sqlx::FromRow, serde::Serialize, Debug, Clone)] pub struct Message { pub sender: String, pub message_type: String, @@ -105,6 +108,7 @@ async fn list_messages( async fn create_message( Path(room_uuid): Path, Extension(db): Extension, + Extension(realtime): Extension, headers: HeaderMap, Json(payload): Json, ) -> Result<(StatusCode, Json), (StatusCode, String)> { @@ -133,13 +137,15 @@ async fn create_message( let sender_name = username_from_uuid(&db, claims.sub).await?; - Ok(( - StatusCode::CREATED, - Json(Message { - sender: sender_name, - message_type: payload.message_type, - content: payload.content, - sent_at: sent_at.format("%Y-%m-%d %H:%M:%S").to_string(), - }), - )) + let message = Message { + sender: sender_name, + message_type: payload.message_type, + content: payload.content, + sent_at: sent_at.format("%Y-%m-%d %H:%M:%S").to_string(), + }; + + let rt_sender = realtime.sender_for(room_id); + let _ = rt_sender.send(message.clone()); + + Ok((StatusCode::CREATED, Json(message))) } diff --git a/src/routes/mod.rs b/src/routes/mod.rs index f5e9592..8283042 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -1,3 +1,4 @@ pub mod messages; pub mod rooms; pub mod users; +pub mod ws; diff --git a/src/routes/ws.rs b/src/routes/ws.rs new file mode 100644 index 0000000..422c100 --- /dev/null +++ b/src/routes/ws.rs @@ -0,0 +1,135 @@ +use axum::Json; +use axum::extract::Query; +use axum::extract::ws::{Message as WsMessage, WebSocket}; +use axum::http::HeaderMap; +use axum::routing::get; +use axum::{ + Extension, + extract::{Path, WebSocketUpgrade}, + http::StatusCode, + response::IntoResponse, +}; +use serde::Deserialize; +use uuid::Uuid; + +use crate::auth::{create_jwt, verify_jwt}; +use crate::db::user_id_from_uuid; +use crate::{db::room_id_from_uuid, realtime::Realtime}; + +#[derive(sqlx::FromRow, serde::Serialize, Deserialize)] +pub struct WsAuthQuery { + pub token: String, +} + +pub fn routes() -> axum::Router { + axum::Router::new() + .route("/ws/issue-token/rooms/{room_uuid}", get(issue_ws_token)) + .route("/ws/rooms/{room_uuid}", get(ws_handler)) +} + +pub async fn issue_ws_token( + Extension(db): Extension, + headers: HeaderMap, + Path(room_uuid): Path, +) -> Result<(StatusCode, Json), (StatusCode, String)> { + let claims = verify_jwt(headers)?; + + let room_id = room_id_from_uuid(&db, room_uuid).await?; + let user_id = user_id_from_uuid(&db, claims.sub).await?; + + let membership: Vec = + sqlx::query_scalar("SELECT user_id FROM membership_ WHERE user_id = $1 AND room = $2") + .bind(user_id) + .bind(room_id) + .fetch_all(&db) + .await + .unwrap_or_else(|_| Vec::new()); + + if membership.is_empty() { + return Err(( + StatusCode::UNAUTHORIZED, + String::from("You are not a member of this room"), + )); + } + + // tracing::info!( + // "recieved token issue request from user {} for room {}", + // claims.sub, + // room_uuid + // ); + + let token = create_jwt(claims.sub).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?; + + sqlx::query( + r#" + insert into ws_token_ (token, room_id, expires_at) + values ($1, $2, now() + interval '30 seconds') + "#, + ) + .bind(&token) + .bind(room_id) + .execute(&db) + .await + .map_err(|e| { + tracing::error!("failed to insert ws token: {e}"); + ( + StatusCode::INTERNAL_SERVER_ERROR, + format!("failed to insert ws token: {e}"), + ) + })?; + + Ok((StatusCode::CREATED, Json(WsAuthQuery { token }))) +} + +async fn ws_handler( + ws: WebSocketUpgrade, + Path(room_uuid): Path, + Query(query): Query, + Extension(realtime): Extension, + Extension(db): Extension, +) -> Result { + tracing::info!("recieved ws handshake: {}", room_uuid); + + let room_id = room_id_from_uuid(&db, room_uuid) + .await + .map_err(|_| StatusCode::NOT_FOUND)?; + + let valid: Option = sqlx::query_scalar( + r#" + delete from ws_token_ + where token = $1 + and room_id = $2 + and expires_at > now() + returning room_id + "#, + ) + .bind(query.token) + .bind(room_id) + .fetch_optional(&db) + .await + .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + + if valid.is_none() { + return Err(StatusCode::UNAUTHORIZED); + } + + let sender = realtime.sender_for(room_id); + let receiver = sender.subscribe(); + + Ok(ws.on_upgrade(move |socket| handle_socket(socket, receiver))) +} + +async fn handle_socket( + mut socket: WebSocket, + mut receiver: tokio::sync::broadcast::Receiver, +) { + while let Ok(msg) = receiver.recv().await { + if socket + .send(WsMessage::Text(serde_json::to_string(&msg).unwrap().into())) + .await + .is_err() + { + break; + } + } +}