""" WebSocket consumer for mizan.channels. Handles multiplexed channel subscriptions AND RPC calls over a single WebSocket connection. Protocol: Browser sends: # Channel subscriptions {"action": "subscribe", "channel": "chat", "params": {"room": "general"}} {"action": "unsubscribe", "channel": "chat", "params": {"room": "general"}} {"action": "message", "channel": "chat", "params": {"room": "general"}, "data": {...}} # RPC calls (server functions) {"action": "rpc", "id": "request-id", "fn": "function_name", "args": {...}} Server sends: # Channel messages {"channel": "chat", "params": {"room": "general"}, "type": "DjangoMessage", "data": {...}} # RPC responses {"id": "request-id", "ok": true, "data": {...}} {"id": "request-id", "ok": false, "error": {...}} {"error": "..."} Authentication: Supports both session (cookie) and JWT authentication: - Session: Handled automatically via AuthMiddlewareStack (cookies in handshake) - JWT: Pass token as query parameter: ws://...?token= The WebSocket URL for JWT auth would be: ws://localhost/ws/?token= Security: - Functions must be explicitly registered (no arbitrary code execution) - Pydantic validation runs BEFORE any function code """ import json import logging from typing import Any from urllib.parse import parse_qs from channels.generic.websocket import AsyncJsonWebsocketConsumer from asgiref.sync import sync_to_async from . import get_channel logger = logging.getLogger(__name__) class WebSocketRequest: """ Minimal request adapter for WebSocket context. Provides the interface expected by ServerFunction without full HttpRequest. This is intentionally minimal - only expose what's needed. Note: Some Django libraries (e.g., allauth rate limiting) check request.method. We set method="POST" since WebSocket RPC calls are semantically similar to POST. """ # WebSocket RPC is semantically similar to POST (sends data, expects response) method = "POST" def __init__(self, scope: dict, channel_name: str = None): self.user = scope.get("user") self.session = scope.get("session", {}) self.channel_name = channel_name # For push subscriptions self._scope = scope @property def META(self) -> dict: """HTTP headers from WebSocket handshake.""" headers = dict(self._scope.get("headers", [])) return { "HTTP_" + k.decode().upper().replace("-", "_"): v.decode() for k, v in headers.items() } class DjangoReactConsumer(AsyncJsonWebsocketConsumer): """ Multiplexed WebSocket consumer for django_react channels. Manages multiple channel subscriptions over a single WebSocket connection. Authentication: - Session auth via cookies (handled by AuthMiddlewareStack) - JWT auth via query parameter: ws://...?token= """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # Track subscriptions: {(channel_name, params_json): channel_instance} self._subscriptions: dict[tuple[str, str], Any] = {} async def connect(self): """Accept the WebSocket connection, authenticating via JWT if provided.""" # Check for JWT token in query parameters await self._try_jwt_auth() await self.accept() logger.debug( f"WebSocket connected: {self.channel_name}, user={self.scope.get('user')}" ) async def _try_jwt_auth(self): """ Attempt JWT authentication from query parameter. If a valid JWT token is provided via ?token=, authenticate the user using JWTUser (no database query). Security: If JWT is provided but invalid, we log it but don't reject the connection - the session auth may still be valid. However, if JWT IS valid, it takes precedence over session auth. """ # Parse query string for token query_string = self.scope.get("query_string", b"").decode() params = parse_qs(query_string) token_list = params.get("token", []) if not token_list: return # No JWT provided, use session auth token = token_list[0] if not token: return # Validate JWT and create JWTUser (no DB query) try: from mizan.client.jwt import decode_token from mizan.jwt.tokens import JWTUser payload = await sync_to_async(decode_token)(token, expected_type="access") if payload is None: logger.debug("JWT token invalid or expired") return # Fall back to session auth # Create JWTUser from token claims - NO DATABASE QUERY self.scope["user"] = JWTUser(payload) logger.debug(f"JWT auth successful for user {payload.user_id}") except Exception as e: logger.debug(f"JWT auth failed: {e}") async def disconnect(self, close_code): """Clean up all subscriptions on disconnect.""" for key, instance in list(self._subscriptions.items()): try: await instance.on_disconnect() await instance._leave_all_groups() except Exception as e: logger.error(f"Error during disconnect cleanup: {e}") self._subscriptions.clear() logger.debug(f"WebSocket disconnected: {self.channel_name}") async def receive_json(self, content: dict): """Handle incoming JSON messages.""" action = content.get("action") if action == "subscribe": await self._handle_subscribe(content) elif action == "unsubscribe": await self._handle_unsubscribe(content) elif action == "message": await self._handle_message(content) elif action == "rpc": await self._handle_rpc(content) else: await self.send_json( { "error": f"Unknown action: {action}", } ) async def _handle_subscribe(self, content: dict): """Handle subscription request.""" channel_name = content.get("channel") params_dict = content.get("params", {}) # Get channel class channel_class = get_channel(channel_name) if not channel_class: await self.send_json( { "error": f"Unknown channel: {channel_name}", } ) return # Create subscription key params_json = json.dumps(params_dict, sort_keys=True) sub_key = (channel_name, params_json) # Check if already subscribed if sub_key in self._subscriptions: await self.send_json( { "error": f"Already subscribed to {channel_name}", "channel": channel_name, "params": params_dict, } ) return # Create channel instance instance = channel_class() instance.user = self.scope.get("user") instance._channel_layer = self.channel_layer instance._channel_name = self.channel_name instance._registered_name = channel_name instance._params_dict = params_dict # Parse params params_obj = None if channel_class.Params: try: params_obj = channel_class.Params(**params_dict) except Exception as e: await self.send_json( { "error": f"Invalid params: {e}", "channel": channel_name, } ) return # Check authorization try: if params_obj: authorized = instance.authorize(params_obj) else: authorized = instance.authorize() except Exception as e: logger.error(f"Authorization error for {channel_name}: {e}") await self.send_json( { "error": "Authorization failed", "channel": channel_name, } ) return if not authorized: await self.send_json( { "error": "Not authorized", "channel": channel_name, } ) return # Get group and join try: if params_obj: group_name = instance.group(params_obj) else: group_name = instance.group() await instance._join_group(group_name) except Exception as e: logger.error(f"Failed to join group for {channel_name}: {e}") await self.send_json( { "error": f"Failed to subscribe: {e}", "channel": channel_name, } ) return # Store subscription self._subscriptions[sub_key] = instance # Call on_connect hook try: await instance.on_connect(params_obj) except Exception as e: logger.error(f"on_connect error for {channel_name}: {e}") # Confirm subscription await self.send_json( { "subscribed": True, "channel": channel_name, "params": params_dict, } ) logger.debug(f"Subscribed to {channel_name} with params {params_dict}") async def _handle_unsubscribe(self, content: dict): """Handle unsubscription request.""" channel_name = content.get("channel") params_dict = content.get("params", {}) params_json = json.dumps(params_dict, sort_keys=True) sub_key = (channel_name, params_json) instance = self._subscriptions.pop(sub_key, None) if instance: try: await instance.on_disconnect() await instance._leave_all_groups() except Exception as e: logger.error(f"Error during unsubscribe: {e}") await self.send_json( { "unsubscribed": True, "channel": channel_name, "params": params_dict, } ) logger.debug(f"Unsubscribed from {channel_name}") async def _handle_message(self, content: dict): """Handle incoming message from browser.""" channel_name = content.get("channel") params_dict = content.get("params", {}) data = content.get("data", {}) params_json = json.dumps(params_dict, sort_keys=True) sub_key = (channel_name, params_json) instance = self._subscriptions.get(sub_key) if not instance: await self.send_json( { "error": f"Not subscribed to {channel_name}", "channel": channel_name, } ) return channel_class = instance.__class__ # Check if channel accepts messages if not channel_class.ReactMessage: await self.send_json( { "error": f"Channel {channel_name} does not accept messages", "channel": channel_name, } ) return # Parse message try: msg = channel_class.ReactMessage(**data) except Exception as e: await self.send_json( { "error": f"Invalid message: {e}", "channel": channel_name, } ) return # Parse params params_obj = None if channel_class.Params: params_obj = channel_class.Params(**params_dict) # Handle message try: response = instance.receive(params_obj, msg) # If handler returned a message, broadcast it if response is not None: if params_obj: group_name = instance.group(params_obj) else: group_name = instance.group() await instance._broadcast(group_name, response) except Exception as e: logger.error(f"Error handling message for {channel_name}: {e}") await self.send_json( { "error": f"Message handling failed: {e}", "channel": channel_name, } ) async def _handle_rpc(self, content: dict): """ Handle RPC (server function) call. Protocol: Request: {"action": "rpc", "id": "request-id", "fn": "function_name", "args": {...}} Response: {"id": "request-id", "ok": true, "data": {...}} or: {"id": "request-id", "ok": false, "error": {...}} Security: - Only functions with @client(websocket=True) are allowed - Pydantic validation happens BEFORE any function code runs - Function must be explicitly registered (no arbitrary code execution) - User context from WebSocket session is passed to function """ from mizan.client.executor import execute_function, FunctionError from mizan.setup.registry import get_function request_id = content.get("id") fn_name = content.get("fn") args = content.get("args", {}) # Validate request structure if not request_id: await self.send_json( { "error": "RPC request missing 'id' field", } ) return if not fn_name: await self.send_json( { "id": request_id, "ok": False, "error": { "code": "BAD_REQUEST", "message": "Missing 'fn' field", }, } ) return # Check if function exists and has websocket=True fn_class = get_function(fn_name) if fn_class is None: await self.send_json( { "id": request_id, "ok": False, "error": { "code": "NOT_FOUND", "message": f"Function '{fn_name}' not found", }, } ) return # Only allow functions explicitly marked with websocket=True fn_meta = getattr(fn_class, "_meta", {}) if not fn_meta.get("websocket"): await self.send_json( { "id": request_id, "ok": False, "error": { "code": "FORBIDDEN", "message": "This function is HTTP-only. Use POST /api/mizan/call/ instead.", }, } ) return # Create request adapter from WebSocket scope ws_request = WebSocketRequest( self.scope, channel_name=getattr(self, "channel_name", None) ) # Execute function (Pydantic validation happens inside execute_function) # This is sync, so we need to run it in a thread pool result = await sync_to_async(execute_function, thread_sensitive=True)( ws_request, fn_name, args, ) # Send response if isinstance(result, FunctionError): await self.send_json( { "id": request_id, "ok": False, "error": { "code": result.code.value, "message": result.message, **({"details": result.details} if result.details else {}), }, } ) else: await self.send_json( { "id": request_id, "ok": True, "data": result.data, } ) async def channel_message(self, event: dict): """ Handle messages broadcast to a group. Called when channel_layer.group_send() is used. Includes channel name and params so the client can route the message. """ await self.send_json( { "channel": event.get("channel"), "params": event.get("params", {}), "type": event.get("message_type", "message"), "data": event.get("data", {}), } ) async def push_message(self, event: dict): """ Handle push messages from server functions. Called when push("topic", data) is used from a server function. The client receives this to update its local state. Protocol: Server sends: {"type": "push", "topic": "room:42", "data": {...}} """ await self.send_json( { "type": "push", "topic": event.get("topic"), "data": event.get("data", {}), } )