FastAPI and TypeScript improved

This commit is contained in:
2026-06-04 05:14:29 -04:00
parent 67ad91b673
commit 66b2db81fb
28 changed files with 1864 additions and 717 deletions

View File

@@ -30,6 +30,9 @@ from pydantic import BaseModel, ValidationError
from mizan.cache import get_cache, cache_get, cache_put, cache_purge
from mizan_core.registry import get_function, get_context_groups
from mizan_core.upload import UploadedFile, bind_uploads
from mizan_core import invalidation as _core_inval
from mizan_core.authguard import enforce_auth as _core_enforce_auth
from mizan_core.errors import MizanError as _CoreMizanError
from mizan.setup.settings import get_settings
if TYPE_CHECKING:
@@ -113,53 +116,14 @@ def _check_auth_requirement(
Django User (from session). Either way, no additional DB query is made
for the built-in checks. Custom callables may query DB if they choose.
"""
if auth_requirement is None:
# Evaluation lives in the shared core (mizan_core.authguard); the callable
# path receives the native Django request. Core raises; we render to the
# Django-shim FunctionError shape the executor expects.
try:
_core_enforce_auth(getattr(request, "user", None), auth_requirement, request)
return None
user = request.user
# Handle callable auth
if callable(auth_requirement):
try:
result = auth_requirement(request)
if result:
return None # Authorized
else:
return FunctionError(
code=ErrorCode.FORBIDDEN,
message="Access denied",
)
except PermissionError as e:
# Custom error message from the callable
return FunctionError(
code=ErrorCode.FORBIDDEN,
message=str(e) or "Access denied",
)
# Check authentication (required for all string-based auth)
if not getattr(user, "is_authenticated", False):
return FunctionError(
code=ErrorCode.UNAUTHORIZED,
message="Authentication required",
)
# Check staff requirement
if auth_requirement == "staff":
if not getattr(user, "is_staff", False):
return FunctionError(
code=ErrorCode.FORBIDDEN,
message="Staff access required",
)
# Check superuser requirement
elif auth_requirement == "superuser":
if not getattr(user, "is_superuser", False):
return FunctionError(
code=ErrorCode.FORBIDDEN,
message="Superuser access required",
)
return None
except _CoreMizanError as e:
return FunctionError(code=ErrorCode(e.code.value), message=e.message)
_cache_log = logging.getLogger("mizan.cache")
@@ -198,51 +162,6 @@ def _purge_cache_for_invalidation(
_cache_log.warning("Cache purge failed", exc_info=True)
def _resolve_affects_target(target_name: str) -> tuple[str, str, str | None]:
"""
Determine whether an affects target is a context name or function name.
Returns:
("context", "user", None) — full context invalidation
("function", "user_profile", "user") — function within context
"""
groups = get_context_groups()
# Check if it's a context name directly
if target_name in groups:
return ("context", target_name, None)
# Check if it's a function name within a context
for ctx_name, fn_names in groups.items():
if target_name in fn_names:
return ("function", target_name, ctx_name)
# Not a context or context function — treat as context name anyway
# (it might be a non-context function or an as-yet-unregistered context)
return ("context", target_name, None)
def _get_context_param_names(context_name: str) -> set[str]:
"""
Get the set of parameter names used by functions in a context.
Returns the union of all Input field names across context functions.
"""
groups = get_context_groups()
fn_names = groups.get(context_name, [])
param_names: set[str] = set()
for fn_name in fn_names:
fn_cls = get_function(fn_name)
if fn_cls is None:
continue
input_cls = getattr(fn_cls, "Input", None)
if input_cls and input_cls is not BaseModel and hasattr(input_cls, "model_fields"):
param_names.update(input_cls.model_fields.keys())
return param_names
def _resolve_invalidation(
view_class: type | None,
input_data: dict[str, Any] | None = None,
@@ -261,49 +180,7 @@ def _resolve_invalidation(
Returns a list suitable for both JSON body and header serialization.
Returns None if no invalidation needed.
"""
if view_class is None:
return None
meta = getattr(view_class, "_meta", {})
affects = meta.get("affects")
if not affects:
return None
result = []
seen = set()
for target in affects:
if target["type"] == "context":
target_name = target["name"]
elif target["type"] == "function" and target.get("context"):
# Function-level: use the function name as the invalidation key
target_name = target["name"]
else:
continue
if target_name in seen:
continue
seen.add(target_name)
# Resolve the context this target belongs to (for param lookup)
resolved = _resolve_affects_target(target_name)
ctx_for_params = resolved[2] if resolved[0] == "function" else resolved[1]
# Tier 1: argument name matching
if input_data and ctx_for_params:
context_params = _get_context_param_names(ctx_for_params)
matched = {
k: v for k, v in input_data.items()
if k in context_params
}
if matched:
result.append({"context": target_name, "params": matched})
continue
# Tier 3: broad fallback
result.append(target_name)
return result if result else None
return _core_inval.resolve_invalidation(view_class, input_data)
def _resolve_merges(
@@ -322,94 +199,12 @@ def _resolve_merges(
Mirrors _resolve_invalidation's tier-1 auto-scoping for params.
Entries whose slot can't be uniquely resolved are dropped.
"""
if view_class is None:
return None
from mizan_core.type_utils import types_match_for_merge
meta = getattr(view_class, "_meta", {})
targets = meta.get("merge") or []
if not targets:
return None
mutation_output = getattr(view_class, "Output", None)
out: list[dict[str, Any]] = []
seen: set[str] = set()
for ctx_name in targets:
if ctx_name in seen:
continue
seen.add(ctx_name)
slot = _resolve_merge_slot(ctx_name, mutation_output, types_match_for_merge)
if slot is None:
continue
entry: dict[str, Any] = {"context": ctx_name, "slot": slot, "value": result_data}
if input_data:
context_params = _get_context_param_names(ctx_name)
matched = {
k: v for k, v in input_data.items()
if k in context_params
}
if matched:
entry["params"] = matched
out.append(entry)
return out
return _core_inval.resolve_merges(view_class, input_data, result_data)
def _resolve_merge_slot(context_name: str, mutation_output: Any, type_matcher: Any) -> str | None:
"""Find the unique function-name slot in context whose return type matches mutation's output."""
if mutation_output is None:
return None
groups = get_context_groups()
fn_names = groups.get(context_name, [])
matches: list[str] = []
for fn_name in fn_names:
fn_cls = get_function(fn_name)
if fn_cls is None:
continue
fn_output = getattr(fn_cls, "Output", None)
if fn_output is not None and type_matcher(fn_output, mutation_output):
matches.append(fn_name)
return matches[0] if len(matches) == 1 else None
def _format_invalidate_header(
invalidate: list[str | dict[str, Any]],
) -> str:
"""
Format invalidation targets as X-Mizan-Invalidate header value.
Format: comma-separated contexts. Semicolon-separated params per context.
Param values are URL-encoded to prevent delimiter collisions.
Examples:
["user"] → "user"
["user", "notifications"] → "user, notifications"
[{"context": "user", "params": {"user_id": 5}}]
"user;user_id=5"
[{"context": "search", "params": {"q": "hello world"}}]
"search;q=hello%20world"
"""
from urllib.parse import quote
parts = []
for entry in invalidate:
if isinstance(entry, str):
parts.append(entry)
elif isinstance(entry, dict):
ctx = entry["context"]
params = entry.get("params", {})
if params:
param_str = ";".join(
f"{quote(str(k), safe='')}={quote(str(v), safe='')}"
for k, v in sorted(params.items())
)
parts.append(f"{ctx};{param_str}")
else:
parts.append(ctx)
return ", ".join(parts)
def _format_invalidate_header(invalidate: list[str | dict[str, Any]]) -> str:
"""Format invalidation targets as the X-Mizan-Invalidate header value (shared core)."""
return _core_inval.format_invalidate_header(invalidate)
def execute_function(

View File

@@ -1,245 +1,79 @@
"""
JWT Token Creation and Validation
JWT tokens — the Django adapter over the shared core (`mizan_core.auth.jwt`).
Uses PyJWT directly - no allauth dependency.
Tokens are tied to Django sessions for immediate revocation on logout.
The token logic (mint/decode/refresh, `JWTUser`, `TokenPair`, `TokenPayload`)
lives in the core; this module binds it to Django settings and keeps the
session-revocation check (`validate_session`), which is Django-session-specific.
"""
import time
from typing import NamedTuple
from __future__ import annotations
import jwt
from django.contrib.sessions.backends.base import SessionBase
from mizan_core.auth import jwt as _core_jwt
from mizan_core.auth.jwt import JWTConfig, JWTUser, TokenPair, TokenPayload
from .settings import get_settings
class TokenPair(NamedTuple):
"""Access and refresh token pair."""
access_token: str
refresh_token: str
expires_in: int
__all__ = [
"TokenPair",
"TokenPayload",
"JWTUser",
"create_access_token",
"create_refresh_token",
"create_token_pair",
"decode_token",
"validate_session",
"refresh_tokens",
]
class TokenPayload(NamedTuple):
"""Decoded token payload."""
user_id: int | str
session_key: str
token_type: str
is_staff: bool
is_superuser: bool
exp: int
iat: int
class JWTUser:
"""
Minimal user object created from JWT claims.
Used as request.user for JWT-authenticated requests.
No database query required - all data comes from the token.
If you need the full User object with all fields, query explicitly:
user = User.objects.get(pk=request.user.id)
"""
def __init__(self, payload: TokenPayload):
self.id = int(payload.user_id) if isinstance(payload.user_id, str) else payload.user_id
self.pk = self.id
self.is_staff = payload.is_staff
self.is_superuser = payload.is_superuser
self.is_authenticated = True
self.is_anonymous = False
self.is_active = True # Assumed active if they have a valid token
def __str__(self):
return f"JWTUser(id={self.id})"
def __repr__(self):
return f"JWTUser(id={self.id}, is_staff={self.is_staff}, is_superuser={self.is_superuser})"
def create_access_token(
user_id: int | str,
session_key: str,
*,
is_staff: bool = False,
is_superuser: bool = False,
) -> str:
"""
Create a short-lived access token.
The token contains:
- sub: user ID
- sid: session key (for revocation checking)
- staff: is_staff flag
- super: is_superuser flag
- type: "access"
- iat: issued at
- exp: expiration
"""
settings = get_settings()
now = int(time.time())
payload = {
"sub": str(user_id),
"sid": session_key,
"staff": is_staff,
"super": is_superuser,
"type": "access",
"iat": now,
"exp": now + settings.access_token_expires_in,
}
return jwt.encode(
payload,
settings.private_key,
algorithm=settings.algorithm,
def _config() -> JWTConfig:
s = get_settings()
return JWTConfig(
private_key=s.private_key,
public_key=s.public_key,
algorithm=s.algorithm,
access_token_expires_in=s.access_token_expires_in,
refresh_token_expires_in=s.refresh_token_expires_in,
)
def create_refresh_token(
user_id: int | str,
session_key: str,
*,
is_staff: bool = False,
is_superuser: bool = False,
) -> str:
"""
Create a longer-lived refresh token.
The token contains:
- sub: user ID
- sid: session key (for revocation checking)
- staff: is_staff flag
- super: is_superuser flag
- type: "refresh"
- iat: issued at
- exp: expiration
"""
settings = get_settings()
now = int(time.time())
payload = {
"sub": str(user_id),
"sid": session_key,
"staff": is_staff,
"super": is_superuser,
"type": "refresh",
"iat": now,
"exp": now + settings.refresh_token_expires_in,
}
return jwt.encode(
payload,
settings.private_key,
algorithm=settings.algorithm,
)
def create_access_token(user_id, session_key, *, is_staff=False, is_superuser=False) -> str:
return _core_jwt.create_access_token(user_id, session_key, _config(),
is_staff=is_staff, is_superuser=is_superuser)
def create_token_pair(
user_id: int | str,
session_key: str,
*,
is_staff: bool = False,
is_superuser: bool = False,
) -> TokenPair:
"""Create both access and refresh tokens."""
settings = get_settings()
return TokenPair(
access_token=create_access_token(
user_id, session_key, is_staff=is_staff, is_superuser=is_superuser
),
refresh_token=create_refresh_token(
user_id, session_key, is_staff=is_staff, is_superuser=is_superuser
),
expires_in=settings.access_token_expires_in,
)
def create_refresh_token(user_id, session_key, *, is_staff=False, is_superuser=False) -> str:
return _core_jwt.create_refresh_token(user_id, session_key, _config(),
is_staff=is_staff, is_superuser=is_superuser)
def decode_token(token: str, expected_type: str = None) -> TokenPayload | None:
"""
Decode and validate a JWT token.
def create_token_pair(user_id, session_key, *, is_staff=False, is_superuser=False) -> TokenPair:
return _core_jwt.create_token_pair(user_id, session_key, _config(),
is_staff=is_staff, is_superuser=is_superuser)
Returns None if:
- Token is invalid or expired
- Token type doesn't match expected_type (if specified)
"""
settings = get_settings()
try:
payload = jwt.decode(
token,
settings.public_key,
algorithms=[settings.algorithm],
)
except jwt.PyJWTError:
return None
# Validate token type if specified
if expected_type and payload.get("type") != expected_type:
return None
return TokenPayload(
user_id=payload["sub"],
session_key=payload["sid"],
token_type=payload["type"],
is_staff=payload.get("staff", False),
is_superuser=payload.get("super", False),
exp=payload["exp"],
iat=payload["iat"],
)
def decode_token(token: str, expected_type: str | None = None) -> TokenPayload | None:
return _core_jwt.decode_token(token, _config(), expected_type=expected_type)
def validate_session(session_key: str) -> bool:
"""
Check if a session is still valid (exists and not expired).
"""Immediate-logout revocation: is this Django session still alive?
This is the key to immediate logout revocation - if the session
is destroyed, tokens tied to it become invalid.
Honors `JWT_VALIDATE_SESSION` — when disabled, always True. This is the one
Django-session-bound piece; the core's `refresh_tokens` takes it as an
injected `session_validator`.
"""
from importlib import import_module
from django.conf import settings as django_settings
jwt_settings = get_settings()
if not jwt_settings.validate_session:
if not get_settings().validate_session:
return True
# Use the configured session engine
engine = import_module(django_settings.SESSION_ENGINE)
SessionStore = engine.SessionStore
# Try to load the session
session = SessionStore(session_key=session_key)
# Check if session exists and is not empty
# exists() is more reliable than checking load() result
session = engine.SessionStore(session_key=session_key)
return session.exists(session_key)
def refresh_tokens(refresh_token: str) -> TokenPair | None:
"""
Use a refresh token to obtain new tokens.
Returns None if:
- Refresh token is invalid or expired
- Associated session no longer exists
"""
payload = decode_token(refresh_token, expected_type="refresh")
if payload is None:
return None
# Validate the session still exists
if not validate_session(payload.session_key):
return None
# Issue new token pair with same claims
return create_token_pair(
payload.user_id,
payload.session_key,
is_staff=payload.is_staff,
is_superuser=payload.is_superuser,
)
return _core_jwt.refresh_tokens(refresh_token, _config(), session_validator=validate_session)

View File

@@ -170,8 +170,8 @@ class HTTPAuthTests(TestCase):
def test_jwt_expired_with_session(self):
"""Expired JWT with valid session → Reject (do NOT fall back)."""
# Create token with past expiration by mocking time
with patch("mizan.jwt.tokens.time.time", return_value=0):
# Create token with past expiration by mocking time (minting lives in the core now)
with patch("mizan_core.auth.jwt.time.time", return_value=0):
tokens = create_token_pair(
self.user.pk,
self.session_key,

View File

@@ -35,12 +35,18 @@ from .executor import (
execute_function,
)
from .router import router, mizan_exception_handler, mizan_validation_handler
from .auth import MizanAuthMiddleware, mizan_auth
from .config import MizanConfig, from_env
from mizan_core.upload import File, Upload, UploadedFile
__all__ = [
"Upload",
"File",
"UploadedFile",
"mizan_auth",
"MizanAuthMiddleware",
"MizanConfig",
"from_env",
"router",
"mizan_exception_handler",
"mizan_validation_handler",

View File

@@ -0,0 +1,54 @@
"""
Built-in identity for FastAPI — Django-equivalent automatic `request.state.user`.
Opt in via `Depends(mizan_auth())` on a route/router, or mount `MizanAuthMiddleware`
app-wide. Both decode a bearer-JWT (`Authorization: Bearer`) or MWT (`X-Mizan-Token`)
via the shared core and set `request.state.user`. A present-but-invalid token is
rejected (401) rather than silently downgraded — the `INVALID` sentinel contract.
If you'd rather resolve identity yourself, set `request.state.user` upstream and skip
these; dispatch reads it directly.
"""
from __future__ import annotations
from typing import Callable
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware
from mizan_core.auth import INVALID, authenticate
from mizan_core.errors import Unauthorized
from .config import get_config
def _resolve(request: Request) -> None:
ident = authenticate(request.headers, get_config(request).auth)
if ident is INVALID:
raise Unauthorized("Invalid or expired token")
if ident is not None:
request.state.user = ident
def mizan_auth() -> Callable:
"""FastAPI dependency that populates `request.state.user` from a token."""
async def _dep(request: Request) -> None:
_resolve(request)
return _dep
class MizanAuthMiddleware(BaseHTTPMiddleware):
"""App-wide variant of `mizan_auth` — resolves identity on every request."""
async def dispatch(self, request, call_next):
try:
_resolve(request)
except Unauthorized:
from .router import _no_store
from mizan_core.errors import ErrorCode
return _no_store(
{"error": {"code": ErrorCode.UNAUTHORIZED.value, "message": "Invalid or expired token"}},
status_code=401,
)
return await call_next(request)

View File

@@ -0,0 +1,80 @@
"""
FastAPI configuration — the "no settings.py" seam.
Builds the shared core's `AuthConfig` (JWT + MWT) and a `CacheOrchestrator`
from environment variables, overridable per-app via `app.state.mizan_config`.
Env:
MIZAN_CACHE_SECRET HMAC cache signing key (enables origin cache)
MIZAN_CACHE_REDIS_URL Redis URL (else in-memory cache)
MIZAN_MWT_SECRET MWT signing key
MIZAN_MWT_AUDIENCE MWT audience (default "mizan")
JWT_PRIVATE_KEY JWT signing key (enables bearer-JWT auth)
JWT_PUBLIC_KEY JWT verify key (default: private key, HS256)
JWT_ALGORITHM default "HS256"
JWT_ACCESS_TOKEN_EXPIRES_IN / JWT_REFRESH_TOKEN_EXPIRES_IN
"""
from __future__ import annotations
import os
from dataclasses import dataclass
from mizan_core.auth import AuthConfig, JWTConfig
from mizan_core.cache.backend import CacheBackend, MemoryCache
from mizan_core.dispatch import CacheOrchestrator
@dataclass(frozen=True)
class MizanConfig:
auth: AuthConfig
cache: CacheOrchestrator
def _cache_backend(secret: str | None, redis_url: str | None) -> CacheBackend | None:
if not secret:
return None
if redis_url:
from mizan_core.cache.backend import RedisCache
return RedisCache(redis_url)
return MemoryCache()
def _jwt_config() -> JWTConfig | None:
key = os.getenv("JWT_PRIVATE_KEY")
if not key:
return None
return JWTConfig(
private_key=key,
public_key=os.getenv("JWT_PUBLIC_KEY", key),
algorithm=os.getenv("JWT_ALGORITHM", "HS256"),
access_token_expires_in=int(os.getenv("JWT_ACCESS_TOKEN_EXPIRES_IN", "300")),
refresh_token_expires_in=int(os.getenv("JWT_REFRESH_TOKEN_EXPIRES_IN", "604800")),
)
def from_env() -> MizanConfig:
secret = os.getenv("MIZAN_CACHE_SECRET")
backend = _cache_backend(secret, os.getenv("MIZAN_CACHE_REDIS_URL"))
auth = AuthConfig(
jwt=_jwt_config(),
mwt_secret=os.getenv("MIZAN_MWT_SECRET"),
mwt_audience=os.getenv("MIZAN_MWT_AUDIENCE", "mizan"),
)
return MizanConfig(auth=auth, cache=CacheOrchestrator(backend, secret))
def get_config(request) -> MizanConfig:
"""Per-app config: `app.state.mizan_config` if set, else built from env (cached)."""
app = getattr(request, "app", None)
state = getattr(app, "state", None)
override = getattr(state, "mizan_config", None) if state is not None else None
if override is not None:
return override
global _DEFAULT
if _DEFAULT is None:
_DEFAULT = from_env()
return _DEFAULT
_DEFAULT: MizanConfig | None = None

View File

@@ -1,263 +1,69 @@
"""
RPC dispatch — looks up registered functions, validates input against the
function's Pydantic Input model, executes, and returns the serialized result.
Dispatch — a thin shim over the shared core (`mizan_core.dispatch`).
Errors raise typed exceptions (MizanError subclasses). Wire those to JSON
responses by registering `mizan_exception_handler` on the FastAPI app, or
let them propagate to your own handler.
The protocol machinery (auth, validation, execution, invalidation, merge, cache)
lives in `mizan_core`; this module re-exports the canonical error taxonomy and
keeps backward-compatible helpers. The router drives `dispatch_call` /
`dispatch_context` directly to get invalidation + origin cache.
"""
from __future__ import annotations
from enum import Enum
from typing import Any
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel, ValidationError
from mizan_core.dispatch import CacheOrchestrator, DispatchRequest, dispatch_call
from mizan_core.errors import (
BadRequest,
ErrorCode,
Forbidden,
InternalError,
MizanError,
NotFound,
NotImplementedYet,
Unauthorized,
ValidationFailed,
)
from mizan_core.invalidation import resolve_invalidation, resolve_merges
from mizan_core.registry import get_context_groups, get_function
from mizan_core.type_utils import types_match_for_merge
__all__ = [
"ErrorCode",
"MizanError",
"NotFound",
"BadRequest",
"ValidationFailed",
"Unauthorized",
"Forbidden",
"NotImplementedYet",
"InternalError",
"compute_invalidation",
"compute_merges",
"execute_function",
]
# ─── Error taxonomy ─────────────────────────────────────────────────────────
class ErrorCode(str, Enum):
NOT_FOUND = "NOT_FOUND"
BAD_REQUEST = "BAD_REQUEST"
VALIDATION_ERROR = "VALIDATION_ERROR"
UNAUTHORIZED = "UNAUTHORIZED"
FORBIDDEN = "FORBIDDEN"
NOT_IMPLEMENTED = "NOT_IMPLEMENTED"
INTERNAL_ERROR = "INTERNAL_ERROR"
_STATUS = {
ErrorCode.NOT_FOUND: 404,
ErrorCode.BAD_REQUEST: 400,
ErrorCode.VALIDATION_ERROR: 422,
ErrorCode.UNAUTHORIZED: 401,
ErrorCode.FORBIDDEN: 403,
ErrorCode.NOT_IMPLEMENTED: 501,
ErrorCode.INTERNAL_ERROR: 500,
}
class MizanError(Exception):
"""Base for protocol-level dispatch errors."""
code: ErrorCode = ErrorCode.INTERNAL_ERROR
def __init__(self, message: str, *, details: dict[str, Any] | None = None) -> None:
super().__init__(message)
self.message = message
self.details = details
@property
def status_code(self) -> int:
return _STATUS[self.code]
class NotFound(MizanError): code = ErrorCode.NOT_FOUND # noqa: E701
class BadRequest(MizanError): code = ErrorCode.BAD_REQUEST # noqa: E701
class ValidationFailed(MizanError): code = ErrorCode.VALIDATION_ERROR # noqa: E701
class Unauthorized(MizanError): code = ErrorCode.UNAUTHORIZED # noqa: E701
class Forbidden(MizanError): code = ErrorCode.FORBIDDEN # noqa: E701
class NotImplementedYet(MizanError): code = ErrorCode.NOT_IMPLEMENTED # noqa: E701
class InternalError(MizanError): code = ErrorCode.INTERNAL_ERROR # noqa: E701
# ─── Auth ───────────────────────────────────────────────────────────────────
def _user(request: Any) -> Any:
return getattr(getattr(request, "state", None), "user", None)
def _is_authenticated(user: Any) -> bool:
return bool(user) and getattr(user, "is_authenticated", True)
def _enforce_auth(request: Any, requirement: Any) -> None:
"""Verify the request meets the function's @client(auth=...) requirement, or raise."""
if requirement is None:
return
user = _user(request)
match requirement:
case True | "required":
if not _is_authenticated(user):
raise Unauthorized("Authentication required")
case "staff":
if not _is_authenticated(user):
raise Unauthorized("Authentication required")
if not getattr(user, "is_staff", False):
raise Forbidden("Staff access required")
case "superuser":
if not _is_authenticated(user):
raise Unauthorized("Authentication required")
if not getattr(user, "is_superuser", False):
raise Forbidden("Superuser access required")
case f if callable(f):
if not f(request):
raise Forbidden("Permission denied")
case other:
raise InternalError(f"Unknown auth requirement: {other!r}")
# ─── Input validation ───────────────────────────────────────────────────────
def _validate_input(input_cls: Any, input_data: Any) -> BaseModel | None:
"""Validate input_data against the function's Input model. Returns the instance or None."""
if input_cls in (None, BaseModel) or not getattr(input_cls, "model_fields", None):
return None
fields = input_cls.model_fields
required = [name for name, f in fields.items() if f.is_required()]
if not input_data:
if required:
raise ValidationFailed(
"Input validation failed",
details={"fields": {name: ["Field required"] for name in required}},
)
return input_cls()
if not isinstance(input_data, dict):
raise BadRequest(f"Input must be an object, got {type(input_data).__name__}")
try:
return input_cls(**input_data)
except ValidationError as e:
raise ValidationFailed(
"Input validation failed",
details={"errors": e.errors()},
) from e
# ─── Dispatch ───────────────────────────────────────────────────────────────
def _resolve_function(fn_name: str) -> Any:
view_class = get_function(fn_name)
if view_class is None:
raise NotFound("Function not found")
if getattr(view_class, "_meta", {}).get("private"):
raise Forbidden("Function is not client-callable")
return view_class
def _serialize(result: Any) -> Any:
# jsonable_encoder walks BaseModel / list / dict recursively, so list[BaseModel]
# (and nested shapes) come out wire-ready without a per-shape branch here.
return jsonable_encoder(result)
async def execute_function(
request: Any,
fn_name: str,
input_data: dict[str, Any] | None = None,
) -> Any:
"""Dispatch a registered function. Returns the serialized result, or raises MizanError.
Awaits `view.acall` — async handlers run on the loop, sync handlers run
in the default threadpool, both via the same entrypoint.
"""
view_class = _resolve_function(fn_name)
_enforce_auth(request, view_class._meta.get("auth"))
view = view_class(request)
validated = _validate_input(view.Input, input_data)
try:
result = await view.acall(validated)
except NotImplementedError as e:
raise NotImplementedYet(str(e) or "Not implemented") from e
except MizanError:
raise
except Exception as e:
raise InternalError(str(e)) from e
return _serialize(result)
# ─── Invalidation ───────────────────────────────────────────────────────────
_NO_CACHE = CacheOrchestrator(None, None)
def compute_invalidation(view_class: Any, input_data: dict[str, Any] | None) -> list[Any]:
"""Build the `invalidate` list from @client(affects=...) metadata, auto-scoping when arg names match context params."""
affects = getattr(view_class, "_meta", {}).get("affects") or []
return [_invalidation_target(target, input_data or {}) for target in affects]
"""`@client(affects=...)` → invalidation list (empty when none). Shared core."""
return resolve_invalidation(view_class, input_data) or []
def compute_merges(view_class: Any, input_data: dict[str, Any] | None, result: Any) -> list[dict[str, Any]]:
"""Build the `merge` list from @client(merge=...) metadata.
"""`@client(merge=...)` → merge list (empty when none). Shared core."""
return resolve_merges(view_class, input_data, result) or []
Each entry is `{context, slot, value, params?}` where `slot` names the
function inside the context bundle the value lands in. The slot is
resolved server-side via `types_match_for_merge` so the kernel does
no shape inference — the server has the schema, type-checked routing
lives here. Entries whose slot can't be uniquely resolved are dropped
with a warning; the consumer falls back to refetch via `affects`.
async def execute_function(request: Any, fn_name: str, input_data: dict[str, Any] | None = None) -> Any:
"""Dispatch a function and return its serialized result (auth enforced via core).
Backward-compat entry point; the router uses `dispatch_call` directly to also
capture invalidation/merge and run the origin cache.
"""
targets = getattr(view_class, "_meta", {}).get("merge") or []
if not targets:
return []
mutation_output = getattr(view_class, "Output", None)
out: list[dict[str, Any]] = []
for ctx_name in targets:
slot = _resolve_merge_slot(ctx_name, mutation_output)
if slot is None:
continue
entry: dict[str, Any] = {"context": ctx_name, "slot": slot, "value": result}
scoped = _scoped_params(ctx_name, input_data or {})
if scoped:
entry["params"] = scoped
out.append(entry)
return out
def _resolve_merge_slot(context_name: str, mutation_output: Any) -> str | None:
"""Find the unique function-name slot whose return type matches the mutation's output.
Returns None on no match or ambiguous match (multiple candidates).
"""
if mutation_output is None:
return None
matches: list[str] = []
for fn_name in get_context_groups().get(context_name, []):
fn_cls = get_function(fn_name)
if fn_cls is None:
continue
fn_output = getattr(fn_cls, "Output", None)
if fn_output is not None and types_match_for_merge(fn_output, mutation_output):
matches.append(fn_name)
return matches[0] if len(matches) == 1 else None
def _scoped_params(context_name: str, input_data: dict[str, Any]) -> dict[str, Any]:
"""Match input args against the context's declared Input field names."""
fn_names = get_context_groups().get(context_name, [])
declared: set[str] = set()
for fn_name in fn_names:
fn_cls = get_function(fn_name)
if fn_cls is None:
continue
input_cls = getattr(fn_cls, "Input", None)
if input_cls and input_cls is not BaseModel and hasattr(input_cls, "model_fields"):
declared.update(input_cls.model_fields.keys())
return {k: v for k, v in input_data.items() if k in declared}
def _invalidation_target(target: dict[str, Any], input_data: dict[str, Any]) -> Any:
match target.get("type"):
case "context":
name = target["name"]
scoped = _scoped_params(name, input_data)
return {"context": name, "params": scoped} if scoped else name
case "function":
return {"function": target["name"]}
case _:
return target
identity = getattr(getattr(request, "state", None), "user", None)
res = await dispatch_call(
DispatchRequest(identity=identity, args=input_data, native_request=request),
fn_name,
_NO_CACHE,
)
return res.data

View File

@@ -19,22 +19,17 @@ from typing import Any
from fastapi import APIRouter, Request
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse
from fastapi.responses import JSONResponse, Response
from pydantic import BaseModel, Field, ValidationError
from starlette.datastructures import UploadFile
from mizan_core.registry import get_context_groups, get_function
from mizan_core.auth import INVALID, authenticate
from mizan_core.dispatch import DispatchRequest, dispatch_call, dispatch_context
from mizan_core.errors import BadRequest, ErrorCode, MizanError, Unauthorized
from mizan_core.registry import get_function
from mizan_core.upload import UploadedFile, bind_uploads
from .executor import (
BadRequest,
ErrorCode,
MizanError,
NotFound,
compute_invalidation,
compute_merges,
execute_function,
)
from .config import MizanConfig, get_config
router = APIRouter()
@@ -106,31 +101,52 @@ async def _parse_call(request: Request) -> tuple[str, dict[str, Any]]:
return body.fn, body.args
def _identity(request: Request, cfg: MizanConfig):
"""Identity for dispatch: a host-set `request.state.user`, else a token decode.
A present-but-invalid token rejects (401); no token → None (anonymous).
"""
existing = getattr(getattr(request, "state", None), "user", None)
if existing is not None:
return existing
ident = authenticate(request.headers, cfg.auth)
if ident is INVALID:
raise Unauthorized("Invalid or expired token")
return ident
@router.post("/call/")
async def function_call(request: Request) -> JSONResponse:
"""RPC dispatch — `{"fn": "...", "args": {...}}` (JSON) or multipart with file
parts → `{"result": ..., "invalidate": [...], "merge"?: [...]}`."""
"""RPC dispatch — JSON or multipart → `{"result", "invalidate", "merge"?}` with
the `X-Mizan-Invalidate` header alongside the body."""
cfg = get_config(request)
fn, args = await _parse_call(request)
fn_class = get_function(fn)
result = await execute_function(request, fn, args)
invalidate = compute_invalidation(fn_class, args)
merges = compute_merges(fn_class, args, result)
payload: dict[str, Any] = {"result": result, "invalidate": invalidate}
if merges:
payload["merge"] = merges
return _no_store(payload)
res = await dispatch_call(
DispatchRequest(identity=_identity(request, cfg), args=args, native_request=request),
fn, cfg.cache,
)
payload: dict[str, Any] = {"result": res.data, "invalidate": res.invalidate or []}
if res.merge:
payload["merge"] = res.merge
headers = {"Cache-Control": "no-store"}
if res.invalidate_header:
headers["X-Mizan-Invalidate"] = res.invalidate_header
return JSONResponse(payload, headers=headers)
@router.get("/ctx/{context_name}/")
async def context_fetch(context_name: str, request: Request) -> JSONResponse:
"""Bundled context fetch — `{function_name: result, ...}` for every function in the context."""
fn_names = get_context_groups().get(context_name)
if not fn_names:
raise NotFound(f"Context '{context_name}' not found")
params = dict(request.query_params)
bundled = {fn: await execute_function(request, fn, params) for fn in fn_names}
return _no_store(bundled)
async def context_fetch(context_name: str, request: Request) -> Response:
"""Bundled context fetch — origin-cached. `{function_name: result, ...}`."""
cfg = get_config(request)
res = await dispatch_context(
DispatchRequest(identity=_identity(request, cfg), args=dict(request.query_params),
native_request=request),
context_name, cfg.cache,
)
headers = {"Cache-Control": "no-store"}
if res.cache_status:
headers["X-Mizan-Cache"] = res.cache_status
return Response(content=res.body_bytes, media_type="application/json", headers=headers)
# ─── Exception handler ──────────────────────────────────────────────────────

View File

@@ -0,0 +1,98 @@
"""FastAPI parity with Django: X-Mizan-Invalidate header, origin cache, token auth."""
from __future__ import annotations
import pytest
from fastapi import Depends, FastAPI
from fastapi.testclient import TestClient
from pydantic import BaseModel
from mizan_core.auth import AuthConfig, JWTConfig, create_access_token
from mizan_core.cache.backend import MemoryCache
from mizan_core.client.function import client
from mizan_core.dispatch import CacheOrchestrator
from mizan_core.registry import clear_registry, register
from mizan_fastapi import (
MizanAuthMiddleware,
MizanConfig,
MizanError,
mizan_auth,
mizan_exception_handler,
router as mizan_router,
)
class Out(BaseModel):
ok: bool
SECRET = "x" * 32
JWT = JWTConfig(private_key=SECRET, public_key=SECRET)
def _app(*, with_cache=False, with_auth_dep=False) -> FastAPI:
clear_registry()
UserCtx = "user"
@client(context=UserCtx)
def user_profile(request, user_id: int) -> Out:
return Out(ok=True)
@client(affects=UserCtx)
def update_profile(request, user_id: int) -> Out:
return Out(ok=True)
@client(auth=True)
def whoami(request) -> Out:
return Out(ok=True)
register(user_profile, "user_profile")
register(update_profile, "update_profile")
register(whoami, "whoami")
app = FastAPI()
cache = CacheOrchestrator(MemoryCache(), SECRET) if with_cache else CacheOrchestrator(None, None)
app.state.mizan_config = MizanConfig(auth=AuthConfig(jwt=JWT), cache=cache)
deps = [Depends(mizan_auth())] if with_auth_dep else []
app.include_router(mizan_router, prefix="/api/mizan", dependencies=deps)
app.add_exception_handler(MizanError, mizan_exception_handler)
return app
def test_mutation_emits_invalidate_header():
c = TestClient(_app())
r = c.post("/api/mizan/call/", json={"fn": "update_profile", "args": {"user_id": 5}})
assert r.status_code == 200
assert r.json()["invalidate"] == [{"context": "user", "params": {"user_id": 5}}]
assert r.headers["X-Mizan-Invalidate"] == "user;user_id=5"
def test_origin_cache_hit_miss():
c = TestClient(_app(with_cache=True))
r1 = c.get("/api/mizan/ctx/user/", params={"user_id": 5})
assert r1.status_code == 200 and r1.headers["X-Mizan-Cache"] == "MISS"
r2 = c.get("/api/mizan/ctx/user/", params={"user_id": 5})
assert r2.headers["X-Mizan-Cache"] == "HIT"
assert r1.content == r2.content
def test_auth_required_rejects_anonymous():
c = TestClient(_app())
r = c.post("/api/mizan/call/", json={"fn": "whoami", "args": {}})
assert r.status_code == 401
def test_auth_required_passes_with_bearer_jwt():
c = TestClient(_app(with_auth_dep=True))
tok = create_access_token("7", "sess", JWT, is_staff=True)
r = c.post("/api/mizan/call/", json={"fn": "whoami", "args": {}},
headers={"Authorization": f"Bearer {tok}"})
assert r.status_code == 200 and r.json()["result"] == {"ok": True}
def test_invalid_bearer_token_rejected():
c = TestClient(_app())
r = c.post("/api/mizan/call/", json={"fn": "update_profile", "args": {"user_id": 1}},
headers={"Authorization": "Bearer not-a-real-token"})
assert r.status_code == 401

View File

@@ -13,7 +13,7 @@
* }
*/
import { ReactContext, type ClientOptions, type RegistryEntry, type ParamDef } from './types'
import { ReactContext, type ClientOptions, type RegistryEntry, type ParamDef, type AuthRequirement } from './types'
import { register } from './registry'
function resolveContext(ctx: ReactContext | string | undefined): string | undefined {
@@ -21,6 +21,19 @@ function resolveContext(ctx: ReactContext | string | undefined): string | undefi
return ctx
}
/**
* Normalize the public auth option into the stored requirement.
* Mirrors Python: undefined→undefined, true→'required', callable→callable,
* 'staff'/'superuser' pass through, anything else throws at decoration time.
*/
function normalizeAuth(auth: ClientOptions['auth']): AuthRequirement | undefined {
if (auth === undefined) return undefined
if (auth === true) return 'required'
if (typeof auth === 'function') return auth
if (auth === 'staff' || auth === 'superuser') return auth
throw new Error(`Invalid auth value ${JSON.stringify(auth)}`)
}
function normalizeAffects(
affects: ClientOptions['affects'],
): RegistryEntry['affects'] | undefined {
@@ -97,7 +110,7 @@ export function client(optionsOrFn: ClientOptions | ClientOptions, fn?: Function
viewPath: isView,
route: options.route,
methods: options.methods,
auth: options.auth,
auth: normalizeAuth(options.auth),
rev: options.rev,
cache: options.cache,
}
@@ -129,7 +142,7 @@ export function client(optionsOrFn: ClientOptions | ClientOptions, fn?: Function
viewPath: false,
route: options.route,
methods: options.methods,
auth: options.auth,
auth: normalizeAuth(options.auth),
rev: options.rev,
cache: options.cache,
}

View File

@@ -8,6 +8,8 @@
import { getFunction, getContextGroups } from './registry'
import { resolveInvalidation, formatInvalidateHeader } from './invalidation'
import { getCache, cacheGet, cachePut, cachePurge } from './cache'
import { ANONYMOUS, type Identity } from './identity'
import type { AuthRequirement } from './types'
let _cacheSecret: string | null = null
@@ -22,6 +24,54 @@ export interface MizanResponse {
headers: Record<string, string>
}
interface AuthDenial {
status: 401 | 403
code: 'UNAUTHORIZED' | 'FORBIDDEN'
message: string
}
/**
* Check whether `identity` satisfies the stored `auth` requirement.
* Ports Django's _check_auth_requirement exactly. Returns an AuthDenial
* on failure, or null when access is allowed.
*/
function checkAuth(auth: AuthRequirement | undefined, identity: Identity): AuthDenial | null {
if (auth === undefined) return null
// Callable runs first — before the authentication gate.
if (typeof auth === 'function') {
try {
return auth(identity)
? null
: { status: 403, code: 'FORBIDDEN', message: 'Access denied' }
} catch (e: any) {
return { status: 403, code: 'FORBIDDEN', message: e?.message || 'Access denied' }
}
}
if (!identity.isAuthenticated) {
return { status: 401, code: 'UNAUTHORIZED', message: 'Authentication required' }
}
if (auth === 'staff' && !identity.isStaff) {
return { status: 403, code: 'FORBIDDEN', message: 'Staff access required' }
}
if (auth === 'superuser' && !identity.isSuperuser) {
return { status: 403, code: 'FORBIDDEN', message: 'Superuser access required' }
}
return null
}
function authDenialResponse(denial: AuthDenial): MizanResponse {
return {
status: denial.status,
body: { error: true, code: denial.code, message: denial.message },
headers: { 'Cache-Control': 'no-store', 'Content-Type': 'application/json' },
}
}
/**
* Handle GET /api/mizan/ctx/:contextName/
*
@@ -30,6 +80,7 @@ export interface MizanResponse {
export async function handleContextFetch(
contextName: string,
params: Record<string, string>,
identity: Identity = ANONYMOUS,
): Promise<MizanResponse> {
const groups = getContextGroups()
const fnNames = groups[contextName]
@@ -42,6 +93,15 @@ export async function handleContextFetch(
}
}
// Auth pre-pass — run BEFORE the cache lookup so a cache HIT can never
// leak to an unauthorized caller. Any denial short-circuits, uncached.
for (const fnName of fnNames) {
const entry = getFunction(fnName)
if (!entry) continue
const denial = checkAuth(entry.auth, identity)
if (denial) return authDenialResponse(denial)
}
// Resolve effective rev (max across functions) and cache policy (min TTL)
let effectiveRev = 0
for (const fnName of fnNames) {
@@ -133,6 +193,7 @@ export async function handleContextFetch(
export async function handleMutationCall(
fnName: string,
args: Record<string, any>,
identity: Identity = ANONYMOUS,
): Promise<MizanResponse> {
const entry = getFunction(fnName)
@@ -153,6 +214,10 @@ export async function handleMutationCall(
}
}
// Auth enforcement — after private rejection, before execution.
const denial = checkAuth(entry.auth, identity)
if (denial) return authDenialResponse(denial)
try {
const argValues = entry.params.map(p => args[p.name])
const result = await entry.fn(...argValues)

View File

@@ -0,0 +1,22 @@
/**
* Identity abstraction — the request-bound caller identity.
*
* Framework-agnostic. Adapters construct an Identity (from MWT, JWT,
* session, etc.) and pass it into dispatch. ANONYMOUS is the default.
*/
export interface Identity {
isAuthenticated: boolean
isStaff: boolean
isSuperuser: boolean
id: number | string | null
}
export const ANONYMOUS: Identity = {
isAuthenticated: false,
isStaff: false,
isSuperuser: false,
id: null,
}
export type AuthPredicate = (identity: Identity) => boolean

View File

@@ -1,5 +1,11 @@
export { ReactContext } from './types'
export type { ClientOptions, EdgeManifest, RegistryEntry } from './types'
export type { ClientOptions, EdgeManifest, RegistryEntry, AuthOption, AuthRequirement } from './types'
export { ANONYMOUS } from './identity'
export type { Identity, AuthPredicate } from './identity'
export { decodeMwt, decodeJwtBearer, identityFromMwt } from './token'
export type { MwtPayload } from './token'
export { client } from './decorator'

View File

@@ -0,0 +1,110 @@
/**
* MWT / JWT decode — HS256 verification, cross-language parity with
* cores/mizan-python/src/mizan_core/mwt.py.
*
* Returns null on ANY failure (bad signature, expired, future nbf, wrong
* aud, malformed). Never throws.
*/
import { createHmac, timingSafeEqual } from 'crypto'
import type { Identity } from './identity'
export interface MwtPayload {
sub: string
staff: boolean
super: boolean
pkey: string
kid: string
aud: string
iat: number
exp: number
}
function base64urlDecode(input: string): Buffer | null {
if (!/^[A-Za-z0-9_-]*$/.test(input)) return null
return Buffer.from(input, 'base64url')
}
function constantTimeEqual(a: Buffer, b: Buffer): boolean {
if (a.length !== b.length) return false
return timingSafeEqual(a, b)
}
/**
* Decode and validate an MWT (HS256 JWT with Mizan claims).
* Returns MwtPayload on success, null on any failure.
*/
export function decodeMwt(
token: string,
secret: string,
audience: string = 'mizan',
): MwtPayload | null {
try {
const parts = token.split('.')
if (parts.length !== 3) return null
const [headerB64, payloadB64, signatureB64] = parts
const headerBytes = base64urlDecode(headerB64)
const payloadBytes = base64urlDecode(payloadB64)
const signatureBytes = base64urlDecode(signatureB64)
if (!headerBytes || !payloadBytes || !signatureBytes) return null
const header = JSON.parse(headerBytes.toString('utf-8'))
if (header.alg !== 'HS256') return null
// Recompute HMAC over `${headerB64}.${payloadB64}`
const expected = createHmac('sha256', secret)
.update(`${headerB64}.${payloadB64}`)
.digest()
if (!constantTimeEqual(expected, signatureBytes)) return null
const data = JSON.parse(payloadBytes.toString('utf-8'))
const now = Math.floor(Date.now() / 1000)
if (typeof data.exp !== 'number' || data.exp <= now) return null
if (data.nbf !== undefined && typeof data.nbf === 'number' && data.nbf > now) return null
if (data.aud !== audience) return null
const kid = typeof header.kid === 'string' ? header.kid : 'v1'
return {
sub: String(data.sub),
staff: Boolean(data.staff),
super: Boolean(data.super),
pkey: typeof data.pkey === 'string' ? data.pkey : '',
kid,
aud: audience,
iat: data.iat,
exp: data.exp,
}
} catch {
return null
}
}
/**
* Decode a Bearer JWT from an Authorization header value.
* Strips the "Bearer " prefix, then validates as an MWT.
*/
export function decodeJwtBearer(
authHeader: string,
secret: string,
audience: string = 'mizan',
): MwtPayload | null {
if (!authHeader) return null
const prefix = 'Bearer '
const token = authHeader.startsWith(prefix)
? authHeader.slice(prefix.length)
: authHeader
return decodeMwt(token, secret, audience)
}
/** Build an Identity from a decoded MWT payload. */
export function identityFromMwt(payload: MwtPayload): Identity {
return {
isAuthenticated: true,
isStaff: payload.staff,
isSuperuser: payload.super,
id: Number(payload.sub),
}
}

View File

@@ -2,6 +2,8 @@
* Mizan TypeScript Adapter — Shared Types
*/
import type { AuthPredicate } from './identity'
export class ReactContext {
constructor(public readonly name: string) {
if (!name) throw new Error('ReactContext name must be non-empty')
@@ -10,13 +12,19 @@ export class ReactContext {
export type AffectsTarget = ReactContext | string
/** Public auth option on the decorator. `true` normalizes to `'required'` when stored. */
export type AuthOption = true | 'staff' | 'superuser' | AuthPredicate
/** Normalized auth requirement as stored on the registry entry. */
export type AuthRequirement = 'required' | 'staff' | 'superuser' | AuthPredicate
export interface ClientOptions {
context?: ReactContext | string
affects?: AffectsTarget | AffectsTarget[]
private?: boolean
route?: string
methods?: string[]
auth?: boolean
auth?: AuthOption
rev?: number
cache?: number | false
}
@@ -37,7 +45,7 @@ export interface RegistryEntry {
viewPath: boolean
route?: string
methods?: string[]
auth?: boolean
auth?: AuthRequirement
rev?: number
cache?: number | false
}

View File

@@ -0,0 +1,163 @@
/**
* Auth-parity tests — mirrors Django's auth enforcement in
* mizan-django/src/mizan/client/executor.py (_check_auth_requirement).
*/
import { describe, test, expect, beforeEach } from 'bun:test'
import {
ReactContext, client, clearRegistry,
handleContextFetch, handleMutationCall,
setCache, resetCache, setCacheSecret, MemoryCache,
type Identity,
} from '../src'
function anon(): Identity {
return { isAuthenticated: false, isStaff: false, isSuperuser: false, id: null }
}
function user(): Identity {
return { isAuthenticated: true, isStaff: false, isSuperuser: false, id: 1 }
}
function staff(): Identity {
return { isAuthenticated: true, isStaff: true, isSuperuser: false, id: 2 }
}
function superuser(): Identity {
return { isAuthenticated: true, isStaff: true, isSuperuser: true, id: 3 }
}
describe('Auth — mutation dispatch', () => {
beforeEach(() => clearRegistry())
test('auth:true + anon → 401', async () => {
client({ auth: true }, async function secret() { return { ok: true } })
const r = await handleMutationCall('secret', {}, anon())
expect(r.status).toBe(401)
expect(r.body.code).toBe('UNAUTHORIZED')
expect(r.body.message).toBe('Authentication required')
expect(r.headers['Cache-Control']).toBe('no-store')
})
test('auth:true + user → 200', async () => {
client({ auth: true }, async function secret() { return { ok: true } })
const r = await handleMutationCall('secret', {}, user())
expect(r.status).toBe(200)
expect(r.body.result).toEqual({ ok: true })
})
test("auth:'staff' + user → 403", async () => {
client({ auth: 'staff' }, async function adminAction() { return { ok: true } })
const r = await handleMutationCall('adminAction', {}, user())
expect(r.status).toBe(403)
expect(r.body.code).toBe('FORBIDDEN')
expect(r.body.message).toBe('Staff access required')
})
test("auth:'staff' + staff → 200", async () => {
client({ auth: 'staff' }, async function adminAction() { return { ok: true } })
const r = await handleMutationCall('adminAction', {}, staff())
expect(r.status).toBe(200)
})
test("auth:'superuser' + staff → 403", async () => {
client({ auth: 'superuser' }, async function nuke() { return { ok: true } })
const r = await handleMutationCall('nuke', {}, staff())
expect(r.status).toBe(403)
expect(r.body.message).toBe('Superuser access required')
})
test("auth:'superuser' + superuser → 200", async () => {
client({ auth: 'superuser' }, async function nuke() { return { ok: true } })
const r = await handleMutationCall('nuke', {}, superuser())
expect(r.status).toBe(200)
})
test('callable → true → 200', async () => {
client({ auth: (id) => id.isAuthenticated }, async function gated() { return { ok: true } })
const r = await handleMutationCall('gated', {}, user())
expect(r.status).toBe(200)
})
test("callable → false → 403 'Access denied'", async () => {
client({ auth: () => false }, async function gated() { return { ok: true } })
const r = await handleMutationCall('gated', {}, user())
expect(r.status).toBe(403)
expect(r.body.message).toBe('Access denied')
})
test("callable throws Error('msg') → 403 'msg'", async () => {
client({ auth: () => { throw new Error('msg') } }, async function gated() { return { ok: true } })
const r = await handleMutationCall('gated', {}, user())
expect(r.status).toBe(403)
expect(r.body.message).toBe('msg')
})
test('callable runs before authentication gate (anon allowed if predicate true)', async () => {
client({ auth: () => true }, async function gated() { return { ok: true } })
const r = await handleMutationCall('gated', {}, anon())
expect(r.status).toBe(200)
})
test('invalid auth string at decoration → throws', () => {
expect(() => {
client({ auth: 'admin' as any }, async function bad() { return {} })
}).toThrow('Invalid auth value')
})
test('no auth + anon → 200 (default ANONYMOUS path stays open)', async () => {
client({}, async function open() { return { ok: true } })
const r = await handleMutationCall('open', {})
expect(r.status).toBe(200)
})
})
describe('Auth — context fetch', () => {
beforeEach(() => clearRegistry())
test('auth-gated context member + anon → 401', async () => {
const Ctx = new ReactContext('secure')
client({ context: Ctx, auth: true }, async function secureData(itemId: number) {
return { id: itemId }
})
const r = await handleContextFetch('secure', { itemId: '1' }, anon())
expect(r.status).toBe(401)
expect(r.body.message).toBe('Authentication required')
})
test('auth-gated context + user → 200', async () => {
const Ctx = new ReactContext('secure')
client({ context: Ctx, auth: true }, async function secureData(itemId: number) {
return { id: itemId }
})
const r = await handleContextFetch('secure', { itemId: '1' }, user())
expect(r.status).toBe(200)
expect(r.body.secureData).toEqual({ id: '1' })
})
test('context fetch denial pre-empts a would-be cache HIT', async () => {
const SECRET = 'auth-test-secret-32bytes-padding!'
const Ctx = new ReactContext('secure')
client({ context: Ctx, auth: true }, async function secureData(itemId: number) {
return { id: itemId }
})
const cache = new MemoryCache()
setCache(cache)
setCacheSecret(SECRET)
// Prime the cache as an authorized caller.
const primed = await handleContextFetch('secure', { itemId: '1' }, user())
expect(primed.status).toBe(200)
expect(primed.headers['X-Mizan-Cache']).toBe('MISS')
// Confirm it's now a cache HIT for an authorized caller.
const hit = await handleContextFetch('secure', { itemId: '1' }, user())
expect(hit.headers['X-Mizan-Cache']).toBe('HIT')
// Anon must get 401 even though the cache holds the entry.
const denied = await handleContextFetch('secure', { itemId: '1' }, anon())
expect(denied.status).toBe(401)
expect(denied.headers['X-Mizan-Cache']).toBeUndefined()
resetCache()
setCacheSecret(null)
})
})

View File

@@ -0,0 +1,126 @@
/**
* MWT decode tests — round-trip + cross-language pin against Python create_mwt.
*/
import { describe, test, expect } from 'bun:test'
import { createHmac } from 'crypto'
import { decodeMwt, decodeJwtBearer, identityFromMwt } from '../src'
function b64url(buf: Buffer | string): string {
return Buffer.from(buf).toString('base64url')
}
/** Mint an HS256 MWT with node crypto, mirroring Python create_mwt. */
function mint(payload: Record<string, any>, secret: string, kid = 'v1'): string {
const header = b64url(JSON.stringify({ alg: 'HS256', kid, typ: 'JWT' }))
const body = b64url(JSON.stringify(payload))
const sig = createHmac('sha256', secret).update(`${header}.${body}`).digest('base64url')
return `${header}.${body}.${sig}`
}
const SECRET = 'round-trip-secret'
const now = Math.floor(Date.now() / 1000)
function basePayload(overrides: Record<string, any> = {}) {
return {
sub: '7',
staff: true,
super: false,
pkey: 'abc123',
aud: 'mizan',
iat: now,
nbf: now,
exp: now + 300,
...overrides,
}
}
describe('MWT round-trip', () => {
test('valid token decodes', () => {
const token = mint(basePayload(), SECRET)
const p = decodeMwt(token, SECRET)
expect(p).not.toBeNull()
expect(p!.sub).toBe('7')
expect(p!.staff).toBe(true)
expect(p!.super).toBe(false)
expect(p!.pkey).toBe('abc123')
expect(p!.kid).toBe('v1')
expect(p!.aud).toBe('mizan')
})
test('identityFromMwt maps claims', () => {
const token = mint(basePayload({ sub: '99', staff: false, super: true }), SECRET)
const p = decodeMwt(token, SECRET)!
expect(identityFromMwt(p)).toEqual({
isAuthenticated: true,
isStaff: false,
isSuperuser: true,
id: 99,
})
})
test('decodeJwtBearer strips Bearer prefix', () => {
const token = mint(basePayload(), SECRET)
const p = decodeJwtBearer(`Bearer ${token}`, SECRET)
expect(p).not.toBeNull()
expect(p!.sub).toBe('7')
})
test('null on tampered signature', () => {
const token = mint(basePayload(), SECRET)
const tampered = token.slice(0, -2) + (token.endsWith('AA') ? 'BB' : 'AA')
expect(decodeMwt(tampered, SECRET)).toBeNull()
})
test('null on wrong secret', () => {
const token = mint(basePayload(), SECRET)
expect(decodeMwt(token, 'other-secret')).toBeNull()
})
test('null on expired exp', () => {
const token = mint(basePayload({ exp: now - 10 }), SECRET)
expect(decodeMwt(token, SECRET)).toBeNull()
})
test('null on future nbf', () => {
const token = mint(basePayload({ nbf: now + 1000 }), SECRET)
expect(decodeMwt(token, SECRET)).toBeNull()
})
test('null on wrong aud', () => {
const token = mint(basePayload({ aud: 'other' }), SECRET)
expect(decodeMwt(token, SECRET)).toBeNull()
})
test('null on malformed token', () => {
expect(decodeMwt('not.a.jwt', SECRET)).toBeNull()
expect(decodeMwt('onlyonepart', SECRET)).toBeNull()
expect(decodeMwt('', SECRET)).toBeNull()
})
})
describe('MWT cross-language pin (Python create_mwt)', () => {
const TOKEN = 'eyJhbGciOiJIUzI1NiIsImtpZCI6InYxIiwidHlwIjoiSldUIn0.eyJzdWIiOiI0MiIsInN0YWZmIjp0cnVlLCJzdXBlciI6ZmFsc2UsInBrZXkiOiIwZTk5OGE5ZmYxNjkwNDYzN2EwM2QyZWEwZmJkYmY5NzQyOTdhOWQxYTVkMjViOGQ0Mjk0ZmE4ODIxMTVlNDU3IiwiYXVkIjoibWl6YW4iLCJpYXQiOjE3MDAwMDAwMDAsIm5iZiI6MTcwMDAwMDAwMCwiZXhwIjo0MTAyNDQ0ODAwfQ._V92JXiLSLXoyuSwbNvvJjwzgmczmC7dvX34kVSLIa8'
const PIN_SECRET = 'pin-test-secret-mwt'
test('decodes the Python-minted token', () => {
const p = decodeMwt(TOKEN, PIN_SECRET)
expect(p).not.toBeNull()
expect(p!.sub).toBe('42')
expect(p!.staff).toBe(true)
expect(p!.super).toBe(false)
expect(p!.pkey).toBe('0e998a9ff16904637a03d2ea0fbdbf974297a9d1a5d25b8d4294fa882115e457')
expect(p!.kid).toBe('v1')
expect(p!.aud).toBe('mizan')
})
test('identity from Python-minted token', () => {
const p = decodeMwt(TOKEN, PIN_SECRET)!
expect(identityFromMwt(p)).toEqual({
isAuthenticated: true,
isStaff: true,
isSuperuser: false,
id: 42,
})
})
})