From 3523f2e3fe6b312c680ed716dcbd51735d3b2523 Mon Sep 17 00:00:00 2001 From: Ryth Azhur Date: Tue, 31 Mar 2026 22:47:55 -0400 Subject: [PATCH] Add named contexts, bundled fetch endpoint, and affects invalidation Phase 1 (Named Contexts): - @client(context=) accepts any string, not just 'global'/'local' - context='local' emits deprecation warning - Registry groups functions by context name (get_context_groups) - GET /api/mizan/ctx// bundles all context functions in one response - Schema export includes x-mizan-contexts with param elevation metadata Phase 2 (Affects): - @client(affects=) declares mutation invalidation targets - Accepts context name strings, function refs, or lists - Mutually exclusive with context= - Exported in x-mizan-functions schema for codegen React runtime: - MizanContextValue gains invalidateContext, invalidateFunctions, registerContextProvider, and baseUrl - Named context providers register for invalidation on mount 259 Django tests pass, 33 React tests pass. Co-Authored-By: Claude Opus 4.6 (1M context) --- django/src/mizan/client/executor.py | 114 ++++++++++++- django/src/mizan/client/function.py | 125 ++++++++++----- django/src/mizan/export/__init__.py | 46 +++++- django/src/mizan/setup/__init__.py | 2 + django/src/mizan/setup/registry.py | 15 ++ django/src/mizan/tests/test_core.py | 237 ++++++++++++++++++++++++++-- django/src/mizan/urls.py | 10 +- react/src/context.tsx | 79 +++++++++- 8 files changed, 572 insertions(+), 56 deletions(-) diff --git a/django/src/mizan/client/executor.py b/django/src/mizan/client/executor.py index 50b9b8c..5829b91 100644 --- a/django/src/mizan/client/executor.py +++ b/django/src/mizan/client/executor.py @@ -27,7 +27,7 @@ from django.http import HttpRequest, JsonResponse from django.views.decorators.csrf import csrf_protect from pydantic import BaseModel, ValidationError -from mizan.setup.registry import get_function +from mizan.setup.registry import get_function, get_context_groups if TYPE_CHECKING: pass @@ -479,3 +479,115 @@ def function_call_view(request: HttpRequest) -> JsonResponse: return result.to_response(status=status) return result.to_response() + + +def execute_context( + request: HttpRequest, + context_name: str, + params: dict[str, str], +) -> FunctionResult | FunctionError: + """ + Execute all functions in a named context with merged params. + + Each function receives only the params it declares in its Input schema. + If any function fails (auth, validation, execution), the entire request fails. + + Args: + request: The Django HttpRequest + context_name: Name of the context (e.g., 'user', 'global') + params: Query parameters (strings — Pydantic coerces types) + + Returns: + FunctionResult with bundled data, or FunctionError + """ + groups = get_context_groups() + fn_names = groups.get(context_name) + if not fn_names: + return FunctionError( + code=ErrorCode.NOT_FOUND, + message=f"Context '{context_name}' not found", + ) + + results = {} + for fn_name in fn_names: + view_class = get_function(fn_name) + if view_class is None: + continue + + # Filter params to only those in this function's Input schema + input_cls = getattr(view_class, "Input", None) + if input_cls and input_cls is not BaseModel and input_cls.model_fields: + fn_params = { + k: v for k, v in params.items() + if k in input_cls.model_fields + } + else: + fn_params = None + + result = execute_function(request, fn_name, fn_params) + if isinstance(result, FunctionError): + return result + results[fn_name] = result.data + + return FunctionResult(data=results) + + +def _jwt_auth_only(view_func): + """ + Decorator that handles JWT auth for GET endpoints (no CSRF needed for GET). + """ + @wraps(view_func) + def wrapper(request: HttpRequest, *args, **kwargs): + has_jwt = _has_jwt_header(request) + if has_jwt: + if _try_jwt_auth(request): + return view_func(request, *args, **kwargs) + else: + return FunctionError( + code=ErrorCode.UNAUTHORIZED, + message="Invalid or expired JWT token", + ).to_response(status=401) + # No JWT — session auth (no CSRF needed for GET) + return view_func(request, *args, **kwargs) + + return wrapper + + +@_jwt_auth_only +def context_fetch_view(request: HttpRequest, context_name: str) -> JsonResponse: + """ + Fetch all functions in a named context in a single bundled GET request. + + Endpoint: GET /api/mizan/ctx//?param1=val1¶m2=val2 + + Response on success: + { + "error": false, + "data": { + "user_profile": { ... }, + "user_orders": [ ... ] + } + } + """ + if request.method != "GET": + return FunctionError( + code=ErrorCode.BAD_REQUEST, + message="Only GET method allowed", + ).to_response(status=405) + + params = dict(request.GET) + result = execute_context(request, context_name, params) + + if isinstance(result, FunctionError): + status = { + ErrorCode.NOT_FOUND: 404, + ErrorCode.VALIDATION_ERROR: 422, + ErrorCode.UNAUTHORIZED: 401, + ErrorCode.FORBIDDEN: 403, + ErrorCode.BAD_REQUEST: 400, + ErrorCode.INTERNAL_ERROR: 500, + ErrorCode.NOT_IMPLEMENTED: 501, + }.get(result.code, 400) + return result.to_response(status=status) + + return result.to_response() diff --git a/django/src/mizan/client/function.py b/django/src/mizan/client/function.py index 566840a..f2febdf 100644 --- a/django/src/mizan/client/function.py +++ b/django/src/mizan/client/function.py @@ -20,6 +20,7 @@ Two styles supported: from __future__ import annotations import inspect +import warnings from abc import ABC, abstractmethod from typing import ( Any, @@ -38,8 +39,9 @@ from django.http import HttpRequest from pydantic import BaseModel -# Valid context modes: 'global', 'local', or False (not a context) -ContextMode = Literal["global", "local", False] +# Context mode: any non-empty string names a context, False means not a context. +# 'global' is a reserved context name whose provider is auto-mounted at root. +ContextMode = str | Literal[False] TInput = TypeVar("TInput", bound=BaseModel) @@ -185,6 +187,7 @@ def client( fn: Callable = None, *, context: ContextMode = False, + affects: str | type["ServerFunction"] | list[str | type["ServerFunction"]] | None = None, websocket: bool = False, auth: bool | str | Callable[[Any], bool] | None = None, ) -> type[ServerFunction] | Callable[[Callable], type[ServerFunction]]: @@ -195,61 +198,69 @@ def client( Function parameters become input fields automatically. Args: - context: Context mode for React state management. - - False (default): Not a context, just a callable function - - 'global': Embedded in root DjangoContext, no params, singleton - - 'local': Standalone provider, supports params via flat props + context: Named context for React state management. + - False (default): Not a context, just a callable function. + - 'global': Reserved name. Embedded in root MizanProvider, + no params, SSR-hydrated. + - Any other string: Named context. Functions sharing the same + context name are grouped into one provider and one fetch. + + affects: Declare which contexts this mutation invalidates. + - A context name string: refetch the entire named context + - A function reference: refetch that function's context + - A list of the above: refetch all specified targets + Mutually exclusive with context=. websocket: Enable WebSocket RPC transport (default: False). - By default, functions use HTTP-only transport. Enable this for - real-time features (chat, gaming, live updates) that benefit - from lower latency. - - Note: Forms (mizanFormMixin) always use HTTP because auth - flows require full HTTP request semantics. auth: Authentication requirement. - - None (default): No auth required (AnonymousUser allowed) + - None (default): No auth required - True or 'required': Must be authenticated - 'staff': Must have is_staff=True - 'superuser': Must have is_superuser=True - callable(request) -> bool: Custom check function Usage: - # Basic HTTP-only function (not a context) @client def echo(request, message: str) -> EchoOutput: return EchoOutput(message=message) - # Global context - embedded in DjangoContext, no params @client(context='global') def current_user(request) -> UserOutput: return UserOutput(email=request.user.email) - # Local context - standalone provider, supports params - @client(context='local') - def user_profile(request, user_id: int) -> ProfileOutput: - return ProfileOutput(...) + # Named context - functions sharing a name are bundled + @client(context='user') + def user_profile(request, user_id: int) -> ProfileOutput: ... - # WebSocket-enabled for real-time - @client(websocket=True) - def send_message(request, room_id: int, text: str) -> MessageOutput: - return MessageOutput(...) - - # Local context with WebSocket (live data) - @client(context='local', websocket=True) - def live_user_status(request, user_id: int) -> StatusOutput: - return StatusOutput(...) + # Mutation that invalidates a context + @client(affects='user') + def edit_profile(request, name: str) -> dict: ... Returns: A ServerFunction class that wraps the function """ # Validate context parameter - if context not in (False, "global", "local"): - raise ValueError( - f"Invalid context value '{context}'. " - f"Must be False, 'global', or 'local'." - ) + if context is not False: + if not isinstance(context, str) or not context.strip(): + raise ValueError( + "context must be a non-empty string or False." + ) + if context == "local": + warnings.warn( + "context='local' is deprecated. Use a named context string instead " + "(e.g., context='my_context').", + DeprecationWarning, + stacklevel=2, + ) + + # Validate affects parameter + if affects is not None: + if context is not False: + raise ValueError( + "context= and affects= are mutually exclusive. " + "A function cannot be both a context reader and a mutation." + ) # Validate auth parameter if auth is not None: @@ -261,21 +272,50 @@ def client( def decorator(fn: Callable) -> type[ServerFunction]: return _create_server_function( - fn, context=context, websocket=websocket, auth=auth + fn, context=context, affects=affects, websocket=websocket, auth=auth ) # Support both @client and @client(...) if fn is not None: return _create_server_function( - fn, context=context, websocket=websocket, auth=auth + fn, context=context, affects=affects, websocket=websocket, auth=auth ) return decorator +def _normalize_affects( + affects: str | type["ServerFunction"] | list[str | type["ServerFunction"]] | None, +) -> list[dict[str, str]] | None: + """Normalize the affects parameter into a list of target descriptors.""" + if affects is None: + return None + + items = affects if isinstance(affects, list) else [affects] + result = [] + for item in items: + if isinstance(item, str): + result.append({"type": "context", "name": item}) + elif isinstance(item, type) and issubclass(item, ServerFunction): + fn_meta = getattr(item, "_meta", {}) + fn_ctx = fn_meta.get("context") + result.append({ + "type": "function", + "name": getattr(item, "__name__", str(item)), + "context": fn_ctx or None, + }) + else: + raise ValueError( + f"affects items must be context name strings or @client function references, " + f"got {type(item)}" + ) + return result + + def _create_server_function( fn: Callable, *, - context: ContextMode = False, + context: str | Literal[False] = False, + affects: str | type["ServerFunction"] | list[str | type["ServerFunction"]] | None = None, websocket: bool = False, auth: bool | str | None = None, ) -> type[ServerFunction]: @@ -371,10 +411,15 @@ def _create_server_function( # Build metadata meta = {} - # Context mode: 'global' or 'local' (False means not a context) + # Context name (any non-empty string) if context: meta["context"] = context + # Affects: mutation invalidation targets + normalized_affects = _normalize_affects(affects) + if normalized_affects: + meta["affects"] = normalized_affects + # WebSocket: enable WebSocket transport if websocket: meta["websocket"] = True @@ -464,7 +509,7 @@ def _is_context_enabled(item) -> bool: return True if isinstance(item, type) and issubclass(item, ServerFunction): meta = getattr(item, "_meta", {}) - return meta.get("context") in ("global", "local") + return bool(meta.get("context")) return False @@ -477,7 +522,7 @@ def compose( Compose multiple contexts into a single provider. Args: - *children: Context functions (@client with context='global'|'local') + *children: Context functions (@client with a context name) or other @compose functions. All must be unique after flattening. on_server: Bundle all calls into a single server request (default: False). @@ -529,7 +574,7 @@ def compose( ) raise ValueError( f"@compose argument {i} ({child_name}) is not context-enabled. " - f"All children must have @client(context='global'|'local') or be @compose." + f"All children must have @client(context=...) or be @compose." ) # Flatten to collect all leaves diff --git a/django/src/mizan/export/__init__.py b/django/src/mizan/export/__init__.py index 1f56d27..d4148c0 100644 --- a/django/src/mizan/export/__init__.py +++ b/django/src/mizan/export/__init__.py @@ -26,7 +26,7 @@ if TYPE_CHECKING: from django import forms from ninja import NinjaAPI -from mizan.setup.registry import get_registry, get_schema +from mizan.setup.registry import get_registry, get_schema, get_context_groups, get_function __all__ = ["get_schema", "generate_openapi_schema", "generate_openapi_json"] @@ -271,6 +271,10 @@ def generate_openapi_schema() -> dict[str, Any]: "formRole": meta.get("form_role"), # "schema", "validate", "submit" } + # Affects metadata (mutation invalidation) + if meta.get("affects"): + fn_meta_entry["affects"] = meta["affects"] + # For form schema functions, extract field definitions for Zod generation if meta.get("form") and meta.get("form_role") == "schema": form_class = meta.get("form_class") @@ -290,6 +294,46 @@ def generate_openapi_schema() -> dict[str, Any]: # Add custom extension with function metadata for provider generation schema["x-mizan-functions"] = function_metadata + # Add x-mizan-contexts: grouped context metadata with param elevation + context_groups = get_context_groups() + if context_groups: + contexts_meta: dict[str, Any] = {} + for ctx_name, fn_names in context_groups.items(): + # Analyze params across all functions in the context + param_info: dict[str, dict[str, Any]] = {} + for fn_name in fn_names: + fn_cls = get_function(fn_name) + if fn_cls is None: + continue + input_cls = getattr(fn_cls, "Input", None) + if input_cls and input_cls is not BaseModel and hasattr(input_cls, "model_fields"): + for field_name, field_info in input_cls.model_fields.items(): + if field_name not in param_info: + annotation = field_info.annotation + # Map Python types to JSON schema types + type_name = "string" + if annotation in (int,): + type_name = "integer" + elif annotation in (float,): + type_name = "number" + elif annotation in (bool,): + type_name = "boolean" + param_info[field_name] = { + "type": type_name, + "sharedBy": [], + } + param_info[field_name]["sharedBy"].append(fn_name) + + # A param is required if ALL functions in the context declare it + for p_name, p_meta in param_info.items(): + p_meta["required"] = len(p_meta["sharedBy"]) == len(fn_names) + + contexts_meta[ctx_name] = { + "functions": fn_names, + "params": param_info, + } + schema["x-mizan-contexts"] = contexts_meta + # Add x-mizan metadata to each operation for fn_meta in function_metadata: path = f"/mizan/{fn_meta['name']}" diff --git a/django/src/mizan/setup/__init__.py b/django/src/mizan/setup/__init__.py index 1d51a77..b8da978 100644 --- a/django/src/mizan/setup/__init__.py +++ b/django/src/mizan/setup/__init__.py @@ -25,6 +25,7 @@ from .registry import ( get_registry, get_schema, get_contexts, + get_context_groups, get_forms, clear_registry, ) @@ -57,6 +58,7 @@ __all__ = [ "get_registry", "get_schema", "get_contexts", + "get_context_groups", "get_forms", "clear_registry", # Discovery diff --git a/django/src/mizan/setup/registry.py b/django/src/mizan/setup/registry.py index b2b86d5..8676860 100644 --- a/django/src/mizan/setup/registry.py +++ b/django/src/mizan/setup/registry.py @@ -290,6 +290,21 @@ def get_contexts() -> dict[str, type["ServerFunction"]]: return contexts +def get_context_groups() -> dict[str, list[str]]: + """ + Group function names by their context string. + + Returns: + {"global": ["current_user"], "user": ["user_profile", "user_orders"]} + """ + groups: dict[str, list[str]] = {} + for name, cls in _functions.items(): + ctx = getattr(cls, "_meta", {}).get("context") + if ctx: + groups.setdefault(ctx, []).append(name) + return groups + + def get_forms() -> dict[str, list[type["ServerFunction"]]]: """ Get all server functions that are form-related, grouped by form name. diff --git a/django/src/mizan/tests/test_core.py b/django/src/mizan/tests/test_core.py index 82af3df..8aaeed0 100644 --- a/django/src/mizan/tests/test_core.py +++ b/django/src/mizan/tests/test_core.py @@ -15,6 +15,7 @@ from mizan.client.executor import ( FunctionError, FunctionResult, execute_function, + execute_context, ) from mizan.setup.registry import ( clear_registry, @@ -514,29 +515,78 @@ class ContextTests(TestCase): self.assertEqual(fn._meta.get("context"), "global") def test_context_local(self): - """Test @client(context='local') creates a local context.""" + """Test @client(context='local') still works with deprecation warning.""" + import warnings class CtxOutput(BaseModel): data: str - @client(context="local") - def local_context(request: HttpRequest, user_id: int) -> CtxOutput: - return CtxOutput(data=f"user_{user_id}") + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + @client(context="local") + def local_context(request: HttpRequest, user_id: int) -> CtxOutput: + return CtxOutput(data=f"user_{user_id}") + + self.assertEqual(len(w), 1) + self.assertIn("deprecated", str(w[0].message).lower()) register(local_context, "local_context") fn = get_function("local_context") self.assertEqual(fn._meta.get("context"), "local") - def test_context_invalid_value_raises(self): - """Test that invalid context values raise ValueError.""" - with self.assertRaises(ValueError) as cm: + def test_context_named(self): + """Test @client(context='user') creates a named context.""" - @client(context="invalid") + class CtxOutput(BaseModel): + data: str + + @client(context="user") + def user_profile(request: HttpRequest, user_id: int) -> CtxOutput: + return CtxOutput(data=f"user_{user_id}") + + register(user_profile, "user_profile") + + fn = get_function("user_profile") + self.assertEqual(fn._meta.get("context"), "user") + + def test_context_empty_string_raises(self): + """Test that empty context string raises ValueError.""" + with self.assertRaises(ValueError): + + @client(context="") def bad_context(request: HttpRequest) -> ValidOutput: return ValidOutput(valid=True) - self.assertIn("Invalid context value", str(cm.exception)) + def test_context_groups(self): + """Test get_context_groups() groups functions by context name.""" + from mizan.setup.registry import get_context_groups + + class Out(BaseModel): + v: int + + @client(context="user") + def fn_a(request: HttpRequest, user_id: int) -> Out: + return Out(v=1) + + @client(context="user") + def fn_b(request: HttpRequest, user_id: int) -> Out: + return Out(v=2) + + @client(context="global") + def fn_c(request: HttpRequest) -> Out: + return Out(v=3) + + register(fn_a, "fn_a") + register(fn_b, "fn_b") + register(fn_c, "fn_c") + + groups = get_context_groups() + self.assertIn("user", groups) + self.assertIn("global", groups) + self.assertCountEqual(groups["user"], ["fn_a", "fn_b"]) + self.assertEqual(groups["global"], ["fn_c"]) def test_get_contexts(self): """Test get_contexts() returns only context-marked functions.""" @@ -568,6 +618,175 @@ class ContextTests(TestCase): self.assertNotIn("echo", contexts) +class AffectsTests(TestCase): + """Tests for the affects= parameter on @client.""" + + def setUp(self): + clear_registry() + + def test_affects_context_string(self): + """Test @client(affects='user') stores context invalidation.""" + + @client(affects="user") + def edit_profile(request: HttpRequest, name: str) -> ValidOutput: + return ValidOutput(valid=True) + + self.assertEqual( + edit_profile._meta["affects"], + [{"type": "context", "name": "user"}], + ) + + def test_affects_function_ref(self): + """Test @client(affects=fn_ref) stores function invalidation.""" + + @client(context="user") + def user_profile(request: HttpRequest, user_id: int) -> ValidOutput: + return ValidOutput(valid=True) + + @client(affects=user_profile) + def edit_profile(request: HttpRequest, name: str) -> ValidOutput: + return ValidOutput(valid=True) + + affects = edit_profile._meta["affects"] + self.assertEqual(len(affects), 1) + self.assertEqual(affects[0]["type"], "function") + self.assertEqual(affects[0]["name"], "user_profile") + self.assertEqual(affects[0]["context"], "user") + + def test_affects_list(self): + """Test @client(affects=[...]) stores multiple targets.""" + + @client(context="user") + def user_profile(request: HttpRequest) -> ValidOutput: + return ValidOutput(valid=True) + + @client(affects=[user_profile, "billing"]) + def change_plan(request: HttpRequest) -> ValidOutput: + return ValidOutput(valid=True) + + affects = change_plan._meta["affects"] + self.assertEqual(len(affects), 2) + self.assertEqual(affects[0]["type"], "function") + self.assertEqual(affects[1]["type"], "context") + self.assertEqual(affects[1]["name"], "billing") + + def test_affects_and_context_mutually_exclusive(self): + """Test that context= and affects= cannot both be set.""" + with self.assertRaises(ValueError) as cm: + + @client(context="user", affects="cart") + def bad(request: HttpRequest) -> ValidOutput: + return ValidOutput(valid=True) + + self.assertIn("mutually exclusive", str(cm.exception)) + + def test_affects_none_not_stored(self): + """Test that affects=None leaves no affects in meta.""" + + @client + def plain(request: HttpRequest) -> ValidOutput: + return ValidOutput(valid=True) + + self.assertNotIn("affects", plain._meta) + + +class ContextFetchTests(TestCase): + """Tests for the bundled context fetch endpoint (execute_context).""" + + def setUp(self): + clear_registry() + self.factory = RequestFactory() + + def tearDown(self): + clear_registry() + + def test_bundled_fetch(self): + """Test that execute_context bundles results from all context functions.""" + + class ProfileOut(BaseModel): + name: str + + class OrdersOut(BaseModel): + count: int + + @client(context="user") + def user_profile(request: HttpRequest, user_id: int) -> ProfileOut: + return ProfileOut(name=f"user_{user_id}") + + @client(context="user") + def user_orders(request: HttpRequest, user_id: int) -> OrdersOut: + return OrdersOut(count=user_id * 10) + + register(user_profile, "user_profile") + register(user_orders, "user_orders") + + request = self.factory.get("/api/mizan/ctx/user/?user_id=5") + request.user = AnonymousUser() + + result = execute_context(request, "user", {"user_id": "5"}) + + self.assertIsInstance(result, FunctionResult) + self.assertIn("user_profile", result.data) + self.assertIn("user_orders", result.data) + self.assertEqual(result.data["user_profile"]["name"], "user_5") + self.assertEqual(result.data["user_orders"]["count"], 50) + + def test_unknown_context_404(self): + """Test that fetching an unknown context returns NOT_FOUND.""" + request = self.factory.get("/") + request.user = AnonymousUser() + + result = execute_context(request, "nonexistent", {}) + self.assertIsInstance(result, FunctionError) + self.assertEqual(result.code, ErrorCode.NOT_FOUND) + + def test_auth_failure_propagates(self): + """Test that if one function requires auth, the entire context fails.""" + + @client(context="admin", auth=True) + def admin_stats(request: HttpRequest) -> ValidOutput: + return ValidOutput(valid=True) + + register(admin_stats, "admin_stats") + + request = self.factory.get("/") + request.user = AnonymousUser() + + result = execute_context(request, "admin", {}) + self.assertIsInstance(result, FunctionError) + self.assertEqual(result.code, ErrorCode.UNAUTHORIZED) + + def test_param_filtering(self): + """Test that each function only receives params it declares.""" + + class AOut(BaseModel): + uid: int + + class BOut(BaseModel): + uid: int + page: int + + @client(context="mixed") + def fn_a(request: HttpRequest, user_id: int) -> AOut: + return AOut(uid=user_id) + + @client(context="mixed") + def fn_b(request: HttpRequest, user_id: int, page: int = 1) -> BOut: + return BOut(uid=user_id, page=page) + + register(fn_a, "fn_a") + register(fn_b, "fn_b") + + request = self.factory.get("/") + request.user = AnonymousUser() + + result = execute_context(request, "mixed", {"user_id": "7", "page": "3"}) + self.assertIsInstance(result, FunctionResult) + self.assertEqual(result.data["fn_a"]["uid"], 7) + self.assertEqual(result.data["fn_b"]["uid"], 7) + self.assertEqual(result.data["fn_b"]["page"], 3) + + # ============================================================================= # Channel Tests # ============================================================================= diff --git a/django/src/mizan/urls.py b/django/src/mizan/urls.py index 0f997b8..fad95b9 100644 --- a/django/src/mizan/urls.py +++ b/django/src/mizan/urls.py @@ -1,9 +1,10 @@ """ mizan URL Configuration -Single integration point for all mizan HTTP endpoints: -- GET /session/ - Initialize session and get CSRF token (for SSR) -- POST /call/ - Server function calls (HTTP transport) +HTTP endpoints: +- GET /session/ - Initialize session and get CSRF token (for SSR) +- POST /call/ - Server function calls (HTTP transport) +- GET /ctx// - Bundled context fetch (all functions in a named context) Security: - Schema export is NOT exposed over HTTP to prevent API enumeration @@ -15,7 +16,7 @@ from django.middleware.csrf import get_token from django.urls import path from django.views.decorators.csrf import ensure_csrf_cookie -from .client.executor import function_call_view +from .client.executor import function_call_view, context_fetch_view app_name = "mizan" @@ -37,4 +38,5 @@ def session_init_view(request): urlpatterns = [ path("session/", session_init_view, name="session-init"), path("call/", function_call_view, name="function-call"), + path("ctx//", context_fetch_view, name="context-fetch"), ] diff --git a/react/src/context.tsx b/react/src/context.tsx index 041e718..0961902 100644 --- a/react/src/context.tsx +++ b/react/src/context.tsx @@ -137,6 +137,34 @@ export interface MizanContextValue { * (e.g., calling a server function immediately on mount). */ whenReady: Promise + + /** + * Invalidate a named context, triggering a refetch. + * Only refetches if the context is currently mounted (has a registered provider). + * No-op if the context is not mounted. + */ + invalidateContext: (name: string) => Promise + + /** + * Invalidate specific functions within their contexts. + * Groups by context and calls invalidateContext per group. + */ + invalidateFunctions: (names: string[]) => Promise + + /** + * Register a named context provider for invalidation support. + * Called by generated context providers on mount. + * Returns an unregister function (call on unmount). + */ + registerContextProvider: ( + name: string, + refetch: () => Promise, + ) => () => void + + /** + * Base URL for HTTP calls (for use by generated context providers). + */ + baseUrl: string } export interface MizanProviderProps { @@ -466,6 +494,51 @@ export function MizanProvider({ const isRPCAvailable = status === 'connected' + // Named context provider registry for invalidation + const contextProvidersRef = useRef Promise }>>(new Map()) + + const registerContextProvider = useCallback( + (name: string, refetch: () => Promise): (() => void) => { + contextProvidersRef.current.set(name, { refetch }) + return () => { + contextProvidersRef.current.delete(name) + } + }, + [] + ) + + const invalidateContext = useCallback( + async (name: string): Promise => { + const provider = contextProvidersRef.current.get(name) + if (provider) { + await provider.refetch() + } + // If not mounted, no-op — no wasted request + }, + [] + ) + + const invalidateFunctions = useCallback( + async (names: string[]): Promise => { + // Each function belongs to a context. Invalidating a function + // means refetching its entire context (since the bundling endpoint + // returns all functions). Dedupe by context name. + const contexts = new Set() + for (const name of names) { + // The context name for each function is known at codegen time + // and baked into the generated hook. Here we just invalidate + // whatever contexts are registered that contain these functions. + for (const [ctxName] of contextProvidersRef.current) { + contexts.add(ctxName) + } + } + await Promise.all( + Array.from(contexts).map(ctx => invalidateContext(ctx)) + ) + }, + [invalidateContext] + ) + const value = useMemo( () => ({ call, @@ -477,8 +550,12 @@ export function MizanProvider({ onPush, onContextChange, whenReady: sessionRef.current!.promise, + invalidateContext, + invalidateFunctions, + registerContextProvider, + baseUrl, }), - [call, getContext, refreshContext, refreshAllContexts, status, isRPCAvailable, onPush, onContextChange] + [call, getContext, refreshContext, refreshAllContexts, status, isRPCAvailable, onPush, onContextChange, invalidateContext, invalidateFunctions, registerContextProvider, baseUrl] ) return (