added voice chat websocket route

This commit is contained in:
2026-01-25 17:13:05 +01:00
parent c81df769de
commit 30dc7475c2
10 changed files with 302 additions and 129 deletions

3
Cargo.lock generated
View File

@@ -619,12 +619,13 @@ dependencies = [
[[package]]
name = "frangipane"
version = "1.0.4"
version = "1.0.5"
dependencies = [
"anyhow",
"argon2",
"axum",
"axum-extra",
"bytes",
"chrono",
"clap",
"dashmap",

View File

@@ -1,6 +1,6 @@
[package]
name = "frangipane"
version = "1.0.4"
version = "1.0.5"
edition = "2024"
[dependencies]
@@ -8,6 +8,7 @@ anyhow = "1.0.99"
argon2 = "0.5.3"
axum = { version = "0.8.4", features = ["multipart", "ws"] }
axum-extra = { version = "0.12.5", features = ["typed-header"] }
bytes = "1.11.0"
chrono = { version = "0.4.42", features = ["serde"] }
clap = { version = "4.5.53", features = ["derive"] }
dashmap = "6.1.0"

View File

@@ -25,7 +25,7 @@
{
packages.default = pkgs.rustPlatform.buildRustPackage {
pname = "frangipane";
version = "1.0.0";
version = "1.0.5";
src = ./.;
cargoLock = {

View File

@@ -20,7 +20,10 @@ pub async fn room_id_from_uuid(db: &PgPool, room_uuid: Uuid) -> Result<i32, (Sta
.bind(room_uuid)
.fetch_one(db)
.await
.map_err(|_| (StatusCode::NOT_FOUND, "Failed to find room".into()))
.map_err(|e| {
tracing::error!("Failed to convert room uuid to room id: {e}");
(StatusCode::NOT_FOUND, "Failed to find room".into())
})
}
pub async fn username_from_uuid(

View File

@@ -92,18 +92,23 @@ async fn main() -> anyhow::Result<()> {
let governor_limiter = governor_conf.limiter().clone();
// a separate background task to clean up
let realtime = realtime::RealtimeMessages::new();
let voice_manager = realtime::RealTimeVoices::new();
let vm_clone = voice_manager.clone();
// A separate background task to clean up
let interval = Duration::from_secs(60);
std::thread::spawn(move || {
loop {
std::thread::sleep(interval);
// tracing::info!("rate limiting storage size: {}", governor_limiter.len());
governor_limiter.retain_recent();
vm_clone.retain_active();
}
});
let realtime = realtime::Realtime::new();
let data_dir = PathBuf::from(cli.data_dir);
let config = Arc::new(AppConfig {
avatar_dir: data_dir.join("avatars"),
@@ -115,10 +120,12 @@ async fn main() -> anyhow::Result<()> {
.merge(routes::users::routes())
.merge(routes::rooms::routes())
.merge(routes::messages::routes())
.merge(routes::voice::routes())
.merge(routes::friends::routes())
.merge(routes::ws::routes())
.layer(Extension(db_pool))
.layer(Extension(realtime))
.layer(Extension(voice_manager))
.layer(Extension(config))
.layer(GovernorLayer::new(governor_conf))
.layer(cors)

View File

@@ -1,3 +1,4 @@
use axum::body::Bytes;
use dashmap::DashMap;
use std::sync::Arc;
use tokio::sync::broadcast;
@@ -6,18 +7,25 @@ use uuid::Uuid;
use crate::routes::messages::Message;
#[derive(Clone)]
pub struct Realtime {
pub struct RealtimeMessages {
pub clients: Arc<DashMap<Uuid, broadcast::Sender<Message>>>,
}
impl Realtime {
type VoicePacket = (Uuid, Bytes);
#[derive(Clone)]
pub struct RealTimeVoices {
pub rooms: Arc<DashMap<Uuid, broadcast::Sender<VoicePacket>>>,
}
impl RealtimeMessages {
pub fn new() -> Self {
Self {
clients: Arc::new(DashMap::new()),
}
}
/// Get or create the channel for a specific user
/// Get or create the sender for a given user
pub fn get_sender(&self, user_uuid: Uuid) -> broadcast::Sender<Message> {
self.clients
.entry(user_uuid)
@@ -25,6 +33,7 @@ impl Realtime {
.clone()
}
/// Send a message to all the recipients
pub fn broadcast(&self, recipient_uuids: Vec<Uuid>, message: Message) {
for user_uuid in recipient_uuids {
if let Some(sender) = self.clients.get(&user_uuid) {
@@ -33,3 +42,24 @@ impl Realtime {
}
}
}
impl RealTimeVoices {
pub fn new() -> Self {
Self {
rooms: Arc::new(DashMap::new()),
}
}
/// Get or create the broadcast sender for a given room
pub fn get_or_create_room(&self, room_uuid: Uuid) -> broadcast::Sender<VoicePacket> {
self.rooms
.entry(room_uuid)
.or_insert_with(|| broadcast::channel(500).0)
.clone()
}
/// Clean up empty rooms
pub fn retain_active(&self) {
self.rooms.retain(|_, sender| sender.receiver_count() > 0);
}
}

View File

@@ -1,16 +1,28 @@
use std::{net::SocketAddr, time::Duration};
use axum::{
Extension, Json, Router,
extract::{Path, Query},
extract::{
ConnectInfo, Path, Query, WebSocketUpgrade,
ws::{Message as WsMessage, WebSocket},
},
http::{HeaderMap, StatusCode},
response::IntoResponse,
routing::{get, post},
};
use axum_extra::{TypedHeader, headers};
use sqlx::PgPool;
use tokio::select;
use uuid::Uuid;
use crate::{auth::verify_jwt, db::room_id_from_uuid, routes::rooms::is_member};
use crate::{
auth::{verify_jwt, verify_jwt_string},
db::room_id_from_uuid,
routes::{rooms::is_member, ws::WsAuthQuery},
};
use crate::{
db::{user_id_from_uuid, username_from_uuid},
realtime::Realtime,
realtime::RealtimeMessages,
};
#[derive(sqlx::FromRow, serde::Serialize, Debug)]
@@ -51,6 +63,7 @@ pub fn routes() -> Router {
Router::new()
.route("/messages/{room_uuid}", get(list_messages))
.route("/messages/{room_uuid}", post(create_message))
.route("/ws/messages", get(message_ws_handler))
}
/// Also resets `last_read_at`
@@ -155,7 +168,7 @@ async fn list_messages(
async fn create_message(
Path(room_uuid): Path<Uuid>,
Extension(db): Extension<PgPool>,
Extension(realtime): Extension<Realtime>,
Extension(realtime): Extension<RealtimeMessages>,
headers: HeaderMap,
Json(payload): Json<NewMessagePayload>,
) -> Result<(StatusCode, Json<Message>), (StatusCode, String)> {
@@ -223,3 +236,106 @@ async fn create_message(
Ok((StatusCode::CREATED, Json(message)))
}
async fn message_ws_handler(
ws: WebSocketUpgrade,
user_agent: Option<TypedHeader<headers::UserAgent>>,
Query(query): Query<WsAuthQuery>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
Extension(realtime): Extension<RealtimeMessages>,
Extension(db): Extension<sqlx::PgPool>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
// tracing::info!("recieved ws handshake: {}", room_uuid);
let claims = verify_jwt_string(&query.token)?;
let user_uuid = claims.sub;
let result = sqlx::query(
r#"
delete from ws_token_
where token = $1
and expires_at > now()
"#,
)
.bind(query.token)
.execute(&db)
.await
.map_err(|e| {
tracing::error!("Failed to get WS token from DB: {e}");
(StatusCode::INTERNAL_SERVER_ERROR, "DB error".into())
})?;
if result.rows_affected() == 0 {
return Err((StatusCode::UNAUTHORIZED, "Wrong token".into()));
}
let receiver = realtime.get_sender(user_uuid).subscribe();
let user_agent = if let Some(TypedHeader(user_agent)) = user_agent {
user_agent.to_string()
} else {
String::from("Unknown browser")
};
tracing::debug!("`{user_agent}` {user_uuid} at {addr} connected.");
Ok(ws.on_upgrade(move |socket| handle_message_socket(socket, addr, receiver)))
}
async fn handle_message_socket(
mut socket: WebSocket,
who: SocketAddr,
mut receiver: tokio::sync::broadcast::Receiver<crate::routes::messages::Message>,
) {
let mut ping_interval = tokio::time::interval(Duration::from_secs(30));
loop {
select! {
// Receive broadcast messages and send to client (any room)
msg = receiver.recv() => {
if let Ok(msg) = msg {
if let Ok(json) = serde_json::to_string(&msg) {
if socket.send(WsMessage::Text(json.into())).await.is_err() {
tracing::error!("Failed to send message to {who}, closing connection");
break;
}
}
} else {
break;
}
}
// Send Ping
_ = ping_interval.tick() => {
if socket.send(WsMessage::Ping(vec![].into())).await.is_err() {
tracing::error!("Failed to send ping to {who}, closing connection");
break;
}
// tracing::debug!("Ping sent to {who}");
}
// Get incoming messages from client
client_msg = socket.recv() => {
if let Some(Ok(msg)) = client_msg {
match msg {
// WsMessage::Pong(_) => {
// tracing::debug!("Received Pong from {who}");
// }
// WsMessage::Ping(_) => {
// tracing::info!("Received Ping from client");
// }
// WsMessage::Text(_) => {}
WsMessage::Close(_) => {
tracing::debug!("Client disconnected");
break;
}
_ => {}
}
} else {
tracing::debug!("Client {who} abruptly disconnected");
break;
}
}
}
}
}

View File

@@ -2,4 +2,5 @@ pub mod friends;
pub mod messages;
pub mod rooms;
pub mod users;
pub mod voice;
pub mod ws;

126
src/routes/voice.rs Normal file
View File

@@ -0,0 +1,126 @@
use axum::{
Extension, Router,
extract::{
ConnectInfo, Path, Query, WebSocketUpgrade,
ws::{Message, WebSocket},
},
http::StatusCode,
response::IntoResponse,
routing::get,
};
use bytes::{BufMut, Bytes, BytesMut};
use sqlx::PgPool;
use std::{net::SocketAddr, time::Duration};
use tokio::select;
use uuid::Uuid;
use crate::{
auth::verify_jwt_string,
db::{room_id_from_uuid, user_id_from_uuid},
realtime::RealTimeVoices,
routes::rooms::is_member,
routes::ws::WsAuthQuery,
};
pub fn routes() -> Router {
Router::new().route("/ws/voice/{room_uuid}", get(voice_ws_handler))
}
async fn voice_ws_handler(
ws: WebSocketUpgrade,
Path(room_uuid): Path<Uuid>,
Query(query): Query<WsAuthQuery>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
Extension(voice_manager): Extension<RealTimeVoices>,
Extension(db): Extension<PgPool>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
let claims = verify_jwt_string(&query.token)?;
let user_uuid = claims.sub;
let result = sqlx::query(
r#"
DELETE FROM ws_token_
WHERE token = $1
AND expires_at > now()
"#,
)
.bind(&query.token)
.execute(&db)
.await
.map_err(|e| {
tracing::error!("Failed to get WS token from DB: {e}");
(StatusCode::INTERNAL_SERVER_ERROR, "DB error".into())
})?;
if result.rows_affected() == 0 {
return Err((StatusCode::UNAUTHORIZED, "Invalid or expired token".into()));
}
let user_id = user_id_from_uuid(&db, user_uuid).await?;
let room_id = room_id_from_uuid(&db, room_uuid).await?;
if !is_member(user_id, room_id, &db).await {
return Err((StatusCode::FORBIDDEN, "Not a member of this room".into()));
}
tracing::info!("User {} joining voice in room {}", user_uuid, room_uuid);
let tx = voice_manager.get_or_create_room(room_uuid);
let rx = tx.subscribe();
Ok(ws.on_upgrade(move |socket| handle_voice_socket(socket, addr, user_uuid, tx, rx)))
}
async fn handle_voice_socket(
mut socket: WebSocket,
who: SocketAddr,
my_uuid: Uuid,
tx: tokio::sync::broadcast::Sender<(Uuid, Bytes)>,
mut rx: tokio::sync::broadcast::Receiver<(Uuid, Bytes)>,
) {
let mut ping_interval = tokio::time::interval(Duration::from_secs(15));
loop {
select! {
// Receive audio from other users and send to client
voice_packet = rx.recv() => {
if let Ok((speaker_uuid, audio_data)) = voice_packet {
if speaker_uuid != my_uuid {
let mut msg = BytesMut::with_capacity(16 + audio_data.len());
msg.put(speaker_uuid.as_bytes().as_slice());
msg.put(audio_data);
if socket.send(Message::Binary(msg.freeze().into())).await.is_err() {
break;
}
}
}
}
// Receive audio from alient and broadcast to room
client_msg = socket.recv() => {
if let Some(Ok(msg)) = client_msg {
match msg {
Message::Binary(data) => {
let _ = tx.send((my_uuid, Bytes::from(data)));
}
Message::Close(_) => {
tracing::debug!("Voice client {} disconnected", who);
break;
}
_ => {}
}
} else {
break;
}
}
// Keepalive
_ = ping_interval.tick() => {
if socket.send(Message::Ping(vec![].into())).await.is_err() {
break;
}
}
}
}
}

View File

@@ -1,17 +1,10 @@
use axum::Json;
use axum::extract::ws::{Message as WsMessage, WebSocket};
use axum::extract::{ConnectInfo, Query};
use axum::http::HeaderMap;
use axum::routing::get;
use axum::{Extension, extract::WebSocketUpgrade, http::StatusCode, response::IntoResponse};
use axum_extra::{TypedHeader, headers};
use axum::{Extension, http::StatusCode};
use serde::Deserialize;
use std::net::SocketAddr;
use std::time::Duration;
use tokio::select;
use crate::auth::{create_jwt, verify_jwt, verify_jwt_string};
use crate::realtime::Realtime;
use crate::auth::{create_jwt, verify_jwt};
#[derive(sqlx::FromRow, serde::Serialize, Deserialize)]
pub struct WsAuthQuery {
@@ -19,9 +12,7 @@ pub struct WsAuthQuery {
}
pub fn routes() -> axum::Router {
axum::Router::new()
.route("/ws/messages/issue-token", get(issue_ws_token))
.route("/ws/messages", get(ws_handler))
axum::Router::new().route("/ws/issue-token", get(issue_ws_token))
}
pub async fn issue_ws_token(
@@ -52,106 +43,3 @@ pub async fn issue_ws_token(
Ok((StatusCode::CREATED, Json(WsAuthQuery { token })))
}
async fn ws_handler(
ws: WebSocketUpgrade,
user_agent: Option<TypedHeader<headers::UserAgent>>,
Query(query): Query<WsAuthQuery>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
Extension(realtime): Extension<Realtime>,
Extension(db): Extension<sqlx::PgPool>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
// tracing::info!("recieved ws handshake: {}", room_uuid);
let claims = verify_jwt_string(&query.token)?;
let user_uuid = claims.sub;
let result = sqlx::query(
r#"
delete from ws_token_
where token = $1
and expires_at > now()
"#,
)
.bind(query.token)
.execute(&db)
.await
.map_err(|e| {
tracing::error!("Failed to get WS token from DB: {e}");
(StatusCode::INTERNAL_SERVER_ERROR, "DB error".into())
})?;
if result.rows_affected() == 0 {
return Err((StatusCode::UNAUTHORIZED, "Wrong token".into()));
}
let receiver = realtime.get_sender(user_uuid).subscribe();
let user_agent = if let Some(TypedHeader(user_agent)) = user_agent {
user_agent.to_string()
} else {
String::from("Unknown browser")
};
tracing::debug!("`{user_agent}` {user_uuid} at {addr} connected.");
Ok(ws.on_upgrade(move |socket| handle_socket(socket, addr, receiver)))
}
async fn handle_socket(
mut socket: WebSocket,
who: SocketAddr,
mut receiver: tokio::sync::broadcast::Receiver<crate::routes::messages::Message>,
) {
let mut ping_interval = tokio::time::interval(Duration::from_secs(30));
loop {
select! {
// Receive broadcast messages and send to client (any room)
msg = receiver.recv() => {
if let Ok(msg) = msg {
if let Ok(json) = serde_json::to_string(&msg) {
if socket.send(WsMessage::Text(json.into())).await.is_err() {
tracing::error!("Failed to send message to {who}, closing connection");
break;
}
}
} else {
break;
}
}
// Send Ping
_ = ping_interval.tick() => {
if socket.send(WsMessage::Ping(vec![].into())).await.is_err() {
tracing::error!("Failed to send ping to {who}, closing connection");
break;
}
// tracing::debug!("Ping sent to {who}");
}
// Get incoming messages from client
client_msg = socket.recv() => {
if let Some(Ok(msg)) = client_msg {
match msg {
// WsMessage::Pong(_) => {
// tracing::debug!("Received Pong from {who}");
// }
// WsMessage::Ping(_) => {
// tracing::info!("Received Ping from client");
// }
// WsMessage::Text(_) => {}
WsMessage::Close(_) => {
tracing::debug!("Client disconnected");
break;
}
_ => {}
}
} else {
tracing::debug!("Client {who} abruptly disconnected");
break;
}
}
}
}
}