The @client(merge=[context, ...]) decorator lets a mutation patch its
return value directly into the cached context bundle by matching the
mutation's Output type against each context-function's Output type
to identify the slot, then splicing server-side. Kernel runs
splice_slot on the response to apply locally — no refetch, no
invalidate-cascade.
Lands H14, H15, H16, M19, M20 from ISSUES.md.
Backends (Django + FastAPI):
_resolve_merges() in both executors walks @client(merge=...) targets,
resolves the per-context slot via types_match_for_merge, and emits
{context, slot, value, params?} entries on the response. Param
auto-scoping mirrors _resolve_invalidation's tier-1 logic.
Frontend kernel (mizan-base):
Response handler reads the merge[] array and applies splice_slot
for each entry — locates the cached context bundle by name+params,
overwrites the named slot with the new value, notifies subscribers.
Core (mizan-python):
@client decorator extended with merge= parameter. Schema export
threads merge metadata onto the OpenAPI x-mizan-functions entries.
Examples / fixtures:
fastapi-react-site harness exercises merge + Playwright spec covers
the end-to-end happy path (mutation → instant UI update without
network refetch). AFI fixture's rename_user function is the
canonical merge target.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
264 lines
10 KiB
Python
264 lines
10 KiB
Python
"""
|
|
RPC dispatch — looks up registered functions, validates input against the
|
|
function's Pydantic Input model, executes, and returns the serialized result.
|
|
|
|
Errors raise typed exceptions (MizanError subclasses). Wire those to JSON
|
|
responses by registering `mizan_exception_handler` on the FastAPI app, or
|
|
let them propagate to your own handler.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
from enum import Enum
|
|
from typing import Any
|
|
|
|
from fastapi.encoders import jsonable_encoder
|
|
from pydantic import BaseModel, ValidationError
|
|
|
|
from mizan_core.registry import get_context_groups, get_function
|
|
from mizan_core.type_utils import types_match_for_merge
|
|
|
|
|
|
# ─── Error taxonomy ─────────────────────────────────────────────────────────
|
|
|
|
|
|
class ErrorCode(str, Enum):
|
|
NOT_FOUND = "NOT_FOUND"
|
|
BAD_REQUEST = "BAD_REQUEST"
|
|
VALIDATION_ERROR = "VALIDATION_ERROR"
|
|
UNAUTHORIZED = "UNAUTHORIZED"
|
|
FORBIDDEN = "FORBIDDEN"
|
|
NOT_IMPLEMENTED = "NOT_IMPLEMENTED"
|
|
INTERNAL_ERROR = "INTERNAL_ERROR"
|
|
|
|
|
|
_STATUS = {
|
|
ErrorCode.NOT_FOUND: 404,
|
|
ErrorCode.BAD_REQUEST: 400,
|
|
ErrorCode.VALIDATION_ERROR: 422,
|
|
ErrorCode.UNAUTHORIZED: 401,
|
|
ErrorCode.FORBIDDEN: 403,
|
|
ErrorCode.NOT_IMPLEMENTED: 501,
|
|
ErrorCode.INTERNAL_ERROR: 500,
|
|
}
|
|
|
|
|
|
class MizanError(Exception):
|
|
"""Base for protocol-level dispatch errors."""
|
|
|
|
code: ErrorCode = ErrorCode.INTERNAL_ERROR
|
|
|
|
def __init__(self, message: str, *, details: dict[str, Any] | None = None) -> None:
|
|
super().__init__(message)
|
|
self.message = message
|
|
self.details = details
|
|
|
|
@property
|
|
def status_code(self) -> int:
|
|
return _STATUS[self.code]
|
|
|
|
|
|
class NotFound(MizanError): code = ErrorCode.NOT_FOUND # noqa: E701
|
|
class BadRequest(MizanError): code = ErrorCode.BAD_REQUEST # noqa: E701
|
|
class ValidationFailed(MizanError): code = ErrorCode.VALIDATION_ERROR # noqa: E701
|
|
class Unauthorized(MizanError): code = ErrorCode.UNAUTHORIZED # noqa: E701
|
|
class Forbidden(MizanError): code = ErrorCode.FORBIDDEN # noqa: E701
|
|
class NotImplementedYet(MizanError): code = ErrorCode.NOT_IMPLEMENTED # noqa: E701
|
|
class InternalError(MizanError): code = ErrorCode.INTERNAL_ERROR # noqa: E701
|
|
|
|
|
|
# ─── Auth ───────────────────────────────────────────────────────────────────
|
|
|
|
|
|
def _user(request: Any) -> Any:
|
|
return getattr(getattr(request, "state", None), "user", None)
|
|
|
|
|
|
def _is_authenticated(user: Any) -> bool:
|
|
return bool(user) and getattr(user, "is_authenticated", True)
|
|
|
|
|
|
def _enforce_auth(request: Any, requirement: Any) -> None:
|
|
"""Verify the request meets the function's @client(auth=...) requirement, or raise."""
|
|
if requirement is None:
|
|
return
|
|
|
|
user = _user(request)
|
|
|
|
match requirement:
|
|
case True | "required":
|
|
if not _is_authenticated(user):
|
|
raise Unauthorized("Authentication required")
|
|
case "staff":
|
|
if not _is_authenticated(user):
|
|
raise Unauthorized("Authentication required")
|
|
if not getattr(user, "is_staff", False):
|
|
raise Forbidden("Staff access required")
|
|
case "superuser":
|
|
if not _is_authenticated(user):
|
|
raise Unauthorized("Authentication required")
|
|
if not getattr(user, "is_superuser", False):
|
|
raise Forbidden("Superuser access required")
|
|
case f if callable(f):
|
|
if not f(request):
|
|
raise Forbidden("Permission denied")
|
|
case other:
|
|
raise InternalError(f"Unknown auth requirement: {other!r}")
|
|
|
|
|
|
# ─── Input validation ───────────────────────────────────────────────────────
|
|
|
|
|
|
def _validate_input(input_cls: Any, input_data: Any) -> BaseModel | None:
|
|
"""Validate input_data against the function's Input model. Returns the instance or None."""
|
|
if input_cls in (None, BaseModel) or not getattr(input_cls, "model_fields", None):
|
|
return None
|
|
|
|
fields = input_cls.model_fields
|
|
required = [name for name, f in fields.items() if f.is_required()]
|
|
|
|
if not input_data:
|
|
if required:
|
|
raise ValidationFailed(
|
|
"Input validation failed",
|
|
details={"fields": {name: ["Field required"] for name in required}},
|
|
)
|
|
return input_cls()
|
|
|
|
if not isinstance(input_data, dict):
|
|
raise BadRequest(f"Input must be an object, got {type(input_data).__name__}")
|
|
|
|
try:
|
|
return input_cls(**input_data)
|
|
except ValidationError as e:
|
|
raise ValidationFailed(
|
|
"Input validation failed",
|
|
details={"errors": e.errors()},
|
|
) from e
|
|
|
|
|
|
# ─── Dispatch ───────────────────────────────────────────────────────────────
|
|
|
|
|
|
def _resolve_function(fn_name: str) -> Any:
|
|
view_class = get_function(fn_name)
|
|
if view_class is None:
|
|
raise NotFound("Function not found")
|
|
if getattr(view_class, "_meta", {}).get("private"):
|
|
raise Forbidden("Function is not client-callable")
|
|
return view_class
|
|
|
|
|
|
def _serialize(result: Any) -> Any:
|
|
# jsonable_encoder walks BaseModel / list / dict recursively, so list[BaseModel]
|
|
# (and nested shapes) come out wire-ready without a per-shape branch here.
|
|
return jsonable_encoder(result)
|
|
|
|
|
|
async def execute_function(
|
|
request: Any,
|
|
fn_name: str,
|
|
input_data: dict[str, Any] | None = None,
|
|
) -> Any:
|
|
"""Dispatch a registered function. Returns the serialized result, or raises MizanError.
|
|
|
|
Awaits `view.acall` — async handlers run on the loop, sync handlers run
|
|
in the default threadpool, both via the same entrypoint.
|
|
"""
|
|
view_class = _resolve_function(fn_name)
|
|
_enforce_auth(request, view_class._meta.get("auth"))
|
|
|
|
view = view_class(request)
|
|
validated = _validate_input(view.Input, input_data)
|
|
|
|
try:
|
|
result = await view.acall(validated)
|
|
except NotImplementedError as e:
|
|
raise NotImplementedYet(str(e) or "Not implemented") from e
|
|
except MizanError:
|
|
raise
|
|
except Exception as e:
|
|
raise InternalError(str(e)) from e
|
|
|
|
return _serialize(result)
|
|
|
|
|
|
# ─── Invalidation ───────────────────────────────────────────────────────────
|
|
|
|
|
|
def compute_invalidation(view_class: Any, input_data: dict[str, Any] | None) -> list[Any]:
|
|
"""Build the `invalidate` list from @client(affects=...) metadata, auto-scoping when arg names match context params."""
|
|
affects = getattr(view_class, "_meta", {}).get("affects") or []
|
|
return [_invalidation_target(target, input_data or {}) for target in affects]
|
|
|
|
|
|
def compute_merges(view_class: Any, input_data: dict[str, Any] | None, result: Any) -> list[dict[str, Any]]:
|
|
"""Build the `merge` list from @client(merge=...) metadata.
|
|
|
|
Each entry is `{context, slot, value, params?}` where `slot` names the
|
|
function inside the context bundle the value lands in. The slot is
|
|
resolved server-side via `types_match_for_merge` so the kernel does
|
|
no shape inference — the server has the schema, type-checked routing
|
|
lives here. Entries whose slot can't be uniquely resolved are dropped
|
|
with a warning; the consumer falls back to refetch via `affects`.
|
|
"""
|
|
targets = getattr(view_class, "_meta", {}).get("merge") or []
|
|
if not targets:
|
|
return []
|
|
mutation_output = getattr(view_class, "Output", None)
|
|
out: list[dict[str, Any]] = []
|
|
for ctx_name in targets:
|
|
slot = _resolve_merge_slot(ctx_name, mutation_output)
|
|
if slot is None:
|
|
continue
|
|
entry: dict[str, Any] = {"context": ctx_name, "slot": slot, "value": result}
|
|
scoped = _scoped_params(ctx_name, input_data or {})
|
|
if scoped:
|
|
entry["params"] = scoped
|
|
out.append(entry)
|
|
return out
|
|
|
|
|
|
def _resolve_merge_slot(context_name: str, mutation_output: Any) -> str | None:
|
|
"""Find the unique function-name slot whose return type matches the mutation's output.
|
|
|
|
Returns None on no match or ambiguous match (multiple candidates).
|
|
"""
|
|
if mutation_output is None:
|
|
return None
|
|
matches: list[str] = []
|
|
for fn_name in get_context_groups().get(context_name, []):
|
|
fn_cls = get_function(fn_name)
|
|
if fn_cls is None:
|
|
continue
|
|
fn_output = getattr(fn_cls, "Output", None)
|
|
if fn_output is not None and types_match_for_merge(fn_output, mutation_output):
|
|
matches.append(fn_name)
|
|
return matches[0] if len(matches) == 1 else None
|
|
|
|
|
|
def _scoped_params(context_name: str, input_data: dict[str, Any]) -> dict[str, Any]:
|
|
"""Match input args against the context's declared Input field names."""
|
|
fn_names = get_context_groups().get(context_name, [])
|
|
declared: set[str] = set()
|
|
for fn_name in fn_names:
|
|
fn_cls = get_function(fn_name)
|
|
if fn_cls is None:
|
|
continue
|
|
input_cls = getattr(fn_cls, "Input", None)
|
|
if input_cls and input_cls is not BaseModel and hasattr(input_cls, "model_fields"):
|
|
declared.update(input_cls.model_fields.keys())
|
|
return {k: v for k, v in input_data.items() if k in declared}
|
|
|
|
|
|
def _invalidation_target(target: dict[str, Any], input_data: dict[str, Any]) -> Any:
|
|
match target.get("type"):
|
|
case "context":
|
|
name = target["name"]
|
|
scoped = _scoped_params(name, input_data)
|
|
return {"context": name, "params": scoped} if scoped else name
|
|
case "function":
|
|
return {"function": target["name"]}
|
|
case _:
|
|
return target
|