Flatten to three packages + extract mizan-runtime

packages/
  mizan-runtime/   Framework-agnostic state engine (~150 lines)
                   Context registry, batched invalidation, fetch primitives
  mizan-django/    Django server adapter (was packages/mizan-rpc/adapters/django/)
                   Codegen moved to mizan-django/generate/
  mizan-react/     React adapter (was packages/mizan-csr/adapters/react/)

Removed premature abstractions: mizan-ast, mizan-schema, mizan-rpc,
mizan-csr, mizan-ssr stub packages. The actual architecture is three
concrete packages, not five abstract layers.

mizan-runtime implements the v1 spec: registerContext with params,
scoped invalidation via microtask batching, server-driven invalidation
from mutation responses, mizanFetch for context bundles, mizanCall for
mutations.

264 Django + 33 React tests pass.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-04-02 15:47:17 -04:00
parent b28ee72c67
commit 787f90fd12
141 changed files with 167 additions and 15 deletions

View File

@@ -0,0 +1,187 @@
"""
mizan - Django + React unified framework
Server functions are the core primitive. Everything else builds on them.
## Quick Start
### 1. urls.py - HTTP endpoint
```python
from mizan import urls as mizan_urls
urlpatterns = [
path('api/mizan/', include(mizan_urls)),
]
```
### 2. asgi.py - WebSocket support (optional)
```python
from mizan import wrap_asgi
from django.core.asgi import get_asgi_application
application = wrap_asgi(get_asgi_application())
```
### 3. Define server functions
```python
# apps/myapp/clients.py
from mizan import client
from pydantic import BaseModel
class EchoOutput(BaseModel):
message: str
# HTTP-only function (default)
@client
def echo(request, text: str) -> EchoOutput:
return EchoOutput(message=f"Echo: {text}")
# Global context (singleton, SSR-hydrated)
@client(context='global')
def current_user(request) -> UserOutput:
return UserOutput(email=request.user.email)
# WebSocket-enabled for real-time
@client(websocket=True)
def send_message(request, room_id: int, text: str) -> MessageOutput:
return MessageOutput(...)
```
### 4. Auto-discover in apps.py
```python
class MyAppConfig(AppConfig):
def ready(self):
from mizan.setup import mizan_clients
mizan_clients('apps')
```
### 5. Frontend - generate types and use
```bash
npm run schemas
```
```tsx
import { useEcho, useCurrentUser } from '@/api'
const user = useCurrentUser()
const echo = useEcho()
await echo({ text: 'hello' })
```
## What You Get
| Backend | Frontend | Transport |
|------------------------------------|-----------------------|------------|
| `@client` | `useXxx()` hook | HTTP |
| `@client(context='global')` | `useXxx()` + SSR | HTTP |
| `@client(context='local')` | `<XxxProvider>` + hook| HTTP |
| `@client(websocket=True)` | `useXxx()` hook | WebSocket |
| `@compose(...)` | `<XxxProvider>` combined | varies |
| `mizanFormMixin` | `useXxxForm()` + Zod | HTTP |
| `ReactChannel` | `useXxxChannel()` | WebSocket |
"""
# All imports at module level (sorted)
from . import channels
from . import client as client_module
from . import export
from . import forms
from . import setup
from .channels import ReactChannel
from .channels import register as register_channel
from .client import ComposedContext, GlobalContext, ReactContext, ServerFunction, client, compose
# Shape is lazy-loaded via __getattr__ because django_readers
# imports contenttypes, which can't happen during apps.populate()
from .setup import (
mizan_clients,
mizan_module,
get_channel,
get_function,
register,
register_as,
)
def __getattr__(name):
"""Lazy loading for modules that can't be imported at app load time."""
if name == "urls":
from .urls import urlpatterns as mizan_patterns
return mizan_patterns
if name == "Shape":
from .shapes import Shape
return Shape
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
def wrap_asgi(http_application):
"""
Wrap an ASGI application with mizan WebSocket support.
Usage in asgi.py:
from django.core.asgi import get_asgi_application
from mizan import wrap_asgi
application = wrap_asgi(get_asgi_application())
This adds:
- WebSocket routing at /ws/ for RPC and channels
- Authentication middleware for WebSocket connections
"""
try:
from channels.auth import AuthMiddlewareStack
from channels.routing import ProtocolTypeRouter, URLRouter
from django.urls import path
except ImportError:
raise ImportError(
"django-channels is required for WebSocket support.\n"
"Install with: pip install channels channels-redis\n"
"Add 'channels' to INSTALLED_APPS and configure CHANNEL_LAYERS."
)
from .channels.connection import DjangoReactConsumer
return ProtocolTypeRouter(
{
"http": http_application,
"websocket": AuthMiddlewareStack(
URLRouter(
[
path("ws/", DjangoReactConsumer.as_asgi()),
]
)
),
}
)
__all__ = [
# Decorators & Contexts
"client",
"compose",
"ReactContext",
"GlobalContext",
"ServerFunction",
"ComposedContext",
# Setup
"mizan_clients",
"mizan_module",
"register",
"register_as",
"get_function",
"get_channel",
# ASGI
"wrap_asgi",
# Channels
"ReactChannel",
"register_channel",
# Shapes
"Shape",
# Submodules
"client_module",
"setup",
"forms",
"channels",
"export",
]

View File

@@ -0,0 +1,91 @@
import inspect
from importlib import import_module
from inspect import isclass
from typing import Protocol, Any
from django.conf import settings
def get_members(path):
try:
module = import_module(path)
except ModuleNotFoundError:
print('Could not import module "{}"'.format(path))
return []
members = [
(name, member)
for name, member in inspect.getmembers(module)
if not isclass(member) or (member.__module__ == module.__name__)
]
return members
class DjangoAppVisitorHandler(Protocol):
def on_module(
self, app_name: str, path_parts: list[str], members: list[tuple[str, Any]]
) -> None: ...
class DjangoAppVisitor:
"""
Discovers Python modules under each Django app following conventions:
- <app>/<module>.py -> url_prefix "<renamed>/"
- <app>/<module>/**/*.py -> url_prefix "<renamed>/<subdirs...>/<module>/"
Example:
<app>/<module>/forms/nksn.py -> url_prefix "<renamed>/forms/nksn/"
module_path "<app>.module.forms.nksn"
"""
def __init__(
self,
*,
layer: str,
apps_root: str = "",
):
self.apps_root = apps_root
self.layer = layer
def visit(self, handler: DjangoAppVisitorHandler) -> None:
apps_dir = (
settings.BASE_DIR / self.apps_root if self.apps_root else settings.BASE_DIR
)
if not apps_dir.is_dir():
apps_dir = settings.BASE_DIR
module_prefix = f"{self.apps_root}." if self.apps_root else ""
for app_name in settings.INSTALLED_APPS:
if app_name.startswith(self.apps_root + "."):
app_name = app_name[(len(self.apps_root) + 1) :]
app_dir = apps_dir / app_name
if not app_dir.exists():
continue
app_module = f"{module_prefix}{app_name}"
# 1) Visit package: <app>/<module>/**/*.py
layer_dir = app_dir / self.layer
if layer_dir.is_dir():
for py_file in layer_dir.rglob("*.py"):
if py_file.name == "__init__.py":
continue
relative_path = py_file.relative_to(layer_dir).with_suffix("")
parts = list(relative_path.parts)
dotted = ".".join(parts)
handler.on_module(
app_name,
parts,
get_members(f"{app_module}.{self.layer}.{dotted}"),
)
# 2) Visit module module file: <app>/module.py
layer_file = app_dir / f"{self.layer}.py"
if layer_file.is_file():
handler.on_module(
app_name, [], get_members(f"{app_module}.{self.layer}")
)

View File

@@ -0,0 +1,543 @@
"""
mizan.channels - Real-time WebSocket communication.
Type-safe bidirectional messaging between Django and React via WebSockets.
Hooks are auto-generated with full TypeScript types.
## Basic Usage
```python
# channels.py
from pydantic import BaseModel
from mizan import channels
class ChatChannel(channels.ReactChannel):
class Params(BaseModel):
room: str
class ReactMessage(BaseModel):
text: str
class DjangoMessage(BaseModel):
user: str
text: str
timestamp: datetime
def authorize(self, params: Params) -> bool:
return self.user.is_authenticated
def group(self, params: Params) -> str:
return f'chat_{params.room}'
def receive(self, params: Params, msg: ReactMessage) -> DjangoMessage | None:
return self.DjangoMessage(
user=self.user.email,
text=msg.text,
timestamp=now(),
)
channels.register(ChatChannel, 'chat')
```
```python
# asgi.py
from mizan import channels
application = ProtocolTypeRouter({
"http": get_asgi_application(),
"websocket": channels.get_websocket_application(),
})
```
## Frontend Usage (auto-generated)
```tsx
import { useChatChannel } from '@/api/generated.channels'
function Chat({ room }) {
const chat = useChatChannel({ room })
chat.status // 'connecting' | 'connected' | 'disconnected'
chat.messages // DjangoMessage[]
chat.send({ text: 'Hello' }) // ReactMessage
}
```
## Server Push
```python
await ChatChannel.push(room='general', message=ChatChannel.DjangoMessage(...))
```
"""
import logging
from typing import TYPE_CHECKING, Any, ClassVar, Type
from pydantic import BaseModel
if TYPE_CHECKING:
from django.contrib.auth.models import AbstractBaseUser, AnonymousUser
from ninja import NinjaAPI
logger = logging.getLogger(__name__)
# =============================================================================
# Base Classes
# =============================================================================
class ReactChannel:
"""
Base class for WebSocket channels.
Define nested Pydantic classes for typed messaging:
- Params: Query parameters for subscribing (optional)
- ReactMessage: Messages from browser to server (optional)
- DjangoMessage: Messages from server to browser (optional)
Implement required methods:
- authorize(): Permission check for connection
- group(): Which group to broadcast to
Optionally implement:
- receive(): Handle incoming ReactMessage, return DjangoMessage to broadcast
- on_connect(): Called after successful connection
- on_disconnect(): Called when connection closes
"""
# Nested classes (optional, defined by subclasses)
Params: ClassVar[Type[BaseModel] | None] = None
ReactMessage: ClassVar[Type[BaseModel] | None] = None
DjangoMessage: ClassVar[Type[BaseModel] | None] = None
# Set by the framework when handling a connection
user: "AbstractBaseUser | AnonymousUser"
_channel_layer: Any = None
_channel_name: str = ""
_registered_name: ClassVar[str] = ""
_params_dict: dict = {}
_groups: set[str]
def __init__(self):
self._groups = set()
self._params_dict = {}
def authorize(self, params: BaseModel | None = None) -> bool:
"""
Permission check. Return True to allow connection, False to reject.
Override this to implement custom authorization logic.
"""
raise NotImplementedError(
f"{self.__class__.__name__} must implement authorize()"
)
def group(self, params: BaseModel | None = None) -> str:
"""
Return the group name for broadcasting.
Messages returned from receive() are broadcast to this group.
"""
raise NotImplementedError(f"{self.__class__.__name__} must implement group()")
def receive(self, params: BaseModel | None, msg: BaseModel) -> BaseModel | None:
"""
Handle incoming ReactMessage.
Return a DjangoMessage to broadcast to the group, or None to skip.
Override this to implement message handling.
"""
return None
async def on_connect(self, params: BaseModel | None = None) -> None:
"""Called after successful connection and group join."""
pass
async def on_disconnect(self) -> None:
"""Called when the connection closes."""
pass
# -------------------------------------------------------------------------
# Internal Methods (used by the consumer)
# -------------------------------------------------------------------------
async def _join_group(self, group_name: str) -> None:
"""Join a channel layer group."""
if self._channel_layer:
await self._channel_layer.group_add(group_name, self._channel_name)
self._groups.add(group_name)
async def _leave_group(self, group_name: str) -> None:
"""Leave a channel layer group."""
if self._channel_layer and group_name in self._groups:
await self._channel_layer.group_discard(group_name, self._channel_name)
self._groups.discard(group_name)
async def _leave_all_groups(self) -> None:
"""Leave all joined groups."""
for group_name in list(self._groups):
await self._leave_group(group_name)
async def _broadcast(self, group_name: str, message: BaseModel) -> None:
"""Broadcast a message to a group."""
if self._channel_layer:
await self._channel_layer.group_send(
group_name,
{
"type": "channel.message",
"channel": self._registered_name,
"params": self._params_dict,
"data": message.model_dump(mode="json"),
"message_type": message.__class__.__name__,
},
)
# -------------------------------------------------------------------------
# Class Methods for Server Push
# -------------------------------------------------------------------------
@classmethod
async def push(cls, message: BaseModel, **params) -> None:
"""
Push a message from server code (views, tasks, signals).
Usage:
await ChatChannel.push(
room='general',
message=ChatChannel.DjangoMessage(user='system', text='Hello')
)
"""
from channels.layers import get_channel_layer
channel_layer = get_channel_layer()
if not channel_layer:
logger.warning(
f"No channel layer configured, cannot push to {cls.__name__}"
)
return
# Build params model if defined
params_obj = None
if cls.Params:
params_obj = cls.Params(**params)
# Get group name
instance = cls()
group_name = instance.group(params_obj)
# Send to group
await channel_layer.group_send(
group_name,
{
"type": "channel.message",
"channel": cls._registered_name,
"params": params,
"data": message.model_dump(mode="json"),
"message_type": message.__class__.__name__,
},
)
# =============================================================================
# Registry
# =============================================================================
_registry: dict[str, Type[ReactChannel]] = {}
def register(channel_class: Type[ReactChannel], name: str) -> None:
"""
Register a channel.
Args:
channel_class: The ReactChannel subclass to register
name: URL-friendly name (used in subscriptions)
"""
if name in _registry:
raise ValueError(f"Channel '{name}' is already registered")
channel_class._registered_name = name
# Validate the channel class
if not hasattr(channel_class, "authorize"):
raise ValueError(f"{channel_class.__name__} must implement authorize()")
if not hasattr(channel_class, "group"):
raise ValueError(f"{channel_class.__name__} must implement group()")
_registry[name] = channel_class
logger.debug(f"Registered channel: {name} -> {channel_class.__name__}")
def get_channel(name: str) -> Type[ReactChannel] | None:
"""Get a registered channel class by name."""
return _registry.get(name)
def get_registered_channels() -> dict[str, Type[ReactChannel]]:
"""Get all registered channel classes."""
return dict(_registry)
# =============================================================================
# WebSocket Consumer
# =============================================================================
def get_websocket_application():
"""
Get the WebSocket application for ASGI.
Usage in asgi.py:
from mizan import channels
application = ProtocolTypeRouter({
"http": get_asgi_application(),
"websocket": channels.get_websocket_application(),
})
"""
try:
from channels.routing import URLRouter
from channels.auth import AuthMiddlewareStack
from django.urls import path
except ImportError:
raise ImportError(
"django-channels is required for WebSocket support. "
"Install it with: pip install channels channels-redis"
)
from .connection import DjangoReactConsumer
return AuthMiddlewareStack(
URLRouter(
[
path("ws/", DjangoReactConsumer.as_asgi()),
]
)
)
# =============================================================================
# Schema Export (for TypeScript generation)
# =============================================================================
def get_channels_schema() -> dict:
"""
Get schema for all registered channels (for TypeScript generation).
Returns a dict suitable for the frontend code generator.
"""
schema = {"channels": {}}
for name, channel_class in _registry.items():
channel_schema = {
"name": name,
"params": None,
"reactMessage": None,
"djangoMessage": None,
}
# Extract Params schema
if hasattr(channel_class, "Params") and channel_class.Params:
channel_schema["params"] = channel_class.Params.model_json_schema()
# Extract ReactMessage schema
if hasattr(channel_class, "ReactMessage") and channel_class.ReactMessage:
channel_schema[
"reactMessage"
] = channel_class.ReactMessage.model_json_schema()
# Extract DjangoMessage schema
if hasattr(channel_class, "DjangoMessage") and channel_class.DjangoMessage:
channel_schema[
"djangoMessage"
] = channel_class.DjangoMessage.model_json_schema()
schema["channels"][name] = channel_schema
return schema
def _register_channel_schema_endpoint(
api: "NinjaAPI",
path: str,
operation_id: str,
summary: str,
input_cls: type | None,
output_cls: type,
) -> None:
"""Register a dummy endpoint for schema generation (avoids closure issues)."""
if input_cls is not None:
def endpoint(request, data):
pass
endpoint.__annotations__ = {"data": input_cls}
else:
def endpoint(request):
pass
api.post(path, response=output_cls, operation_id=operation_id, summary=summary)(
endpoint
)
def get_channels_openapi_schema() -> dict:
"""
Get OpenAPI schema for all registered channels.
Uses Django Ninja's schema generation for robust Pydantic→OpenAPI conversion.
This schema is consumed by openapi-typescript for type generation.
"""
from ninja import NinjaAPI
from pydantic import BaseModel
# Create temporary Ninja API for schema generation only
schema_api = NinjaAPI(
title="mizan Channels",
version="1.0.0",
description="Auto-generated schema for mizan channels",
docs_url=None,
openapi_url=None,
)
# Store dynamically created classes
schema_classes: dict[str, type] = {}
channel_metadata: list[dict] = []
for name, channel_class in _registry.items():
pascal_name = name.replace("_", " ").title().replace(" ", "")
channel_meta = {
"name": name,
"pascalName": pascal_name,
"hasParams": False,
"hasReactMessage": False,
"hasDjangoMessage": False,
}
# Register Params type
if hasattr(channel_class, "Params") and channel_class.Params:
params_name = f"{pascal_name}Params"
schema_classes[params_name] = type(params_name, (channel_class.Params,), {})
channel_meta["hasParams"] = True
channel_meta["paramsType"] = params_name
# Create dummy endpoint to include in schema
_register_channel_schema_endpoint(
api=schema_api,
path=f"/channels/{name}/params",
operation_id=f"{name}Params",
summary=f"{pascal_name} channel params",
input_cls=schema_classes[params_name],
output_cls=BaseModel,
)
# Register ReactMessage type
if hasattr(channel_class, "ReactMessage") and channel_class.ReactMessage:
react_name = f"{pascal_name}ReactMessage"
schema_classes[react_name] = type(
react_name, (channel_class.ReactMessage,), {}
)
channel_meta["hasReactMessage"] = True
channel_meta["reactMessageType"] = react_name
_register_channel_schema_endpoint(
api=schema_api,
path=f"/channels/{name}/react",
operation_id=f"{name}ReactMessage",
summary=f"{pascal_name} React→Django message",
input_cls=schema_classes[react_name],
output_cls=BaseModel,
)
# Register DjangoMessage type
if hasattr(channel_class, "DjangoMessage") and channel_class.DjangoMessage:
django_name = f"{pascal_name}DjangoMessage"
schema_classes[django_name] = type(
django_name, (channel_class.DjangoMessage,), {}
)
channel_meta["hasDjangoMessage"] = True
channel_meta["djangoMessageType"] = django_name
_register_channel_schema_endpoint(
api=schema_api,
path=f"/channels/{name}/django",
operation_id=f"{name}DjangoMessage",
summary=f"{pascal_name} Django→React message",
input_cls=None,
output_cls=schema_classes[django_name],
)
channel_metadata.append(channel_meta)
# Get OpenAPI schema from Ninja
# path_prefix="" avoids URL reverse() — this API is never mounted
schema = schema_api.get_openapi_schema(path_prefix="")
# Add channel metadata extension
schema["x-mizan-channels"] = channel_metadata
return schema
# =============================================================================
# Schema Endpoint (for TypeScript generation)
# =============================================================================
_schema_router = None
def _get_schema_router():
"""Get the Ninja router for the channels schema endpoint."""
global _schema_router
if _schema_router is None:
from ninja import Router
_schema_router = Router(tags=["channels"])
@_schema_router.get("/schema/")
def channels_schema(request):
"""Get schema for all registered channels (for TypeScript generation)."""
return get_channels_schema()
return _schema_router
def get_urls():
"""Get URL patterns for channels schema endpoint."""
from ninja import NinjaAPI
api = NinjaAPI(urls_namespace="django_react_channels")
api.add_router("/", _get_schema_router())
return api.urls
def __getattr__(name):
if name == "urls":
return get_urls()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
# =============================================================================
# Exports
# =============================================================================
__all__ = [
# URLs
"urls",
# Base class
"ReactChannel",
# Registration
"register",
"get_channel",
"get_registered_channels",
# ASGI application
"get_websocket_application",
# Schema export
"get_channels_schema",
]

View File

@@ -0,0 +1,528 @@
"""
WebSocket consumer for mizan.channels.
Handles multiplexed channel subscriptions AND RPC calls over a single WebSocket connection.
Protocol:
Browser sends:
# Channel subscriptions
{"action": "subscribe", "channel": "chat", "params": {"room": "general"}}
{"action": "unsubscribe", "channel": "chat", "params": {"room": "general"}}
{"action": "message", "channel": "chat", "params": {"room": "general"}, "data": {...}}
# RPC calls (server functions)
{"action": "rpc", "id": "request-id", "fn": "function_name", "args": {...}}
Server sends:
# Channel messages
{"channel": "chat", "params": {"room": "general"}, "type": "DjangoMessage", "data": {...}}
# RPC responses
{"id": "request-id", "ok": true, "data": {...}}
{"id": "request-id", "ok": false, "error": {...}}
{"error": "..."}
Authentication:
Supports both session (cookie) and JWT authentication:
- Session: Handled automatically via AuthMiddlewareStack (cookies in handshake)
- JWT: Pass token as query parameter: ws://...?token=<jwt>
The WebSocket URL for JWT auth would be: ws://localhost/ws/?token=<access_token>
Security:
- Functions must be explicitly registered (no arbitrary code execution)
- Pydantic validation runs BEFORE any function code
"""
import json
import logging
from typing import Any
from urllib.parse import parse_qs
from channels.generic.websocket import AsyncJsonWebsocketConsumer
from asgiref.sync import sync_to_async
from . import get_channel
logger = logging.getLogger(__name__)
class WebSocketRequest:
"""
Minimal request adapter for WebSocket context.
Provides the interface expected by ServerFunction without full HttpRequest.
This is intentionally minimal - only expose what's needed.
Note: Some Django libraries (e.g., allauth rate limiting) check request.method.
We set method="POST" since WebSocket RPC calls are semantically similar to POST.
"""
# WebSocket RPC is semantically similar to POST (sends data, expects response)
method = "POST"
def __init__(self, scope: dict, channel_name: str = None):
self.user = scope.get("user")
self.session = scope.get("session", {})
self.channel_name = channel_name # For push subscriptions
self._scope = scope
@property
def META(self) -> dict:
"""HTTP headers from WebSocket handshake."""
headers = dict(self._scope.get("headers", []))
return {
"HTTP_" + k.decode().upper().replace("-", "_"): v.decode()
for k, v in headers.items()
}
class DjangoReactConsumer(AsyncJsonWebsocketConsumer):
"""
Multiplexed WebSocket consumer for django_react channels.
Manages multiple channel subscriptions over a single WebSocket connection.
Authentication:
- Session auth via cookies (handled by AuthMiddlewareStack)
- JWT auth via query parameter: ws://...?token=<jwt>
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
# Track subscriptions: {(channel_name, params_json): channel_instance}
self._subscriptions: dict[tuple[str, str], Any] = {}
async def connect(self):
"""Accept the WebSocket connection, authenticating via JWT if provided."""
# Check for JWT token in query parameters
await self._try_jwt_auth()
await self.accept()
logger.debug(
f"WebSocket connected: {self.channel_name}, user={self.scope.get('user')}"
)
async def _try_jwt_auth(self):
"""
Attempt JWT authentication from query parameter.
If a valid JWT token is provided via ?token=<jwt>, authenticate the user
using JWTUser (no database query).
Security: If JWT is provided but invalid, we log it but don't reject
the connection - the session auth may still be valid. However, if JWT
IS valid, it takes precedence over session auth.
"""
# Parse query string for token
query_string = self.scope.get("query_string", b"").decode()
params = parse_qs(query_string)
token_list = params.get("token", [])
if not token_list:
return # No JWT provided, use session auth
token = token_list[0]
if not token:
return
# Validate JWT and create JWTUser (no DB query)
try:
from mizan.client.jwt import decode_token
from mizan.jwt.tokens import JWTUser
payload = await sync_to_async(decode_token)(token, expected_type="access")
if payload is None:
logger.debug("JWT token invalid or expired")
return # Fall back to session auth
# Create JWTUser from token claims - NO DATABASE QUERY
self.scope["user"] = JWTUser(payload)
logger.debug(f"JWT auth successful for user {payload.user_id}")
except Exception as e:
logger.debug(f"JWT auth failed: {e}")
async def disconnect(self, close_code):
"""Clean up all subscriptions on disconnect."""
for key, instance in list(self._subscriptions.items()):
try:
await instance.on_disconnect()
await instance._leave_all_groups()
except Exception as e:
logger.error(f"Error during disconnect cleanup: {e}")
self._subscriptions.clear()
logger.debug(f"WebSocket disconnected: {self.channel_name}")
async def receive_json(self, content: dict):
"""Handle incoming JSON messages."""
action = content.get("action")
if action == "subscribe":
await self._handle_subscribe(content)
elif action == "unsubscribe":
await self._handle_unsubscribe(content)
elif action == "message":
await self._handle_message(content)
elif action == "rpc":
await self._handle_rpc(content)
else:
await self.send_json(
{
"error": f"Unknown action: {action}",
}
)
async def _handle_subscribe(self, content: dict):
"""Handle subscription request."""
channel_name = content.get("channel")
params_dict = content.get("params", {})
# Get channel class
channel_class = get_channel(channel_name)
if not channel_class:
await self.send_json(
{
"error": f"Unknown channel: {channel_name}",
}
)
return
# Create subscription key
params_json = json.dumps(params_dict, sort_keys=True)
sub_key = (channel_name, params_json)
# Check if already subscribed
if sub_key in self._subscriptions:
await self.send_json(
{
"error": f"Already subscribed to {channel_name}",
"channel": channel_name,
"params": params_dict,
}
)
return
# Create channel instance
instance = channel_class()
instance.user = self.scope.get("user")
instance._channel_layer = self.channel_layer
instance._channel_name = self.channel_name
instance._registered_name = channel_name
instance._params_dict = params_dict
# Parse params
params_obj = None
if channel_class.Params:
try:
params_obj = channel_class.Params(**params_dict)
except Exception as e:
await self.send_json(
{
"error": f"Invalid params: {e}",
"channel": channel_name,
}
)
return
# Check authorization
try:
if params_obj:
authorized = instance.authorize(params_obj)
else:
authorized = instance.authorize()
except Exception as e:
logger.error(f"Authorization error for {channel_name}: {e}")
await self.send_json(
{
"error": "Authorization failed",
"channel": channel_name,
}
)
return
if not authorized:
await self.send_json(
{
"error": "Not authorized",
"channel": channel_name,
}
)
return
# Get group and join
try:
if params_obj:
group_name = instance.group(params_obj)
else:
group_name = instance.group()
await instance._join_group(group_name)
except Exception as e:
logger.error(f"Failed to join group for {channel_name}: {e}")
await self.send_json(
{
"error": f"Failed to subscribe: {e}",
"channel": channel_name,
}
)
return
# Store subscription
self._subscriptions[sub_key] = instance
# Call on_connect hook
try:
await instance.on_connect(params_obj)
except Exception as e:
logger.error(f"on_connect error for {channel_name}: {e}")
# Confirm subscription
await self.send_json(
{
"subscribed": True,
"channel": channel_name,
"params": params_dict,
}
)
logger.debug(f"Subscribed to {channel_name} with params {params_dict}")
async def _handle_unsubscribe(self, content: dict):
"""Handle unsubscription request."""
channel_name = content.get("channel")
params_dict = content.get("params", {})
params_json = json.dumps(params_dict, sort_keys=True)
sub_key = (channel_name, params_json)
instance = self._subscriptions.pop(sub_key, None)
if instance:
try:
await instance.on_disconnect()
await instance._leave_all_groups()
except Exception as e:
logger.error(f"Error during unsubscribe: {e}")
await self.send_json(
{
"unsubscribed": True,
"channel": channel_name,
"params": params_dict,
}
)
logger.debug(f"Unsubscribed from {channel_name}")
async def _handle_message(self, content: dict):
"""Handle incoming message from browser."""
channel_name = content.get("channel")
params_dict = content.get("params", {})
data = content.get("data", {})
params_json = json.dumps(params_dict, sort_keys=True)
sub_key = (channel_name, params_json)
instance = self._subscriptions.get(sub_key)
if not instance:
await self.send_json(
{
"error": f"Not subscribed to {channel_name}",
"channel": channel_name,
}
)
return
channel_class = instance.__class__
# Check if channel accepts messages
if not channel_class.ReactMessage:
await self.send_json(
{
"error": f"Channel {channel_name} does not accept messages",
"channel": channel_name,
}
)
return
# Parse message
try:
msg = channel_class.ReactMessage(**data)
except Exception as e:
await self.send_json(
{
"error": f"Invalid message: {e}",
"channel": channel_name,
}
)
return
# Parse params
params_obj = None
if channel_class.Params:
params_obj = channel_class.Params(**params_dict)
# Handle message
try:
response = instance.receive(params_obj, msg)
# If handler returned a message, broadcast it
if response is not None:
if params_obj:
group_name = instance.group(params_obj)
else:
group_name = instance.group()
await instance._broadcast(group_name, response)
except Exception as e:
logger.error(f"Error handling message for {channel_name}: {e}")
await self.send_json(
{
"error": f"Message handling failed: {e}",
"channel": channel_name,
}
)
async def _handle_rpc(self, content: dict):
"""
Handle RPC (server function) call.
Protocol:
Request: {"action": "rpc", "id": "request-id", "fn": "function_name", "args": {...}}
Response: {"id": "request-id", "ok": true, "data": {...}}
or: {"id": "request-id", "ok": false, "error": {...}}
Security:
- Only functions with @client(websocket=True) are allowed
- Pydantic validation happens BEFORE any function code runs
- Function must be explicitly registered (no arbitrary code execution)
- User context from WebSocket session is passed to function
"""
from mizan.client.executor import execute_function, FunctionError
from mizan.setup.registry import get_function
request_id = content.get("id")
fn_name = content.get("fn")
args = content.get("args", {})
# Validate request structure
if not request_id:
await self.send_json(
{
"error": "RPC request missing 'id' field",
}
)
return
if not fn_name:
await self.send_json(
{
"id": request_id,
"ok": False,
"error": {
"code": "BAD_REQUEST",
"message": "Missing 'fn' field",
},
}
)
return
# Check if function exists and has websocket=True
fn_class = get_function(fn_name)
if fn_class is None:
await self.send_json(
{
"id": request_id,
"ok": False,
"error": {
"code": "NOT_FOUND",
"message": f"Function '{fn_name}' not found",
},
}
)
return
# Only allow functions explicitly marked with websocket=True
fn_meta = getattr(fn_class, "_meta", {})
if not fn_meta.get("websocket"):
await self.send_json(
{
"id": request_id,
"ok": False,
"error": {
"code": "FORBIDDEN",
"message": "This function is HTTP-only. Use POST /api/mizan/call/ instead.",
},
}
)
return
# Create request adapter from WebSocket scope
ws_request = WebSocketRequest(
self.scope, channel_name=getattr(self, "channel_name", None)
)
# Execute function (Pydantic validation happens inside execute_function)
# This is sync, so we need to run it in a thread pool
result = await sync_to_async(execute_function, thread_sensitive=True)(
ws_request,
fn_name,
args,
)
# Send response
if isinstance(result, FunctionError):
await self.send_json(
{
"id": request_id,
"ok": False,
"error": {
"code": result.code.value,
"message": result.message,
**({"details": result.details} if result.details else {}),
},
}
)
else:
await self.send_json(
{
"id": request_id,
"ok": True,
"data": result.data,
}
)
async def channel_message(self, event: dict):
"""
Handle messages broadcast to a group.
Called when channel_layer.group_send() is used.
Includes channel name and params so the client can route the message.
"""
await self.send_json(
{
"channel": event.get("channel"),
"params": event.get("params", {}),
"type": event.get("message_type", "message"),
"data": event.get("data", {}),
}
)
async def push_message(self, event: dict):
"""
Handle push messages from server functions.
Called when push("topic", data) is used from a server function.
The client receives this to update its local state.
Protocol:
Server sends: {"type": "push", "topic": "room:42", "data": {...}}
"""
await self.send_json(
{
"type": "push",
"topic": event.get("topic"),
"data": event.get("data", {}),
}
)

View File

@@ -0,0 +1,153 @@
"""
mizan Push - Server-initiated messages to clients.
Simple API for pushing data to subscribed WebSocket connections.
Usage:
# In a server function - push to all subscribers
from mizan.push import push
push("room:42", {"type": "new_message", "data": {...}})
# Subscribe a connection to a topic (call during context fetch)
from mizan.push import subscribe
subscribe(request, "room:42")
"""
from typing import TYPE_CHECKING
from pydantic import BaseModel
# Lazy import to avoid import errors when channels is not installed
# (e.g., during schema generation)
if TYPE_CHECKING:
from channels.layers import BaseChannelLayer
def _get_channel_layer() -> "BaseChannelLayer | None":
"""Get channel layer, returning None if channels is not installed."""
try:
from channels.layers import get_channel_layer
return get_channel_layer()
except ImportError:
return None
def _async_to_sync(coro):
"""Wrapper for async_to_sync that handles missing channels."""
from asgiref.sync import async_to_sync
return async_to_sync(coro)
def get_topic_group_name(topic: str) -> str:
"""Convert a topic string to a valid channel layer group name."""
# Channel layer group names must be valid ASCII alphanumeric + hyphens/underscores/periods
# Replace colons with underscores
return topic.replace(":", "_")
def subscribe(request, topic: str) -> None:
"""
Subscribe this WebSocket connection to a topic.
Call this in a context or server function to register the connection
for push notifications on the given topic.
Args:
request: The Django request (must have channel_name attribute from WebSocket)
topic: Topic string, e.g., "room:42", "user:123:notifications"
"""
channel_name = getattr(request, "channel_name", None)
if not channel_name:
# HTTP request, not WebSocket - can't subscribe
return
channel_layer = _get_channel_layer()
if not channel_layer:
return
group_name = get_topic_group_name(topic)
_async_to_sync(channel_layer.group_add)(group_name, channel_name)
def unsubscribe(request, topic: str) -> None:
"""
Unsubscribe this WebSocket connection from a topic.
Args:
request: The Django request (must have channel_name attribute from WebSocket)
topic: Topic string to unsubscribe from
"""
channel_name = getattr(request, "channel_name", None)
if not channel_name:
return
channel_layer = _get_channel_layer()
if not channel_layer:
return
group_name = get_topic_group_name(topic)
_async_to_sync(channel_layer.group_discard)(group_name, channel_name)
def push(topic: str, data: dict | BaseModel) -> None:
"""
Push data to all connections subscribed to a topic.
Args:
topic: Topic string, e.g., "room:42"
data: Data to send (dict or Pydantic model)
Example:
push("room:42", {
"type": "new_message",
"message": {"id": 1, "text": "Hello", "user": "alice@example.com"}
})
"""
channel_layer = _get_channel_layer()
if not channel_layer:
import logging
logging.getLogger(__name__).warning(
"No channel layer configured, cannot push to topic '%s'", topic
)
return
# Convert Pydantic model to dict if needed
if isinstance(data, BaseModel):
data = data.model_dump()
group_name = get_topic_group_name(topic)
_async_to_sync(channel_layer.group_send)(
group_name,
{
"type": "push.message", # Maps to push_message handler in consumer
"topic": topic,
"data": data,
},
)
async def push_async(topic: str, data: dict | BaseModel) -> None:
"""Async version of push for use in async contexts."""
channel_layer = _get_channel_layer()
if not channel_layer:
return
if isinstance(data, BaseModel):
data = data.model_dump()
group_name = get_topic_group_name(topic)
await channel_layer.group_send(
group_name,
{
"type": "push.message",
"topic": topic,
"data": data,
},
)

View File

@@ -0,0 +1,66 @@
"""
mizan.client - Server function implementation.
This subpackage contains everything needed to make server functions work:
- The @client decorator
- ServerFunction base class
- Function execution logic
- JWT authentication (integral to server functions)
Usage:
from mizan.client import client, ServerFunction, compose
"""
from .function import (
# Decorator
client,
# Context markers
ReactContext,
GlobalContext,
# Base classes
ServerFunction,
ComposedContext,
# Composition
compose,
# Type aliases
ContextMode,
# Form helpers
FormValidationOutput,
FormSchemaField,
FormSchemaOutput,
create_form_functions,
)
from .executor import (
execute_function,
function_call_view,
ErrorCode,
FunctionError,
FunctionResult,
)
__all__ = [
# Decorator
"client",
# Context markers
"ReactContext",
"GlobalContext",
# Base classes
"ServerFunction",
"ComposedContext",
# Composition
"compose",
# Type aliases
"ContextMode",
# Execution
"execute_function",
"function_call_view",
"ErrorCode",
"FunctionError",
"FunctionResult",
# Form helpers
"FormValidationOutput",
"FormSchemaField",
"FormSchemaOutput",
"create_form_functions",
]

View File

@@ -0,0 +1,593 @@
"""
mizan Function Executor
Handles execution of server functions.
This is the core of the "Server Functions" feature - callable from React
without REST boilerplate.
Security model:
- All input validated against Pydantic schema BEFORE execution
- Authentication: JWT (stateless) or Session (stateful) - auto-detected
- JWT: Authorization header with Bearer token (no CSRF needed)
- Session: Cookie-based with CSRF token (via X-CSRFToken header)
- WebSocket RPC uses Origin header checking instead
- No implicit function exposure - must be explicitly registered
"""
from __future__ import annotations
import json
import logging
from dataclasses import dataclass
from enum import Enum
from functools import wraps
from typing import TYPE_CHECKING, Any, Callable
from django.http import HttpRequest, JsonResponse
from django.views.decorators.csrf import csrf_protect
from pydantic import BaseModel, ValidationError
from mizan.setup.registry import get_function, get_context_groups
if TYPE_CHECKING:
pass
logger = logging.getLogger(__name__)
class ErrorCode(str, Enum):
"""Standard error codes for function execution."""
# Client errors (4xx)
NOT_FOUND = "NOT_FOUND" # Function not registered
VALIDATION_ERROR = "VALIDATION_ERROR" # Input failed Pydantic validation
UNAUTHORIZED = "UNAUTHORIZED" # User not authenticated (when required)
FORBIDDEN = "FORBIDDEN" # User lacks permission
BAD_REQUEST = "BAD_REQUEST" # Malformed request
# Server errors (5xx)
INTERNAL_ERROR = "INTERNAL_ERROR" # Unhandled exception
NOT_IMPLEMENTED = "NOT_IMPLEMENTED" # Function exists but not implemented
@dataclass
class FunctionError:
"""Structured error response from function execution."""
code: ErrorCode
message: str
details: dict[str, Any] | None = None
def to_dict(self) -> dict[str, Any]:
"""Convert to JSON-serializable dict."""
result = {
"error": True,
"code": self.code.value,
"message": self.message,
}
if self.details:
result["details"] = self.details
return result
def to_response(self, status: int = 400) -> JsonResponse:
"""Convert to Django JsonResponse."""
return JsonResponse(self.to_dict(), status=status)
@dataclass
class FunctionResult:
"""Successful result from function execution."""
data: dict[str, Any]
def to_dict(self) -> dict[str, Any]:
"""Convert to JSON-serializable dict."""
return {
"error": False,
"data": self.data,
}
def to_response(self) -> JsonResponse:
"""Convert to Django JsonResponse."""
return JsonResponse(self.to_dict())
def _check_auth_requirement(
request: HttpRequest,
auth_requirement: str | Callable | None,
) -> FunctionError | None:
"""
Check if the request meets the auth requirement.
Args:
request: The Django HttpRequest (with user set)
auth_requirement: 'required', 'staff', 'superuser', callable, or None
Returns:
FunctionError if auth check fails, None if it passes.
Note: This uses request.user which may be a JWTUser (stateless) or
Django User (from session). Either way, no additional DB query is made
for the built-in checks. Custom callables may query DB if they choose.
"""
if auth_requirement is None:
return None
user = request.user
# Handle callable auth
if callable(auth_requirement):
try:
result = auth_requirement(request)
if result:
return None # Authorized
else:
return FunctionError(
code=ErrorCode.FORBIDDEN,
message="Access denied",
)
except PermissionError as e:
# Custom error message from the callable
return FunctionError(
code=ErrorCode.FORBIDDEN,
message=str(e) or "Access denied",
)
# Check authentication (required for all string-based auth)
if not getattr(user, "is_authenticated", False):
return FunctionError(
code=ErrorCode.UNAUTHORIZED,
message="Authentication required",
)
# Check staff requirement
if auth_requirement == "staff":
if not getattr(user, "is_staff", False):
return FunctionError(
code=ErrorCode.FORBIDDEN,
message="Staff access required",
)
# Check superuser requirement
elif auth_requirement == "superuser":
if not getattr(user, "is_superuser", False):
return FunctionError(
code=ErrorCode.FORBIDDEN,
message="Superuser access required",
)
return None
def execute_function(
request: HttpRequest,
fn_name: str,
input_data: dict[str, Any] | None = None,
) -> FunctionResult | FunctionError:
"""
Execute a registered server function.
Args:
request: The Django HttpRequest
fn_name: Name of the registered function
input_data: Input data to pass to the function
Returns:
FunctionResult on success, FunctionError on failure
"""
from django.conf import settings
# Look up the function by name
view_class = get_function(fn_name)
if view_class is None:
# In DEBUG mode, include the name for easier debugging
if settings.DEBUG:
message = f"Function '{fn_name}' not found"
else:
message = "Function not found"
return FunctionError(
code=ErrorCode.NOT_FOUND,
message=message,
)
# Check auth requirement BEFORE executing
meta = getattr(view_class, "_meta", {})
auth_requirement = meta.get("auth")
auth_error = _check_auth_requirement(request, auth_requirement)
if auth_error is not None:
return auth_error
# Instantiate the view with the request
view = view_class(request)
# Check if this is a form function that handles input specially
meta = getattr(view_class, "_meta", {})
is_form_multipart = meta.get("multipart", False)
# For form functions with Input=None, skip Pydantic validation
# The form itself handles validation
input_cls = view.Input
if input_cls is None and is_form_multipart:
# Form function - pass input_data directly (already parsed by view or will be)
validated_input = input_data
elif input_cls is BaseModel:
has_input = False
validated_input = None
else:
# Check if it has any fields defined
has_input = bool(input_cls.model_fields) if input_cls else False
# Validate input against Pydantic schema
try:
if input_data:
# Ensure input_data is a dict (not array or other type)
if not isinstance(input_data, dict):
return FunctionError(
code=ErrorCode.BAD_REQUEST,
message="Input must be an object, not "
+ type(input_data).__name__,
)
validated_input = input_cls(**input_data)
elif has_input:
# Check if function requires input fields
input_schema = input_cls.model_json_schema()
required_fields = input_schema.get("required", [])
if required_fields:
# Format as field errors for consistency
errors = {field: ["Field required"] for field in required_fields}
return FunctionError(
code=ErrorCode.VALIDATION_ERROR,
message="Input validation failed",
details={"fields": errors},
)
validated_input = input_cls()
else:
# No input expected, create empty model
validated_input = None
except ValidationError as e:
# Convert Pydantic errors to our format
errors = {}
for error in e.errors():
field = ".".join(str(loc) for loc in error["loc"])
if field not in errors:
errors[field] = []
errors[field].append(error["msg"])
return FunctionError(
code=ErrorCode.VALIDATION_ERROR,
message="Input validation failed",
details={"fields": errors},
)
# Execute the function
try:
output = view.call(validated_input)
except NotImplementedError as e:
logger.error(f"Function {fn_name} not implemented: {e}")
return FunctionError(
code=ErrorCode.NOT_IMPLEMENTED,
message=str(e),
)
except PermissionError as e:
# Functions can raise PermissionError for auth issues
return FunctionError(
code=ErrorCode.FORBIDDEN,
message=str(e) or "Permission denied",
)
except Exception as e:
# Log the full exception for debugging
logger.exception(f"Error executing function {fn_name}")
return FunctionError(
code=ErrorCode.INTERNAL_ERROR,
message="An internal error occurred",
# Don't expose internal details in production
details={"type": type(e).__name__}
if logger.isEnabledFor(logging.DEBUG)
else None,
)
# Serialize output (handle None for Optional return types)
if output is None:
return FunctionResult(data=None)
return FunctionResult(data=output.model_dump())
def _try_jwt_auth(request: HttpRequest) -> bool:
"""
Attempt to authenticate the request using JWT.
If Authorization header contains a valid Bearer token, authenticates
the request and sets request.user to a JWTUser. Returns True if JWT
auth succeeded.
IMPORTANT: This is stateless - no database query is made. The JWTUser
object is created from the token claims. If you need the full User
object, query it explicitly in your function.
Security: If JWT is provided but invalid, we return False and do NOT
fall back to session auth. The caller should reject the request.
"""
auth_header = request.META.get("HTTP_AUTHORIZATION", "")
if not auth_header.startswith("Bearer "):
return False
token = auth_header[7:] # Strip "Bearer "
if not token:
return False
try:
from mizan.client.jwt import decode_token
from mizan.jwt.tokens import JWTUser
payload = decode_token(token, expected_type="access")
if payload is None:
return False
# Create JWTUser from token claims - NO DATABASE QUERY
request.user = JWTUser(payload)
request._mizan_jwt_authenticated = True
return True
except Exception:
return False
def _has_jwt_header(request: HttpRequest) -> bool:
"""Check if request has a JWT Authorization header."""
auth_header = request.META.get("HTTP_AUTHORIZATION", "")
return auth_header.startswith("Bearer ")
def _csrf_protect_unless_jwt(view_func):
"""
Decorator that applies CSRF protection unless JWT auth is used.
JWT tokens are self-authenticating (the token itself proves the request
is legitimate), so CSRF protection is not needed.
Security: If JWT is provided but invalid, reject the request - do NOT
fall back to session auth. This prevents attacks where an invalid token
is sent alongside a valid session cookie.
"""
csrf_protected_view = csrf_protect(view_func)
@wraps(view_func)
def wrapper(request: HttpRequest, *args, **kwargs):
# Check if JWT header is present
has_jwt = _has_jwt_header(request)
if has_jwt:
# JWT header present - try to authenticate
if _try_jwt_auth(request):
# JWT valid - skip CSRF, proceed
return view_func(request, *args, **kwargs)
else:
# JWT invalid - reject (do NOT fall back to session)
return FunctionError(
code=ErrorCode.UNAUTHORIZED,
message="Invalid or expired JWT token",
).to_response(status=401)
else:
# No JWT - use session auth with CSRF
return csrf_protected_view(request, *args, **kwargs)
return wrapper
@_csrf_protect_unless_jwt
def function_call_view(request: HttpRequest) -> JsonResponse:
"""
Django view for handling function calls (HTTP fallback for WebSocket RPC).
Authentication (auto-detected):
- JWT: Authorization: Bearer <token> (stateless, no CSRF needed)
- Session: Cookie-based with X-CSRFToken header (CSRF required)
Endpoint: POST /api/mizan/call/
Request body (JSON):
{
"fn": "function_name", // Function name
"args": { ... } // Optional, depending on function
}
Request body (multipart/form-data for form submit functions):
fn: function_name
<field>: <value>
...
Response on success:
{
"error": false,
"data": { ... } // Function output
}
Response on error:
{
"error": true,
"code": "VALIDATION_ERROR",
"message": "Input validation failed",
"details": { ... }
}
"""
# Only allow POST
if request.method != "POST":
return FunctionError(
code=ErrorCode.BAD_REQUEST,
message="Only POST method allowed",
).to_response(status=405)
# Check content type to determine parsing method
content_type = request.content_type or ""
is_multipart = content_type.startswith("multipart/form-data")
if is_multipart:
# Multipart form data - used by form submit functions
fn_name = request.POST.get("fn")
if not fn_name:
return FunctionError(
code=ErrorCode.BAD_REQUEST,
message="Missing 'fn' field",
).to_response()
# Get form data (excluding 'fn')
input_data = {k: v for k, v in request.POST.dict().items() if k != "fn"}
# Attach parsed form data and files to request for form functions
request._mizan_form_data = input_data
request._mizan_form_files = request.FILES
else:
# JSON body - standard RPC
try:
if request.body:
body = json.loads(request.body)
else:
return FunctionError(
code=ErrorCode.BAD_REQUEST,
message="Request body required",
).to_response()
except json.JSONDecodeError:
return FunctionError(
code=ErrorCode.BAD_REQUEST,
message="Invalid JSON in request body",
).to_response()
# Extract function name and args
fn_name = body.get("fn")
if not fn_name:
return FunctionError(
code=ErrorCode.BAD_REQUEST,
message="Missing 'fn' field",
).to_response()
input_data = body.get("args")
# Execute the function
result = execute_function(request, fn_name, input_data)
# Return appropriate response
if isinstance(result, FunctionError):
status = {
ErrorCode.NOT_FOUND: 404,
ErrorCode.VALIDATION_ERROR: 422,
ErrorCode.UNAUTHORIZED: 401,
ErrorCode.FORBIDDEN: 403,
ErrorCode.BAD_REQUEST: 400,
ErrorCode.INTERNAL_ERROR: 500,
ErrorCode.NOT_IMPLEMENTED: 501,
}.get(result.code, 400)
return result.to_response(status=status)
return result.to_response()
def execute_context(
request: HttpRequest,
context_name: str,
params: dict[str, str],
) -> FunctionResult | FunctionError:
"""
Execute all functions in a named context with merged params.
Each function receives only the params it declares in its Input schema.
If any function fails (auth, validation, execution), the entire request fails.
Args:
request: The Django HttpRequest
context_name: Name of the context (e.g., 'user', 'global')
params: Query parameters (strings — Pydantic coerces types)
Returns:
FunctionResult with bundled data, or FunctionError
"""
groups = get_context_groups()
fn_names = groups.get(context_name)
if not fn_names:
return FunctionError(
code=ErrorCode.NOT_FOUND,
message=f"Context '{context_name}' not found",
)
results = {}
for fn_name in fn_names:
view_class = get_function(fn_name)
if view_class is None:
continue
# Filter params to only those in this function's Input schema
input_cls = getattr(view_class, "Input", None)
if input_cls and input_cls is not BaseModel and input_cls.model_fields:
fn_params = {
k: v for k, v in params.items()
if k in input_cls.model_fields
}
else:
fn_params = None
result = execute_function(request, fn_name, fn_params)
if isinstance(result, FunctionError):
return result
results[fn_name] = result.data
return FunctionResult(data=results)
def _jwt_auth_only(view_func):
"""
Decorator that handles JWT auth for GET endpoints (no CSRF needed for GET).
"""
@wraps(view_func)
def wrapper(request: HttpRequest, *args, **kwargs):
has_jwt = _has_jwt_header(request)
if has_jwt:
if _try_jwt_auth(request):
return view_func(request, *args, **kwargs)
else:
return FunctionError(
code=ErrorCode.UNAUTHORIZED,
message="Invalid or expired JWT token",
).to_response(status=401)
# No JWT — session auth (no CSRF needed for GET)
return view_func(request, *args, **kwargs)
return wrapper
@_jwt_auth_only
def context_fetch_view(request: HttpRequest, context_name: str) -> JsonResponse:
"""
Fetch all functions in a named context in a single bundled GET request.
Endpoint: GET /api/mizan/ctx/<context_name>/?param1=val1&param2=val2
Response on success:
{
"error": false,
"data": {
"user_profile": { ... },
"user_orders": [ ... ]
}
}
"""
if request.method != "GET":
return FunctionError(
code=ErrorCode.BAD_REQUEST,
message="Only GET method allowed",
).to_response(status=405)
params = dict(request.GET)
result = execute_context(request, context_name, params)
if isinstance(result, FunctionError):
status = {
ErrorCode.NOT_FOUND: 404,
ErrorCode.VALIDATION_ERROR: 422,
ErrorCode.UNAUTHORIZED: 401,
ErrorCode.FORBIDDEN: 403,
ErrorCode.BAD_REQUEST: 400,
ErrorCode.INTERNAL_ERROR: 500,
ErrorCode.NOT_IMPLEMENTED: 501,
}.get(result.code, 400)
return result.to_response(status=status)
return result.to_response()

View File

@@ -0,0 +1,816 @@
"""
mizan Server Functions - Core Primitive
Server functions are the core primitive. Everything else builds on them.
Two styles supported:
1. Function-based (recommended, Django Ninja style):
@client("update-profile")
def update_profile(request, input: UpdateProfileInput) -> UpdateProfileOutput:
return UpdateProfileOutput(success=True)
2. Class-based (for complex cases):
class UpdateProfile(ServerFunction):
def call(self, input: UpdateProfileInput) -> UpdateProfileOutput:
return UpdateProfileOutput(success=True)
register(UpdateProfile, 'update-profile')
"""
from __future__ import annotations
import inspect
import warnings
from abc import ABC, abstractmethod
from typing import (
Any,
Callable,
ClassVar,
Generic,
Literal,
TypeVar,
Union,
get_args,
get_origin,
get_type_hints,
)
from django.http import HttpRequest
from pydantic import BaseModel
# =============================================================================
# REACT CONTEXT - Named context marker
# =============================================================================
class ReactContext:
"""
A named context that groups server functions into one provider and one fetch.
Usage:
UserContext = ReactContext('user')
@client(context=UserContext)
def user_profile(request, user_id: int) -> ProfileShape: ...
@client(context=UserContext)
def user_orders(request, user_id: int) -> list[OrderShape]: ...
@client(affects=UserContext)
def edit_profile(request, name: str) -> dict: ...
@client(affects=[UserContext, OrderContext])
def change_plan(request) -> dict: ...
"""
def __init__(self, name: str):
if not name or not isinstance(name, str):
raise ValueError("ReactContext name must be a non-empty string")
self.name = name
def __repr__(self) -> str:
return f"ReactContext({self.name!r})"
# Built-in global context (auto-mounted at root, SSR-hydrated)
GlobalContext = ReactContext("global")
# Context parameter type: a ReactContext instance, a raw string, or False
ContextMode = ReactContext | str | Literal[False]
TInput = TypeVar("TInput", bound=BaseModel)
TOutput = TypeVar("TOutput", bound=BaseModel)
# =============================================================================
# SERVER FUNCTION - The Core Primitive
# =============================================================================
class ServerFunction(ABC, Generic[TInput, TOutput]):
"""
Class-based server function (for complex cases).
For simple functions, use the @client decorator instead.
Usage:
class UpdateProfile(ServerFunction):
def call(self, input: UpdateProfileInput) -> UpdateProfileOutput:
self.user.name = input.name
self.user.save()
return UpdateProfileOutput(success=True)
register(UpdateProfile, 'update-profile')
"""
# Registration name (set by register())
name: ClassVar[str]
# Metadata for code generation
_meta: ClassVar[dict[str, Any]] = {}
# Schema classes (set automatically from type hints or explicitly)
Input: ClassVar[type[BaseModel]] = BaseModel
Output: ClassVar[type[BaseModel]] = BaseModel
def __init__(self, request: HttpRequest):
"""Initialize with the Django request."""
self.request = request
@property
def user(self):
"""Shortcut to request.user."""
return self.request.user
@abstractmethod
def call(self, input: TInput) -> TOutput:
"""
Execute the function.
Args:
input: Validated input data
Returns:
Output instance
"""
raise NotImplementedError(f"{self.__class__.__name__} must implement call()")
@classmethod
def get_schema_export(cls) -> dict[str, Any]:
"""Export schema for TypeScript generation."""
export = {
"name": getattr(cls, "name", cls.__name__),
"type": "function",
"meta": getattr(cls, "_meta", {}),
}
# Get Input/Output from class attributes
input_cls = getattr(cls, "Input", BaseModel)
output_cls = getattr(cls, "Output", BaseModel)
# Check if Input has fields
input_schema = input_cls.model_json_schema()
has_input = bool(input_schema.get("properties"))
if has_input:
export["input"] = input_schema
export["has_input"] = has_input
export["output"] = output_cls.model_json_schema()
return export
# =============================================================================
# FUNCTION DECORATOR - Django Ninja Style
# =============================================================================
class _FunctionWrapper(ServerFunction):
"""Internal wrapper that makes a plain function behave like a ServerFunction."""
# Will be set per-wrapper instance
_wrapped_fn: ClassVar[Callable]
_input_cls: ClassVar[type[BaseModel] | None]
_output_cls: ClassVar[type[BaseModel]]
_param_names: ClassVar[list[str]] = []
_is_primitive_output: ClassVar[bool] = False
def call(self, input):
"""Execute the wrapped function, unpacking input into individual args."""
if input is not None and self._param_names:
# Unpack validated model into keyword arguments
kwargs = {name: getattr(input, name) for name in self._param_names}
result = self._wrapped_fn(self.request, **kwargs)
else:
result = self._wrapped_fn(self.request)
# Wrap primitive returns in the generated output model
if self._is_primitive_output:
return self._output_cls(result=result)
return result
@classmethod
def get_schema_export(cls) -> dict[str, Any]:
"""Export schema for TypeScript generation."""
export = {
"name": getattr(cls, "name", cls.__name__),
"type": "function",
"meta": getattr(cls, "_meta", {}),
}
# Use stored schema classes
if cls._input_cls is not None:
input_schema = cls._input_cls.model_json_schema()
has_input = bool(input_schema.get("properties"))
if has_input:
export["input"] = input_schema
export["has_input"] = has_input
else:
export["has_input"] = False
export["output"] = cls._output_cls.model_json_schema()
return export
# Valid string values for auth parameter
_VALID_AUTH_STRINGS = frozenset({"required", "staff", "superuser"})
def _resolve_context(context: ContextMode) -> str | Literal[False]:
"""Resolve a context parameter to its name string."""
if context is False:
return False
if isinstance(context, ReactContext):
return context.name
if isinstance(context, str):
if not context.strip():
raise ValueError("context must be a non-empty string, ReactContext, or False.")
if context == "local":
warnings.warn(
"context='local' is deprecated. Use ReactContext('name') instead.",
DeprecationWarning,
stacklevel=3,
)
return context
raise ValueError(
f"context must be a ReactContext, a string, or False. Got {type(context).__name__}."
)
# Affects parameter type
AffectsTarget = ReactContext | str | type["ServerFunction"]
AffectsMode = AffectsTarget | list[AffectsTarget] | None
def client(
fn: Callable = None,
*,
context: ContextMode = False,
affects: AffectsMode = None,
websocket: bool = False,
auth: bool | str | Callable[[Any], bool] | None = None,
) -> type[ServerFunction] | Callable[[Callable], type[ServerFunction]]:
"""
Register a function as a server function.
Type annotations define the schema - just like Django Ninja/FastAPI.
Function parameters become input fields automatically.
Args:
context: Named context for React state management.
- False (default): Not a context, just a callable function.
- ReactContext instance: groups functions into a named context.
- GlobalContext: reserved, auto-mounted at root, SSR-hydrated.
- Raw string: also accepted (e.g., 'user'), but ReactContext preferred.
affects: Declare which contexts this mutation invalidates.
- A ReactContext instance
- A list of ReactContext instances
- Also accepts strings or function references for backwards compat
Mutually exclusive with context=.
websocket: Enable WebSocket RPC transport (default: False).
auth: Authentication requirement.
- None (default): No auth required
- True or 'required': Must be authenticated
- 'staff': Must have is_staff=True
- 'superuser': Must have is_superuser=True
- callable(request) -> bool: Custom check function
Usage:
UserContext = ReactContext('user')
@client(context=GlobalContext)
def current_user(request) -> UserOutput: ...
@client(context=UserContext)
def user_profile(request, user_id: int) -> ProfileOutput: ...
@client(affects=UserContext)
def edit_profile(request, name: str) -> dict: ...
@client(affects=[UserContext, OrderContext])
def change_plan(request) -> dict: ...
Returns:
A ServerFunction class that wraps the function
"""
# Resolve context to name string
resolved_context = _resolve_context(context)
# Validate affects parameter
if affects is not None:
if resolved_context is not False:
raise ValueError(
"context= and affects= are mutually exclusive. "
"A function cannot be both a context reader and a mutation."
)
# Validate auth parameter
if auth is not None:
if isinstance(auth, str) and auth not in _VALID_AUTH_STRINGS:
raise ValueError(
f"Invalid auth value '{auth}'. "
f"Must be one of: {', '.join(sorted(_VALID_AUTH_STRINGS))}, True, or a callable."
)
def decorator(fn: Callable) -> type[ServerFunction]:
return _create_server_function(
fn, context=resolved_context, affects=affects, websocket=websocket, auth=auth
)
# Support both @client and @client(...)
if fn is not None:
return _create_server_function(
fn, context=resolved_context, affects=affects, websocket=websocket, auth=auth
)
return decorator
def _normalize_affects(affects: AffectsMode) -> list[dict[str, str]] | None:
"""Normalize the affects parameter into a list of target descriptors."""
if affects is None:
return None
items = affects if isinstance(affects, list) else [affects]
result = []
for item in items:
if isinstance(item, ReactContext):
result.append({"type": "context", "name": item.name})
elif isinstance(item, str):
result.append({"type": "context", "name": item})
elif isinstance(item, type) and issubclass(item, ServerFunction):
fn_meta = getattr(item, "_meta", {})
fn_ctx = fn_meta.get("context")
result.append({
"type": "function",
"name": getattr(item, "__name__", str(item)),
"context": fn_ctx or None,
})
else:
raise ValueError(
f"affects items must be ReactContext instances, context name strings, "
f"or @client function references. Got {type(item)}"
)
return result
def _create_server_function(
fn: Callable,
*,
context: str | Literal[False] = False,
affects: str | type["ServerFunction"] | list[str | type["ServerFunction"]] | None = None,
websocket: bool = False,
auth: bool | str | None = None,
) -> type[ServerFunction]:
"""Internal helper that creates a ServerFunction from a decorated function."""
from pydantic import create_model
# Use function name directly
name = fn.__name__
# Extract type hints and signature
hints = get_type_hints(fn)
sig = inspect.signature(fn)
params = list(sig.parameters.items())
# Skip 'request' parameter (first param)
input_params = params[1:] if params else []
# Build input schema from function parameters
if input_params:
# Build field definitions for create_model
# Format: {field_name: (type, default) or (type, ...)}
fields = {}
for param_name, param in input_params:
param_type = hints.get(param_name, Any)
if param.default is inspect.Parameter.empty:
# Required field
fields[param_name] = (param_type, ...)
else:
# Optional field with default
fields[param_name] = (param_type, param.default)
# Create dynamic Pydantic model
input_cls = create_model(f"{fn.__name__}_Input", **fields)
else:
input_cls = None
# Get output type from return annotation
output_type = hints.get("return")
if output_type is None:
raise TypeError(f"Server function '{name}' must have a return type annotation")
# Support primitive return types by wrapping in a model with 'result' field
# Also handle Optional[X] / X | None by extracting the non-None type
import types
def is_basemodel_type(t: Any) -> bool:
"""Check if type is a BaseModel subclass, handling Optional/Union."""
if isinstance(t, type) and issubclass(t, BaseModel):
return True
# Handle Union types: typing.Union (Optional[X]) and types.UnionType (X | None)
origin = get_origin(t)
if origin is Union or isinstance(t, types.UnionType):
args = get_args(t)
# Check if any non-None arg is a BaseModel
for arg in args:
if (
arg is not type(None)
and isinstance(arg, type)
and issubclass(arg, BaseModel)
):
return True
return False
if is_basemodel_type(output_type):
output_cls = output_type
is_primitive_output = False
else:
# Create model wrapper for primitive types (int, str, list, etc.)
output_cls = create_model(f"{fn.__name__}_Output", result=(output_type, ...))
is_primitive_output = True
# Store param names for unpacking validated input
param_names = [p[0] for p in input_params]
# Create a unique wrapper class for this function
class FunctionWrapper(_FunctionWrapper):
_param_names: ClassVar[list[str]] = param_names
FunctionWrapper.__name__ = fn.__name__
FunctionWrapper.__doc__ = fn.__doc__
FunctionWrapper.__module__ = fn.__module__ # Critical for discovery
FunctionWrapper._wrapped_fn = staticmethod(fn)
FunctionWrapper._input_cls = input_cls
FunctionWrapper._output_cls = output_cls
FunctionWrapper._is_primitive_output = is_primitive_output
# Set Input/Output class attributes for compatibility
if input_cls is not None:
FunctionWrapper.Input = input_cls
FunctionWrapper.Output = output_cls
# Build metadata
meta = {}
# Context name (any non-empty string)
if context:
meta["context"] = context
# Affects: mutation invalidation targets
normalized_affects = _normalize_affects(affects)
if normalized_affects:
meta["affects"] = normalized_affects
# WebSocket: enable WebSocket transport
if websocket:
meta["websocket"] = True
# Auth requirement
if auth is not None:
if auth is True:
meta["auth"] = "required"
elif callable(auth):
meta["auth"] = auth
else:
meta["auth"] = auth
if meta:
FunctionWrapper._meta = {**FunctionWrapper._meta, **meta}
# Note: Registration happens via discovery (mizan_clients), not here.
# This allows the decorator to be used without import-time side effects.
return FunctionWrapper
# =============================================================================
# COMPOSE - Combine multiple contexts into a single provider
# =============================================================================
class ComposedContext:
"""
Marker class for composed contexts.
Stores metadata about the composition for schema export.
"""
name: str
_meta: dict[str, Any]
_children: list[type[ServerFunction] | "ComposedContext"]
_leaves: list[type[ServerFunction]]
def __init__(
self,
name: str,
children: list,
leaves: list,
on_server: bool,
websocket: bool,
):
self.name = name
self._children = children
self._leaves = leaves
self._meta = {
"compose": True,
"on_server": on_server,
"websocket": websocket,
"children": [c.name for c in children],
"leaves": [leaf.name for leaf in leaves],
}
@classmethod
def get_schema_export(cls) -> dict[str, Any]:
"""Export schema for TypeScript generation."""
return {
"name": cls.name,
"type": "compose",
"meta": cls._meta,
"children": cls._meta.get("children", []),
"leaves": cls._meta.get("leaves", []),
}
def _get_leaves(item) -> list[type[ServerFunction]]:
"""Recursively collect all leaf contexts from a context or composition."""
if isinstance(item, type) and issubclass(item, ServerFunction):
return [item]
elif isinstance(item, ComposedContext):
return item._leaves.copy()
elif hasattr(item, "_leaves"):
# Duck typing for composed contexts
return item._leaves.copy()
else:
raise TypeError(f"Expected ServerFunction or ComposedContext, got {type(item)}")
def _is_context_enabled(item) -> bool:
"""Check if an item is a context-enabled function or composition."""
if isinstance(item, ComposedContext) or hasattr(item, "_leaves"):
return True
if isinstance(item, type) and issubclass(item, ServerFunction):
meta = getattr(item, "_meta", {})
return bool(meta.get("context"))
return False
def compose(
*children,
on_server: bool = False,
websocket: bool = False,
):
"""
Compose multiple contexts into a single provider.
Args:
*children: Context functions (@client with a context name)
or other @compose functions. All must be unique after flattening.
on_server: Bundle all calls into a single server request (default: False).
- False: Frontend makes individual calls (mixed HTTP/WS OK)
- True: Single bundled call. Requires transport consistency:
all children must be HTTP-only XOR all must be websocket=True.
websocket: Transport for bundled call when on_server=True (default: False).
- False: Bundled call over HTTP. All children must be HTTP-only.
- True: Bundled call over WebSocket. All children must have websocket=True.
Usage:
@client(context='local')
def user_profile(request, user_id: int) -> ProfileOutput: ...
@client(context='local')
def user_posts(request, user_id: int) -> PostsOutput: ...
@compose(user_profile, user_posts)
def user_page():
pass
# Frontend generates:
# <UserPageProvider user_id={123}>
# <App />
# </UserPageProvider>
Nesting:
@compose(ctx_a, ctx_b)
def ab(): pass
@compose(ab, ctx_c) # Flattens to [ctx_a, ctx_b, ctx_c]
def abc(): pass
Returns:
A ComposedContext that can be used in other compositions.
"""
def decorator(fn: Callable) -> ComposedContext:
from mizan.setup.registry import register_compose
name = fn.__name__
# Validate: all children must be context-enabled
for i, child in enumerate(children):
if not _is_context_enabled(child):
child_name = getattr(
child, "name", getattr(child, "__name__", str(child))
)
raise ValueError(
f"@compose argument {i} ({child_name}) is not context-enabled. "
f"All children must have @client(context=...) or be @compose."
)
# Flatten to collect all leaves
leaves = []
for child in children:
leaves.extend(_get_leaves(child))
# Validate: no duplicate leaves (by identity)
seen = set()
for leaf in leaves:
if id(leaf) in seen:
raise ValueError(
f"Duplicate context '{leaf.name}' in @compose({name}). "
f"Each context can only appear once. Use named kwargs for reuse (future feature)."
)
seen.add(id(leaf))
# Validate transport consistency when on_server=True
if on_server:
has_websocket = [
getattr(leaf, "_meta", {}).get("websocket", False) for leaf in leaves
]
if websocket:
# All must have websocket=True
if not all(has_websocket):
non_ws = [
leaf.name for leaf, ws in zip(leaves, has_websocket) if not ws
]
raise ValueError(
f"@compose({name}, on_server=True, websocket=True) requires all children "
f"to have websocket=True. These are HTTP-only: {non_ws}"
)
else:
# All must be HTTP-only
if any(has_websocket):
ws_enabled = [
leaf.name for leaf, ws in zip(leaves, has_websocket) if ws
]
raise ValueError(
f"@compose({name}, on_server=True, websocket=False) requires all children "
f"to be HTTP-only. These have websocket=True: {ws_enabled}"
)
# Create composed context
composed = ComposedContext(
name=name,
children=list(children),
leaves=leaves,
on_server=on_server,
websocket=websocket,
)
# Make it a class-like object for consistency
composed.__name__ = name
composed.__doc__ = fn.__doc__
# Register the composition
register_compose(composed, name)
return composed
return decorator
# =============================================================================
# FORM HELPERS - Output types used by form server functions
# =============================================================================
class FormValidationOutput(BaseModel):
"""Standard output for form validation."""
valid: bool
errors: dict[str, list[str]]
class FormSchemaField(BaseModel):
"""Schema for a single form field."""
name: str
type: str
required: bool
label: str
help_text: str | None = None
choices: list[tuple[str, str]] | None = None
initial: Any = None
class FormSchemaOutput(BaseModel):
"""Standard output for form schema."""
fields: list[FormSchemaField]
def create_form_functions(
form_class: type,
name: str,
submit_handler: Callable[[HttpRequest, dict], BaseModel] | None = None,
) -> tuple[type[ServerFunction], type[ServerFunction], type[ServerFunction] | None]:
"""
Generate server functions for a Django Form.
Args:
form_class: Django Form class
name: Base name for the functions
submit_handler: Optional handler for form submission
Returns:
Tuple of (SchemaFunction, ValidateFunction, SubmitFunction or None)
Usage:
SchemaFn, ValidateFn, SubmitFn = create_form_functions(
ContactForm,
'contact',
submit_handler=lambda req, data: ContactSubmitOutput(success=True),
)
register(SchemaFn, 'contact-schema')
register(ValidateFn, 'contact-validate')
register(SubmitFn, 'contact-submit')
Or use the helper:
register_form(ContactForm, 'contact', submit_handler=...)
"""
from mizan.forms.schema_utils import build_form_schema
# Schema function - returns field definitions
class FormSchema(ServerFunction):
class Output(FormSchemaOutput):
pass
def call(self, input):
schema = build_form_schema(form_class)
fields = [
FormSchemaField(
name=field.name,
type=field.type,
required=field.required,
label=field.label or field.name,
help_text=field.help_text or None,
choices=[(c.value, c.label) for c in field.choices]
if field.choices
else None,
initial=field.initial,
)
for field in schema.fields
]
return self.Output(fields=fields)
FormSchema.__name__ = f"{name.title().replace('-', '')}Schema"
FormSchema._meta = {"form": True, "form_name": name, "form_role": "schema"}
# Validation function
class FormDataInput(BaseModel):
data: dict[str, Any]
class FormValidate(ServerFunction):
Input = FormDataInput
class Output(FormValidationOutput):
pass
def call(self, input):
form = form_class(data=input.data)
if form.is_valid():
return self.Output(valid=True, errors={})
return self.Output(valid=False, errors=dict(form.errors))
FormValidate.__name__ = f"{name.title().replace('-', '')}Validate"
FormValidate._meta = {"form": True, "form_name": name, "form_role": "validate"}
# Submit function (optional)
FormSubmit = None
if submit_handler:
class FormSubmit(ServerFunction):
Input = FormDataInput
def call(self, input):
# Validate first
form = form_class(data=input.data)
if not form.is_valid():
raise ValueError("Form validation failed")
# Call handler
return submit_handler(self.request, form.cleaned_data)
FormSubmit.__name__ = f"{name.title().replace('-', '')}Submit"
FormSubmit._meta = {"form": True, "form_name": name, "form_role": "submit"}
return FormSchema, FormValidate, FormSubmit

View File

@@ -0,0 +1,44 @@
"""
mizan.client.jwt - JWT authentication for server functions.
Provides:
- Server functions for obtaining/refreshing JWT tokens
- JWT authentication utilities for validating tokens
Server Functions:
- jwt_obtain: Convert authenticated session to JWT tokens
- jwt_refresh: Refresh tokens using a refresh token
Note: This module is purpose-built for mizan server functions.
For Django Ninja API authentication, use mizan.jwt.security directly.
"""
# Token utilities (re-exports from django_jwt_session)
from mizan.jwt.tokens import (
create_token_pair,
create_access_token,
create_refresh_token,
decode_token,
refresh_tokens,
TokenPair,
TokenPayload,
JWTUser,
)
# Settings
from mizan.jwt.settings import get_settings, JWTSettings
__all__ = [
# Token utilities
"create_token_pair",
"create_access_token",
"create_refresh_token",
"decode_token",
"refresh_tokens",
"TokenPair",
"TokenPayload",
"JWTUser",
# Settings
"get_settings",
"JWTSettings",
]

View File

@@ -0,0 +1,352 @@
"""
mizan OpenAPI Schema Generator
Generates OpenAPI 3.0 compatible schema from registered server functions.
Uses Django Ninja's battle-tested schema generation for robust Pydantic→OpenAPI conversion.
This schema is consumed by the frontend generator which uses openapi-typescript
for robust type generation.
NOTE: Schema export is only available via management command for security.
HTTP endpoint has been removed to prevent function enumeration.
Usage:
python manage.py export_mizan_schema
"""
from __future__ import annotations
import json
import re
from typing import TYPE_CHECKING, Any
# Lazy imports to avoid Django settings access at module load time
# (asgi.py imports mizan before Django is fully configured)
if TYPE_CHECKING:
from django import forms
from ninja import NinjaAPI
from mizan.setup.registry import get_registry, get_schema, get_context_groups, get_function
__all__ = ["get_schema", "generate_openapi_schema", "generate_openapi_json"]
def _extract_form_fields(form_class: type) -> list[dict[str, Any]]:
"""
Extract field definitions with constraints from a Django Form class.
Returns a list of field metadata suitable for Zod schema generation:
- name: field name
- zodType: base Zod type ("string", "number", "boolean", "array")
- required: whether field is required
- constraints: dict of Zod-compatible constraints
Constraints include:
- min/max: for string length or number range
- email/url: for format validation
- regex: for pattern validation
- choices: for enum validation
"""
try:
# Try to instantiate form to get bound fields
form = form_class()
fields_dict = form.fields
except TypeError:
# Form requires extra args - use base_fields
fields_dict = getattr(form_class, "base_fields", {})
result = []
for name, field in fields_dict.items():
field_meta = _extract_field_constraints(name, field)
result.append(field_meta)
return result
def _extract_field_constraints(name: str, field: "forms.Field") -> dict[str, Any]:
"""
Extract Zod-compatible constraints from a single Django form field.
"""
from django import forms # Lazy import
meta: dict[str, Any] = {
"name": name,
"required": field.required,
"constraints": {},
}
# Determine base Zod type
if isinstance(field, forms.BooleanField):
meta["zodType"] = "boolean"
elif isinstance(field, (forms.IntegerField, forms.FloatField, forms.DecimalField)):
meta["zodType"] = "number"
if isinstance(field, forms.IntegerField):
meta["constraints"]["int"] = True
elif isinstance(field, forms.MultipleChoiceField):
meta["zodType"] = "array"
meta["constraints"]["items"] = "string"
elif isinstance(field, forms.FileField):
meta["zodType"] = "file"
else:
# Default to string (CharField, EmailField, URLField, etc.)
meta["zodType"] = "string"
# Extract string constraints
if hasattr(field, "max_length") and field.max_length is not None:
meta["constraints"]["max"] = field.max_length
if hasattr(field, "min_length") and field.min_length is not None:
meta["constraints"]["min"] = field.min_length
# Extract number constraints
if hasattr(field, "max_value") and field.max_value is not None:
meta["constraints"]["max"] = field.max_value
if hasattr(field, "min_value") and field.min_value is not None:
meta["constraints"]["min"] = field.min_value
# Email/URL format
if isinstance(field, forms.EmailField):
meta["constraints"]["email"] = True
elif isinstance(field, forms.URLField):
meta["constraints"]["url"] = True
# Choices (for enum validation)
if hasattr(field, "choices") and field.choices:
# Extract choice values (not labels)
choices = []
for choice in field.choices:
if isinstance(choice, (list, tuple)) and len(choice) >= 1:
# Skip empty/blank choices
if choice[0] != "":
choices.append(str(choice[0]))
else:
choices.append(str(choice))
if choices:
meta["constraints"]["choices"] = choices
# Regex validators
for validator in field.validators:
if hasattr(validator, "regex"):
# RegexValidator - extract pattern
pattern = validator.regex.pattern
meta["constraints"]["regex"] = pattern
if hasattr(validator, "message"):
meta["constraints"]["regexMessage"] = validator.message
break # Only use first regex validator
return meta
def snake_to_camel(name: str) -> str:
"""Convert snake_case or dotted.name to camelCase.
Examples:
- login -> login
- login.schema -> loginSchema
- activate_totp -> activateTotp
- activate_totp.schema -> activateTotpSchema
"""
# Split on both underscores and dots
components = re.split(r"[._]", name)
return components[0] + "".join(x.title() for x in components[1:])
def _register_schema_endpoint(
api: "NinjaAPI",
path: str,
operation_id: str,
summary: str,
input_cls: type | None,
output_cls: type,
) -> None:
"""
Register a dummy endpoint on the API for schema generation.
Sets __annotations__ directly to avoid closure capture issues
and exec() security concerns.
"""
if input_cls is not None:
def endpoint(request, data):
pass
# Set annotations directly to the actual type objects (not strings)
endpoint.__annotations__ = {"data": input_cls}
else:
def endpoint(request):
pass
# Register with Ninja
api.post(path, response=output_cls, operation_id=operation_id, summary=summary)(
endpoint
)
def generate_openapi_schema() -> dict[str, Any]:
"""
Generate OpenAPI 3.0 schema for all registered mizan functions.
Uses Django Ninja's schema generation internally to ensure proper
Pydantic→OpenAPI conversion (handling $refs, nested types, etc.).
Returns a complete OpenAPI document that can be processed by openapi-typescript.
"""
from ninja import NinjaAPI # Lazy import
from pydantic import BaseModel, create_model # Lazy import
registry = get_registry()
functions = registry.get("functions", {})
# Create a temporary Ninja API for schema generation only
# This is NOT exposed as an HTTP endpoint - purely for leveraging Ninja's
# battle-tested Pydantic→OpenAPI conversion
schema_api = NinjaAPI(
title="mizan Server Functions",
version="1.0.0",
description="Auto-generated schema for mizan server functions",
docs_url=None, # No docs endpoint
openapi_url=None, # No openapi endpoint
)
function_metadata: list[dict[str, Any]] = []
# Store dynamically created classes so they persist for schema generation
schema_classes: dict[str, type] = {}
for name, fn_class in functions.items():
camel_name = snake_to_camel(name)
meta = getattr(fn_class, "_meta", {})
# Get Input/Output classes
input_cls = getattr(fn_class, "Input", None)
output_cls = getattr(fn_class, "Output", None) or BaseModel
# Check if input_cls is a valid Pydantic model with fields
has_input = (
input_cls is not None
and input_cls is not BaseModel
and hasattr(input_cls, "model_fields")
and bool(input_cls.model_fields)
)
# Determine type names for metadata
input_type_name = f"{camel_name}Input" if has_input else None
output_type_name = f"{camel_name}Output"
# Create renamed Pydantic classes for cleaner schema names
# Store them in schema_classes so they persist beyond loop scope
# Uses create_model to avoid metaclass conflicts with custom base classes
if has_input:
schema_classes[input_type_name] = create_model(
input_type_name, __base__=input_cls
)
schema_classes[output_type_name] = create_model(
output_type_name, __base__=output_cls
)
# Register endpoint using helper to avoid closure capture issues
_register_schema_endpoint(
api=schema_api,
path=f"/mizan/{name}",
operation_id=camel_name,
summary=fn_class.__doc__ or f"Call {name}",
input_cls=schema_classes.get(input_type_name),
output_cls=schema_classes[output_type_name],
)
# Collect function metadata for provider generation
fn_meta_entry: dict[str, Any] = {
"name": name,
"camelName": camel_name,
"hasInput": has_input,
"inputType": input_type_name,
"outputType": output_type_name,
"transport": "websocket" if meta.get("websocket") else "http",
"isContext": meta.get("context", False),
# Form metadata
"isForm": meta.get("form", False),
"formName": meta.get("form_name"),
"formRole": meta.get("form_role"), # "schema", "validate", "submit"
}
# Affects metadata (mutation invalidation)
if meta.get("affects"):
fn_meta_entry["affects"] = meta["affects"]
# For form schema functions, extract field definitions for Zod generation
if meta.get("form") and meta.get("form_role") == "schema":
form_class = meta.get("form_class")
if form_class is not None:
try:
fn_meta_entry["formFields"] = _extract_form_fields(form_class)
except Exception as e:
# Don't fail schema generation if field extraction fails
fn_meta_entry["formFields"] = []
fn_meta_entry["formFieldsError"] = str(e)
function_metadata.append(fn_meta_entry)
# Get the OpenAPI schema from Ninja (handles all Pydantic conversion properly)
schema = schema_api.get_openapi_schema(path_prefix="")
# Add custom extension with function metadata for provider generation
schema["x-mizan-functions"] = function_metadata
# Add x-mizan-contexts: grouped context metadata with param elevation
context_groups = get_context_groups()
if context_groups:
contexts_meta: dict[str, Any] = {}
for ctx_name, fn_names in context_groups.items():
# Analyze params across all functions in the context
param_info: dict[str, dict[str, Any]] = {}
for fn_name in fn_names:
fn_cls = get_function(fn_name)
if fn_cls is None:
continue
input_cls = getattr(fn_cls, "Input", None)
if input_cls and input_cls is not BaseModel and hasattr(input_cls, "model_fields"):
for field_name, field_info in input_cls.model_fields.items():
if field_name not in param_info:
annotation = field_info.annotation
# Map Python types to JSON schema types
type_name = "string"
if annotation in (int,):
type_name = "integer"
elif annotation in (float,):
type_name = "number"
elif annotation in (bool,):
type_name = "boolean"
param_info[field_name] = {
"type": type_name,
"sharedBy": [],
}
param_info[field_name]["sharedBy"].append(fn_name)
# A param is required if ALL functions in the context declare it
for p_name, p_meta in param_info.items():
p_meta["required"] = len(p_meta["sharedBy"]) == len(fn_names)
contexts_meta[ctx_name] = {
"functions": fn_names,
"params": param_info,
}
schema["x-mizan-contexts"] = contexts_meta
# Add x-mizan metadata to each operation
for fn_meta in function_metadata:
path = f"/mizan/{fn_meta['name']}"
if path in schema.get("paths", {}):
schema["paths"][path]["post"]["x-mizan"] = {
"transport": fn_meta["transport"],
"isContext": fn_meta["isContext"],
}
return schema
def generate_openapi_json(indent: int = 2) -> str:
"""Generate OpenAPI schema as formatted JSON string."""
schema = generate_openapi_schema()
return json.dumps(schema, indent=indent)

View File

@@ -0,0 +1,632 @@
"""
mizanFormMixin - Turn Django Forms into server functions.
This mixin transforms any Django Form into mizan server functions,
preserving full Django Form functionality (validation, widgets, ModelChoiceField, etc.)
while exposing them through the unified server function API.
Usage:
from django import forms
from mizan.forms import mizanFormMixin, mizanFormMeta
class ContactForm(mizanFormMixin, forms.Form):
mizan = mizanFormMeta(
name="contact",
title="Contact Us",
submit_label="Send",
)
name = forms.CharField()
email = forms.EmailField()
message = forms.CharField(widget=forms.Textarea)
def on_submit_success(self, request):
send_email(self.cleaned_data)
return {"sent": True}
Auto-registers server functions:
- contact.schema
- contact.validate
- contact.submit
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, ClassVar
from django import forms
from django.http import HttpRequest
from pydantic import BaseModel, create_model
if TYPE_CHECKING:
from .schemas import FormValidation
def _django_field_to_python_type(field: forms.Field) -> type:
"""
Map a Django form field to a Python type for Pydantic schema generation.
This provides TypeScript with proper field types instead of generic `any`.
"""
# Handle common Django field types
if isinstance(field, forms.BooleanField):
return bool
elif isinstance(field, forms.IntegerField):
return int
elif isinstance(field, forms.FloatField):
return float
elif isinstance(field, forms.DecimalField):
return str # Decimals serialize as strings for precision
elif isinstance(field, forms.DateTimeField):
return str # ISO format string
elif isinstance(field, forms.DateField):
return str # ISO format string
elif isinstance(field, forms.TimeField):
return str # ISO format string
elif isinstance(field, forms.JSONField):
return dict | list | str | int | float | bool | None
elif isinstance(field, forms.MultipleChoiceField):
return list[str]
elif isinstance(field, forms.FileField):
return str # File path/name as string
elif isinstance(field, forms.ImageField):
return str # File path/name as string
else:
# Default to string (covers CharField, EmailField, URLField, etc.)
return str
def _create_form_input_schema(
form_class: type[forms.BaseForm],
schema_name: str,
) -> type[BaseModel]:
"""
Create a Pydantic model from Django Form fields.
This generates a typed schema for the form's input data, giving TypeScript
full LSP support (autocomplete, type checking) for form fields.
Args:
form_class: Django Form class to introspect
schema_name: Name for the generated Pydantic model (e.g., "ContactFormData")
Returns:
A Pydantic BaseModel subclass with fields matching the form
"""
# Instantiate form without data to get field definitions
try:
form = form_class()
except TypeError:
# Form requires extra args (like request) - use form_class.base_fields instead
fields_dict = getattr(form_class, "base_fields", {})
else:
fields_dict = form.fields
# Build Pydantic field definitions
pydantic_fields: dict[str, Any] = {}
for field_name, field in fields_dict.items():
python_type = _django_field_to_python_type(field)
# Optional fields (not required or has initial value)
if not field.required:
python_type = python_type | None
default = None
elif field.initial is not None:
default = field.initial
else:
default = ... # Required field
pydantic_fields[field_name] = (python_type, default)
# Create the model with a unique name
model = create_model(schema_name, **pydantic_fields)
return model
class mizanFormMeta(BaseModel):
"""
Configuration for a mizan form.
This Pydantic model provides type-safe configuration with full LSP support,
and serializes to JSON for the frontend schema.
Required:
name: API identifier (e.g., "contact" → contact.schema, contact.validate, contact.submit)
Display options:
title: Display title (default: derived from class name)
subtitle: Display subtitle
submit_label: Submit button text (default: "Submit")
Frontend behavior:
live_validation: Enable live validation as user types (default: True)
live_form_errors: Show form-level errors during live validation (default: False)
refetch_schema_on_validate: Refetch schema on each validation - useful for
dynamic choice fields (default: False)
Features:
enable_formset: Generate formset endpoints (default: False)
"""
# Required
name: str
# Display
title: str | None = None
subtitle: str | None = None
submit_label: str = "Submit"
# Frontend behavior
live_validation: bool = True
live_form_errors: bool = False
refetch_schema_on_validate: bool = False
# Features
enable_formset: bool = False
class mizanFormMixin:
"""
Mixin that exposes a Django Form as mizan server functions.
Add this mixin to any Django Form class along with a `mizan` configuration:
class ContactForm(mizanFormMixin, forms.Form):
mizan = mizanFormMeta(
name="contact",
title="Contact Us",
)
name = forms.CharField()
email = forms.EmailField()
def on_submit_success(self, request):
return {"sent": True}
This auto-registers:
- contact.schema - Get form field definitions
- contact.validate - Validate form data
- contact.submit - Submit form
Overridable methods:
get_init_kwargs(cls, request) -> dict: Extra kwargs for form instantiation
on_submit_success(self, request) -> dict | None: Handle successful submission
on_submit_failure(self, request, errors) -> None: Handle failed submission
"""
# Configuration - subclasses must define this
mizan: ClassVar[mizanFormMeta]
# Track registered forms to avoid duplicate registration
_mizan_registered: ClassVar[bool] = False
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
"""
Override to provide extra kwargs for form instantiation.
Common use: pass request or user to forms that need them.
Example:
@classmethod
def get_init_kwargs(cls, request):
return {"request": request, "user": request.user}
"""
return {}
def on_submit_success(self, request: HttpRequest) -> dict | None:
"""
Called after successful form validation and submission.
Override to handle the form submission logic.
Return a dict to include data in the response.
Example:
def on_submit_success(self, request):
self.save()
return {"id": self.instance.pk}
"""
# Default: call save() if available
if hasattr(self, "save"):
result = self.save()
# If save returns something serializable, include it
if isinstance(result, dict):
return result
return None
def on_submit_failure(self, request: HttpRequest, errors: "FormValidation") -> None:
"""
Called after form validation fails.
Override to add custom error handling, logging, etc.
"""
pass
def __init_subclass__(cls, **kwargs):
"""Auto-register when a concrete form class is defined."""
super().__init_subclass__(**kwargs)
# Only register concrete forms with mizan config defined
if _is_concrete_mizan_form(cls):
_register_form_as_server_functions(cls)
def _is_concrete_mizan_form(cls: type) -> bool:
"""
Check if a class is a concrete mizan form ready for registration.
A form is concrete if:
1. It has a `mizan` attribute that is a mizanFormMeta instance
2. It inherits from Django's BaseForm
3. It hasn't been registered yet (for this class definition)
"""
# Must have mizan config (check cls.__dict__ to avoid inheriting)
mizan_config = cls.__dict__.get("mizan")
if not isinstance(mizan_config, mizanFormMeta):
return False
# Must be a Django form
if not issubclass(cls, forms.BaseForm):
return False
# Check if already registered (handle re-imports gracefully)
if cls.__dict__.get("_mizan_registered", False):
return False
return True
def _register_form_as_server_functions(form_class: type) -> None:
"""
Register a Django Form class as mizan server functions.
Creates and registers:
- {name}.schema - Returns form field definitions
- {name}.validate - Validates form data
- {name}.submit - Validates and submits form
Each function gets a unique typed schema for better TypeScript LSP support.
"""
from .schemas import FormSchema, FormSubmitFail, FormSubmitPass, FormValidation
from .schema_utils import build_form_schema
from .validation_utils import validate_form_instance
from mizan.setup.registry import register
from mizan.client.function import ServerFunction
config: mizanFormMeta = form_class.mizan
form_name = config.name
# Mark as registered
form_class._mizan_registered = True
# Generate PascalCase name for schemas (e.g., "contact" -> "Contact")
pascal_name = "".join(
word.capitalize()
for word in form_name.replace(".", "_").replace("-", "_").split("_")
)
# NOTE: We cannot create FormDataSchema here because form fields aren't
# populated yet during __init_subclass__. We use lazy creation instead.
_form_data_schema_cache: dict[str, type[BaseModel]] = {}
def get_form_data_schema() -> type[BaseModel]:
"""Lazily create the form data schema (form fields aren't available at registration time)."""
if "schema" not in _form_data_schema_cache:
_form_data_schema_cache["schema"] = _create_form_input_schema(
form_class, f"{pascal_name}FormData"
)
return _form_data_schema_cache["schema"]
# -------------------------------------------------------------------------
# Schema Function
# -------------------------------------------------------------------------
# Schema input wraps the form data for pre-populating dynamic fields
FormSchemaInput = create_model(
f"{pascal_name}SchemaInput",
data=(dict[str, Any], {}),
)
class SchemaFunction(ServerFunction):
Input = FormSchemaInput
Output = FormSchema
_meta: ClassVar[dict] = {
"form": True,
"form_name": form_name,
"form_role": "schema",
"form_class": form_class, # Store reference for schema generation
}
def call(self, input) -> FormSchema:
init_kwargs = form_class.get_init_kwargs(self.request)
schema = build_form_schema(
form_class,
data=input.data if input else {},
**init_kwargs,
)
# Override with mizanFormMeta values
if config.title is not None:
schema.title = config.title
if config.subtitle is not None:
schema.subtitle = config.subtitle
schema.submit_label = config.submit_label
# Behavior settings are nested in schema.meta
schema.meta.live_validation = config.live_validation
schema.meta.live_form_errors = config.live_form_errors
schema.meta.refetch_schema_on_validate = config.refetch_schema_on_validate
return schema
SchemaFunction.__name__ = f"{form_name}_schema"
SchemaFunction.__qualname__ = f"{form_name}_schema"
register(SchemaFunction, f"{form_name}.schema")
# -------------------------------------------------------------------------
# Validate Function
# -------------------------------------------------------------------------
# Use generic dict input - form fields aren't available during __init_subclass__
FormValidateInput = create_model(
f"{pascal_name}ValidateInput",
data=(dict[str, Any], ...),
)
class ValidateFunction(ServerFunction):
Input = FormValidateInput
Output = FormValidation
_meta: ClassVar[dict] = {
"form": True,
"form_name": form_name,
"form_role": "validate",
}
def call(self, input) -> FormValidation:
init_kwargs = form_class.get_init_kwargs(self.request)
# Input data is already a dict
data = input.data
_, validation = validate_form_instance(
form_class,
data=data,
files=None,
**init_kwargs,
)
return validation
ValidateFunction.__name__ = f"{form_name}_validate"
ValidateFunction.__qualname__ = f"{form_name}_validate"
register(ValidateFunction, f"{form_name}.validate")
# -------------------------------------------------------------------------
# Submit Function
# -------------------------------------------------------------------------
class SubmitFunction(ServerFunction):
"""
Submit function handles both JSON and multipart/form-data.
The executor detects form functions and parses the request appropriately.
"""
# Use dict for input - form fields unknown at registration time
Input = None # Signals executor to pass raw dict
_meta: ClassVar[dict] = {
"form": True,
"form_name": form_name,
"form_role": "submit",
"multipart": True, # Signal that this function accepts multipart
}
def call(self, input) -> FormSubmitPass | FormSubmitFail:
"""Execute form submission."""
request = self.request
# Check if we have multipart data from executor
if hasattr(request, "_mizan_form_data"):
data = request._mizan_form_data
files = request._mizan_form_files
elif input is not None:
# JSON input - already a dict
data = input if isinstance(input, dict) else input.model_dump()
files = None
else:
data = {}
files = None
init_kwargs = form_class.get_init_kwargs(request)
# Create and validate form
form, validation = validate_form_instance(
form_class,
data=data,
files=files,
**init_kwargs,
)
if form.is_valid():
# Call the form's on_submit_success
result_data = form.on_submit_success(request)
return FormSubmitPass(success=True, data=result_data)
# Call the form's on_submit_failure
form.on_submit_failure(request, validation)
return FormSubmitFail(success=False, errors=validation)
SubmitFunction.__name__ = f"{form_name}_submit"
SubmitFunction.__qualname__ = f"{form_name}_submit"
SubmitFunction.Output = FormSubmitPass # For schema generation
register(SubmitFunction, f"{form_name}.submit")
# -------------------------------------------------------------------------
# Formset Functions (if enabled)
# -------------------------------------------------------------------------
if config.enable_formset:
_register_formset_functions(form_class, form_name)
def _register_formset_functions(
form_class: type,
form_name: str,
) -> None:
"""Register formset server functions for a form."""
from django.forms import formset_factory
from .schemas import (
FormsetSchema,
FormsetSubmitFail,
FormsetSubmitPass,
FormsetValidation,
)
from .schema_utils import build_form_schema
from .validation_utils import build_formset_validation
from .formset_utils import forms_to_formset_post_data
from mizan.setup.registry import register
from mizan.client.function import ServerFunction
formset_class = formset_factory(form_class)
# Generate PascalCase name for schemas
pascal_name = "".join(
word.capitalize()
for word in form_name.replace(".", "_").replace("-", "_").split("_")
)
# NOTE: We cannot create typed schemas here because form fields aren't
# populated yet during __init_subclass__. We use generic dict inputs.
# -------------------------------------------------------------------------
# Formset Schema Function
# -------------------------------------------------------------------------
FormsetSchemaInput = create_model(
f"{pascal_name}FormsetSchemaInput",
forms=(list[dict[str, Any]], []),
)
class FormsetSchemaFunction(ServerFunction):
Input = FormsetSchemaInput
Output = FormsetSchema
_meta: ClassVar[dict] = {
"form": True,
"form_name": form_name,
"form_role": "formset_schema",
}
def call(self, input) -> FormsetSchema:
init_kwargs = form_class.get_init_kwargs(self.request)
forms_data = input.forms if input else []
formset_data = forms_to_formset_post_data(forms_data)
formset = formset_class(formset_data)
return FormsetSchema(
forms=[
build_form_schema(form_class, data=fd, **init_kwargs)
for fd in forms_data
],
min_num=formset.min_num,
max_num=formset.max_num,
can_delete=formset.can_delete,
can_order=formset.can_order,
)
FormsetSchemaFunction.__name__ = f"{form_name}_formset_schema"
register(FormsetSchemaFunction, f"{form_name}.formset.schema")
# -------------------------------------------------------------------------
# Formset Validate Function
# -------------------------------------------------------------------------
# Generic dict input - form fields aren't available during __init_subclass__
FormsetValidateInput = create_model(
f"{pascal_name}FormsetValidateInput",
forms=(list[dict[str, Any]], ...),
)
class FormsetValidateFunction(ServerFunction):
Input = FormsetValidateInput
Output = FormsetValidation
_meta: ClassVar[dict] = {
"form": True,
"form_name": form_name,
"form_role": "formset_validate",
}
def call(self, input) -> FormsetValidation:
init_kwargs = form_class.get_init_kwargs(self.request)
# Input.forms is already a list of dicts
forms_data = input.forms
formset_data = forms_to_formset_post_data(forms_data)
formset = formset_class(formset_data, form_kwargs=init_kwargs)
for form in formset:
form.empty_permitted = False
return build_formset_validation(formset)
FormsetValidateFunction.__name__ = f"{form_name}_formset_validate"
register(FormsetValidateFunction, f"{form_name}.formset.validate")
# -------------------------------------------------------------------------
# Formset Submit Function
# -------------------------------------------------------------------------
# Generic dict input - form fields aren't available during __init_subclass__
FormsetSubmitInput = create_model(
f"{pascal_name}FormsetSubmitInput",
forms=(list[dict[str, Any]], ...),
)
class FormsetSubmitFunction(ServerFunction):
Input = FormsetSubmitInput
_meta: ClassVar[dict] = {
"form": True,
"form_name": form_name,
"form_role": "formset_submit",
"multipart": True,
}
def call(self, input) -> FormsetSubmitPass | FormsetSubmitFail:
request = self.request
init_kwargs = form_class.get_init_kwargs(request)
# Handle multipart vs JSON
if hasattr(request, "_mizan_form_data"):
post_data = request._mizan_form_data
files = request._mizan_form_files
elif input and hasattr(input, "forms"):
# Input.forms is already a list of dicts
forms_data = input.forms
post_data = forms_to_formset_post_data(forms_data)
files = None
else:
post_data = {}
files = None
formset = formset_class(post_data, files=files, form_kwargs=init_kwargs)
if formset.is_valid():
for form in formset.forms:
if form.cleaned_data:
form.on_submit_success(request)
return FormsetSubmitPass(success=True)
validation = build_formset_validation(formset)
# Call failure handler on each form
for form in formset.forms:
if hasattr(form, "on_submit_failure"):
form.on_submit_failure(request, validation)
return FormsetSubmitFail(success=False, errors=validation)
FormsetSubmitFunction.__name__ = f"{form_name}_formset_submit"
FormsetSubmitFunction.Output = FormsetSubmitPass
register(FormsetSubmitFunction, f"{form_name}.formset.submit")

View File

@@ -0,0 +1,16 @@
from typing import Any
def forms_to_formset_post_data(forms_data: list[dict[str, Any]]) -> dict[str, Any]:
"""
Convert a list of form dicts into Django formset-compatible POST data.
"""
formset_data: dict[str, Any] = {
"form-TOTAL_FORMS": str(len(forms_data)),
"form-INITIAL_FORMS": "0",
}
for i, form_data in enumerate(forms_data):
formset_data.update(
{f"form-{i}-{key}": value for key, value in form_data.items()}
)
return formset_data

View File

@@ -0,0 +1,187 @@
import re
from typing import Any, Optional
from django import forms
from django.forms import Field
from .schemas import FieldChoice, FieldSchema, FormMeta, FormSchema
def create_form_instance(
form_class: type[forms.BaseForm],
data: Optional[dict] = None,
files: Optional[dict] = None,
**kwargs,
) -> forms.BaseForm:
"""
Create a form instance, gracefully handling kwargs that the form doesn't accept.
Some Django forms (like allauth's) accept `request` in __init__, others don't.
This function tries with all kwargs first, then progressively removes kwargs
that cause TypeErrors until instantiation succeeds.
"""
# Common kwargs that forms may or may not accept
optional_kwargs = ['request', 'user', 'instance']
# Build init kwargs
init_kwargs = dict(kwargs)
if data is not None:
init_kwargs['data'] = data
if files is not None:
init_kwargs['files'] = files
while True:
try:
return form_class(**init_kwargs)
except TypeError as e:
error_msg = str(e)
# Check if it's an unexpected keyword argument error
if "unexpected keyword argument" not in error_msg:
raise
# Find which kwarg caused the problem and remove it
removed = False
for kwarg in optional_kwargs:
if f"'{kwarg}'" in error_msg and kwarg in init_kwargs:
init_kwargs.pop(kwarg)
removed = True
break
# If we couldn't identify/remove the problematic kwarg, re-raise
if not removed:
raise
def _get_choices(field: Field) -> Optional[list[FieldChoice]]:
"""
Extract choices from a field, handling ModelChoiceField properly.
ModelChoiceField returns ModelChoiceIteratorValue which is not JSON serializable.
"""
if not hasattr(field, "choices"):
return None
choices: list[FieldChoice] = []
for raw_value, label in field.choices:
value = getattr(
raw_value, "value", raw_value
) # ModelChoiceIteratorValue -> .value
choices.append(FieldChoice(value=str(value), label=str(label)))
return choices
def _get_initial(value: Any) -> Any:
"""Convert initial value to JSON-serializable format."""
if value is None:
return None
if hasattr(value, "isoformat"):
return value.isoformat()
if hasattr(value, "pk"):
return value.pk
if hasattr(value, "__iter__") and not isinstance(value, (str, bytes)):
return [item.pk if hasattr(item, "pk") else item for item in value]
return value
def _class_name_to_title(name: str) -> str:
"""
Convert a class name to a human-readable title.
e.g., 'LoginForm' -> 'Login', 'ResetPasswordForm' -> 'Reset Password'
"""
# Remove 'Form' suffix
name = re.sub(r"Form$", "", name)
# Insert spaces before capital letters
name = re.sub(r"([a-z])([A-Z])", r"\1 \2", name)
return name
def _class_name_to_slug(name: str) -> str:
"""
Convert a class name to a slug.
e.g., 'LoginForm' -> 'login', 'ResetPasswordForm' -> 'reset_password'
"""
# Remove 'Form' suffix
name = re.sub(r"Form$", "", name)
# Insert underscores before capital letters and lowercase
name = re.sub(r"([a-z])([A-Z])", r"\1_\2", name)
return name.lower()
def build_form_schema(
form_class: type[forms.BaseForm],
data: Optional[dict] = None,
**kwargs,
) -> FormSchema:
"""
Produce a FormSchema for the given Django form class and (optional) data.
The form class can define metadata via an inner Meta class:
class MyForm(forms.Form):
class Meta:
form_name = "my_form"
title = "My Form Title"
subtitle = "Optional description"
submit_label = "Submit"
# Frontend behavior (optional)
refetch_schema_on_validate = False # Set True for dynamic choice fields
live_validation = True # Set False to disable live validation
live_form_errors = False # Set True to show form errors live
If not provided, sensible defaults are derived from the class name.
"""
form = create_form_instance(form_class, data=data, **kwargs)
# Extract metadata from form's Meta class
form_meta = getattr(form_class, "Meta", None)
# Get form name (used as identifier)
name = getattr(form_meta, "form_name", None)
if name is None:
name = _class_name_to_slug(form_class.__name__)
# Get title (human-readable heading)
title = getattr(form_meta, "title", None)
if title is None:
title = _class_name_to_title(form_class.__name__)
# Get optional subtitle
subtitle = getattr(form_meta, "subtitle", None)
# Get submit button label
submit_label = getattr(form_meta, "submit_label", None)
if submit_label is None:
submit_label = "Submit"
# Build frontend behavior metadata
frontend_meta = FormMeta(
refetch_schema_on_validate=getattr(form_meta, "refetch_schema_on_validate", False),
live_validation=getattr(form_meta, "live_validation", True),
live_form_errors=getattr(form_meta, "live_form_errors", False),
)
return FormSchema(
name=name,
title=title,
subtitle=subtitle,
submit_label=submit_label,
fields=[
FieldSchema(
name=name,
label=str(field.label or name.replace("_", " ").title()),
type=getattr(field.widget, "input_type", "text"),
widget=field.widget.__class__.__name__,
required=field.required,
disabled=field.disabled,
help_text=str(field.help_text) if field.help_text else "",
initial=_get_initial(field.initial),
max_length=getattr(field, "max_length", None),
min_length=getattr(field, "min_length", None),
choices=_get_choices(field),
)
for name, field in form.fields.items()
],
meta=frontend_meta,
)

View File

@@ -0,0 +1,103 @@
from typing import Any, Optional
from ninja import Schema
# Form metadata schema
class FormMeta(Schema):
"""
Metadata controlling frontend form behavior.
Attributes:
refetch_schema_on_validate: If True, frontend should refetch schema on each
validation (useful for dynamic choice fields). Default False.
live_validation: If False, frontend should disable live validation entirely.
Useful for sensitive forms like login. Default True.
live_form_errors: If True, show form-level errors during live validation.
Form errors are things like "Invalid credentials" vs field errors like
"This field is required". Default False for security.
"""
refetch_schema_on_validate: bool = False
live_validation: bool = True
live_form_errors: bool = False
# Field-level schemas
class FieldChoice(Schema):
value: str
label: str
class FieldError(Schema):
message: str
code: Optional[str]
class FieldErrorList(Schema):
field: str
errors: list[FieldError]
class FieldSchema(Schema):
name: str
label: str
type: str
widget: str
required: bool
disabled: bool
help_text: str
initial: Any
max_length: Optional[int]
min_length: Optional[int]
choices: Optional[list[FieldChoice]]
# Form-level schemas
class FormSchema(Schema):
"""Schema returned by /schema endpoint with form metadata and fields."""
# Form metadata
name: str
title: str
subtitle: Optional[str]
submit_label: str
# Fields
fields: list[FieldSchema]
# Frontend behavior metadata
meta: FormMeta = FormMeta()
class FormValidation(Schema):
errors: list[FieldErrorList]
class FormSubmitPass(Schema):
success: bool
data: Optional[dict] = None
class FormSubmitFail(Schema):
success: bool
errors: FormValidation
# Formset-level schemas
class FormsetSchema(Schema):
forms: list[FormSchema]
min_num: int
max_num: int
can_delete: bool
can_order: bool
class FormsetValidation(Schema):
general: list[str]
per_form: list[FormValidation]
class FormsetSubmitPass(Schema):
success: bool
class FormsetSubmitFail(Schema):
success: bool
errors: FormsetValidation

View File

@@ -0,0 +1,72 @@
from typing import Any
from django import forms
from django.core.files.uploadedfile import UploadedFile
from django.utils.datastructures import MultiValueDict
from .schemas import (
FieldError,
FieldErrorList,
FormValidation,
FormsetValidation,
)
from .schema_utils import create_form_instance
def validate_form_instance(
form_class: type[forms.BaseForm],
data: dict,
files: MultiValueDict[str, UploadedFile] | None = None,
**kwargs: Any,
) -> tuple[forms.BaseForm, FormValidation]:
"""
Build a form instance and return (form, structured_validation_errors).
"""
form = create_form_instance(form_class, data=data, files=files, initial=data, **kwargs)
# Run validation
form.is_valid()
validation = FormValidation(
errors=[
FieldErrorList(
field=field_name,
errors=[
FieldError(
message=str(e.message) if hasattr(e, 'message') else str(e),
code=getattr(e, "code", None),
)
for e in field_errors.as_data()
],
)
for field_name, field_errors in form.errors.items()
]
)
return form, validation
def build_formset_validation(formset: forms.BaseFormSet) -> FormsetValidation:
"""
Turn a Django formset into a FormsetValidation structure.
"""
return FormsetValidation(
general=[str(e) if e else "" for e in formset.non_form_errors()],
per_form=[
FormValidation(
errors=[
FieldErrorList(
field=field_name,
errors=[
FieldError(
message=str(e.message) if hasattr(e, 'message') else str(e),
code=getattr(e, "code", None),
)
for e in field_errors.as_data()
],
)
for field_name, field_errors in form.errors.items()
]
)
for form in formset
],
)

View File

@@ -0,0 +1,25 @@
"""
mizan Allauth Integration
Backend support for django-allauth with mizan server functions.
Provides:
- Auth contexts (auth_status, user) - required by frontend allauth module
- Allauth form wrappers - expose allauth forms as server functions
Usage:
# In your app's apps.py
class MyAppConfig(AppConfig):
def ready(self):
import mizan.allauth.forms # noqa - registers forms
import mizan.allauth.contexts # noqa - registers contexts
"""
from .contexts import auth_status, user, AuthStatusOutput, UserOutput
__all__ = [
"auth_status",
"user",
"AuthStatusOutput",
"UserOutput",
]

View File

@@ -0,0 +1,118 @@
"""
Auth contexts for mizan Allauth integration.
These are the core auth primitives that the frontend allauth module depends on.
Separated into two concerns:
- auth_status: Authentication state and permission guards (fast, no DB hit with JWT)
- user: Full user profile data (may require DB query for JWT auth)
Both are registered as global contexts for SSR hydration.
"""
from django.http import HttpRequest
from pydantic import BaseModel
from mizan.client import client
# =============================================================================
# Auth Status Context
# =============================================================================
class AuthStatusOutput(BaseModel):
"""Authentication status and permission guards."""
is_authenticated: bool
user_id: int | None = None
is_staff: bool = False
is_superuser: bool = False
@client(context="global")
def auth_status(request: HttpRequest) -> AuthStatusOutput:
"""
Auth status context - provides authentication state and guards.
This works identically for both session and JWT auth. The data comes
from the request.user object (either full User or JWTUser with claims).
Frontend:
const auth = useAuthStatus()
if (auth.is_authenticated) { ... }
if (auth.is_staff) { ... }
"""
user = request.user
if not user.is_authenticated:
return AuthStatusOutput(is_authenticated=False)
return AuthStatusOutput(
is_authenticated=True,
user_id=user.id,
is_staff=user.is_staff,
is_superuser=user.is_superuser,
)
# =============================================================================
# User Profile Context
# =============================================================================
class UserOutput(BaseModel):
"""Full user profile data."""
id: int
email: str
first_name: str = ""
last_name: str = ""
@client(context="global")
def user(request: HttpRequest) -> UserOutput | None:
"""
User profile context - provides full user data.
Unlike auth_status, this may require a DB query (for JWT auth where
the user object is a minimal JWTUser with only claims).
Returns None if not authenticated.
Frontend:
const user = useUser()
if (user) {
console.log(user.email)
}
"""
req_user = request.user
if not req_user.is_authenticated:
return None
# Check if we have full user data or just JWT claims
if hasattr(req_user, "email") and req_user.email:
# Full User object (session auth)
return UserOutput(
id=req_user.id,
email=req_user.email,
first_name=getattr(req_user, "first_name", "") or "",
last_name=getattr(req_user, "last_name", "") or "",
)
# JWTUser - need to fetch from DB
from django.contrib.auth import get_user_model
User = get_user_model()
try:
db_user = User.objects.get(pk=req_user.id)
return UserOutput(
id=db_user.id,
email=db_user.email,
first_name=db_user.first_name or "",
last_name=db_user.last_name or "",
)
except User.DoesNotExist:
return None

View File

@@ -0,0 +1,408 @@
"""
Allauth forms as mizan server functions.
This module wraps allauth forms with mizanFormMixin, exposing them as
typed server functions for the React frontend.
Each form becomes three server functions:
- {name}.schema - Get form field definitions
- {name}.validate - Validate form data
- {name}.submit - Submit form
Import this module in your app's ready() to register the forms:
class MyAppConfig(AppConfig):
def ready(self):
import mizan.allauth.forms # noqa
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any
from django.http import HttpRequest
from mizan.forms import mizanFormMixin, mizanFormMeta
# Account forms
from allauth.account.forms import (
AddEmailForm,
ChangePasswordForm,
ConfirmLoginCodeForm,
LoginForm,
RequestLoginCodeForm,
ResetPasswordForm,
ResetPasswordKeyForm,
SetPasswordForm,
SignupForm,
UserTokenForm,
)
# Password reauthentication form - conditionally import
try:
from allauth.account.forms import ReauthenticateForm
HAS_REAUTH = True
except ImportError:
HAS_REAUTH = False
# MFA forms - conditionally import
try:
from allauth.mfa.base.forms import AuthenticateForm as MFAAuthenticateForm
from allauth.mfa.base.forms import ReauthenticateForm as MFAReauthenticateForm
from allauth.mfa.totp.forms import ActivateTOTPForm, DeactivateTOTPForm
from allauth.mfa.recovery_codes.forms import GenerateRecoveryCodesForm
HAS_MFA = True
except ImportError:
HAS_MFA = False
# WebAuthn forms (if available)
try:
from allauth.mfa.webauthn.forms import AuthenticateWebAuthnForm
HAS_WEBAUTHN = True
except ImportError:
HAS_WEBAUTHN = False
if TYPE_CHECKING:
from mizan.forms.schemas import FormValidation
# =============================================================================
# Account Forms
# =============================================================================
class mizanLoginForm(LoginForm, mizanFormMixin):
"""Sign in with email and password."""
mizan = mizanFormMeta(
name="login",
title="Sign In",
subtitle="Welcome back. Enter your credentials to continue.",
submit_label="Sign In",
live_validation=False, # Don't validate credentials as user types
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.login(request)
return None
class mizanSignupForm(SignupForm, mizanFormMixin):
"""Create a new account."""
mizan = mizanFormMeta(
name="signup",
title="Create Account",
subtitle="Enter your details to get started.",
submit_label="Create Account",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save(request)
return None
class mizanAddEmailForm(AddEmailForm, mizanFormMixin):
"""Add another email address to your account."""
mizan = mizanFormMeta(
name="add_email",
title="Add Email Address",
subtitle="Add another email address to your account.",
submit_label="Add Email",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request, "user": request.user}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save()
return None
class mizanChangePasswordForm(ChangePasswordForm, mizanFormMixin):
"""Change your account password."""
mizan = mizanFormMeta(
name="change_password",
title="Change Password",
subtitle="Update your password to keep your account secure.",
submit_label="Change Password",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request, "user": request.user}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save()
return None
class mizanSetPasswordForm(SetPasswordForm, mizanFormMixin):
"""Set a password for accounts created via social login."""
mizan = mizanFormMeta(
name="set_password",
title="Set Password",
subtitle="Create a password for your account.",
submit_label="Set Password",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request, "user": request.user}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save()
return None
class mizanResetPasswordForm(ResetPasswordForm, mizanFormMixin):
"""Request a password reset email."""
mizan = mizanFormMeta(
name="reset_password",
title="Reset Password",
subtitle="Enter your email address and we'll send you a link to reset your password.",
submit_label="Send Reset Link",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save(request)
return None
class mizanResetPasswordKeyForm(ResetPasswordKeyForm, mizanFormMixin):
"""Set a new password using a reset key."""
mizan = mizanFormMeta(
name="reset_password_from_key",
title="Set New Password",
subtitle="Enter your new password below.",
submit_label="Reset Password",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request, "user": request.user}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save()
return None
class mizanRequestLoginCodeForm(RequestLoginCodeForm, mizanFormMixin):
"""Request a login code via email."""
mizan = mizanFormMeta(
name="request_login_code",
title="Sign In with Code",
subtitle="Enter your email address and we'll send you a login code.",
submit_label="Send Code",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save()
return None
class mizanConfirmLoginCodeForm(ConfirmLoginCodeForm, mizanFormMixin):
"""Confirm a login code."""
mizan = mizanFormMeta(
name="confirm_login_code",
title="Enter Code",
subtitle="Enter the code we sent to your email.",
submit_label="Verify Code",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save()
return None
class mizanUserTokenForm(UserTokenForm, mizanFormMixin):
"""Verify an email with a token."""
mizan = mizanFormMeta(
name="user_token",
title="Verify Email",
subtitle="Enter the verification code from your email.",
submit_label="Verify",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save()
return None
# Password reauthentication - conditionally define
if HAS_REAUTH:
class mizanReauthenticateForm(ReauthenticateForm, mizanFormMixin):
"""Re-authenticate with password for sensitive actions."""
mizan = mizanFormMeta(
name="reauthenticate",
title="Confirm Your Identity",
subtitle="Please enter your password to continue.",
submit_label="Confirm",
live_validation=False,
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request, "user": request.user}
def on_submit_success(self, request: HttpRequest) -> dict | None:
from allauth.account.internal.flows import reauthentication
reauthentication.reauthenticate_by_password(request)
return None
# =============================================================================
# MFA Forms
# =============================================================================
if HAS_MFA:
class mizanMFAAuthenticateForm(MFAAuthenticateForm, mizanFormMixin):
"""Authenticate with MFA during login."""
mizan = mizanFormMeta(
name="mfa_authenticate",
title="Two-Factor Authentication",
subtitle="Enter your authentication code to continue.",
submit_label="Verify",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request, "user": request.user}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save()
return None
class mizanMFAReauthenticateForm(MFAReauthenticateForm, mizanFormMixin):
"""Re-authenticate with MFA for sensitive actions."""
mizan = mizanFormMeta(
name="mfa_reauthenticate",
title="Confirm Your Identity",
subtitle="Enter your authentication code to continue.",
submit_label="Confirm",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request, "user": request.user}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save()
return None
class mizanActivateTOTPForm(ActivateTOTPForm, mizanFormMixin):
"""Activate TOTP authenticator."""
mizan = mizanFormMeta(
name="activate_totp",
title="Set Up Authenticator",
subtitle="Enter the code from your authenticator app to complete setup.",
submit_label="Activate",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request, "user": request.user}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save()
return None
class mizanDeactivateTOTPForm(DeactivateTOTPForm, mizanFormMixin):
"""Deactivate TOTP authenticator."""
mizan = mizanFormMeta(
name="deactivate_totp",
title="Disable Authenticator",
subtitle="Enter your password to disable two-factor authentication.",
submit_label="Disable",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request, "user": request.user}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save()
return None
class mizanGenerateRecoveryCodesForm(GenerateRecoveryCodesForm, mizanFormMixin):
"""Generate new recovery codes."""
mizan = mizanFormMeta(
name="generate_recovery_codes",
title="Recovery Codes",
subtitle="Generate new recovery codes for your account.",
submit_label="Generate Codes",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request, "user": request.user}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save()
return None
if HAS_WEBAUTHN:
class mizanAuthenticateWebAuthnForm(AuthenticateWebAuthnForm, mizanFormMixin):
"""Authenticate with WebAuthn security key."""
mizan = mizanFormMeta(
name="webauthn_authenticate",
title="Security Key",
subtitle="Use your security key to authenticate.",
submit_label="Use Security Key",
)
@classmethod
def get_init_kwargs(cls, request: HttpRequest) -> dict[str, Any]:
return {"request": request, "user": request.user}
def on_submit_success(self, request: HttpRequest) -> dict | None:
self.save()
return None

View File

@@ -0,0 +1,71 @@
"""
mizan.jwt - JWT authentication for server functions.
Provides:
- Server functions for obtaining/refreshing JWT tokens
- JWT authentication utilities for validating tokens
Server Functions:
- jwt_obtain: Convert authenticated session to JWT tokens
- jwt_refresh: Refresh tokens using a refresh token
Usage in apps.py or urls.py (to register the functions):
import mizan.jwt.functions # noqa: F401
Note: This module is purpose-built for mizan server functions.
For Django Ninja API authentication, use mizan.jwt.security directly.
"""
# Server functions (import to register with @client decorator)
from .functions import jwt_obtain, jwt_refresh
# Token utilities
from .tokens import (
create_token_pair,
create_access_token,
create_refresh_token,
decode_token,
refresh_tokens,
TokenPair,
TokenPayload,
JWTUser,
)
# Settings
from .settings import get_settings, JWTSettings
# Security (Ninja API auth) - lazy import to avoid triggering
# django-ninja's settings access at module load time.
# Use: from mizan.jwt.security import jwt_auth
def __getattr__(name):
if name in ("JWTAuth", "jwt_auth"):
from .security import JWTAuth, jwt_auth
globals()["JWTAuth"] = JWTAuth
globals()["jwt_auth"] = jwt_auth
return globals()[name]
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
__all__ = [
# Server functions
"jwt_obtain",
"jwt_refresh",
# Token utilities
"create_token_pair",
"create_access_token",
"create_refresh_token",
"decode_token",
"refresh_tokens",
"TokenPair",
"TokenPayload",
"JWTUser",
# Settings
"get_settings",
"JWTSettings",
# Security (lazy)
"JWTAuth",
"jwt_auth",
]

View File

@@ -0,0 +1,101 @@
"""
JWT Server Functions
JWT token operations exposed as mizan server functions.
Works over WebSocket RPC (primary) or HTTP fallback.
"""
from django.http import HttpRequest
from pydantic import BaseModel
from mizan.client import client
from mizan.jwt.tokens import create_token_pair, refresh_tokens
class TokenPairOutput(BaseModel):
"""JWT token pair response."""
access_token: str
refresh_token: str
expires_in: int
class JWTError(BaseModel):
"""JWT operation error."""
error: str
@client
def jwt_obtain(request: HttpRequest) -> TokenPairOutput:
"""
Obtain JWT tokens from an authenticated session.
Requires session authentication (cookie or WebSocket session).
Returns access and refresh tokens that can be used for stateless auth.
The tokens include user claims (is_staff, is_superuser) so that
subsequent JWT-authenticated requests don't need a database query.
Usage:
const { access_token, refresh_token } = await call('jwt_obtain')
// Use access_token in Authorization: Bearer header
"""
user = request.user
if not user.is_authenticated:
raise PermissionError("Authentication required")
# Get session key - for WebSocket, this comes from the scope
session = getattr(request, "session", None)
if session is None:
# WebSocket request adapter - session is a dict, not SessionBase
session_key = (
getattr(request, "_scope", {}).get("session", {}).get("_session_key")
)
if not session_key:
raise PermissionError("No session available")
else:
# HTTP request - ensure session is saved
if not session.session_key:
session.save()
session_key = session.session_key
# Include user claims in the token for stateless auth
tokens = create_token_pair(
user.pk,
session_key,
is_staff=getattr(user, "is_staff", False),
is_superuser=getattr(user, "is_superuser", False),
)
return TokenPairOutput(
access_token=tokens.access_token,
refresh_token=tokens.refresh_token,
expires_in=tokens.expires_in,
)
@client
def jwt_refresh(request: HttpRequest, refresh_token: str) -> TokenPairOutput:
"""
Refresh JWT tokens using a refresh token.
Does not require session authentication - the refresh token itself
contains the session reference and is validated against the session store.
If the original session has been destroyed (user logged out), this fails.
Usage:
const { access_token, refresh_token } = await call('jwt_refresh', { refresh_token })
"""
tokens = refresh_tokens(refresh_token)
if tokens is None:
raise PermissionError("Invalid or expired refresh token")
return TokenPairOutput(
access_token=tokens.access_token,
refresh_token=tokens.refresh_token,
expires_in=tokens.expires_in,
)

View File

@@ -0,0 +1,64 @@
"""
Django Ninja Security Classes for JWT Authentication
Provides authentication classes that can be used with Django Ninja's
auth parameter to protect API endpoints.
"""
from django.http import HttpRequest
from ninja.security import HttpBearer
from .tokens import decode_token, JWTUser
class JWTAuth(HttpBearer):
"""
JWT Bearer token authentication for Django Ninja.
Usage:
from ninja_jwt_session import jwt_auth
@api.get("/protected/", auth=jwt_auth)
def protected_endpoint(request):
return {"user_id": request.user.id}
Or globally:
api = NinjaExtraAPI(auth=[django_auth, jwt_auth])
The token must be passed in the Authorization header:
Authorization: Bearer <access_token>
IMPORTANT: This is stateless - no database query is made.
request.user is a JWTUser object with id, is_staff, is_superuser.
If you need the full User object, query it explicitly:
user = User.objects.get(pk=request.user.id)
"""
def authenticate(self, request: HttpRequest, token: str):
"""
Validate the JWT and return a JWTUser if valid.
Returns None (authentication failed) if:
- Token is invalid or expired
- Token is not an access token
Note: No database query is made. The JWTUser is created from
token claims. This is truly stateless authentication.
"""
# Decode and validate the token
payload = decode_token(token, expected_type="access")
if payload is None:
return None
# Create JWTUser from token claims - NO DATABASE QUERY
jwt_user = JWTUser(payload)
# Set request.user for compatibility with code expecting it
request.user = jwt_user
return jwt_user
# Singleton instance for convenience
jwt_auth = JWTAuth()

View File

@@ -0,0 +1,118 @@
"""
JWT Hybrid Settings
Configuration is read from Django settings with sensible defaults.
Supports both symmetric (HS256) and asymmetric (RS256) algorithms.
"""
from dataclasses import dataclass
from functools import lru_cache
from django.conf import settings as django_settings
@dataclass
class JWTSettings:
"""JWT configuration."""
# Signing keys
private_key: str # Used for signing (required)
public_key: str # Used for verification (same as private for HS256)
# Algorithm
algorithm: str # HS256, RS256, etc.
# Token lifetimes (seconds)
access_token_expires_in: int
refresh_token_expires_in: int
# Security options
validate_session: bool # Check session exists on token validation
rotate_refresh_token: bool # Issue new refresh token on refresh
@lru_cache
def get_settings() -> JWTSettings:
"""
Load JWT settings from Django settings.
Settings:
JWT_PRIVATE_KEY: Signing key (required)
JWT_PUBLIC_KEY: Verification key (defaults to private key for HS256)
JWT_ALGORITHM: Algorithm to use (default: HS256)
JWT_ACCESS_TOKEN_EXPIRES_IN: Access token lifetime (default: 300)
JWT_REFRESH_TOKEN_EXPIRES_IN: Refresh token lifetime (default: 604800)
JWT_VALIDATE_SESSION: Validate session on token use (default: True)
JWT_ROTATE_REFRESH_TOKEN: Rotate refresh tokens (default: True)
"""
private_key = getattr(django_settings, "JWT_PRIVATE_KEY", None)
if not private_key:
# Fall back to allauth setting if available (for compatibility)
headless_key = getattr(django_settings, "HEADLESS_JWT_PRIVATE_KEY", None)
if headless_key:
private_key = headless_key
if private_key is None:
raise ValueError(
"JWT_PRIVATE_KEY must be set in Django settings. "
"For HS256, use a secure random string. "
"For RS256, use a PEM-encoded RSA private key."
)
# Auto-detect algorithm based on key format if not explicitly set
algorithm = getattr(django_settings, "JWT_ALGORITHM", None)
if algorithm is None:
# Auto-detect: if key looks like PEM, use RS256; otherwise HS256
if isinstance(private_key, str) and private_key.strip().startswith("-----BEGIN"):
algorithm = "RS256"
else:
algorithm = "HS256"
# For symmetric algorithms, public key = private key
if algorithm.startswith("HS"):
public_key = private_key
else:
public_key = getattr(django_settings, "JWT_PUBLIC_KEY", None)
if public_key is None:
# Try to extract public key from private key for RSA
try:
from cryptography.hazmat.primitives import serialization
from cryptography.hazmat.primitives.serialization import load_pem_private_key
private_key_obj = load_pem_private_key(
private_key.encode() if isinstance(private_key, str) else private_key,
password=None,
)
public_key = private_key_obj.public_key().public_bytes(
encoding=serialization.Encoding.PEM,
format=serialization.PublicFormat.SubjectPublicKeyInfo,
).decode()
except Exception:
raise ValueError(
f"JWT_PUBLIC_KEY must be set for {algorithm} algorithm, "
"or JWT_PRIVATE_KEY must be a valid PEM-encoded RSA key."
)
return JWTSettings(
private_key=private_key,
public_key=public_key,
algorithm=algorithm,
access_token_expires_in=getattr(
django_settings,
"JWT_ACCESS_TOKEN_EXPIRES_IN",
getattr(django_settings, "HEADLESS_JWT_ACCESS_TOKEN_EXPIRES_IN", 300),
),
refresh_token_expires_in=getattr(
django_settings,
"JWT_REFRESH_TOKEN_EXPIRES_IN",
getattr(django_settings, "HEADLESS_JWT_REFRESH_TOKEN_EXPIRES_IN", 604800),
),
validate_session=getattr(
django_settings, "JWT_VALIDATE_SESSION", True
),
rotate_refresh_token=getattr(
django_settings, "JWT_ROTATE_REFRESH_TOKEN", True
),
)

View File

@@ -0,0 +1,245 @@
"""
JWT Token Creation and Validation
Uses PyJWT directly - no allauth dependency.
Tokens are tied to Django sessions for immediate revocation on logout.
"""
import time
from typing import NamedTuple
import jwt
from django.contrib.sessions.backends.base import SessionBase
from .settings import get_settings
class TokenPair(NamedTuple):
"""Access and refresh token pair."""
access_token: str
refresh_token: str
expires_in: int
class TokenPayload(NamedTuple):
"""Decoded token payload."""
user_id: int | str
session_key: str
token_type: str
is_staff: bool
is_superuser: bool
exp: int
iat: int
class JWTUser:
"""
Minimal user object created from JWT claims.
Used as request.user for JWT-authenticated requests.
No database query required - all data comes from the token.
If you need the full User object with all fields, query explicitly:
user = User.objects.get(pk=request.user.id)
"""
def __init__(self, payload: TokenPayload):
self.id = int(payload.user_id) if isinstance(payload.user_id, str) else payload.user_id
self.pk = self.id
self.is_staff = payload.is_staff
self.is_superuser = payload.is_superuser
self.is_authenticated = True
self.is_anonymous = False
self.is_active = True # Assumed active if they have a valid token
def __str__(self):
return f"JWTUser(id={self.id})"
def __repr__(self):
return f"JWTUser(id={self.id}, is_staff={self.is_staff}, is_superuser={self.is_superuser})"
def create_access_token(
user_id: int | str,
session_key: str,
*,
is_staff: bool = False,
is_superuser: bool = False,
) -> str:
"""
Create a short-lived access token.
The token contains:
- sub: user ID
- sid: session key (for revocation checking)
- staff: is_staff flag
- super: is_superuser flag
- type: "access"
- iat: issued at
- exp: expiration
"""
settings = get_settings()
now = int(time.time())
payload = {
"sub": str(user_id),
"sid": session_key,
"staff": is_staff,
"super": is_superuser,
"type": "access",
"iat": now,
"exp": now + settings.access_token_expires_in,
}
return jwt.encode(
payload,
settings.private_key,
algorithm=settings.algorithm,
)
def create_refresh_token(
user_id: int | str,
session_key: str,
*,
is_staff: bool = False,
is_superuser: bool = False,
) -> str:
"""
Create a longer-lived refresh token.
The token contains:
- sub: user ID
- sid: session key (for revocation checking)
- staff: is_staff flag
- super: is_superuser flag
- type: "refresh"
- iat: issued at
- exp: expiration
"""
settings = get_settings()
now = int(time.time())
payload = {
"sub": str(user_id),
"sid": session_key,
"staff": is_staff,
"super": is_superuser,
"type": "refresh",
"iat": now,
"exp": now + settings.refresh_token_expires_in,
}
return jwt.encode(
payload,
settings.private_key,
algorithm=settings.algorithm,
)
def create_token_pair(
user_id: int | str,
session_key: str,
*,
is_staff: bool = False,
is_superuser: bool = False,
) -> TokenPair:
"""Create both access and refresh tokens."""
settings = get_settings()
return TokenPair(
access_token=create_access_token(
user_id, session_key, is_staff=is_staff, is_superuser=is_superuser
),
refresh_token=create_refresh_token(
user_id, session_key, is_staff=is_staff, is_superuser=is_superuser
),
expires_in=settings.access_token_expires_in,
)
def decode_token(token: str, expected_type: str = None) -> TokenPayload | None:
"""
Decode and validate a JWT token.
Returns None if:
- Token is invalid or expired
- Token type doesn't match expected_type (if specified)
"""
settings = get_settings()
try:
payload = jwt.decode(
token,
settings.public_key,
algorithms=[settings.algorithm],
)
except jwt.PyJWTError:
return None
# Validate token type if specified
if expected_type and payload.get("type") != expected_type:
return None
return TokenPayload(
user_id=payload["sub"],
session_key=payload["sid"],
token_type=payload["type"],
is_staff=payload.get("staff", False),
is_superuser=payload.get("super", False),
exp=payload["exp"],
iat=payload["iat"],
)
def validate_session(session_key: str) -> bool:
"""
Check if a session is still valid (exists and not expired).
This is the key to immediate logout revocation - if the session
is destroyed, tokens tied to it become invalid.
"""
from importlib import import_module
from django.conf import settings as django_settings
jwt_settings = get_settings()
if not jwt_settings.validate_session:
return True
# Use the configured session engine
engine = import_module(django_settings.SESSION_ENGINE)
SessionStore = engine.SessionStore
# Try to load the session
session = SessionStore(session_key=session_key)
# Check if session exists and is not empty
# exists() is more reliable than checking load() result
return session.exists(session_key)
def refresh_tokens(refresh_token: str) -> TokenPair | None:
"""
Use a refresh token to obtain new tokens.
Returns None if:
- Refresh token is invalid or expired
- Associated session no longer exists
"""
payload = decode_token(refresh_token, expected_type="refresh")
if payload is None:
return None
# Validate the session still exists
if not validate_session(payload.session_key):
return None
# Issue new token pair with same claims
return create_token_pair(
payload.user_id,
payload.session_key,
is_staff=payload.is_staff,
is_superuser=payload.is_superuser,
)

View File

@@ -0,0 +1,35 @@
"""
Export channels schema as OpenAPI JSON for TypeScript generation.
Uses Django Ninja's schema generation for robust Pydantic→OpenAPI conversion.
The schema is consumed by openapi-typescript for type generation.
Usage:
python manage.py export_channels_schema
"""
import json
from django.core.management.base import BaseCommand
class Command(BaseCommand):
help = "Export channels schema as OpenAPI JSON for TypeScript code generation"
def add_arguments(self, parser):
parser.add_argument(
"--indent",
type=int,
default=2,
help="JSON indentation level (default: 2, use 0 for compact)",
)
def handle(self, *args, **options):
from mizan.channels import get_channels_openapi_schema
schema = get_channels_openapi_schema()
indent = options["indent"] if options["indent"] > 0 else None
output = json.dumps(schema, indent=indent)
self.stdout.write(output)

View File

@@ -0,0 +1,49 @@
"""
Export mizan Schema
Management command to export the mizan OpenAPI schema for TypeScript code generation.
The schema is consumed by openapi-typescript for robust type generation.
Usage:
python manage.py export_mizan_schema # Output to stdout
python manage.py export_mizan_schema --output schema.json # Output to file
"""
import json
from pathlib import Path
from django.core.management.base import BaseCommand
from mizan.export import generate_openapi_schema
class Command(BaseCommand):
help = "Export mizan OpenAPI schema for TypeScript code generation"
def add_arguments(self, parser):
parser.add_argument(
"--output",
"-o",
type=str,
default=None,
help="Output file path. If not specified, outputs to stdout.",
)
parser.add_argument(
"--indent",
type=int,
default=2,
help="JSON indentation level (0 for compact output)",
)
def handle(self, *args, **options):
schema = generate_openapi_schema()
indent = options["indent"] if options["indent"] > 0 else None
json_output = json.dumps(schema, indent=indent)
if options["output"]:
output_path = Path(options["output"])
output_path.parent.mkdir(parents=True, exist_ok=True)
output_path.write_text(json_output)
self.stdout.write(self.style.SUCCESS(f"Schema written to {output_path}"))
else:
self.stdout.write(json_output)

View File

@@ -0,0 +1,71 @@
"""
mizan.setup - Integration and registration utilities.
This subpackage contains everything developers need to integrate mizan:
- Registry for server functions and channels
- Auto-discovery for apps
- Configuration settings
Usage:
from mizan.setup import mizan_clients, register, get_function
"""
from .registry import (
register,
register_as,
register_form,
register_compose,
get_function,
get_channel,
get_compose,
get_view,
get_all_functions,
get_all_channels,
get_all_compositions,
get_registry,
get_schema,
get_contexts,
get_context_groups,
get_forms,
clear_registry,
)
from .discovery import (
mizan_clients,
mizan_module,
)
from .settings import (
mizanSettings,
get_settings,
clear_settings_cache,
)
__all__ = [
# Registration
"register",
"register_as",
"register_form",
"register_compose",
# Lookup
"get_function",
"get_channel",
"get_compose",
"get_view",
"get_all_functions",
"get_all_channels",
"get_all_compositions",
"get_registry",
"get_schema",
"get_contexts",
"get_context_groups",
"get_forms",
"clear_registry",
# Discovery
"mizan_clients",
"mizan_module",
# Settings
"mizanSettings",
"get_settings",
"clear_settings_cache",
]

View File

@@ -0,0 +1,90 @@
"""
mizan Auto-Discovery
Scans Django apps for server functions following the 'clients' layer convention:
- <app>/clients.py
- <app>/clients/**/*.py
Usage in urls.py:
from mizan.setup.discovery import mizan_clients
mizan_clients('apps') # Scans apps/*/clients.py
mizan_clients('mizan', 'allauth') # Scans mizan/allauth/**/*.py
This replaces manual "import to register" patterns with explicit auto-discovery.
"""
from typing import Any
from mizan._vendor.app_visitor import DjangoAppVisitor, get_members
from .registry import register, get_function
from mizan.client.function import ServerFunction
class _RegisterServerFunctions:
"""Visitor handler that registers ServerFunction subclasses."""
def on_module(
self, app_name: str, path_parts: list[str], members: list[tuple[str, Any]]
) -> None:
"""Process discovered module members."""
for name, member in members:
# Register ServerFunction subclasses
if (
isinstance(member, type)
and issubclass(member, ServerFunction)
and member is not ServerFunction
and hasattr(member, "__name__")
):
# Use the function name as registration name
fn_name = getattr(member, "name", None) or member.__name__
# Skip already registered (idempotent)
if get_function(fn_name) is member:
continue
try:
register(member, fn_name)
except ValueError:
# Already registered with different class - skip
pass
def mizan_clients(apps_root: str, layer: str = "clients") -> None:
"""
Discover and register server functions from Django apps.
Scans for the specified layer (default: 'clients') in each app:
- <app>/<layer>.py
- <app>/<layer>/**/*.py
Args:
apps_root: Root package containing Django apps (e.g., 'apps')
layer: Module name pattern to scan (default: 'clients')
Example:
# In urls.py
mizan_clients('apps') # Scans apps/*/clients.py
mizan_clients('apps', 'functions') # Scans apps/*/functions.py
"""
visitor = DjangoAppVisitor(layer=layer, apps_root=apps_root)
visitor.visit(_RegisterServerFunctions())
def mizan_module(module_path: str) -> None:
"""
Register server functions from a specific module.
Use this for library modules that don't follow the app convention.
Args:
module_path: Full module path (e.g., 'mizan.integrations.allauth')
Example:
mizan_module('mizan.integrations.allauth')
mizan_module('mizan.jwt.functions')
"""
members = get_members(module_path)
handler = _RegisterServerFunctions()
handler.on_module("", [], members)

View File

@@ -0,0 +1,333 @@
"""
mizan Registry
Central registration for server functions, channels, and compositions.
All items are identified by name.
"""
from __future__ import annotations
from typing import TYPE_CHECKING, Any, Callable
if TYPE_CHECKING:
from mizan.client.function import ServerFunction, ComposedContext
from mizan.channels import ReactChannel
# Global registries - all use name as key
_functions: dict[str, type["ServerFunction"]] = {}
_channels: dict[str, type["ReactChannel"]] = {}
_compositions: dict[str, "ComposedContext"] = {}
def register(
view_class: type["ServerFunction"] | type["ReactChannel"],
name: str,
) -> type["ServerFunction"] | type["ReactChannel"]:
"""
Register a server function or channel.
Args:
view_class: ServerFunction or ReactChannel subclass
name: Registration name (used for API calls and code generation)
Returns:
The view class (allows use as part of decorator chain)
"""
from mizan.client.function import ServerFunction
from mizan.channels import ReactChannel
view_class.name = name
if issubclass(view_class, ReactChannel):
if name in _channels:
# Allow re-registration of the same class (idempotent for reloads)
if _channels[name] is not view_class:
raise ValueError(
f"Channel '{name}' already registered by {_channels[name].__name__}"
)
return view_class
_channels[name] = view_class
elif issubclass(view_class, ServerFunction):
if name in _functions:
# Allow re-registration of the same class (idempotent for reloads)
existing = _functions[name]
if existing.__name__ == view_class.__name__:
# Same function being re-registered (reload scenario)
_functions[name] = view_class
return view_class
raise ValueError(
f"Function '{name}' already registered by {existing.__name__}"
)
_functions[name] = view_class
else:
raise TypeError(f"{view_class} must be a ServerFunction or ReactChannel")
return view_class
def register_as(name: str):
"""
Decorator for registering a server function or channel.
Usage:
@register_as('update-profile')
class UpdateProfile(ServerFunction):
...
"""
def decorator(view_class):
return register(view_class, name)
return decorator
def register_form(
form_class: type,
name: str,
submit_handler: Callable | None = None,
) -> None:
"""
Register a Django Form as server functions.
Creates and registers:
- {name}.schema: Returns form field definitions
- {name}.validate: Validates form data
- {name}.submit: Submits form (if submit_handler provided)
Usage:
register_form(ContactForm, 'contact', submit_handler=handle_contact)
"""
from mizan.client.function import create_form_functions
schema_fn, validate_fn, submit_fn = create_form_functions(
form_class, name, submit_handler
)
register(schema_fn, f"{name}.schema")
register(validate_fn, f"{name}.validate")
if submit_fn:
register(submit_fn, f"{name}.submit")
def register_compose(
composed: "ComposedContext",
name: str,
) -> "ComposedContext":
"""
Register a composed context.
Args:
composed: ComposedContext instance
name: Registration name
Returns:
The composed context
"""
if name in _compositions:
existing = _compositions[name]
if existing.name == composed.name:
# Same composition being re-registered (reload scenario)
_compositions[name] = composed
return composed
raise ValueError(f"Composition '{name}' already registered by {existing.name}")
_compositions[name] = composed
return composed
def get_function(name: str) -> type["ServerFunction"] | None:
"""Get a registered server function by name."""
return _functions.get(name)
def get_channel(name: str) -> type["ReactChannel"] | None:
"""Get a registered channel by name."""
return _channels.get(name)
def get_compose(name: str) -> "ComposedContext | None":
"""Get a registered composition by name."""
return _compositions.get(name)
def get_view(name: str) -> type["ServerFunction"] | type["ReactChannel"] | None:
"""Get any registered view by name (function or channel)."""
return _functions.get(name) or _channels.get(name)
def get_all_functions() -> dict[str, type["ServerFunction"]]:
"""Get all registered functions."""
return _functions.copy()
def get_all_channels() -> dict[str, type["ReactChannel"]]:
"""Get all registered channels."""
return _channels.copy()
def get_all_compositions() -> dict[str, "ComposedContext"]:
"""Get all registered compositions."""
return _compositions.copy()
def get_registry() -> dict[str, dict[str, Any]]:
"""
Get the full registry organized by type.
Returns:
{
"functions": { name: class, ... },
"channels": { name: class, ... },
"compositions": { name: ComposedContext, ... },
}
"""
return {
"functions": _functions.copy(),
"channels": _channels.copy(),
"compositions": _compositions.copy(),
}
def get_schema() -> dict[str, Any]:
"""
Export the full schema for TypeScript generation.
Returns:
{
"functions": {
"update_profile": {
"name": "update_profile",
"type": "function",
"meta": { "context": "global", ... },
"input": { ... },
"output": { ... },
},
...
},
"channels": {
"chat": {
"name": "chat",
"type": "channel",
"params": { ... },
"django_message": { ... },
...
},
...
},
"compositions": {
"user_page": {
"name": "user_page",
"type": "compose",
"meta": { "on_server": false, ... },
"children": ["user_profile", "user_posts"],
"leaves": ["user_profile", "user_posts"],
},
...
},
}
"""
functions = {}
for name, cls in _functions.items():
schema = cls.get_schema_export()
functions[name] = schema
compositions = {}
for name, composed in _compositions.items():
compositions[name] = {
"name": composed.name,
"type": "compose",
"meta": composed._meta,
"children": composed._meta.get("children", []),
"leaves": composed._meta.get("leaves", []),
}
# Build channel schemas from our registry
# Only include keys when they have values (test expects absent keys, not None)
channels_schema = {}
for name, channel_class in _channels.items():
channel_schema: dict[str, Any] = {
"name": name,
"type": "channel",
"bidirectional": False,
}
# Extract Params schema (only if defined)
if hasattr(channel_class, "Params") and channel_class.Params:
channel_schema["params"] = channel_class.Params.model_json_schema()
# Extract ReactMessage schema (only if defined - indicates bidirectional)
if hasattr(channel_class, "ReactMessage") and channel_class.ReactMessage:
channel_schema[
"react_message"
] = channel_class.ReactMessage.model_json_schema()
channel_schema["bidirectional"] = True
# Extract DjangoMessage schema (only if defined)
if hasattr(channel_class, "DjangoMessage") and channel_class.DjangoMessage:
channel_schema[
"django_message"
] = channel_class.DjangoMessage.model_json_schema()
channels_schema[name] = channel_schema
return {
"functions": functions,
"channels": channels_schema,
"compositions": compositions,
}
def get_contexts() -> dict[str, type["ServerFunction"]]:
"""
Get all server functions marked as contexts.
These are functions with meta.context = True, used for SSR hydration.
"""
contexts = {}
for name, cls in _functions.items():
if getattr(cls, "_meta", {}).get("context"):
contexts[name] = cls
return contexts
def get_context_groups() -> dict[str, list[str]]:
"""
Group function names by their context string.
Returns:
{"global": ["current_user"], "user": ["user_profile", "user_orders"]}
"""
groups: dict[str, list[str]] = {}
for name, cls in _functions.items():
ctx = getattr(cls, "_meta", {}).get("context")
if ctx:
groups.setdefault(ctx, []).append(name)
return groups
def get_forms() -> dict[str, list[type["ServerFunction"]]]:
"""
Get all server functions that are form-related, grouped by form name.
Returns:
{
"contact": [ContactSchema, ContactValidate, ContactSubmit],
...
}
"""
forms: dict[str, list] = {}
for name, cls in _functions.items():
meta = getattr(cls, "_meta", {})
if meta.get("form"):
form_name = meta.get("form_name")
if form_name not in forms:
forms[form_name] = []
forms[form_name].append(cls)
return forms
def clear_registry() -> None:
"""Clear all registrations. Primarily for testing."""
_functions.clear()
_channels.clear()
_compositions.clear()

View File

@@ -0,0 +1,36 @@
"""
mizan Settings
Configuration is read from Django settings with sensible defaults.
"""
from dataclasses import dataclass
from functools import lru_cache
from django.conf import settings as django_settings
@dataclass
class mizanSettings:
"""mizan configuration."""
# Whether to expose function names in DEBUG mode errors
debug_expose_names: bool
@lru_cache
def get_settings() -> mizanSettings:
"""
Load mizan settings from Django settings.
Settings:
mizan_DEBUG_EXPOSE_NAMES: Show function names in errors when DEBUG=True (default: True)
"""
return mizanSettings(
debug_expose_names=getattr(django_settings, "mizan_DEBUG_EXPOSE_NAMES", True),
)
def clear_settings_cache():
"""Clear the settings cache (for testing)."""
get_settings.cache_clear()

View File

@@ -0,0 +1,3 @@
from mizan.shapes.core import Diff, NestedDiff, Shape
__all__ = ["Diff", "NestedDiff", "Shape"]

View File

@@ -0,0 +1,265 @@
from __future__ import annotations
import types
from typing import Any, ClassVar, Generic, TypeVar, Union, get_type_hints
from pydantic import BaseModel
from django_readers import pairs, specs
from django_readers import qs as readers_qs
_M = TypeVar("_M")
_S = TypeVar("_S", bound="Shape")
def _extract_shape_class(hint) -> type[Shape] | None:
origin = getattr(hint, "__origin__", None)
args = getattr(hint, "__args__", ())
# list[SomeShape]
if (
origin is list
and args
and isinstance(args[0], type)
and issubclass(args[0], Shape)
):
return args[0]
# SomeShape (bare)
if isinstance(hint, type) and issubclass(hint, Shape) and hint is not Shape:
return hint
# SomeShape | None (Union/Optional)
if origin is Union or isinstance(hint, types.UnionType):
for arg in args:
if arg is type(None):
continue
if isinstance(arg, type) and issubclass(arg, Shape) and arg is not Shape:
return arg
return None
def _resolve_model(cls) -> Any | None:
for base in cls.__bases__:
meta = getattr(base, "__pydantic_generic_metadata__", None) or {}
if meta.get("origin") is Shape and (args := meta.get("args")):
return args[0]
return None
class Shape(BaseModel, Generic[_M]):
_model: ClassVar[Any]
_nested: ClassVar[dict[str, type[Shape]]]
_field_names: ClassVar[list[str]]
_pk_field: ClassVar[str]
_spec: ClassVar[list]
_pair: ClassVar[tuple]
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not (model := _resolve_model(cls)):
return
cls._model = model
cls._nested = {}
cls._pk_field = model._meta.pk.name if model._meta.pk else "id"
hints = get_type_hints(cls, include_extras=False, localns={cls.__name__: cls}) or cls.__annotations__
field_names = []
for name, hint in hints.items():
if name.startswith("_"):
continue
if shape_cls := _extract_shape_class(hint):
cls._nested[name] = shape_cls
else:
field_names.append(name)
cls._field_names = field_names
# Set field-only spec first so self-references can find it
cls._spec = [*field_names]
# Now rebuild with nested — self-refs resolve because cls._spec exists
cls._spec = [
*field_names,
*({name: shape._spec} for name, shape in cls._nested.items()),
]
cls._pair = specs.process(cls._spec)
@classmethod
def _build_pair(cls, relation_qs: dict[str, Any]):
field_pairs = [
pairs.producer_to_projector(name, pairs.field(name))
for name in cls._field_names
]
rel_pairs = []
for name, shape_cls in cls._nested.items():
child_prepare, child_project = shape_cls._pair
prepare = (
readers_qs.pipe(relation_qs[name], child_prepare)
if name in relation_qs
else child_prepare
)
rel_pairs.append(
pairs.producer_to_projector(
name, pairs.relationship(name, (prepare, child_project))
)
)
return pairs.combine(*field_pairs, *rel_pairs)
@classmethod
def _get_pk(cls, instance) -> Any | None:
return getattr(instance, cls._pk_field, None)
@classmethod
def query(cls: type[_S], *qs_fns, **relation_qs) -> list[_S]:
prepare, project = cls._build_pair(relation_qs) if relation_qs else cls._pair
base = cls._model.objects.all()
# Accept a raw QuerySet as the first arg, or qs functions, or nothing
if qs_fns and hasattr(qs_fns[0], "query"):
base, qs_fns = qs_fns[0], qs_fns[1:]
queryset = readers_qs.pipe(prepare, *qs_fns)(base)
return [cls.model_validate(project(obj)) for obj in queryset]
@classmethod
def diff_many(cls: type[_S], items: list[_S]) -> list[tuple[_S, Diff]]:
pk_field = cls._pk_field
pk_map: dict[Any, _S] = {}
new_items: list[_S] = []
for item in items:
pk = cls._get_pk(item)
if pk is not None:
pk_map[pk] = item
else:
new_items.append(item)
# Single query for all existing items
current_map: dict[Any, _S] = {}
if pk_map:
current_items = cls.query(
cls._model.objects.filter(**{f"{pk_field}__in": pk_map.keys()})
)
current_map = {cls._get_pk(c): c for c in current_items}
results: list[tuple[_S, Diff]] = []
for item in new_items:
results.append((item, cls._diff_one(item, None)))
for pk, item in pk_map.items():
current = current_map.get(pk)
if current is None:
raise cls._model.DoesNotExist(
f"{cls._model.__name__} with {pk_field}={pk} does not exist"
)
results.append((item, cls._diff_one(item, current)))
return results
@classmethod
def _diff_one(cls, incoming: _S, current: _S | None) -> Diff:
pk_field = cls._pk_field
changed = (
{
k: getattr(incoming, k)
for k in cls._field_names
if k != pk_field and getattr(incoming, k) != getattr(current, k)
}
if current
else {k: getattr(incoming, k) for k in cls._field_names if k != pk_field}
)
nested = {}
for name, shape_cls in cls._nested.items():
incoming_items = getattr(incoming, name, None) or []
current_items = getattr(current, name, None) or [] if current else []
if not isinstance(incoming_items, list):
incoming_items = [incoming_items]
if not isinstance(current_items, list):
current_items = [current_items]
child_pk = shape_cls._pk_field
current_by_pk = {
shape_cls._get_pk(c): c
for c in current_items
if shape_cls._get_pk(c) is not None
}
incoming_by_pk = {
shape_cls._get_pk(c): c
for c in incoming_items
if shape_cls._get_pk(c) is not None
}
nested[name] = NestedDiff(
created=[c for c in incoming_items if shape_cls._get_pk(c) is None],
updated=[
c
for pk, c in incoming_by_pk.items()
if pk in current_by_pk and c != current_by_pk[pk]
],
deleted=[pk for pk in current_by_pk if pk not in incoming_by_pk],
)
return Diff(is_new=current is None, changed=changed, _nested=nested)
def diff(self) -> Diff:
cls = type(self)
pk = cls._get_pk(self)
if pk is not None:
results = cls.query(cls._model.objects.filter(pk=pk))
if not results:
raise cls._model.DoesNotExist(
f"{cls._model.__name__} with {cls._pk_field}={pk} does not exist"
)
current = results[0]
else:
current = None
return cls._diff_one(self, current)
class NestedDiff:
__slots__ = ("created", "updated", "deleted")
def __init__(self, created=(), updated=(), deleted=()):
self.created = list(created)
self.updated = list(updated)
self.deleted = list(deleted)
class Diff:
__slots__ = ("is_new", "changed", "_nested")
def __init__(
self, is_new: bool, changed: dict[str, Any], _nested: dict[str, NestedDiff]
):
self.is_new = is_new
self.changed = changed
self._nested = _nested
def nested(self, name: str) -> NestedDiff:
"""Strict access to nested diffs. Raises KeyError for invalid names."""
if name not in self._nested:
valid = ", ".join(sorted(self._nested)) or "(none)"
raise KeyError(f"No nested diff for '{name}'. Valid nested shapes: {valid}")
return self._nested[name]
def __getattr__(self, name: str) -> NestedDiff:
if name.startswith("_"):
raise AttributeError(name)
if name not in self._nested:
valid = ", ".join(sorted(self._nested)) or "(none)"
raise AttributeError(
f"No nested diff for '{name}'. Valid nested shapes: {valid}"
)
return self._nested[name]

View File

@@ -0,0 +1,554 @@
"""
Authentication Tests for mizan Server Functions
Tests all combinations of:
- Transport: HTTP vs WebSocket RPC
- JWT: Present (valid), Present (invalid), Absent
- Session: Present (valid), Absent
Expected behavior:
- JWT present (valid) → JWTUser (no DB query)
- JWT present (invalid) → Reject (401), do NOT fall back to session
- JWT absent + Session present → Session auth (DB query)
- JWT absent + Session absent → AnonymousUser
"""
from django.test import TestCase, RequestFactory, override_settings
from django.contrib.auth import get_user_model
from django.contrib.sessions.backends.db import SessionStore
from unittest.mock import patch, MagicMock
import json
from mizan.jwt.tokens import (
create_token_pair,
decode_token,
JWTUser,
)
from mizan.client.executor import (
_try_jwt_auth,
execute_function,
FunctionError,
FunctionResult,
ErrorCode,
)
from mizan.client import client
from mizan.setup.registry import clear_registry, register
from pydantic import BaseModel
User = get_user_model()
# =============================================================================
# Test Output Models (proper Pydantic models, not raw dicts)
# =============================================================================
class WhoamiOutput(BaseModel):
is_authenticated: bool
user_id: int | None
user_type: str
is_staff: bool
class OkOutput(BaseModel):
ok: bool
class UserTypeOutput(BaseModel):
user_type: str
# =============================================================================
# Test Server Functions - defined as plain functions, registered in setUp
# =============================================================================
def _whoami_fn(request) -> WhoamiOutput:
"""Returns info about the authenticated user."""
user = request.user
return WhoamiOutput(
is_authenticated=user.is_authenticated,
user_id=getattr(user, "id", None),
user_type=type(user).__name__,
is_staff=getattr(user, "is_staff", False),
)
@override_settings(
JWT_PRIVATE_KEY="test-secret-key-for-testing-only",
JWT_ALGORITHM="HS256",
)
class HTTPAuthTests(TestCase):
"""Test HTTP transport authentication combinations."""
def setUp(self):
clear_registry()
self.factory = RequestFactory()
self.user = User.objects.create_user(
email="test@example.com",
password="testpass123",
is_staff=True,
is_superuser=False,
)
# Create a session
self.session = SessionStore()
self.session.create()
self.session_key = self.session.session_key
# Register test function
@client
def whoami(request) -> WhoamiOutput:
user = request.user
return WhoamiOutput(
is_authenticated=user.is_authenticated,
user_id=getattr(user, "id", None),
user_type=type(user).__name__,
is_staff=getattr(user, "is_staff", False),
)
register(whoami, "whoami")
def tearDown(self):
self.user.delete()
self.session.delete()
clear_registry()
def test_jwt_valid_no_session(self):
"""Valid JWT without session → JWTUser (no DB query)."""
tokens = create_token_pair(
self.user.pk,
self.session_key,
is_staff=True,
is_superuser=False,
)
request = self.factory.post("/")
request.META["HTTP_AUTHORIZATION"] = f"Bearer {tokens.access_token}"
request.user = MagicMock(is_authenticated=False) # No session auth
# Try JWT auth
result = _try_jwt_auth(request)
self.assertTrue(result)
self.assertIsInstance(request.user, JWTUser)
self.assertEqual(request.user.id, self.user.pk)
self.assertTrue(request.user.is_staff)
self.assertTrue(request.user.is_authenticated)
def test_jwt_valid_with_session(self):
"""Valid JWT with session → JWT takes precedence (no DB query)."""
tokens = create_token_pair(
self.user.pk,
self.session_key,
is_staff=True,
is_superuser=False,
)
request = self.factory.post("/")
request.META["HTTP_AUTHORIZATION"] = f"Bearer {tokens.access_token}"
request.user = self.user # Session auth already set user
# JWT should still be processed and take precedence
result = _try_jwt_auth(request)
self.assertTrue(result)
self.assertIsInstance(request.user, JWTUser)
def test_jwt_invalid_with_session(self):
"""Invalid JWT with valid session → Reject (do NOT fall back)."""
request = self.factory.post("/")
request.META["HTTP_AUTHORIZATION"] = "Bearer invalid-token-here"
request.user = self.user # Session would work
# JWT auth should fail
result = _try_jwt_auth(request)
self.assertFalse(result)
# User should NOT be changed to session user - that happens elsewhere
# The point is _try_jwt_auth returns False, indicating JWT failed
def test_jwt_expired_with_session(self):
"""Expired JWT with valid session → Reject (do NOT fall back)."""
# Create token with past expiration by mocking time
with patch("mizan.jwt.tokens.time.time", return_value=0):
tokens = create_token_pair(
self.user.pk,
self.session_key,
is_staff=True,
is_superuser=False,
)
request = self.factory.post("/")
request.META["HTTP_AUTHORIZATION"] = f"Bearer {tokens.access_token}"
request.user = self.user # Session would work
# JWT auth should fail (expired)
result = _try_jwt_auth(request)
self.assertFalse(result)
def test_no_jwt_with_session(self):
"""No JWT with valid session → Session auth (normal Django flow)."""
request = self.factory.post("/")
request.user = self.user # Session auth set user
# No JWT auth attempted
result = _try_jwt_auth(request)
self.assertFalse(result) # No JWT to process
# User remains the session user
self.assertEqual(request.user, self.user)
def test_no_jwt_no_session(self):
"""No JWT, no session → AnonymousUser."""
from django.contrib.auth.models import AnonymousUser
request = self.factory.post("/")
request.user = AnonymousUser()
result = _try_jwt_auth(request)
self.assertFalse(result)
self.assertIsInstance(request.user, AnonymousUser)
def test_execute_function_with_jwt(self):
"""Execute server function with JWT auth."""
tokens = create_token_pair(
self.user.pk,
self.session_key,
is_staff=True,
is_superuser=False,
)
request = self.factory.post("/")
request.META["HTTP_AUTHORIZATION"] = f"Bearer {tokens.access_token}"
# Simulate what the view does: try JWT auth first
_try_jwt_auth(request)
# Use the whoami function which returns WhoamiOutput (Pydantic model)
result = execute_function(request, "whoami", {})
self.assertIsInstance(result, FunctionResult)
self.assertTrue(result.data["is_authenticated"])
self.assertEqual(result.data["user_type"], "JWTUser")
self.assertTrue(result.data["is_staff"])
@override_settings(
JWT_PRIVATE_KEY="test-secret-key-for-testing-only",
JWT_ALGORITHM="HS256",
)
class JWTUserTests(TestCase):
"""Test JWTUser behavior."""
def setUp(self):
clear_registry()
def tearDown(self):
clear_registry()
def test_jwt_user_attributes(self):
"""JWTUser has expected attributes."""
from mizan.jwt.tokens import TokenPayload
payload = TokenPayload(
user_id=42,
session_key="test-session",
token_type="access",
is_staff=True,
is_superuser=False,
exp=9999999999,
iat=0,
)
user = JWTUser(payload)
self.assertEqual(user.id, 42)
self.assertEqual(user.pk, 42)
self.assertTrue(user.is_staff)
self.assertFalse(user.is_superuser)
self.assertTrue(user.is_authenticated)
self.assertFalse(user.is_anonymous)
self.assertTrue(user.is_active)
def test_jwt_user_string_id(self):
"""JWTUser handles string user_id (converted to int)."""
from mizan.jwt.tokens import TokenPayload
payload = TokenPayload(
user_id="42", # String, as stored in JWT
session_key="test-session",
token_type="access",
is_staff=False,
is_superuser=False,
exp=9999999999,
iat=0,
)
user = JWTUser(payload)
self.assertEqual(user.id, 42)
self.assertIsInstance(user.id, int)
@override_settings(
JWT_PRIVATE_KEY="test-secret-key-for-testing-only",
JWT_ALGORITHM="HS256",
)
class AuthDecoratorTests(TestCase):
"""Test @client(auth=...) decorator."""
def setUp(self):
clear_registry()
self.factory = RequestFactory()
self.user = User.objects.create_user(
email="test@example.com",
password="testpass123",
is_staff=False,
is_superuser=False,
)
self.staff_user = User.objects.create_user(
email="staff@example.com",
password="testpass123",
is_staff=True,
is_superuser=False,
)
self.superuser = User.objects.create_user(
email="super@example.com",
password="testpass123",
is_staff=True,
is_superuser=True,
)
def tearDown(self):
self.user.delete()
self.staff_user.delete()
self.superuser.delete()
clear_registry()
def test_auth_required_with_anonymous(self):
"""@client(auth=True) rejects anonymous users."""
from django.contrib.auth.models import AnonymousUser
# Register a test function with proper Pydantic model
@client(auth=True)
def protected_fn(request) -> OkOutput:
return OkOutput(ok=True)
register(protected_fn, "protected_fn")
request = self.factory.post("/")
request.user = AnonymousUser()
result = execute_function(request, "protected_fn", {})
self.assertIsInstance(result, FunctionError)
self.assertEqual(result.code, ErrorCode.UNAUTHORIZED)
def test_auth_required_with_authenticated(self):
"""@client(auth=True) allows authenticated users."""
@client(auth=True)
def protected_fn2(request) -> OkOutput:
return OkOutput(ok=True)
register(protected_fn2, "protected_fn2")
request = self.factory.post("/")
request.user = self.user
result = execute_function(request, "protected_fn2", {})
self.assertIsInstance(result, FunctionResult)
self.assertEqual(result.data["ok"], True)
def test_auth_staff_with_regular_user(self):
"""@client(auth='staff') rejects non-staff users."""
@client(auth="staff")
def staff_fn(request) -> OkOutput:
return OkOutput(ok=True)
register(staff_fn, "staff_fn")
request = self.factory.post("/")
request.user = self.user # Not staff
result = execute_function(request, "staff_fn", {})
self.assertIsInstance(result, FunctionError)
self.assertEqual(result.code, ErrorCode.FORBIDDEN)
def test_auth_staff_with_staff_user(self):
"""@client(auth='staff') allows staff users."""
@client(auth="staff")
def staff_fn2(request) -> OkOutput:
return OkOutput(ok=True)
register(staff_fn2, "staff_fn2")
request = self.factory.post("/")
request.user = self.staff_user
result = execute_function(request, "staff_fn2", {})
self.assertIsInstance(result, FunctionResult)
def test_auth_superuser_with_staff(self):
"""@client(auth='superuser') rejects non-superusers."""
@client(auth="superuser")
def super_fn(request) -> OkOutput:
return OkOutput(ok=True)
register(super_fn, "super_fn")
request = self.factory.post("/")
request.user = self.staff_user # Staff but not superuser
result = execute_function(request, "super_fn", {})
self.assertIsInstance(result, FunctionError)
self.assertEqual(result.code, ErrorCode.FORBIDDEN)
def test_auth_superuser_with_superuser(self):
"""@client(auth='superuser') allows superusers."""
@client(auth="superuser")
def super_fn2(request) -> OkOutput:
return OkOutput(ok=True)
register(super_fn2, "super_fn2")
request = self.factory.post("/")
request.user = self.superuser
result = execute_function(request, "super_fn2", {})
self.assertIsInstance(result, FunctionResult)
def test_auth_with_jwt_user(self):
"""Auth checks work with JWTUser (stateless)."""
from mizan.jwt.tokens import TokenPayload
@client(auth="staff")
def jwt_staff_fn(request) -> UserTypeOutput:
return UserTypeOutput(user_type=type(request.user).__name__)
register(jwt_staff_fn, "jwt_staff_fn")
# Create JWTUser with is_staff=True
payload = TokenPayload(
user_id=99,
session_key="test",
token_type="access",
is_staff=True,
is_superuser=False,
exp=9999999999,
iat=0,
)
jwt_user = JWTUser(payload)
request = self.factory.post("/")
request.user = jwt_user
result = execute_function(request, "jwt_staff_fn", {})
self.assertIsInstance(result, FunctionResult)
self.assertEqual(result.data["user_type"], "JWTUser")
def test_auth_invalid_string_raises(self):
"""Invalid auth string raises ValueError at decoration time."""
with self.assertRaises(ValueError) as ctx:
@client(auth="admin") # 'admin' is not valid
def bad_fn(request) -> OkOutput:
return OkOutput(ok=True)
self.assertIn("Invalid auth value 'admin'", str(ctx.exception))
self.assertIn("required", str(ctx.exception))
def test_auth_callable_returns_true(self):
"""Callable auth returning True allows access."""
@client(auth=lambda r: r.user.email.endswith("@example.com"))
def email_check_fn(request) -> OkOutput:
return OkOutput(ok=True)
register(email_check_fn, "email_check_fn")
request = self.factory.post("/")
request.user = self.user # email is test@example.com
result = execute_function(request, "email_check_fn", {})
self.assertIsInstance(result, FunctionResult)
self.assertTrue(result.data["ok"])
def test_auth_callable_returns_false(self):
"""Callable auth returning False denies access."""
@client(auth=lambda r: r.user.email.endswith("@admin.com"))
def admin_email_fn(request) -> OkOutput:
return OkOutput(ok=True)
register(admin_email_fn, "admin_email_fn")
request = self.factory.post("/")
request.user = self.user # email is test@example.com, not @admin.com
result = execute_function(request, "admin_email_fn", {})
self.assertIsInstance(result, FunctionError)
self.assertEqual(result.code, ErrorCode.FORBIDDEN)
self.assertEqual(result.message, "Access denied")
def test_auth_callable_raises_permission_error(self):
"""Callable auth raising PermissionError uses custom message."""
def check_premium(request):
if not getattr(request.user, "is_premium", False):
raise PermissionError("Premium subscription required")
return True
@client(auth=check_premium)
def premium_fn(request) -> OkOutput:
return OkOutput(ok=True)
register(premium_fn, "premium_fn")
request = self.factory.post("/")
request.user = self.user # No is_premium attribute
result = execute_function(request, "premium_fn", {})
self.assertIsInstance(result, FunctionError)
self.assertEqual(result.code, ErrorCode.FORBIDDEN)
self.assertEqual(result.message, "Premium subscription required")
def test_auth_callable_with_anonymous_user(self):
"""Callable auth can check for anonymous users."""
from django.contrib.auth.models import AnonymousUser
def must_be_authenticated(request):
if not request.user.is_authenticated:
raise PermissionError("Please log in")
return True
@client(auth=must_be_authenticated)
def needs_login_fn(request) -> OkOutput:
return OkOutput(ok=True)
register(needs_login_fn, "needs_login_fn")
request = self.factory.post("/")
request.user = AnonymousUser()
result = execute_function(request, "needs_login_fn", {})
self.assertIsInstance(result, FunctionError)
self.assertEqual(result.code, ErrorCode.FORBIDDEN)
self.assertEqual(result.message, "Please log in")

View File

@@ -0,0 +1,567 @@
"""
Protocol Benchmark: HTTP vs WebSocket Server Functions
Compares performance of HTTP POST vs WebSocket RPC for server function calls.
Includes realistic scenarios with ORM queries.
Usage:
python manage.py test mizan.tests.test_benchmarks --verbosity=2
Note:
These are not unit tests - they measure performance. Results are printed
to stdout and should be run in isolation for accurate measurements.
"""
import asyncio
import json
import statistics
import time
from typing import Any
from unittest.mock import MagicMock, AsyncMock
from django.contrib.auth import get_user_model
from django.contrib.auth.models import AnonymousUser
from django.db import connection
from django.http import HttpRequest
from django.test import RequestFactory, TestCase, TransactionTestCase, override_settings
from pydantic import BaseModel
from mizan.client.executor import FunctionResult, execute_function, function_call_view
from mizan.setup.registry import clear_registry
from mizan.client import client
User = get_user_model()
# =============================================================================
# Benchmark Output Models
# =============================================================================
class SimpleOutput(BaseModel):
value: int
class UserOutput(BaseModel):
id: int
email: str
class UserListOutput(BaseModel):
users: list[dict[str, Any]]
count: int
class StatsOutput(BaseModel):
total_users: int
active_users: int
staff_count: int
# =============================================================================
# Benchmark Functions
# =============================================================================
def setup_benchmark_functions():
"""Register benchmark server functions."""
from mizan.setup.registry import register
clear_registry()
# 1. Simple computation (no I/O)
@client
def bench_simple(request: HttpRequest, a: int, b: int) -> SimpleOutput:
"""Simple addition - baseline with no I/O."""
return SimpleOutput(value=a + b)
register(bench_simple, "bench_simple")
# 2. Single ORM query
@client
def bench_get_user(request: HttpRequest, user_id: int) -> UserOutput:
"""Fetch single user by ID."""
user = User.objects.filter(id=user_id).first()
if user:
return UserOutput(id=user.id, email=user.email)
return UserOutput(id=0, email="")
register(bench_get_user, "bench_get_user")
# 3. List query with limit
@client
def bench_list_users(request: HttpRequest, limit: int) -> UserListOutput:
"""Fetch list of users with limit."""
users = User.objects.all()[:limit]
return UserListOutput(
users=[{"id": u.id, "email": u.email} for u in users],
count=len(users),
)
register(bench_list_users, "bench_list_users")
# 4. Aggregation query
@client
def bench_user_stats(request: HttpRequest) -> StatsOutput:
"""Compute user statistics with multiple queries."""
total = User.objects.count()
active = User.objects.filter(is_active=True).count()
staff = User.objects.filter(is_staff=True).count()
return StatsOutput(
total_users=total,
active_users=active,
staff_count=staff,
)
register(bench_user_stats, "bench_user_stats")
# 5. Complex query with joins
@client
def bench_user_search(
request: HttpRequest, email_contains: str, limit: int
) -> UserListOutput:
"""Search users by email pattern."""
users = User.objects.filter(
email__icontains=email_contains,
is_active=True,
).select_related()[:limit]
return UserListOutput(
users=[{"id": u.id, "email": u.email} for u in users],
count=len(users),
)
register(bench_user_search, "bench_user_search")
# =============================================================================
# Benchmark Test Cases
# =============================================================================
class ProtocolBenchmark(TransactionTestCase):
"""
Benchmark comparing HTTP vs WebSocket (simulated) performance.
Uses TransactionTestCase to ensure database state is realistic.
"""
# Number of iterations for each benchmark
ITERATIONS = 100
WARMUP = 10
@classmethod
def setUpClass(cls):
super().setUpClass()
setup_benchmark_functions()
def setUp(self):
self.factory = RequestFactory()
# Create test users for ORM benchmarks
self._create_test_users()
def _create_test_users(self):
"""Create test users for benchmarks."""
# Create 100 test users
users = []
for i in range(100):
users.append(
User(
email=f"bench{i}@example.com",
is_active=i % 10 != 0, # 90% active
is_staff=i < 5, # 5 staff
)
)
User.objects.bulk_create(users, ignore_conflicts=True)
self.test_user = User.objects.first()
def _make_request(self, body: dict | None = None) -> HttpRequest:
"""Create a request with optional JSON body."""
if body:
request = self.factory.post(
"/api/mizan/call/",
data=json.dumps(body),
content_type="application/json",
)
else:
request = self.factory.post("/api/mizan/call/")
request.user = AnonymousUser()
request._dont_enforce_csrf_checks = True
return request
def _benchmark_executor(self, fn_name: str, args: dict, label: str) -> dict:
"""
Benchmark direct executor calls (simulates WebSocket RPC).
Returns timing statistics.
"""
request = self._make_request()
times = []
# Warmup
for _ in range(self.WARMUP):
execute_function(request, fn_name, args)
# Benchmark
for _ in range(self.ITERATIONS):
start = time.perf_counter()
result = execute_function(request, fn_name, args)
end = time.perf_counter()
times.append((end - start) * 1000) # ms
return self._compute_stats(times, f"Executor ({label})")
def _benchmark_http(self, fn_name: str, args: dict, label: str) -> dict:
"""
Benchmark HTTP view calls.
Returns timing statistics.
"""
times = []
# Warmup
for _ in range(self.WARMUP):
request = self._make_request({"fn": fn_name, "args": args})
function_call_view(request)
# Benchmark
for _ in range(self.ITERATIONS):
request = self._make_request({"fn": fn_name, "args": args})
start = time.perf_counter()
response = function_call_view(request)
end = time.perf_counter()
times.append((end - start) * 1000) # ms
return self._compute_stats(times, f"HTTP ({label})")
def _compute_stats(self, times: list[float], label: str) -> dict:
"""Compute statistics from timing data."""
return {
"label": label,
"min": min(times),
"max": max(times),
"mean": statistics.mean(times),
"median": statistics.median(times),
"stdev": statistics.stdev(times) if len(times) > 1 else 0,
"p95": sorted(times)[int(len(times) * 0.95)],
"p99": sorted(times)[int(len(times) * 0.99)],
"iterations": len(times),
}
def _print_results(self, results: list[dict]):
"""Print benchmark results in a table."""
print("\n" + "=" * 80)
print(f"{'Benchmark':<40} {'Mean':>8} {'Median':>8} {'P95':>8} {'P99':>8}")
print("=" * 80)
for r in results:
print(
f"{r['label']:<40} {r['mean']:>7.3f}ms {r['median']:>7.3f}ms {r['p95']:>7.3f}ms {r['p99']:>7.3f}ms"
)
print("=" * 80)
def _print_comparison(self, executor_stats: dict, http_stats: dict):
"""Print comparison between executor and HTTP."""
overhead = (
(http_stats["mean"] - executor_stats["mean"]) / executor_stats["mean"]
) * 100
print(f" HTTP overhead vs Executor: {overhead:+.1f}%")
# -------------------------------------------------------------------------
# Benchmark Tests
# -------------------------------------------------------------------------
def test_benchmark_simple_computation(self):
"""Benchmark: Simple computation (no I/O)."""
print("\n\n### BENCHMARK: Simple Computation (no I/O) ###")
args = {"a": 100, "b": 200}
exec_stats = self._benchmark_executor("bench_simple", args, "simple")
http_stats = self._benchmark_http("bench_simple", args, "simple")
self._print_results([exec_stats, http_stats])
self._print_comparison(exec_stats, http_stats)
# Correctness check: 100 + 200 = 300
request = self._make_request()
result = execute_function(request, "bench_simple", args)
self.assertIsInstance(result, FunctionResult)
self.assertEqual(result.data["value"], 300)
def test_benchmark_single_query(self):
"""Benchmark: Single ORM query."""
print("\n\n### BENCHMARK: Single ORM Query ###")
args = {"user_id": self.test_user.id if self.test_user else 1}
exec_stats = self._benchmark_executor("bench_get_user", args, "single query")
http_stats = self._benchmark_http("bench_get_user", args, "single query")
self._print_results([exec_stats, http_stats])
self._print_comparison(exec_stats, http_stats)
# Correctness check: should return the test user's data
request = self._make_request()
result = execute_function(request, "bench_get_user", args)
self.assertIsInstance(result, FunctionResult)
self.assertEqual(result.data["id"], self.test_user.id)
def test_benchmark_list_query(self):
"""Benchmark: List query with serialization."""
print("\n\n### BENCHMARK: List Query (10 users) ###")
args = {"limit": 10}
exec_stats = self._benchmark_executor("bench_list_users", args, "list 10")
http_stats = self._benchmark_http("bench_list_users", args, "list 10")
self._print_results([exec_stats, http_stats])
self._print_comparison(exec_stats, http_stats)
# Correctness check: should return up to 10 users
request = self._make_request()
result = execute_function(request, "bench_list_users", args)
self.assertIsInstance(result, FunctionResult)
self.assertLessEqual(result.data["count"], 10)
self.assertEqual(len(result.data["users"]), result.data["count"])
def test_benchmark_aggregation(self):
"""Benchmark: Aggregation queries."""
print("\n\n### BENCHMARK: Aggregation (3 COUNT queries) ###")
args = {}
exec_stats = self._benchmark_executor("bench_user_stats", args, "aggregation")
http_stats = self._benchmark_http("bench_user_stats", args, "aggregation")
self._print_results([exec_stats, http_stats])
self._print_comparison(exec_stats, http_stats)
# Correctness check: stats should have non-negative counts
request = self._make_request()
result = execute_function(request, "bench_user_stats", args)
self.assertIsInstance(result, FunctionResult)
self.assertGreaterEqual(result.data["total_users"], 0)
self.assertGreaterEqual(result.data["active_users"], 0)
self.assertGreaterEqual(result.data["staff_count"], 0)
def test_benchmark_search_query(self):
"""Benchmark: Search with filter."""
print("\n\n### BENCHMARK: Search Query (LIKE + LIMIT) ###")
args = {"email_contains": "bench", "limit": 20}
exec_stats = self._benchmark_executor("bench_user_search", args, "search")
http_stats = self._benchmark_http("bench_user_search", args, "search")
self._print_results([exec_stats, http_stats])
self._print_comparison(exec_stats, http_stats)
# Correctness check: search results should contain "bench" in emails
request = self._make_request()
result = execute_function(request, "bench_user_search", args)
self.assertIsInstance(result, FunctionResult)
self.assertLessEqual(result.data["count"], 20)
for user in result.data["users"]:
self.assertIn("bench", user["email"].lower())
def test_summary(self):
"""Print summary of all benchmarks."""
print("\n\n" + "=" * 80)
print("BENCHMARK SUMMARY")
print("=" * 80)
print(f"Iterations per benchmark: {self.ITERATIONS}")
print(f"Warmup iterations: {self.WARMUP}")
print("\nKey findings:")
print("- 'Executor' simulates WebSocket RPC (direct function call)")
print("- 'HTTP' measures full request/response cycle")
print("- HTTP overhead includes: JSON parsing, CSRF, view dispatch")
print("- For I/O-bound operations, protocol overhead is negligible")
print("=" * 80)
# Verify bench_simple still produces correct output after all benchmarks
request = self._make_request()
result = execute_function(request, "bench_simple", {"a": 7, "b": 8})
self.assertIsInstance(result, FunctionResult)
self.assertEqual(result.data["value"], 15)
# =============================================================================
# Throughput Benchmark
# =============================================================================
class ThroughputBenchmark(TransactionTestCase):
"""
Measure requests per second (throughput) for server functions.
Tests both sequential and concurrent scenarios.
"""
DURATION_SECONDS = 2 # How long to run each throughput test
@classmethod
def setUpClass(cls):
super().setUpClass()
setup_benchmark_functions()
def setUp(self):
self.factory = RequestFactory()
self._create_test_users()
def _create_test_users(self):
"""Create test users for benchmarks."""
users = []
for i in range(100):
users.append(
User(
email=f"bench{i}@example.com",
is_active=i % 10 != 0,
is_staff=i < 5,
)
)
User.objects.bulk_create(users, ignore_conflicts=True)
self.test_user = User.objects.first()
def _make_request(self, body: dict) -> HttpRequest:
"""Create a POST request with JSON body."""
request = self.factory.post(
"/api/mizan/call/",
data=json.dumps(body),
content_type="application/json",
)
request.user = AnonymousUser()
request._dont_enforce_csrf_checks = True
return request
def _measure_throughput_executor(self, fn_name: str, args: dict) -> float:
"""Measure requests/second using direct executor calls."""
request = self._make_request({"fn": fn_name, "args": args})
# Warmup
for _ in range(10):
execute_function(request, fn_name, args)
# Measure
count = 0
start = time.perf_counter()
deadline = start + self.DURATION_SECONDS
while time.perf_counter() < deadline:
execute_function(request, fn_name, args)
count += 1
elapsed = time.perf_counter() - start
return count / elapsed
def _measure_throughput_http(self, fn_name: str, args: dict) -> float:
"""Measure requests/second using HTTP view calls."""
# Warmup
for _ in range(10):
request = self._make_request({"fn": fn_name, "args": args})
function_call_view(request)
# Measure
count = 0
start = time.perf_counter()
deadline = start + self.DURATION_SECONDS
while time.perf_counter() < deadline:
request = self._make_request({"fn": fn_name, "args": args})
function_call_view(request)
count += 1
elapsed = time.perf_counter() - start
return count / elapsed
def _print_throughput(self, label: str, executor_rps: float, http_rps: float):
"""Print throughput results."""
print(f"\n{label}:")
print(f" Executor (WebSocket): {executor_rps:,.0f} req/s")
print(f" HTTP: {http_rps:,.0f} req/s")
print(f" Ratio: {executor_rps/http_rps:.1f}x")
def test_throughput_simple(self):
"""Throughput: Simple computation (no I/O)."""
print("\n\n### THROUGHPUT: Simple Computation ###")
executor_rps = self._measure_throughput_executor(
"bench_simple", {"a": 1, "b": 2}
)
http_rps = self._measure_throughput_http("bench_simple", {"a": 1, "b": 2})
self._print_throughput("Simple (no I/O)", executor_rps, http_rps)
# Correctness check
request = self._make_request({"fn": "bench_simple", "args": {"a": 1, "b": 2}})
result = execute_function(request, "bench_simple", {"a": 1, "b": 2})
self.assertIsInstance(result, FunctionResult)
self.assertEqual(result.data["value"], 3)
def test_throughput_single_query(self):
"""Throughput: Single ORM query."""
print("\n\n### THROUGHPUT: Single ORM Query ###")
args = {"user_id": self.test_user.id if self.test_user else 1}
executor_rps = self._measure_throughput_executor("bench_get_user", args)
http_rps = self._measure_throughput_http("bench_get_user", args)
self._print_throughput("Single Query", executor_rps, http_rps)
# Correctness check
request = self._make_request({"fn": "bench_get_user", "args": args})
result = execute_function(request, "bench_get_user", args)
self.assertIsInstance(result, FunctionResult)
self.assertEqual(result.data["id"], self.test_user.id)
def test_throughput_list_query(self):
"""Throughput: List query."""
print("\n\n### THROUGHPUT: List Query (10 users) ###")
executor_rps = self._measure_throughput_executor(
"bench_list_users", {"limit": 10}
)
http_rps = self._measure_throughput_http("bench_list_users", {"limit": 10})
self._print_throughput("List Query", executor_rps, http_rps)
# Correctness check
request = self._make_request({"fn": "bench_list_users", "args": {"limit": 10}})
result = execute_function(request, "bench_list_users", {"limit": 10})
self.assertIsInstance(result, FunctionResult)
self.assertLessEqual(result.data["count"], 10)
def test_throughput_aggregation(self):
"""Throughput: Aggregation queries."""
print("\n\n### THROUGHPUT: Aggregation ###")
executor_rps = self._measure_throughput_executor("bench_user_stats", {})
http_rps = self._measure_throughput_http("bench_user_stats", {})
self._print_throughput("Aggregation", executor_rps, http_rps)
# Correctness check
request = self._make_request({"fn": "bench_user_stats", "args": {}})
result = execute_function(request, "bench_user_stats", {})
self.assertIsInstance(result, FunctionResult)
self.assertGreaterEqual(result.data["total_users"], 0)
def test_throughput_summary(self):
"""Print throughput summary."""
print("\n\n" + "=" * 80)
print("THROUGHPUT SUMMARY")
print("=" * 80)
print(f"Test duration: {self.DURATION_SECONDS}s per scenario")
print("\nNotes:")
print("- These are single-threaded sequential measurements")
print("- Real throughput scales with worker processes (gunicorn -w N)")
print("- Database queries are the bottleneck, not protocol overhead")
print("- Async workers (uvicorn) can handle more concurrent connections")
print("=" * 80)
# Verify bench_simple still produces correct output after all throughput tests
request = self._make_request({"fn": "bench_simple", "args": {"a": 10, "b": 20}})
result = execute_function(request, "bench_simple", {"a": 10, "b": 20})
self.assertIsInstance(result, FunctionResult)
self.assertEqual(result.data["value"], 30)

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,679 @@
"""
Stress tests for mizan.shapes — edge cases and deep nesting.
Models: Publisher → Author → Book → Chapter → Section (5 levels deep),
two FKs to same model, slug PK, UUID PK, self-referential FK, M2M,
nullable FKs, abstract bases, empty/zero/false values.
"""
import pytest
from typing import get_type_hints
from django.test import TestCase
from mizan.shapes import Shape, Diff, NestedDiff
import uuid
from tests.models import (
Publisher,
Author,
Book,
Chapter,
Section,
Tag,
Category,
)
# =============================================================================
# Shapes — varying projections
# =============================================================================
class TagShape(Shape[Tag]):
slug: str
label: str
class FlatAuthorShape(Shape[Author]):
id: int | None = None
name: str
class FlatBookShape(Shape[Book]):
id: int | None = None
title: str
is_published: bool
class BookCardShape(Shape[Book]):
id: int | None = None
title: str
isbn: str
page_count: int
is_published: bool
author: FlatAuthorShape # single nested, not list
class AuthorCardShape(Shape[Author]):
id: int | None = None
name: str
bio: str
books: list[FlatBookShape] = []
class SectionShape(Shape[Section]):
id: uuid.UUID | None = None
heading: str
body: str
position: int
class ChapterShape(Shape[Chapter]):
id: int | None = None
number: int
title: str
word_count: int
sections: list[SectionShape] = []
class BookDetailShape(Shape[Book]):
id: int | None = None
title: str
isbn: str
page_count: int
is_published: bool
author: FlatAuthorShape
chapters: list[ChapterShape] = []
tags: list[TagShape] = []
class AuthorDetailShape(Shape[Author]):
id: int | None = None
name: str
bio: str
books: list[BookDetailShape] = []
class PublisherDetailShape(Shape[Publisher]):
id: int | None = None
name: str
country: str
authors: list[AuthorDetailShape] = []
class BookWithEditorShape(Shape[Book]):
"""Two FKs to the same model (author + editor)."""
id: int | None = None
title: str
author: FlatAuthorShape
editor: FlatAuthorShape | None = None
class CategoryShape(Shape[Category]):
id: int | None = None
name: str
children: list["CategoryShape"] = []
# =============================================================================
# Shape class creation
# =============================================================================
class TestShapeClassCreation(TestCase):
def test_flat_shape_has_no_nested(self):
self.assertEqual(FlatAuthorShape._nested, {})
self.assertEqual(FlatAuthorShape._field_names, ["id", "name"])
def test_nested_shape_detected(self):
self.assertIn("books", AuthorCardShape._nested)
self.assertIs(AuthorCardShape._nested["books"], FlatBookShape)
def test_deep_nesting_spec_depth(self):
"""PublisherDetailShape → Author → Book → Chapter → Section."""
nested_keys = {
k for d in PublisherDetailShape._spec if isinstance(d, dict) for k in d
}
self.assertIn("authors", nested_keys)
author_spec = next(
d["authors"]
for d in PublisherDetailShape._spec
if isinstance(d, dict) and "authors" in d
)
author_nested = {k for d in author_spec if isinstance(d, dict) for k in d}
self.assertIn("books", author_nested)
def test_pk_field_resolution_integer(self):
self.assertEqual(FlatAuthorShape._pk_field, "id")
def test_pk_field_resolution_slug(self):
self.assertEqual(TagShape._pk_field, "slug")
def test_pk_field_resolution_uuid(self):
self.assertEqual(SectionShape._pk_field, "id")
def test_single_nested_not_list(self):
self.assertIn("author", BookCardShape._nested)
self.assertIs(BookCardShape._nested["author"], FlatAuthorShape)
def test_optional_nested(self):
"""BookWithEditorShape.editor is FlatAuthorShape | None.
_extract_shape_class needs to handle Optional/Union."""
# If this doesn't detect editor as nested, it's a known gap
if "editor" in BookWithEditorShape._nested:
self.assertIs(BookWithEditorShape._nested["editor"], FlatAuthorShape)
else:
self.skipTest(
"_extract_shape_class does not unwrap Optional[Shape] — known gap"
)
def test_self_referential_shape(self):
"""CategoryShape.children references itself."""
self.assertIn("children", CategoryShape._nested)
self.assertIs(CategoryShape._nested["children"], CategoryShape)
def test_multiple_shapes_same_model_independent(self):
self.assertLess(
len(FlatBookShape._field_names), len(BookDetailShape._field_names)
)
self.assertNotEqual(FlatBookShape._spec, BookDetailShape._spec)
# =============================================================================
# Queries
# =============================================================================
class TestShapeQuery(TestCase):
@classmethod
def setUpTestData(cls):
cls.publisher = Publisher.objects.create(name="Orbit", country="UK")
cls.mentor = Author.objects.create(
name="Ursula", bio="Legend", publisher=cls.publisher
)
cls.author = Author.objects.create(
name="Ann Leckie",
bio="Imperial Radch",
publisher=cls.publisher,
mentor=cls.mentor,
)
cls.editor = Author.objects.create(
name="Devi Pillai", bio="Editor", publisher=cls.publisher
)
cls.tag_sf = Tag.objects.create(slug="sci-fi", label="Science Fiction")
cls.tag_space = Tag.objects.create(slug="space-opera", label="Space Opera")
cls.book = Book.objects.create(
title="Ancillary Justice",
isbn="9780316246620",
page_count=386,
is_published=True,
author=cls.author,
editor=cls.editor,
)
cls.book.tags.add(cls.tag_sf, cls.tag_space)
cls.ch1 = Chapter.objects.create(
book=cls.book, number=1, title="The Body", word_count=5200
)
cls.ch2 = Chapter.objects.create(
book=cls.book, number=2, title="The Ship", word_count=4800
)
Section.objects.create(
chapter=cls.ch1, heading="Opening", body="...", position=0
)
Section.objects.create(
chapter=cls.ch1, heading="Discovery", body="...", position=1
)
cls.root_cat = Category.objects.create(name="Fiction")
cls.child_cat = Category.objects.create(name="Sci-Fi", parent=cls.root_cat)
Category.objects.create(name="Hard SF", parent=cls.child_cat)
# ── Flat ──
def test_flat_query_returns_minimal_fields(self):
results = FlatAuthorShape.query()
self.assertEqual(len(results), 3)
for r in results:
self.assertTrue(hasattr(r, "name"))
self.assertTrue(hasattr(r, "id"))
def test_flat_query_with_lambda_filter(self):
results = FlatAuthorShape.query(lambda qs: qs.filter(name="Ann Leckie"))
self.assertEqual(len(results), 1)
self.assertEqual(results[0].name, "Ann Leckie")
def test_flat_query_with_raw_queryset(self):
qs = Author.objects.filter(mentor__isnull=False)
results = FlatAuthorShape.query(qs)
self.assertEqual(len(results), 1)
self.assertEqual(results[0].name, "Ann Leckie")
# ── Nested ──
def test_single_nested_fk(self):
results = BookCardShape.query(lambda qs: qs.filter(pk=self.book.pk))
self.assertEqual(len(results), 1)
self.assertEqual(results[0].author.name, "Ann Leckie")
def test_list_nested_reverse_fk(self):
results = AuthorCardShape.query(lambda qs: qs.filter(pk=self.author.pk))
self.assertEqual(len(results), 1)
self.assertEqual(len(results[0].books), 1)
self.assertEqual(results[0].books[0].title, "Ancillary Justice")
def test_deep_nesting_book_chapters_sections(self):
results = BookDetailShape.query(lambda qs: qs.filter(pk=self.book.pk))
self.assertEqual(len(results), 1)
book = results[0]
self.assertEqual(len(book.chapters), 2)
ch1 = next(c for c in book.chapters if c.number == 1)
self.assertEqual(len(ch1.sections), 2)
def test_full_depth_publisher_to_section(self):
"""5 levels: Publisher → Author → Book → Chapter → Section."""
results = PublisherDetailShape.query(lambda qs: qs.filter(pk=self.publisher.pk))
self.assertEqual(len(results), 1)
pub = results[0]
self.assertEqual(len(pub.authors), 3)
leckie = next(a for a in pub.authors if a.name == "Ann Leckie")
self.assertEqual(len(leckie.books), 1)
self.assertEqual(len(leckie.books[0].chapters), 2)
def test_two_fks_to_same_model(self):
results = BookWithEditorShape.query(lambda qs: qs.filter(pk=self.book.pk))
self.assertEqual(len(results), 1)
self.assertEqual(results[0].author.name, "Ann Leckie")
if "editor" in BookWithEditorShape._nested:
self.assertIsNotNone(results[0].editor)
self.assertEqual(results[0].editor.name, "Devi Pillai")
def test_nullable_fk_returns_none(self):
book_no_editor = Book.objects.create(
title="Provenance",
isbn="9780316246699",
page_count=448,
is_published=True,
author=self.author,
editor=None,
)
results = BookWithEditorShape.query(lambda qs: qs.filter(pk=book_no_editor.pk))
self.assertEqual(len(results), 1)
if "editor" in BookWithEditorShape._nested:
self.assertIsNone(results[0].editor)
def test_m2m_tags(self):
results = BookDetailShape.query(lambda qs: qs.filter(pk=self.book.pk))
book = results[0]
self.assertEqual(len(book.tags), 2)
slugs = {t.slug for t in book.tags}
self.assertEqual(slugs, {"sci-fi", "space-opera"})
def test_slug_pk_shape(self):
results = TagShape.query()
self.assertEqual(len(results), 2)
self.assertTrue(all(isinstance(r.slug, str) for r in results))
def test_relation_qs_filters_nested(self):
results = AuthorCardShape.query(
lambda qs: qs.filter(pk=self.author.pk),
books=lambda qs: qs.filter(is_published=True),
)
self.assertEqual(len(results), 1)
self.assertTrue(all(b.is_published for b in results[0].books))
def test_empty_nested_list(self):
results = AuthorCardShape.query(lambda qs: qs.filter(pk=self.editor.pk))
self.assertEqual(len(results), 1)
self.assertEqual(results[0].books, [])
# ── Query efficiency ──
def test_flat_query_is_single_query(self):
with self.assertNumQueries(1):
FlatAuthorShape.query()
def test_nested_query_uses_prefetch(self):
with self.assertNumQueries(2):
AuthorCardShape.query()
# =============================================================================
# Diff
# =============================================================================
class TestDiff(TestCase):
@classmethod
def setUpTestData(cls):
cls.publisher = Publisher.objects.create(name="Tor", country="US")
cls.author = Author.objects.create(
name="Brandon Sanderson", bio="Cosmere", publisher=cls.publisher
)
cls.book = Book.objects.create(
title="Mistborn",
isbn="9780765311788",
page_count=541,
is_published=True,
author=cls.author,
)
cls.ch1 = Chapter.objects.create(
book=cls.book, number=1, title="Ash", word_count=6000
)
cls.ch2 = Chapter.objects.create(
book=cls.book, number=2, title="Mist", word_count=5500
)
# ── Single item ──
def test_diff_no_changes(self):
shape = BookCardShape(
id=self.book.pk,
title="Mistborn",
isbn="9780765311788",
page_count=541,
is_published=True,
author=FlatAuthorShape(id=self.author.pk, name="Brandon Sanderson"),
)
d = shape.diff()
self.assertFalse(d.is_new)
self.assertEqual(d.changed, {})
def test_diff_detects_field_change(self):
shape = BookCardShape(
id=self.book.pk,
title="Mistborn: The Final Empire",
isbn="9780765311788",
page_count=541,
is_published=True,
author=FlatAuthorShape(id=self.author.pk, name="Brandon Sanderson"),
)
d = shape.diff()
self.assertIn("title", d.changed)
self.assertEqual(d.changed["title"], "Mistborn: The Final Empire")
def test_diff_multiple_field_changes(self):
shape = BookCardShape(
id=self.book.pk,
title="Mistborn: TFE",
isbn="9780765311788",
page_count=600,
is_published=True,
author=FlatAuthorShape(id=self.author.pk, name="Brandon Sanderson"),
)
d = shape.diff()
self.assertIn("title", d.changed)
self.assertIn("page_count", d.changed)
self.assertNotIn("isbn", d.changed)
def test_diff_new_item(self):
shape = FlatBookShape(id=None, title="Elantris", is_published=True)
d = shape.diff()
self.assertTrue(d.is_new)
self.assertIn("title", d.changed)
def test_diff_nonexistent_pk_raises(self):
shape = FlatBookShape(id=999999, title="Nope", is_published=False)
with self.assertRaises(Book.DoesNotExist):
shape.diff()
# ── Nested ──
def test_nested_diff_detects_updated_chapter(self):
shape = BookDetailShape(
id=self.book.pk,
title="Mistborn",
isbn="9780765311788",
page_count=541,
is_published=True,
author=FlatAuthorShape(id=self.author.pk, name="Brandon Sanderson"),
chapters=[
ChapterShape(
id=self.ch1.pk,
number=1,
title="Ash Falls",
word_count=6000,
sections=[],
),
ChapterShape(
id=self.ch2.pk, number=2, title="Mist", word_count=5500, sections=[]
),
],
tags=[],
)
d = shape.diff()
self.assertEqual(len(d.chapters.updated), 1)
self.assertEqual(d.chapters.updated[0].title, "Ash Falls")
def test_nested_diff_detects_created(self):
shape = BookDetailShape(
id=self.book.pk,
title="Mistborn",
isbn="9780765311788",
page_count=541,
is_published=True,
author=FlatAuthorShape(id=self.author.pk, name="Brandon Sanderson"),
chapters=[
ChapterShape(
id=self.ch1.pk, number=1, title="Ash", word_count=6000, sections=[]
),
ChapterShape(
id=self.ch2.pk, number=2, title="Mist", word_count=5500, sections=[]
),
ChapterShape(
id=None, number=3, title="New Chapter", word_count=0, sections=[]
),
],
tags=[],
)
d = shape.diff()
self.assertEqual(len(d.chapters.created), 1)
def test_nested_diff_detects_deleted(self):
shape = BookDetailShape(
id=self.book.pk,
title="Mistborn",
isbn="9780765311788",
page_count=541,
is_published=True,
author=FlatAuthorShape(id=self.author.pk, name="Brandon Sanderson"),
chapters=[
ChapterShape(
id=self.ch1.pk, number=1, title="Ash", word_count=6000, sections=[]
),
],
tags=[],
)
d = shape.diff()
self.assertIn(self.ch2.pk, d.chapters.deleted)
def test_nested_diff_combined_operations(self):
shape = BookDetailShape(
id=self.book.pk,
title="Mistborn",
isbn="9780765311788",
page_count=541,
is_published=True,
author=FlatAuthorShape(id=self.author.pk, name="Brandon Sanderson"),
chapters=[
ChapterShape(
id=self.ch1.pk,
number=1,
title="Ash Rewritten",
word_count=7000,
sections=[],
),
ChapterShape(
id=None, number=3, title="Epilogue", word_count=2000, sections=[]
),
],
tags=[],
)
d = shape.diff()
self.assertEqual(len(d.chapters.updated), 1)
self.assertEqual(len(d.chapters.deleted), 1)
self.assertEqual(len(d.chapters.created), 1)
# ── Strict Diff access ──
def test_diff_strict_getattr_raises_on_typo(self):
shape = FlatBookShape(id=self.book.pk, title="Mistborn", is_published=True)
d = shape.diff()
with self.assertRaises(AttributeError):
_ = d.chapterz
def test_diff_strict_nested_raises_on_typo(self):
shape = FlatBookShape(id=self.book.pk, title="Mistborn", is_published=True)
d = shape.diff()
with self.assertRaises(KeyError):
d.nested("chapterz")
def test_diff_strict_shows_valid_names(self):
shape = BookDetailShape(
id=self.book.pk,
title="Mistborn",
isbn="9780765311788",
page_count=541,
is_published=True,
author=FlatAuthorShape(id=self.author.pk, name="Brandon Sanderson"),
chapters=[],
tags=[],
)
d = shape.diff()
with self.assertRaises(AttributeError) as ctx:
_ = d.bogus
self.assertIn("chapters", str(ctx.exception))
# ── diff_many ──
def test_diff_many_single_query_for_existing(self):
items = [FlatBookShape(id=self.book.pk, title="Renamed", is_published=True)]
results = FlatBookShape.diff_many(items)
self.assertEqual(len(results), 1)
_, d = results[0]
self.assertIn("title", d.changed)
def test_diff_many_mixed_new_and_existing(self):
items = [
FlatBookShape(id=self.book.pk, title="Mistborn", is_published=True),
FlatBookShape(id=None, title="New Book", is_published=False),
]
results = FlatBookShape.diff_many(items)
new = [d for _, d in results if d.is_new]
existing = [d for _, d in results if not d.is_new]
self.assertEqual(len(new), 1)
self.assertEqual(len(existing), 1)
def test_diff_many_nonexistent_raises(self):
items = [FlatBookShape(id=999999, title="Ghost", is_published=False)]
with self.assertRaises(Book.DoesNotExist):
FlatBookShape.diff_many(items)
def test_diff_many_batched_query(self):
book2 = Book.objects.create(
title="Warbreaker",
isbn="9780765320308",
page_count=592,
is_published=True,
author=self.author,
)
items = [
FlatBookShape(id=self.book.pk, title="Mistborn", is_published=True),
FlatBookShape(id=book2.pk, title="Warbreaker Updated", is_published=True),
]
with self.assertNumQueries(1):
FlatBookShape.diff_many(items)
def test_diff_many_empty(self):
self.assertEqual(FlatBookShape.diff_many([]), [])
# =============================================================================
# Edge cases
# =============================================================================
class TestEdgeCases(TestCase):
@classmethod
def setUpTestData(cls):
cls.publisher = Publisher.objects.create(name="Edge Cases Ltd", country="XX")
cls.author = Author.objects.create(
name="Edge Author", bio="", publisher=cls.publisher
)
def test_empty_table_returns_empty_list(self):
Tag.objects.all().delete()
results = TagShape.query()
self.assertEqual(results, [])
def test_empty_string_fields(self):
results = AuthorCardShape.query(lambda qs: qs.filter(pk=self.author.pk))
self.assertEqual(results[0].bio, "")
def test_boolean_false_is_not_missing(self):
book = Book.objects.create(
title="Unpublished",
isbn="0000000000000",
page_count=0,
is_published=False,
author=self.author,
)
results = FlatBookShape.query(lambda qs: qs.filter(pk=book.pk))
self.assertIs(results[0].is_published, False)
def test_zero_integer_is_not_missing(self):
book = Book.objects.create(
title="Empty",
isbn="0000000000001",
page_count=0,
is_published=False,
author=self.author,
)
results = BookCardShape.query(lambda qs: qs.filter(pk=book.pk))
self.assertEqual(results[0].page_count, 0)
def test_large_queryset(self):
books = [
Book(
title=f"Book {i}",
isbn=f"{i:013d}",
page_count=i * 10,
is_published=i % 2 == 0,
author=self.author,
)
for i in range(100)
]
Book.objects.bulk_create(books)
results = FlatBookShape.query(lambda qs: qs.filter(author=self.author))
self.assertGreaterEqual(len(results), 100)
def test_diff_on_boolean_change(self):
book = Book.objects.create(
title="Toggle",
isbn="1111111111111",
page_count=100,
is_published=False,
author=self.author,
)
shape = FlatBookShape(id=book.pk, title="Toggle", is_published=True)
d = shape.diff()
self.assertIn("is_published", d.changed)
self.assertIs(d.changed["is_published"], True)
def test_diff_unchanged_returns_empty(self):
book = Book.objects.create(
title="Same",
isbn="2222222222222",
page_count=200,
is_published=True,
author=self.author,
)
shape = FlatBookShape(id=book.pk, title="Same", is_published=True)
d = shape.diff()
self.assertEqual(d.changed, {})
self.assertFalse(d.is_new)

View File

@@ -0,0 +1,42 @@
"""
mizan URL Configuration
HTTP endpoints:
- GET /session/ - Initialize session and get CSRF token (for SSR)
- POST /call/ - Server function calls (HTTP transport)
- GET /ctx/<name>/ - Bundled context fetch (all functions in a named context)
Security:
- Schema export is NOT exposed over HTTP to prevent API enumeration
- Use the management command instead: python manage.py export_mizan_schema
"""
from django.http import JsonResponse
from django.middleware.csrf import get_token
from django.urls import path
from django.views.decorators.csrf import ensure_csrf_cookie
from .client.executor import function_call_view, context_fetch_view
app_name = "mizan"
@ensure_csrf_cookie
def session_init_view(request):
"""
Initialize a Django session and return the CSRF token.
Used by SSR to establish a session before making authenticated requests.
The @ensure_csrf_cookie decorator ensures the csrftoken cookie is set.
Returns:
{ "csrfToken": "..." }
"""
return JsonResponse({"csrfToken": get_token(request)})
urlpatterns = [
path("session/", session_init_view, name="session-init"),
path("call/", function_call_view, name="function-call"),
path("ctx/<str:context_name>/", context_fetch_view, name="context-fetch"),
]