Add ReactContext class for type-safe context and affects declarations
ReactContext('user') creates a reusable context marker that provides
proper linting, find-references, and autocomplete:
UserContext = ReactContext('user')
@client(context=UserContext)
def user_profile(request, user_id: int) -> ProfileShape: ...
@client(affects=UserContext)
def edit_profile(request, name: str) -> dict: ...
@client(affects=[UserContext, OrderContext])
def change_plan(request) -> dict: ...
- ReactContext class with name validation
- GlobalContext built-in instance for context='global'
- affects= accepts ReactContext, lists, strings, or function refs
- Backwards compat: raw strings still work for context= and affects=
- Exported from mizan and mizan.client
Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -88,7 +88,7 @@ from . import forms
|
|||||||
from . import setup
|
from . import setup
|
||||||
from .channels import ReactChannel
|
from .channels import ReactChannel
|
||||||
from .channels import register as register_channel
|
from .channels import register as register_channel
|
||||||
from .client import ComposedContext, ServerFunction, client, compose
|
from .client import ComposedContext, GlobalContext, ReactContext, ServerFunction, client, compose
|
||||||
|
|
||||||
# Shape is lazy-loaded via __getattr__ because django_readers
|
# Shape is lazy-loaded via __getattr__ because django_readers
|
||||||
# imports contenttypes, which can't happen during apps.populate()
|
# imports contenttypes, which can't happen during apps.populate()
|
||||||
@@ -157,9 +157,11 @@ def wrap_asgi(http_application):
|
|||||||
|
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
# Decorators
|
# Decorators & Contexts
|
||||||
"client",
|
"client",
|
||||||
"compose",
|
"compose",
|
||||||
|
"ReactContext",
|
||||||
|
"GlobalContext",
|
||||||
"ServerFunction",
|
"ServerFunction",
|
||||||
"ComposedContext",
|
"ComposedContext",
|
||||||
# Setup
|
# Setup
|
||||||
|
|||||||
@@ -14,6 +14,9 @@ Usage:
|
|||||||
from .function import (
|
from .function import (
|
||||||
# Decorator
|
# Decorator
|
||||||
client,
|
client,
|
||||||
|
# Context markers
|
||||||
|
ReactContext,
|
||||||
|
GlobalContext,
|
||||||
# Base classes
|
# Base classes
|
||||||
ServerFunction,
|
ServerFunction,
|
||||||
ComposedContext,
|
ComposedContext,
|
||||||
@@ -39,6 +42,9 @@ from .executor import (
|
|||||||
__all__ = [
|
__all__ = [
|
||||||
# Decorator
|
# Decorator
|
||||||
"client",
|
"client",
|
||||||
|
# Context markers
|
||||||
|
"ReactContext",
|
||||||
|
"GlobalContext",
|
||||||
# Base classes
|
# Base classes
|
||||||
"ServerFunction",
|
"ServerFunction",
|
||||||
"ComposedContext",
|
"ComposedContext",
|
||||||
|
|||||||
@@ -39,9 +39,46 @@ from django.http import HttpRequest
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
# Context mode: any non-empty string names a context, False means not a context.
|
# =============================================================================
|
||||||
# 'global' is a reserved context name whose provider is auto-mounted at root.
|
# REACT CONTEXT - Named context marker
|
||||||
ContextMode = str | Literal[False]
|
# =============================================================================
|
||||||
|
|
||||||
|
|
||||||
|
class ReactContext:
|
||||||
|
"""
|
||||||
|
A named context that groups server functions into one provider and one fetch.
|
||||||
|
|
||||||
|
Usage:
|
||||||
|
UserContext = ReactContext('user')
|
||||||
|
|
||||||
|
@client(context=UserContext)
|
||||||
|
def user_profile(request, user_id: int) -> ProfileShape: ...
|
||||||
|
|
||||||
|
@client(context=UserContext)
|
||||||
|
def user_orders(request, user_id: int) -> list[OrderShape]: ...
|
||||||
|
|
||||||
|
@client(affects=UserContext)
|
||||||
|
def edit_profile(request, name: str) -> dict: ...
|
||||||
|
|
||||||
|
@client(affects=[UserContext, OrderContext])
|
||||||
|
def change_plan(request) -> dict: ...
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, name: str):
|
||||||
|
if not name or not isinstance(name, str):
|
||||||
|
raise ValueError("ReactContext name must be a non-empty string")
|
||||||
|
self.name = name
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"ReactContext({self.name!r})"
|
||||||
|
|
||||||
|
|
||||||
|
# Built-in global context (auto-mounted at root, SSR-hydrated)
|
||||||
|
GlobalContext = ReactContext("global")
|
||||||
|
|
||||||
|
|
||||||
|
# Context parameter type: a ReactContext instance, a raw string, or False
|
||||||
|
ContextMode = ReactContext | str | Literal[False]
|
||||||
|
|
||||||
|
|
||||||
TInput = TypeVar("TInput", bound=BaseModel)
|
TInput = TypeVar("TInput", bound=BaseModel)
|
||||||
@@ -183,11 +220,37 @@ class _FunctionWrapper(ServerFunction):
|
|||||||
_VALID_AUTH_STRINGS = frozenset({"required", "staff", "superuser"})
|
_VALID_AUTH_STRINGS = frozenset({"required", "staff", "superuser"})
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_context(context: ContextMode) -> str | Literal[False]:
|
||||||
|
"""Resolve a context parameter to its name string."""
|
||||||
|
if context is False:
|
||||||
|
return False
|
||||||
|
if isinstance(context, ReactContext):
|
||||||
|
return context.name
|
||||||
|
if isinstance(context, str):
|
||||||
|
if not context.strip():
|
||||||
|
raise ValueError("context must be a non-empty string, ReactContext, or False.")
|
||||||
|
if context == "local":
|
||||||
|
warnings.warn(
|
||||||
|
"context='local' is deprecated. Use ReactContext('name') instead.",
|
||||||
|
DeprecationWarning,
|
||||||
|
stacklevel=3,
|
||||||
|
)
|
||||||
|
return context
|
||||||
|
raise ValueError(
|
||||||
|
f"context must be a ReactContext, a string, or False. Got {type(context).__name__}."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
# Affects parameter type
|
||||||
|
AffectsTarget = ReactContext | str | type["ServerFunction"]
|
||||||
|
AffectsMode = AffectsTarget | list[AffectsTarget] | None
|
||||||
|
|
||||||
|
|
||||||
def client(
|
def client(
|
||||||
fn: Callable = None,
|
fn: Callable = None,
|
||||||
*,
|
*,
|
||||||
context: ContextMode = False,
|
context: ContextMode = False,
|
||||||
affects: str | type["ServerFunction"] | list[str | type["ServerFunction"]] | None = None,
|
affects: AffectsMode = None,
|
||||||
websocket: bool = False,
|
websocket: bool = False,
|
||||||
auth: bool | str | Callable[[Any], bool] | None = None,
|
auth: bool | str | Callable[[Any], bool] | None = None,
|
||||||
) -> type[ServerFunction] | Callable[[Callable], type[ServerFunction]]:
|
) -> type[ServerFunction] | Callable[[Callable], type[ServerFunction]]:
|
||||||
@@ -200,15 +263,14 @@ def client(
|
|||||||
Args:
|
Args:
|
||||||
context: Named context for React state management.
|
context: Named context for React state management.
|
||||||
- False (default): Not a context, just a callable function.
|
- False (default): Not a context, just a callable function.
|
||||||
- 'global': Reserved name. Embedded in root MizanProvider,
|
- ReactContext instance: groups functions into a named context.
|
||||||
no params, SSR-hydrated.
|
- GlobalContext: reserved, auto-mounted at root, SSR-hydrated.
|
||||||
- Any other string: Named context. Functions sharing the same
|
- Raw string: also accepted (e.g., 'user'), but ReactContext preferred.
|
||||||
context name are grouped into one provider and one fetch.
|
|
||||||
|
|
||||||
affects: Declare which contexts this mutation invalidates.
|
affects: Declare which contexts this mutation invalidates.
|
||||||
- A context name string: refetch the entire named context
|
- A ReactContext instance
|
||||||
- A function reference: refetch that function's context
|
- A list of ReactContext instances
|
||||||
- A list of the above: refetch all specified targets
|
- Also accepts strings or function references for backwards compat
|
||||||
Mutually exclusive with context=.
|
Mutually exclusive with context=.
|
||||||
|
|
||||||
websocket: Enable WebSocket RPC transport (default: False).
|
websocket: Enable WebSocket RPC transport (default: False).
|
||||||
@@ -221,42 +283,29 @@ def client(
|
|||||||
- callable(request) -> bool: Custom check function
|
- callable(request) -> bool: Custom check function
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
@client
|
UserContext = ReactContext('user')
|
||||||
def echo(request, message: str) -> EchoOutput:
|
|
||||||
return EchoOutput(message=message)
|
|
||||||
|
|
||||||
@client(context='global')
|
@client(context=GlobalContext)
|
||||||
def current_user(request) -> UserOutput:
|
def current_user(request) -> UserOutput: ...
|
||||||
return UserOutput(email=request.user.email)
|
|
||||||
|
|
||||||
# Named context - functions sharing a name are bundled
|
@client(context=UserContext)
|
||||||
@client(context='user')
|
|
||||||
def user_profile(request, user_id: int) -> ProfileOutput: ...
|
def user_profile(request, user_id: int) -> ProfileOutput: ...
|
||||||
|
|
||||||
# Mutation that invalidates a context
|
@client(affects=UserContext)
|
||||||
@client(affects='user')
|
|
||||||
def edit_profile(request, name: str) -> dict: ...
|
def edit_profile(request, name: str) -> dict: ...
|
||||||
|
|
||||||
|
@client(affects=[UserContext, OrderContext])
|
||||||
|
def change_plan(request) -> dict: ...
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A ServerFunction class that wraps the function
|
A ServerFunction class that wraps the function
|
||||||
"""
|
"""
|
||||||
# Validate context parameter
|
# Resolve context to name string
|
||||||
if context is not False:
|
resolved_context = _resolve_context(context)
|
||||||
if not isinstance(context, str) or not context.strip():
|
|
||||||
raise ValueError(
|
|
||||||
"context must be a non-empty string or False."
|
|
||||||
)
|
|
||||||
if context == "local":
|
|
||||||
warnings.warn(
|
|
||||||
"context='local' is deprecated. Use a named context string instead "
|
|
||||||
"(e.g., context='my_context').",
|
|
||||||
DeprecationWarning,
|
|
||||||
stacklevel=2,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Validate affects parameter
|
# Validate affects parameter
|
||||||
if affects is not None:
|
if affects is not None:
|
||||||
if context is not False:
|
if resolved_context is not False:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"context= and affects= are mutually exclusive. "
|
"context= and affects= are mutually exclusive. "
|
||||||
"A function cannot be both a context reader and a mutation."
|
"A function cannot be both a context reader and a mutation."
|
||||||
@@ -272,20 +321,18 @@ def client(
|
|||||||
|
|
||||||
def decorator(fn: Callable) -> type[ServerFunction]:
|
def decorator(fn: Callable) -> type[ServerFunction]:
|
||||||
return _create_server_function(
|
return _create_server_function(
|
||||||
fn, context=context, affects=affects, websocket=websocket, auth=auth
|
fn, context=resolved_context, affects=affects, websocket=websocket, auth=auth
|
||||||
)
|
)
|
||||||
|
|
||||||
# Support both @client and @client(...)
|
# Support both @client and @client(...)
|
||||||
if fn is not None:
|
if fn is not None:
|
||||||
return _create_server_function(
|
return _create_server_function(
|
||||||
fn, context=context, affects=affects, websocket=websocket, auth=auth
|
fn, context=resolved_context, affects=affects, websocket=websocket, auth=auth
|
||||||
)
|
)
|
||||||
return decorator
|
return decorator
|
||||||
|
|
||||||
|
|
||||||
def _normalize_affects(
|
def _normalize_affects(affects: AffectsMode) -> list[dict[str, str]] | None:
|
||||||
affects: str | type["ServerFunction"] | list[str | type["ServerFunction"]] | None,
|
|
||||||
) -> list[dict[str, str]] | None:
|
|
||||||
"""Normalize the affects parameter into a list of target descriptors."""
|
"""Normalize the affects parameter into a list of target descriptors."""
|
||||||
if affects is None:
|
if affects is None:
|
||||||
return None
|
return None
|
||||||
@@ -293,7 +340,9 @@ def _normalize_affects(
|
|||||||
items = affects if isinstance(affects, list) else [affects]
|
items = affects if isinstance(affects, list) else [affects]
|
||||||
result = []
|
result = []
|
||||||
for item in items:
|
for item in items:
|
||||||
if isinstance(item, str):
|
if isinstance(item, ReactContext):
|
||||||
|
result.append({"type": "context", "name": item.name})
|
||||||
|
elif isinstance(item, str):
|
||||||
result.append({"type": "context", "name": item})
|
result.append({"type": "context", "name": item})
|
||||||
elif isinstance(item, type) and issubclass(item, ServerFunction):
|
elif isinstance(item, type) and issubclass(item, ServerFunction):
|
||||||
fn_meta = getattr(item, "_meta", {})
|
fn_meta = getattr(item, "_meta", {})
|
||||||
@@ -305,8 +354,8 @@ def _normalize_affects(
|
|||||||
})
|
})
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"affects items must be context name strings or @client function references, "
|
f"affects items must be ReactContext instances, context name strings, "
|
||||||
f"got {type(item)}"
|
f"or @client function references. Got {type(item)}"
|
||||||
)
|
)
|
||||||
return result
|
return result
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from mizan.setup.registry import (
|
|||||||
get_contexts,
|
get_contexts,
|
||||||
get_function,
|
get_function,
|
||||||
)
|
)
|
||||||
from mizan.client import ServerFunction, client
|
from mizan.client import ServerFunction, client, ReactContext, GlobalContext
|
||||||
from mizan.channels import ReactChannel
|
from mizan.channels import ReactChannel
|
||||||
|
|
||||||
|
|
||||||
@@ -536,8 +536,8 @@ class ContextTests(TestCase):
|
|||||||
fn = get_function("local_context")
|
fn = get_function("local_context")
|
||||||
self.assertEqual(fn._meta.get("context"), "local")
|
self.assertEqual(fn._meta.get("context"), "local")
|
||||||
|
|
||||||
def test_context_named(self):
|
def test_context_named_string(self):
|
||||||
"""Test @client(context='user') creates a named context."""
|
"""Test @client(context='user') creates a named context (string form)."""
|
||||||
|
|
||||||
class CtxOutput(BaseModel):
|
class CtxOutput(BaseModel):
|
||||||
data: str
|
data: str
|
||||||
@@ -551,6 +551,37 @@ class ContextTests(TestCase):
|
|||||||
fn = get_function("user_profile")
|
fn = get_function("user_profile")
|
||||||
self.assertEqual(fn._meta.get("context"), "user")
|
self.assertEqual(fn._meta.get("context"), "user")
|
||||||
|
|
||||||
|
def test_context_react_context(self):
|
||||||
|
"""Test @client(context=ReactContext('user')) creates a named context."""
|
||||||
|
UserCtx = ReactContext("user")
|
||||||
|
|
||||||
|
class CtxOutput(BaseModel):
|
||||||
|
data: str
|
||||||
|
|
||||||
|
@client(context=UserCtx)
|
||||||
|
def user_profile(request: HttpRequest, user_id: int) -> CtxOutput:
|
||||||
|
return CtxOutput(data=f"user_{user_id}")
|
||||||
|
|
||||||
|
register(user_profile, "user_profile")
|
||||||
|
|
||||||
|
fn = get_function("user_profile")
|
||||||
|
self.assertEqual(fn._meta.get("context"), "user")
|
||||||
|
|
||||||
|
def test_context_global_context(self):
|
||||||
|
"""Test @client(context=GlobalContext) uses the 'global' name."""
|
||||||
|
|
||||||
|
class CtxOutput(BaseModel):
|
||||||
|
data: str
|
||||||
|
|
||||||
|
@client(context=GlobalContext)
|
||||||
|
def site_info(request: HttpRequest) -> CtxOutput:
|
||||||
|
return CtxOutput(data="test")
|
||||||
|
|
||||||
|
register(site_info, "site_info")
|
||||||
|
|
||||||
|
fn = get_function("site_info")
|
||||||
|
self.assertEqual(fn._meta.get("context"), "global")
|
||||||
|
|
||||||
def test_context_empty_string_raises(self):
|
def test_context_empty_string_raises(self):
|
||||||
"""Test that empty context string raises ValueError."""
|
"""Test that empty context string raises ValueError."""
|
||||||
with self.assertRaises(ValueError):
|
with self.assertRaises(ValueError):
|
||||||
@@ -559,22 +590,29 @@ class ContextTests(TestCase):
|
|||||||
def bad_context(request: HttpRequest) -> ValidOutput:
|
def bad_context(request: HttpRequest) -> ValidOutput:
|
||||||
return ValidOutput(valid=True)
|
return ValidOutput(valid=True)
|
||||||
|
|
||||||
|
def test_react_context_empty_name_raises(self):
|
||||||
|
"""Test that ReactContext('') raises ValueError."""
|
||||||
|
with self.assertRaises(ValueError):
|
||||||
|
ReactContext("")
|
||||||
|
|
||||||
def test_context_groups(self):
|
def test_context_groups(self):
|
||||||
"""Test get_context_groups() groups functions by context name."""
|
"""Test get_context_groups() groups functions by context name."""
|
||||||
from mizan.setup.registry import get_context_groups
|
from mizan.setup.registry import get_context_groups
|
||||||
|
|
||||||
|
UserCtx = ReactContext("user")
|
||||||
|
|
||||||
class Out(BaseModel):
|
class Out(BaseModel):
|
||||||
v: int
|
v: int
|
||||||
|
|
||||||
@client(context="user")
|
@client(context=UserCtx)
|
||||||
def fn_a(request: HttpRequest, user_id: int) -> Out:
|
def fn_a(request: HttpRequest, user_id: int) -> Out:
|
||||||
return Out(v=1)
|
return Out(v=1)
|
||||||
|
|
||||||
@client(context="user")
|
@client(context=UserCtx)
|
||||||
def fn_b(request: HttpRequest, user_id: int) -> Out:
|
def fn_b(request: HttpRequest, user_id: int) -> Out:
|
||||||
return Out(v=2)
|
return Out(v=2)
|
||||||
|
|
||||||
@client(context="global")
|
@client(context=GlobalContext)
|
||||||
def fn_c(request: HttpRequest) -> Out:
|
def fn_c(request: HttpRequest) -> Out:
|
||||||
return Out(v=3)
|
return Out(v=3)
|
||||||
|
|
||||||
@@ -624,8 +662,21 @@ class AffectsTests(TestCase):
|
|||||||
def setUp(self):
|
def setUp(self):
|
||||||
clear_registry()
|
clear_registry()
|
||||||
|
|
||||||
|
def test_affects_react_context(self):
|
||||||
|
"""Test @client(affects=ReactContext) stores context invalidation."""
|
||||||
|
UserCtx = ReactContext("user")
|
||||||
|
|
||||||
|
@client(affects=UserCtx)
|
||||||
|
def edit_profile(request: HttpRequest, name: str) -> ValidOutput:
|
||||||
|
return ValidOutput(valid=True)
|
||||||
|
|
||||||
|
self.assertEqual(
|
||||||
|
edit_profile._meta["affects"],
|
||||||
|
[{"type": "context", "name": "user"}],
|
||||||
|
)
|
||||||
|
|
||||||
def test_affects_context_string(self):
|
def test_affects_context_string(self):
|
||||||
"""Test @client(affects='user') stores context invalidation."""
|
"""Test @client(affects='user') stores context invalidation (string form)."""
|
||||||
|
|
||||||
@client(affects="user")
|
@client(affects="user")
|
||||||
def edit_profile(request: HttpRequest, name: str) -> ValidOutput:
|
def edit_profile(request: HttpRequest, name: str) -> ValidOutput:
|
||||||
@@ -636,6 +687,20 @@ class AffectsTests(TestCase):
|
|||||||
[{"type": "context", "name": "user"}],
|
[{"type": "context", "name": "user"}],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_affects_list_of_react_contexts(self):
|
||||||
|
"""Test @client(affects=[ctx1, ctx2]) stores multiple contexts."""
|
||||||
|
UserCtx = ReactContext("user")
|
||||||
|
OrderCtx = ReactContext("orders")
|
||||||
|
|
||||||
|
@client(affects=[UserCtx, OrderCtx])
|
||||||
|
def change_plan(request: HttpRequest) -> ValidOutput:
|
||||||
|
return ValidOutput(valid=True)
|
||||||
|
|
||||||
|
affects = change_plan._meta["affects"]
|
||||||
|
self.assertEqual(len(affects), 2)
|
||||||
|
self.assertEqual(affects[0], {"type": "context", "name": "user"})
|
||||||
|
self.assertEqual(affects[1], {"type": "context", "name": "orders"})
|
||||||
|
|
||||||
def test_affects_function_ref(self):
|
def test_affects_function_ref(self):
|
||||||
"""Test @client(affects=fn_ref) stores function invalidation."""
|
"""Test @client(affects=fn_ref) stores function invalidation."""
|
||||||
|
|
||||||
@@ -653,14 +718,15 @@ class AffectsTests(TestCase):
|
|||||||
self.assertEqual(affects[0]["name"], "user_profile")
|
self.assertEqual(affects[0]["name"], "user_profile")
|
||||||
self.assertEqual(affects[0]["context"], "user")
|
self.assertEqual(affects[0]["context"], "user")
|
||||||
|
|
||||||
def test_affects_list(self):
|
def test_affects_mixed_list(self):
|
||||||
"""Test @client(affects=[...]) stores multiple targets."""
|
"""Test @client(affects=[fn_ref, ReactContext]) stores mixed targets."""
|
||||||
|
BillingCtx = ReactContext("billing")
|
||||||
|
|
||||||
@client(context="user")
|
@client(context="user")
|
||||||
def user_profile(request: HttpRequest) -> ValidOutput:
|
def user_profile(request: HttpRequest) -> ValidOutput:
|
||||||
return ValidOutput(valid=True)
|
return ValidOutput(valid=True)
|
||||||
|
|
||||||
@client(affects=[user_profile, "billing"])
|
@client(affects=[user_profile, BillingCtx])
|
||||||
def change_plan(request: HttpRequest) -> ValidOutput:
|
def change_plan(request: HttpRequest) -> ValidOutput:
|
||||||
return ValidOutput(valid=True)
|
return ValidOutput(valid=True)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user