FastAPI and TypeScript improved

This commit is contained in:
2026-06-04 05:14:29 -04:00
parent 67ad91b673
commit 66b2db81fb
28 changed files with 1864 additions and 717 deletions

View File

@@ -35,12 +35,18 @@ from .executor import (
execute_function,
)
from .router import router, mizan_exception_handler, mizan_validation_handler
from .auth import MizanAuthMiddleware, mizan_auth
from .config import MizanConfig, from_env
from mizan_core.upload import File, Upload, UploadedFile
__all__ = [
"Upload",
"File",
"UploadedFile",
"mizan_auth",
"MizanAuthMiddleware",
"MizanConfig",
"from_env",
"router",
"mizan_exception_handler",
"mizan_validation_handler",

View File

@@ -0,0 +1,54 @@
"""
Built-in identity for FastAPI — Django-equivalent automatic `request.state.user`.
Opt in via `Depends(mizan_auth())` on a route/router, or mount `MizanAuthMiddleware`
app-wide. Both decode a bearer-JWT (`Authorization: Bearer`) or MWT (`X-Mizan-Token`)
via the shared core and set `request.state.user`. A present-but-invalid token is
rejected (401) rather than silently downgraded — the `INVALID` sentinel contract.
If you'd rather resolve identity yourself, set `request.state.user` upstream and skip
these; dispatch reads it directly.
"""
from __future__ import annotations
from typing import Callable
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from mizan_core.auth import INVALID, authenticate
from mizan_core.errors import Unauthorized
from .config import get_config
def _resolve(request: Request) -> None:
ident = authenticate(request.headers, get_config(request).auth)
if ident is INVALID:
raise Unauthorized("Invalid or expired token")
if ident is not None:
request.state.user = ident
def mizan_auth() -> Callable:
"""FastAPI dependency that populates `request.state.user` from a token."""
async def _dep(request: Request) -> None:
_resolve(request)
return _dep
class MizanAuthMiddleware(BaseHTTPMiddleware):
"""App-wide variant of `mizan_auth` — resolves identity on every request."""
async def dispatch(self, request, call_next):
try:
_resolve(request)
except Unauthorized:
from .router import _no_store
from mizan_core.errors import ErrorCode
return _no_store(
{"error": {"code": ErrorCode.UNAUTHORIZED.value, "message": "Invalid or expired token"}},
status_code=401,
)
return await call_next(request)

View File

@@ -0,0 +1,80 @@
"""
FastAPI configuration — the "no settings.py" seam.
Builds the shared core's `AuthConfig` (JWT + MWT) and a `CacheOrchestrator`
from environment variables, overridable per-app via `app.state.mizan_config`.
Env:
MIZAN_CACHE_SECRET HMAC cache signing key (enables origin cache)
MIZAN_CACHE_REDIS_URL Redis URL (else in-memory cache)
MIZAN_MWT_SECRET MWT signing key
MIZAN_MWT_AUDIENCE MWT audience (default "mizan")
JWT_PRIVATE_KEY JWT signing key (enables bearer-JWT auth)
JWT_PUBLIC_KEY JWT verify key (default: private key, HS256)
JWT_ALGORITHM default "HS256"
JWT_ACCESS_TOKEN_EXPIRES_IN / JWT_REFRESH_TOKEN_EXPIRES_IN
"""
from __future__ import annotations
import os
from dataclasses import dataclass
from mizan_core.auth import AuthConfig, JWTConfig
from mizan_core.cache.backend import CacheBackend, MemoryCache
from mizan_core.dispatch import CacheOrchestrator
@dataclass(frozen=True)
class MizanConfig:
auth: AuthConfig
cache: CacheOrchestrator
def _cache_backend(secret: str | None, redis_url: str | None) -> CacheBackend | None:
if not secret:
return None
if redis_url:
from mizan_core.cache.backend import RedisCache
return RedisCache(redis_url)
return MemoryCache()
def _jwt_config() -> JWTConfig | None:
key = os.getenv("JWT_PRIVATE_KEY")
if not key:
return None
return JWTConfig(
private_key=key,
public_key=os.getenv("JWT_PUBLIC_KEY", key),
algorithm=os.getenv("JWT_ALGORITHM", "HS256"),
access_token_expires_in=int(os.getenv("JWT_ACCESS_TOKEN_EXPIRES_IN", "300")),
refresh_token_expires_in=int(os.getenv("JWT_REFRESH_TOKEN_EXPIRES_IN", "604800")),
)
def from_env() -> MizanConfig:
secret = os.getenv("MIZAN_CACHE_SECRET")
backend = _cache_backend(secret, os.getenv("MIZAN_CACHE_REDIS_URL"))
auth = AuthConfig(
jwt=_jwt_config(),
mwt_secret=os.getenv("MIZAN_MWT_SECRET"),
mwt_audience=os.getenv("MIZAN_MWT_AUDIENCE", "mizan"),
)
return MizanConfig(auth=auth, cache=CacheOrchestrator(backend, secret))
def get_config(request) -> MizanConfig:
"""Per-app config: `app.state.mizan_config` if set, else built from env (cached)."""
app = getattr(request, "app", None)
state = getattr(app, "state", None)
override = getattr(state, "mizan_config", None) if state is not None else None
if override is not None:
return override
global _DEFAULT
if _DEFAULT is None:
_DEFAULT = from_env()
return _DEFAULT
_DEFAULT: MizanConfig | None = None

View File

@@ -1,263 +1,69 @@
"""
RPC dispatch — looks up registered functions, validates input against the
function's Pydantic Input model, executes, and returns the serialized result.
Dispatch — a thin shim over the shared core (`mizan_core.dispatch`).
Errors raise typed exceptions (MizanError subclasses). Wire those to JSON
responses by registering `mizan_exception_handler` on the FastAPI app, or
let them propagate to your own handler.
The protocol machinery (auth, validation, execution, invalidation, merge, cache)
lives in `mizan_core`; this module re-exports the canonical error taxonomy and
keeps backward-compatible helpers. The router drives `dispatch_call` /
`dispatch_context` directly to get invalidation + origin cache.
"""
from __future__ import annotations
from enum import Enum
from typing import Any
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel, ValidationError
from mizan_core.dispatch import CacheOrchestrator, DispatchRequest, dispatch_call
from mizan_core.errors import (
BadRequest,
ErrorCode,
Forbidden,
InternalError,
MizanError,
NotFound,
NotImplementedYet,
Unauthorized,
ValidationFailed,
)
from mizan_core.invalidation import resolve_invalidation, resolve_merges
from mizan_core.registry import get_context_groups, get_function
from mizan_core.type_utils import types_match_for_merge
__all__ = [
"ErrorCode",
"MizanError",
"NotFound",
"BadRequest",
"ValidationFailed",
"Unauthorized",
"Forbidden",
"NotImplementedYet",
"InternalError",
"compute_invalidation",
"compute_merges",
"execute_function",
]
# ─── Error taxonomy ─────────────────────────────────────────────────────────
class ErrorCode(str, Enum):
NOT_FOUND = "NOT_FOUND"
BAD_REQUEST = "BAD_REQUEST"
VALIDATION_ERROR = "VALIDATION_ERROR"
UNAUTHORIZED = "UNAUTHORIZED"
FORBIDDEN = "FORBIDDEN"
NOT_IMPLEMENTED = "NOT_IMPLEMENTED"
INTERNAL_ERROR = "INTERNAL_ERROR"
_STATUS = {
ErrorCode.NOT_FOUND: 404,
ErrorCode.BAD_REQUEST: 400,
ErrorCode.VALIDATION_ERROR: 422,
ErrorCode.UNAUTHORIZED: 401,
ErrorCode.FORBIDDEN: 403,
ErrorCode.NOT_IMPLEMENTED: 501,
ErrorCode.INTERNAL_ERROR: 500,
}
class MizanError(Exception):
"""Base for protocol-level dispatch errors."""
code: ErrorCode = ErrorCode.INTERNAL_ERROR
def __init__(self, message: str, *, details: dict[str, Any] | None = None) -> None:
super().__init__(message)
self.message = message
self.details = details
@property
def status_code(self) -> int:
return _STATUS[self.code]
class NotFound(MizanError): code = ErrorCode.NOT_FOUND # noqa: E701
class BadRequest(MizanError): code = ErrorCode.BAD_REQUEST # noqa: E701
class ValidationFailed(MizanError): code = ErrorCode.VALIDATION_ERROR # noqa: E701
class Unauthorized(MizanError): code = ErrorCode.UNAUTHORIZED # noqa: E701
class Forbidden(MizanError): code = ErrorCode.FORBIDDEN # noqa: E701
class NotImplementedYet(MizanError): code = ErrorCode.NOT_IMPLEMENTED # noqa: E701
class InternalError(MizanError): code = ErrorCode.INTERNAL_ERROR # noqa: E701
# ─── Auth ───────────────────────────────────────────────────────────────────
def _user(request: Any) -> Any:
return getattr(getattr(request, "state", None), "user", None)
def _is_authenticated(user: Any) -> bool:
return bool(user) and getattr(user, "is_authenticated", True)
def _enforce_auth(request: Any, requirement: Any) -> None:
"""Verify the request meets the function's @client(auth=...) requirement, or raise."""
if requirement is None:
return
user = _user(request)
match requirement:
case True | "required":
if not _is_authenticated(user):
raise Unauthorized("Authentication required")
case "staff":
if not _is_authenticated(user):
raise Unauthorized("Authentication required")
if not getattr(user, "is_staff", False):
raise Forbidden("Staff access required")
case "superuser":
if not _is_authenticated(user):
raise Unauthorized("Authentication required")
if not getattr(user, "is_superuser", False):
raise Forbidden("Superuser access required")
case f if callable(f):
if not f(request):
raise Forbidden("Permission denied")
case other:
raise InternalError(f"Unknown auth requirement: {other!r}")
# ─── Input validation ───────────────────────────────────────────────────────
def _validate_input(input_cls: Any, input_data: Any) -> BaseModel | None:
"""Validate input_data against the function's Input model. Returns the instance or None."""
if input_cls in (None, BaseModel) or not getattr(input_cls, "model_fields", None):
return None
fields = input_cls.model_fields
required = [name for name, f in fields.items() if f.is_required()]
if not input_data:
if required:
raise ValidationFailed(
"Input validation failed",
details={"fields": {name: ["Field required"] for name in required}},
)
return input_cls()
if not isinstance(input_data, dict):
raise BadRequest(f"Input must be an object, got {type(input_data).__name__}")
try:
return input_cls(**input_data)
except ValidationError as e:
raise ValidationFailed(
"Input validation failed",
details={"errors": e.errors()},
) from e
# ─── Dispatch ───────────────────────────────────────────────────────────────
def _resolve_function(fn_name: str) -> Any:
view_class = get_function(fn_name)
if view_class is None:
raise NotFound("Function not found")
if getattr(view_class, "_meta", {}).get("private"):
raise Forbidden("Function is not client-callable")
return view_class
def _serialize(result: Any) -> Any:
# jsonable_encoder walks BaseModel / list / dict recursively, so list[BaseModel]
# (and nested shapes) come out wire-ready without a per-shape branch here.
return jsonable_encoder(result)
async def execute_function(
request: Any,
fn_name: str,
input_data: dict[str, Any] | None = None,
) -> Any:
"""Dispatch a registered function. Returns the serialized result, or raises MizanError.
Awaits `view.acall` — async handlers run on the loop, sync handlers run
in the default threadpool, both via the same entrypoint.
"""
view_class = _resolve_function(fn_name)
_enforce_auth(request, view_class._meta.get("auth"))
view = view_class(request)
validated = _validate_input(view.Input, input_data)
try:
result = await view.acall(validated)
except NotImplementedError as e:
raise NotImplementedYet(str(e) or "Not implemented") from e
except MizanError:
raise
except Exception as e:
raise InternalError(str(e)) from e
return _serialize(result)
# ─── Invalidation ───────────────────────────────────────────────────────────
_NO_CACHE = CacheOrchestrator(None, None)
def compute_invalidation(view_class: Any, input_data: dict[str, Any] | None) -> list[Any]:
"""Build the `invalidate` list from @client(affects=...) metadata, auto-scoping when arg names match context params."""
affects = getattr(view_class, "_meta", {}).get("affects") or []
return [_invalidation_target(target, input_data or {}) for target in affects]
"""`@client(affects=...)` → invalidation list (empty when none). Shared core."""
return resolve_invalidation(view_class, input_data) or []
def compute_merges(view_class: Any, input_data: dict[str, Any] | None, result: Any) -> list[dict[str, Any]]:
"""Build the `merge` list from @client(merge=...) metadata.
"""`@client(merge=...)` → merge list (empty when none). Shared core."""
return resolve_merges(view_class, input_data, result) or []
Each entry is `{context, slot, value, params?}` where `slot` names the
function inside the context bundle the value lands in. The slot is
resolved server-side via `types_match_for_merge` so the kernel does
no shape inference — the server has the schema, type-checked routing
lives here. Entries whose slot can't be uniquely resolved are dropped
with a warning; the consumer falls back to refetch via `affects`.
async def execute_function(request: Any, fn_name: str, input_data: dict[str, Any] | None = None) -> Any:
"""Dispatch a function and return its serialized result (auth enforced via core).
Backward-compat entry point; the router uses `dispatch_call` directly to also
capture invalidation/merge and run the origin cache.
"""
targets = getattr(view_class, "_meta", {}).get("merge") or []
if not targets:
return []
mutation_output = getattr(view_class, "Output", None)
out: list[dict[str, Any]] = []
for ctx_name in targets:
slot = _resolve_merge_slot(ctx_name, mutation_output)
if slot is None:
continue
entry: dict[str, Any] = {"context": ctx_name, "slot": slot, "value": result}
scoped = _scoped_params(ctx_name, input_data or {})
if scoped:
entry["params"] = scoped
out.append(entry)
return out
def _resolve_merge_slot(context_name: str, mutation_output: Any) -> str | None:
"""Find the unique function-name slot whose return type matches the mutation's output.
Returns None on no match or ambiguous match (multiple candidates).
"""
if mutation_output is None:
return None
matches: list[str] = []
for fn_name in get_context_groups().get(context_name, []):
fn_cls = get_function(fn_name)
if fn_cls is None:
continue
fn_output = getattr(fn_cls, "Output", None)
if fn_output is not None and types_match_for_merge(fn_output, mutation_output):
matches.append(fn_name)
return matches[0] if len(matches) == 1 else None
def _scoped_params(context_name: str, input_data: dict[str, Any]) -> dict[str, Any]:
"""Match input args against the context's declared Input field names."""
fn_names = get_context_groups().get(context_name, [])
declared: set[str] = set()
for fn_name in fn_names:
fn_cls = get_function(fn_name)
if fn_cls is None:
continue
input_cls = getattr(fn_cls, "Input", None)
if input_cls and input_cls is not BaseModel and hasattr(input_cls, "model_fields"):
declared.update(input_cls.model_fields.keys())
return {k: v for k, v in input_data.items() if k in declared}
def _invalidation_target(target: dict[str, Any], input_data: dict[str, Any]) -> Any:
match target.get("type"):
case "context":
name = target["name"]
scoped = _scoped_params(name, input_data)
return {"context": name, "params": scoped} if scoped else name
case "function":
return {"function": target["name"]}
case _:
return target
identity = getattr(getattr(request, "state", None), "user", None)
res = await dispatch_call(
DispatchRequest(identity=identity, args=input_data, native_request=request),
fn_name,
_NO_CACHE,
)
return res.data

View File

@@ -19,22 +19,17 @@ from typing import Any
from fastapi import APIRouter, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.responses import JSONResponse, Response
from pydantic import BaseModel, Field, ValidationError
from starlette.datastructures import UploadFile
from mizan_core.registry import get_context_groups, get_function
from mizan_core.auth import INVALID, authenticate
from mizan_core.dispatch import DispatchRequest, dispatch_call, dispatch_context
from mizan_core.errors import BadRequest, ErrorCode, MizanError, Unauthorized
from mizan_core.registry import get_function
from mizan_core.upload import UploadedFile, bind_uploads
from .executor import (
BadRequest,
ErrorCode,
MizanError,
NotFound,
compute_invalidation,
compute_merges,
execute_function,
)
from .config import MizanConfig, get_config
router = APIRouter()
@@ -106,31 +101,52 @@ async def _parse_call(request: Request) -> tuple[str, dict[str, Any]]:
return body.fn, body.args
def _identity(request: Request, cfg: MizanConfig):
"""Identity for dispatch: a host-set `request.state.user`, else a token decode.
A present-but-invalid token rejects (401); no token → None (anonymous).
"""
existing = getattr(getattr(request, "state", None), "user", None)
if existing is not None:
return existing
ident = authenticate(request.headers, cfg.auth)
if ident is INVALID:
raise Unauthorized("Invalid or expired token")
return ident
@router.post("/call/")
async def function_call(request: Request) -> JSONResponse:
"""RPC dispatch — `{"fn": "...", "args": {...}}` (JSON) or multipart with file
parts → `{"result": ..., "invalidate": [...], "merge"?: [...]}`."""
"""RPC dispatch — JSON or multipart → `{"result", "invalidate", "merge"?}` with
the `X-Mizan-Invalidate` header alongside the body."""
cfg = get_config(request)
fn, args = await _parse_call(request)
fn_class = get_function(fn)
result = await execute_function(request, fn, args)
invalidate = compute_invalidation(fn_class, args)
merges = compute_merges(fn_class, args, result)
payload: dict[str, Any] = {"result": result, "invalidate": invalidate}
if merges:
payload["merge"] = merges
return _no_store(payload)
res = await dispatch_call(
DispatchRequest(identity=_identity(request, cfg), args=args, native_request=request),
fn, cfg.cache,
)
payload: dict[str, Any] = {"result": res.data, "invalidate": res.invalidate or []}
if res.merge:
payload["merge"] = res.merge
headers = {"Cache-Control": "no-store"}
if res.invalidate_header:
headers["X-Mizan-Invalidate"] = res.invalidate_header
return JSONResponse(payload, headers=headers)
@router.get("/ctx/{context_name}/")
async def context_fetch(context_name: str, request: Request) -> JSONResponse:
"""Bundled context fetch — `{function_name: result, ...}` for every function in the context."""
fn_names = get_context_groups().get(context_name)
if not fn_names:
raise NotFound(f"Context '{context_name}' not found")
params = dict(request.query_params)
bundled = {fn: await execute_function(request, fn, params) for fn in fn_names}
return _no_store(bundled)
async def context_fetch(context_name: str, request: Request) -> Response:
"""Bundled context fetch — origin-cached. `{function_name: result, ...}`."""
cfg = get_config(request)
res = await dispatch_context(
DispatchRequest(identity=_identity(request, cfg), args=dict(request.query_params),
native_request=request),
context_name, cfg.cache,
)
headers = {"Cache-Control": "no-store"}
if res.cache_status:
headers["X-Mizan-Cache"] = res.cache_status
return Response(content=res.body_bytes, media_type="application/json", headers=headers)
# ─── Exception handler ──────────────────────────────────────────────────────

View File

@@ -0,0 +1,98 @@
"""FastAPI parity with Django: X-Mizan-Invalidate header, origin cache, token auth."""
from __future__ import annotations
import pytest
from fastapi import Depends, FastAPI
from fastapi.testclient import TestClient
from pydantic import BaseModel
from mizan_core.auth import AuthConfig, JWTConfig, create_access_token
from mizan_core.cache.backend import MemoryCache
from mizan_core.client.function import client
from mizan_core.dispatch import CacheOrchestrator
from mizan_core.registry import clear_registry, register
from mizan_fastapi import (
MizanAuthMiddleware,
MizanConfig,
MizanError,
mizan_auth,
mizan_exception_handler,
router as mizan_router,
)
class Out(BaseModel):
ok: bool
SECRET = "x" * 32
JWT = JWTConfig(private_key=SECRET, public_key=SECRET)
def _app(*, with_cache=False, with_auth_dep=False) -> FastAPI:
clear_registry()
UserCtx = "user"
@client(context=UserCtx)
def user_profile(request, user_id: int) -> Out:
return Out(ok=True)
@client(affects=UserCtx)
def update_profile(request, user_id: int) -> Out:
return Out(ok=True)
@client(auth=True)
def whoami(request) -> Out:
return Out(ok=True)
register(user_profile, "user_profile")
register(update_profile, "update_profile")
register(whoami, "whoami")
app = FastAPI()
cache = CacheOrchestrator(MemoryCache(), SECRET) if with_cache else CacheOrchestrator(None, None)
app.state.mizan_config = MizanConfig(auth=AuthConfig(jwt=JWT), cache=cache)
deps = [Depends(mizan_auth())] if with_auth_dep else []
app.include_router(mizan_router, prefix="/api/mizan", dependencies=deps)
app.add_exception_handler(MizanError, mizan_exception_handler)
return app
def test_mutation_emits_invalidate_header():
c = TestClient(_app())
r = c.post("/api/mizan/call/", json={"fn": "update_profile", "args": {"user_id": 5}})
assert r.status_code == 200
assert r.json()["invalidate"] == [{"context": "user", "params": {"user_id": 5}}]
assert r.headers["X-Mizan-Invalidate"] == "user;user_id=5"
def test_origin_cache_hit_miss():
c = TestClient(_app(with_cache=True))
r1 = c.get("/api/mizan/ctx/user/", params={"user_id": 5})
assert r1.status_code == 200 and r1.headers["X-Mizan-Cache"] == "MISS"
r2 = c.get("/api/mizan/ctx/user/", params={"user_id": 5})
assert r2.headers["X-Mizan-Cache"] == "HIT"
assert r1.content == r2.content
def test_auth_required_rejects_anonymous():
c = TestClient(_app())
r = c.post("/api/mizan/call/", json={"fn": "whoami", "args": {}})
assert r.status_code == 401
def test_auth_required_passes_with_bearer_jwt():
c = TestClient(_app(with_auth_dep=True))
tok = create_access_token("7", "sess", JWT, is_staff=True)
r = c.post("/api/mizan/call/", json={"fn": "whoami", "args": {}},
headers={"Authorization": f"Bearer {tok}"})
assert r.status_code == 200 and r.json()["result"] == {"ok": True}
def test_invalid_bearer_token_rejected():
c = TestClient(_app())
r = c.post("/api/mizan/call/", json={"fn": "update_profile", "args": {"user_id": 1}},
headers={"Authorization": "Bearer not-a-real-token"})
assert r.status_code == 401