""" RPC dispatch — looks up registered functions, validates input against the function's Pydantic Input model, executes, and returns the serialized result. Errors raise typed exceptions (MizanError subclasses). Wire those to JSON responses by registering `mizan_exception_handler` on the FastAPI app, or let them propagate to your own handler. """ from __future__ import annotations from enum import Enum from typing import Any from fastapi.encoders import jsonable_encoder from pydantic import BaseModel, ValidationError from mizan_core.registry import get_context_groups, get_function from mizan_core.type_utils import types_match_for_merge # ─── Error taxonomy ───────────────────────────────────────────────────────── class ErrorCode(str, Enum): NOT_FOUND = "NOT_FOUND" BAD_REQUEST = "BAD_REQUEST" VALIDATION_ERROR = "VALIDATION_ERROR" UNAUTHORIZED = "UNAUTHORIZED" FORBIDDEN = "FORBIDDEN" NOT_IMPLEMENTED = "NOT_IMPLEMENTED" INTERNAL_ERROR = "INTERNAL_ERROR" _STATUS = { ErrorCode.NOT_FOUND: 404, ErrorCode.BAD_REQUEST: 400, ErrorCode.VALIDATION_ERROR: 422, ErrorCode.UNAUTHORIZED: 401, ErrorCode.FORBIDDEN: 403, ErrorCode.NOT_IMPLEMENTED: 501, ErrorCode.INTERNAL_ERROR: 500, } class MizanError(Exception): """Base for protocol-level dispatch errors.""" code: ErrorCode = ErrorCode.INTERNAL_ERROR def __init__(self, message: str, *, details: dict[str, Any] | None = None) -> None: super().__init__(message) self.message = message self.details = details @property def status_code(self) -> int: return _STATUS[self.code] class NotFound(MizanError): code = ErrorCode.NOT_FOUND # noqa: E701 class BadRequest(MizanError): code = ErrorCode.BAD_REQUEST # noqa: E701 class ValidationFailed(MizanError): code = ErrorCode.VALIDATION_ERROR # noqa: E701 class Unauthorized(MizanError): code = ErrorCode.UNAUTHORIZED # noqa: E701 class Forbidden(MizanError): code = ErrorCode.FORBIDDEN # noqa: E701 class NotImplementedYet(MizanError): code = ErrorCode.NOT_IMPLEMENTED # noqa: E701 class InternalError(MizanError): code = ErrorCode.INTERNAL_ERROR # noqa: E701 # ─── Auth ─────────────────────────────────────────────────────────────────── def _user(request: Any) -> Any: return getattr(getattr(request, "state", None), "user", None) def _is_authenticated(user: Any) -> bool: return bool(user) and getattr(user, "is_authenticated", True) def _enforce_auth(request: Any, requirement: Any) -> None: """Verify the request meets the function's @client(auth=...) requirement, or raise.""" if requirement is None: return user = _user(request) match requirement: case True | "required": if not _is_authenticated(user): raise Unauthorized("Authentication required") case "staff": if not _is_authenticated(user): raise Unauthorized("Authentication required") if not getattr(user, "is_staff", False): raise Forbidden("Staff access required") case "superuser": if not _is_authenticated(user): raise Unauthorized("Authentication required") if not getattr(user, "is_superuser", False): raise Forbidden("Superuser access required") case f if callable(f): if not f(request): raise Forbidden("Permission denied") case other: raise InternalError(f"Unknown auth requirement: {other!r}") # ─── Input validation ─────────────────────────────────────────────────────── def _validate_input(input_cls: Any, input_data: Any) -> BaseModel | None: """Validate input_data against the function's Input model. Returns the instance or None.""" if input_cls in (None, BaseModel) or not getattr(input_cls, "model_fields", None): return None fields = input_cls.model_fields required = [name for name, f in fields.items() if f.is_required()] if not input_data: if required: raise ValidationFailed( "Input validation failed", details={"fields": {name: ["Field required"] for name in required}}, ) return input_cls() if not isinstance(input_data, dict): raise BadRequest(f"Input must be an object, got {type(input_data).__name__}") try: return input_cls(**input_data) except ValidationError as e: raise ValidationFailed( "Input validation failed", details={"errors": e.errors()}, ) from e # ─── Dispatch ─────────────────────────────────────────────────────────────── def _resolve_function(fn_name: str) -> Any: view_class = get_function(fn_name) if view_class is None: raise NotFound("Function not found") if getattr(view_class, "_meta", {}).get("private"): raise Forbidden("Function is not client-callable") return view_class def _serialize(result: Any) -> Any: # jsonable_encoder walks BaseModel / list / dict recursively, so list[BaseModel] # (and nested shapes) come out wire-ready without a per-shape branch here. return jsonable_encoder(result) async def execute_function( request: Any, fn_name: str, input_data: dict[str, Any] | None = None, ) -> Any: """Dispatch a registered function. Returns the serialized result, or raises MizanError. Awaits `view.acall` — async handlers run on the loop, sync handlers run in the default threadpool, both via the same entrypoint. """ view_class = _resolve_function(fn_name) _enforce_auth(request, view_class._meta.get("auth")) view = view_class(request) validated = _validate_input(view.Input, input_data) try: result = await view.acall(validated) except NotImplementedError as e: raise NotImplementedYet(str(e) or "Not implemented") from e except MizanError: raise except Exception as e: raise InternalError(str(e)) from e return _serialize(result) # ─── Invalidation ─────────────────────────────────────────────────────────── def compute_invalidation(view_class: Any, input_data: dict[str, Any] | None) -> list[Any]: """Build the `invalidate` list from @client(affects=...) metadata, auto-scoping when arg names match context params.""" affects = getattr(view_class, "_meta", {}).get("affects") or [] return [_invalidation_target(target, input_data or {}) for target in affects] def compute_merges(view_class: Any, input_data: dict[str, Any] | None, result: Any) -> list[dict[str, Any]]: """Build the `merge` list from @client(merge=...) metadata. Each entry is `{context, slot, value, params?}` where `slot` names the function inside the context bundle the value lands in. The slot is resolved server-side via `types_match_for_merge` so the kernel does no shape inference — the server has the schema, type-checked routing lives here. Entries whose slot can't be uniquely resolved are dropped with a warning; the consumer falls back to refetch via `affects`. """ targets = getattr(view_class, "_meta", {}).get("merge") or [] if not targets: return [] mutation_output = getattr(view_class, "Output", None) out: list[dict[str, Any]] = [] for ctx_name in targets: slot = _resolve_merge_slot(ctx_name, mutation_output) if slot is None: continue entry: dict[str, Any] = {"context": ctx_name, "slot": slot, "value": result} scoped = _scoped_params(ctx_name, input_data or {}) if scoped: entry["params"] = scoped out.append(entry) return out def _resolve_merge_slot(context_name: str, mutation_output: Any) -> str | None: """Find the unique function-name slot whose return type matches the mutation's output. Returns None on no match or ambiguous match (multiple candidates). """ if mutation_output is None: return None matches: list[str] = [] for fn_name in get_context_groups().get(context_name, []): fn_cls = get_function(fn_name) if fn_cls is None: continue fn_output = getattr(fn_cls, "Output", None) if fn_output is not None and types_match_for_merge(fn_output, mutation_output): matches.append(fn_name) return matches[0] if len(matches) == 1 else None def _scoped_params(context_name: str, input_data: dict[str, Any]) -> dict[str, Any]: """Match input args against the context's declared Input field names.""" fn_names = get_context_groups().get(context_name, []) declared: set[str] = set() for fn_name in fn_names: fn_cls = get_function(fn_name) if fn_cls is None: continue input_cls = getattr(fn_cls, "Input", None) if input_cls and input_cls is not BaseModel and hasattr(input_cls, "model_fields"): declared.update(input_cls.model_fields.keys()) return {k: v for k, v in input_data.items() if k in declared} def _invalidation_target(target: dict[str, Any], input_data: dict[str, Any]) -> Any: match target.get("type"): case "context": name = target["name"] scoped = _scoped_params(name, input_data) return {"context": name, "params": scoped} if scoped else name case "function": return {"function": target["name"]} case _: return target