added voice chat websocket route

This commit is contained in:
2026-01-25 17:13:05 +01:00
parent c81df769de
commit 30dc7475c2
10 changed files with 302 additions and 129 deletions

View File

@@ -1,16 +1,28 @@
use std::{net::SocketAddr, time::Duration};
use axum::{
Extension, Json, Router,
extract::{Path, Query},
extract::{
ConnectInfo, Path, Query, WebSocketUpgrade,
ws::{Message as WsMessage, WebSocket},
},
http::{HeaderMap, StatusCode},
response::IntoResponse,
routing::{get, post},
};
use axum_extra::{TypedHeader, headers};
use sqlx::PgPool;
use tokio::select;
use uuid::Uuid;
use crate::{auth::verify_jwt, db::room_id_from_uuid, routes::rooms::is_member};
use crate::{
auth::{verify_jwt, verify_jwt_string},
db::room_id_from_uuid,
routes::{rooms::is_member, ws::WsAuthQuery},
};
use crate::{
db::{user_id_from_uuid, username_from_uuid},
realtime::Realtime,
realtime::RealtimeMessages,
};
#[derive(sqlx::FromRow, serde::Serialize, Debug)]
@@ -51,6 +63,7 @@ pub fn routes() -> Router {
Router::new()
.route("/messages/{room_uuid}", get(list_messages))
.route("/messages/{room_uuid}", post(create_message))
.route("/ws/messages", get(message_ws_handler))
}
/// Also resets `last_read_at`
@@ -155,7 +168,7 @@ async fn list_messages(
async fn create_message(
Path(room_uuid): Path<Uuid>,
Extension(db): Extension<PgPool>,
Extension(realtime): Extension<Realtime>,
Extension(realtime): Extension<RealtimeMessages>,
headers: HeaderMap,
Json(payload): Json<NewMessagePayload>,
) -> Result<(StatusCode, Json<Message>), (StatusCode, String)> {
@@ -223,3 +236,106 @@ async fn create_message(
Ok((StatusCode::CREATED, Json(message)))
}
async fn message_ws_handler(
ws: WebSocketUpgrade,
user_agent: Option<TypedHeader<headers::UserAgent>>,
Query(query): Query<WsAuthQuery>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
Extension(realtime): Extension<RealtimeMessages>,
Extension(db): Extension<sqlx::PgPool>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
// tracing::info!("recieved ws handshake: {}", room_uuid);
let claims = verify_jwt_string(&query.token)?;
let user_uuid = claims.sub;
let result = sqlx::query(
r#"
delete from ws_token_
where token = $1
and expires_at > now()
"#,
)
.bind(query.token)
.execute(&db)
.await
.map_err(|e| {
tracing::error!("Failed to get WS token from DB: {e}");
(StatusCode::INTERNAL_SERVER_ERROR, "DB error".into())
})?;
if result.rows_affected() == 0 {
return Err((StatusCode::UNAUTHORIZED, "Wrong token".into()));
}
let receiver = realtime.get_sender(user_uuid).subscribe();
let user_agent = if let Some(TypedHeader(user_agent)) = user_agent {
user_agent.to_string()
} else {
String::from("Unknown browser")
};
tracing::debug!("`{user_agent}` {user_uuid} at {addr} connected.");
Ok(ws.on_upgrade(move |socket| handle_message_socket(socket, addr, receiver)))
}
async fn handle_message_socket(
mut socket: WebSocket,
who: SocketAddr,
mut receiver: tokio::sync::broadcast::Receiver<crate::routes::messages::Message>,
) {
let mut ping_interval = tokio::time::interval(Duration::from_secs(30));
loop {
select! {
// Receive broadcast messages and send to client (any room)
msg = receiver.recv() => {
if let Ok(msg) = msg {
if let Ok(json) = serde_json::to_string(&msg) {
if socket.send(WsMessage::Text(json.into())).await.is_err() {
tracing::error!("Failed to send message to {who}, closing connection");
break;
}
}
} else {
break;
}
}
// Send Ping
_ = ping_interval.tick() => {
if socket.send(WsMessage::Ping(vec![].into())).await.is_err() {
tracing::error!("Failed to send ping to {who}, closing connection");
break;
}
// tracing::debug!("Ping sent to {who}");
}
// Get incoming messages from client
client_msg = socket.recv() => {
if let Some(Ok(msg)) = client_msg {
match msg {
// WsMessage::Pong(_) => {
// tracing::debug!("Received Pong from {who}");
// }
// WsMessage::Ping(_) => {
// tracing::info!("Received Ping from client");
// }
// WsMessage::Text(_) => {}
WsMessage::Close(_) => {
tracing::debug!("Client disconnected");
break;
}
_ => {}
}
} else {
tracing::debug!("Client {who} abruptly disconnected");
break;
}
}
}
}
}

View File

@@ -2,4 +2,5 @@ pub mod friends;
pub mod messages;
pub mod rooms;
pub mod users;
pub mod voice;
pub mod ws;

126
src/routes/voice.rs Normal file
View File

@@ -0,0 +1,126 @@
use axum::{
Extension, Router,
extract::{
ConnectInfo, Path, Query, WebSocketUpgrade,
ws::{Message, WebSocket},
},
http::StatusCode,
response::IntoResponse,
routing::get,
};
use bytes::{BufMut, Bytes, BytesMut};
use sqlx::PgPool;
use std::{net::SocketAddr, time::Duration};
use tokio::select;
use uuid::Uuid;
use crate::{
auth::verify_jwt_string,
db::{room_id_from_uuid, user_id_from_uuid},
realtime::RealTimeVoices,
routes::rooms::is_member,
routes::ws::WsAuthQuery,
};
pub fn routes() -> Router {
Router::new().route("/ws/voice/{room_uuid}", get(voice_ws_handler))
}
async fn voice_ws_handler(
ws: WebSocketUpgrade,
Path(room_uuid): Path<Uuid>,
Query(query): Query<WsAuthQuery>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
Extension(voice_manager): Extension<RealTimeVoices>,
Extension(db): Extension<PgPool>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
let claims = verify_jwt_string(&query.token)?;
let user_uuid = claims.sub;
let result = sqlx::query(
r#"
DELETE FROM ws_token_
WHERE token = $1
AND expires_at > now()
"#,
)
.bind(&query.token)
.execute(&db)
.await
.map_err(|e| {
tracing::error!("Failed to get WS token from DB: {e}");
(StatusCode::INTERNAL_SERVER_ERROR, "DB error".into())
})?;
if result.rows_affected() == 0 {
return Err((StatusCode::UNAUTHORIZED, "Invalid or expired token".into()));
}
let user_id = user_id_from_uuid(&db, user_uuid).await?;
let room_id = room_id_from_uuid(&db, room_uuid).await?;
if !is_member(user_id, room_id, &db).await {
return Err((StatusCode::FORBIDDEN, "Not a member of this room".into()));
}
tracing::info!("User {} joining voice in room {}", user_uuid, room_uuid);
let tx = voice_manager.get_or_create_room(room_uuid);
let rx = tx.subscribe();
Ok(ws.on_upgrade(move |socket| handle_voice_socket(socket, addr, user_uuid, tx, rx)))
}
async fn handle_voice_socket(
mut socket: WebSocket,
who: SocketAddr,
my_uuid: Uuid,
tx: tokio::sync::broadcast::Sender<(Uuid, Bytes)>,
mut rx: tokio::sync::broadcast::Receiver<(Uuid, Bytes)>,
) {
let mut ping_interval = tokio::time::interval(Duration::from_secs(15));
loop {
select! {
// Receive audio from other users and send to client
voice_packet = rx.recv() => {
if let Ok((speaker_uuid, audio_data)) = voice_packet {
if speaker_uuid != my_uuid {
let mut msg = BytesMut::with_capacity(16 + audio_data.len());
msg.put(speaker_uuid.as_bytes().as_slice());
msg.put(audio_data);
if socket.send(Message::Binary(msg.freeze().into())).await.is_err() {
break;
}
}
}
}
// Receive audio from alient and broadcast to room
client_msg = socket.recv() => {
if let Some(Ok(msg)) = client_msg {
match msg {
Message::Binary(data) => {
let _ = tx.send((my_uuid, Bytes::from(data)));
}
Message::Close(_) => {
tracing::debug!("Voice client {} disconnected", who);
break;
}
_ => {}
}
} else {
break;
}
}
// Keepalive
_ = ping_interval.tick() => {
if socket.send(Message::Ping(vec![].into())).await.is_err() {
break;
}
}
}
}
}

View File

@@ -1,17 +1,10 @@
use axum::Json;
use axum::extract::ws::{Message as WsMessage, WebSocket};
use axum::extract::{ConnectInfo, Query};
use axum::http::HeaderMap;
use axum::routing::get;
use axum::{Extension, extract::WebSocketUpgrade, http::StatusCode, response::IntoResponse};
use axum_extra::{TypedHeader, headers};
use axum::{Extension, http::StatusCode};
use serde::Deserialize;
use std::net::SocketAddr;
use std::time::Duration;
use tokio::select;
use crate::auth::{create_jwt, verify_jwt, verify_jwt_string};
use crate::realtime::Realtime;
use crate::auth::{create_jwt, verify_jwt};
#[derive(sqlx::FromRow, serde::Serialize, Deserialize)]
pub struct WsAuthQuery {
@@ -19,9 +12,7 @@ pub struct WsAuthQuery {
}
pub fn routes() -> axum::Router {
axum::Router::new()
.route("/ws/messages/issue-token", get(issue_ws_token))
.route("/ws/messages", get(ws_handler))
axum::Router::new().route("/ws/issue-token", get(issue_ws_token))
}
pub async fn issue_ws_token(
@@ -52,106 +43,3 @@ pub async fn issue_ws_token(
Ok((StatusCode::CREATED, Json(WsAuthQuery { token })))
}
async fn ws_handler(
ws: WebSocketUpgrade,
user_agent: Option<TypedHeader<headers::UserAgent>>,
Query(query): Query<WsAuthQuery>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
Extension(realtime): Extension<Realtime>,
Extension(db): Extension<sqlx::PgPool>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
// tracing::info!("recieved ws handshake: {}", room_uuid);
let claims = verify_jwt_string(&query.token)?;
let user_uuid = claims.sub;
let result = sqlx::query(
r#"
delete from ws_token_
where token = $1
and expires_at > now()
"#,
)
.bind(query.token)
.execute(&db)
.await
.map_err(|e| {
tracing::error!("Failed to get WS token from DB: {e}");
(StatusCode::INTERNAL_SERVER_ERROR, "DB error".into())
})?;
if result.rows_affected() == 0 {
return Err((StatusCode::UNAUTHORIZED, "Wrong token".into()));
}
let receiver = realtime.get_sender(user_uuid).subscribe();
let user_agent = if let Some(TypedHeader(user_agent)) = user_agent {
user_agent.to_string()
} else {
String::from("Unknown browser")
};
tracing::debug!("`{user_agent}` {user_uuid} at {addr} connected.");
Ok(ws.on_upgrade(move |socket| handle_socket(socket, addr, receiver)))
}
async fn handle_socket(
mut socket: WebSocket,
who: SocketAddr,
mut receiver: tokio::sync::broadcast::Receiver<crate::routes::messages::Message>,
) {
let mut ping_interval = tokio::time::interval(Duration::from_secs(30));
loop {
select! {
// Receive broadcast messages and send to client (any room)
msg = receiver.recv() => {
if let Ok(msg) = msg {
if let Ok(json) = serde_json::to_string(&msg) {
if socket.send(WsMessage::Text(json.into())).await.is_err() {
tracing::error!("Failed to send message to {who}, closing connection");
break;
}
}
} else {
break;
}
}
// Send Ping
_ = ping_interval.tick() => {
if socket.send(WsMessage::Ping(vec![].into())).await.is_err() {
tracing::error!("Failed to send ping to {who}, closing connection");
break;
}
// tracing::debug!("Ping sent to {who}");
}
// Get incoming messages from client
client_msg = socket.recv() => {
if let Some(Ok(msg)) = client_msg {
match msg {
// WsMessage::Pong(_) => {
// tracing::debug!("Received Pong from {who}");
// }
// WsMessage::Ping(_) => {
// tracing::info!("Received Ping from client");
// }
// WsMessage::Text(_) => {}
WsMessage::Close(_) => {
tracing::debug!("Client disconnected");
break;
}
_ => {}
}
} else {
tracing::debug!("Client {who} abruptly disconnected");
break;
}
}
}
}
}