158 lines
4.9 KiB
Rust
158 lines
4.9 KiB
Rust
use axum::Json;
|
|
use axum::extract::ws::{Message as WsMessage, WebSocket};
|
|
use axum::extract::{ConnectInfo, Query};
|
|
use axum::http::HeaderMap;
|
|
use axum::routing::get;
|
|
use axum::{Extension, extract::WebSocketUpgrade, http::StatusCode, response::IntoResponse};
|
|
use axum_extra::{TypedHeader, headers};
|
|
use serde::Deserialize;
|
|
use std::net::SocketAddr;
|
|
use std::time::Duration;
|
|
use tokio::select;
|
|
|
|
use crate::auth::{create_jwt, verify_jwt, verify_jwt_string};
|
|
use crate::realtime::Realtime;
|
|
|
|
#[derive(sqlx::FromRow, serde::Serialize, Deserialize)]
|
|
pub struct WsAuthQuery {
|
|
pub token: String,
|
|
}
|
|
|
|
pub fn routes() -> axum::Router {
|
|
axum::Router::new()
|
|
.route("/ws/messages/issue-token", get(issue_ws_token))
|
|
.route("/ws/messages", get(ws_handler))
|
|
}
|
|
|
|
pub async fn issue_ws_token(
|
|
Extension(db): Extension<sqlx::PgPool>,
|
|
headers: HeaderMap,
|
|
) -> Result<(StatusCode, Json<WsAuthQuery>), (StatusCode, String)> {
|
|
let claims = verify_jwt(headers)?;
|
|
|
|
tracing::debug!("Recieved token issue request from user {}", claims.sub);
|
|
|
|
let token = create_jwt(claims.sub).map_err(|e| (StatusCode::INTERNAL_SERVER_ERROR, e))?;
|
|
|
|
sqlx::query(
|
|
r#"
|
|
insert into ws_token_ (token, expires_at)
|
|
values ($1, now() + interval '30 seconds')
|
|
"#,
|
|
)
|
|
.bind(&token)
|
|
.execute(&db)
|
|
.await
|
|
.map_err(|_| {
|
|
(
|
|
StatusCode::INTERNAL_SERVER_ERROR,
|
|
format!("failed to provide ws token"),
|
|
)
|
|
})?;
|
|
|
|
Ok((StatusCode::CREATED, Json(WsAuthQuery { token })))
|
|
}
|
|
|
|
async fn ws_handler(
|
|
ws: WebSocketUpgrade,
|
|
user_agent: Option<TypedHeader<headers::UserAgent>>,
|
|
Query(query): Query<WsAuthQuery>,
|
|
ConnectInfo(addr): ConnectInfo<SocketAddr>,
|
|
Extension(realtime): Extension<Realtime>,
|
|
Extension(db): Extension<sqlx::PgPool>,
|
|
) -> Result<impl IntoResponse, (StatusCode, String)> {
|
|
// tracing::info!("recieved ws handshake: {}", room_uuid);
|
|
|
|
let claims = verify_jwt_string(&query.token)?;
|
|
let user_uuid = claims.sub;
|
|
|
|
let result = sqlx::query(
|
|
r#"
|
|
delete from ws_token_
|
|
where token = $1
|
|
and expires_at > now()
|
|
"#,
|
|
)
|
|
.bind(query.token)
|
|
.execute(&db)
|
|
.await
|
|
.map_err(|e| {
|
|
tracing::error!("Failed to get WS token from DB: {e}");
|
|
(StatusCode::INTERNAL_SERVER_ERROR, "DB error".into())
|
|
})?;
|
|
|
|
if result.rows_affected() == 0 {
|
|
return Err((StatusCode::UNAUTHORIZED, "Wrong token".into()));
|
|
}
|
|
|
|
let receiver = realtime.get_sender(user_uuid).subscribe();
|
|
|
|
let user_agent = if let Some(TypedHeader(user_agent)) = user_agent {
|
|
user_agent.to_string()
|
|
} else {
|
|
String::from("Unknown browser")
|
|
};
|
|
tracing::debug!("`{user_agent}` {user_uuid} at {addr} connected.");
|
|
|
|
Ok(ws.on_upgrade(move |socket| handle_socket(socket, addr, receiver)))
|
|
}
|
|
|
|
async fn handle_socket(
|
|
mut socket: WebSocket,
|
|
who: SocketAddr,
|
|
mut receiver: tokio::sync::broadcast::Receiver<crate::routes::messages::Message>,
|
|
) {
|
|
let mut ping_interval = tokio::time::interval(Duration::from_secs(30));
|
|
|
|
loop {
|
|
select! {
|
|
// Receive broadcast messages and send to client (any room)
|
|
msg = receiver.recv() => {
|
|
if let Ok(msg) = msg {
|
|
if let Ok(json) = serde_json::to_string(&msg) {
|
|
if socket.send(WsMessage::Text(json.into())).await.is_err() {
|
|
tracing::error!("Failed to send message to {who}, closing connection");
|
|
break;
|
|
}
|
|
}
|
|
} else {
|
|
break;
|
|
}
|
|
}
|
|
|
|
// Send Ping
|
|
_ = ping_interval.tick() => {
|
|
if socket.send(WsMessage::Ping(vec![].into())).await.is_err() {
|
|
tracing::error!("Failed to send ping to {who}, closing connection");
|
|
break;
|
|
}
|
|
|
|
// tracing::debug!("Ping sent to {who}");
|
|
}
|
|
|
|
// Get incoming messages from client
|
|
client_msg = socket.recv() => {
|
|
if let Some(Ok(msg)) = client_msg {
|
|
match msg {
|
|
// WsMessage::Pong(_) => {
|
|
// tracing::debug!("Received Pong from {who}");
|
|
// }
|
|
// WsMessage::Ping(_) => {
|
|
// tracing::info!("Received Ping from client");
|
|
// }
|
|
// WsMessage::Text(_) => {}
|
|
WsMessage::Close(_) => {
|
|
tracing::debug!("Client disconnected");
|
|
break;
|
|
}
|
|
_ => {}
|
|
}
|
|
} else {
|
|
tracing::debug!("Client {who} abruptly disconnected");
|
|
break;
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|