diff --git a/Cargo.lock b/Cargo.lock index c576835..aeb67f4 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -619,12 +619,13 @@ dependencies = [ [[package]] name = "frangipane" -version = "1.0.4" +version = "1.0.5" dependencies = [ "anyhow", "argon2", "axum", "axum-extra", + "bytes", "chrono", "clap", "dashmap", diff --git a/Cargo.toml b/Cargo.toml index f44caf9..5fe63e6 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,6 +1,6 @@ [package] name = "frangipane" -version = "1.0.4" +version = "1.0.5" edition = "2024" [dependencies] @@ -8,6 +8,7 @@ anyhow = "1.0.99" argon2 = "0.5.3" axum = { version = "0.8.4", features = ["multipart", "ws"] } axum-extra = { version = "0.12.5", features = ["typed-header"] } +bytes = "1.11.0" chrono = { version = "0.4.42", features = ["serde"] } clap = { version = "4.5.53", features = ["derive"] } dashmap = "6.1.0" diff --git a/flake.nix b/flake.nix index 2926a50..6c0c43d 100644 --- a/flake.nix +++ b/flake.nix @@ -25,7 +25,7 @@ { packages.default = pkgs.rustPlatform.buildRustPackage { pname = "frangipane"; - version = "1.0.0"; + version = "1.0.5"; src = ./.; cargoLock = { diff --git a/src/db.rs b/src/db.rs index 1468e7d..c84b511 100644 --- a/src/db.rs +++ b/src/db.rs @@ -20,7 +20,10 @@ pub async fn room_id_from_uuid(db: &PgPool, room_uuid: Uuid) -> Result anyhow::Result<()> { let governor_limiter = governor_conf.limiter().clone(); - // a separate background task to clean up + let realtime = realtime::RealtimeMessages::new(); + let voice_manager = realtime::RealTimeVoices::new(); + let vm_clone = voice_manager.clone(); + + // A separate background task to clean up let interval = Duration::from_secs(60); std::thread::spawn(move || { loop { std::thread::sleep(interval); + // tracing::info!("rate limiting storage size: {}", governor_limiter.len()); governor_limiter.retain_recent(); + + vm_clone.retain_active(); } }); - let realtime = realtime::Realtime::new(); - let data_dir = PathBuf::from(cli.data_dir); let config = Arc::new(AppConfig { avatar_dir: data_dir.join("avatars"), @@ -115,10 +120,12 @@ async fn main() -> anyhow::Result<()> { .merge(routes::users::routes()) .merge(routes::rooms::routes()) .merge(routes::messages::routes()) + .merge(routes::voice::routes()) .merge(routes::friends::routes()) .merge(routes::ws::routes()) .layer(Extension(db_pool)) .layer(Extension(realtime)) + .layer(Extension(voice_manager)) .layer(Extension(config)) .layer(GovernorLayer::new(governor_conf)) .layer(cors) diff --git a/src/realtime.rs b/src/realtime.rs index 285cf80..f60eb74 100644 --- a/src/realtime.rs +++ b/src/realtime.rs @@ -1,3 +1,4 @@ +use axum::body::Bytes; use dashmap::DashMap; use std::sync::Arc; use tokio::sync::broadcast; @@ -6,18 +7,25 @@ use uuid::Uuid; use crate::routes::messages::Message; #[derive(Clone)] -pub struct Realtime { +pub struct RealtimeMessages { pub clients: Arc>>, } -impl Realtime { +type VoicePacket = (Uuid, Bytes); + +#[derive(Clone)] +pub struct RealTimeVoices { + pub rooms: Arc>>, +} + +impl RealtimeMessages { pub fn new() -> Self { Self { clients: Arc::new(DashMap::new()), } } - /// Get or create the channel for a specific user + /// Get or create the sender for a given user pub fn get_sender(&self, user_uuid: Uuid) -> broadcast::Sender { self.clients .entry(user_uuid) @@ -25,6 +33,7 @@ impl Realtime { .clone() } + /// Send a message to all the recipients pub fn broadcast(&self, recipient_uuids: Vec, message: Message) { for user_uuid in recipient_uuids { if let Some(sender) = self.clients.get(&user_uuid) { @@ -33,3 +42,24 @@ impl Realtime { } } } + +impl RealTimeVoices { + pub fn new() -> Self { + Self { + rooms: Arc::new(DashMap::new()), + } + } + + /// Get or create the broadcast sender for a given room + pub fn get_or_create_room(&self, room_uuid: Uuid) -> broadcast::Sender { + self.rooms + .entry(room_uuid) + .or_insert_with(|| broadcast::channel(500).0) + .clone() + } + + /// Clean up empty rooms + pub fn retain_active(&self) { + self.rooms.retain(|_, sender| sender.receiver_count() > 0); + } +} diff --git a/src/routes/messages.rs b/src/routes/messages.rs index dc4f843..f781810 100644 --- a/src/routes/messages.rs +++ b/src/routes/messages.rs @@ -1,16 +1,28 @@ +use std::{net::SocketAddr, time::Duration}; + use axum::{ Extension, Json, Router, - extract::{Path, Query}, + extract::{ + ConnectInfo, Path, Query, WebSocketUpgrade, + ws::{Message as WsMessage, WebSocket}, + }, http::{HeaderMap, StatusCode}, + response::IntoResponse, routing::{get, post}, }; +use axum_extra::{TypedHeader, headers}; use sqlx::PgPool; +use tokio::select; use uuid::Uuid; -use crate::{auth::verify_jwt, db::room_id_from_uuid, routes::rooms::is_member}; +use crate::{ + auth::{verify_jwt, verify_jwt_string}, + db::room_id_from_uuid, + routes::{rooms::is_member, ws::WsAuthQuery}, +}; use crate::{ db::{user_id_from_uuid, username_from_uuid}, - realtime::Realtime, + realtime::RealtimeMessages, }; #[derive(sqlx::FromRow, serde::Serialize, Debug)] @@ -51,6 +63,7 @@ pub fn routes() -> Router { Router::new() .route("/messages/{room_uuid}", get(list_messages)) .route("/messages/{room_uuid}", post(create_message)) + .route("/ws/messages", get(message_ws_handler)) } /// Also resets `last_read_at` @@ -155,7 +168,7 @@ async fn list_messages( async fn create_message( Path(room_uuid): Path, Extension(db): Extension, - Extension(realtime): Extension, + Extension(realtime): Extension, headers: HeaderMap, Json(payload): Json, ) -> Result<(StatusCode, Json), (StatusCode, String)> { @@ -223,3 +236,106 @@ async fn create_message( Ok((StatusCode::CREATED, Json(message))) } + +async fn message_ws_handler( + ws: WebSocketUpgrade, + user_agent: Option>, + Query(query): Query, + ConnectInfo(addr): ConnectInfo, + Extension(realtime): Extension, + Extension(db): Extension, +) -> Result { + // tracing::info!("recieved ws handshake: {}", room_uuid); + + let claims = verify_jwt_string(&query.token)?; + let user_uuid = claims.sub; + + let result = sqlx::query( + r#" + delete from ws_token_ + where token = $1 + and expires_at > now() + "#, + ) + .bind(query.token) + .execute(&db) + .await + .map_err(|e| { + tracing::error!("Failed to get WS token from DB: {e}"); + (StatusCode::INTERNAL_SERVER_ERROR, "DB error".into()) + })?; + + if result.rows_affected() == 0 { + return Err((StatusCode::UNAUTHORIZED, "Wrong token".into())); + } + + let receiver = realtime.get_sender(user_uuid).subscribe(); + + let user_agent = if let Some(TypedHeader(user_agent)) = user_agent { + user_agent.to_string() + } else { + String::from("Unknown browser") + }; + tracing::debug!("`{user_agent}` {user_uuid} at {addr} connected."); + + Ok(ws.on_upgrade(move |socket| handle_message_socket(socket, addr, receiver))) +} + +async fn handle_message_socket( + mut socket: WebSocket, + who: SocketAddr, + mut receiver: tokio::sync::broadcast::Receiver, +) { + let mut ping_interval = tokio::time::interval(Duration::from_secs(30)); + + loop { + select! { + // Receive broadcast messages and send to client (any room) + msg = receiver.recv() => { + if let Ok(msg) = msg { + if let Ok(json) = serde_json::to_string(&msg) { + if socket.send(WsMessage::Text(json.into())).await.is_err() { + tracing::error!("Failed to send message to {who}, closing connection"); + break; + } + } + } else { + break; + } + } + + // Send Ping + _ = ping_interval.tick() => { + if socket.send(WsMessage::Ping(vec![].into())).await.is_err() { + tracing::error!("Failed to send ping to {who}, closing connection"); + break; + } + + // tracing::debug!("Ping sent to {who}"); + } + + // Get incoming messages from client + client_msg = socket.recv() => { + if let Some(Ok(msg)) = client_msg { + match msg { + // WsMessage::Pong(_) => { + // tracing::debug!("Received Pong from {who}"); + // } + // WsMessage::Ping(_) => { + // tracing::info!("Received Ping from client"); + // } + // WsMessage::Text(_) => {} + WsMessage::Close(_) => { + tracing::debug!("Client disconnected"); + break; + } + _ => {} + } + } else { + tracing::debug!("Client {who} abruptly disconnected"); + break; + } + } + } + } +} diff --git a/src/routes/mod.rs b/src/routes/mod.rs index 1c5b7ca..ee546f0 100644 --- a/src/routes/mod.rs +++ b/src/routes/mod.rs @@ -2,4 +2,5 @@ pub mod friends; pub mod messages; pub mod rooms; pub mod users; +pub mod voice; pub mod ws; diff --git a/src/routes/voice.rs b/src/routes/voice.rs new file mode 100644 index 0000000..830a572 --- /dev/null +++ b/src/routes/voice.rs @@ -0,0 +1,126 @@ +use axum::{ + Extension, Router, + extract::{ + ConnectInfo, Path, Query, WebSocketUpgrade, + ws::{Message, WebSocket}, + }, + http::StatusCode, + response::IntoResponse, + routing::get, +}; +use bytes::{BufMut, Bytes, BytesMut}; +use sqlx::PgPool; +use std::{net::SocketAddr, time::Duration}; +use tokio::select; +use uuid::Uuid; + +use crate::{ + auth::verify_jwt_string, + db::{room_id_from_uuid, user_id_from_uuid}, + realtime::RealTimeVoices, + routes::rooms::is_member, + routes::ws::WsAuthQuery, +}; + +pub fn routes() -> Router { + Router::new().route("/ws/voice/{room_uuid}", get(voice_ws_handler)) +} + +async fn voice_ws_handler( + ws: WebSocketUpgrade, + Path(room_uuid): Path, + Query(query): Query, + ConnectInfo(addr): ConnectInfo, + Extension(voice_manager): Extension, + Extension(db): Extension, +) -> Result { + let claims = verify_jwt_string(&query.token)?; + let user_uuid = claims.sub; + + let result = sqlx::query( + r#" + DELETE FROM ws_token_ + WHERE token = $1 + AND expires_at > now() + "#, + ) + .bind(&query.token) + .execute(&db) + .await + .map_err(|e| { + tracing::error!("Failed to get WS token from DB: {e}"); + (StatusCode::INTERNAL_SERVER_ERROR, "DB error".into()) + })?; + + if result.rows_affected() == 0 { + return Err((StatusCode::UNAUTHORIZED, "Invalid or expired token".into())); + } + + let user_id = user_id_from_uuid(&db, user_uuid).await?; + let room_id = room_id_from_uuid(&db, room_uuid).await?; + + if !is_member(user_id, room_id, &db).await { + return Err((StatusCode::FORBIDDEN, "Not a member of this room".into())); + } + + tracing::info!("User {} joining voice in room {}", user_uuid, room_uuid); + + let tx = voice_manager.get_or_create_room(room_uuid); + let rx = tx.subscribe(); + + Ok(ws.on_upgrade(move |socket| handle_voice_socket(socket, addr, user_uuid, tx, rx))) +} + +async fn handle_voice_socket( + mut socket: WebSocket, + who: SocketAddr, + my_uuid: Uuid, + tx: tokio::sync::broadcast::Sender<(Uuid, Bytes)>, + mut rx: tokio::sync::broadcast::Receiver<(Uuid, Bytes)>, +) { + let mut ping_interval = tokio::time::interval(Duration::from_secs(15)); + + loop { + select! { + // Receive audio from other users and send to client + voice_packet = rx.recv() => { + if let Ok((speaker_uuid, audio_data)) = voice_packet { + if speaker_uuid != my_uuid { + let mut msg = BytesMut::with_capacity(16 + audio_data.len()); + msg.put(speaker_uuid.as_bytes().as_slice()); + msg.put(audio_data); + + if socket.send(Message::Binary(msg.freeze().into())).await.is_err() { + break; + } + } + } + } + + // Receive audio from alient and broadcast to room + client_msg = socket.recv() => { + if let Some(Ok(msg)) = client_msg { + match msg { + Message::Binary(data) => { + let _ = tx.send((my_uuid, Bytes::from(data))); + } + Message::Close(_) => { + tracing::debug!("Voice client {} disconnected", who); + break; + } + _ => {} + } + } else { + break; + } + } + + // Keepalive + _ = ping_interval.tick() => { + if socket.send(Message::Ping(vec![].into())).await.is_err() { + break; + } + } + } + } +} diff --git a/src/routes/ws.rs b/src/routes/ws.rs index 3a89e3d..730bb82 100644 --- a/src/routes/ws.rs +++ b/src/routes/ws.rs @@ -1,17 +1,10 @@ use axum::Json; -use axum::extract::ws::{Message as WsMessage, WebSocket}; -use axum::extract::{ConnectInfo, Query}; use axum::http::HeaderMap; use axum::routing::get; -use axum::{Extension, extract::WebSocketUpgrade, http::StatusCode, response::IntoResponse}; -use axum_extra::{TypedHeader, headers}; +use axum::{Extension, http::StatusCode}; use serde::Deserialize; -use std::net::SocketAddr; -use std::time::Duration; -use tokio::select; -use crate::auth::{create_jwt, verify_jwt, verify_jwt_string}; -use crate::realtime::Realtime; +use crate::auth::{create_jwt, verify_jwt}; #[derive(sqlx::FromRow, serde::Serialize, Deserialize)] pub struct WsAuthQuery { @@ -19,9 +12,7 @@ pub struct WsAuthQuery { } pub fn routes() -> axum::Router { - axum::Router::new() - .route("/ws/messages/issue-token", get(issue_ws_token)) - .route("/ws/messages", get(ws_handler)) + axum::Router::new().route("/ws/issue-token", get(issue_ws_token)) } pub async fn issue_ws_token( @@ -52,106 +43,3 @@ pub async fn issue_ws_token( Ok((StatusCode::CREATED, Json(WsAuthQuery { token }))) } - -async fn ws_handler( - ws: WebSocketUpgrade, - user_agent: Option>, - Query(query): Query, - ConnectInfo(addr): ConnectInfo, - Extension(realtime): Extension, - Extension(db): Extension, -) -> Result { - // tracing::info!("recieved ws handshake: {}", room_uuid); - - let claims = verify_jwt_string(&query.token)?; - let user_uuid = claims.sub; - - let result = sqlx::query( - r#" - delete from ws_token_ - where token = $1 - and expires_at > now() - "#, - ) - .bind(query.token) - .execute(&db) - .await - .map_err(|e| { - tracing::error!("Failed to get WS token from DB: {e}"); - (StatusCode::INTERNAL_SERVER_ERROR, "DB error".into()) - })?; - - if result.rows_affected() == 0 { - return Err((StatusCode::UNAUTHORIZED, "Wrong token".into())); - } - - let receiver = realtime.get_sender(user_uuid).subscribe(); - - let user_agent = if let Some(TypedHeader(user_agent)) = user_agent { - user_agent.to_string() - } else { - String::from("Unknown browser") - }; - tracing::debug!("`{user_agent}` {user_uuid} at {addr} connected."); - - Ok(ws.on_upgrade(move |socket| handle_socket(socket, addr, receiver))) -} - -async fn handle_socket( - mut socket: WebSocket, - who: SocketAddr, - mut receiver: tokio::sync::broadcast::Receiver, -) { - let mut ping_interval = tokio::time::interval(Duration::from_secs(30)); - - loop { - select! { - // Receive broadcast messages and send to client (any room) - msg = receiver.recv() => { - if let Ok(msg) = msg { - if let Ok(json) = serde_json::to_string(&msg) { - if socket.send(WsMessage::Text(json.into())).await.is_err() { - tracing::error!("Failed to send message to {who}, closing connection"); - break; - } - } - } else { - break; - } - } - - // Send Ping - _ = ping_interval.tick() => { - if socket.send(WsMessage::Ping(vec![].into())).await.is_err() { - tracing::error!("Failed to send ping to {who}, closing connection"); - break; - } - - // tracing::debug!("Ping sent to {who}"); - } - - // Get incoming messages from client - client_msg = socket.recv() => { - if let Some(Ok(msg)) = client_msg { - match msg { - // WsMessage::Pong(_) => { - // tracing::debug!("Received Pong from {who}"); - // } - // WsMessage::Ping(_) => { - // tracing::info!("Received Ping from client"); - // } - // WsMessage::Text(_) => {} - WsMessage::Close(_) => { - tracing::debug!("Client disconnected"); - break; - } - _ => {} - } - } else { - tracing::debug!("Client {who} abruptly disconnected"); - break; - } - } - } - } -}