FastAPI and TypeScript improved

This commit is contained in:
2026-06-04 05:14:29 -04:00
parent 67ad91b673
commit 66b2db81fb
28 changed files with 1864 additions and 717 deletions

View File

@@ -0,0 +1,27 @@
from mizan_core.auth.authenticate import INVALID, AuthConfig, authenticate
from mizan_core.auth.jwt import (
JWTConfig,
JWTUser,
TokenPair,
TokenPayload,
create_access_token,
create_refresh_token,
create_token_pair,
decode_token,
refresh_tokens,
)
__all__ = [
"AuthConfig",
"authenticate",
"INVALID",
"JWTConfig",
"JWTUser",
"TokenPair",
"TokenPayload",
"create_access_token",
"create_refresh_token",
"create_token_pair",
"decode_token",
"refresh_tokens",
]

View File

@@ -0,0 +1,53 @@
"""
Token → identity resolution, shared by every adapter.
`authenticate(headers, config)` reads `X-Mizan-Token` (MWT) first, then
`Authorization: Bearer` (JWT), and returns an `Identity`, `None`, or the
`INVALID` sentinel.
The `INVALID` sentinel is load-bearing: when a token is PRESENT but bad, the
adapter must REJECT — never silently fall back to session auth (that would let
a forged/expired token degrade into anonymous-or-session access). `None` means
"no token offered" → the adapter may fall back to its own session identity.
"""
from __future__ import annotations
from dataclasses import dataclass
from typing import Mapping
from mizan_core.auth.jwt import JWTConfig, JWTUser, decode_token
from mizan_core.identity import Identity
from mizan_core.mwt import MWTUser, decode_mwt
class _Invalid:
"""Sentinel: a token was presented but failed validation."""
def __repr__(self) -> str:
return "INVALID"
INVALID = _Invalid()
@dataclass(frozen=True)
class AuthConfig:
jwt: JWTConfig | None = None
mwt_secret: str | None = None
mwt_audience: str = "mizan"
def authenticate(headers: Mapping[str, str], config: AuthConfig) -> Identity | _Invalid | None:
"""Resolve identity from request headers. Returns Identity | INVALID | None."""
mwt = headers.get("X-Mizan-Token") or headers.get("x-mizan-token")
if mwt and config.mwt_secret:
payload = decode_mwt(mwt, config.mwt_secret, audience=config.mwt_audience)
return MWTUser(payload) if payload else INVALID
bearer = headers.get("Authorization") or headers.get("authorization") or ""
if bearer.startswith("Bearer ") and config.jwt:
payload = decode_token(bearer[7:], config.jwt, expected_type="access")
return JWTUser(payload) if payload else INVALID
return None

View File

@@ -0,0 +1,137 @@
"""
JWT access/refresh tokens — adapter-agnostic (PyJWT).
Config is injected (`JWTConfig`) rather than read from any framework's settings.
`validate_session` (the immediate-logout-revocation check) is Django-session-bound
and stays in the Django adapter; `refresh_tokens` takes a `session_validator`
callable so the core stays framework-free.
"""
from __future__ import annotations
import time
from dataclasses import dataclass
from typing import Callable, NamedTuple
import jwt
@dataclass(frozen=True)
class JWTConfig:
private_key: str
public_key: str
algorithm: str = "HS256"
access_token_expires_in: int = 300
refresh_token_expires_in: int = 604800
class TokenPair(NamedTuple):
access_token: str
refresh_token: str
expires_in: int
class TokenPayload(NamedTuple):
user_id: int | str
session_key: str
token_type: str
is_staff: bool
is_superuser: bool
exp: int
iat: int
class JWTUser:
"""Minimal `Identity` built from JWT claims — no DB query."""
def __init__(self, payload: TokenPayload):
self.id = int(payload.user_id) if isinstance(payload.user_id, str) else payload.user_id
self.pk = self.id
self.is_staff = payload.is_staff
self.is_superuser = payload.is_superuser
self.is_authenticated = True
self.is_anonymous = False
self.is_active = True
def __str__(self) -> str:
return f"JWTUser(id={self.id})"
def __repr__(self) -> str:
return f"JWTUser(id={self.id}, is_staff={self.is_staff}, is_superuser={self.is_superuser})"
def _mint(user_id: int | str, session_key: str, token_type: str, ttl: int,
config: JWTConfig, is_staff: bool, is_superuser: bool) -> str:
now = int(time.time())
payload = {
"sub": str(user_id),
"sid": session_key,
"staff": is_staff,
"super": is_superuser,
"type": token_type,
"iat": now,
"exp": now + ttl,
}
return jwt.encode(payload, config.private_key, algorithm=config.algorithm)
def create_access_token(user_id, session_key, config: JWTConfig, *,
is_staff: bool = False, is_superuser: bool = False) -> str:
return _mint(user_id, session_key, "access", config.access_token_expires_in,
config, is_staff, is_superuser)
def create_refresh_token(user_id, session_key, config: JWTConfig, *,
is_staff: bool = False, is_superuser: bool = False) -> str:
return _mint(user_id, session_key, "refresh", config.refresh_token_expires_in,
config, is_staff, is_superuser)
def create_token_pair(user_id, session_key, config: JWTConfig, *,
is_staff: bool = False, is_superuser: bool = False) -> TokenPair:
return TokenPair(
access_token=create_access_token(user_id, session_key, config,
is_staff=is_staff, is_superuser=is_superuser),
refresh_token=create_refresh_token(user_id, session_key, config,
is_staff=is_staff, is_superuser=is_superuser),
expires_in=config.access_token_expires_in,
)
def decode_token(token: str, config: JWTConfig, expected_type: str | None = None) -> TokenPayload | None:
"""Decode + validate. None on invalid/expired token, or type mismatch."""
try:
payload = jwt.decode(token, config.public_key, algorithms=[config.algorithm])
except jwt.PyJWTError:
return None
if expected_type and payload.get("type") != expected_type:
return None
return TokenPayload(
user_id=payload["sub"],
session_key=payload["sid"],
token_type=payload["type"],
is_staff=payload.get("staff", False),
is_superuser=payload.get("super", False),
exp=payload["exp"],
iat=payload["iat"],
)
def refresh_tokens(
refresh_token: str,
config: JWTConfig,
session_validator: Callable[[str], bool] | None = None,
) -> TokenPair | None:
"""Exchange a refresh token for a new pair. None if invalid or the session is gone.
`session_validator(session_key) -> bool` lets the Django adapter enforce
immediate-logout revocation; omit it (or pass a always-True) where there is
no session store.
"""
payload = decode_token(refresh_token, config, expected_type="refresh")
if payload is None:
return None
if session_validator is not None and not session_validator(payload.session_key):
return None
return create_token_pair(payload.user_id, payload.session_key, config,
is_staff=payload.is_staff, is_superuser=payload.is_superuser)

View File

@@ -0,0 +1,52 @@
"""
Auth-guard evaluation — the adapter-agnostic core.
`enforce_auth` evaluates a function's `@client(auth=...)` requirement against an
`Identity` and raises `Unauthorized`/`Forbidden` on failure. A custom `auth=callable`
receives the adapter's NATIVE request (it may read request-specific state), passed
through opaquely — the core never introspects it.
"""
from __future__ import annotations
from typing import Any
from mizan_core.errors import Forbidden, InternalError, Unauthorized
from mizan_core.identity import Identity
def enforce_auth(
identity: Identity | None,
requirement: Any,
native_request: Any = None,
) -> None:
"""Raise `Unauthorized`/`Forbidden` if `identity` fails `requirement`; else return.
Requirement: None | True | "required" | "staff" | "superuser" | callable(native_request)->bool.
"""
if requirement is None:
return
if callable(requirement):
try:
if not requirement(native_request):
raise Forbidden("Access denied")
except PermissionError as e:
raise Forbidden(str(e) or "Access denied") from e
return
if not getattr(identity, "is_authenticated", False):
raise Unauthorized("Authentication required")
if requirement in (True, "required"):
return
if requirement == "staff":
if not getattr(identity, "is_staff", False):
raise Forbidden("Staff access required")
return
if requirement == "superuser":
if not getattr(identity, "is_superuser", False):
raise Forbidden("Superuser access required")
return
raise InternalError(f"Unknown auth requirement: {requirement!r}")

View File

@@ -0,0 +1,250 @@
"""
The adapter-agnostic dispatch core.
Both `dispatch_call` (mutations/RPC) and `dispatch_context` (bundled reads) run
the full protocol: auth → input validation → execute (`await view.acall`, which
threadpools sync handlers) → serialize → resolve invalidation/merge → orchestrate
origin cache. They return a `DispatchResult` the adapter renders to its native
response. Errors raise `MizanError` (the adapter catches at its boundary).
The adapter owns native request parsing (multipart/JSON) and native response
construction; it hands the core a `DispatchRequest` carrying only what the core
reads, and renders what the core returns.
"""
from __future__ import annotations
import json
from dataclasses import dataclass, field
from typing import Any, Literal
from pydantic import BaseModel, ValidationError
from pydantic_core import to_jsonable_python
from mizan_core.authguard import enforce_auth
from mizan_core.cache.backend import CacheBackend
from mizan_core.cache.keys import CONTEXT_KEY_PREFIX, derive_cache_key
from mizan_core.errors import (
BadRequest,
InternalError,
MizanError,
NotFound,
NotImplementedYet,
ValidationFailed,
)
from mizan_core.identity import Identity, user_id_of
from mizan_core.invalidation import (
format_invalidate_header,
resolve_invalidation,
resolve_merges,
)
from mizan_core.registry import get_context_groups, get_function
# ─── Request / result ───────────────────────────────────────────────────────
@dataclass
class DispatchRequest:
"""What the dispatch core reads. The adapter resolves `identity` (session OR
token) and parses `args`/`files`; `native_request` is an opaque passthrough
handed to `view_class(...)` and to `auth=callable`."""
identity: Identity | None = None
args: dict[str, Any] | None = None
files: dict[str, list[Any]] | None = None
native_request: Any = None
@dataclass
class DispatchResult:
kind: Literal["rpc", "view", "context"] = "rpc"
native_response: Any | None = None # view-path: the handler's own response
data: Any | None = None # rpc: serialized payload; context: bundle dict
body_bytes: bytes | None = None # context: canonical JSON to send/cache
cache_status: str | None = None # context: "HIT" | "MISS" | None
invalidate: list[Any] | None = None
merge: list[dict[str, Any]] | None = None
invalidate_header: str | None = None
# ─── Cache orchestration ────────────────────────────────────────────────────
class CacheOrchestrator:
"""Origin-side cache, backend + secret injected by the adapter (config seam)."""
def __init__(self, backend: CacheBackend | None, secret: str | None):
self.backend = backend
self.secret = secret
@property
def enabled(self) -> bool:
return self.backend is not None and bool(self.secret)
def get(self, context: str, params: dict[str, Any], user_id: str | None, rev: int) -> bytes | None:
if not self.enabled:
return None
return self.backend.get(derive_cache_key(self.secret, context, params, user_id, rev))
def put(self, context: str, params: dict[str, Any], value: bytes, user_id: str | None, rev: int) -> None:
if not self.enabled:
return
self.backend.set(derive_cache_key(self.secret, context, params, user_id, rev), value)
def purge(self, invalidate: list[Any], user_id: str | None) -> None:
if not self.enabled:
return
for entry in invalidate:
if isinstance(entry, str):
self.backend.delete_by_prefix(f"{CONTEXT_KEY_PREFIX}{entry}:")
elif isinstance(entry, dict):
ctx = entry["context"]
params = entry.get("params")
if params:
self.backend.delete(derive_cache_key(self.secret, ctx, params, user_id, 0))
else:
self.backend.delete_by_prefix(f"{CONTEXT_KEY_PREFIX}{ctx}:")
# ─── Shared dispatch helpers ────────────────────────────────────────────────
def _resolve_function(fn_name: str) -> Any:
view_class = get_function(fn_name)
if view_class is None:
raise NotFound("Function not found")
if getattr(view_class, "_meta", {}).get("private"):
from mizan_core.errors import Forbidden
raise Forbidden("Function is not client-callable")
return view_class
def _validate_input(input_cls: Any, input_data: Any) -> BaseModel | None:
"""Validate `input_data` against the function's Input model."""
if input_cls in (None, BaseModel) or not getattr(input_cls, "model_fields", None):
return None
required = [name for name, f in input_cls.model_fields.items() if f.is_required()]
if not input_data:
if required:
raise ValidationFailed(
"Input validation failed",
details={"fields": {name: ["Field required"] for name in required}},
)
return input_cls()
if not isinstance(input_data, dict):
raise BadRequest(f"Input must be an object, got {type(input_data).__name__}")
try:
return input_cls(**input_data)
except ValidationError as e:
raise ValidationFailed("Input validation failed", details={"errors": e.errors()}) from e
def _serialize(result: Any) -> Any:
return to_jsonable_python(result)
async def _run(view: Any, validated: Any) -> Any:
try:
return await view.acall(validated)
except NotImplementedError as e:
raise NotImplementedYet(str(e) or "Not implemented") from e
except MizanError:
raise
except Exception as e:
raise InternalError(str(e)) from e
def _canonical_bytes(data: Any) -> bytes:
return json.dumps(data, sort_keys=True).encode("utf-8")
# ─── Entry points ───────────────────────────────────────────────────────────
async def dispatch_call(req: DispatchRequest, fn_name: str, cache: CacheOrchestrator) -> DispatchResult:
"""Mutation / RPC dispatch."""
view_class = _resolve_function(fn_name)
meta = getattr(view_class, "_meta", {})
enforce_auth(req.identity, meta.get("auth"), req.native_request)
view = view_class(req.native_request)
validated = _validate_input(view.Input, req.args)
result = await _run(view, validated)
invalidate = resolve_invalidation(view_class, req.args)
header = format_invalidate_header(invalidate) if invalidate else None
if invalidate:
cache.purge(invalidate, user_id_of(req.identity))
if meta.get("view_path"):
# Handler returned its own native response; carry it through + the header.
return DispatchResult(kind="view", native_response=result,
invalidate=invalidate, invalidate_header=header)
serialized = _serialize(result)
return DispatchResult(
kind="rpc",
data=serialized,
invalidate=invalidate,
merge=resolve_merges(view_class, req.args, serialized),
invalidate_header=header,
)
def _effective_policy(fn_names: list[str]) -> tuple[int, int | bool]:
"""(effective_rev, effective_cache) across a context's functions."""
rev = 0
cache_policy: int | bool = True # True=forever, False=no-store, int=TTL
for fn_name in fn_names:
fn_cls = get_function(fn_name)
if fn_cls is None:
continue
m = getattr(fn_cls, "_meta", {})
rev = max(rev, m.get("rev", 0))
fn_cache = m.get("cache", True)
if fn_cache is False:
return rev, False
if isinstance(fn_cache, int):
cache_policy = fn_cache if cache_policy is True else min(cache_policy, fn_cache)
return rev, cache_policy
async def dispatch_context(req: DispatchRequest, context_name: str, cache: CacheOrchestrator) -> DispatchResult:
"""Bundled context read with origin-cache get/put."""
groups = get_context_groups()
fn_names = groups.get(context_name)
if not fn_names:
raise NotFound(f"Context '{context_name}' not found")
params = req.args or {}
rev, cache_policy = _effective_policy(fn_names)
user_id = user_id_of(req.identity)
use_cache = cache.enabled and cache_policy is not False
if use_cache:
cached = cache.get(context_name, params, user_id, rev)
if cached is not None:
return DispatchResult(kind="context", body_bytes=cached, cache_status="HIT")
bundle: dict[str, Any] = {}
for fn_name in fn_names:
view_class = _resolve_function(fn_name)
enforce_auth(req.identity, getattr(view_class, "_meta", {}).get("auth"), req.native_request)
view = view_class(req.native_request)
validated = _validate_input(view.Input, {k: v for k, v in params.items() if _declares(view.Input, k)})
bundle[fn_name] = _serialize(await _run(view, validated))
body = _canonical_bytes(bundle)
if use_cache:
cache.put(context_name, params, body, user_id, rev)
return DispatchResult(kind="context", data=bundle, body_bytes=body,
cache_status="MISS" if use_cache else None)
def _declares(input_cls: Any, name: str) -> bool:
return bool(
input_cls and input_cls is not BaseModel
and getattr(input_cls, "model_fields", None)
and name in input_cls.model_fields
)

View File

@@ -0,0 +1,58 @@
"""
Canonical protocol-level error taxonomy.
Dispatch raises these typed exceptions; each backend adapter renders them to
its native response (Django `JsonResponse`, FastAPI exception handler, …). The
shared dispatch core never returns error envelopes — it raises, and the adapter
catches at its boundary.
"""
from __future__ import annotations
from enum import Enum
from typing import Any
class ErrorCode(str, Enum):
NOT_FOUND = "NOT_FOUND"
BAD_REQUEST = "BAD_REQUEST"
VALIDATION_ERROR = "VALIDATION_ERROR"
UNAUTHORIZED = "UNAUTHORIZED"
FORBIDDEN = "FORBIDDEN"
NOT_IMPLEMENTED = "NOT_IMPLEMENTED"
INTERNAL_ERROR = "INTERNAL_ERROR"
STATUS = {
ErrorCode.NOT_FOUND: 404,
ErrorCode.BAD_REQUEST: 400,
ErrorCode.VALIDATION_ERROR: 422,
ErrorCode.UNAUTHORIZED: 401,
ErrorCode.FORBIDDEN: 403,
ErrorCode.NOT_IMPLEMENTED: 501,
ErrorCode.INTERNAL_ERROR: 500,
}
class MizanError(Exception):
"""Base for protocol-level dispatch errors."""
code: ErrorCode = ErrorCode.INTERNAL_ERROR
def __init__(self, message: str, *, details: dict[str, Any] | None = None) -> None:
super().__init__(message)
self.message = message
self.details = details
@property
def status_code(self) -> int:
return STATUS[self.code]
class NotFound(MizanError): code = ErrorCode.NOT_FOUND # noqa: E701
class BadRequest(MizanError): code = ErrorCode.BAD_REQUEST # noqa: E701
class ValidationFailed(MizanError): code = ErrorCode.VALIDATION_ERROR # noqa: E701
class Unauthorized(MizanError): code = ErrorCode.UNAUTHORIZED # noqa: E701
class Forbidden(MizanError): code = ErrorCode.FORBIDDEN # noqa: E701
class NotImplementedYet(MizanError): code = ErrorCode.NOT_IMPLEMENTED # noqa: E701
class InternalError(MizanError): code = ErrorCode.INTERNAL_ERROR # noqa: E701

View File

@@ -0,0 +1,32 @@
"""
The minimal identity contract the dispatch core reads.
Auth-guard evaluation and per-user cache scoping need exactly these four
attributes — nothing about how the identity was established. Django's session
`User`, `JWTUser`, `MWTUser`, and any token-user an adapter constructs all
satisfy this structurally; no inheritance required.
`get_all_permissions()` (Django ORM) is deliberately NOT here — the MWT
permission-key is a Django-side concern, and adding it would force every
adapter to implement a Django-shaped method.
"""
from __future__ import annotations
from typing import Protocol, runtime_checkable
@runtime_checkable
class Identity(Protocol):
is_authenticated: bool
is_staff: bool
is_superuser: bool
@property
def pk(self) -> object | None: ... # str | int | None; cache scoping stringifies it
def user_id_of(identity: Identity | None) -> str | None:
"""The cache-scoping user id — `str(pk)`, or None for anonymous/no-pk."""
pk = getattr(identity, "pk", None)
return str(pk) if pk is not None else None

View File

@@ -0,0 +1,174 @@
"""
Server-driven invalidation + merge resolution — the adapter-agnostic core.
This is the canonical implementation (formerly housed in the Django executor).
Every adapter calls `resolve_invalidation` / `resolve_merges` / `format_invalidate_header`
so the wire shape is identical across backends.
Invalidation entries take one of two shapes:
- a bare context/function name string → broad invalidation
- {"context": name, "params": {...}} → scoped invalidation
Function-level `affects=` resolves to the function NAME as the key (v1 refetches
the whole context anyway).
"""
from __future__ import annotations
from typing import Any
from urllib.parse import quote
from pydantic import BaseModel
from mizan_core.registry import get_context_groups, get_function
from mizan_core.type_utils import types_match_for_merge
__all__ = ["resolve_invalidation", "resolve_merges", "format_invalidate_header"]
def _resolve_affects_target(target_name: str) -> tuple[str, str, str | None]:
"""Classify an affects target → ("context", name, None) | ("function", name, ctx)."""
groups = get_context_groups()
if target_name in groups:
return ("context", target_name, None)
for ctx_name, fn_names in groups.items():
if target_name in fn_names:
return ("function", target_name, ctx_name)
# Unknown — treat as a context name (non-context fn, or not-yet-registered).
return ("context", target_name, None)
def _context_param_names(context_name: str) -> set[str]:
"""Union of Input field names across the functions in a context."""
param_names: set[str] = set()
for fn_name in get_context_groups().get(context_name, []):
fn_cls = get_function(fn_name)
if fn_cls is None:
continue
input_cls = getattr(fn_cls, "Input", None)
if input_cls and input_cls is not BaseModel and hasattr(input_cls, "model_fields"):
param_names.update(input_cls.model_fields.keys())
return param_names
def resolve_invalidation(
view_class: type | None,
input_data: dict[str, Any] | None = None,
) -> list[str | dict[str, Any]] | None:
"""Three-tier auto-scoping over `@client(affects=...)`. None if nothing to invalidate.
Tier 1: arg-name matching against the context's params → scoped entry.
Tier 2: auth inference — Edge-side, not handled here.
Tier 3: broad fallback — bare name.
"""
if view_class is None:
return None
affects = getattr(view_class, "_meta", {}).get("affects")
if not affects:
return None
result: list[str | dict[str, Any]] = []
seen: set[str] = set()
for target in affects:
if target["type"] == "context":
target_name = target["name"]
elif target["type"] == "function" and target.get("context"):
target_name = target["name"]
else:
continue
if target_name in seen:
continue
seen.add(target_name)
resolved = _resolve_affects_target(target_name)
ctx_for_params = resolved[2] if resolved[0] == "function" else resolved[1]
if input_data and ctx_for_params:
context_params = _context_param_names(ctx_for_params)
matched = {k: v for k, v in input_data.items() if k in context_params}
if matched:
result.append({"context": target_name, "params": matched})
continue
result.append(target_name)
return result or None
def _resolve_merge_slot(context_name: str, mutation_output: Any) -> str | None:
"""The unique function-name slot in a context whose return type matches the mutation output."""
if mutation_output is None:
return None
matches: list[str] = []
for fn_name in get_context_groups().get(context_name, []):
fn_cls = get_function(fn_name)
if fn_cls is None:
continue
fn_output = getattr(fn_cls, "Output", None)
if fn_output is not None and types_match_for_merge(fn_output, mutation_output):
matches.append(fn_name)
return matches[0] if len(matches) == 1 else None
def resolve_merges(
view_class: type | None,
input_data: dict[str, Any] | None,
result_data: Any,
) -> list[dict[str, Any]] | None:
"""Build the `merge` list from `@client(merge=...)`. None when no targets resolve.
Each entry is `{context, slot, value, params?}`; `slot` is the context-function
whose return type matches the mutation output (server-side type-checked routing,
no client shape inference). Ambiguous/unmatched targets are dropped.
"""
if view_class is None:
return None
targets = getattr(view_class, "_meta", {}).get("merge") or []
if not targets:
return None
mutation_output = getattr(view_class, "Output", None)
out: list[dict[str, Any]] = []
seen: set[str] = set()
for ctx_name in targets:
if ctx_name in seen:
continue
seen.add(ctx_name)
slot = _resolve_merge_slot(ctx_name, mutation_output)
if slot is None:
continue
entry: dict[str, Any] = {"context": ctx_name, "slot": slot, "value": result_data}
if input_data:
matched = {k: v for k, v in input_data.items() if k in _context_param_names(ctx_name)}
if matched:
entry["params"] = matched
out.append(entry)
return out or None
def format_invalidate_header(invalidate: list[str | dict[str, Any]]) -> str:
"""Serialize invalidation targets to the `X-Mizan-Invalidate` header value.
Comma-separated contexts; semicolon-separated URL-encoded params per context.
["user"] → "user"
["user", "notifications"] → "user, notifications"
[{"context": "user", "params": {"user_id": 5}}] → "user;user_id=5"
[{"context": "search", "params": {"q": "hello world"}}] → "search;q=hello%20world"
"""
parts: list[str] = []
for entry in invalidate:
if isinstance(entry, str):
parts.append(entry)
elif isinstance(entry, dict):
ctx = entry["context"]
params = entry.get("params", {})
if params:
param_str = ";".join(
f"{quote(str(k), safe='')}={quote(str(v), safe='')}"
for k, v in sorted(params.items())
)
parts.append(f"{ctx};{param_str}")
else:
parts.append(ctx)
return ", ".join(parts)

View File

@@ -0,0 +1,147 @@
"""Unit tests for the adapter-agnostic dispatch core."""
import asyncio
import pytest
from pydantic import BaseModel
from mizan_core.auth import AuthConfig, JWTConfig, INVALID, authenticate, create_access_token
from mizan_core.authguard import enforce_auth
from mizan_core.client.function import client
from mizan_core.dispatch import CacheOrchestrator, DispatchRequest, dispatch_call
from mizan_core.errors import Forbidden, Unauthorized
from mizan_core.invalidation import format_invalidate_header, resolve_invalidation
from mizan_core.registry import clear_registry, register
class Ident:
def __init__(self, authed=True, staff=False, su=False, pk=1):
self.is_authenticated = authed
self.is_staff = staff
self.is_superuser = su
self.pk = pk
# ─── authguard ──────────────────────────────────────────────────────────────
def test_auth_required_anonymous():
with pytest.raises(Unauthorized):
enforce_auth(None, True)
def test_auth_required_authenticated():
enforce_auth(Ident(), True) # no raise
def test_auth_staff_denied_then_allowed():
with pytest.raises(Forbidden):
enforce_auth(Ident(staff=False), "staff")
enforce_auth(Ident(staff=True), "staff")
def test_auth_superuser():
with pytest.raises(Forbidden):
enforce_auth(Ident(su=False), "superuser")
enforce_auth(Ident(su=True), "superuser")
def test_auth_callable_false_and_raise():
with pytest.raises(Forbidden):
enforce_auth(Ident(), lambda r: False)
with pytest.raises(Forbidden, match="custom"):
enforce_auth(Ident(), lambda r: (_ for _ in ()).throw(PermissionError("custom")))
# ─── authenticate / INVALID sentinel ────────────────────────────────────────
def _cfg():
return AuthConfig(jwt=JWTConfig(private_key="k" * 32, public_key="k" * 32))
def test_authenticate_jwt_ok():
cfg = _cfg()
tok = create_access_token("7", "sess", cfg.jwt, is_staff=True)
ident = authenticate({"Authorization": f"Bearer {tok}"}, cfg)
assert ident.pk == 7 and ident.is_staff and ident.is_authenticated
def test_authenticate_bad_token_is_invalid_sentinel():
assert authenticate({"Authorization": "Bearer garbage"}, _cfg()) is INVALID
def test_authenticate_no_token_is_none():
assert authenticate({}, _cfg()) is None
# ─── invalidation + header ──────────────────────────────────────────────────
def test_invalidation_three_tier_and_header():
clear_registry()
UserCtx = "user"
class Out(BaseModel):
ok: bool
@client(context=UserCtx)
def user_profile(request, user_id: int) -> Out:
return Out(ok=True)
@client(affects=UserCtx)
def update_profile(request, user_id: int, name: str) -> Out:
return Out(ok=True)
register(user_profile, "user_profile")
register(update_profile, "update_profile")
# Tier 1: user_id matches context param → scoped
inv = resolve_invalidation(update_profile, {"user_id": 5, "name": "x"})
assert inv == [{"context": "user", "params": {"user_id": 5}}]
assert format_invalidate_header(inv) == "user;user_id=5"
# Tier 3: no matching param → broad
inv2 = resolve_invalidation(update_profile, {"name": "x"})
assert inv2 == ["user"]
clear_registry()
# ─── dispatch_call end to end ───────────────────────────────────────────────
def test_dispatch_call_auth_and_invalidation():
clear_registry()
class Out(BaseModel):
ok: bool
@client(context="user")
def user_profile(request, user_id: int) -> Out:
return Out(ok=True)
@client(affects="user", auth="staff")
def secure_update(request, user_id: int) -> Out:
return Out(ok=True)
register(user_profile, "user_profile")
register(secure_update, "secure_update")
cache = CacheOrchestrator(None, None)
# non-staff rejected
with pytest.raises(Forbidden):
asyncio.run(dispatch_call(
DispatchRequest(identity=Ident(staff=False), args={"user_id": 1}),
"secure_update", cache,
))
# staff passes, invalidation resolved
res = asyncio.run(dispatch_call(
DispatchRequest(identity=Ident(staff=True), args={"user_id": 1}),
"secure_update", cache,
))
assert res.kind == "rpc" and res.data == {"ok": True}
assert res.invalidate == [{"context": "user", "params": {"user_id": 1}}]
assert res.invalidate_header == "user;user_id=1"
clear_registry()