From 37e6bb25fc9be1d9203d11211bf234416c28ff0c Mon Sep 17 00:00:00 2001 From: eiiko6 Date: Fri, 16 Jan 2026 11:35:07 +0100 Subject: [PATCH] refactor: clients now have a single websocket that handles all rooms the user is in --- Cargo.lock | 2 +- Cargo.toml | 2 +- db/init.sql | 1 - db/mock_data.sql | 6 ++-- src/auth.rs | 6 +++- src/realtime.rs | 22 ++++++++----- src/routes/messages.rs | 30 ++++++++++++++++-- src/routes/ws.rs | 71 +++++++++++++----------------------------- 8 files changed, 74 insertions(+), 66 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index a922e85..0644347 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -619,7 +619,7 @@ dependencies = [ [[package]] name = "frangipane" -version = "1.0.0" +version = "1.0.1" dependencies = [ "anyhow", "argon2", diff --git a/Cargo.toml b/Cargo.toml index 235fc37..3b5b476 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "frangipane" -version = "1.0.0" +version = "1.0.1" edition = "2024" [dependencies] diff --git a/db/init.sql b/db/init.sql index 073e72b..44462db 100644 --- a/db/init.sql +++ b/db/init.sql @@ -57,7 +57,6 @@ CREATE TABLE IF NOT EXISTS message_ ( CREATE TABLE ws_token_ ( token TEXT PRIMARY KEY, - room_id INT NOT NULL, expires_at TIMESTAMPTZ NOT NULL ); diff --git a/db/mock_data.sql b/db/mock_data.sql index c5f738c..dad9ab6 100644 --- a/db/mock_data.sql +++ b/db/mock_data.sql @@ -29,9 +29,9 @@ INSERT INTO friendship_ (user_first, user_second) VALUES INSERT INTO friend_request_ (sender, receiver) VALUES (2, 1); -- Bob sent a friend request to Alice -INSERT INTO ws_token_ (token, room_id, expires_at) VALUES -('random_token_1', 1, '2025-12-31T23:59:59Z'), -('random_token_2', 2, '2025-12-31T23:59:59Z'); +INSERT INTO ws_token_ (token, expires_at) VALUES +('random_token_1', '2025-12-31T23:59:59Z'), +('random_token_2', '2025-12-31T23:59:59Z'); INSERT INTO room_invite_ (sender, receiver, room) VALUES (2, 1, 2); diff --git a/src/auth.rs b/src/auth.rs index b495949..150def2 100644 --- a/src/auth.rs +++ b/src/auth.rs @@ -67,11 +67,15 @@ pub fn verify_jwt(headers: HeaderMap) -> Result { .and_then(|s| s.strip_prefix("Bearer ")) .ok_or((StatusCode::UNAUTHORIZED, "Missing token".to_string()))?; + verify_jwt_string(&token.to_string()) +} + +pub fn verify_jwt_string(token: &String) -> Result { let secret = std::env::var("FRANGIPANE_JWT_SECRET").unwrap_or_else(|_| DEFAULT_SECRET_KEY.to_string()); decode::( - token, + token.as_str(), &DecodingKey::from_secret(secret.as_ref()), &Validation::default(), ) diff --git a/src/realtime.rs b/src/realtime.rs index 14cdb17..285cf80 100644 --- a/src/realtime.rs +++ b/src/realtime.rs @@ -1,27 +1,35 @@ use dashmap::DashMap; use std::sync::Arc; use tokio::sync::broadcast; +use uuid::Uuid; use crate::routes::messages::Message; -pub type RoomId = i32; - #[derive(Clone)] pub struct Realtime { - pub rooms: Arc>>, + pub clients: Arc>>, } impl Realtime { pub fn new() -> Self { Self { - rooms: Arc::new(DashMap::new()), + clients: Arc::new(DashMap::new()), } } - pub fn sender_for(&self, room: RoomId) -> broadcast::Sender { - self.rooms - .entry(room) + /// Get or create the channel for a specific user + pub fn get_sender(&self, user_uuid: Uuid) -> broadcast::Sender { + self.clients + .entry(user_uuid) .or_insert_with(|| broadcast::channel(100).0) .clone() } + + pub fn broadcast(&self, recipient_uuids: Vec, message: Message) { + for user_uuid in recipient_uuids { + if let Some(sender) = self.clients.get(&user_uuid) { + let _ = sender.send(message.clone()); + } + } + } } diff --git a/src/routes/messages.rs b/src/routes/messages.rs index dec5427..31b5780 100644 --- a/src/routes/messages.rs +++ b/src/routes/messages.rs @@ -18,6 +18,7 @@ pub struct MessageRow { pub uuid: Uuid, pub sender: String, pub sender_uuid: Uuid, + pub room_uuid: Uuid, pub message_type: String, pub content: String, pub sent_at: chrono::NaiveDateTime, @@ -26,6 +27,7 @@ pub struct MessageRow { #[derive(serde::Serialize, Debug, Clone)] pub struct Message { pub uuid: Uuid, + pub room_uuid: Uuid, pub sender: String, pub sender_uuid: Uuid, pub message_type: String, @@ -77,7 +79,7 @@ async fn list_messages( m.uuid, u.username AS sender, u.uuid AS sender_uuid, - r.uuid AS room, + r.uuid AS room_uuid, m.message_type, m.content, m.sent_at @@ -106,6 +108,7 @@ async fn list_messages( .into_iter() .map(|m| Message { uuid: m.uuid, + room_uuid: m.room_uuid, sender: m.sender, sender_uuid: m.sender_uuid, message_type: m.message_type, @@ -157,6 +160,7 @@ async fn create_message( let message = Message { uuid: uuid, + room_uuid, sender: sender_name, sender_uuid: claims.sub, message_type: payload.message_type, @@ -164,8 +168,28 @@ async fn create_message( sent_at: sent_at.format("%Y-%m-%d %H:%M:%S").to_string(), }; - let rt_sender = realtime.sender_for(room_id); - let _ = rt_sender.send(message.clone()); + let recipients: Vec = sqlx::query_scalar( + r#" + SELECT u.uuid + FROM membership_ m + JOIN user_ u ON u.id = m.user_id + WHERE m.room = $1 + "#, + ) + .bind(room_id) + .fetch_all(&db) + .await + .map_err(|e| { + tracing::error!("Error fetching message recipients: {e}"); + (StatusCode::INTERNAL_SERVER_ERROR, "DB error".into()) + })?; + + let rt = realtime.clone(); + let msg_clone = message.clone(); + + tokio::spawn(async move { + rt.broadcast(recipients, msg_clone); + }); Ok((StatusCode::CREATED, Json(message))) } diff --git a/src/routes/ws.rs b/src/routes/ws.rs index 667a3fb..3a89e3d 100644 --- a/src/routes/ws.rs +++ b/src/routes/ws.rs @@ -3,23 +3,15 @@ use axum::extract::ws::{Message as WsMessage, WebSocket}; use axum::extract::{ConnectInfo, Query}; use axum::http::HeaderMap; use axum::routing::get; -use axum::{ - Extension, - extract::{Path, WebSocketUpgrade}, - http::StatusCode, - response::IntoResponse, -}; +use axum::{Extension, extract::WebSocketUpgrade, http::StatusCode, response::IntoResponse}; use axum_extra::{TypedHeader, headers}; use serde::Deserialize; use std::net::SocketAddr; use std::time::Duration; use tokio::select; -use uuid::Uuid; -use crate::auth::{create_jwt, verify_jwt}; -use crate::db::user_id_from_uuid; -use crate::routes::rooms::is_member; -use crate::{db::room_id_from_uuid, realtime::Realtime}; +use crate::auth::{create_jwt, verify_jwt, verify_jwt_string}; +use crate::realtime::Realtime; #[derive(sqlx::FromRow, serde::Serialize, Deserialize)] pub struct WsAuthQuery { @@ -28,43 +20,27 @@ pub struct WsAuthQuery { pub fn routes() -> axum::Router { axum::Router::new() - .route("/ws/issue-token/rooms/{room_uuid}", get(issue_ws_token)) - .route("/ws/rooms/{room_uuid}", get(ws_handler)) + .route("/ws/messages/issue-token", get(issue_ws_token)) + .route("/ws/messages", get(ws_handler)) } pub async fn issue_ws_token( Extension(db): Extension, headers: HeaderMap, - Path(room_uuid): Path, ) -> Result<(StatusCode, Json), (StatusCode, String)> { let claims = verify_jwt(headers)?; - let room_id = room_id_from_uuid(&db, room_uuid).await?; - let user_id = user_id_from_uuid(&db, claims.sub).await?; - - if !is_member(user_id, room_id, &db).await { - return Err(( - StatusCode::UNAUTHORIZED, - String::from("You are not a member of this room"), - )); - } - - // tracing::info!( - // "recieved token issue request from user {} for room {}", - // claims.sub, - // room_uuid - // ); + tracing::debug!("Recieved token issue request from user {}", claims.sub); let token = create_jwt(claims.sub).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?; sqlx::query( r#" - insert into ws_token_ (token, room_id, expires_at) - values ($1, $2, now() + interval '30 seconds') + insert into ws_token_ (token, expires_at) + values ($1, now() + interval '30 seconds') "#, ) .bind(&token) - .bind(room_id) .execute(&db) .await .map_err(|_| { @@ -80,46 +56,43 @@ pub async fn issue_ws_token( async fn ws_handler( ws: WebSocketUpgrade, user_agent: Option>, - Path(room_uuid): Path, Query(query): Query, ConnectInfo(addr): ConnectInfo, Extension(realtime): Extension, Extension(db): Extension, -) -> Result { +) -> Result { // tracing::info!("recieved ws handshake: {}", room_uuid); - let room_id = room_id_from_uuid(&db, room_uuid) - .await - .map_err(|_| StatusCode::NOT_FOUND)?; + let claims = verify_jwt_string(&query.token)?; + let user_uuid = claims.sub; - let valid: Option = sqlx::query_scalar( + let result = sqlx::query( r#" delete from ws_token_ where token = $1 - and room_id = $2 and expires_at > now() - returning room_id "#, ) .bind(query.token) - .bind(room_id) - .fetch_optional(&db) + .execute(&db) .await - .map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; + .map_err(|e| { + tracing::error!("Failed to get WS token from DB: {e}"); + (StatusCode::INTERNAL_SERVER_ERROR, "DB error".into()) + })?; - if valid.is_none() { - return Err(StatusCode::UNAUTHORIZED); + if result.rows_affected() == 0 { + return Err((StatusCode::UNAUTHORIZED, "Wrong token".into())); } - let sender = realtime.sender_for(room_id); - let receiver = sender.subscribe(); + let receiver = realtime.get_sender(user_uuid).subscribe(); let user_agent = if let Some(TypedHeader(user_agent)) = user_agent { user_agent.to_string() } else { String::from("Unknown browser") }; - tracing::debug!("`{user_agent}` at {addr} connected."); + tracing::debug!("`{user_agent}` {user_uuid} at {addr} connected."); Ok(ws.on_upgrade(move |socket| handle_socket(socket, addr, receiver))) } @@ -133,7 +106,7 @@ async fn handle_socket( loop { select! { - // Receive broadcast messages and send to client + // Receive broadcast messages and send to client (any room) msg = receiver.recv() => { if let Ok(msg) = msg { if let Ok(json) = serde_json::to_string(&msg) {