Files
frangipane-backend/src/routes/ws.rs

158 lines
4.9 KiB
Rust

use axum::Json;
use axum::extract::ws::{Message as WsMessage, WebSocket};
use axum::extract::{ConnectInfo, Query};
use axum::http::HeaderMap;
use axum::routing::get;
use axum::{Extension, extract::WebSocketUpgrade, http::StatusCode, response::IntoResponse};
use axum_extra::{TypedHeader, headers};
use serde::Deserialize;
use std::net::SocketAddr;
use std::time::Duration;
use tokio::select;
use crate::auth::{create_jwt, verify_jwt, verify_jwt_string};
use crate::realtime::Realtime;
#[derive(sqlx::FromRow, serde::Serialize, Deserialize)]
pub struct WsAuthQuery {
pub token: String,
}
pub fn routes() -> axum::Router {
axum::Router::new()
.route("/ws/messages/issue-token", get(issue_ws_token))
.route("/ws/messages", get(ws_handler))
}
pub async fn issue_ws_token(
Extension(db): Extension<sqlx::PgPool>,
headers: HeaderMap,
) -> Result<(StatusCode, Json<WsAuthQuery>), (StatusCode, String)> {
let claims = verify_jwt(headers)?;
tracing::debug!("Recieved token issue request from user {}", claims.sub);
let token = create_jwt(claims.sub).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?;
sqlx::query(
r#"
insert into ws_token_ (token, expires_at)
values ($1, now() + interval '30 seconds')
"#,
)
.bind(&token)
.execute(&db)
.await
.map_err(|_| {
(
StatusCode::INTERNAL_SERVER_ERROR,
format!("failed to provide ws token"),
)
})?;
Ok((StatusCode::CREATED, Json(WsAuthQuery { token })))
}
async fn ws_handler(
ws: WebSocketUpgrade,
user_agent: Option<TypedHeader<headers::UserAgent>>,
Query(query): Query<WsAuthQuery>,
ConnectInfo(addr): ConnectInfo<SocketAddr>,
Extension(realtime): Extension<Realtime>,
Extension(db): Extension<sqlx::PgPool>,
) -> Result<impl IntoResponse, (StatusCode, String)> {
// tracing::info!("recieved ws handshake: {}", room_uuid);
let claims = verify_jwt_string(&query.token)?;
let user_uuid = claims.sub;
let result = sqlx::query(
r#"
delete from ws_token_
where token = $1
and expires_at > now()
"#,
)
.bind(query.token)
.execute(&db)
.await
.map_err(|e| {
tracing::error!("Failed to get WS token from DB: {e}");
(StatusCode::INTERNAL_SERVER_ERROR, "DB error".into())
})?;
if result.rows_affected() == 0 {
return Err((StatusCode::UNAUTHORIZED, "Wrong token".into()));
}
let receiver = realtime.get_sender(user_uuid).subscribe();
let user_agent = if let Some(TypedHeader(user_agent)) = user_agent {
user_agent.to_string()
} else {
String::from("Unknown browser")
};
tracing::debug!("`{user_agent}` {user_uuid} at {addr} connected.");
Ok(ws.on_upgrade(move |socket| handle_socket(socket, addr, receiver)))
}
async fn handle_socket(
mut socket: WebSocket,
who: SocketAddr,
mut receiver: tokio::sync::broadcast::Receiver<crate::routes::messages::Message>,
) {
let mut ping_interval = tokio::time::interval(Duration::from_secs(30));
loop {
select! {
// Receive broadcast messages and send to client (any room)
msg = receiver.recv() => {
if let Ok(msg) = msg {
if let Ok(json) = serde_json::to_string(&msg) {
if socket.send(WsMessage::Text(json.into())).await.is_err() {
tracing::error!("Failed to send message to {who}, closing connection");
break;
}
}
} else {
break;
}
}
// Send Ping
_ = ping_interval.tick() => {
if socket.send(WsMessage::Ping(vec![].into())).await.is_err() {
tracing::error!("Failed to send ping to {who}, closing connection");
break;
}
// tracing::debug!("Ping sent to {who}");
}
// Get incoming messages from client
client_msg = socket.recv() => {
if let Some(Ok(msg)) = client_msg {
match msg {
// WsMessage::Pong(_) => {
// tracing::debug!("Received Pong from {who}");
// }
// WsMessage::Ping(_) => {
// tracing::info!("Received Ping from client");
// }
// WsMessage::Text(_) => {}
WsMessage::Close(_) => {
tracing::debug!("Client disconnected");
break;
}
_ => {}
}
} else {
tracing::debug!("Client {who} abruptly disconnected");
break;
}
}
}
}
}