mizan-fastapi: more Pythonic and declarative

Reworked the MVP code along the lines Ryth flagged. Same behavior
(11/11 tests still pass), tighter idiom.

executor.py:
- Replaced FunctionResult / FunctionError dataclasses with a MizanError
  exception hierarchy (NotFound, BadRequest, ValidationFailed,
  Unauthorized, Forbidden, NotImplementedYet, InternalError). Each
  carries its own ErrorCode + HTTP status; the dispatcher path raises
  rather than returning sentinel objects.
- Auth check uses match/case for the requirement (True / 'staff' /
  'superuser' / callable / other) — single declarative dispatch instead
  of an if/elif chain.
- Broke up the single 80-line execute_function into focused helpers:
  _resolve_function, _enforce_auth, _validate_input, _serialize,
  _invalidation_target. The execute_function body now reads as five
  declarative steps.
- Input validation uses Pydantic's model_fields[name].is_required()
  directly and a list comprehension for required-field reporting,
  instead of round-tripping through model_json_schema().

router.py:
- POST /call/ now declares its body as a Pydantic CallBody model;
  FastAPI handles parsing + envelope validation. No more manual
  await request.json() + dict[get] dancing.
- Endpoint bodies shrink to 3-5 lines each. Context fetch uses a
  dict comprehension over the function group.
- mizan_exception_handler renders MizanError to the protocol's
  {error: {code, message, details}} envelope.
- mizan_validation_handler maps FastAPI's RequestValidationError to
  the same envelope under BAD_REQUEST so the wire format is uniform
  whether the failure is body-shape or business validation.

__init__.py: exposes the full exception hierarchy + both handlers
so consumers can wire them onto their FastAPI app declaratively:

    app.add_exception_handler(MizanError, mizan_exception_handler)
    app.add_exception_handler(RequestValidationError, mizan_validation_handler)

Verified: mizan-core 15/15, mizan-django 348 pass, mizan-fastapi 11/11.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-06 16:51:03 -04:00
parent 4e4d1bb6b1
commit 63c9a9c4ce
4 changed files with 228 additions and 221 deletions

View File

@@ -1,17 +1,14 @@
"""
RPC dispatch — looks up registered functions by name, validates input
against the function's Pydantic Input model, executes, and returns the
result wrapped in a normalized FunctionResult / FunctionError.
RPC dispatch — looks up registered functions, validates input against the
function's Pydantic Input model, executes, and returns the serialized result.
Backend-agnostic where possible. The only FastAPI-specific bits are the
Request type hint (kept loose as Any) and the auth-check mechanism, which
expects FastAPI to populate request.state.user via dependency injection
or middleware before dispatch.
Errors raise typed exceptions (MizanError subclasses). Wire those to JSON
responses by registering `mizan_exception_handler` on the FastAPI app, or
let them propagate to your own handler.
"""
from __future__ import annotations
from dataclasses import dataclass
from enum import Enum
from typing import Any
@@ -20,7 +17,7 @@ from pydantic import BaseModel, ValidationError
from mizan_core.registry import get_function
# ─── Error / Result types ───────────────────────────────────────────────────
# ─── Error taxonomy ─────────────────────────────────────────────────────────
class ErrorCode(str, Enum):
@@ -33,164 +30,168 @@ class ErrorCode(str, Enum):
INTERNAL_ERROR = "INTERNAL_ERROR"
@dataclass
class FunctionError:
code: ErrorCode
message: str
details: dict[str, Any] | None = None
_STATUS = {
ErrorCode.NOT_FOUND: 404,
ErrorCode.BAD_REQUEST: 400,
ErrorCode.VALIDATION_ERROR: 422,
ErrorCode.UNAUTHORIZED: 401,
ErrorCode.FORBIDDEN: 403,
ErrorCode.NOT_IMPLEMENTED: 501,
ErrorCode.INTERNAL_ERROR: 500,
}
@dataclass
class FunctionResult:
data: Any # serialized return value (dict, list, primitive, or Pydantic model_dump)
class MizanError(Exception):
"""Base for protocol-level dispatch errors."""
code: ErrorCode = ErrorCode.INTERNAL_ERROR
def __init__(self, message: str, *, details: dict[str, Any] | None = None) -> None:
super().__init__(message)
self.message = message
self.details = details
@property
def status_code(self) -> int:
return _STATUS[self.code]
class NotFound(MizanError): code = ErrorCode.NOT_FOUND # noqa: E701
class BadRequest(MizanError): code = ErrorCode.BAD_REQUEST # noqa: E701
class ValidationFailed(MizanError): code = ErrorCode.VALIDATION_ERROR # noqa: E701
class Unauthorized(MizanError): code = ErrorCode.UNAUTHORIZED # noqa: E701
class Forbidden(MizanError): code = ErrorCode.FORBIDDEN # noqa: E701
class NotImplementedYet(MizanError): code = ErrorCode.NOT_IMPLEMENTED # noqa: E701
class InternalError(MizanError): code = ErrorCode.INTERNAL_ERROR # noqa: E701
# ─── Auth ───────────────────────────────────────────────────────────────────
def _check_auth(request: Any, auth_requirement: Any) -> FunctionError | None:
"""
Verify the request meets the function's auth requirement.
def _user(request: Any) -> Any:
return getattr(getattr(request, "state", None), "user", None)
The auth value comes from @client(auth=...). FastAPI projects are expected
to populate `request.state.user` (or compatible) via middleware. If
request has no `.state` or `.state.user`, treats the user as anonymous.
"""
if auth_requirement is None:
def _is_authenticated(user: Any) -> bool:
return bool(user) and getattr(user, "is_authenticated", True)
def _enforce_auth(request: Any, requirement: Any) -> None:
"""Verify the request meets the function's @client(auth=...) requirement, or raise."""
if requirement is None:
return
user = _user(request)
match requirement:
case True:
if not _is_authenticated(user):
raise Unauthorized("Authentication required")
case "staff":
if not _is_authenticated(user):
raise Unauthorized("Authentication required")
if not getattr(user, "is_staff", False):
raise Forbidden("Staff access required")
case "superuser":
if not _is_authenticated(user):
raise Unauthorized("Authentication required")
if not getattr(user, "is_superuser", False):
raise Forbidden("Superuser access required")
case f if callable(f):
if not f(request):
raise Forbidden("Permission denied")
case other:
raise InternalError(f"Unknown auth requirement: {other!r}")
# ─── Input validation ───────────────────────────────────────────────────────
def _validate_input(input_cls: Any, input_data: Any) -> BaseModel | None:
"""Validate input_data against the function's Input model. Returns the instance or None."""
if input_cls in (None, BaseModel) or not getattr(input_cls, "model_fields", None):
return None
user = getattr(getattr(request, "state", None), "user", None)
is_authenticated = bool(user) and getattr(user, "is_authenticated", True)
fields = input_cls.model_fields
required = [name for name, f in fields.items() if f.is_required()]
if auth_requirement is True:
if not is_authenticated:
return FunctionError(ErrorCode.UNAUTHORIZED, "Authentication required")
return None
if not input_data:
if required:
raise ValidationFailed(
"Input validation failed",
details={"fields": {name: ["Field required"] for name in required}},
)
return input_cls()
if auth_requirement == "staff":
if not is_authenticated:
return FunctionError(ErrorCode.UNAUTHORIZED, "Authentication required")
if not getattr(user, "is_staff", False):
return FunctionError(ErrorCode.FORBIDDEN, "Staff access required")
return None
if not isinstance(input_data, dict):
raise BadRequest(f"Input must be an object, got {type(input_data).__name__}")
if auth_requirement == "superuser":
if not is_authenticated:
return FunctionError(ErrorCode.UNAUTHORIZED, "Authentication required")
if not getattr(user, "is_superuser", False):
return FunctionError(ErrorCode.FORBIDDEN, "Superuser access required")
return None
if callable(auth_requirement):
if not auth_requirement(request):
return FunctionError(ErrorCode.FORBIDDEN, "Permission denied")
return None
return FunctionError(
ErrorCode.INTERNAL_ERROR,
f"Unknown auth requirement: {auth_requirement!r}",
)
try:
return input_cls(**input_data)
except ValidationError as e:
raise ValidationFailed(
"Input validation failed",
details={"errors": e.errors()},
) from e
# ─── Dispatch ───────────────────────────────────────────────────────────────
def _resolve_function(fn_name: str) -> Any:
view_class = get_function(fn_name)
if view_class is None:
raise NotFound("Function not found")
if getattr(view_class, "_meta", {}).get("private"):
raise Forbidden("Function is not client-callable")
return view_class
def _serialize(result: Any) -> Any:
return result.model_dump(mode="json") if isinstance(result, BaseModel) else result
def execute_function(
request: Any,
fn_name: str,
input_data: dict[str, Any] | None = None,
) -> FunctionResult | FunctionError:
"""
Look up a registered function by name, validate input, execute, return result.
"""
view_class = get_function(fn_name)
if view_class is None:
return FunctionError(ErrorCode.NOT_FOUND, "Function not found")
meta = getattr(view_class, "_meta", {})
if meta.get("private"):
return FunctionError(ErrorCode.FORBIDDEN, "Function is not client-callable")
auth_error = _check_auth(request, meta.get("auth"))
if auth_error is not None:
return auth_error
) -> Any:
"""Dispatch a registered function. Returns the serialized result, or raises MizanError."""
view_class = _resolve_function(fn_name)
_enforce_auth(request, view_class._meta.get("auth"))
view = view_class(request)
input_cls = view.Input
validated = _validate_input(view.Input, input_data)
# Pydantic input validation
has_fields = bool(getattr(input_cls, "model_fields", None)) if input_cls else False
if input_data is not None and has_fields:
if not isinstance(input_data, dict):
return FunctionError(
ErrorCode.BAD_REQUEST,
f"Input must be an object, got {type(input_data).__name__}",
)
try:
validated_input = input_cls(**input_data)
except ValidationError as e:
return FunctionError(
ErrorCode.VALIDATION_ERROR,
"Input validation failed",
details={"errors": e.errors()},
)
elif has_fields:
# Function expects input but none provided
required = input_cls.model_json_schema().get("required", [])
if required:
return FunctionError(
ErrorCode.VALIDATION_ERROR,
"Input validation failed",
details={"fields": {field: ["Field required"] for field in required}},
)
validated_input = input_cls()
else:
validated_input = None
# Execute. The wrapper's call(input) always takes the input arg,
# passing None when the function has no fields.
try:
result = view.call(validated_input)
result = view.call(validated)
except NotImplementedError as e:
return FunctionError(ErrorCode.NOT_IMPLEMENTED, str(e) or "Not implemented")
raise NotImplementedYet(str(e) or "Not implemented") from e
except MizanError:
raise
except Exception as e:
return FunctionError(ErrorCode.INTERNAL_ERROR, str(e))
raise InternalError(str(e)) from e
# Serialize Pydantic models to plain dicts
if isinstance(result, BaseModel):
return FunctionResult(data=result.model_dump(mode="json"))
return FunctionResult(data=result)
return _serialize(result)
# ─── Invalidation ───────────────────────────────────────────────────────────
def compute_invalidation(view_class: Any, input_data: dict[str, Any] | None) -> list[Any]:
"""
Build the invalidate list for a mutation response from the function's
@client(affects=...) metadata. Auto-scopes to params when the mutation's
arg names overlap with a context's params.
"""
meta = getattr(view_class, "_meta", {})
affects = meta.get("affects")
if not affects:
return []
"""Build the `invalidate` list from @client(affects=...) metadata, auto-scoping when arg names match context params."""
affects = getattr(view_class, "_meta", {}).get("affects") or []
return [_invalidation_target(target, input_data or {}) for target in affects]
out: list[Any] = []
for target in affects:
target_type = target.get("type")
target_name = target.get("name")
if target_type == "context":
scope_params = target.get("params") or {}
if scope_params and input_data:
# Auto-scope: include matching param values
matched = {k: input_data[k] for k in scope_params if k in input_data}
if matched:
out.append({"context": target_name, "params": matched})
continue
out.append(target_name)
elif target_type == "function":
out.append({"function": target_name})
return out
def _invalidation_target(target: dict[str, Any], input_data: dict[str, Any]) -> Any:
match target.get("type"):
case "context":
name = target["name"]
scope_keys = (target.get("params") or {}).keys()
scoped = {k: input_data[k] for k in scope_keys if k in input_data}
return {"context": name, "params": scoped} if scoped else name
case "function":
return {"function": target["name"]}
case _:
return target