refactor: clients now have a single websocket that handles all rooms the user is in

This commit is contained in:
2026-01-16 11:35:07 +01:00
parent 376353833c
commit 37e6bb25fc
8 changed files with 74 additions and 66 deletions

2
Cargo.lock generated
View File

@@ -619,7 +619,7 @@ dependencies = [
[[package]] [[package]]
name = "frangipane" name = "frangipane"
version = "1.0.0" version = "1.0.1"
dependencies = [ dependencies = [
"anyhow", "anyhow",
"argon2", "argon2",

View File

@@ -1,6 +1,6 @@
[package] [package]
name = "frangipane" name = "frangipane"
version = "1.0.0" version = "1.0.1"
edition = "2024" edition = "2024"
[dependencies] [dependencies]

View File

@@ -57,7 +57,6 @@ CREATE TABLE IF NOT EXISTS message_ (
CREATE TABLE ws_token_ ( CREATE TABLE ws_token_ (
token TEXT PRIMARY KEY, token TEXT PRIMARY KEY,
room_id INT NOT NULL,
expires_at TIMESTAMPTZ NOT NULL expires_at TIMESTAMPTZ NOT NULL
); );

View File

@@ -29,9 +29,9 @@ INSERT INTO friendship_ (user_first, user_second) VALUES
INSERT INTO friend_request_ (sender, receiver) VALUES INSERT INTO friend_request_ (sender, receiver) VALUES
(2, 1); -- Bob sent a friend request to Alice (2, 1); -- Bob sent a friend request to Alice
INSERT INTO ws_token_ (token, room_id, expires_at) VALUES INSERT INTO ws_token_ (token, expires_at) VALUES
('random_token_1', 1, '2025-12-31T23:59:59Z'), ('random_token_1', '2025-12-31T23:59:59Z'),
('random_token_2', 2, '2025-12-31T23:59:59Z'); ('random_token_2', '2025-12-31T23:59:59Z');
INSERT INTO room_invite_ (sender, receiver, room) VALUES INSERT INTO room_invite_ (sender, receiver, room) VALUES
(2, 1, 2); (2, 1, 2);

View File

@@ -67,11 +67,15 @@ pub fn verify_jwt(headers: HeaderMap) -> Result<Claims, (StatusCode, String)> {
.and_then(|s| s.strip_prefix("Bearer ")) .and_then(|s| s.strip_prefix("Bearer "))
.ok_or((StatusCode::UNAUTHORIZED, "Missing token".to_string()))?; .ok_or((StatusCode::UNAUTHORIZED, "Missing token".to_string()))?;
verify_jwt_string(&token.to_string())
}
pub fn verify_jwt_string(token: &String) -> Result<Claims, (StatusCode, String)> {
let secret = let secret =
std::env::var("FRANGIPANE_JWT_SECRET").unwrap_or_else(|_| DEFAULT_SECRET_KEY.to_string()); std::env::var("FRANGIPANE_JWT_SECRET").unwrap_or_else(|_| DEFAULT_SECRET_KEY.to_string());
decode::<Claims>( decode::<Claims>(
token, token.as_str(),
&DecodingKey::from_secret(secret.as_ref()), &DecodingKey::from_secret(secret.as_ref()),
&Validation::default(), &Validation::default(),
) )

View File

@@ -1,27 +1,35 @@
use dashmap::DashMap; use dashmap::DashMap;
use std::sync::Arc; use std::sync::Arc;
use tokio::sync::broadcast; use tokio::sync::broadcast;
use uuid::Uuid;
use crate::routes::messages::Message; use crate::routes::messages::Message;
pub type RoomId = i32;
#[derive(Clone)] #[derive(Clone)]
pub struct Realtime { pub struct Realtime {
pub rooms: Arc<DashMap<RoomId, broadcast::Sender<Message>>>, pub clients: Arc<DashMap<Uuid, broadcast::Sender<Message>>>,
} }
impl Realtime { impl Realtime {
pub fn new() -> Self { pub fn new() -> Self {
Self { Self {
rooms: Arc::new(DashMap::new()), clients: Arc::new(DashMap::new()),
} }
} }
pub fn sender_for(&self, room: RoomId) -> broadcast::Sender<Message> { /// Get or create the channel for a specific user
self.rooms pub fn get_sender(&self, user_uuid: Uuid) -> broadcast::Sender<Message> {
.entry(room) self.clients
.entry(user_uuid)
.or_insert_with(|| broadcast::channel(100).0) .or_insert_with(|| broadcast::channel(100).0)
.clone() .clone()
} }
pub fn broadcast(&self, recipient_uuids: Vec<Uuid>, message: Message) {
for user_uuid in recipient_uuids {
if let Some(sender) = self.clients.get(&user_uuid) {
let _ = sender.send(message.clone());
}
}
}
} }

View File

@@ -18,6 +18,7 @@ pub struct MessageRow {
pub uuid: Uuid, pub uuid: Uuid,
pub sender: String, pub sender: String,
pub sender_uuid: Uuid, pub sender_uuid: Uuid,
pub room_uuid: Uuid,
pub message_type: String, pub message_type: String,
pub content: String, pub content: String,
pub sent_at: chrono::NaiveDateTime, pub sent_at: chrono::NaiveDateTime,
@@ -26,6 +27,7 @@ pub struct MessageRow {
#[derive(serde::Serialize, Debug, Clone)] #[derive(serde::Serialize, Debug, Clone)]
pub struct Message { pub struct Message {
pub uuid: Uuid, pub uuid: Uuid,
pub room_uuid: Uuid,
pub sender: String, pub sender: String,
pub sender_uuid: Uuid, pub sender_uuid: Uuid,
pub message_type: String, pub message_type: String,
@@ -77,7 +79,7 @@ async fn list_messages(
m.uuid, m.uuid,
u.username AS sender, u.username AS sender,
u.uuid AS sender_uuid, u.uuid AS sender_uuid,
r.uuid AS room, r.uuid AS room_uuid,
m.message_type, m.message_type,
m.content, m.content,
m.sent_at m.sent_at
@@ -106,6 +108,7 @@ async fn list_messages(
.into_iter() .into_iter()
.map(|m| Message { .map(|m| Message {
uuid: m.uuid, uuid: m.uuid,
room_uuid: m.room_uuid,
sender: m.sender, sender: m.sender,
sender_uuid: m.sender_uuid, sender_uuid: m.sender_uuid,
message_type: m.message_type, message_type: m.message_type,
@@ -157,6 +160,7 @@ async fn create_message(
let message = Message { let message = Message {
uuid: uuid, uuid: uuid,
room_uuid,
sender: sender_name, sender: sender_name,
sender_uuid: claims.sub, sender_uuid: claims.sub,
message_type: payload.message_type, message_type: payload.message_type,
@@ -164,8 +168,28 @@ async fn create_message(
sent_at: sent_at.format("%Y-%m-%d %H:%M:%S").to_string(), sent_at: sent_at.format("%Y-%m-%d %H:%M:%S").to_string(),
}; };
let rt_sender = realtime.sender_for(room_id); let recipients: Vec<Uuid> = sqlx::query_scalar(
let _ = rt_sender.send(message.clone()); r#"
SELECT u.uuid
FROM membership_ m
JOIN user_ u ON u.id = m.user_id
WHERE m.room = $1
"#,
)
.bind(room_id)
.fetch_all(&db)
.await
.map_err(|e| {
tracing::error!("Error fetching message recipients: {e}");
(StatusCode::INTERNAL_SERVER_ERROR, "DB error".into())
})?;
let rt = realtime.clone();
let msg_clone = message.clone();
tokio::spawn(async move {
rt.broadcast(recipients, msg_clone);
});
Ok((StatusCode::CREATED, Json(message))) Ok((StatusCode::CREATED, Json(message)))
} }

View File

@@ -3,23 +3,15 @@ use axum::extract::ws::{Message as WsMessage, WebSocket};
use axum::extract::{ConnectInfo, Query}; use axum::extract::{ConnectInfo, Query};
use axum::http::HeaderMap; use axum::http::HeaderMap;
use axum::routing::get; use axum::routing::get;
use axum::{ use axum::{Extension, extract::WebSocketUpgrade, http::StatusCode, response::IntoResponse};
Extension,
extract::{Path, WebSocketUpgrade},
http::StatusCode,
response::IntoResponse,
};
use axum_extra::{TypedHeader, headers}; use axum_extra::{TypedHeader, headers};
use serde::Deserialize; use serde::Deserialize;
use std::net::SocketAddr; use std::net::SocketAddr;
use std::time::Duration; use std::time::Duration;
use tokio::select; use tokio::select;
use uuid::Uuid;
use crate::auth::{create_jwt, verify_jwt}; use crate::auth::{create_jwt, verify_jwt, verify_jwt_string};
use crate::db::user_id_from_uuid; use crate::realtime::Realtime;
use crate::routes::rooms::is_member;
use crate::{db::room_id_from_uuid, realtime::Realtime};
#[derive(sqlx::FromRow, serde::Serialize, Deserialize)] #[derive(sqlx::FromRow, serde::Serialize, Deserialize)]
pub struct WsAuthQuery { pub struct WsAuthQuery {
@@ -28,43 +20,27 @@ pub struct WsAuthQuery {
pub fn routes() -> axum::Router { pub fn routes() -> axum::Router {
axum::Router::new() axum::Router::new()
.route("/ws/issue-token/rooms/{room_uuid}", get(issue_ws_token)) .route("/ws/messages/issue-token", get(issue_ws_token))
.route("/ws/rooms/{room_uuid}", get(ws_handler)) .route("/ws/messages", get(ws_handler))
} }
pub async fn issue_ws_token( pub async fn issue_ws_token(
Extension(db): Extension<sqlx::PgPool>, Extension(db): Extension<sqlx::PgPool>,
headers: HeaderMap, headers: HeaderMap,
Path(room_uuid): Path<Uuid>,
) -> Result<(StatusCode, Json<WsAuthQuery>), (StatusCode, String)> { ) -> Result<(StatusCode, Json<WsAuthQuery>), (StatusCode, String)> {
let claims = verify_jwt(headers)?; let claims = verify_jwt(headers)?;
let room_id = room_id_from_uuid(&db, room_uuid).await?; tracing::debug!("Recieved token issue request from user {}", claims.sub);
let user_id = user_id_from_uuid(&db, claims.sub).await?;
if !is_member(user_id, room_id, &db).await {
return Err((
StatusCode::UNAUTHORIZED,
String::from("You are not a member of this room"),
));
}
// tracing::info!(
// "recieved token issue request from user {} for room {}",
// claims.sub,
// room_uuid
// );
let token = create_jwt(claims.sub).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?; let token = create_jwt(claims.sub).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?;
sqlx::query( sqlx::query(
r#" r#"
insert into ws_token_ (token, room_id, expires_at) insert into ws_token_ (token, expires_at)
values ($1, $2, now() + interval '30 seconds') values ($1, now() + interval '30 seconds')
"#, "#,
) )
.bind(&token) .bind(&token)
.bind(room_id)
.execute(&db) .execute(&db)
.await .await
.map_err(|_| { .map_err(|_| {
@@ -80,46 +56,43 @@ pub async fn issue_ws_token(
async fn ws_handler( async fn ws_handler(
ws: WebSocketUpgrade, ws: WebSocketUpgrade,
user_agent: Option<TypedHeader<headers::UserAgent>>, user_agent: Option<TypedHeader<headers::UserAgent>>,
Path(room_uuid): Path<Uuid>,
Query(query): Query<WsAuthQuery>, Query(query): Query<WsAuthQuery>,
ConnectInfo(addr): ConnectInfo<SocketAddr>, ConnectInfo(addr): ConnectInfo<SocketAddr>,
Extension(realtime): Extension<Realtime>, Extension(realtime): Extension<Realtime>,
Extension(db): Extension<sqlx::PgPool>, Extension(db): Extension<sqlx::PgPool>,
) -> Result<impl IntoResponse, axum::http::StatusCode> { ) -> Result<impl IntoResponse, (StatusCode, String)> {
// tracing::info!("recieved ws handshake: {}", room_uuid); // tracing::info!("recieved ws handshake: {}", room_uuid);
let room_id = room_id_from_uuid(&db, room_uuid) let claims = verify_jwt_string(&query.token)?;
.await let user_uuid = claims.sub;
.map_err(|_| StatusCode::NOT_FOUND)?;
let valid: Option<i32> = sqlx::query_scalar( let result = sqlx::query(
r#" r#"
delete from ws_token_ delete from ws_token_
where token = $1 where token = $1
and room_id = $2
and expires_at > now() and expires_at > now()
returning room_id
"#, "#,
) )
.bind(query.token) .bind(query.token)
.bind(room_id) .execute(&db)
.fetch_optional(&db)
.await .await
.map_err(|_| StatusCode::INTERNAL_SERVER_ERROR)?; .map_err(|e| {
tracing::error!("Failed to get WS token from DB: {e}");
(StatusCode::INTERNAL_SERVER_ERROR, "DB error".into())
})?;
if valid.is_none() { if result.rows_affected() == 0 {
return Err(StatusCode::UNAUTHORIZED); return Err((StatusCode::UNAUTHORIZED, "Wrong token".into()));
} }
let sender = realtime.sender_for(room_id); let receiver = realtime.get_sender(user_uuid).subscribe();
let receiver = sender.subscribe();
let user_agent = if let Some(TypedHeader(user_agent)) = user_agent { let user_agent = if let Some(TypedHeader(user_agent)) = user_agent {
user_agent.to_string() user_agent.to_string()
} else { } else {
String::from("Unknown browser") String::from("Unknown browser")
}; };
tracing::debug!("`{user_agent}` at {addr} connected."); tracing::debug!("`{user_agent}` {user_uuid} at {addr} connected.");
Ok(ws.on_upgrade(move |socket| handle_socket(socket, addr, receiver))) Ok(ws.on_upgrade(move |socket| handle_socket(socket, addr, receiver)))
} }
@@ -133,7 +106,7 @@ async fn handle_socket(
loop { loop {
select! { select! {
// Receive broadcast messages and send to client // Receive broadcast messages and send to client (any room)
msg = receiver.recv() => { msg = receiver.recv() => {
if let Ok(msg) = msg { if let Ok(msg) = msg {
if let Ok(json) = serde_json::to_string(&msg) { if let Ok(json) = serde_json::to_string(&msg) {