Add named contexts, bundled fetch endpoint, and affects invalidation
Phase 1 (Named Contexts): - @client(context=) accepts any string, not just 'global'/'local' - context='local' emits deprecation warning - Registry groups functions by context name (get_context_groups) - GET /api/mizan/ctx/<name>/ bundles all context functions in one response - Schema export includes x-mizan-contexts with param elevation metadata Phase 2 (Affects): - @client(affects=) declares mutation invalidation targets - Accepts context name strings, function refs, or lists - Mutually exclusive with context= - Exported in x-mizan-functions schema for codegen React runtime: - MizanContextValue gains invalidateContext, invalidateFunctions, registerContextProvider, and baseUrl - Named context providers register for invalidation on mount 259 Django tests pass, 33 React tests pass. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -27,7 +27,7 @@ from django.http import HttpRequest, JsonResponse
|
||||
from django.views.decorators.csrf import csrf_protect
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from mizan.setup.registry import get_function
|
||||
from mizan.setup.registry import get_function, get_context_groups
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
@@ -479,3 +479,115 @@ def function_call_view(request: HttpRequest) -> JsonResponse:
|
||||
return result.to_response(status=status)
|
||||
|
||||
return result.to_response()
|
||||
|
||||
|
||||
def execute_context(
|
||||
request: HttpRequest,
|
||||
context_name: str,
|
||||
params: dict[str, str],
|
||||
) -> FunctionResult | FunctionError:
|
||||
"""
|
||||
Execute all functions in a named context with merged params.
|
||||
|
||||
Each function receives only the params it declares in its Input schema.
|
||||
If any function fails (auth, validation, execution), the entire request fails.
|
||||
|
||||
Args:
|
||||
request: The Django HttpRequest
|
||||
context_name: Name of the context (e.g., 'user', 'global')
|
||||
params: Query parameters (strings — Pydantic coerces types)
|
||||
|
||||
Returns:
|
||||
FunctionResult with bundled data, or FunctionError
|
||||
"""
|
||||
groups = get_context_groups()
|
||||
fn_names = groups.get(context_name)
|
||||
if not fn_names:
|
||||
return FunctionError(
|
||||
code=ErrorCode.NOT_FOUND,
|
||||
message=f"Context '{context_name}' not found",
|
||||
)
|
||||
|
||||
results = {}
|
||||
for fn_name in fn_names:
|
||||
view_class = get_function(fn_name)
|
||||
if view_class is None:
|
||||
continue
|
||||
|
||||
# Filter params to only those in this function's Input schema
|
||||
input_cls = getattr(view_class, "Input", None)
|
||||
if input_cls and input_cls is not BaseModel and input_cls.model_fields:
|
||||
fn_params = {
|
||||
k: v for k, v in params.items()
|
||||
if k in input_cls.model_fields
|
||||
}
|
||||
else:
|
||||
fn_params = None
|
||||
|
||||
result = execute_function(request, fn_name, fn_params)
|
||||
if isinstance(result, FunctionError):
|
||||
return result
|
||||
results[fn_name] = result.data
|
||||
|
||||
return FunctionResult(data=results)
|
||||
|
||||
|
||||
def _jwt_auth_only(view_func):
|
||||
"""
|
||||
Decorator that handles JWT auth for GET endpoints (no CSRF needed for GET).
|
||||
"""
|
||||
@wraps(view_func)
|
||||
def wrapper(request: HttpRequest, *args, **kwargs):
|
||||
has_jwt = _has_jwt_header(request)
|
||||
if has_jwt:
|
||||
if _try_jwt_auth(request):
|
||||
return view_func(request, *args, **kwargs)
|
||||
else:
|
||||
return FunctionError(
|
||||
code=ErrorCode.UNAUTHORIZED,
|
||||
message="Invalid or expired JWT token",
|
||||
).to_response(status=401)
|
||||
# No JWT — session auth (no CSRF needed for GET)
|
||||
return view_func(request, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
@_jwt_auth_only
|
||||
def context_fetch_view(request: HttpRequest, context_name: str) -> JsonResponse:
|
||||
"""
|
||||
Fetch all functions in a named context in a single bundled GET request.
|
||||
|
||||
Endpoint: GET /api/mizan/ctx/<context_name>/?param1=val1¶m2=val2
|
||||
|
||||
Response on success:
|
||||
{
|
||||
"error": false,
|
||||
"data": {
|
||||
"user_profile": { ... },
|
||||
"user_orders": [ ... ]
|
||||
}
|
||||
}
|
||||
"""
|
||||
if request.method != "GET":
|
||||
return FunctionError(
|
||||
code=ErrorCode.BAD_REQUEST,
|
||||
message="Only GET method allowed",
|
||||
).to_response(status=405)
|
||||
|
||||
params = dict(request.GET)
|
||||
result = execute_context(request, context_name, params)
|
||||
|
||||
if isinstance(result, FunctionError):
|
||||
status = {
|
||||
ErrorCode.NOT_FOUND: 404,
|
||||
ErrorCode.VALIDATION_ERROR: 422,
|
||||
ErrorCode.UNAUTHORIZED: 401,
|
||||
ErrorCode.FORBIDDEN: 403,
|
||||
ErrorCode.BAD_REQUEST: 400,
|
||||
ErrorCode.INTERNAL_ERROR: 500,
|
||||
ErrorCode.NOT_IMPLEMENTED: 501,
|
||||
}.get(result.code, 400)
|
||||
return result.to_response(status=status)
|
||||
|
||||
return result.to_response()
|
||||
|
||||
@@ -20,6 +20,7 @@ Two styles supported:
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import (
|
||||
Any,
|
||||
@@ -38,8 +39,9 @@ from django.http import HttpRequest
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
# Valid context modes: 'global', 'local', or False (not a context)
|
||||
ContextMode = Literal["global", "local", False]
|
||||
# Context mode: any non-empty string names a context, False means not a context.
|
||||
# 'global' is a reserved context name whose provider is auto-mounted at root.
|
||||
ContextMode = str | Literal[False]
|
||||
|
||||
|
||||
TInput = TypeVar("TInput", bound=BaseModel)
|
||||
@@ -185,6 +187,7 @@ def client(
|
||||
fn: Callable = None,
|
||||
*,
|
||||
context: ContextMode = False,
|
||||
affects: str | type["ServerFunction"] | list[str | type["ServerFunction"]] | None = None,
|
||||
websocket: bool = False,
|
||||
auth: bool | str | Callable[[Any], bool] | None = None,
|
||||
) -> type[ServerFunction] | Callable[[Callable], type[ServerFunction]]:
|
||||
@@ -195,61 +198,69 @@ def client(
|
||||
Function parameters become input fields automatically.
|
||||
|
||||
Args:
|
||||
context: Context mode for React state management.
|
||||
- False (default): Not a context, just a callable function
|
||||
- 'global': Embedded in root DjangoContext, no params, singleton
|
||||
- 'local': Standalone provider, supports params via flat props
|
||||
context: Named context for React state management.
|
||||
- False (default): Not a context, just a callable function.
|
||||
- 'global': Reserved name. Embedded in root MizanProvider,
|
||||
no params, SSR-hydrated.
|
||||
- Any other string: Named context. Functions sharing the same
|
||||
context name are grouped into one provider and one fetch.
|
||||
|
||||
affects: Declare which contexts this mutation invalidates.
|
||||
- A context name string: refetch the entire named context
|
||||
- A function reference: refetch that function's context
|
||||
- A list of the above: refetch all specified targets
|
||||
Mutually exclusive with context=.
|
||||
|
||||
websocket: Enable WebSocket RPC transport (default: False).
|
||||
By default, functions use HTTP-only transport. Enable this for
|
||||
real-time features (chat, gaming, live updates) that benefit
|
||||
from lower latency.
|
||||
|
||||
Note: Forms (mizanFormMixin) always use HTTP because auth
|
||||
flows require full HTTP request semantics.
|
||||
|
||||
auth: Authentication requirement.
|
||||
- None (default): No auth required (AnonymousUser allowed)
|
||||
- None (default): No auth required
|
||||
- True or 'required': Must be authenticated
|
||||
- 'staff': Must have is_staff=True
|
||||
- 'superuser': Must have is_superuser=True
|
||||
- callable(request) -> bool: Custom check function
|
||||
|
||||
Usage:
|
||||
# Basic HTTP-only function (not a context)
|
||||
@client
|
||||
def echo(request, message: str) -> EchoOutput:
|
||||
return EchoOutput(message=message)
|
||||
|
||||
# Global context - embedded in DjangoContext, no params
|
||||
@client(context='global')
|
||||
def current_user(request) -> UserOutput:
|
||||
return UserOutput(email=request.user.email)
|
||||
|
||||
# Local context - standalone provider, supports params
|
||||
@client(context='local')
|
||||
def user_profile(request, user_id: int) -> ProfileOutput:
|
||||
return ProfileOutput(...)
|
||||
# Named context - functions sharing a name are bundled
|
||||
@client(context='user')
|
||||
def user_profile(request, user_id: int) -> ProfileOutput: ...
|
||||
|
||||
# WebSocket-enabled for real-time
|
||||
@client(websocket=True)
|
||||
def send_message(request, room_id: int, text: str) -> MessageOutput:
|
||||
return MessageOutput(...)
|
||||
|
||||
# Local context with WebSocket (live data)
|
||||
@client(context='local', websocket=True)
|
||||
def live_user_status(request, user_id: int) -> StatusOutput:
|
||||
return StatusOutput(...)
|
||||
# Mutation that invalidates a context
|
||||
@client(affects='user')
|
||||
def edit_profile(request, name: str) -> dict: ...
|
||||
|
||||
Returns:
|
||||
A ServerFunction class that wraps the function
|
||||
"""
|
||||
# Validate context parameter
|
||||
if context not in (False, "global", "local"):
|
||||
raise ValueError(
|
||||
f"Invalid context value '{context}'. "
|
||||
f"Must be False, 'global', or 'local'."
|
||||
)
|
||||
if context is not False:
|
||||
if not isinstance(context, str) or not context.strip():
|
||||
raise ValueError(
|
||||
"context must be a non-empty string or False."
|
||||
)
|
||||
if context == "local":
|
||||
warnings.warn(
|
||||
"context='local' is deprecated. Use a named context string instead "
|
||||
"(e.g., context='my_context').",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# Validate affects parameter
|
||||
if affects is not None:
|
||||
if context is not False:
|
||||
raise ValueError(
|
||||
"context= and affects= are mutually exclusive. "
|
||||
"A function cannot be both a context reader and a mutation."
|
||||
)
|
||||
|
||||
# Validate auth parameter
|
||||
if auth is not None:
|
||||
@@ -261,21 +272,50 @@ def client(
|
||||
|
||||
def decorator(fn: Callable) -> type[ServerFunction]:
|
||||
return _create_server_function(
|
||||
fn, context=context, websocket=websocket, auth=auth
|
||||
fn, context=context, affects=affects, websocket=websocket, auth=auth
|
||||
)
|
||||
|
||||
# Support both @client and @client(...)
|
||||
if fn is not None:
|
||||
return _create_server_function(
|
||||
fn, context=context, websocket=websocket, auth=auth
|
||||
fn, context=context, affects=affects, websocket=websocket, auth=auth
|
||||
)
|
||||
return decorator
|
||||
|
||||
|
||||
def _normalize_affects(
|
||||
affects: str | type["ServerFunction"] | list[str | type["ServerFunction"]] | None,
|
||||
) -> list[dict[str, str]] | None:
|
||||
"""Normalize the affects parameter into a list of target descriptors."""
|
||||
if affects is None:
|
||||
return None
|
||||
|
||||
items = affects if isinstance(affects, list) else [affects]
|
||||
result = []
|
||||
for item in items:
|
||||
if isinstance(item, str):
|
||||
result.append({"type": "context", "name": item})
|
||||
elif isinstance(item, type) and issubclass(item, ServerFunction):
|
||||
fn_meta = getattr(item, "_meta", {})
|
||||
fn_ctx = fn_meta.get("context")
|
||||
result.append({
|
||||
"type": "function",
|
||||
"name": getattr(item, "__name__", str(item)),
|
||||
"context": fn_ctx or None,
|
||||
})
|
||||
else:
|
||||
raise ValueError(
|
||||
f"affects items must be context name strings or @client function references, "
|
||||
f"got {type(item)}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def _create_server_function(
|
||||
fn: Callable,
|
||||
*,
|
||||
context: ContextMode = False,
|
||||
context: str | Literal[False] = False,
|
||||
affects: str | type["ServerFunction"] | list[str | type["ServerFunction"]] | None = None,
|
||||
websocket: bool = False,
|
||||
auth: bool | str | None = None,
|
||||
) -> type[ServerFunction]:
|
||||
@@ -371,10 +411,15 @@ def _create_server_function(
|
||||
# Build metadata
|
||||
meta = {}
|
||||
|
||||
# Context mode: 'global' or 'local' (False means not a context)
|
||||
# Context name (any non-empty string)
|
||||
if context:
|
||||
meta["context"] = context
|
||||
|
||||
# Affects: mutation invalidation targets
|
||||
normalized_affects = _normalize_affects(affects)
|
||||
if normalized_affects:
|
||||
meta["affects"] = normalized_affects
|
||||
|
||||
# WebSocket: enable WebSocket transport
|
||||
if websocket:
|
||||
meta["websocket"] = True
|
||||
@@ -464,7 +509,7 @@ def _is_context_enabled(item) -> bool:
|
||||
return True
|
||||
if isinstance(item, type) and issubclass(item, ServerFunction):
|
||||
meta = getattr(item, "_meta", {})
|
||||
return meta.get("context") in ("global", "local")
|
||||
return bool(meta.get("context"))
|
||||
return False
|
||||
|
||||
|
||||
@@ -477,7 +522,7 @@ def compose(
|
||||
Compose multiple contexts into a single provider.
|
||||
|
||||
Args:
|
||||
*children: Context functions (@client with context='global'|'local')
|
||||
*children: Context functions (@client with a context name)
|
||||
or other @compose functions. All must be unique after flattening.
|
||||
|
||||
on_server: Bundle all calls into a single server request (default: False).
|
||||
@@ -529,7 +574,7 @@ def compose(
|
||||
)
|
||||
raise ValueError(
|
||||
f"@compose argument {i} ({child_name}) is not context-enabled. "
|
||||
f"All children must have @client(context='global'|'local') or be @compose."
|
||||
f"All children must have @client(context=...) or be @compose."
|
||||
)
|
||||
|
||||
# Flatten to collect all leaves
|
||||
|
||||
@@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
||||
from django import forms
|
||||
from ninja import NinjaAPI
|
||||
|
||||
from mizan.setup.registry import get_registry, get_schema
|
||||
from mizan.setup.registry import get_registry, get_schema, get_context_groups, get_function
|
||||
|
||||
|
||||
__all__ = ["get_schema", "generate_openapi_schema", "generate_openapi_json"]
|
||||
@@ -271,6 +271,10 @@ def generate_openapi_schema() -> dict[str, Any]:
|
||||
"formRole": meta.get("form_role"), # "schema", "validate", "submit"
|
||||
}
|
||||
|
||||
# Affects metadata (mutation invalidation)
|
||||
if meta.get("affects"):
|
||||
fn_meta_entry["affects"] = meta["affects"]
|
||||
|
||||
# For form schema functions, extract field definitions for Zod generation
|
||||
if meta.get("form") and meta.get("form_role") == "schema":
|
||||
form_class = meta.get("form_class")
|
||||
@@ -290,6 +294,46 @@ def generate_openapi_schema() -> dict[str, Any]:
|
||||
# Add custom extension with function metadata for provider generation
|
||||
schema["x-mizan-functions"] = function_metadata
|
||||
|
||||
# Add x-mizan-contexts: grouped context metadata with param elevation
|
||||
context_groups = get_context_groups()
|
||||
if context_groups:
|
||||
contexts_meta: dict[str, Any] = {}
|
||||
for ctx_name, fn_names in context_groups.items():
|
||||
# Analyze params across all functions in the context
|
||||
param_info: dict[str, dict[str, Any]] = {}
|
||||
for fn_name in fn_names:
|
||||
fn_cls = get_function(fn_name)
|
||||
if fn_cls is None:
|
||||
continue
|
||||
input_cls = getattr(fn_cls, "Input", None)
|
||||
if input_cls and input_cls is not BaseModel and hasattr(input_cls, "model_fields"):
|
||||
for field_name, field_info in input_cls.model_fields.items():
|
||||
if field_name not in param_info:
|
||||
annotation = field_info.annotation
|
||||
# Map Python types to JSON schema types
|
||||
type_name = "string"
|
||||
if annotation in (int,):
|
||||
type_name = "integer"
|
||||
elif annotation in (float,):
|
||||
type_name = "number"
|
||||
elif annotation in (bool,):
|
||||
type_name = "boolean"
|
||||
param_info[field_name] = {
|
||||
"type": type_name,
|
||||
"sharedBy": [],
|
||||
}
|
||||
param_info[field_name]["sharedBy"].append(fn_name)
|
||||
|
||||
# A param is required if ALL functions in the context declare it
|
||||
for p_name, p_meta in param_info.items():
|
||||
p_meta["required"] = len(p_meta["sharedBy"]) == len(fn_names)
|
||||
|
||||
contexts_meta[ctx_name] = {
|
||||
"functions": fn_names,
|
||||
"params": param_info,
|
||||
}
|
||||
schema["x-mizan-contexts"] = contexts_meta
|
||||
|
||||
# Add x-mizan metadata to each operation
|
||||
for fn_meta in function_metadata:
|
||||
path = f"/mizan/{fn_meta['name']}"
|
||||
|
||||
@@ -25,6 +25,7 @@ from .registry import (
|
||||
get_registry,
|
||||
get_schema,
|
||||
get_contexts,
|
||||
get_context_groups,
|
||||
get_forms,
|
||||
clear_registry,
|
||||
)
|
||||
@@ -57,6 +58,7 @@ __all__ = [
|
||||
"get_registry",
|
||||
"get_schema",
|
||||
"get_contexts",
|
||||
"get_context_groups",
|
||||
"get_forms",
|
||||
"clear_registry",
|
||||
# Discovery
|
||||
|
||||
@@ -290,6 +290,21 @@ def get_contexts() -> dict[str, type["ServerFunction"]]:
|
||||
return contexts
|
||||
|
||||
|
||||
def get_context_groups() -> dict[str, list[str]]:
|
||||
"""
|
||||
Group function names by their context string.
|
||||
|
||||
Returns:
|
||||
{"global": ["current_user"], "user": ["user_profile", "user_orders"]}
|
||||
"""
|
||||
groups: dict[str, list[str]] = {}
|
||||
for name, cls in _functions.items():
|
||||
ctx = getattr(cls, "_meta", {}).get("context")
|
||||
if ctx:
|
||||
groups.setdefault(ctx, []).append(name)
|
||||
return groups
|
||||
|
||||
|
||||
def get_forms() -> dict[str, list[type["ServerFunction"]]]:
|
||||
"""
|
||||
Get all server functions that are form-related, grouped by form name.
|
||||
|
||||
@@ -15,6 +15,7 @@ from mizan.client.executor import (
|
||||
FunctionError,
|
||||
FunctionResult,
|
||||
execute_function,
|
||||
execute_context,
|
||||
)
|
||||
from mizan.setup.registry import (
|
||||
clear_registry,
|
||||
@@ -514,29 +515,78 @@ class ContextTests(TestCase):
|
||||
self.assertEqual(fn._meta.get("context"), "global")
|
||||
|
||||
def test_context_local(self):
|
||||
"""Test @client(context='local') creates a local context."""
|
||||
"""Test @client(context='local') still works with deprecation warning."""
|
||||
import warnings
|
||||
|
||||
class CtxOutput(BaseModel):
|
||||
data: str
|
||||
|
||||
@client(context="local")
|
||||
def local_context(request: HttpRequest, user_id: int) -> CtxOutput:
|
||||
return CtxOutput(data=f"user_{user_id}")
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
|
||||
@client(context="local")
|
||||
def local_context(request: HttpRequest, user_id: int) -> CtxOutput:
|
||||
return CtxOutput(data=f"user_{user_id}")
|
||||
|
||||
self.assertEqual(len(w), 1)
|
||||
self.assertIn("deprecated", str(w[0].message).lower())
|
||||
|
||||
register(local_context, "local_context")
|
||||
|
||||
fn = get_function("local_context")
|
||||
self.assertEqual(fn._meta.get("context"), "local")
|
||||
|
||||
def test_context_invalid_value_raises(self):
|
||||
"""Test that invalid context values raise ValueError."""
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
def test_context_named(self):
|
||||
"""Test @client(context='user') creates a named context."""
|
||||
|
||||
@client(context="invalid")
|
||||
class CtxOutput(BaseModel):
|
||||
data: str
|
||||
|
||||
@client(context="user")
|
||||
def user_profile(request: HttpRequest, user_id: int) -> CtxOutput:
|
||||
return CtxOutput(data=f"user_{user_id}")
|
||||
|
||||
register(user_profile, "user_profile")
|
||||
|
||||
fn = get_function("user_profile")
|
||||
self.assertEqual(fn._meta.get("context"), "user")
|
||||
|
||||
def test_context_empty_string_raises(self):
|
||||
"""Test that empty context string raises ValueError."""
|
||||
with self.assertRaises(ValueError):
|
||||
|
||||
@client(context="")
|
||||
def bad_context(request: HttpRequest) -> ValidOutput:
|
||||
return ValidOutput(valid=True)
|
||||
|
||||
self.assertIn("Invalid context value", str(cm.exception))
|
||||
def test_context_groups(self):
|
||||
"""Test get_context_groups() groups functions by context name."""
|
||||
from mizan.setup.registry import get_context_groups
|
||||
|
||||
class Out(BaseModel):
|
||||
v: int
|
||||
|
||||
@client(context="user")
|
||||
def fn_a(request: HttpRequest, user_id: int) -> Out:
|
||||
return Out(v=1)
|
||||
|
||||
@client(context="user")
|
||||
def fn_b(request: HttpRequest, user_id: int) -> Out:
|
||||
return Out(v=2)
|
||||
|
||||
@client(context="global")
|
||||
def fn_c(request: HttpRequest) -> Out:
|
||||
return Out(v=3)
|
||||
|
||||
register(fn_a, "fn_a")
|
||||
register(fn_b, "fn_b")
|
||||
register(fn_c, "fn_c")
|
||||
|
||||
groups = get_context_groups()
|
||||
self.assertIn("user", groups)
|
||||
self.assertIn("global", groups)
|
||||
self.assertCountEqual(groups["user"], ["fn_a", "fn_b"])
|
||||
self.assertEqual(groups["global"], ["fn_c"])
|
||||
|
||||
def test_get_contexts(self):
|
||||
"""Test get_contexts() returns only context-marked functions."""
|
||||
@@ -568,6 +618,175 @@ class ContextTests(TestCase):
|
||||
self.assertNotIn("echo", contexts)
|
||||
|
||||
|
||||
class AffectsTests(TestCase):
|
||||
"""Tests for the affects= parameter on @client."""
|
||||
|
||||
def setUp(self):
|
||||
clear_registry()
|
||||
|
||||
def test_affects_context_string(self):
|
||||
"""Test @client(affects='user') stores context invalidation."""
|
||||
|
||||
@client(affects="user")
|
||||
def edit_profile(request: HttpRequest, name: str) -> ValidOutput:
|
||||
return ValidOutput(valid=True)
|
||||
|
||||
self.assertEqual(
|
||||
edit_profile._meta["affects"],
|
||||
[{"type": "context", "name": "user"}],
|
||||
)
|
||||
|
||||
def test_affects_function_ref(self):
|
||||
"""Test @client(affects=fn_ref) stores function invalidation."""
|
||||
|
||||
@client(context="user")
|
||||
def user_profile(request: HttpRequest, user_id: int) -> ValidOutput:
|
||||
return ValidOutput(valid=True)
|
||||
|
||||
@client(affects=user_profile)
|
||||
def edit_profile(request: HttpRequest, name: str) -> ValidOutput:
|
||||
return ValidOutput(valid=True)
|
||||
|
||||
affects = edit_profile._meta["affects"]
|
||||
self.assertEqual(len(affects), 1)
|
||||
self.assertEqual(affects[0]["type"], "function")
|
||||
self.assertEqual(affects[0]["name"], "user_profile")
|
||||
self.assertEqual(affects[0]["context"], "user")
|
||||
|
||||
def test_affects_list(self):
|
||||
"""Test @client(affects=[...]) stores multiple targets."""
|
||||
|
||||
@client(context="user")
|
||||
def user_profile(request: HttpRequest) -> ValidOutput:
|
||||
return ValidOutput(valid=True)
|
||||
|
||||
@client(affects=[user_profile, "billing"])
|
||||
def change_plan(request: HttpRequest) -> ValidOutput:
|
||||
return ValidOutput(valid=True)
|
||||
|
||||
affects = change_plan._meta["affects"]
|
||||
self.assertEqual(len(affects), 2)
|
||||
self.assertEqual(affects[0]["type"], "function")
|
||||
self.assertEqual(affects[1]["type"], "context")
|
||||
self.assertEqual(affects[1]["name"], "billing")
|
||||
|
||||
def test_affects_and_context_mutually_exclusive(self):
|
||||
"""Test that context= and affects= cannot both be set."""
|
||||
with self.assertRaises(ValueError) as cm:
|
||||
|
||||
@client(context="user", affects="cart")
|
||||
def bad(request: HttpRequest) -> ValidOutput:
|
||||
return ValidOutput(valid=True)
|
||||
|
||||
self.assertIn("mutually exclusive", str(cm.exception))
|
||||
|
||||
def test_affects_none_not_stored(self):
|
||||
"""Test that affects=None leaves no affects in meta."""
|
||||
|
||||
@client
|
||||
def plain(request: HttpRequest) -> ValidOutput:
|
||||
return ValidOutput(valid=True)
|
||||
|
||||
self.assertNotIn("affects", plain._meta)
|
||||
|
||||
|
||||
class ContextFetchTests(TestCase):
|
||||
"""Tests for the bundled context fetch endpoint (execute_context)."""
|
||||
|
||||
def setUp(self):
|
||||
clear_registry()
|
||||
self.factory = RequestFactory()
|
||||
|
||||
def tearDown(self):
|
||||
clear_registry()
|
||||
|
||||
def test_bundled_fetch(self):
|
||||
"""Test that execute_context bundles results from all context functions."""
|
||||
|
||||
class ProfileOut(BaseModel):
|
||||
name: str
|
||||
|
||||
class OrdersOut(BaseModel):
|
||||
count: int
|
||||
|
||||
@client(context="user")
|
||||
def user_profile(request: HttpRequest, user_id: int) -> ProfileOut:
|
||||
return ProfileOut(name=f"user_{user_id}")
|
||||
|
||||
@client(context="user")
|
||||
def user_orders(request: HttpRequest, user_id: int) -> OrdersOut:
|
||||
return OrdersOut(count=user_id * 10)
|
||||
|
||||
register(user_profile, "user_profile")
|
||||
register(user_orders, "user_orders")
|
||||
|
||||
request = self.factory.get("/api/mizan/ctx/user/?user_id=5")
|
||||
request.user = AnonymousUser()
|
||||
|
||||
result = execute_context(request, "user", {"user_id": "5"})
|
||||
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertIn("user_profile", result.data)
|
||||
self.assertIn("user_orders", result.data)
|
||||
self.assertEqual(result.data["user_profile"]["name"], "user_5")
|
||||
self.assertEqual(result.data["user_orders"]["count"], 50)
|
||||
|
||||
def test_unknown_context_404(self):
|
||||
"""Test that fetching an unknown context returns NOT_FOUND."""
|
||||
request = self.factory.get("/")
|
||||
request.user = AnonymousUser()
|
||||
|
||||
result = execute_context(request, "nonexistent", {})
|
||||
self.assertIsInstance(result, FunctionError)
|
||||
self.assertEqual(result.code, ErrorCode.NOT_FOUND)
|
||||
|
||||
def test_auth_failure_propagates(self):
|
||||
"""Test that if one function requires auth, the entire context fails."""
|
||||
|
||||
@client(context="admin", auth=True)
|
||||
def admin_stats(request: HttpRequest) -> ValidOutput:
|
||||
return ValidOutput(valid=True)
|
||||
|
||||
register(admin_stats, "admin_stats")
|
||||
|
||||
request = self.factory.get("/")
|
||||
request.user = AnonymousUser()
|
||||
|
||||
result = execute_context(request, "admin", {})
|
||||
self.assertIsInstance(result, FunctionError)
|
||||
self.assertEqual(result.code, ErrorCode.UNAUTHORIZED)
|
||||
|
||||
def test_param_filtering(self):
|
||||
"""Test that each function only receives params it declares."""
|
||||
|
||||
class AOut(BaseModel):
|
||||
uid: int
|
||||
|
||||
class BOut(BaseModel):
|
||||
uid: int
|
||||
page: int
|
||||
|
||||
@client(context="mixed")
|
||||
def fn_a(request: HttpRequest, user_id: int) -> AOut:
|
||||
return AOut(uid=user_id)
|
||||
|
||||
@client(context="mixed")
|
||||
def fn_b(request: HttpRequest, user_id: int, page: int = 1) -> BOut:
|
||||
return BOut(uid=user_id, page=page)
|
||||
|
||||
register(fn_a, "fn_a")
|
||||
register(fn_b, "fn_b")
|
||||
|
||||
request = self.factory.get("/")
|
||||
request.user = AnonymousUser()
|
||||
|
||||
result = execute_context(request, "mixed", {"user_id": "7", "page": "3"})
|
||||
self.assertIsInstance(result, FunctionResult)
|
||||
self.assertEqual(result.data["fn_a"]["uid"], 7)
|
||||
self.assertEqual(result.data["fn_b"]["uid"], 7)
|
||||
self.assertEqual(result.data["fn_b"]["page"], 3)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Channel Tests
|
||||
# =============================================================================
|
||||
|
||||
@@ -1,9 +1,10 @@
|
||||
"""
|
||||
mizan URL Configuration
|
||||
|
||||
Single integration point for all mizan HTTP endpoints:
|
||||
- GET /session/ - Initialize session and get CSRF token (for SSR)
|
||||
- POST /call/ - Server function calls (HTTP transport)
|
||||
HTTP endpoints:
|
||||
- GET /session/ - Initialize session and get CSRF token (for SSR)
|
||||
- POST /call/ - Server function calls (HTTP transport)
|
||||
- GET /ctx/<name>/ - Bundled context fetch (all functions in a named context)
|
||||
|
||||
Security:
|
||||
- Schema export is NOT exposed over HTTP to prevent API enumeration
|
||||
@@ -15,7 +16,7 @@ from django.middleware.csrf import get_token
|
||||
from django.urls import path
|
||||
from django.views.decorators.csrf import ensure_csrf_cookie
|
||||
|
||||
from .client.executor import function_call_view
|
||||
from .client.executor import function_call_view, context_fetch_view
|
||||
|
||||
app_name = "mizan"
|
||||
|
||||
@@ -37,4 +38,5 @@ def session_init_view(request):
|
||||
urlpatterns = [
|
||||
path("session/", session_init_view, name="session-init"),
|
||||
path("call/", function_call_view, name="function-call"),
|
||||
path("ctx/<str:context_name>/", context_fetch_view, name="context-fetch"),
|
||||
]
|
||||
|
||||
@@ -137,6 +137,34 @@ export interface MizanContextValue {
|
||||
* (e.g., calling a server function immediately on mount).
|
||||
*/
|
||||
whenReady: Promise<void>
|
||||
|
||||
/**
|
||||
* Invalidate a named context, triggering a refetch.
|
||||
* Only refetches if the context is currently mounted (has a registered provider).
|
||||
* No-op if the context is not mounted.
|
||||
*/
|
||||
invalidateContext: (name: string) => Promise<void>
|
||||
|
||||
/**
|
||||
* Invalidate specific functions within their contexts.
|
||||
* Groups by context and calls invalidateContext per group.
|
||||
*/
|
||||
invalidateFunctions: (names: string[]) => Promise<void>
|
||||
|
||||
/**
|
||||
* Register a named context provider for invalidation support.
|
||||
* Called by generated context providers on mount.
|
||||
* Returns an unregister function (call on unmount).
|
||||
*/
|
||||
registerContextProvider: (
|
||||
name: string,
|
||||
refetch: () => Promise<void>,
|
||||
) => () => void
|
||||
|
||||
/**
|
||||
* Base URL for HTTP calls (for use by generated context providers).
|
||||
*/
|
||||
baseUrl: string
|
||||
}
|
||||
|
||||
export interface MizanProviderProps {
|
||||
@@ -466,6 +494,51 @@ export function MizanProvider({
|
||||
|
||||
const isRPCAvailable = status === 'connected'
|
||||
|
||||
// Named context provider registry for invalidation
|
||||
const contextProvidersRef = useRef<Map<string, { refetch: () => Promise<void> }>>(new Map())
|
||||
|
||||
const registerContextProvider = useCallback(
|
||||
(name: string, refetch: () => Promise<void>): (() => void) => {
|
||||
contextProvidersRef.current.set(name, { refetch })
|
||||
return () => {
|
||||
contextProvidersRef.current.delete(name)
|
||||
}
|
||||
},
|
||||
[]
|
||||
)
|
||||
|
||||
const invalidateContext = useCallback(
|
||||
async (name: string): Promise<void> => {
|
||||
const provider = contextProvidersRef.current.get(name)
|
||||
if (provider) {
|
||||
await provider.refetch()
|
||||
}
|
||||
// If not mounted, no-op — no wasted request
|
||||
},
|
||||
[]
|
||||
)
|
||||
|
||||
const invalidateFunctions = useCallback(
|
||||
async (names: string[]): Promise<void> => {
|
||||
// Each function belongs to a context. Invalidating a function
|
||||
// means refetching its entire context (since the bundling endpoint
|
||||
// returns all functions). Dedupe by context name.
|
||||
const contexts = new Set<string>()
|
||||
for (const name of names) {
|
||||
// The context name for each function is known at codegen time
|
||||
// and baked into the generated hook. Here we just invalidate
|
||||
// whatever contexts are registered that contain these functions.
|
||||
for (const [ctxName] of contextProvidersRef.current) {
|
||||
contexts.add(ctxName)
|
||||
}
|
||||
}
|
||||
await Promise.all(
|
||||
Array.from(contexts).map(ctx => invalidateContext(ctx))
|
||||
)
|
||||
},
|
||||
[invalidateContext]
|
||||
)
|
||||
|
||||
const value = useMemo<MizanContextValue>(
|
||||
() => ({
|
||||
call,
|
||||
@@ -477,8 +550,12 @@ export function MizanProvider({
|
||||
onPush,
|
||||
onContextChange,
|
||||
whenReady: sessionRef.current!.promise,
|
||||
invalidateContext,
|
||||
invalidateFunctions,
|
||||
registerContextProvider,
|
||||
baseUrl,
|
||||
}),
|
||||
[call, getContext, refreshContext, refreshAllContexts, status, isRPCAvailable, onPush, onContextChange]
|
||||
[call, getContext, refreshContext, refreshAllContexts, status, isRPCAvailable, onPush, onContextChange, invalidateContext, invalidateFunctions, registerContextProvider, baseUrl]
|
||||
)
|
||||
|
||||
return (
|
||||
|
||||
Reference in New Issue
Block a user