Add named contexts, bundled fetch endpoint, and affects invalidation

Phase 1 (Named Contexts):
- @client(context=) accepts any string, not just 'global'/'local'
- context='local' emits deprecation warning
- Registry groups functions by context name (get_context_groups)
- GET /api/mizan/ctx/<name>/ bundles all context functions in one response
- Schema export includes x-mizan-contexts with param elevation metadata

Phase 2 (Affects):
- @client(affects=) declares mutation invalidation targets
- Accepts context name strings, function refs, or lists
- Mutually exclusive with context=
- Exported in x-mizan-functions schema for codegen

React runtime:
- MizanContextValue gains invalidateContext, invalidateFunctions,
  registerContextProvider, and baseUrl
- Named context providers register for invalidation on mount

259 Django tests pass, 33 React tests pass.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-03-31 22:47:55 -04:00
parent f3c225ef49
commit 3523f2e3fe
8 changed files with 572 additions and 56 deletions

View File

@@ -27,7 +27,7 @@ from django.http import HttpRequest, JsonResponse
from django.views.decorators.csrf import csrf_protect
from pydantic import BaseModel, ValidationError
from mizan.setup.registry import get_function
from mizan.setup.registry import get_function, get_context_groups
if TYPE_CHECKING:
pass
@@ -479,3 +479,115 @@ def function_call_view(request: HttpRequest) -> JsonResponse:
return result.to_response(status=status)
return result.to_response()
def execute_context(
request: HttpRequest,
context_name: str,
params: dict[str, str],
) -> FunctionResult | FunctionError:
"""
Execute all functions in a named context with merged params.
Each function receives only the params it declares in its Input schema.
If any function fails (auth, validation, execution), the entire request fails.
Args:
request: The Django HttpRequest
context_name: Name of the context (e.g., 'user', 'global')
params: Query parameters (strings — Pydantic coerces types)
Returns:
FunctionResult with bundled data, or FunctionError
"""
groups = get_context_groups()
fn_names = groups.get(context_name)
if not fn_names:
return FunctionError(
code=ErrorCode.NOT_FOUND,
message=f"Context '{context_name}' not found",
)
results = {}
for fn_name in fn_names:
view_class = get_function(fn_name)
if view_class is None:
continue
# Filter params to only those in this function's Input schema
input_cls = getattr(view_class, "Input", None)
if input_cls and input_cls is not BaseModel and input_cls.model_fields:
fn_params = {
k: v for k, v in params.items()
if k in input_cls.model_fields
}
else:
fn_params = None
result = execute_function(request, fn_name, fn_params)
if isinstance(result, FunctionError):
return result
results[fn_name] = result.data
return FunctionResult(data=results)
def _jwt_auth_only(view_func):
"""
Decorator that handles JWT auth for GET endpoints (no CSRF needed for GET).
"""
@wraps(view_func)
def wrapper(request: HttpRequest, *args, **kwargs):
has_jwt = _has_jwt_header(request)
if has_jwt:
if _try_jwt_auth(request):
return view_func(request, *args, **kwargs)
else:
return FunctionError(
code=ErrorCode.UNAUTHORIZED,
message="Invalid or expired JWT token",
).to_response(status=401)
# No JWT — session auth (no CSRF needed for GET)
return view_func(request, *args, **kwargs)
return wrapper
@_jwt_auth_only
def context_fetch_view(request: HttpRequest, context_name: str) -> JsonResponse:
"""
Fetch all functions in a named context in a single bundled GET request.
Endpoint: GET /api/mizan/ctx/<context_name>/?param1=val1&param2=val2
Response on success:
{
"error": false,
"data": {
"user_profile": { ... },
"user_orders": [ ... ]
}
}
"""
if request.method != "GET":
return FunctionError(
code=ErrorCode.BAD_REQUEST,
message="Only GET method allowed",
).to_response(status=405)
params = dict(request.GET)
result = execute_context(request, context_name, params)
if isinstance(result, FunctionError):
status = {
ErrorCode.NOT_FOUND: 404,
ErrorCode.VALIDATION_ERROR: 422,
ErrorCode.UNAUTHORIZED: 401,
ErrorCode.FORBIDDEN: 403,
ErrorCode.BAD_REQUEST: 400,
ErrorCode.INTERNAL_ERROR: 500,
ErrorCode.NOT_IMPLEMENTED: 501,
}.get(result.code, 400)
return result.to_response(status=status)
return result.to_response()

View File

@@ -20,6 +20,7 @@ Two styles supported:
from __future__ import annotations
import inspect
import warnings
from abc import ABC, abstractmethod
from typing import (
Any,
@@ -38,8 +39,9 @@ from django.http import HttpRequest
from pydantic import BaseModel
# Valid context modes: 'global', 'local', or False (not a context)
ContextMode = Literal["global", "local", False]
# Context mode: any non-empty string names a context, False means not a context.
# 'global' is a reserved context name whose provider is auto-mounted at root.
ContextMode = str | Literal[False]
TInput = TypeVar("TInput", bound=BaseModel)
@@ -185,6 +187,7 @@ def client(
fn: Callable = None,
*,
context: ContextMode = False,
affects: str | type["ServerFunction"] | list[str | type["ServerFunction"]] | None = None,
websocket: bool = False,
auth: bool | str | Callable[[Any], bool] | None = None,
) -> type[ServerFunction] | Callable[[Callable], type[ServerFunction]]:
@@ -195,61 +198,69 @@ def client(
Function parameters become input fields automatically.
Args:
context: Context mode for React state management.
- False (default): Not a context, just a callable function
- 'global': Embedded in root DjangoContext, no params, singleton
- 'local': Standalone provider, supports params via flat props
context: Named context for React state management.
- False (default): Not a context, just a callable function.
- 'global': Reserved name. Embedded in root MizanProvider,
no params, SSR-hydrated.
- Any other string: Named context. Functions sharing the same
context name are grouped into one provider and one fetch.
affects: Declare which contexts this mutation invalidates.
- A context name string: refetch the entire named context
- A function reference: refetch that function's context
- A list of the above: refetch all specified targets
Mutually exclusive with context=.
websocket: Enable WebSocket RPC transport (default: False).
By default, functions use HTTP-only transport. Enable this for
real-time features (chat, gaming, live updates) that benefit
from lower latency.
Note: Forms (mizanFormMixin) always use HTTP because auth
flows require full HTTP request semantics.
auth: Authentication requirement.
- None (default): No auth required (AnonymousUser allowed)
- None (default): No auth required
- True or 'required': Must be authenticated
- 'staff': Must have is_staff=True
- 'superuser': Must have is_superuser=True
- callable(request) -> bool: Custom check function
Usage:
# Basic HTTP-only function (not a context)
@client
def echo(request, message: str) -> EchoOutput:
return EchoOutput(message=message)
# Global context - embedded in DjangoContext, no params
@client(context='global')
def current_user(request) -> UserOutput:
return UserOutput(email=request.user.email)
# Local context - standalone provider, supports params
@client(context='local')
def user_profile(request, user_id: int) -> ProfileOutput:
return ProfileOutput(...)
# Named context - functions sharing a name are bundled
@client(context='user')
def user_profile(request, user_id: int) -> ProfileOutput: ...
# WebSocket-enabled for real-time
@client(websocket=True)
def send_message(request, room_id: int, text: str) -> MessageOutput:
return MessageOutput(...)
# Local context with WebSocket (live data)
@client(context='local', websocket=True)
def live_user_status(request, user_id: int) -> StatusOutput:
return StatusOutput(...)
# Mutation that invalidates a context
@client(affects='user')
def edit_profile(request, name: str) -> dict: ...
Returns:
A ServerFunction class that wraps the function
"""
# Validate context parameter
if context not in (False, "global", "local"):
raise ValueError(
f"Invalid context value '{context}'. "
f"Must be False, 'global', or 'local'."
)
if context is not False:
if not isinstance(context, str) or not context.strip():
raise ValueError(
"context must be a non-empty string or False."
)
if context == "local":
warnings.warn(
"context='local' is deprecated. Use a named context string instead "
"(e.g., context='my_context').",
DeprecationWarning,
stacklevel=2,
)
# Validate affects parameter
if affects is not None:
if context is not False:
raise ValueError(
"context= and affects= are mutually exclusive. "
"A function cannot be both a context reader and a mutation."
)
# Validate auth parameter
if auth is not None:
@@ -261,21 +272,50 @@ def client(
def decorator(fn: Callable) -> type[ServerFunction]:
return _create_server_function(
fn, context=context, websocket=websocket, auth=auth
fn, context=context, affects=affects, websocket=websocket, auth=auth
)
# Support both @client and @client(...)
if fn is not None:
return _create_server_function(
fn, context=context, websocket=websocket, auth=auth
fn, context=context, affects=affects, websocket=websocket, auth=auth
)
return decorator
def _normalize_affects(
affects: str | type["ServerFunction"] | list[str | type["ServerFunction"]] | None,
) -> list[dict[str, str]] | None:
"""Normalize the affects parameter into a list of target descriptors."""
if affects is None:
return None
items = affects if isinstance(affects, list) else [affects]
result = []
for item in items:
if isinstance(item, str):
result.append({"type": "context", "name": item})
elif isinstance(item, type) and issubclass(item, ServerFunction):
fn_meta = getattr(item, "_meta", {})
fn_ctx = fn_meta.get("context")
result.append({
"type": "function",
"name": getattr(item, "__name__", str(item)),
"context": fn_ctx or None,
})
else:
raise ValueError(
f"affects items must be context name strings or @client function references, "
f"got {type(item)}"
)
return result
def _create_server_function(
fn: Callable,
*,
context: ContextMode = False,
context: str | Literal[False] = False,
affects: str | type["ServerFunction"] | list[str | type["ServerFunction"]] | None = None,
websocket: bool = False,
auth: bool | str | None = None,
) -> type[ServerFunction]:
@@ -371,10 +411,15 @@ def _create_server_function(
# Build metadata
meta = {}
# Context mode: 'global' or 'local' (False means not a context)
# Context name (any non-empty string)
if context:
meta["context"] = context
# Affects: mutation invalidation targets
normalized_affects = _normalize_affects(affects)
if normalized_affects:
meta["affects"] = normalized_affects
# WebSocket: enable WebSocket transport
if websocket:
meta["websocket"] = True
@@ -464,7 +509,7 @@ def _is_context_enabled(item) -> bool:
return True
if isinstance(item, type) and issubclass(item, ServerFunction):
meta = getattr(item, "_meta", {})
return meta.get("context") in ("global", "local")
return bool(meta.get("context"))
return False
@@ -477,7 +522,7 @@ def compose(
Compose multiple contexts into a single provider.
Args:
*children: Context functions (@client with context='global'|'local')
*children: Context functions (@client with a context name)
or other @compose functions. All must be unique after flattening.
on_server: Bundle all calls into a single server request (default: False).
@@ -529,7 +574,7 @@ def compose(
)
raise ValueError(
f"@compose argument {i} ({child_name}) is not context-enabled. "
f"All children must have @client(context='global'|'local') or be @compose."
f"All children must have @client(context=...) or be @compose."
)
# Flatten to collect all leaves

View File

@@ -26,7 +26,7 @@ if TYPE_CHECKING:
from django import forms
from ninja import NinjaAPI
from mizan.setup.registry import get_registry, get_schema
from mizan.setup.registry import get_registry, get_schema, get_context_groups, get_function
__all__ = ["get_schema", "generate_openapi_schema", "generate_openapi_json"]
@@ -271,6 +271,10 @@ def generate_openapi_schema() -> dict[str, Any]:
"formRole": meta.get("form_role"), # "schema", "validate", "submit"
}
# Affects metadata (mutation invalidation)
if meta.get("affects"):
fn_meta_entry["affects"] = meta["affects"]
# For form schema functions, extract field definitions for Zod generation
if meta.get("form") and meta.get("form_role") == "schema":
form_class = meta.get("form_class")
@@ -290,6 +294,46 @@ def generate_openapi_schema() -> dict[str, Any]:
# Add custom extension with function metadata for provider generation
schema["x-mizan-functions"] = function_metadata
# Add x-mizan-contexts: grouped context metadata with param elevation
context_groups = get_context_groups()
if context_groups:
contexts_meta: dict[str, Any] = {}
for ctx_name, fn_names in context_groups.items():
# Analyze params across all functions in the context
param_info: dict[str, dict[str, Any]] = {}
for fn_name in fn_names:
fn_cls = get_function(fn_name)
if fn_cls is None:
continue
input_cls = getattr(fn_cls, "Input", None)
if input_cls and input_cls is not BaseModel and hasattr(input_cls, "model_fields"):
for field_name, field_info in input_cls.model_fields.items():
if field_name not in param_info:
annotation = field_info.annotation
# Map Python types to JSON schema types
type_name = "string"
if annotation in (int,):
type_name = "integer"
elif annotation in (float,):
type_name = "number"
elif annotation in (bool,):
type_name = "boolean"
param_info[field_name] = {
"type": type_name,
"sharedBy": [],
}
param_info[field_name]["sharedBy"].append(fn_name)
# A param is required if ALL functions in the context declare it
for p_name, p_meta in param_info.items():
p_meta["required"] = len(p_meta["sharedBy"]) == len(fn_names)
contexts_meta[ctx_name] = {
"functions": fn_names,
"params": param_info,
}
schema["x-mizan-contexts"] = contexts_meta
# Add x-mizan metadata to each operation
for fn_meta in function_metadata:
path = f"/mizan/{fn_meta['name']}"

View File

@@ -25,6 +25,7 @@ from .registry import (
get_registry,
get_schema,
get_contexts,
get_context_groups,
get_forms,
clear_registry,
)
@@ -57,6 +58,7 @@ __all__ = [
"get_registry",
"get_schema",
"get_contexts",
"get_context_groups",
"get_forms",
"clear_registry",
# Discovery

View File

@@ -290,6 +290,21 @@ def get_contexts() -> dict[str, type["ServerFunction"]]:
return contexts
def get_context_groups() -> dict[str, list[str]]:
"""
Group function names by their context string.
Returns:
{"global": ["current_user"], "user": ["user_profile", "user_orders"]}
"""
groups: dict[str, list[str]] = {}
for name, cls in _functions.items():
ctx = getattr(cls, "_meta", {}).get("context")
if ctx:
groups.setdefault(ctx, []).append(name)
return groups
def get_forms() -> dict[str, list[type["ServerFunction"]]]:
"""
Get all server functions that are form-related, grouped by form name.

View File

@@ -15,6 +15,7 @@ from mizan.client.executor import (
FunctionError,
FunctionResult,
execute_function,
execute_context,
)
from mizan.setup.registry import (
clear_registry,
@@ -514,29 +515,78 @@ class ContextTests(TestCase):
self.assertEqual(fn._meta.get("context"), "global")
def test_context_local(self):
"""Test @client(context='local') creates a local context."""
"""Test @client(context='local') still works with deprecation warning."""
import warnings
class CtxOutput(BaseModel):
data: str
@client(context="local")
def local_context(request: HttpRequest, user_id: int) -> CtxOutput:
return CtxOutput(data=f"user_{user_id}")
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
@client(context="local")
def local_context(request: HttpRequest, user_id: int) -> CtxOutput:
return CtxOutput(data=f"user_{user_id}")
self.assertEqual(len(w), 1)
self.assertIn("deprecated", str(w[0].message).lower())
register(local_context, "local_context")
fn = get_function("local_context")
self.assertEqual(fn._meta.get("context"), "local")
def test_context_invalid_value_raises(self):
"""Test that invalid context values raise ValueError."""
with self.assertRaises(ValueError) as cm:
def test_context_named(self):
"""Test @client(context='user') creates a named context."""
@client(context="invalid")
class CtxOutput(BaseModel):
data: str
@client(context="user")
def user_profile(request: HttpRequest, user_id: int) -> CtxOutput:
return CtxOutput(data=f"user_{user_id}")
register(user_profile, "user_profile")
fn = get_function("user_profile")
self.assertEqual(fn._meta.get("context"), "user")
def test_context_empty_string_raises(self):
"""Test that empty context string raises ValueError."""
with self.assertRaises(ValueError):
@client(context="")
def bad_context(request: HttpRequest) -> ValidOutput:
return ValidOutput(valid=True)
self.assertIn("Invalid context value", str(cm.exception))
def test_context_groups(self):
"""Test get_context_groups() groups functions by context name."""
from mizan.setup.registry import get_context_groups
class Out(BaseModel):
v: int
@client(context="user")
def fn_a(request: HttpRequest, user_id: int) -> Out:
return Out(v=1)
@client(context="user")
def fn_b(request: HttpRequest, user_id: int) -> Out:
return Out(v=2)
@client(context="global")
def fn_c(request: HttpRequest) -> Out:
return Out(v=3)
register(fn_a, "fn_a")
register(fn_b, "fn_b")
register(fn_c, "fn_c")
groups = get_context_groups()
self.assertIn("user", groups)
self.assertIn("global", groups)
self.assertCountEqual(groups["user"], ["fn_a", "fn_b"])
self.assertEqual(groups["global"], ["fn_c"])
def test_get_contexts(self):
"""Test get_contexts() returns only context-marked functions."""
@@ -568,6 +618,175 @@ class ContextTests(TestCase):
self.assertNotIn("echo", contexts)
class AffectsTests(TestCase):
"""Tests for the affects= parameter on @client."""
def setUp(self):
clear_registry()
def test_affects_context_string(self):
"""Test @client(affects='user') stores context invalidation."""
@client(affects="user")
def edit_profile(request: HttpRequest, name: str) -> ValidOutput:
return ValidOutput(valid=True)
self.assertEqual(
edit_profile._meta["affects"],
[{"type": "context", "name": "user"}],
)
def test_affects_function_ref(self):
"""Test @client(affects=fn_ref) stores function invalidation."""
@client(context="user")
def user_profile(request: HttpRequest, user_id: int) -> ValidOutput:
return ValidOutput(valid=True)
@client(affects=user_profile)
def edit_profile(request: HttpRequest, name: str) -> ValidOutput:
return ValidOutput(valid=True)
affects = edit_profile._meta["affects"]
self.assertEqual(len(affects), 1)
self.assertEqual(affects[0]["type"], "function")
self.assertEqual(affects[0]["name"], "user_profile")
self.assertEqual(affects[0]["context"], "user")
def test_affects_list(self):
"""Test @client(affects=[...]) stores multiple targets."""
@client(context="user")
def user_profile(request: HttpRequest) -> ValidOutput:
return ValidOutput(valid=True)
@client(affects=[user_profile, "billing"])
def change_plan(request: HttpRequest) -> ValidOutput:
return ValidOutput(valid=True)
affects = change_plan._meta["affects"]
self.assertEqual(len(affects), 2)
self.assertEqual(affects[0]["type"], "function")
self.assertEqual(affects[1]["type"], "context")
self.assertEqual(affects[1]["name"], "billing")
def test_affects_and_context_mutually_exclusive(self):
"""Test that context= and affects= cannot both be set."""
with self.assertRaises(ValueError) as cm:
@client(context="user", affects="cart")
def bad(request: HttpRequest) -> ValidOutput:
return ValidOutput(valid=True)
self.assertIn("mutually exclusive", str(cm.exception))
def test_affects_none_not_stored(self):
"""Test that affects=None leaves no affects in meta."""
@client
def plain(request: HttpRequest) -> ValidOutput:
return ValidOutput(valid=True)
self.assertNotIn("affects", plain._meta)
class ContextFetchTests(TestCase):
"""Tests for the bundled context fetch endpoint (execute_context)."""
def setUp(self):
clear_registry()
self.factory = RequestFactory()
def tearDown(self):
clear_registry()
def test_bundled_fetch(self):
"""Test that execute_context bundles results from all context functions."""
class ProfileOut(BaseModel):
name: str
class OrdersOut(BaseModel):
count: int
@client(context="user")
def user_profile(request: HttpRequest, user_id: int) -> ProfileOut:
return ProfileOut(name=f"user_{user_id}")
@client(context="user")
def user_orders(request: HttpRequest, user_id: int) -> OrdersOut:
return OrdersOut(count=user_id * 10)
register(user_profile, "user_profile")
register(user_orders, "user_orders")
request = self.factory.get("/api/mizan/ctx/user/?user_id=5")
request.user = AnonymousUser()
result = execute_context(request, "user", {"user_id": "5"})
self.assertIsInstance(result, FunctionResult)
self.assertIn("user_profile", result.data)
self.assertIn("user_orders", result.data)
self.assertEqual(result.data["user_profile"]["name"], "user_5")
self.assertEqual(result.data["user_orders"]["count"], 50)
def test_unknown_context_404(self):
"""Test that fetching an unknown context returns NOT_FOUND."""
request = self.factory.get("/")
request.user = AnonymousUser()
result = execute_context(request, "nonexistent", {})
self.assertIsInstance(result, FunctionError)
self.assertEqual(result.code, ErrorCode.NOT_FOUND)
def test_auth_failure_propagates(self):
"""Test that if one function requires auth, the entire context fails."""
@client(context="admin", auth=True)
def admin_stats(request: HttpRequest) -> ValidOutput:
return ValidOutput(valid=True)
register(admin_stats, "admin_stats")
request = self.factory.get("/")
request.user = AnonymousUser()
result = execute_context(request, "admin", {})
self.assertIsInstance(result, FunctionError)
self.assertEqual(result.code, ErrorCode.UNAUTHORIZED)
def test_param_filtering(self):
"""Test that each function only receives params it declares."""
class AOut(BaseModel):
uid: int
class BOut(BaseModel):
uid: int
page: int
@client(context="mixed")
def fn_a(request: HttpRequest, user_id: int) -> AOut:
return AOut(uid=user_id)
@client(context="mixed")
def fn_b(request: HttpRequest, user_id: int, page: int = 1) -> BOut:
return BOut(uid=user_id, page=page)
register(fn_a, "fn_a")
register(fn_b, "fn_b")
request = self.factory.get("/")
request.user = AnonymousUser()
result = execute_context(request, "mixed", {"user_id": "7", "page": "3"})
self.assertIsInstance(result, FunctionResult)
self.assertEqual(result.data["fn_a"]["uid"], 7)
self.assertEqual(result.data["fn_b"]["uid"], 7)
self.assertEqual(result.data["fn_b"]["page"], 3)
# =============================================================================
# Channel Tests
# =============================================================================

View File

@@ -1,9 +1,10 @@
"""
mizan URL Configuration
Single integration point for all mizan HTTP endpoints:
- GET /session/ - Initialize session and get CSRF token (for SSR)
- POST /call/ - Server function calls (HTTP transport)
HTTP endpoints:
- GET /session/ - Initialize session and get CSRF token (for SSR)
- POST /call/ - Server function calls (HTTP transport)
- GET /ctx/<name>/ - Bundled context fetch (all functions in a named context)
Security:
- Schema export is NOT exposed over HTTP to prevent API enumeration
@@ -15,7 +16,7 @@ from django.middleware.csrf import get_token
from django.urls import path
from django.views.decorators.csrf import ensure_csrf_cookie
from .client.executor import function_call_view
from .client.executor import function_call_view, context_fetch_view
app_name = "mizan"
@@ -37,4 +38,5 @@ def session_init_view(request):
urlpatterns = [
path("session/", session_init_view, name="session-init"),
path("call/", function_call_view, name="function-call"),
path("ctx/<str:context_name>/", context_fetch_view, name="context-fetch"),
]

View File

@@ -137,6 +137,34 @@ export interface MizanContextValue {
* (e.g., calling a server function immediately on mount).
*/
whenReady: Promise<void>
/**
* Invalidate a named context, triggering a refetch.
* Only refetches if the context is currently mounted (has a registered provider).
* No-op if the context is not mounted.
*/
invalidateContext: (name: string) => Promise<void>
/**
* Invalidate specific functions within their contexts.
* Groups by context and calls invalidateContext per group.
*/
invalidateFunctions: (names: string[]) => Promise<void>
/**
* Register a named context provider for invalidation support.
* Called by generated context providers on mount.
* Returns an unregister function (call on unmount).
*/
registerContextProvider: (
name: string,
refetch: () => Promise<void>,
) => () => void
/**
* Base URL for HTTP calls (for use by generated context providers).
*/
baseUrl: string
}
export interface MizanProviderProps {
@@ -466,6 +494,51 @@ export function MizanProvider({
const isRPCAvailable = status === 'connected'
// Named context provider registry for invalidation
const contextProvidersRef = useRef<Map<string, { refetch: () => Promise<void> }>>(new Map())
const registerContextProvider = useCallback(
(name: string, refetch: () => Promise<void>): (() => void) => {
contextProvidersRef.current.set(name, { refetch })
return () => {
contextProvidersRef.current.delete(name)
}
},
[]
)
const invalidateContext = useCallback(
async (name: string): Promise<void> => {
const provider = contextProvidersRef.current.get(name)
if (provider) {
await provider.refetch()
}
// If not mounted, no-op — no wasted request
},
[]
)
const invalidateFunctions = useCallback(
async (names: string[]): Promise<void> => {
// Each function belongs to a context. Invalidating a function
// means refetching its entire context (since the bundling endpoint
// returns all functions). Dedupe by context name.
const contexts = new Set<string>()
for (const name of names) {
// The context name for each function is known at codegen time
// and baked into the generated hook. Here we just invalidate
// whatever contexts are registered that contain these functions.
for (const [ctxName] of contextProvidersRef.current) {
contexts.add(ctxName)
}
}
await Promise.all(
Array.from(contexts).map(ctx => invalidateContext(ctx))
)
},
[invalidateContext]
)
const value = useMemo<MizanContextValue>(
() => ({
call,
@@ -477,8 +550,12 @@ export function MizanProvider({
onPush,
onContextChange,
whenReady: sessionRef.current!.promise,
invalidateContext,
invalidateFunctions,
registerContextProvider,
baseUrl,
}),
[call, getContext, refreshContext, refreshAllContexts, status, isRPCAvailable, onPush, onContextChange]
[call, getContext, refreshContext, refreshAllContexts, status, isRPCAvailable, onPush, onContextChange, invalidateContext, invalidateFunctions, registerContextProvider, baseUrl]
)
return (