diff --git a/packages/mizan-django/src/mizan/client/executor.py b/packages/mizan-django/src/mizan/client/executor.py index 84d96d8..f517c63 100644 --- a/packages/mizan-django/src/mizan/client/executor.py +++ b/packages/mizan-django/src/mizan/client/executor.py @@ -159,21 +159,67 @@ def _check_auth_requirement( return None +def _resolve_affects_target(target_name: str) -> tuple[str, str, str | None]: + """ + Determine whether an affects target is a context name or function name. + + Returns: + ("context", "user", None) — full context invalidation + ("function", "user_profile", "user") — function within context + """ + groups = get_context_groups() + + # Check if it's a context name directly + if target_name in groups: + return ("context", target_name, None) + + # Check if it's a function name within a context + for ctx_name, fn_names in groups.items(): + if target_name in fn_names: + return ("function", target_name, ctx_name) + + # Not a context or context function — treat as context name anyway + # (it might be a non-context function or an as-yet-unregistered context) + return ("context", target_name, None) + + +def _get_context_param_names(context_name: str) -> set[str]: + """ + Get the set of parameter names used by functions in a context. + + Returns the union of all Input field names across context functions. + """ + groups = get_context_groups() + fn_names = groups.get(context_name, []) + param_names: set[str] = set() + + for fn_name in fn_names: + fn_cls = get_function(fn_name) + if fn_cls is None: + continue + input_cls = getattr(fn_cls, "Input", None) + if input_cls and input_cls is not BaseModel and hasattr(input_cls, "model_fields"): + param_names.update(input_cls.model_fields.keys()) + + return param_names + + def _resolve_invalidation( view_class: type | None, - request: HttpRequest | None = None, + input_data: dict[str, Any] | None = None, ) -> list[str | dict[str, Any]] | None: """ - Resolve the invalidation targets from a function's affects metadata. + Resolve invalidation targets with three-tier auto-scoping. - If affects_params is declared, calls it with the request to produce - scoped invalidation entries. + Tier 1: Argument name matching — if the mutation's input args overlap + with the context's params by name, auto-scope. + Tier 2: Auth inference — Edge-side concern, not handled here. + Tier 3: Broad fallback — invalidate all instances. - Returns a list suitable for both JSON body and header serialization: - - Simple: ["user", "notifications"] - - Scoped: [{"context": "user", "params": {"user_id": 5}}] - - Mixed: ["notifications", {"context": "user", "params": {"user_id": 5}}] + Also handles function-level targeting: affects='user_profile' resolves + to the function name (v1: runtime refetches the whole context anyway). + Returns a list suitable for both JSON body and header serialization. Returns None if no invalidation needed. """ if view_class is None: @@ -184,35 +230,41 @@ def _resolve_invalidation( if not affects: return None - # Resolve context names from affects targets - context_names = [] + result = [] + seen = set() + for target in affects: if target["type"] == "context": - context_names.append(target["name"]) + target_name = target["name"] elif target["type"] == "function" and target.get("context"): - context_names.append(target["context"]) + # Function-level: use the function name as the invalidation key + target_name = target["name"] + else: + continue - if not context_names: - return None + if target_name in seen: + continue + seen.add(target_name) - # Dedupe while preserving order - context_names = list(dict.fromkeys(context_names)) + # Resolve the context this target belongs to (for param lookup) + resolved = _resolve_affects_target(target_name) + ctx_for_params = resolved[2] if resolved[0] == "function" else resolved[1] - # If affects_params is declared, produce scoped entries - affects_params_fn = meta.get("affects_params") - if affects_params_fn and request is not None: - try: - params = affects_params_fn(request) - if params and isinstance(params, dict): - return [ - {"context": name, "params": params} - for name in context_names - ] - except Exception as e: - logger.warning(f"affects_params callable failed: {e}") - # Fall through to broad invalidation + # Tier 1: argument name matching + if input_data and ctx_for_params: + context_params = _get_context_param_names(ctx_for_params) + matched = { + k: v for k, v in input_data.items() + if k in context_params + } + if matched: + result.append({"context": target_name, "params": matched}) + continue - return context_names + # Tier 3: broad fallback + result.append(target_name) + + return result if result else None def _format_invalidate_header( @@ -568,7 +620,7 @@ def function_call_view(request: HttpRequest) -> JsonResponse: # Build response with server-driven invalidation (both transports) view_class = get_function(fn_name) response_data = {"result": result.data} - invalidate_contexts = _resolve_invalidation(view_class, request) + invalidate_contexts = _resolve_invalidation(view_class, input_data) if invalidate_contexts: response_data["invalidate"] = invalidate_contexts diff --git a/packages/mizan-django/src/mizan/client/function.py b/packages/mizan-django/src/mizan/client/function.py index f39a041..6e768dc 100644 --- a/packages/mizan-django/src/mizan/client/function.py +++ b/packages/mizan-django/src/mizan/client/function.py @@ -251,7 +251,6 @@ def client( *, context: ContextMode = False, affects: AffectsMode = None, - affects_params: Callable[[Any], dict[str, Any]] | None = None, websocket: bool = False, auth: bool | str | Callable[[Any], bool] | None = None, ) -> type[ServerFunction] | Callable[[Callable], type[ServerFunction]]: @@ -264,16 +263,14 @@ def client( - ReactContext instance: groups functions into a named context. - GlobalContext: reserved, auto-mounted at root, SSR-hydrated. - affects: Declare which contexts this mutation invalidates. - - A ReactContext instance or list of them + affects: Declare which contexts or functions this mutation invalidates. + - A ReactContext instance or context name string: invalidates entire context + - A function name string: invalidates just that function within its context + - A list of the above: invalidates multiple targets Mutually exclusive with context=. - affects_params: Callable that extracts scoped invalidation params. - Called with the request after function execution. - Returns a dict of params that scope the invalidation. - Produces: invalidate: [{context: "user", params: {user_id: 5}}] - And header: X-Mizan-Invalidate: user;user_id=5 - Requires affects= to be set. + Scoping is automatic: if the mutation's arguments overlap with the + context's params by name, the invalidation is scoped to those values. websocket: Enable WebSocket RPC transport (default: False). @@ -289,12 +286,17 @@ def client( @client(context=UserContext) def user_profile(request, user_id: int) -> ProfileOutput: ... + # Broad invalidation — all UserContext instances @client(affects=UserContext) - def edit_profile(request, name: str) -> dict: ... + def reset_all_profiles(request) -> dict: ... - # Scoped: only invalidate user context for this specific user - @client(affects=UserContext, affects_params=lambda req: {'user_id': req.user.pk}) - def update_avatar(request, url: str) -> dict: ... + # Auto-scoped — user_id matches, only invalidates UserContext(user_id=5) + @client(affects=UserContext) + def update_profile(request, user_id: int, name: str) -> dict: ... + + # Function-level — only user_profile refetches, not user_orders + @client(affects='user_profile') + def update_name(request, user_id: int, name: str) -> dict: ... Returns: A ServerFunction class that wraps the function @@ -310,12 +312,6 @@ def client( "A function cannot be both a context reader and a mutation." ) - # Validate affects_params - if affects_params is not None and affects is None: - raise ValueError( - "affects_params= requires affects= to be set." - ) - # Validate auth parameter if auth is not None: if isinstance(auth, str) and auth not in _VALID_AUTH_STRINGS: @@ -326,14 +322,14 @@ def client( def decorator(fn: Callable) -> type[ServerFunction]: return _create_server_function( - fn, context=resolved_context, affects=affects, affects_params=affects_params, + fn, context=resolved_context, affects=affects, websocket=websocket, auth=auth, ) # Support both @client and @client(...) if fn is not None: return _create_server_function( - fn, context=resolved_context, affects=affects, affects_params=affects_params, + fn, context=resolved_context, affects=affects, websocket=websocket, auth=auth, ) return decorator @@ -372,7 +368,6 @@ def _create_server_function( *, context: str | Literal[False] = False, affects: str | type["ServerFunction"] | list[str | type["ServerFunction"]] | None = None, - affects_params: Callable | None = None, websocket: bool = False, auth: bool | str | None = None, ) -> type[ServerFunction]: @@ -476,8 +471,6 @@ def _create_server_function( normalized_affects = _normalize_affects(affects) if normalized_affects: meta["affects"] = normalized_affects - if affects_params is not None: - meta["affects_params"] = affects_params # WebSocket: enable WebSocket transport if websocket: diff --git a/packages/mizan-django/src/mizan/tests/test_core.py b/packages/mizan-django/src/mizan/tests/test_core.py index 6b24e6d..69c7e80 100644 --- a/packages/mizan-django/src/mizan/tests/test_core.py +++ b/packages/mizan-django/src/mizan/tests/test_core.py @@ -829,55 +829,112 @@ class ServerDrivenInvalidationTests(TestCase): self.assertEqual(data["invalidate"], ["user", "notifications"]) self.assertEqual(response["X-Mizan-Invalidate"], "user, notifications") - def test_scoped_invalidation_with_affects_params(self): - """affects_params produces scoped invalidation in body and header.""" + def test_auto_scoped_invalidation(self): + """Mutation args overlapping with context params auto-scope.""" from mizan.client.executor import function_call_view - from django.contrib.auth.models import User UserCtx = ReactContext("user") - @client( - affects=UserCtx, - affects_params=lambda req: {"user_id": getattr(req.user, "pk", 0)}, - auth=True, - ) - def update_avatar(request: HttpRequest, url: str) -> ValidOutput: + @client(context=UserCtx) + def user_profile(request: HttpRequest, user_id: int) -> ValidOutput: return ValidOutput(valid=True) - register(update_avatar, "update_avatar") + @client(affects=UserCtx) + def update_profile(request: HttpRequest, user_id: int, name: str) -> ValidOutput: + return ValidOutput(valid=True) + + register(user_profile, "user_profile") + register(update_profile, "update_profile") request = self.factory.post( "/api/mizan/call/", - json.dumps({"fn": "update_avatar", "args": {"url": "https://example.com/pic.jpg"}}), + json.dumps({"fn": "update_profile", "args": {"user_id": 5, "name": "Ryth"}}), content_type="application/json", ) - # Create a mock user with pk - user = MagicMock() - user.pk = 42 - user.is_authenticated = True - request.user = user + request.user = AnonymousUser() request._dont_enforce_csrf_checks = True response = function_call_view(request) data = json.loads(response.content) - # Scoped invalidation in JSON body + # Auto-scoped: user_id matched between mutation args and context params self.assertEqual(len(data["invalidate"]), 1) self.assertEqual(data["invalidate"][0]["context"], "user") - self.assertEqual(data["invalidate"][0]["params"]["user_id"], 42) + self.assertEqual(data["invalidate"][0]["params"]["user_id"], 5) + self.assertEqual(response["X-Mizan-Invalidate"], "user;user_id=5") - # Scoped invalidation in header - self.assertEqual(response["X-Mizan-Invalidate"], "user;user_id=42") + def test_broad_invalidation_no_matching_args(self): + """Mutation with no matching args falls back to broad invalidation.""" + from mizan.client.executor import function_call_view - def test_affects_params_without_affects_raises(self): - """affects_params without affects raises ValueError.""" - with self.assertRaises(ValueError) as cm: + UserCtx = ReactContext("user") - @client(affects_params=lambda req: {"user_id": 1}) - def bad(request: HttpRequest) -> ValidOutput: - return ValidOutput(valid=True) + @client(context=UserCtx) + def user_profile(request: HttpRequest, user_id: int) -> ValidOutput: + return ValidOutput(valid=True) - self.assertIn("requires affects", str(cm.exception)) + # url doesn't match any context param + @client(affects=UserCtx) + def update_avatar(request: HttpRequest, url: str) -> ValidOutput: + return ValidOutput(valid=True) + + register(user_profile, "user_profile") + register(update_avatar, "update_avatar") + + request = self.factory.post( + "/api/mizan/call/", + json.dumps({"fn": "update_avatar", "args": {"url": "pic.jpg"}}), + content_type="application/json", + ) + request.user = AnonymousUser() + request._dont_enforce_csrf_checks = True + + response = function_call_view(request) + data = json.loads(response.content) + + # Broad: no param match + self.assertEqual(data["invalidate"], ["user"]) + self.assertEqual(response["X-Mizan-Invalidate"], "user") + + def test_function_level_affects(self): + """affects='user_profile' targets a specific function, not the whole context.""" + from mizan.client.executor import function_call_view + + UserCtx = ReactContext("user") + + @client(context=UserCtx) + def user_profile(request: HttpRequest, user_id: int) -> ValidOutput: + return ValidOutput(valid=True) + + @client(context=UserCtx) + def user_orders(request: HttpRequest, user_id: int) -> ValidOutput: + return ValidOutput(valid=True) + + # Targets user_profile specifically, not the whole 'user' context + @client(affects="user_profile") + def update_name(request: HttpRequest, user_id: int, name: str) -> ValidOutput: + return ValidOutput(valid=True) + + register(user_profile, "user_profile") + register(user_orders, "user_orders") + register(update_name, "update_name") + + request = self.factory.post( + "/api/mizan/call/", + json.dumps({"fn": "update_name", "args": {"user_id": 7, "name": "Ryth"}}), + content_type="application/json", + ) + request.user = AnonymousUser() + request._dont_enforce_csrf_checks = True + + response = function_call_view(request) + data = json.loads(response.content) + + # Function-level + auto-scoped + self.assertEqual(len(data["invalidate"]), 1) + self.assertEqual(data["invalidate"][0]["context"], "user_profile") + self.assertEqual(data["invalidate"][0]["params"]["user_id"], 7) + self.assertEqual(response["X-Mizan-Invalidate"], "user_profile;user_id=7") def test_mutation_without_affects_has_no_invalidate(self): """Mutation without affects= returns result only."""