//! WebSocket RPC transport. `@client(websocket=true)` functions declare //! `Transport::Websocket` in the IR; this routes a real Axum WebSocket handler //! that dispatches call/fetch frames through the same `mizan-core` registry //! the HTTP path uses. A call frame naming a non-websocket function is //! rejected, so the transport boundary the IR declares is enforced. //! //! Frame protocol (text JSON), mirroring the HTTP call/ctx shapes: //! → {"id": 1, "op": "call", "fn": "name", "args": {...}} //! → {"id": 2, "op": "fetch", "context": "c", "params": {...}} //! ← {"id": 1, "result": ..., "invalidate": [...], "merge"?: [...]} //! ← {"id": 2, "data": {fnName: result, ...}} //! ← {"id": N, "error": {"code": ..., "message": ...}} use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade}; use axum::extract::State; use axum::response::Response; use futures_util::StreamExt; use mizan_core::{ compute_invalidation, compute_merges, lookup_context, lookup_function, AuthRequirement, FunctionSpec, InvalidationTarget, MergeEntry, MizanError, RequestHandle, Transport, FUNCTIONS, }; use serde_json::{json, Map, Value}; use std::sync::Arc; use crate::state::MizanState; /// GET /ws/ — upgrade to a Mizan WebSocket RPC connection. pub async fn ws_handler( ws: WebSocketUpgrade, State(state): State>, ) -> Response { ws.on_upgrade(move |socket| handle_socket(socket, state)) } async fn handle_socket(mut socket: WebSocket, state: Arc) { while let Some(Ok(msg)) = socket.next().await { let text = match msg { Message::Text(t) => t, Message::Close(_) => break, Message::Ping(_) | Message::Pong(_) | Message::Binary(_) => continue, }; let reply = handle_frame(&state, &text).await; if socket .send(Message::Text(reply.to_string())) .await .is_err() { break; } } } async fn handle_frame(state: &MizanState, text: &str) -> Value { let frame: Value = match serde_json::from_str(text) { Ok(v) => v, Err(e) => return err_frame(Value::Null, &MizanError::BadRequest(format!("bad frame: {e}"))), }; let id = frame.get("id").cloned().unwrap_or(Value::Null); let op = frame.get("op").and_then(|o| o.as_str()).unwrap_or("call"); match op { "call" => match dispatch_ws_call(state, &frame).await { Ok(v) => with_id(id, v), Err(e) => err_frame(id, &e), }, "fetch" => match dispatch_ws_fetch(state, &frame).await { Ok(v) => with_id(id, json!({ "data": v })), Err(e) => err_frame(id, &e), }, other => err_frame(id, &MizanError::BadRequest(format!("unknown op {other:?}"))), } } async fn dispatch_ws_call(state: &MizanState, frame: &Value) -> Result { let fn_name = frame .get("fn") .and_then(|f| f.as_str()) .ok_or_else(|| MizanError::BadRequest("missing `fn`".into()))?; let args = frame .get("args") .and_then(|a| a.as_object()) .cloned() .unwrap_or_default(); let fn_spec = lookup_function(fn_name).ok_or_else(|| MizanError::NotFound(format!("{fn_name:?}")))?; if fn_spec.private() { return Err(MizanError::Forbidden("Function is not client-callable".into())); } // The WS transport only carries functions that opted into it. if !matches!(fn_spec.transport(), Transport::Websocket | Transport::Both) { return Err(MizanError::BadRequest(format!( "function {fn_name:?} is not exposed over the WebSocket transport" ))); } enforce_anon_guard(fn_spec)?; let req = RequestHandle::from_dyn(state.app_state.as_ref()); let result = fn_spec.dispatch(req, Value::Object(args.clone())).await?; let targets = compute_invalidation(fn_spec, &args); let invalidate: Vec = targets.iter().map(InvalidationTarget::to_json).collect(); let merges = compute_merges(fn_spec, &args, &result); let mut out = Map::new(); out.insert("result".into(), result); out.insert("invalidate".into(), Value::Array(invalidate)); if !merges.is_empty() { out.insert( "merge".into(), Value::Array(merges.iter().map(MergeEntry::to_json).collect()), ); } Ok(Value::Object(out)) } async fn dispatch_ws_fetch(state: &MizanState, frame: &Value) -> Result { let ctx = frame .get("context") .and_then(|c| c.as_str()) .ok_or_else(|| MizanError::BadRequest("missing `context`".into()))?; if lookup_context(ctx).is_none() { return Err(MizanError::NotFound(format!("context {ctx:?}"))); } let params = frame .get("params") .and_then(|p| p.as_object()) .cloned() .unwrap_or_default(); let members: Vec<&dyn FunctionSpec> = FUNCTIONS .iter() .copied() .filter(|f| f.context() == Some(ctx)) .collect(); let mut bundle = Map::new(); for fn_spec in &members { enforce_anon_guard(*fn_spec)?; let mut args = Map::new(); for ip in fn_spec.input_params() { if let Some(v) = params.get(ip.name) { args.insert(ip.name.into(), v.clone()); } } let req = RequestHandle::from_dyn(state.app_state.as_ref()); let result = fn_spec.dispatch(req, Value::Object(args)).await?; bundle.insert(fn_spec.name().to_string(), result); } Ok(Value::Object(bundle)) } /// Enforce a function's auth guard for the WS transport. The WS upgrade /// carries no per-frame identity in this baseline, so a guarded function is /// rejected over WS — the same enforce-or-reject contract the HTTP path uses, /// applied with an anonymous identity. fn enforce_anon_guard(fn_spec: &dyn FunctionSpec) -> Result<(), MizanError> { let req = AuthRequirement::from_str_opt(fn_spec.auth()); mizan_core::enforce_auth(None, &req) } fn with_id(id: Value, mut body: Value) -> Value { if let Some(obj) = body.as_object_mut() { obj.insert("id".into(), id); } body } fn err_frame(id: Value, e: &MizanError) -> Value { json!({ "id": id, "error": { "code": e.code(), "message": e.message() }, }) }