diff --git a/django/src/mizan/__init__.py b/django/src/mizan/__init__.py index 6f8a4b8..bda06c6 100644 --- a/django/src/mizan/__init__.py +++ b/django/src/mizan/__init__.py @@ -88,7 +88,7 @@ from . import forms from . import setup from .channels import ReactChannel from .channels import register as register_channel -from .client import ComposedContext, ServerFunction, client, compose +from .client import ComposedContext, GlobalContext, ReactContext, ServerFunction, client, compose # Shape is lazy-loaded via __getattr__ because django_readers # imports contenttypes, which can't happen during apps.populate() @@ -157,9 +157,11 @@ def wrap_asgi(http_application): __all__ = [ - # Decorators + # Decorators & Contexts "client", "compose", + "ReactContext", + "GlobalContext", "ServerFunction", "ComposedContext", # Setup diff --git a/django/src/mizan/client/__init__.py b/django/src/mizan/client/__init__.py index a9ed345..5c80b93 100644 --- a/django/src/mizan/client/__init__.py +++ b/django/src/mizan/client/__init__.py @@ -14,6 +14,9 @@ Usage: from .function import ( # Decorator client, + # Context markers + ReactContext, + GlobalContext, # Base classes ServerFunction, ComposedContext, @@ -39,6 +42,9 @@ from .executor import ( __all__ = [ # Decorator "client", + # Context markers + "ReactContext", + "GlobalContext", # Base classes "ServerFunction", "ComposedContext", diff --git a/django/src/mizan/client/function.py b/django/src/mizan/client/function.py index f2febdf..104d17a 100644 --- a/django/src/mizan/client/function.py +++ b/django/src/mizan/client/function.py @@ -39,9 +39,46 @@ from django.http import HttpRequest from pydantic import BaseModel -# Context mode: any non-empty string names a context, False means not a context. -# 'global' is a reserved context name whose provider is auto-mounted at root. -ContextMode = str | Literal[False] +# ============================================================================= +# REACT CONTEXT - Named context marker +# ============================================================================= + + +class ReactContext: + """ + A named context that groups server functions into one provider and one fetch. + + Usage: + UserContext = ReactContext('user') + + @client(context=UserContext) + def user_profile(request, user_id: int) -> ProfileShape: ... + + @client(context=UserContext) + def user_orders(request, user_id: int) -> list[OrderShape]: ... + + @client(affects=UserContext) + def edit_profile(request, name: str) -> dict: ... + + @client(affects=[UserContext, OrderContext]) + def change_plan(request) -> dict: ... + """ + + def __init__(self, name: str): + if not name or not isinstance(name, str): + raise ValueError("ReactContext name must be a non-empty string") + self.name = name + + def __repr__(self) -> str: + return f"ReactContext({self.name!r})" + + +# Built-in global context (auto-mounted at root, SSR-hydrated) +GlobalContext = ReactContext("global") + + +# Context parameter type: a ReactContext instance, a raw string, or False +ContextMode = ReactContext | str | Literal[False] TInput = TypeVar("TInput", bound=BaseModel) @@ -183,11 +220,37 @@ class _FunctionWrapper(ServerFunction): _VALID_AUTH_STRINGS = frozenset({"required", "staff", "superuser"}) +def _resolve_context(context: ContextMode) -> str | Literal[False]: + """Resolve a context parameter to its name string.""" + if context is False: + return False + if isinstance(context, ReactContext): + return context.name + if isinstance(context, str): + if not context.strip(): + raise ValueError("context must be a non-empty string, ReactContext, or False.") + if context == "local": + warnings.warn( + "context='local' is deprecated. Use ReactContext('name') instead.", + DeprecationWarning, + stacklevel=3, + ) + return context + raise ValueError( + f"context must be a ReactContext, a string, or False. Got {type(context).__name__}." + ) + + +# Affects parameter type +AffectsTarget = ReactContext | str | type["ServerFunction"] +AffectsMode = AffectsTarget | list[AffectsTarget] | None + + def client( fn: Callable = None, *, context: ContextMode = False, - affects: str | type["ServerFunction"] | list[str | type["ServerFunction"]] | None = None, + affects: AffectsMode = None, websocket: bool = False, auth: bool | str | Callable[[Any], bool] | None = None, ) -> type[ServerFunction] | Callable[[Callable], type[ServerFunction]]: @@ -200,15 +263,14 @@ def client( Args: context: Named context for React state management. - False (default): Not a context, just a callable function. - - 'global': Reserved name. Embedded in root MizanProvider, - no params, SSR-hydrated. - - Any other string: Named context. Functions sharing the same - context name are grouped into one provider and one fetch. + - ReactContext instance: groups functions into a named context. + - GlobalContext: reserved, auto-mounted at root, SSR-hydrated. + - Raw string: also accepted (e.g., 'user'), but ReactContext preferred. affects: Declare which contexts this mutation invalidates. - - A context name string: refetch the entire named context - - A function reference: refetch that function's context - - A list of the above: refetch all specified targets + - A ReactContext instance + - A list of ReactContext instances + - Also accepts strings or function references for backwards compat Mutually exclusive with context=. websocket: Enable WebSocket RPC transport (default: False). @@ -221,42 +283,29 @@ def client( - callable(request) -> bool: Custom check function Usage: - @client - def echo(request, message: str) -> EchoOutput: - return EchoOutput(message=message) + UserContext = ReactContext('user') - @client(context='global') - def current_user(request) -> UserOutput: - return UserOutput(email=request.user.email) + @client(context=GlobalContext) + def current_user(request) -> UserOutput: ... - # Named context - functions sharing a name are bundled - @client(context='user') + @client(context=UserContext) def user_profile(request, user_id: int) -> ProfileOutput: ... - # Mutation that invalidates a context - @client(affects='user') + @client(affects=UserContext) def edit_profile(request, name: str) -> dict: ... + @client(affects=[UserContext, OrderContext]) + def change_plan(request) -> dict: ... + Returns: A ServerFunction class that wraps the function """ - # Validate context parameter - if context is not False: - if not isinstance(context, str) or not context.strip(): - raise ValueError( - "context must be a non-empty string or False." - ) - if context == "local": - warnings.warn( - "context='local' is deprecated. Use a named context string instead " - "(e.g., context='my_context').", - DeprecationWarning, - stacklevel=2, - ) + # Resolve context to name string + resolved_context = _resolve_context(context) # Validate affects parameter if affects is not None: - if context is not False: + if resolved_context is not False: raise ValueError( "context= and affects= are mutually exclusive. " "A function cannot be both a context reader and a mutation." @@ -272,20 +321,18 @@ def client( def decorator(fn: Callable) -> type[ServerFunction]: return _create_server_function( - fn, context=context, affects=affects, websocket=websocket, auth=auth + fn, context=resolved_context, affects=affects, websocket=websocket, auth=auth ) # Support both @client and @client(...) if fn is not None: return _create_server_function( - fn, context=context, affects=affects, websocket=websocket, auth=auth + fn, context=resolved_context, affects=affects, websocket=websocket, auth=auth ) return decorator -def _normalize_affects( - affects: str | type["ServerFunction"] | list[str | type["ServerFunction"]] | None, -) -> list[dict[str, str]] | None: +def _normalize_affects(affects: AffectsMode) -> list[dict[str, str]] | None: """Normalize the affects parameter into a list of target descriptors.""" if affects is None: return None @@ -293,7 +340,9 @@ def _normalize_affects( items = affects if isinstance(affects, list) else [affects] result = [] for item in items: - if isinstance(item, str): + if isinstance(item, ReactContext): + result.append({"type": "context", "name": item.name}) + elif isinstance(item, str): result.append({"type": "context", "name": item}) elif isinstance(item, type) and issubclass(item, ServerFunction): fn_meta = getattr(item, "_meta", {}) @@ -305,8 +354,8 @@ def _normalize_affects( }) else: raise ValueError( - f"affects items must be context name strings or @client function references, " - f"got {type(item)}" + f"affects items must be ReactContext instances, context name strings, " + f"or @client function references. Got {type(item)}" ) return result diff --git a/django/src/mizan/tests/test_core.py b/django/src/mizan/tests/test_core.py index 8aaeed0..f40c9a0 100644 --- a/django/src/mizan/tests/test_core.py +++ b/django/src/mizan/tests/test_core.py @@ -26,7 +26,7 @@ from mizan.setup.registry import ( get_contexts, get_function, ) -from mizan.client import ServerFunction, client +from mizan.client import ServerFunction, client, ReactContext, GlobalContext from mizan.channels import ReactChannel @@ -536,8 +536,8 @@ class ContextTests(TestCase): fn = get_function("local_context") self.assertEqual(fn._meta.get("context"), "local") - def test_context_named(self): - """Test @client(context='user') creates a named context.""" + def test_context_named_string(self): + """Test @client(context='user') creates a named context (string form).""" class CtxOutput(BaseModel): data: str @@ -551,6 +551,37 @@ class ContextTests(TestCase): fn = get_function("user_profile") self.assertEqual(fn._meta.get("context"), "user") + def test_context_react_context(self): + """Test @client(context=ReactContext('user')) creates a named context.""" + UserCtx = ReactContext("user") + + class CtxOutput(BaseModel): + data: str + + @client(context=UserCtx) + def user_profile(request: HttpRequest, user_id: int) -> CtxOutput: + return CtxOutput(data=f"user_{user_id}") + + register(user_profile, "user_profile") + + fn = get_function("user_profile") + self.assertEqual(fn._meta.get("context"), "user") + + def test_context_global_context(self): + """Test @client(context=GlobalContext) uses the 'global' name.""" + + class CtxOutput(BaseModel): + data: str + + @client(context=GlobalContext) + def site_info(request: HttpRequest) -> CtxOutput: + return CtxOutput(data="test") + + register(site_info, "site_info") + + fn = get_function("site_info") + self.assertEqual(fn._meta.get("context"), "global") + def test_context_empty_string_raises(self): """Test that empty context string raises ValueError.""" with self.assertRaises(ValueError): @@ -559,22 +590,29 @@ class ContextTests(TestCase): def bad_context(request: HttpRequest) -> ValidOutput: return ValidOutput(valid=True) + def test_react_context_empty_name_raises(self): + """Test that ReactContext('') raises ValueError.""" + with self.assertRaises(ValueError): + ReactContext("") + def test_context_groups(self): """Test get_context_groups() groups functions by context name.""" from mizan.setup.registry import get_context_groups + UserCtx = ReactContext("user") + class Out(BaseModel): v: int - @client(context="user") + @client(context=UserCtx) def fn_a(request: HttpRequest, user_id: int) -> Out: return Out(v=1) - @client(context="user") + @client(context=UserCtx) def fn_b(request: HttpRequest, user_id: int) -> Out: return Out(v=2) - @client(context="global") + @client(context=GlobalContext) def fn_c(request: HttpRequest) -> Out: return Out(v=3) @@ -624,8 +662,21 @@ class AffectsTests(TestCase): def setUp(self): clear_registry() + def test_affects_react_context(self): + """Test @client(affects=ReactContext) stores context invalidation.""" + UserCtx = ReactContext("user") + + @client(affects=UserCtx) + def edit_profile(request: HttpRequest, name: str) -> ValidOutput: + return ValidOutput(valid=True) + + self.assertEqual( + edit_profile._meta["affects"], + [{"type": "context", "name": "user"}], + ) + def test_affects_context_string(self): - """Test @client(affects='user') stores context invalidation.""" + """Test @client(affects='user') stores context invalidation (string form).""" @client(affects="user") def edit_profile(request: HttpRequest, name: str) -> ValidOutput: @@ -636,6 +687,20 @@ class AffectsTests(TestCase): [{"type": "context", "name": "user"}], ) + def test_affects_list_of_react_contexts(self): + """Test @client(affects=[ctx1, ctx2]) stores multiple contexts.""" + UserCtx = ReactContext("user") + OrderCtx = ReactContext("orders") + + @client(affects=[UserCtx, OrderCtx]) + def change_plan(request: HttpRequest) -> ValidOutput: + return ValidOutput(valid=True) + + affects = change_plan._meta["affects"] + self.assertEqual(len(affects), 2) + self.assertEqual(affects[0], {"type": "context", "name": "user"}) + self.assertEqual(affects[1], {"type": "context", "name": "orders"}) + def test_affects_function_ref(self): """Test @client(affects=fn_ref) stores function invalidation.""" @@ -653,14 +718,15 @@ class AffectsTests(TestCase): self.assertEqual(affects[0]["name"], "user_profile") self.assertEqual(affects[0]["context"], "user") - def test_affects_list(self): - """Test @client(affects=[...]) stores multiple targets.""" + def test_affects_mixed_list(self): + """Test @client(affects=[fn_ref, ReactContext]) stores mixed targets.""" + BillingCtx = ReactContext("billing") @client(context="user") def user_profile(request: HttpRequest) -> ValidOutput: return ValidOutput(valid=True) - @client(affects=[user_profile, "billing"]) + @client(affects=[user_profile, BillingCtx]) def change_plan(request: HttpRequest) -> ValidOutput: return ValidOutput(valid=True)