Files
mizan/packages/mizan-rpc/adapters/django/src/mizan/channels/connection.py
Ryth Azhur b28ee72c67 Restructure repo into five-package AFI architecture
Mizan is an Application Framework Interface (AFI) with five
independent packages:

  packages/
    mizan-ast/       Language layer (source → KDL schema)
    mizan-schema/    IR layer (KDL schema definition)
    mizan-rpc/       Protocol layer (client gen + server adapters)
      adapters/django/   ← was django/
      generator/         ← was react/src/generator/
    mizan-csr/       State layer (client state engine)
      adapters/react/    ← was react/
    mizan-ssr/       Rendering layer (server-side rendering)

Each package is independent. The adapter directories contain the
framework-specific implementations. Stub packages (ast, schema, ssr)
establish the structure for future work.

264 Django tests + 33 React tests pass from new locations.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-04-06 15:41:31 -04:00

529 lines
17 KiB
Python

"""
WebSocket consumer for mizan.channels.
Handles multiplexed channel subscriptions AND RPC calls over a single WebSocket connection.
Protocol:
Browser sends:
# Channel subscriptions
{"action": "subscribe", "channel": "chat", "params": {"room": "general"}}
{"action": "unsubscribe", "channel": "chat", "params": {"room": "general"}}
{"action": "message", "channel": "chat", "params": {"room": "general"}, "data": {...}}
# RPC calls (server functions)
{"action": "rpc", "id": "request-id", "fn": "function_name", "args": {...}}
Server sends:
# Channel messages
{"channel": "chat", "params": {"room": "general"}, "type": "DjangoMessage", "data": {...}}
# RPC responses
{"id": "request-id", "ok": true, "data": {...}}
{"id": "request-id", "ok": false, "error": {...}}
{"error": "..."}
Authentication:
Supports both session (cookie) and JWT authentication:
- Session: Handled automatically via AuthMiddlewareStack (cookies in handshake)
- JWT: Pass token as query parameter: ws://...?token=<jwt>
The WebSocket URL for JWT auth would be: ws://localhost/ws/?token=<access_token>
Security:
- Functions must be explicitly registered (no arbitrary code execution)
- Pydantic validation runs BEFORE any function code
"""
import json
import logging
from typing import Any
from urllib.parse import parse_qs
from channels.generic.websocket import AsyncJsonWebsocketConsumer
from asgiref.sync import sync_to_async
from . import get_channel
logger = logging.getLogger(__name__)
class WebSocketRequest:
"""
Minimal request adapter for WebSocket context.
Provides the interface expected by ServerFunction without full HttpRequest.
This is intentionally minimal - only expose what's needed.
Note: Some Django libraries (e.g., allauth rate limiting) check request.method.
We set method="POST" since WebSocket RPC calls are semantically similar to POST.
"""
# WebSocket RPC is semantically similar to POST (sends data, expects response)
method = "POST"
def __init__(self, scope: dict, channel_name: str = None):
self.user = scope.get("user")
self.session = scope.get("session", {})
self.channel_name = channel_name # For push subscriptions
self._scope = scope
@property
def META(self) -> dict:
"""HTTP headers from WebSocket handshake."""
headers = dict(self._scope.get("headers", []))
return {
"HTTP_" + k.decode().upper().replace("-", "_"): v.decode()
for k, v in headers.items()
}
class DjangoReactConsumer(AsyncJsonWebsocketConsumer):
"""
Multiplexed WebSocket consumer for django_react channels.
Manages multiple channel subscriptions over a single WebSocket connection.
Authentication:
- Session auth via cookies (handled by AuthMiddlewareStack)
- JWT auth via query parameter: ws://...?token=<jwt>
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Track subscriptions: {(channel_name, params_json): channel_instance}
self._subscriptions: dict[tuple[str, str], Any] = {}
async def connect(self):
"""Accept the WebSocket connection, authenticating via JWT if provided."""
# Check for JWT token in query parameters
await self._try_jwt_auth()
await self.accept()
logger.debug(
f"WebSocket connected: {self.channel_name}, user={self.scope.get('user')}"
)
async def _try_jwt_auth(self):
"""
Attempt JWT authentication from query parameter.
If a valid JWT token is provided via ?token=<jwt>, authenticate the user
using JWTUser (no database query).
Security: If JWT is provided but invalid, we log it but don't reject
the connection - the session auth may still be valid. However, if JWT
IS valid, it takes precedence over session auth.
"""
# Parse query string for token
query_string = self.scope.get("query_string", b"").decode()
params = parse_qs(query_string)
token_list = params.get("token", [])
if not token_list:
return # No JWT provided, use session auth
token = token_list[0]
if not token:
return
# Validate JWT and create JWTUser (no DB query)
try:
from mizan.client.jwt import decode_token
from mizan.jwt.tokens import JWTUser
payload = await sync_to_async(decode_token)(token, expected_type="access")
if payload is None:
logger.debug("JWT token invalid or expired")
return # Fall back to session auth
# Create JWTUser from token claims - NO DATABASE QUERY
self.scope["user"] = JWTUser(payload)
logger.debug(f"JWT auth successful for user {payload.user_id}")
except Exception as e:
logger.debug(f"JWT auth failed: {e}")
async def disconnect(self, close_code):
"""Clean up all subscriptions on disconnect."""
for key, instance in list(self._subscriptions.items()):
try:
await instance.on_disconnect()
await instance._leave_all_groups()
except Exception as e:
logger.error(f"Error during disconnect cleanup: {e}")
self._subscriptions.clear()
logger.debug(f"WebSocket disconnected: {self.channel_name}")
async def receive_json(self, content: dict):
"""Handle incoming JSON messages."""
action = content.get("action")
if action == "subscribe":
await self._handle_subscribe(content)
elif action == "unsubscribe":
await self._handle_unsubscribe(content)
elif action == "message":
await self._handle_message(content)
elif action == "rpc":
await self._handle_rpc(content)
else:
await self.send_json(
{
"error": f"Unknown action: {action}",
}
)
async def _handle_subscribe(self, content: dict):
"""Handle subscription request."""
channel_name = content.get("channel")
params_dict = content.get("params", {})
# Get channel class
channel_class = get_channel(channel_name)
if not channel_class:
await self.send_json(
{
"error": f"Unknown channel: {channel_name}",
}
)
return
# Create subscription key
params_json = json.dumps(params_dict, sort_keys=True)
sub_key = (channel_name, params_json)
# Check if already subscribed
if sub_key in self._subscriptions:
await self.send_json(
{
"error": f"Already subscribed to {channel_name}",
"channel": channel_name,
"params": params_dict,
}
)
return
# Create channel instance
instance = channel_class()
instance.user = self.scope.get("user")
instance._channel_layer = self.channel_layer
instance._channel_name = self.channel_name
instance._registered_name = channel_name
instance._params_dict = params_dict
# Parse params
params_obj = None
if channel_class.Params:
try:
params_obj = channel_class.Params(**params_dict)
except Exception as e:
await self.send_json(
{
"error": f"Invalid params: {e}",
"channel": channel_name,
}
)
return
# Check authorization
try:
if params_obj:
authorized = instance.authorize(params_obj)
else:
authorized = instance.authorize()
except Exception as e:
logger.error(f"Authorization error for {channel_name}: {e}")
await self.send_json(
{
"error": "Authorization failed",
"channel": channel_name,
}
)
return
if not authorized:
await self.send_json(
{
"error": "Not authorized",
"channel": channel_name,
}
)
return
# Get group and join
try:
if params_obj:
group_name = instance.group(params_obj)
else:
group_name = instance.group()
await instance._join_group(group_name)
except Exception as e:
logger.error(f"Failed to join group for {channel_name}: {e}")
await self.send_json(
{
"error": f"Failed to subscribe: {e}",
"channel": channel_name,
}
)
return
# Store subscription
self._subscriptions[sub_key] = instance
# Call on_connect hook
try:
await instance.on_connect(params_obj)
except Exception as e:
logger.error(f"on_connect error for {channel_name}: {e}")
# Confirm subscription
await self.send_json(
{
"subscribed": True,
"channel": channel_name,
"params": params_dict,
}
)
logger.debug(f"Subscribed to {channel_name} with params {params_dict}")
async def _handle_unsubscribe(self, content: dict):
"""Handle unsubscription request."""
channel_name = content.get("channel")
params_dict = content.get("params", {})
params_json = json.dumps(params_dict, sort_keys=True)
sub_key = (channel_name, params_json)
instance = self._subscriptions.pop(sub_key, None)
if instance:
try:
await instance.on_disconnect()
await instance._leave_all_groups()
except Exception as e:
logger.error(f"Error during unsubscribe: {e}")
await self.send_json(
{
"unsubscribed": True,
"channel": channel_name,
"params": params_dict,
}
)
logger.debug(f"Unsubscribed from {channel_name}")
async def _handle_message(self, content: dict):
"""Handle incoming message from browser."""
channel_name = content.get("channel")
params_dict = content.get("params", {})
data = content.get("data", {})
params_json = json.dumps(params_dict, sort_keys=True)
sub_key = (channel_name, params_json)
instance = self._subscriptions.get(sub_key)
if not instance:
await self.send_json(
{
"error": f"Not subscribed to {channel_name}",
"channel": channel_name,
}
)
return
channel_class = instance.__class__
# Check if channel accepts messages
if not channel_class.ReactMessage:
await self.send_json(
{
"error": f"Channel {channel_name} does not accept messages",
"channel": channel_name,
}
)
return
# Parse message
try:
msg = channel_class.ReactMessage(**data)
except Exception as e:
await self.send_json(
{
"error": f"Invalid message: {e}",
"channel": channel_name,
}
)
return
# Parse params
params_obj = None
if channel_class.Params:
params_obj = channel_class.Params(**params_dict)
# Handle message
try:
response = instance.receive(params_obj, msg)
# If handler returned a message, broadcast it
if response is not None:
if params_obj:
group_name = instance.group(params_obj)
else:
group_name = instance.group()
await instance._broadcast(group_name, response)
except Exception as e:
logger.error(f"Error handling message for {channel_name}: {e}")
await self.send_json(
{
"error": f"Message handling failed: {e}",
"channel": channel_name,
}
)
async def _handle_rpc(self, content: dict):
"""
Handle RPC (server function) call.
Protocol:
Request: {"action": "rpc", "id": "request-id", "fn": "function_name", "args": {...}}
Response: {"id": "request-id", "ok": true, "data": {...}}
or: {"id": "request-id", "ok": false, "error": {...}}
Security:
- Only functions with @client(websocket=True) are allowed
- Pydantic validation happens BEFORE any function code runs
- Function must be explicitly registered (no arbitrary code execution)
- User context from WebSocket session is passed to function
"""
from mizan.client.executor import execute_function, FunctionError
from mizan.setup.registry import get_function
request_id = content.get("id")
fn_name = content.get("fn")
args = content.get("args", {})
# Validate request structure
if not request_id:
await self.send_json(
{
"error": "RPC request missing 'id' field",
}
)
return
if not fn_name:
await self.send_json(
{
"id": request_id,
"ok": False,
"error": {
"code": "BAD_REQUEST",
"message": "Missing 'fn' field",
},
}
)
return
# Check if function exists and has websocket=True
fn_class = get_function(fn_name)
if fn_class is None:
await self.send_json(
{
"id": request_id,
"ok": False,
"error": {
"code": "NOT_FOUND",
"message": f"Function '{fn_name}' not found",
},
}
)
return
# Only allow functions explicitly marked with websocket=True
fn_meta = getattr(fn_class, "_meta", {})
if not fn_meta.get("websocket"):
await self.send_json(
{
"id": request_id,
"ok": False,
"error": {
"code": "FORBIDDEN",
"message": "This function is HTTP-only. Use POST /api/mizan/call/ instead.",
},
}
)
return
# Create request adapter from WebSocket scope
ws_request = WebSocketRequest(
self.scope, channel_name=getattr(self, "channel_name", None)
)
# Execute function (Pydantic validation happens inside execute_function)
# This is sync, so we need to run it in a thread pool
result = await sync_to_async(execute_function, thread_sensitive=True)(
ws_request,
fn_name,
args,
)
# Send response
if isinstance(result, FunctionError):
await self.send_json(
{
"id": request_id,
"ok": False,
"error": {
"code": result.code.value,
"message": result.message,
**({"details": result.details} if result.details else {}),
},
}
)
else:
await self.send_json(
{
"id": request_id,
"ok": True,
"data": result.data,
}
)
async def channel_message(self, event: dict):
"""
Handle messages broadcast to a group.
Called when channel_layer.group_send() is used.
Includes channel name and params so the client can route the message.
"""
await self.send_json(
{
"channel": event.get("channel"),
"params": event.get("params", {}),
"type": event.get("message_type", "message"),
"data": event.get("data", {}),
}
)
async def push_message(self, event: dict):
"""
Handle push messages from server functions.
Called when push("topic", data) is used from a server function.
The client receives this to update its local state.
Protocol:
Server sends: {"type": "push", "topic": "room:42", "data": {...}}
"""
await self.send_json(
{
"type": "push",
"topic": event.get("topic"),
"data": event.get("data", {}),
}
)