From 66b2db81fbf55186b7aa7420b286523eb7b6bd4b Mon Sep 17 00:00:00 2001 From: Ryth Azhur Date: Thu, 4 Jun 2026 05:14:29 -0400 Subject: [PATCH] FastAPI and TypeScript improved --- .gitignore | 3 + README.md | 22 +- .../mizan-django/src/mizan/client/executor.py | 235 +------------- backends/mizan-django/src/mizan/jwt/tokens.py | 254 +++------------ .../mizan-django/src/mizan/tests/test_auth.py | 4 +- .../src/mizan_fastapi/__init__.py | 6 + .../mizan-fastapi/src/mizan_fastapi/auth.py | 54 ++++ .../mizan-fastapi/src/mizan_fastapi/config.py | 80 +++++ .../src/mizan_fastapi/executor.py | 294 +++--------------- .../mizan-fastapi/src/mizan_fastapi/router.py | 76 +++-- backends/mizan-fastapi/tests/test_parity.py | 98 ++++++ backends/mizan-ts/src/decorator.ts | 19 +- backends/mizan-ts/src/dispatch.ts | 65 ++++ backends/mizan-ts/src/identity.ts | 22 ++ backends/mizan-ts/src/index.ts | 8 +- backends/mizan-ts/src/token.ts | 110 +++++++ backends/mizan-ts/src/types.ts | 12 +- backends/mizan-ts/tests/auth.test.ts | 163 ++++++++++ backends/mizan-ts/tests/token.test.ts | 126 ++++++++ .../src/mizan_core/auth/__init__.py | 27 ++ .../src/mizan_core/auth/authenticate.py | 53 ++++ cores/mizan-python/src/mizan_core/auth/jwt.py | 137 ++++++++ .../mizan-python/src/mizan_core/authguard.py | 52 ++++ cores/mizan-python/src/mizan_core/dispatch.py | 250 +++++++++++++++ cores/mizan-python/src/mizan_core/errors.py | 58 ++++ cores/mizan-python/src/mizan_core/identity.py | 32 ++ .../src/mizan_core/invalidation.py | 174 +++++++++++ .../mizan-python/tests/test_dispatch_core.py | 147 +++++++++ 28 files changed, 1864 insertions(+), 717 deletions(-) create mode 100644 backends/mizan-fastapi/src/mizan_fastapi/auth.py create mode 100644 backends/mizan-fastapi/src/mizan_fastapi/config.py create mode 100644 backends/mizan-fastapi/tests/test_parity.py create mode 100644 backends/mizan-ts/src/identity.ts create mode 100644 backends/mizan-ts/src/token.ts create mode 100644 backends/mizan-ts/tests/auth.test.ts create mode 100644 backends/mizan-ts/tests/token.test.ts create mode 100644 cores/mizan-python/src/mizan_core/auth/__init__.py create mode 100644 cores/mizan-python/src/mizan_core/auth/authenticate.py create mode 100644 cores/mizan-python/src/mizan_core/auth/jwt.py create mode 100644 cores/mizan-python/src/mizan_core/authguard.py create mode 100644 cores/mizan-python/src/mizan_core/dispatch.py create mode 100644 cores/mizan-python/src/mizan_core/errors.py create mode 100644 cores/mizan-python/src/mizan_core/identity.py create mode 100644 cores/mizan-python/src/mizan_core/invalidation.py create mode 100644 cores/mizan-python/tests/test_dispatch_core.py diff --git a/.gitignore b/.gitignore index 0a8911b..dcbbf8d 100644 --- a/.gitignore +++ b/.gitignore @@ -34,3 +34,6 @@ examples/django-react-site/harness/test-results/ .env.* *.pem *.key + +# Agent worktrees (transient scratch — never tracked) +.claude/worktrees/ diff --git a/README.md b/README.md index a6980a4..2a08f40 100644 --- a/README.md +++ b/README.md @@ -61,15 +61,18 @@ Protocol transports and guarantees co-equal with the body channel in the spec. | Capability | Django | FastAPI | Rust / Axum | Tauri | TypeScript | |---|:---:|:---:|:---:|:---:|:---:| -| Invalidation — `X-Mizan-Invalidate` header | ✅ | ❌ | ❌ | — ¹ | ✅ | -| Auth-guard enforcement (`auth=…` rejects) | ✅ | ✅ | ❌ ⁵ | ◑ ⁵ | ❌ | -| Origin-side HMAC cache | ✅ | ❌ | ❌ | ❌ | ✅ | +| Invalidation — `X-Mizan-Invalidate` header | ✅ | ✅ | ❌ | — ¹ | ✅ | +| Auth-guard enforcement (`auth=…` rejects) | ✅ | ✅ | ❌ ⁵ | ◑ ⁵ | ✅ ¹¹ | +| Origin-side HMAC cache | ✅ | ✅ | ❌ | ❌ | ✅ | | Edge manifest export | ✅ | ❌ | ❌ | — | ✅ | | PSR (`render_strategy` in manifest) | ✅ | ❌ | ❌ | — | ✅ | | Session / CSRF init endpoint | ✅ | ◑ ⁷ | ◑ ⁷ | — | ❌ | > **Caveat:** Rust/Axum and Tauri accept `auth=` on a function but do not yet enforce > it — do not rely on `auth=` for access control on those adapters. +> +> Django, FastAPI, and TypeScript share one auth/invalidation/cache implementation +> (`mizan_core` for the Python adapters; the same spec, pinned cross-language, for TS). ### Stack extensions (Django) @@ -82,8 +85,8 @@ target stack calls for them. | Forms (schema / validate / submit) | ✅ | ❌ | ◑ ³ | ❌ | ❌ | | Formsets | ✅ | ❌ | ❌ | ❌ | ❌ | | API shapes (ORM query projection) ⁴ | ✅ | — | — | — | — | -| JWT auth (access / refresh, session validation) | ✅ | ❌ | ❌ | ❌ | ❌ | -| MWT (edge identity token) | ✅ | ❌ | ❌ | — | ❌ | +| JWT auth (access / refresh) ¹² | ✅ | ✅ | ❌ | ❌ | ◑ ¹³ | +| MWT (edge identity token) | ✅ | ✅ | ❌ | — | ◑ ¹³ | | SSR bridge | ✅ | ❌ | ❌ | — | ❌ | | Auth-provider integration (allauth) | ✅ | ❌ | ❌ | ❌ | ❌ | @@ -113,6 +116,15 @@ target stack calls for them. 10. The TypeScript column is the `mizan-ts` backend adapter, which has no upload dispatch. The matching client side lives in the kernel (`@mizan/base`): `mizanCall` auto-switches to `multipart/form-data` when any argument is a `File`. +11. `mizan-ts` dispatch now enforces `auth=` (`true`/`'staff'`/`'superuser'`/predicate) + against a host-supplied `Identity`, byte-matching the Python guard's denial messages. +12. JWT/MWT token logic is single-sourced in `mizan_core.auth`; Django and FastAPI ride + it. Session-validation (immediate-logout revocation) is Django-only — FastAPI mints + from its own credential check. +13. `mizan-ts` ships an optional `decodeMwt`/`decodeJwtBearer`/`identityFromMwt` helper + (HS256 via Node `crypto`, cross-language pin-tested against a Python-minted MWT) so a + TS edge worker can derive `Identity` from a Python-issued token. Identity source stays + host-supplied; `mizan-ts` does not mint from a session. ## Conformance diff --git a/backends/mizan-django/src/mizan/client/executor.py b/backends/mizan-django/src/mizan/client/executor.py index ec908c8..b528ed9 100644 --- a/backends/mizan-django/src/mizan/client/executor.py +++ b/backends/mizan-django/src/mizan/client/executor.py @@ -30,6 +30,9 @@ from pydantic import BaseModel, ValidationError from mizan.cache import get_cache, cache_get, cache_put, cache_purge from mizan_core.registry import get_function, get_context_groups from mizan_core.upload import UploadedFile, bind_uploads +from mizan_core import invalidation as _core_inval +from mizan_core.authguard import enforce_auth as _core_enforce_auth +from mizan_core.errors import MizanError as _CoreMizanError from mizan.setup.settings import get_settings if TYPE_CHECKING: @@ -113,53 +116,14 @@ def _check_auth_requirement( Django User (from session). Either way, no additional DB query is made for the built-in checks. Custom callables may query DB if they choose. """ - if auth_requirement is None: + # Evaluation lives in the shared core (mizan_core.authguard); the callable + # path receives the native Django request. Core raises; we render to the + # Django-shim FunctionError shape the executor expects. + try: + _core_enforce_auth(getattr(request, "user", None), auth_requirement, request) return None - - user = request.user - - # Handle callable auth - if callable(auth_requirement): - try: - result = auth_requirement(request) - if result: - return None # Authorized - else: - return FunctionError( - code=ErrorCode.FORBIDDEN, - message="Access denied", - ) - except PermissionError as e: - # Custom error message from the callable - return FunctionError( - code=ErrorCode.FORBIDDEN, - message=str(e) or "Access denied", - ) - - # Check authentication (required for all string-based auth) - if not getattr(user, "is_authenticated", False): - return FunctionError( - code=ErrorCode.UNAUTHORIZED, - message="Authentication required", - ) - - # Check staff requirement - if auth_requirement == "staff": - if not getattr(user, "is_staff", False): - return FunctionError( - code=ErrorCode.FORBIDDEN, - message="Staff access required", - ) - - # Check superuser requirement - elif auth_requirement == "superuser": - if not getattr(user, "is_superuser", False): - return FunctionError( - code=ErrorCode.FORBIDDEN, - message="Superuser access required", - ) - - return None + except _CoreMizanError as e: + return FunctionError(code=ErrorCode(e.code.value), message=e.message) _cache_log = logging.getLogger("mizan.cache") @@ -198,51 +162,6 @@ def _purge_cache_for_invalidation( _cache_log.warning("Cache purge failed", exc_info=True) -def _resolve_affects_target(target_name: str) -> tuple[str, str, str | None]: - """ - Determine whether an affects target is a context name or function name. - - Returns: - ("context", "user", None) — full context invalidation - ("function", "user_profile", "user") — function within context - """ - groups = get_context_groups() - - # Check if it's a context name directly - if target_name in groups: - return ("context", target_name, None) - - # Check if it's a function name within a context - for ctx_name, fn_names in groups.items(): - if target_name in fn_names: - return ("function", target_name, ctx_name) - - # Not a context or context function — treat as context name anyway - # (it might be a non-context function or an as-yet-unregistered context) - return ("context", target_name, None) - - -def _get_context_param_names(context_name: str) -> set[str]: - """ - Get the set of parameter names used by functions in a context. - - Returns the union of all Input field names across context functions. - """ - groups = get_context_groups() - fn_names = groups.get(context_name, []) - param_names: set[str] = set() - - for fn_name in fn_names: - fn_cls = get_function(fn_name) - if fn_cls is None: - continue - input_cls = getattr(fn_cls, "Input", None) - if input_cls and input_cls is not BaseModel and hasattr(input_cls, "model_fields"): - param_names.update(input_cls.model_fields.keys()) - - return param_names - - def _resolve_invalidation( view_class: type | None, input_data: dict[str, Any] | None = None, @@ -261,49 +180,7 @@ def _resolve_invalidation( Returns a list suitable for both JSON body and header serialization. Returns None if no invalidation needed. """ - if view_class is None: - return None - - meta = getattr(view_class, "_meta", {}) - affects = meta.get("affects") - if not affects: - return None - - result = [] - seen = set() - - for target in affects: - if target["type"] == "context": - target_name = target["name"] - elif target["type"] == "function" and target.get("context"): - # Function-level: use the function name as the invalidation key - target_name = target["name"] - else: - continue - - if target_name in seen: - continue - seen.add(target_name) - - # Resolve the context this target belongs to (for param lookup) - resolved = _resolve_affects_target(target_name) - ctx_for_params = resolved[2] if resolved[0] == "function" else resolved[1] - - # Tier 1: argument name matching - if input_data and ctx_for_params: - context_params = _get_context_param_names(ctx_for_params) - matched = { - k: v for k, v in input_data.items() - if k in context_params - } - if matched: - result.append({"context": target_name, "params": matched}) - continue - - # Tier 3: broad fallback - result.append(target_name) - - return result if result else None + return _core_inval.resolve_invalidation(view_class, input_data) def _resolve_merges( @@ -322,94 +199,12 @@ def _resolve_merges( Mirrors _resolve_invalidation's tier-1 auto-scoping for params. Entries whose slot can't be uniquely resolved are dropped. """ - if view_class is None: - return None - - from mizan_core.type_utils import types_match_for_merge - - meta = getattr(view_class, "_meta", {}) - targets = meta.get("merge") or [] - if not targets: - return None - - mutation_output = getattr(view_class, "Output", None) - - out: list[dict[str, Any]] = [] - seen: set[str] = set() - for ctx_name in targets: - if ctx_name in seen: - continue - seen.add(ctx_name) - - slot = _resolve_merge_slot(ctx_name, mutation_output, types_match_for_merge) - if slot is None: - continue - - entry: dict[str, Any] = {"context": ctx_name, "slot": slot, "value": result_data} - if input_data: - context_params = _get_context_param_names(ctx_name) - matched = { - k: v for k, v in input_data.items() - if k in context_params - } - if matched: - entry["params"] = matched - out.append(entry) - return out + return _core_inval.resolve_merges(view_class, input_data, result_data) -def _resolve_merge_slot(context_name: str, mutation_output: Any, type_matcher: Any) -> str | None: - """Find the unique function-name slot in context whose return type matches mutation's output.""" - if mutation_output is None: - return None - groups = get_context_groups() - fn_names = groups.get(context_name, []) - matches: list[str] = [] - for fn_name in fn_names: - fn_cls = get_function(fn_name) - if fn_cls is None: - continue - fn_output = getattr(fn_cls, "Output", None) - if fn_output is not None and type_matcher(fn_output, mutation_output): - matches.append(fn_name) - return matches[0] if len(matches) == 1 else None - - -def _format_invalidate_header( - invalidate: list[str | dict[str, Any]], -) -> str: - """ - Format invalidation targets as X-Mizan-Invalidate header value. - - Format: comma-separated contexts. Semicolon-separated params per context. - Param values are URL-encoded to prevent delimiter collisions. - - Examples: - ["user"] → "user" - ["user", "notifications"] → "user, notifications" - [{"context": "user", "params": {"user_id": 5}}] - → "user;user_id=5" - [{"context": "search", "params": {"q": "hello world"}}] - → "search;q=hello%20world" - """ - from urllib.parse import quote - - parts = [] - for entry in invalidate: - if isinstance(entry, str): - parts.append(entry) - elif isinstance(entry, dict): - ctx = entry["context"] - params = entry.get("params", {}) - if params: - param_str = ";".join( - f"{quote(str(k), safe='')}={quote(str(v), safe='')}" - for k, v in sorted(params.items()) - ) - parts.append(f"{ctx};{param_str}") - else: - parts.append(ctx) - return ", ".join(parts) +def _format_invalidate_header(invalidate: list[str | dict[str, Any]]) -> str: + """Format invalidation targets as the X-Mizan-Invalidate header value (shared core).""" + return _core_inval.format_invalidate_header(invalidate) def execute_function( diff --git a/backends/mizan-django/src/mizan/jwt/tokens.py b/backends/mizan-django/src/mizan/jwt/tokens.py index b973e17..c0ff6e6 100644 --- a/backends/mizan-django/src/mizan/jwt/tokens.py +++ b/backends/mizan-django/src/mizan/jwt/tokens.py @@ -1,245 +1,79 @@ """ -JWT Token Creation and Validation +JWT tokens — the Django adapter over the shared core (`mizan_core.auth.jwt`). -Uses PyJWT directly - no allauth dependency. -Tokens are tied to Django sessions for immediate revocation on logout. +The token logic (mint/decode/refresh, `JWTUser`, `TokenPair`, `TokenPayload`) +lives in the core; this module binds it to Django settings and keeps the +session-revocation check (`validate_session`), which is Django-session-specific. """ -import time -from typing import NamedTuple +from __future__ import annotations -import jwt -from django.contrib.sessions.backends.base import SessionBase +from mizan_core.auth import jwt as _core_jwt +from mizan_core.auth.jwt import JWTConfig, JWTUser, TokenPair, TokenPayload from .settings import get_settings - -class TokenPair(NamedTuple): - """Access and refresh token pair.""" - access_token: str - refresh_token: str - expires_in: int +__all__ = [ + "TokenPair", + "TokenPayload", + "JWTUser", + "create_access_token", + "create_refresh_token", + "create_token_pair", + "decode_token", + "validate_session", + "refresh_tokens", +] -class TokenPayload(NamedTuple): - """Decoded token payload.""" - user_id: int | str - session_key: str - token_type: str - is_staff: bool - is_superuser: bool - exp: int - iat: int - - -class JWTUser: - """ - Minimal user object created from JWT claims. - - Used as request.user for JWT-authenticated requests. - No database query required - all data comes from the token. - - If you need the full User object with all fields, query explicitly: - user = User.objects.get(pk=request.user.id) - """ - - def __init__(self, payload: TokenPayload): - self.id = int(payload.user_id) if isinstance(payload.user_id, str) else payload.user_id - self.pk = self.id - self.is_staff = payload.is_staff - self.is_superuser = payload.is_superuser - self.is_authenticated = True - self.is_anonymous = False - self.is_active = True # Assumed active if they have a valid token - - def __str__(self): - return f"JWTUser(id={self.id})" - - def __repr__(self): - return f"JWTUser(id={self.id}, is_staff={self.is_staff}, is_superuser={self.is_superuser})" - - -def create_access_token( - user_id: int | str, - session_key: str, - *, - is_staff: bool = False, - is_superuser: bool = False, -) -> str: - """ - Create a short-lived access token. - - The token contains: - - sub: user ID - - sid: session key (for revocation checking) - - staff: is_staff flag - - super: is_superuser flag - - type: "access" - - iat: issued at - - exp: expiration - """ - settings = get_settings() - now = int(time.time()) - - payload = { - "sub": str(user_id), - "sid": session_key, - "staff": is_staff, - "super": is_superuser, - "type": "access", - "iat": now, - "exp": now + settings.access_token_expires_in, - } - - return jwt.encode( - payload, - settings.private_key, - algorithm=settings.algorithm, +def _config() -> JWTConfig: + s = get_settings() + return JWTConfig( + private_key=s.private_key, + public_key=s.public_key, + algorithm=s.algorithm, + access_token_expires_in=s.access_token_expires_in, + refresh_token_expires_in=s.refresh_token_expires_in, ) -def create_refresh_token( - user_id: int | str, - session_key: str, - *, - is_staff: bool = False, - is_superuser: bool = False, -) -> str: - """ - Create a longer-lived refresh token. - - The token contains: - - sub: user ID - - sid: session key (for revocation checking) - - staff: is_staff flag - - super: is_superuser flag - - type: "refresh" - - iat: issued at - - exp: expiration - """ - settings = get_settings() - now = int(time.time()) - - payload = { - "sub": str(user_id), - "sid": session_key, - "staff": is_staff, - "super": is_superuser, - "type": "refresh", - "iat": now, - "exp": now + settings.refresh_token_expires_in, - } - - return jwt.encode( - payload, - settings.private_key, - algorithm=settings.algorithm, - ) +def create_access_token(user_id, session_key, *, is_staff=False, is_superuser=False) -> str: + return _core_jwt.create_access_token(user_id, session_key, _config(), + is_staff=is_staff, is_superuser=is_superuser) -def create_token_pair( - user_id: int | str, - session_key: str, - *, - is_staff: bool = False, - is_superuser: bool = False, -) -> TokenPair: - """Create both access and refresh tokens.""" - settings = get_settings() - return TokenPair( - access_token=create_access_token( - user_id, session_key, is_staff=is_staff, is_superuser=is_superuser - ), - refresh_token=create_refresh_token( - user_id, session_key, is_staff=is_staff, is_superuser=is_superuser - ), - expires_in=settings.access_token_expires_in, - ) +def create_refresh_token(user_id, session_key, *, is_staff=False, is_superuser=False) -> str: + return _core_jwt.create_refresh_token(user_id, session_key, _config(), + is_staff=is_staff, is_superuser=is_superuser) -def decode_token(token: str, expected_type: str = None) -> TokenPayload | None: - """ - Decode and validate a JWT token. +def create_token_pair(user_id, session_key, *, is_staff=False, is_superuser=False) -> TokenPair: + return _core_jwt.create_token_pair(user_id, session_key, _config(), + is_staff=is_staff, is_superuser=is_superuser) - Returns None if: - - Token is invalid or expired - - Token type doesn't match expected_type (if specified) - """ - settings = get_settings() - try: - payload = jwt.decode( - token, - settings.public_key, - algorithms=[settings.algorithm], - ) - except jwt.PyJWTError: - return None - - # Validate token type if specified - if expected_type and payload.get("type") != expected_type: - return None - - return TokenPayload( - user_id=payload["sub"], - session_key=payload["sid"], - token_type=payload["type"], - is_staff=payload.get("staff", False), - is_superuser=payload.get("super", False), - exp=payload["exp"], - iat=payload["iat"], - ) +def decode_token(token: str, expected_type: str | None = None) -> TokenPayload | None: + return _core_jwt.decode_token(token, _config(), expected_type=expected_type) def validate_session(session_key: str) -> bool: - """ - Check if a session is still valid (exists and not expired). + """Immediate-logout revocation: is this Django session still alive? - This is the key to immediate logout revocation - if the session - is destroyed, tokens tied to it become invalid. + Honors `JWT_VALIDATE_SESSION` — when disabled, always True. This is the one + Django-session-bound piece; the core's `refresh_tokens` takes it as an + injected `session_validator`. """ from importlib import import_module from django.conf import settings as django_settings - jwt_settings = get_settings() - - if not jwt_settings.validate_session: + if not get_settings().validate_session: return True - # Use the configured session engine engine = import_module(django_settings.SESSION_ENGINE) - SessionStore = engine.SessionStore - - # Try to load the session - session = SessionStore(session_key=session_key) - - # Check if session exists and is not empty - # exists() is more reliable than checking load() result + session = engine.SessionStore(session_key=session_key) return session.exists(session_key) def refresh_tokens(refresh_token: str) -> TokenPair | None: - """ - Use a refresh token to obtain new tokens. - - Returns None if: - - Refresh token is invalid or expired - - Associated session no longer exists - """ - payload = decode_token(refresh_token, expected_type="refresh") - - if payload is None: - return None - - # Validate the session still exists - if not validate_session(payload.session_key): - return None - - # Issue new token pair with same claims - return create_token_pair( - payload.user_id, - payload.session_key, - is_staff=payload.is_staff, - is_superuser=payload.is_superuser, - ) + return _core_jwt.refresh_tokens(refresh_token, _config(), session_validator=validate_session) diff --git a/backends/mizan-django/src/mizan/tests/test_auth.py b/backends/mizan-django/src/mizan/tests/test_auth.py index 661960c..58c18bb 100644 --- a/backends/mizan-django/src/mizan/tests/test_auth.py +++ b/backends/mizan-django/src/mizan/tests/test_auth.py @@ -170,8 +170,8 @@ class HTTPAuthTests(TestCase): def test_jwt_expired_with_session(self): """Expired JWT with valid session → Reject (do NOT fall back).""" - # Create token with past expiration by mocking time - with patch("mizan.jwt.tokens.time.time", return_value=0): + # Create token with past expiration by mocking time (minting lives in the core now) + with patch("mizan_core.auth.jwt.time.time", return_value=0): tokens = create_token_pair( self.user.pk, self.session_key, diff --git a/backends/mizan-fastapi/src/mizan_fastapi/__init__.py b/backends/mizan-fastapi/src/mizan_fastapi/__init__.py index 3003cf1..9b27ea2 100644 --- a/backends/mizan-fastapi/src/mizan_fastapi/__init__.py +++ b/backends/mizan-fastapi/src/mizan_fastapi/__init__.py @@ -35,12 +35,18 @@ from .executor import ( execute_function, ) from .router import router, mizan_exception_handler, mizan_validation_handler +from .auth import MizanAuthMiddleware, mizan_auth +from .config import MizanConfig, from_env from mizan_core.upload import File, Upload, UploadedFile __all__ = [ "Upload", "File", "UploadedFile", + "mizan_auth", + "MizanAuthMiddleware", + "MizanConfig", + "from_env", "router", "mizan_exception_handler", "mizan_validation_handler", diff --git a/backends/mizan-fastapi/src/mizan_fastapi/auth.py b/backends/mizan-fastapi/src/mizan_fastapi/auth.py new file mode 100644 index 0000000..cb76068 --- /dev/null +++ b/backends/mizan-fastapi/src/mizan_fastapi/auth.py @@ -0,0 +1,54 @@ +""" +Built-in identity for FastAPI — Django-equivalent automatic `request.state.user`. + +Opt in via `Depends(mizan_auth())` on a route/router, or mount `MizanAuthMiddleware` +app-wide. Both decode a bearer-JWT (`Authorization: Bearer`) or MWT (`X-Mizan-Token`) +via the shared core and set `request.state.user`. A present-but-invalid token is +rejected (401) rather than silently downgraded — the `INVALID` sentinel contract. + +If you'd rather resolve identity yourself, set `request.state.user` upstream and skip +these; dispatch reads it directly. +""" + +from __future__ import annotations + +from typing import Callable + +from fastapi import Request +from starlette.middleware.base import BaseHTTPMiddleware + +from mizan_core.auth import INVALID, authenticate +from mizan_core.errors import Unauthorized + +from .config import get_config + + +def _resolve(request: Request) -> None: + ident = authenticate(request.headers, get_config(request).auth) + if ident is INVALID: + raise Unauthorized("Invalid or expired token") + if ident is not None: + request.state.user = ident + + +def mizan_auth() -> Callable: + """FastAPI dependency that populates `request.state.user` from a token.""" + async def _dep(request: Request) -> None: + _resolve(request) + return _dep + + +class MizanAuthMiddleware(BaseHTTPMiddleware): + """App-wide variant of `mizan_auth` — resolves identity on every request.""" + + async def dispatch(self, request, call_next): + try: + _resolve(request) + except Unauthorized: + from .router import _no_store + from mizan_core.errors import ErrorCode + return _no_store( + {"error": {"code": ErrorCode.UNAUTHORIZED.value, "message": "Invalid or expired token"}}, + status_code=401, + ) + return await call_next(request) diff --git a/backends/mizan-fastapi/src/mizan_fastapi/config.py b/backends/mizan-fastapi/src/mizan_fastapi/config.py new file mode 100644 index 0000000..0641da3 --- /dev/null +++ b/backends/mizan-fastapi/src/mizan_fastapi/config.py @@ -0,0 +1,80 @@ +""" +FastAPI configuration — the "no settings.py" seam. + +Builds the shared core's `AuthConfig` (JWT + MWT) and a `CacheOrchestrator` +from environment variables, overridable per-app via `app.state.mizan_config`. + +Env: + MIZAN_CACHE_SECRET HMAC cache signing key (enables origin cache) + MIZAN_CACHE_REDIS_URL Redis URL (else in-memory cache) + MIZAN_MWT_SECRET MWT signing key + MIZAN_MWT_AUDIENCE MWT audience (default "mizan") + JWT_PRIVATE_KEY JWT signing key (enables bearer-JWT auth) + JWT_PUBLIC_KEY JWT verify key (default: private key, HS256) + JWT_ALGORITHM default "HS256" + JWT_ACCESS_TOKEN_EXPIRES_IN / JWT_REFRESH_TOKEN_EXPIRES_IN +""" + +from __future__ import annotations + +import os +from dataclasses import dataclass + +from mizan_core.auth import AuthConfig, JWTConfig +from mizan_core.cache.backend import CacheBackend, MemoryCache +from mizan_core.dispatch import CacheOrchestrator + + +@dataclass(frozen=True) +class MizanConfig: + auth: AuthConfig + cache: CacheOrchestrator + + +def _cache_backend(secret: str | None, redis_url: str | None) -> CacheBackend | None: + if not secret: + return None + if redis_url: + from mizan_core.cache.backend import RedisCache + return RedisCache(redis_url) + return MemoryCache() + + +def _jwt_config() -> JWTConfig | None: + key = os.getenv("JWT_PRIVATE_KEY") + if not key: + return None + return JWTConfig( + private_key=key, + public_key=os.getenv("JWT_PUBLIC_KEY", key), + algorithm=os.getenv("JWT_ALGORITHM", "HS256"), + access_token_expires_in=int(os.getenv("JWT_ACCESS_TOKEN_EXPIRES_IN", "300")), + refresh_token_expires_in=int(os.getenv("JWT_REFRESH_TOKEN_EXPIRES_IN", "604800")), + ) + + +def from_env() -> MizanConfig: + secret = os.getenv("MIZAN_CACHE_SECRET") + backend = _cache_backend(secret, os.getenv("MIZAN_CACHE_REDIS_URL")) + auth = AuthConfig( + jwt=_jwt_config(), + mwt_secret=os.getenv("MIZAN_MWT_SECRET"), + mwt_audience=os.getenv("MIZAN_MWT_AUDIENCE", "mizan"), + ) + return MizanConfig(auth=auth, cache=CacheOrchestrator(backend, secret)) + + +def get_config(request) -> MizanConfig: + """Per-app config: `app.state.mizan_config` if set, else built from env (cached).""" + app = getattr(request, "app", None) + state = getattr(app, "state", None) + override = getattr(state, "mizan_config", None) if state is not None else None + if override is not None: + return override + global _DEFAULT + if _DEFAULT is None: + _DEFAULT = from_env() + return _DEFAULT + + +_DEFAULT: MizanConfig | None = None diff --git a/backends/mizan-fastapi/src/mizan_fastapi/executor.py b/backends/mizan-fastapi/src/mizan_fastapi/executor.py index 64bdac1..6d7a395 100644 --- a/backends/mizan-fastapi/src/mizan_fastapi/executor.py +++ b/backends/mizan-fastapi/src/mizan_fastapi/executor.py @@ -1,263 +1,69 @@ """ -RPC dispatch — looks up registered functions, validates input against the -function's Pydantic Input model, executes, and returns the serialized result. +Dispatch — a thin shim over the shared core (`mizan_core.dispatch`). -Errors raise typed exceptions (MizanError subclasses). Wire those to JSON -responses by registering `mizan_exception_handler` on the FastAPI app, or -let them propagate to your own handler. +The protocol machinery (auth, validation, execution, invalidation, merge, cache) +lives in `mizan_core`; this module re-exports the canonical error taxonomy and +keeps backward-compatible helpers. The router drives `dispatch_call` / +`dispatch_context` directly to get invalidation + origin cache. """ from __future__ import annotations -from enum import Enum from typing import Any -from fastapi.encoders import jsonable_encoder -from pydantic import BaseModel, ValidationError +from mizan_core.dispatch import CacheOrchestrator, DispatchRequest, dispatch_call +from mizan_core.errors import ( + BadRequest, + ErrorCode, + Forbidden, + InternalError, + MizanError, + NotFound, + NotImplementedYet, + Unauthorized, + ValidationFailed, +) +from mizan_core.invalidation import resolve_invalidation, resolve_merges -from mizan_core.registry import get_context_groups, get_function -from mizan_core.type_utils import types_match_for_merge +__all__ = [ + "ErrorCode", + "MizanError", + "NotFound", + "BadRequest", + "ValidationFailed", + "Unauthorized", + "Forbidden", + "NotImplementedYet", + "InternalError", + "compute_invalidation", + "compute_merges", + "execute_function", +] -# ─── Error taxonomy ───────────────────────────────────────────────────────── - - -class ErrorCode(str, Enum): - NOT_FOUND = "NOT_FOUND" - BAD_REQUEST = "BAD_REQUEST" - VALIDATION_ERROR = "VALIDATION_ERROR" - UNAUTHORIZED = "UNAUTHORIZED" - FORBIDDEN = "FORBIDDEN" - NOT_IMPLEMENTED = "NOT_IMPLEMENTED" - INTERNAL_ERROR = "INTERNAL_ERROR" - - -_STATUS = { - ErrorCode.NOT_FOUND: 404, - ErrorCode.BAD_REQUEST: 400, - ErrorCode.VALIDATION_ERROR: 422, - ErrorCode.UNAUTHORIZED: 401, - ErrorCode.FORBIDDEN: 403, - ErrorCode.NOT_IMPLEMENTED: 501, - ErrorCode.INTERNAL_ERROR: 500, -} - - -class MizanError(Exception): - """Base for protocol-level dispatch errors.""" - - code: ErrorCode = ErrorCode.INTERNAL_ERROR - - def __init__(self, message: str, *, details: dict[str, Any] | None = None) -> None: - super().__init__(message) - self.message = message - self.details = details - - @property - def status_code(self) -> int: - return _STATUS[self.code] - - -class NotFound(MizanError): code = ErrorCode.NOT_FOUND # noqa: E701 -class BadRequest(MizanError): code = ErrorCode.BAD_REQUEST # noqa: E701 -class ValidationFailed(MizanError): code = ErrorCode.VALIDATION_ERROR # noqa: E701 -class Unauthorized(MizanError): code = ErrorCode.UNAUTHORIZED # noqa: E701 -class Forbidden(MizanError): code = ErrorCode.FORBIDDEN # noqa: E701 -class NotImplementedYet(MizanError): code = ErrorCode.NOT_IMPLEMENTED # noqa: E701 -class InternalError(MizanError): code = ErrorCode.INTERNAL_ERROR # noqa: E701 - - -# ─── Auth ─────────────────────────────────────────────────────────────────── - - -def _user(request: Any) -> Any: - return getattr(getattr(request, "state", None), "user", None) - - -def _is_authenticated(user: Any) -> bool: - return bool(user) and getattr(user, "is_authenticated", True) - - -def _enforce_auth(request: Any, requirement: Any) -> None: - """Verify the request meets the function's @client(auth=...) requirement, or raise.""" - if requirement is None: - return - - user = _user(request) - - match requirement: - case True | "required": - if not _is_authenticated(user): - raise Unauthorized("Authentication required") - case "staff": - if not _is_authenticated(user): - raise Unauthorized("Authentication required") - if not getattr(user, "is_staff", False): - raise Forbidden("Staff access required") - case "superuser": - if not _is_authenticated(user): - raise Unauthorized("Authentication required") - if not getattr(user, "is_superuser", False): - raise Forbidden("Superuser access required") - case f if callable(f): - if not f(request): - raise Forbidden("Permission denied") - case other: - raise InternalError(f"Unknown auth requirement: {other!r}") - - -# ─── Input validation ─────────────────────────────────────────────────────── - - -def _validate_input(input_cls: Any, input_data: Any) -> BaseModel | None: - """Validate input_data against the function's Input model. Returns the instance or None.""" - if input_cls in (None, BaseModel) or not getattr(input_cls, "model_fields", None): - return None - - fields = input_cls.model_fields - required = [name for name, f in fields.items() if f.is_required()] - - if not input_data: - if required: - raise ValidationFailed( - "Input validation failed", - details={"fields": {name: ["Field required"] for name in required}}, - ) - return input_cls() - - if not isinstance(input_data, dict): - raise BadRequest(f"Input must be an object, got {type(input_data).__name__}") - - try: - return input_cls(**input_data) - except ValidationError as e: - raise ValidationFailed( - "Input validation failed", - details={"errors": e.errors()}, - ) from e - - -# ─── Dispatch ─────────────────────────────────────────────────────────────── - - -def _resolve_function(fn_name: str) -> Any: - view_class = get_function(fn_name) - if view_class is None: - raise NotFound("Function not found") - if getattr(view_class, "_meta", {}).get("private"): - raise Forbidden("Function is not client-callable") - return view_class - - -def _serialize(result: Any) -> Any: - # jsonable_encoder walks BaseModel / list / dict recursively, so list[BaseModel] - # (and nested shapes) come out wire-ready without a per-shape branch here. - return jsonable_encoder(result) - - -async def execute_function( - request: Any, - fn_name: str, - input_data: dict[str, Any] | None = None, -) -> Any: - """Dispatch a registered function. Returns the serialized result, or raises MizanError. - - Awaits `view.acall` — async handlers run on the loop, sync handlers run - in the default threadpool, both via the same entrypoint. - """ - view_class = _resolve_function(fn_name) - _enforce_auth(request, view_class._meta.get("auth")) - - view = view_class(request) - validated = _validate_input(view.Input, input_data) - - try: - result = await view.acall(validated) - except NotImplementedError as e: - raise NotImplementedYet(str(e) or "Not implemented") from e - except MizanError: - raise - except Exception as e: - raise InternalError(str(e)) from e - - return _serialize(result) - - -# ─── Invalidation ─────────────────────────────────────────────────────────── +_NO_CACHE = CacheOrchestrator(None, None) def compute_invalidation(view_class: Any, input_data: dict[str, Any] | None) -> list[Any]: - """Build the `invalidate` list from @client(affects=...) metadata, auto-scoping when arg names match context params.""" - affects = getattr(view_class, "_meta", {}).get("affects") or [] - return [_invalidation_target(target, input_data or {}) for target in affects] + """`@client(affects=...)` → invalidation list (empty when none). Shared core.""" + return resolve_invalidation(view_class, input_data) or [] def compute_merges(view_class: Any, input_data: dict[str, Any] | None, result: Any) -> list[dict[str, Any]]: - """Build the `merge` list from @client(merge=...) metadata. + """`@client(merge=...)` → merge list (empty when none). Shared core.""" + return resolve_merges(view_class, input_data, result) or [] - Each entry is `{context, slot, value, params?}` where `slot` names the - function inside the context bundle the value lands in. The slot is - resolved server-side via `types_match_for_merge` so the kernel does - no shape inference — the server has the schema, type-checked routing - lives here. Entries whose slot can't be uniquely resolved are dropped - with a warning; the consumer falls back to refetch via `affects`. + +async def execute_function(request: Any, fn_name: str, input_data: dict[str, Any] | None = None) -> Any: + """Dispatch a function and return its serialized result (auth enforced via core). + + Backward-compat entry point; the router uses `dispatch_call` directly to also + capture invalidation/merge and run the origin cache. """ - targets = getattr(view_class, "_meta", {}).get("merge") or [] - if not targets: - return [] - mutation_output = getattr(view_class, "Output", None) - out: list[dict[str, Any]] = [] - for ctx_name in targets: - slot = _resolve_merge_slot(ctx_name, mutation_output) - if slot is None: - continue - entry: dict[str, Any] = {"context": ctx_name, "slot": slot, "value": result} - scoped = _scoped_params(ctx_name, input_data or {}) - if scoped: - entry["params"] = scoped - out.append(entry) - return out - - -def _resolve_merge_slot(context_name: str, mutation_output: Any) -> str | None: - """Find the unique function-name slot whose return type matches the mutation's output. - - Returns None on no match or ambiguous match (multiple candidates). - """ - if mutation_output is None: - return None - matches: list[str] = [] - for fn_name in get_context_groups().get(context_name, []): - fn_cls = get_function(fn_name) - if fn_cls is None: - continue - fn_output = getattr(fn_cls, "Output", None) - if fn_output is not None and types_match_for_merge(fn_output, mutation_output): - matches.append(fn_name) - return matches[0] if len(matches) == 1 else None - - -def _scoped_params(context_name: str, input_data: dict[str, Any]) -> dict[str, Any]: - """Match input args against the context's declared Input field names.""" - fn_names = get_context_groups().get(context_name, []) - declared: set[str] = set() - for fn_name in fn_names: - fn_cls = get_function(fn_name) - if fn_cls is None: - continue - input_cls = getattr(fn_cls, "Input", None) - if input_cls and input_cls is not BaseModel and hasattr(input_cls, "model_fields"): - declared.update(input_cls.model_fields.keys()) - return {k: v for k, v in input_data.items() if k in declared} - - -def _invalidation_target(target: dict[str, Any], input_data: dict[str, Any]) -> Any: - match target.get("type"): - case "context": - name = target["name"] - scoped = _scoped_params(name, input_data) - return {"context": name, "params": scoped} if scoped else name - case "function": - return {"function": target["name"]} - case _: - return target + identity = getattr(getattr(request, "state", None), "user", None) + res = await dispatch_call( + DispatchRequest(identity=identity, args=input_data, native_request=request), + fn_name, + _NO_CACHE, + ) + return res.data diff --git a/backends/mizan-fastapi/src/mizan_fastapi/router.py b/backends/mizan-fastapi/src/mizan_fastapi/router.py index 77c67e1..0881ada 100644 --- a/backends/mizan-fastapi/src/mizan_fastapi/router.py +++ b/backends/mizan-fastapi/src/mizan_fastapi/router.py @@ -19,22 +19,17 @@ from typing import Any from fastapi import APIRouter, Request from fastapi.exceptions import RequestValidationError -from fastapi.responses import JSONResponse +from fastapi.responses import JSONResponse, Response from pydantic import BaseModel, Field, ValidationError from starlette.datastructures import UploadFile -from mizan_core.registry import get_context_groups, get_function +from mizan_core.auth import INVALID, authenticate +from mizan_core.dispatch import DispatchRequest, dispatch_call, dispatch_context +from mizan_core.errors import BadRequest, ErrorCode, MizanError, Unauthorized +from mizan_core.registry import get_function from mizan_core.upload import UploadedFile, bind_uploads -from .executor import ( - BadRequest, - ErrorCode, - MizanError, - NotFound, - compute_invalidation, - compute_merges, - execute_function, -) +from .config import MizanConfig, get_config router = APIRouter() @@ -106,31 +101,52 @@ async def _parse_call(request: Request) -> tuple[str, dict[str, Any]]: return body.fn, body.args +def _identity(request: Request, cfg: MizanConfig): + """Identity for dispatch: a host-set `request.state.user`, else a token decode. + + A present-but-invalid token rejects (401); no token → None (anonymous). + """ + existing = getattr(getattr(request, "state", None), "user", None) + if existing is not None: + return existing + ident = authenticate(request.headers, cfg.auth) + if ident is INVALID: + raise Unauthorized("Invalid or expired token") + return ident + + @router.post("/call/") async def function_call(request: Request) -> JSONResponse: - """RPC dispatch — `{"fn": "...", "args": {...}}` (JSON) or multipart with file - parts → `{"result": ..., "invalidate": [...], "merge"?: [...]}`.""" + """RPC dispatch — JSON or multipart → `{"result", "invalidate", "merge"?}` with + the `X-Mizan-Invalidate` header alongside the body.""" + cfg = get_config(request) fn, args = await _parse_call(request) - fn_class = get_function(fn) - result = await execute_function(request, fn, args) - invalidate = compute_invalidation(fn_class, args) - merges = compute_merges(fn_class, args, result) - payload: dict[str, Any] = {"result": result, "invalidate": invalidate} - if merges: - payload["merge"] = merges - return _no_store(payload) + res = await dispatch_call( + DispatchRequest(identity=_identity(request, cfg), args=args, native_request=request), + fn, cfg.cache, + ) + payload: dict[str, Any] = {"result": res.data, "invalidate": res.invalidate or []} + if res.merge: + payload["merge"] = res.merge + headers = {"Cache-Control": "no-store"} + if res.invalidate_header: + headers["X-Mizan-Invalidate"] = res.invalidate_header + return JSONResponse(payload, headers=headers) @router.get("/ctx/{context_name}/") -async def context_fetch(context_name: str, request: Request) -> JSONResponse: - """Bundled context fetch — `{function_name: result, ...}` for every function in the context.""" - fn_names = get_context_groups().get(context_name) - if not fn_names: - raise NotFound(f"Context '{context_name}' not found") - - params = dict(request.query_params) - bundled = {fn: await execute_function(request, fn, params) for fn in fn_names} - return _no_store(bundled) +async def context_fetch(context_name: str, request: Request) -> Response: + """Bundled context fetch — origin-cached. `{function_name: result, ...}`.""" + cfg = get_config(request) + res = await dispatch_context( + DispatchRequest(identity=_identity(request, cfg), args=dict(request.query_params), + native_request=request), + context_name, cfg.cache, + ) + headers = {"Cache-Control": "no-store"} + if res.cache_status: + headers["X-Mizan-Cache"] = res.cache_status + return Response(content=res.body_bytes, media_type="application/json", headers=headers) # ─── Exception handler ────────────────────────────────────────────────────── diff --git a/backends/mizan-fastapi/tests/test_parity.py b/backends/mizan-fastapi/tests/test_parity.py new file mode 100644 index 0000000..66ce8f2 --- /dev/null +++ b/backends/mizan-fastapi/tests/test_parity.py @@ -0,0 +1,98 @@ +"""FastAPI parity with Django: X-Mizan-Invalidate header, origin cache, token auth.""" + +from __future__ import annotations + +import pytest +from fastapi import Depends, FastAPI +from fastapi.testclient import TestClient +from pydantic import BaseModel + +from mizan_core.auth import AuthConfig, JWTConfig, create_access_token +from mizan_core.cache.backend import MemoryCache +from mizan_core.client.function import client +from mizan_core.dispatch import CacheOrchestrator +from mizan_core.registry import clear_registry, register +from mizan_fastapi import ( + MizanAuthMiddleware, + MizanConfig, + MizanError, + mizan_auth, + mizan_exception_handler, + router as mizan_router, +) + + +class Out(BaseModel): + ok: bool + + +SECRET = "x" * 32 +JWT = JWTConfig(private_key=SECRET, public_key=SECRET) + + +def _app(*, with_cache=False, with_auth_dep=False) -> FastAPI: + clear_registry() + + UserCtx = "user" + + @client(context=UserCtx) + def user_profile(request, user_id: int) -> Out: + return Out(ok=True) + + @client(affects=UserCtx) + def update_profile(request, user_id: int) -> Out: + return Out(ok=True) + + @client(auth=True) + def whoami(request) -> Out: + return Out(ok=True) + + register(user_profile, "user_profile") + register(update_profile, "update_profile") + register(whoami, "whoami") + + app = FastAPI() + cache = CacheOrchestrator(MemoryCache(), SECRET) if with_cache else CacheOrchestrator(None, None) + app.state.mizan_config = MizanConfig(auth=AuthConfig(jwt=JWT), cache=cache) + deps = [Depends(mizan_auth())] if with_auth_dep else [] + app.include_router(mizan_router, prefix="/api/mizan", dependencies=deps) + app.add_exception_handler(MizanError, mizan_exception_handler) + return app + + +def test_mutation_emits_invalidate_header(): + c = TestClient(_app()) + r = c.post("/api/mizan/call/", json={"fn": "update_profile", "args": {"user_id": 5}}) + assert r.status_code == 200 + assert r.json()["invalidate"] == [{"context": "user", "params": {"user_id": 5}}] + assert r.headers["X-Mizan-Invalidate"] == "user;user_id=5" + + +def test_origin_cache_hit_miss(): + c = TestClient(_app(with_cache=True)) + r1 = c.get("/api/mizan/ctx/user/", params={"user_id": 5}) + assert r1.status_code == 200 and r1.headers["X-Mizan-Cache"] == "MISS" + r2 = c.get("/api/mizan/ctx/user/", params={"user_id": 5}) + assert r2.headers["X-Mizan-Cache"] == "HIT" + assert r1.content == r2.content + + +def test_auth_required_rejects_anonymous(): + c = TestClient(_app()) + r = c.post("/api/mizan/call/", json={"fn": "whoami", "args": {}}) + assert r.status_code == 401 + + +def test_auth_required_passes_with_bearer_jwt(): + c = TestClient(_app(with_auth_dep=True)) + tok = create_access_token("7", "sess", JWT, is_staff=True) + r = c.post("/api/mizan/call/", json={"fn": "whoami", "args": {}}, + headers={"Authorization": f"Bearer {tok}"}) + assert r.status_code == 200 and r.json()["result"] == {"ok": True} + + +def test_invalid_bearer_token_rejected(): + c = TestClient(_app()) + r = c.post("/api/mizan/call/", json={"fn": "update_profile", "args": {"user_id": 1}}, + headers={"Authorization": "Bearer not-a-real-token"}) + assert r.status_code == 401 diff --git a/backends/mizan-ts/src/decorator.ts b/backends/mizan-ts/src/decorator.ts index 34b3a06..9758ff7 100644 --- a/backends/mizan-ts/src/decorator.ts +++ b/backends/mizan-ts/src/decorator.ts @@ -13,7 +13,7 @@ * } */ -import { ReactContext, type ClientOptions, type RegistryEntry, type ParamDef } from './types' +import { ReactContext, type ClientOptions, type RegistryEntry, type ParamDef, type AuthRequirement } from './types' import { register } from './registry' function resolveContext(ctx: ReactContext | string | undefined): string | undefined { @@ -21,6 +21,19 @@ function resolveContext(ctx: ReactContext | string | undefined): string | undefi return ctx } +/** + * Normalize the public auth option into the stored requirement. + * Mirrors Python: undefined→undefined, true→'required', callable→callable, + * 'staff'/'superuser' pass through, anything else throws at decoration time. + */ +function normalizeAuth(auth: ClientOptions['auth']): AuthRequirement | undefined { + if (auth === undefined) return undefined + if (auth === true) return 'required' + if (typeof auth === 'function') return auth + if (auth === 'staff' || auth === 'superuser') return auth + throw new Error(`Invalid auth value ${JSON.stringify(auth)}`) +} + function normalizeAffects( affects: ClientOptions['affects'], ): RegistryEntry['affects'] | undefined { @@ -97,7 +110,7 @@ export function client(optionsOrFn: ClientOptions | ClientOptions, fn?: Function viewPath: isView, route: options.route, methods: options.methods, - auth: options.auth, + auth: normalizeAuth(options.auth), rev: options.rev, cache: options.cache, } @@ -129,7 +142,7 @@ export function client(optionsOrFn: ClientOptions | ClientOptions, fn?: Function viewPath: false, route: options.route, methods: options.methods, - auth: options.auth, + auth: normalizeAuth(options.auth), rev: options.rev, cache: options.cache, } diff --git a/backends/mizan-ts/src/dispatch.ts b/backends/mizan-ts/src/dispatch.ts index 5159769..48c7f62 100644 --- a/backends/mizan-ts/src/dispatch.ts +++ b/backends/mizan-ts/src/dispatch.ts @@ -8,6 +8,8 @@ import { getFunction, getContextGroups } from './registry' import { resolveInvalidation, formatInvalidateHeader } from './invalidation' import { getCache, cacheGet, cachePut, cachePurge } from './cache' +import { ANONYMOUS, type Identity } from './identity' +import type { AuthRequirement } from './types' let _cacheSecret: string | null = null @@ -22,6 +24,54 @@ export interface MizanResponse { headers: Record } +interface AuthDenial { + status: 401 | 403 + code: 'UNAUTHORIZED' | 'FORBIDDEN' + message: string +} + +/** + * Check whether `identity` satisfies the stored `auth` requirement. + * Ports Django's _check_auth_requirement exactly. Returns an AuthDenial + * on failure, or null when access is allowed. + */ +function checkAuth(auth: AuthRequirement | undefined, identity: Identity): AuthDenial | null { + if (auth === undefined) return null + + // Callable runs first — before the authentication gate. + if (typeof auth === 'function') { + try { + return auth(identity) + ? null + : { status: 403, code: 'FORBIDDEN', message: 'Access denied' } + } catch (e: any) { + return { status: 403, code: 'FORBIDDEN', message: e?.message || 'Access denied' } + } + } + + if (!identity.isAuthenticated) { + return { status: 401, code: 'UNAUTHORIZED', message: 'Authentication required' } + } + + if (auth === 'staff' && !identity.isStaff) { + return { status: 403, code: 'FORBIDDEN', message: 'Staff access required' } + } + + if (auth === 'superuser' && !identity.isSuperuser) { + return { status: 403, code: 'FORBIDDEN', message: 'Superuser access required' } + } + + return null +} + +function authDenialResponse(denial: AuthDenial): MizanResponse { + return { + status: denial.status, + body: { error: true, code: denial.code, message: denial.message }, + headers: { 'Cache-Control': 'no-store', 'Content-Type': 'application/json' }, + } +} + /** * Handle GET /api/mizan/ctx/:contextName/ * @@ -30,6 +80,7 @@ export interface MizanResponse { export async function handleContextFetch( contextName: string, params: Record, + identity: Identity = ANONYMOUS, ): Promise { const groups = getContextGroups() const fnNames = groups[contextName] @@ -42,6 +93,15 @@ export async function handleContextFetch( } } + // Auth pre-pass — run BEFORE the cache lookup so a cache HIT can never + // leak to an unauthorized caller. Any denial short-circuits, uncached. + for (const fnName of fnNames) { + const entry = getFunction(fnName) + if (!entry) continue + const denial = checkAuth(entry.auth, identity) + if (denial) return authDenialResponse(denial) + } + // Resolve effective rev (max across functions) and cache policy (min TTL) let effectiveRev = 0 for (const fnName of fnNames) { @@ -133,6 +193,7 @@ export async function handleContextFetch( export async function handleMutationCall( fnName: string, args: Record, + identity: Identity = ANONYMOUS, ): Promise { const entry = getFunction(fnName) @@ -153,6 +214,10 @@ export async function handleMutationCall( } } + // Auth enforcement — after private rejection, before execution. + const denial = checkAuth(entry.auth, identity) + if (denial) return authDenialResponse(denial) + try { const argValues = entry.params.map(p => args[p.name]) const result = await entry.fn(...argValues) diff --git a/backends/mizan-ts/src/identity.ts b/backends/mizan-ts/src/identity.ts new file mode 100644 index 0000000..b5dac16 --- /dev/null +++ b/backends/mizan-ts/src/identity.ts @@ -0,0 +1,22 @@ +/** + * Identity abstraction — the request-bound caller identity. + * + * Framework-agnostic. Adapters construct an Identity (from MWT, JWT, + * session, etc.) and pass it into dispatch. ANONYMOUS is the default. + */ + +export interface Identity { + isAuthenticated: boolean + isStaff: boolean + isSuperuser: boolean + id: number | string | null +} + +export const ANONYMOUS: Identity = { + isAuthenticated: false, + isStaff: false, + isSuperuser: false, + id: null, +} + +export type AuthPredicate = (identity: Identity) => boolean diff --git a/backends/mizan-ts/src/index.ts b/backends/mizan-ts/src/index.ts index bdd1eee..50b147d 100644 --- a/backends/mizan-ts/src/index.ts +++ b/backends/mizan-ts/src/index.ts @@ -1,5 +1,11 @@ export { ReactContext } from './types' -export type { ClientOptions, EdgeManifest, RegistryEntry } from './types' +export type { ClientOptions, EdgeManifest, RegistryEntry, AuthOption, AuthRequirement } from './types' + +export { ANONYMOUS } from './identity' +export type { Identity, AuthPredicate } from './identity' + +export { decodeMwt, decodeJwtBearer, identityFromMwt } from './token' +export type { MwtPayload } from './token' export { client } from './decorator' diff --git a/backends/mizan-ts/src/token.ts b/backends/mizan-ts/src/token.ts new file mode 100644 index 0000000..5289ab6 --- /dev/null +++ b/backends/mizan-ts/src/token.ts @@ -0,0 +1,110 @@ +/** + * MWT / JWT decode — HS256 verification, cross-language parity with + * cores/mizan-python/src/mizan_core/mwt.py. + * + * Returns null on ANY failure (bad signature, expired, future nbf, wrong + * aud, malformed). Never throws. + */ + +import { createHmac, timingSafeEqual } from 'crypto' +import type { Identity } from './identity' + +export interface MwtPayload { + sub: string + staff: boolean + super: boolean + pkey: string + kid: string + aud: string + iat: number + exp: number +} + +function base64urlDecode(input: string): Buffer | null { + if (!/^[A-Za-z0-9_-]*$/.test(input)) return null + return Buffer.from(input, 'base64url') +} + +function constantTimeEqual(a: Buffer, b: Buffer): boolean { + if (a.length !== b.length) return false + return timingSafeEqual(a, b) +} + +/** + * Decode and validate an MWT (HS256 JWT with Mizan claims). + * Returns MwtPayload on success, null on any failure. + */ +export function decodeMwt( + token: string, + secret: string, + audience: string = 'mizan', +): MwtPayload | null { + try { + const parts = token.split('.') + if (parts.length !== 3) return null + const [headerB64, payloadB64, signatureB64] = parts + + const headerBytes = base64urlDecode(headerB64) + const payloadBytes = base64urlDecode(payloadB64) + const signatureBytes = base64urlDecode(signatureB64) + if (!headerBytes || !payloadBytes || !signatureBytes) return null + + const header = JSON.parse(headerBytes.toString('utf-8')) + if (header.alg !== 'HS256') return null + + // Recompute HMAC over `${headerB64}.${payloadB64}` + const expected = createHmac('sha256', secret) + .update(`${headerB64}.${payloadB64}`) + .digest() + if (!constantTimeEqual(expected, signatureBytes)) return null + + const data = JSON.parse(payloadBytes.toString('utf-8')) + + const now = Math.floor(Date.now() / 1000) + if (typeof data.exp !== 'number' || data.exp <= now) return null + if (data.nbf !== undefined && typeof data.nbf === 'number' && data.nbf > now) return null + if (data.aud !== audience) return null + + const kid = typeof header.kid === 'string' ? header.kid : 'v1' + + return { + sub: String(data.sub), + staff: Boolean(data.staff), + super: Boolean(data.super), + pkey: typeof data.pkey === 'string' ? data.pkey : '', + kid, + aud: audience, + iat: data.iat, + exp: data.exp, + } + } catch { + return null + } +} + +/** + * Decode a Bearer JWT from an Authorization header value. + * Strips the "Bearer " prefix, then validates as an MWT. + */ +export function decodeJwtBearer( + authHeader: string, + secret: string, + audience: string = 'mizan', +): MwtPayload | null { + if (!authHeader) return null + const prefix = 'Bearer ' + const token = authHeader.startsWith(prefix) + ? authHeader.slice(prefix.length) + : authHeader + return decodeMwt(token, secret, audience) +} + +/** Build an Identity from a decoded MWT payload. */ +export function identityFromMwt(payload: MwtPayload): Identity { + return { + isAuthenticated: true, + isStaff: payload.staff, + isSuperuser: payload.super, + id: Number(payload.sub), + } +} diff --git a/backends/mizan-ts/src/types.ts b/backends/mizan-ts/src/types.ts index bde256b..17309e8 100644 --- a/backends/mizan-ts/src/types.ts +++ b/backends/mizan-ts/src/types.ts @@ -2,6 +2,8 @@ * Mizan TypeScript Adapter — Shared Types */ +import type { AuthPredicate } from './identity' + export class ReactContext { constructor(public readonly name: string) { if (!name) throw new Error('ReactContext name must be non-empty') @@ -10,13 +12,19 @@ export class ReactContext { export type AffectsTarget = ReactContext | string +/** Public auth option on the decorator. `true` normalizes to `'required'` when stored. */ +export type AuthOption = true | 'staff' | 'superuser' | AuthPredicate + +/** Normalized auth requirement as stored on the registry entry. */ +export type AuthRequirement = 'required' | 'staff' | 'superuser' | AuthPredicate + export interface ClientOptions { context?: ReactContext | string affects?: AffectsTarget | AffectsTarget[] private?: boolean route?: string methods?: string[] - auth?: boolean + auth?: AuthOption rev?: number cache?: number | false } @@ -37,7 +45,7 @@ export interface RegistryEntry { viewPath: boolean route?: string methods?: string[] - auth?: boolean + auth?: AuthRequirement rev?: number cache?: number | false } diff --git a/backends/mizan-ts/tests/auth.test.ts b/backends/mizan-ts/tests/auth.test.ts new file mode 100644 index 0000000..10b1fe6 --- /dev/null +++ b/backends/mizan-ts/tests/auth.test.ts @@ -0,0 +1,163 @@ +/** + * Auth-parity tests — mirrors Django's auth enforcement in + * mizan-django/src/mizan/client/executor.py (_check_auth_requirement). + */ + +import { describe, test, expect, beforeEach } from 'bun:test' +import { + ReactContext, client, clearRegistry, + handleContextFetch, handleMutationCall, + setCache, resetCache, setCacheSecret, MemoryCache, + type Identity, +} from '../src' + +function anon(): Identity { + return { isAuthenticated: false, isStaff: false, isSuperuser: false, id: null } +} +function user(): Identity { + return { isAuthenticated: true, isStaff: false, isSuperuser: false, id: 1 } +} +function staff(): Identity { + return { isAuthenticated: true, isStaff: true, isSuperuser: false, id: 2 } +} +function superuser(): Identity { + return { isAuthenticated: true, isStaff: true, isSuperuser: true, id: 3 } +} + +describe('Auth — mutation dispatch', () => { + beforeEach(() => clearRegistry()) + + test('auth:true + anon → 401', async () => { + client({ auth: true }, async function secret() { return { ok: true } }) + const r = await handleMutationCall('secret', {}, anon()) + expect(r.status).toBe(401) + expect(r.body.code).toBe('UNAUTHORIZED') + expect(r.body.message).toBe('Authentication required') + expect(r.headers['Cache-Control']).toBe('no-store') + }) + + test('auth:true + user → 200', async () => { + client({ auth: true }, async function secret() { return { ok: true } }) + const r = await handleMutationCall('secret', {}, user()) + expect(r.status).toBe(200) + expect(r.body.result).toEqual({ ok: true }) + }) + + test("auth:'staff' + user → 403", async () => { + client({ auth: 'staff' }, async function adminAction() { return { ok: true } }) + const r = await handleMutationCall('adminAction', {}, user()) + expect(r.status).toBe(403) + expect(r.body.code).toBe('FORBIDDEN') + expect(r.body.message).toBe('Staff access required') + }) + + test("auth:'staff' + staff → 200", async () => { + client({ auth: 'staff' }, async function adminAction() { return { ok: true } }) + const r = await handleMutationCall('adminAction', {}, staff()) + expect(r.status).toBe(200) + }) + + test("auth:'superuser' + staff → 403", async () => { + client({ auth: 'superuser' }, async function nuke() { return { ok: true } }) + const r = await handleMutationCall('nuke', {}, staff()) + expect(r.status).toBe(403) + expect(r.body.message).toBe('Superuser access required') + }) + + test("auth:'superuser' + superuser → 200", async () => { + client({ auth: 'superuser' }, async function nuke() { return { ok: true } }) + const r = await handleMutationCall('nuke', {}, superuser()) + expect(r.status).toBe(200) + }) + + test('callable → true → 200', async () => { + client({ auth: (id) => id.isAuthenticated }, async function gated() { return { ok: true } }) + const r = await handleMutationCall('gated', {}, user()) + expect(r.status).toBe(200) + }) + + test("callable → false → 403 'Access denied'", async () => { + client({ auth: () => false }, async function gated() { return { ok: true } }) + const r = await handleMutationCall('gated', {}, user()) + expect(r.status).toBe(403) + expect(r.body.message).toBe('Access denied') + }) + + test("callable throws Error('msg') → 403 'msg'", async () => { + client({ auth: () => { throw new Error('msg') } }, async function gated() { return { ok: true } }) + const r = await handleMutationCall('gated', {}, user()) + expect(r.status).toBe(403) + expect(r.body.message).toBe('msg') + }) + + test('callable runs before authentication gate (anon allowed if predicate true)', async () => { + client({ auth: () => true }, async function gated() { return { ok: true } }) + const r = await handleMutationCall('gated', {}, anon()) + expect(r.status).toBe(200) + }) + + test('invalid auth string at decoration → throws', () => { + expect(() => { + client({ auth: 'admin' as any }, async function bad() { return {} }) + }).toThrow('Invalid auth value') + }) + + test('no auth + anon → 200 (default ANONYMOUS path stays open)', async () => { + client({}, async function open() { return { ok: true } }) + const r = await handleMutationCall('open', {}) + expect(r.status).toBe(200) + }) +}) + +describe('Auth — context fetch', () => { + beforeEach(() => clearRegistry()) + + test('auth-gated context member + anon → 401', async () => { + const Ctx = new ReactContext('secure') + client({ context: Ctx, auth: true }, async function secureData(itemId: number) { + return { id: itemId } + }) + const r = await handleContextFetch('secure', { itemId: '1' }, anon()) + expect(r.status).toBe(401) + expect(r.body.message).toBe('Authentication required') + }) + + test('auth-gated context + user → 200', async () => { + const Ctx = new ReactContext('secure') + client({ context: Ctx, auth: true }, async function secureData(itemId: number) { + return { id: itemId } + }) + const r = await handleContextFetch('secure', { itemId: '1' }, user()) + expect(r.status).toBe(200) + expect(r.body.secureData).toEqual({ id: '1' }) + }) + + test('context fetch denial pre-empts a would-be cache HIT', async () => { + const SECRET = 'auth-test-secret-32bytes-padding!' + const Ctx = new ReactContext('secure') + client({ context: Ctx, auth: true }, async function secureData(itemId: number) { + return { id: itemId } + }) + + const cache = new MemoryCache() + setCache(cache) + setCacheSecret(SECRET) + + // Prime the cache as an authorized caller. + const primed = await handleContextFetch('secure', { itemId: '1' }, user()) + expect(primed.status).toBe(200) + expect(primed.headers['X-Mizan-Cache']).toBe('MISS') + + // Confirm it's now a cache HIT for an authorized caller. + const hit = await handleContextFetch('secure', { itemId: '1' }, user()) + expect(hit.headers['X-Mizan-Cache']).toBe('HIT') + + // Anon must get 401 even though the cache holds the entry. + const denied = await handleContextFetch('secure', { itemId: '1' }, anon()) + expect(denied.status).toBe(401) + expect(denied.headers['X-Mizan-Cache']).toBeUndefined() + + resetCache() + setCacheSecret(null) + }) +}) diff --git a/backends/mizan-ts/tests/token.test.ts b/backends/mizan-ts/tests/token.test.ts new file mode 100644 index 0000000..2f7f115 --- /dev/null +++ b/backends/mizan-ts/tests/token.test.ts @@ -0,0 +1,126 @@ +/** + * MWT decode tests — round-trip + cross-language pin against Python create_mwt. + */ + +import { describe, test, expect } from 'bun:test' +import { createHmac } from 'crypto' +import { decodeMwt, decodeJwtBearer, identityFromMwt } from '../src' + +function b64url(buf: Buffer | string): string { + return Buffer.from(buf).toString('base64url') +} + +/** Mint an HS256 MWT with node crypto, mirroring Python create_mwt. */ +function mint(payload: Record, secret: string, kid = 'v1'): string { + const header = b64url(JSON.stringify({ alg: 'HS256', kid, typ: 'JWT' })) + const body = b64url(JSON.stringify(payload)) + const sig = createHmac('sha256', secret).update(`${header}.${body}`).digest('base64url') + return `${header}.${body}.${sig}` +} + +const SECRET = 'round-trip-secret' +const now = Math.floor(Date.now() / 1000) + +function basePayload(overrides: Record = {}) { + return { + sub: '7', + staff: true, + super: false, + pkey: 'abc123', + aud: 'mizan', + iat: now, + nbf: now, + exp: now + 300, + ...overrides, + } +} + +describe('MWT round-trip', () => { + test('valid token decodes', () => { + const token = mint(basePayload(), SECRET) + const p = decodeMwt(token, SECRET) + expect(p).not.toBeNull() + expect(p!.sub).toBe('7') + expect(p!.staff).toBe(true) + expect(p!.super).toBe(false) + expect(p!.pkey).toBe('abc123') + expect(p!.kid).toBe('v1') + expect(p!.aud).toBe('mizan') + }) + + test('identityFromMwt maps claims', () => { + const token = mint(basePayload({ sub: '99', staff: false, super: true }), SECRET) + const p = decodeMwt(token, SECRET)! + expect(identityFromMwt(p)).toEqual({ + isAuthenticated: true, + isStaff: false, + isSuperuser: true, + id: 99, + }) + }) + + test('decodeJwtBearer strips Bearer prefix', () => { + const token = mint(basePayload(), SECRET) + const p = decodeJwtBearer(`Bearer ${token}`, SECRET) + expect(p).not.toBeNull() + expect(p!.sub).toBe('7') + }) + + test('null on tampered signature', () => { + const token = mint(basePayload(), SECRET) + const tampered = token.slice(0, -2) + (token.endsWith('AA') ? 'BB' : 'AA') + expect(decodeMwt(tampered, SECRET)).toBeNull() + }) + + test('null on wrong secret', () => { + const token = mint(basePayload(), SECRET) + expect(decodeMwt(token, 'other-secret')).toBeNull() + }) + + test('null on expired exp', () => { + const token = mint(basePayload({ exp: now - 10 }), SECRET) + expect(decodeMwt(token, SECRET)).toBeNull() + }) + + test('null on future nbf', () => { + const token = mint(basePayload({ nbf: now + 1000 }), SECRET) + expect(decodeMwt(token, SECRET)).toBeNull() + }) + + test('null on wrong aud', () => { + const token = mint(basePayload({ aud: 'other' }), SECRET) + expect(decodeMwt(token, SECRET)).toBeNull() + }) + + test('null on malformed token', () => { + expect(decodeMwt('not.a.jwt', SECRET)).toBeNull() + expect(decodeMwt('onlyonepart', SECRET)).toBeNull() + expect(decodeMwt('', SECRET)).toBeNull() + }) +}) + +describe('MWT cross-language pin (Python create_mwt)', () => { + const TOKEN = 'eyJhbGciOiJIUzI1NiIsImtpZCI6InYxIiwidHlwIjoiSldUIn0.eyJzdWIiOiI0MiIsInN0YWZmIjp0cnVlLCJzdXBlciI6ZmFsc2UsInBrZXkiOiIwZTk5OGE5ZmYxNjkwNDYzN2EwM2QyZWEwZmJkYmY5NzQyOTdhOWQxYTVkMjViOGQ0Mjk0ZmE4ODIxMTVlNDU3IiwiYXVkIjoibWl6YW4iLCJpYXQiOjE3MDAwMDAwMDAsIm5iZiI6MTcwMDAwMDAwMCwiZXhwIjo0MTAyNDQ0ODAwfQ._V92JXiLSLXoyuSwbNvvJjwzgmczmC7dvX34kVSLIa8' + const PIN_SECRET = 'pin-test-secret-mwt' + + test('decodes the Python-minted token', () => { + const p = decodeMwt(TOKEN, PIN_SECRET) + expect(p).not.toBeNull() + expect(p!.sub).toBe('42') + expect(p!.staff).toBe(true) + expect(p!.super).toBe(false) + expect(p!.pkey).toBe('0e998a9ff16904637a03d2ea0fbdbf974297a9d1a5d25b8d4294fa882115e457') + expect(p!.kid).toBe('v1') + expect(p!.aud).toBe('mizan') + }) + + test('identity from Python-minted token', () => { + const p = decodeMwt(TOKEN, PIN_SECRET)! + expect(identityFromMwt(p)).toEqual({ + isAuthenticated: true, + isStaff: true, + isSuperuser: false, + id: 42, + }) + }) +}) diff --git a/cores/mizan-python/src/mizan_core/auth/__init__.py b/cores/mizan-python/src/mizan_core/auth/__init__.py new file mode 100644 index 0000000..2c4f237 --- /dev/null +++ b/cores/mizan-python/src/mizan_core/auth/__init__.py @@ -0,0 +1,27 @@ +from mizan_core.auth.authenticate import INVALID, AuthConfig, authenticate +from mizan_core.auth.jwt import ( + JWTConfig, + JWTUser, + TokenPair, + TokenPayload, + create_access_token, + create_refresh_token, + create_token_pair, + decode_token, + refresh_tokens, +) + +__all__ = [ + "AuthConfig", + "authenticate", + "INVALID", + "JWTConfig", + "JWTUser", + "TokenPair", + "TokenPayload", + "create_access_token", + "create_refresh_token", + "create_token_pair", + "decode_token", + "refresh_tokens", +] diff --git a/cores/mizan-python/src/mizan_core/auth/authenticate.py b/cores/mizan-python/src/mizan_core/auth/authenticate.py new file mode 100644 index 0000000..2bcd8b9 --- /dev/null +++ b/cores/mizan-python/src/mizan_core/auth/authenticate.py @@ -0,0 +1,53 @@ +""" +Token → identity resolution, shared by every adapter. + +`authenticate(headers, config)` reads `X-Mizan-Token` (MWT) first, then +`Authorization: Bearer` (JWT), and returns an `Identity`, `None`, or the +`INVALID` sentinel. + +The `INVALID` sentinel is load-bearing: when a token is PRESENT but bad, the +adapter must REJECT — never silently fall back to session auth (that would let +a forged/expired token degrade into anonymous-or-session access). `None` means +"no token offered" → the adapter may fall back to its own session identity. +""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import Mapping + +from mizan_core.auth.jwt import JWTConfig, JWTUser, decode_token +from mizan_core.identity import Identity +from mizan_core.mwt import MWTUser, decode_mwt + + +class _Invalid: + """Sentinel: a token was presented but failed validation.""" + + def __repr__(self) -> str: + return "INVALID" + + +INVALID = _Invalid() + + +@dataclass(frozen=True) +class AuthConfig: + jwt: JWTConfig | None = None + mwt_secret: str | None = None + mwt_audience: str = "mizan" + + +def authenticate(headers: Mapping[str, str], config: AuthConfig) -> Identity | _Invalid | None: + """Resolve identity from request headers. Returns Identity | INVALID | None.""" + mwt = headers.get("X-Mizan-Token") or headers.get("x-mizan-token") + if mwt and config.mwt_secret: + payload = decode_mwt(mwt, config.mwt_secret, audience=config.mwt_audience) + return MWTUser(payload) if payload else INVALID + + bearer = headers.get("Authorization") or headers.get("authorization") or "" + if bearer.startswith("Bearer ") and config.jwt: + payload = decode_token(bearer[7:], config.jwt, expected_type="access") + return JWTUser(payload) if payload else INVALID + + return None diff --git a/cores/mizan-python/src/mizan_core/auth/jwt.py b/cores/mizan-python/src/mizan_core/auth/jwt.py new file mode 100644 index 0000000..9bf531f --- /dev/null +++ b/cores/mizan-python/src/mizan_core/auth/jwt.py @@ -0,0 +1,137 @@ +""" +JWT access/refresh tokens — adapter-agnostic (PyJWT). + +Config is injected (`JWTConfig`) rather than read from any framework's settings. +`validate_session` (the immediate-logout-revocation check) is Django-session-bound +and stays in the Django adapter; `refresh_tokens` takes a `session_validator` +callable so the core stays framework-free. +""" + +from __future__ import annotations + +import time +from dataclasses import dataclass +from typing import Callable, NamedTuple + +import jwt + + +@dataclass(frozen=True) +class JWTConfig: + private_key: str + public_key: str + algorithm: str = "HS256" + access_token_expires_in: int = 300 + refresh_token_expires_in: int = 604800 + + +class TokenPair(NamedTuple): + access_token: str + refresh_token: str + expires_in: int + + +class TokenPayload(NamedTuple): + user_id: int | str + session_key: str + token_type: str + is_staff: bool + is_superuser: bool + exp: int + iat: int + + +class JWTUser: + """Minimal `Identity` built from JWT claims — no DB query.""" + + def __init__(self, payload: TokenPayload): + self.id = int(payload.user_id) if isinstance(payload.user_id, str) else payload.user_id + self.pk = self.id + self.is_staff = payload.is_staff + self.is_superuser = payload.is_superuser + self.is_authenticated = True + self.is_anonymous = False + self.is_active = True + + def __str__(self) -> str: + return f"JWTUser(id={self.id})" + + def __repr__(self) -> str: + return f"JWTUser(id={self.id}, is_staff={self.is_staff}, is_superuser={self.is_superuser})" + + +def _mint(user_id: int | str, session_key: str, token_type: str, ttl: int, + config: JWTConfig, is_staff: bool, is_superuser: bool) -> str: + now = int(time.time()) + payload = { + "sub": str(user_id), + "sid": session_key, + "staff": is_staff, + "super": is_superuser, + "type": token_type, + "iat": now, + "exp": now + ttl, + } + return jwt.encode(payload, config.private_key, algorithm=config.algorithm) + + +def create_access_token(user_id, session_key, config: JWTConfig, *, + is_staff: bool = False, is_superuser: bool = False) -> str: + return _mint(user_id, session_key, "access", config.access_token_expires_in, + config, is_staff, is_superuser) + + +def create_refresh_token(user_id, session_key, config: JWTConfig, *, + is_staff: bool = False, is_superuser: bool = False) -> str: + return _mint(user_id, session_key, "refresh", config.refresh_token_expires_in, + config, is_staff, is_superuser) + + +def create_token_pair(user_id, session_key, config: JWTConfig, *, + is_staff: bool = False, is_superuser: bool = False) -> TokenPair: + return TokenPair( + access_token=create_access_token(user_id, session_key, config, + is_staff=is_staff, is_superuser=is_superuser), + refresh_token=create_refresh_token(user_id, session_key, config, + is_staff=is_staff, is_superuser=is_superuser), + expires_in=config.access_token_expires_in, + ) + + +def decode_token(token: str, config: JWTConfig, expected_type: str | None = None) -> TokenPayload | None: + """Decode + validate. None on invalid/expired token, or type mismatch.""" + try: + payload = jwt.decode(token, config.public_key, algorithms=[config.algorithm]) + except jwt.PyJWTError: + return None + if expected_type and payload.get("type") != expected_type: + return None + return TokenPayload( + user_id=payload["sub"], + session_key=payload["sid"], + token_type=payload["type"], + is_staff=payload.get("staff", False), + is_superuser=payload.get("super", False), + exp=payload["exp"], + iat=payload["iat"], + ) + + +def refresh_tokens( + refresh_token: str, + config: JWTConfig, + session_validator: Callable[[str], bool] | None = None, +) -> TokenPair | None: + """Exchange a refresh token for a new pair. None if invalid or the session is gone. + + `session_validator(session_key) -> bool` lets the Django adapter enforce + immediate-logout revocation; omit it (or pass a always-True) where there is + no session store. + """ + payload = decode_token(refresh_token, config, expected_type="refresh") + if payload is None: + return None + if session_validator is not None and not session_validator(payload.session_key): + return None + return create_token_pair(payload.user_id, payload.session_key, config, + is_staff=payload.is_staff, is_superuser=payload.is_superuser) diff --git a/cores/mizan-python/src/mizan_core/authguard.py b/cores/mizan-python/src/mizan_core/authguard.py new file mode 100644 index 0000000..9d760fb --- /dev/null +++ b/cores/mizan-python/src/mizan_core/authguard.py @@ -0,0 +1,52 @@ +""" +Auth-guard evaluation — the adapter-agnostic core. + +`enforce_auth` evaluates a function's `@client(auth=...)` requirement against an +`Identity` and raises `Unauthorized`/`Forbidden` on failure. A custom `auth=callable` +receives the adapter's NATIVE request (it may read request-specific state), passed +through opaquely — the core never introspects it. +""" + +from __future__ import annotations + +from typing import Any + +from mizan_core.errors import Forbidden, InternalError, Unauthorized +from mizan_core.identity import Identity + + +def enforce_auth( + identity: Identity | None, + requirement: Any, + native_request: Any = None, +) -> None: + """Raise `Unauthorized`/`Forbidden` if `identity` fails `requirement`; else return. + + Requirement: None | True | "required" | "staff" | "superuser" | callable(native_request)->bool. + """ + if requirement is None: + return + + if callable(requirement): + try: + if not requirement(native_request): + raise Forbidden("Access denied") + except PermissionError as e: + raise Forbidden(str(e) or "Access denied") from e + return + + if not getattr(identity, "is_authenticated", False): + raise Unauthorized("Authentication required") + + if requirement in (True, "required"): + return + if requirement == "staff": + if not getattr(identity, "is_staff", False): + raise Forbidden("Staff access required") + return + if requirement == "superuser": + if not getattr(identity, "is_superuser", False): + raise Forbidden("Superuser access required") + return + + raise InternalError(f"Unknown auth requirement: {requirement!r}") diff --git a/cores/mizan-python/src/mizan_core/dispatch.py b/cores/mizan-python/src/mizan_core/dispatch.py new file mode 100644 index 0000000..fb35f1d --- /dev/null +++ b/cores/mizan-python/src/mizan_core/dispatch.py @@ -0,0 +1,250 @@ +""" +The adapter-agnostic dispatch core. + +Both `dispatch_call` (mutations/RPC) and `dispatch_context` (bundled reads) run +the full protocol: auth → input validation → execute (`await view.acall`, which +threadpools sync handlers) → serialize → resolve invalidation/merge → orchestrate +origin cache. They return a `DispatchResult` the adapter renders to its native +response. Errors raise `MizanError` (the adapter catches at its boundary). + +The adapter owns native request parsing (multipart/JSON) and native response +construction; it hands the core a `DispatchRequest` carrying only what the core +reads, and renders what the core returns. +""" + +from __future__ import annotations + +import json +from dataclasses import dataclass, field +from typing import Any, Literal + +from pydantic import BaseModel, ValidationError +from pydantic_core import to_jsonable_python + +from mizan_core.authguard import enforce_auth +from mizan_core.cache.backend import CacheBackend +from mizan_core.cache.keys import CONTEXT_KEY_PREFIX, derive_cache_key +from mizan_core.errors import ( + BadRequest, + InternalError, + MizanError, + NotFound, + NotImplementedYet, + ValidationFailed, +) +from mizan_core.identity import Identity, user_id_of +from mizan_core.invalidation import ( + format_invalidate_header, + resolve_invalidation, + resolve_merges, +) +from mizan_core.registry import get_context_groups, get_function + + +# ─── Request / result ─────────────────────────────────────────────────────── + + +@dataclass +class DispatchRequest: + """What the dispatch core reads. The adapter resolves `identity` (session OR + token) and parses `args`/`files`; `native_request` is an opaque passthrough + handed to `view_class(...)` and to `auth=callable`.""" + + identity: Identity | None = None + args: dict[str, Any] | None = None + files: dict[str, list[Any]] | None = None + native_request: Any = None + + +@dataclass +class DispatchResult: + kind: Literal["rpc", "view", "context"] = "rpc" + native_response: Any | None = None # view-path: the handler's own response + data: Any | None = None # rpc: serialized payload; context: bundle dict + body_bytes: bytes | None = None # context: canonical JSON to send/cache + cache_status: str | None = None # context: "HIT" | "MISS" | None + invalidate: list[Any] | None = None + merge: list[dict[str, Any]] | None = None + invalidate_header: str | None = None + + +# ─── Cache orchestration ──────────────────────────────────────────────────── + + +class CacheOrchestrator: + """Origin-side cache, backend + secret injected by the adapter (config seam).""" + + def __init__(self, backend: CacheBackend | None, secret: str | None): + self.backend = backend + self.secret = secret + + @property + def enabled(self) -> bool: + return self.backend is not None and bool(self.secret) + + def get(self, context: str, params: dict[str, Any], user_id: str | None, rev: int) -> bytes | None: + if not self.enabled: + return None + return self.backend.get(derive_cache_key(self.secret, context, params, user_id, rev)) + + def put(self, context: str, params: dict[str, Any], value: bytes, user_id: str | None, rev: int) -> None: + if not self.enabled: + return + self.backend.set(derive_cache_key(self.secret, context, params, user_id, rev), value) + + def purge(self, invalidate: list[Any], user_id: str | None) -> None: + if not self.enabled: + return + for entry in invalidate: + if isinstance(entry, str): + self.backend.delete_by_prefix(f"{CONTEXT_KEY_PREFIX}{entry}:") + elif isinstance(entry, dict): + ctx = entry["context"] + params = entry.get("params") + if params: + self.backend.delete(derive_cache_key(self.secret, ctx, params, user_id, 0)) + else: + self.backend.delete_by_prefix(f"{CONTEXT_KEY_PREFIX}{ctx}:") + + +# ─── Shared dispatch helpers ──────────────────────────────────────────────── + + +def _resolve_function(fn_name: str) -> Any: + view_class = get_function(fn_name) + if view_class is None: + raise NotFound("Function not found") + if getattr(view_class, "_meta", {}).get("private"): + from mizan_core.errors import Forbidden + raise Forbidden("Function is not client-callable") + return view_class + + +def _validate_input(input_cls: Any, input_data: Any) -> BaseModel | None: + """Validate `input_data` against the function's Input model.""" + if input_cls in (None, BaseModel) or not getattr(input_cls, "model_fields", None): + return None + required = [name for name, f in input_cls.model_fields.items() if f.is_required()] + if not input_data: + if required: + raise ValidationFailed( + "Input validation failed", + details={"fields": {name: ["Field required"] for name in required}}, + ) + return input_cls() + if not isinstance(input_data, dict): + raise BadRequest(f"Input must be an object, got {type(input_data).__name__}") + try: + return input_cls(**input_data) + except ValidationError as e: + raise ValidationFailed("Input validation failed", details={"errors": e.errors()}) from e + + +def _serialize(result: Any) -> Any: + return to_jsonable_python(result) + + +async def _run(view: Any, validated: Any) -> Any: + try: + return await view.acall(validated) + except NotImplementedError as e: + raise NotImplementedYet(str(e) or "Not implemented") from e + except MizanError: + raise + except Exception as e: + raise InternalError(str(e)) from e + + +def _canonical_bytes(data: Any) -> bytes: + return json.dumps(data, sort_keys=True).encode("utf-8") + + +# ─── Entry points ─────────────────────────────────────────────────────────── + + +async def dispatch_call(req: DispatchRequest, fn_name: str, cache: CacheOrchestrator) -> DispatchResult: + """Mutation / RPC dispatch.""" + view_class = _resolve_function(fn_name) + meta = getattr(view_class, "_meta", {}) + enforce_auth(req.identity, meta.get("auth"), req.native_request) + + view = view_class(req.native_request) + validated = _validate_input(view.Input, req.args) + result = await _run(view, validated) + + invalidate = resolve_invalidation(view_class, req.args) + header = format_invalidate_header(invalidate) if invalidate else None + if invalidate: + cache.purge(invalidate, user_id_of(req.identity)) + + if meta.get("view_path"): + # Handler returned its own native response; carry it through + the header. + return DispatchResult(kind="view", native_response=result, + invalidate=invalidate, invalidate_header=header) + + serialized = _serialize(result) + return DispatchResult( + kind="rpc", + data=serialized, + invalidate=invalidate, + merge=resolve_merges(view_class, req.args, serialized), + invalidate_header=header, + ) + + +def _effective_policy(fn_names: list[str]) -> tuple[int, int | bool]: + """(effective_rev, effective_cache) across a context's functions.""" + rev = 0 + cache_policy: int | bool = True # True=forever, False=no-store, int=TTL + for fn_name in fn_names: + fn_cls = get_function(fn_name) + if fn_cls is None: + continue + m = getattr(fn_cls, "_meta", {}) + rev = max(rev, m.get("rev", 0)) + fn_cache = m.get("cache", True) + if fn_cache is False: + return rev, False + if isinstance(fn_cache, int): + cache_policy = fn_cache if cache_policy is True else min(cache_policy, fn_cache) + return rev, cache_policy + + +async def dispatch_context(req: DispatchRequest, context_name: str, cache: CacheOrchestrator) -> DispatchResult: + """Bundled context read with origin-cache get/put.""" + groups = get_context_groups() + fn_names = groups.get(context_name) + if not fn_names: + raise NotFound(f"Context '{context_name}' not found") + + params = req.args or {} + rev, cache_policy = _effective_policy(fn_names) + user_id = user_id_of(req.identity) + use_cache = cache.enabled and cache_policy is not False + + if use_cache: + cached = cache.get(context_name, params, user_id, rev) + if cached is not None: + return DispatchResult(kind="context", body_bytes=cached, cache_status="HIT") + + bundle: dict[str, Any] = {} + for fn_name in fn_names: + view_class = _resolve_function(fn_name) + enforce_auth(req.identity, getattr(view_class, "_meta", {}).get("auth"), req.native_request) + view = view_class(req.native_request) + validated = _validate_input(view.Input, {k: v for k, v in params.items() if _declares(view.Input, k)}) + bundle[fn_name] = _serialize(await _run(view, validated)) + + body = _canonical_bytes(bundle) + if use_cache: + cache.put(context_name, params, body, user_id, rev) + return DispatchResult(kind="context", data=bundle, body_bytes=body, + cache_status="MISS" if use_cache else None) + + +def _declares(input_cls: Any, name: str) -> bool: + return bool( + input_cls and input_cls is not BaseModel + and getattr(input_cls, "model_fields", None) + and name in input_cls.model_fields + ) diff --git a/cores/mizan-python/src/mizan_core/errors.py b/cores/mizan-python/src/mizan_core/errors.py new file mode 100644 index 0000000..6aa5f63 --- /dev/null +++ b/cores/mizan-python/src/mizan_core/errors.py @@ -0,0 +1,58 @@ +""" +Canonical protocol-level error taxonomy. + +Dispatch raises these typed exceptions; each backend adapter renders them to +its native response (Django `JsonResponse`, FastAPI exception handler, …). The +shared dispatch core never returns error envelopes — it raises, and the adapter +catches at its boundary. +""" + +from __future__ import annotations + +from enum import Enum +from typing import Any + + +class ErrorCode(str, Enum): + NOT_FOUND = "NOT_FOUND" + BAD_REQUEST = "BAD_REQUEST" + VALIDATION_ERROR = "VALIDATION_ERROR" + UNAUTHORIZED = "UNAUTHORIZED" + FORBIDDEN = "FORBIDDEN" + NOT_IMPLEMENTED = "NOT_IMPLEMENTED" + INTERNAL_ERROR = "INTERNAL_ERROR" + + +STATUS = { + ErrorCode.NOT_FOUND: 404, + ErrorCode.BAD_REQUEST: 400, + ErrorCode.VALIDATION_ERROR: 422, + ErrorCode.UNAUTHORIZED: 401, + ErrorCode.FORBIDDEN: 403, + ErrorCode.NOT_IMPLEMENTED: 501, + ErrorCode.INTERNAL_ERROR: 500, +} + + +class MizanError(Exception): + """Base for protocol-level dispatch errors.""" + + code: ErrorCode = ErrorCode.INTERNAL_ERROR + + def __init__(self, message: str, *, details: dict[str, Any] | None = None) -> None: + super().__init__(message) + self.message = message + self.details = details + + @property + def status_code(self) -> int: + return STATUS[self.code] + + +class NotFound(MizanError): code = ErrorCode.NOT_FOUND # noqa: E701 +class BadRequest(MizanError): code = ErrorCode.BAD_REQUEST # noqa: E701 +class ValidationFailed(MizanError): code = ErrorCode.VALIDATION_ERROR # noqa: E701 +class Unauthorized(MizanError): code = ErrorCode.UNAUTHORIZED # noqa: E701 +class Forbidden(MizanError): code = ErrorCode.FORBIDDEN # noqa: E701 +class NotImplementedYet(MizanError): code = ErrorCode.NOT_IMPLEMENTED # noqa: E701 +class InternalError(MizanError): code = ErrorCode.INTERNAL_ERROR # noqa: E701 diff --git a/cores/mizan-python/src/mizan_core/identity.py b/cores/mizan-python/src/mizan_core/identity.py new file mode 100644 index 0000000..a9abfb9 --- /dev/null +++ b/cores/mizan-python/src/mizan_core/identity.py @@ -0,0 +1,32 @@ +""" +The minimal identity contract the dispatch core reads. + +Auth-guard evaluation and per-user cache scoping need exactly these four +attributes — nothing about how the identity was established. Django's session +`User`, `JWTUser`, `MWTUser`, and any token-user an adapter constructs all +satisfy this structurally; no inheritance required. + +`get_all_permissions()` (Django ORM) is deliberately NOT here — the MWT +permission-key is a Django-side concern, and adding it would force every +adapter to implement a Django-shaped method. +""" + +from __future__ import annotations + +from typing import Protocol, runtime_checkable + + +@runtime_checkable +class Identity(Protocol): + is_authenticated: bool + is_staff: bool + is_superuser: bool + + @property + def pk(self) -> object | None: ... # str | int | None; cache scoping stringifies it + + +def user_id_of(identity: Identity | None) -> str | None: + """The cache-scoping user id — `str(pk)`, or None for anonymous/no-pk.""" + pk = getattr(identity, "pk", None) + return str(pk) if pk is not None else None diff --git a/cores/mizan-python/src/mizan_core/invalidation.py b/cores/mizan-python/src/mizan_core/invalidation.py new file mode 100644 index 0000000..7f88960 --- /dev/null +++ b/cores/mizan-python/src/mizan_core/invalidation.py @@ -0,0 +1,174 @@ +""" +Server-driven invalidation + merge resolution — the adapter-agnostic core. + +This is the canonical implementation (formerly housed in the Django executor). +Every adapter calls `resolve_invalidation` / `resolve_merges` / `format_invalidate_header` +so the wire shape is identical across backends. + +Invalidation entries take one of two shapes: + - a bare context/function name string → broad invalidation + - {"context": name, "params": {...}} → scoped invalidation +Function-level `affects=` resolves to the function NAME as the key (v1 refetches +the whole context anyway). +""" + +from __future__ import annotations + +from typing import Any +from urllib.parse import quote + +from pydantic import BaseModel + +from mizan_core.registry import get_context_groups, get_function +from mizan_core.type_utils import types_match_for_merge + + +__all__ = ["resolve_invalidation", "resolve_merges", "format_invalidate_header"] + + +def _resolve_affects_target(target_name: str) -> tuple[str, str, str | None]: + """Classify an affects target → ("context", name, None) | ("function", name, ctx).""" + groups = get_context_groups() + if target_name in groups: + return ("context", target_name, None) + for ctx_name, fn_names in groups.items(): + if target_name in fn_names: + return ("function", target_name, ctx_name) + # Unknown — treat as a context name (non-context fn, or not-yet-registered). + return ("context", target_name, None) + + +def _context_param_names(context_name: str) -> set[str]: + """Union of Input field names across the functions in a context.""" + param_names: set[str] = set() + for fn_name in get_context_groups().get(context_name, []): + fn_cls = get_function(fn_name) + if fn_cls is None: + continue + input_cls = getattr(fn_cls, "Input", None) + if input_cls and input_cls is not BaseModel and hasattr(input_cls, "model_fields"): + param_names.update(input_cls.model_fields.keys()) + return param_names + + +def resolve_invalidation( + view_class: type | None, + input_data: dict[str, Any] | None = None, +) -> list[str | dict[str, Any]] | None: + """Three-tier auto-scoping over `@client(affects=...)`. None if nothing to invalidate. + + Tier 1: arg-name matching against the context's params → scoped entry. + Tier 2: auth inference — Edge-side, not handled here. + Tier 3: broad fallback — bare name. + """ + if view_class is None: + return None + affects = getattr(view_class, "_meta", {}).get("affects") + if not affects: + return None + + result: list[str | dict[str, Any]] = [] + seen: set[str] = set() + for target in affects: + if target["type"] == "context": + target_name = target["name"] + elif target["type"] == "function" and target.get("context"): + target_name = target["name"] + else: + continue + + if target_name in seen: + continue + seen.add(target_name) + + resolved = _resolve_affects_target(target_name) + ctx_for_params = resolved[2] if resolved[0] == "function" else resolved[1] + + if input_data and ctx_for_params: + context_params = _context_param_names(ctx_for_params) + matched = {k: v for k, v in input_data.items() if k in context_params} + if matched: + result.append({"context": target_name, "params": matched}) + continue + + result.append(target_name) + + return result or None + + +def _resolve_merge_slot(context_name: str, mutation_output: Any) -> str | None: + """The unique function-name slot in a context whose return type matches the mutation output.""" + if mutation_output is None: + return None + matches: list[str] = [] + for fn_name in get_context_groups().get(context_name, []): + fn_cls = get_function(fn_name) + if fn_cls is None: + continue + fn_output = getattr(fn_cls, "Output", None) + if fn_output is not None and types_match_for_merge(fn_output, mutation_output): + matches.append(fn_name) + return matches[0] if len(matches) == 1 else None + + +def resolve_merges( + view_class: type | None, + input_data: dict[str, Any] | None, + result_data: Any, +) -> list[dict[str, Any]] | None: + """Build the `merge` list from `@client(merge=...)`. None when no targets resolve. + + Each entry is `{context, slot, value, params?}`; `slot` is the context-function + whose return type matches the mutation output (server-side type-checked routing, + no client shape inference). Ambiguous/unmatched targets are dropped. + """ + if view_class is None: + return None + targets = getattr(view_class, "_meta", {}).get("merge") or [] + if not targets: + return None + + mutation_output = getattr(view_class, "Output", None) + out: list[dict[str, Any]] = [] + seen: set[str] = set() + for ctx_name in targets: + if ctx_name in seen: + continue + seen.add(ctx_name) + slot = _resolve_merge_slot(ctx_name, mutation_output) + if slot is None: + continue + entry: dict[str, Any] = {"context": ctx_name, "slot": slot, "value": result_data} + if input_data: + matched = {k: v for k, v in input_data.items() if k in _context_param_names(ctx_name)} + if matched: + entry["params"] = matched + out.append(entry) + return out or None + + +def format_invalidate_header(invalidate: list[str | dict[str, Any]]) -> str: + """Serialize invalidation targets to the `X-Mizan-Invalidate` header value. + + Comma-separated contexts; semicolon-separated URL-encoded params per context. + ["user"] → "user" + ["user", "notifications"] → "user, notifications" + [{"context": "user", "params": {"user_id": 5}}] → "user;user_id=5" + [{"context": "search", "params": {"q": "hello world"}}] → "search;q=hello%20world" + """ + parts: list[str] = [] + for entry in invalidate: + if isinstance(entry, str): + parts.append(entry) + elif isinstance(entry, dict): + ctx = entry["context"] + params = entry.get("params", {}) + if params: + param_str = ";".join( + f"{quote(str(k), safe='')}={quote(str(v), safe='')}" + for k, v in sorted(params.items()) + ) + parts.append(f"{ctx};{param_str}") + else: + parts.append(ctx) + return ", ".join(parts) diff --git a/cores/mizan-python/tests/test_dispatch_core.py b/cores/mizan-python/tests/test_dispatch_core.py new file mode 100644 index 0000000..eb59b5e --- /dev/null +++ b/cores/mizan-python/tests/test_dispatch_core.py @@ -0,0 +1,147 @@ +"""Unit tests for the adapter-agnostic dispatch core.""" + +import asyncio + +import pytest +from pydantic import BaseModel + +from mizan_core.auth import AuthConfig, JWTConfig, INVALID, authenticate, create_access_token +from mizan_core.authguard import enforce_auth +from mizan_core.client.function import client +from mizan_core.dispatch import CacheOrchestrator, DispatchRequest, dispatch_call +from mizan_core.errors import Forbidden, Unauthorized +from mizan_core.invalidation import format_invalidate_header, resolve_invalidation +from mizan_core.registry import clear_registry, register + + +class Ident: + def __init__(self, authed=True, staff=False, su=False, pk=1): + self.is_authenticated = authed + self.is_staff = staff + self.is_superuser = su + self.pk = pk + + +# ─── authguard ────────────────────────────────────────────────────────────── + + +def test_auth_required_anonymous(): + with pytest.raises(Unauthorized): + enforce_auth(None, True) + + +def test_auth_required_authenticated(): + enforce_auth(Ident(), True) # no raise + + +def test_auth_staff_denied_then_allowed(): + with pytest.raises(Forbidden): + enforce_auth(Ident(staff=False), "staff") + enforce_auth(Ident(staff=True), "staff") + + +def test_auth_superuser(): + with pytest.raises(Forbidden): + enforce_auth(Ident(su=False), "superuser") + enforce_auth(Ident(su=True), "superuser") + + +def test_auth_callable_false_and_raise(): + with pytest.raises(Forbidden): + enforce_auth(Ident(), lambda r: False) + with pytest.raises(Forbidden, match="custom"): + enforce_auth(Ident(), lambda r: (_ for _ in ()).throw(PermissionError("custom"))) + + +# ─── authenticate / INVALID sentinel ──────────────────────────────────────── + + +def _cfg(): + return AuthConfig(jwt=JWTConfig(private_key="k" * 32, public_key="k" * 32)) + + +def test_authenticate_jwt_ok(): + cfg = _cfg() + tok = create_access_token("7", "sess", cfg.jwt, is_staff=True) + ident = authenticate({"Authorization": f"Bearer {tok}"}, cfg) + assert ident.pk == 7 and ident.is_staff and ident.is_authenticated + + +def test_authenticate_bad_token_is_invalid_sentinel(): + assert authenticate({"Authorization": "Bearer garbage"}, _cfg()) is INVALID + + +def test_authenticate_no_token_is_none(): + assert authenticate({}, _cfg()) is None + + +# ─── invalidation + header ────────────────────────────────────────────────── + + +def test_invalidation_three_tier_and_header(): + clear_registry() + UserCtx = "user" + + class Out(BaseModel): + ok: bool + + @client(context=UserCtx) + def user_profile(request, user_id: int) -> Out: + return Out(ok=True) + + @client(affects=UserCtx) + def update_profile(request, user_id: int, name: str) -> Out: + return Out(ok=True) + + register(user_profile, "user_profile") + register(update_profile, "update_profile") + + # Tier 1: user_id matches context param → scoped + inv = resolve_invalidation(update_profile, {"user_id": 5, "name": "x"}) + assert inv == [{"context": "user", "params": {"user_id": 5}}] + assert format_invalidate_header(inv) == "user;user_id=5" + + # Tier 3: no matching param → broad + inv2 = resolve_invalidation(update_profile, {"name": "x"}) + assert inv2 == ["user"] + clear_registry() + + +# ─── dispatch_call end to end ─────────────────────────────────────────────── + + +def test_dispatch_call_auth_and_invalidation(): + clear_registry() + + class Out(BaseModel): + ok: bool + + @client(context="user") + def user_profile(request, user_id: int) -> Out: + return Out(ok=True) + + @client(affects="user", auth="staff") + def secure_update(request, user_id: int) -> Out: + return Out(ok=True) + + register(user_profile, "user_profile") + register(secure_update, "secure_update") + + cache = CacheOrchestrator(None, None) + + # non-staff rejected + with pytest.raises(Forbidden): + asyncio.run(dispatch_call( + DispatchRequest(identity=Ident(staff=False), args={"user_id": 1}), + "secure_update", cache, + )) + + # staff passes, invalidation resolved + res = asyncio.run(dispatch_call( + DispatchRequest(identity=Ident(staff=True), args={"user_id": 1}), + "secure_update", cache, + )) + assert res.kind == "rpc" and res.data == {"ok": True} + assert res.invalidate == [{"context": "user", "params": {"user_id": 1}}] + assert res.invalidate_header == "user;user_id=1" + clear_registry()