added voice chat websocket route
This commit is contained in:
@@ -1,16 +1,28 @@
|
||||
use std::{net::SocketAddr, time::Duration};
|
||||
|
||||
use axum::{
|
||||
Extension, Json, Router,
|
||||
extract::{Path, Query},
|
||||
extract::{
|
||||
ConnectInfo, Path, Query, WebSocketUpgrade,
|
||||
ws::{Message as WsMessage, WebSocket},
|
||||
},
|
||||
http::{HeaderMap, StatusCode},
|
||||
response::IntoResponse,
|
||||
routing::{get, post},
|
||||
};
|
||||
use axum_extra::{TypedHeader, headers};
|
||||
use sqlx::PgPool;
|
||||
use tokio::select;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{auth::verify_jwt, db::room_id_from_uuid, routes::rooms::is_member};
|
||||
use crate::{
|
||||
auth::{verify_jwt, verify_jwt_string},
|
||||
db::room_id_from_uuid,
|
||||
routes::{rooms::is_member, ws::WsAuthQuery},
|
||||
};
|
||||
use crate::{
|
||||
db::{user_id_from_uuid, username_from_uuid},
|
||||
realtime::Realtime,
|
||||
realtime::RealtimeMessages,
|
||||
};
|
||||
|
||||
#[derive(sqlx::FromRow, serde::Serialize, Debug)]
|
||||
@@ -51,6 +63,7 @@ pub fn routes() -> Router {
|
||||
Router::new()
|
||||
.route("/messages/{room_uuid}", get(list_messages))
|
||||
.route("/messages/{room_uuid}", post(create_message))
|
||||
.route("/ws/messages", get(message_ws_handler))
|
||||
}
|
||||
|
||||
/// Also resets `last_read_at`
|
||||
@@ -155,7 +168,7 @@ async fn list_messages(
|
||||
async fn create_message(
|
||||
Path(room_uuid): Path<Uuid>,
|
||||
Extension(db): Extension<PgPool>,
|
||||
Extension(realtime): Extension<Realtime>,
|
||||
Extension(realtime): Extension<RealtimeMessages>,
|
||||
headers: HeaderMap,
|
||||
Json(payload): Json<NewMessagePayload>,
|
||||
) -> Result<(StatusCode, Json<Message>), (StatusCode, String)> {
|
||||
@@ -223,3 +236,106 @@ async fn create_message(
|
||||
|
||||
Ok((StatusCode::CREATED, Json(message)))
|
||||
}
|
||||
|
||||
async fn message_ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
user_agent: Option<TypedHeader<headers::UserAgent>>,
|
||||
Query(query): Query<WsAuthQuery>,
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
Extension(realtime): Extension<RealtimeMessages>,
|
||||
Extension(db): Extension<sqlx::PgPool>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
// tracing::info!("recieved ws handshake: {}", room_uuid);
|
||||
|
||||
let claims = verify_jwt_string(&query.token)?;
|
||||
let user_uuid = claims.sub;
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
delete from ws_token_
|
||||
where token = $1
|
||||
and expires_at > now()
|
||||
"#,
|
||||
)
|
||||
.bind(query.token)
|
||||
.execute(&db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Failed to get WS token from DB: {e}");
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "DB error".into())
|
||||
})?;
|
||||
|
||||
if result.rows_affected() == 0 {
|
||||
return Err((StatusCode::UNAUTHORIZED, "Wrong token".into()));
|
||||
}
|
||||
|
||||
let receiver = realtime.get_sender(user_uuid).subscribe();
|
||||
|
||||
let user_agent = if let Some(TypedHeader(user_agent)) = user_agent {
|
||||
user_agent.to_string()
|
||||
} else {
|
||||
String::from("Unknown browser")
|
||||
};
|
||||
tracing::debug!("`{user_agent}` {user_uuid} at {addr} connected.");
|
||||
|
||||
Ok(ws.on_upgrade(move |socket| handle_message_socket(socket, addr, receiver)))
|
||||
}
|
||||
|
||||
async fn handle_message_socket(
|
||||
mut socket: WebSocket,
|
||||
who: SocketAddr,
|
||||
mut receiver: tokio::sync::broadcast::Receiver<crate::routes::messages::Message>,
|
||||
) {
|
||||
let mut ping_interval = tokio::time::interval(Duration::from_secs(30));
|
||||
|
||||
loop {
|
||||
select! {
|
||||
// Receive broadcast messages and send to client (any room)
|
||||
msg = receiver.recv() => {
|
||||
if let Ok(msg) = msg {
|
||||
if let Ok(json) = serde_json::to_string(&msg) {
|
||||
if socket.send(WsMessage::Text(json.into())).await.is_err() {
|
||||
tracing::error!("Failed to send message to {who}, closing connection");
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Send Ping
|
||||
_ = ping_interval.tick() => {
|
||||
if socket.send(WsMessage::Ping(vec![].into())).await.is_err() {
|
||||
tracing::error!("Failed to send ping to {who}, closing connection");
|
||||
break;
|
||||
}
|
||||
|
||||
// tracing::debug!("Ping sent to {who}");
|
||||
}
|
||||
|
||||
// Get incoming messages from client
|
||||
client_msg = socket.recv() => {
|
||||
if let Some(Ok(msg)) = client_msg {
|
||||
match msg {
|
||||
// WsMessage::Pong(_) => {
|
||||
// tracing::debug!("Received Pong from {who}");
|
||||
// }
|
||||
// WsMessage::Ping(_) => {
|
||||
// tracing::info!("Received Ping from client");
|
||||
// }
|
||||
// WsMessage::Text(_) => {}
|
||||
WsMessage::Close(_) => {
|
||||
tracing::debug!("Client disconnected");
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
} else {
|
||||
tracing::debug!("Client {who} abruptly disconnected");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -2,4 +2,5 @@ pub mod friends;
|
||||
pub mod messages;
|
||||
pub mod rooms;
|
||||
pub mod users;
|
||||
pub mod voice;
|
||||
pub mod ws;
|
||||
|
||||
126
src/routes/voice.rs
Normal file
126
src/routes/voice.rs
Normal file
@@ -0,0 +1,126 @@
|
||||
use axum::{
|
||||
Extension, Router,
|
||||
extract::{
|
||||
ConnectInfo, Path, Query, WebSocketUpgrade,
|
||||
ws::{Message, WebSocket},
|
||||
},
|
||||
http::StatusCode,
|
||||
response::IntoResponse,
|
||||
routing::get,
|
||||
};
|
||||
use bytes::{BufMut, Bytes, BytesMut};
|
||||
use sqlx::PgPool;
|
||||
use std::{net::SocketAddr, time::Duration};
|
||||
use tokio::select;
|
||||
use uuid::Uuid;
|
||||
|
||||
use crate::{
|
||||
auth::verify_jwt_string,
|
||||
db::{room_id_from_uuid, user_id_from_uuid},
|
||||
realtime::RealTimeVoices,
|
||||
routes::rooms::is_member,
|
||||
routes::ws::WsAuthQuery,
|
||||
};
|
||||
|
||||
pub fn routes() -> Router {
|
||||
Router::new().route("/ws/voice/{room_uuid}", get(voice_ws_handler))
|
||||
}
|
||||
|
||||
async fn voice_ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
Path(room_uuid): Path<Uuid>,
|
||||
Query(query): Query<WsAuthQuery>,
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
Extension(voice_manager): Extension<RealTimeVoices>,
|
||||
Extension(db): Extension<PgPool>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
let claims = verify_jwt_string(&query.token)?;
|
||||
let user_uuid = claims.sub;
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
DELETE FROM ws_token_
|
||||
WHERE token = $1
|
||||
AND expires_at > now()
|
||||
"#,
|
||||
)
|
||||
.bind(&query.token)
|
||||
.execute(&db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Failed to get WS token from DB: {e}");
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "DB error".into())
|
||||
})?;
|
||||
|
||||
if result.rows_affected() == 0 {
|
||||
return Err((StatusCode::UNAUTHORIZED, "Invalid or expired token".into()));
|
||||
}
|
||||
|
||||
let user_id = user_id_from_uuid(&db, user_uuid).await?;
|
||||
let room_id = room_id_from_uuid(&db, room_uuid).await?;
|
||||
|
||||
if !is_member(user_id, room_id, &db).await {
|
||||
return Err((StatusCode::FORBIDDEN, "Not a member of this room".into()));
|
||||
}
|
||||
|
||||
tracing::info!("User {} joining voice in room {}", user_uuid, room_uuid);
|
||||
|
||||
let tx = voice_manager.get_or_create_room(room_uuid);
|
||||
let rx = tx.subscribe();
|
||||
|
||||
Ok(ws.on_upgrade(move |socket| handle_voice_socket(socket, addr, user_uuid, tx, rx)))
|
||||
}
|
||||
|
||||
async fn handle_voice_socket(
|
||||
mut socket: WebSocket,
|
||||
who: SocketAddr,
|
||||
my_uuid: Uuid,
|
||||
tx: tokio::sync::broadcast::Sender<(Uuid, Bytes)>,
|
||||
mut rx: tokio::sync::broadcast::Receiver<(Uuid, Bytes)>,
|
||||
) {
|
||||
let mut ping_interval = tokio::time::interval(Duration::from_secs(15));
|
||||
|
||||
loop {
|
||||
select! {
|
||||
// Receive audio from other users and send to client
|
||||
voice_packet = rx.recv() => {
|
||||
if let Ok((speaker_uuid, audio_data)) = voice_packet {
|
||||
if speaker_uuid != my_uuid {
|
||||
let mut msg = BytesMut::with_capacity(16 + audio_data.len());
|
||||
msg.put(speaker_uuid.as_bytes().as_slice());
|
||||
msg.put(audio_data);
|
||||
|
||||
if socket.send(Message::Binary(msg.freeze().into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Receive audio from alient and broadcast to room
|
||||
client_msg = socket.recv() => {
|
||||
if let Some(Ok(msg)) = client_msg {
|
||||
match msg {
|
||||
Message::Binary(data) => {
|
||||
let _ = tx.send((my_uuid, Bytes::from(data)));
|
||||
}
|
||||
Message::Close(_) => {
|
||||
tracing::debug!("Voice client {} disconnected", who);
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Keepalive
|
||||
_ = ping_interval.tick() => {
|
||||
if socket.send(Message::Ping(vec![].into())).await.is_err() {
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
118
src/routes/ws.rs
118
src/routes/ws.rs
@@ -1,17 +1,10 @@
|
||||
use axum::Json;
|
||||
use axum::extract::ws::{Message as WsMessage, WebSocket};
|
||||
use axum::extract::{ConnectInfo, Query};
|
||||
use axum::http::HeaderMap;
|
||||
use axum::routing::get;
|
||||
use axum::{Extension, extract::WebSocketUpgrade, http::StatusCode, response::IntoResponse};
|
||||
use axum_extra::{TypedHeader, headers};
|
||||
use axum::{Extension, http::StatusCode};
|
||||
use serde::Deserialize;
|
||||
use std::net::SocketAddr;
|
||||
use std::time::Duration;
|
||||
use tokio::select;
|
||||
|
||||
use crate::auth::{create_jwt, verify_jwt, verify_jwt_string};
|
||||
use crate::realtime::Realtime;
|
||||
use crate::auth::{create_jwt, verify_jwt};
|
||||
|
||||
#[derive(sqlx::FromRow, serde::Serialize, Deserialize)]
|
||||
pub struct WsAuthQuery {
|
||||
@@ -19,9 +12,7 @@ pub struct WsAuthQuery {
|
||||
}
|
||||
|
||||
pub fn routes() -> axum::Router {
|
||||
axum::Router::new()
|
||||
.route("/ws/messages/issue-token", get(issue_ws_token))
|
||||
.route("/ws/messages", get(ws_handler))
|
||||
axum::Router::new().route("/ws/issue-token", get(issue_ws_token))
|
||||
}
|
||||
|
||||
pub async fn issue_ws_token(
|
||||
@@ -52,106 +43,3 @@ pub async fn issue_ws_token(
|
||||
|
||||
Ok((StatusCode::CREATED, Json(WsAuthQuery { token })))
|
||||
}
|
||||
|
||||
async fn ws_handler(
|
||||
ws: WebSocketUpgrade,
|
||||
user_agent: Option<TypedHeader<headers::UserAgent>>,
|
||||
Query(query): Query<WsAuthQuery>,
|
||||
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
||||
Extension(realtime): Extension<Realtime>,
|
||||
Extension(db): Extension<sqlx::PgPool>,
|
||||
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
||||
// tracing::info!("recieved ws handshake: {}", room_uuid);
|
||||
|
||||
let claims = verify_jwt_string(&query.token)?;
|
||||
let user_uuid = claims.sub;
|
||||
|
||||
let result = sqlx::query(
|
||||
r#"
|
||||
delete from ws_token_
|
||||
where token = $1
|
||||
and expires_at > now()
|
||||
"#,
|
||||
)
|
||||
.bind(query.token)
|
||||
.execute(&db)
|
||||
.await
|
||||
.map_err(|e| {
|
||||
tracing::error!("Failed to get WS token from DB: {e}");
|
||||
(StatusCode::INTERNAL_SERVER_ERROR, "DB error".into())
|
||||
})?;
|
||||
|
||||
if result.rows_affected() == 0 {
|
||||
return Err((StatusCode::UNAUTHORIZED, "Wrong token".into()));
|
||||
}
|
||||
|
||||
let receiver = realtime.get_sender(user_uuid).subscribe();
|
||||
|
||||
let user_agent = if let Some(TypedHeader(user_agent)) = user_agent {
|
||||
user_agent.to_string()
|
||||
} else {
|
||||
String::from("Unknown browser")
|
||||
};
|
||||
tracing::debug!("`{user_agent}` {user_uuid} at {addr} connected.");
|
||||
|
||||
Ok(ws.on_upgrade(move |socket| handle_socket(socket, addr, receiver)))
|
||||
}
|
||||
|
||||
async fn handle_socket(
|
||||
mut socket: WebSocket,
|
||||
who: SocketAddr,
|
||||
mut receiver: tokio::sync::broadcast::Receiver<crate::routes::messages::Message>,
|
||||
) {
|
||||
let mut ping_interval = tokio::time::interval(Duration::from_secs(30));
|
||||
|
||||
loop {
|
||||
select! {
|
||||
// Receive broadcast messages and send to client (any room)
|
||||
msg = receiver.recv() => {
|
||||
if let Ok(msg) = msg {
|
||||
if let Ok(json) = serde_json::to_string(&msg) {
|
||||
if socket.send(WsMessage::Text(json.into())).await.is_err() {
|
||||
tracing::error!("Failed to send message to {who}, closing connection");
|
||||
break;
|
||||
}
|
||||
}
|
||||
} else {
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Send Ping
|
||||
_ = ping_interval.tick() => {
|
||||
if socket.send(WsMessage::Ping(vec![].into())).await.is_err() {
|
||||
tracing::error!("Failed to send ping to {who}, closing connection");
|
||||
break;
|
||||
}
|
||||
|
||||
// tracing::debug!("Ping sent to {who}");
|
||||
}
|
||||
|
||||
// Get incoming messages from client
|
||||
client_msg = socket.recv() => {
|
||||
if let Some(Ok(msg)) = client_msg {
|
||||
match msg {
|
||||
// WsMessage::Pong(_) => {
|
||||
// tracing::debug!("Received Pong from {who}");
|
||||
// }
|
||||
// WsMessage::Ping(_) => {
|
||||
// tracing::info!("Received Ping from client");
|
||||
// }
|
||||
// WsMessage::Text(_) => {}
|
||||
WsMessage::Close(_) => {
|
||||
tracing::debug!("Client disconnected");
|
||||
break;
|
||||
}
|
||||
_ => {}
|
||||
}
|
||||
} else {
|
||||
tracing::debug!("Client {who} abruptly disconnected");
|
||||
break;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user