Mutation→context merge primitive across the stack

The @client(merge=[context, ...]) decorator lets a mutation patch its
return value directly into the cached context bundle by matching the
mutation's Output type against each context-function's Output type
to identify the slot, then splicing server-side. Kernel runs
splice_slot on the response to apply locally — no refetch, no
invalidate-cascade.

Lands H14, H15, H16, M19, M20 from ISSUES.md.

Backends (Django + FastAPI):
  _resolve_merges() in both executors walks @client(merge=...) targets,
  resolves the per-context slot via types_match_for_merge, and emits
  {context, slot, value, params?} entries on the response. Param
  auto-scoping mirrors _resolve_invalidation's tier-1 logic.

Frontend kernel (mizan-base):
  Response handler reads the merge[] array and applies splice_slot
  for each entry — locates the cached context bundle by name+params,
  overwrites the named slot with the new value, notifies subscribers.

Core (mizan-python):
  @client decorator extended with merge= parameter. Schema export
  threads merge metadata onto the OpenAPI x-mizan-functions entries.

Examples / fixtures:
  fastapi-react-site harness exercises merge + Playwright spec covers
  the end-to-end happy path (mutation → instant UI update without
  network refetch). AFI fixture's rename_user function is the
  canonical merge target.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
2026-05-17 18:29:06 -04:00
parent 43bcf3f26f
commit 7fb0c4a400
22 changed files with 1142 additions and 56 deletions

View File

@@ -12,9 +12,11 @@ from __future__ import annotations
from enum import Enum
from typing import Any
from fastapi.encoders import jsonable_encoder
from pydantic import BaseModel, ValidationError
from mizan_core.registry import get_function
from mizan_core.registry import get_context_groups, get_function
from mizan_core.type_utils import types_match_for_merge
# ─── Error taxonomy ─────────────────────────────────────────────────────────
@@ -148,15 +150,21 @@ def _resolve_function(fn_name: str) -> Any:
def _serialize(result: Any) -> Any:
return result.model_dump(mode="json") if isinstance(result, BaseModel) else result
# jsonable_encoder walks BaseModel / list / dict recursively, so list[BaseModel]
# (and nested shapes) come out wire-ready without a per-shape branch here.
return jsonable_encoder(result)
def execute_function(
async def execute_function(
request: Any,
fn_name: str,
input_data: dict[str, Any] | None = None,
) -> Any:
"""Dispatch a registered function. Returns the serialized result, or raises MizanError."""
"""Dispatch a registered function. Returns the serialized result, or raises MizanError.
Awaits `view.acall` — async handlers run on the loop, sync handlers run
in the default threadpool, both via the same entrypoint.
"""
view_class = _resolve_function(fn_name)
_enforce_auth(request, view_class._meta.get("auth"))
@@ -164,7 +172,7 @@ def execute_function(
validated = _validate_input(view.Input, input_data)
try:
result = view.call(validated)
result = await view.acall(validated)
except NotImplementedError as e:
raise NotImplementedYet(str(e) or "Not implemented") from e
except MizanError:
@@ -184,12 +192,70 @@ def compute_invalidation(view_class: Any, input_data: dict[str, Any] | None) ->
return [_invalidation_target(target, input_data or {}) for target in affects]
def compute_merges(view_class: Any, input_data: dict[str, Any] | None, result: Any) -> list[dict[str, Any]]:
"""Build the `merge` list from @client(merge=...) metadata.
Each entry is `{context, slot, value, params?}` where `slot` names the
function inside the context bundle the value lands in. The slot is
resolved server-side via `types_match_for_merge` so the kernel does
no shape inference — the server has the schema, type-checked routing
lives here. Entries whose slot can't be uniquely resolved are dropped
with a warning; the consumer falls back to refetch via `affects`.
"""
targets = getattr(view_class, "_meta", {}).get("merge") or []
if not targets:
return []
mutation_output = getattr(view_class, "Output", None)
out: list[dict[str, Any]] = []
for ctx_name in targets:
slot = _resolve_merge_slot(ctx_name, mutation_output)
if slot is None:
continue
entry: dict[str, Any] = {"context": ctx_name, "slot": slot, "value": result}
scoped = _scoped_params(ctx_name, input_data or {})
if scoped:
entry["params"] = scoped
out.append(entry)
return out
def _resolve_merge_slot(context_name: str, mutation_output: Any) -> str | None:
"""Find the unique function-name slot whose return type matches the mutation's output.
Returns None on no match or ambiguous match (multiple candidates).
"""
if mutation_output is None:
return None
matches: list[str] = []
for fn_name in get_context_groups().get(context_name, []):
fn_cls = get_function(fn_name)
if fn_cls is None:
continue
fn_output = getattr(fn_cls, "Output", None)
if fn_output is not None and types_match_for_merge(fn_output, mutation_output):
matches.append(fn_name)
return matches[0] if len(matches) == 1 else None
def _scoped_params(context_name: str, input_data: dict[str, Any]) -> dict[str, Any]:
"""Match input args against the context's declared Input field names."""
fn_names = get_context_groups().get(context_name, [])
declared: set[str] = set()
for fn_name in fn_names:
fn_cls = get_function(fn_name)
if fn_cls is None:
continue
input_cls = getattr(fn_cls, "Input", None)
if input_cls and input_cls is not BaseModel and hasattr(input_cls, "model_fields"):
declared.update(input_cls.model_fields.keys())
return {k: v for k, v in input_data.items() if k in declared}
def _invalidation_target(target: dict[str, Any], input_data: dict[str, Any]) -> Any:
match target.get("type"):
case "context":
name = target["name"]
scope_keys = (target.get("params") or {}).keys()
scoped = {k: input_data[k] for k in scope_keys if k in input_data}
scoped = _scoped_params(name, input_data)
return {"context": name, "params": scoped} if scoped else name
case "function":
return {"function": target["name"]}

View File

@@ -28,6 +28,7 @@ from .executor import (
MizanError,
NotFound,
compute_invalidation,
compute_merges,
execute_function,
)
@@ -49,10 +50,15 @@ class CallBody(BaseModel):
@router.post("/call/")
async def function_call(body: CallBody, request: Request) -> JSONResponse:
"""RPC dispatch — `{"fn": "...", "args": {...}}` → `{"result": ..., "invalidate": [...]}`."""
result = execute_function(request, body.fn, body.args)
invalidate = compute_invalidation(get_function(body.fn), body.args)
return _no_store({"result": result, "invalidate": invalidate})
"""RPC dispatch — `{"fn": "...", "args": {...}}` → `{"result": ..., "invalidate": [...], "merge"?: [...]}`."""
fn_class = get_function(body.fn)
result = await execute_function(request, body.fn, body.args)
invalidate = compute_invalidation(fn_class, body.args)
merges = compute_merges(fn_class, body.args, result)
payload: dict[str, Any] = {"result": result, "invalidate": invalidate}
if merges:
payload["merge"] = merges
return _no_store(payload)
@router.get("/ctx/{context_name}/")
@@ -63,7 +69,7 @@ async def context_fetch(context_name: str, request: Request) -> JSONResponse:
raise NotFound(f"Context '{context_name}' not found")
params = dict(request.query_params)
bundled = {fn: execute_function(request, fn, params) for fn in fn_names}
bundled = {fn: await execute_function(request, fn, params) for fn in fn_names}
return _no_store(bundled)

View File

@@ -2,6 +2,8 @@
from __future__ import annotations
import asyncio
import pytest
from fastapi import FastAPI
from fastapi.exceptions import RequestValidationError
@@ -34,6 +36,11 @@ class UserOutput(BaseModel):
authenticated: bool
class ItemOutput(BaseModel):
id: int
name: str
@pytest.fixture
def app():
"""Build a fresh FastAPI app + Mizan router with a few @client functions."""
@@ -63,12 +70,39 @@ def app():
def whoami(request) -> UserOutput:
return UserOutput(email="real@example.com", authenticated=True)
@client
def list_items(request) -> list[ItemOutput]:
return [ItemOutput(id=1, name="a"), ItemOutput(id=2, name="b")]
@client
def find_item(request, item_id: int) -> ItemOutput | None:
return ItemOutput(id=item_id, name="found") if item_id > 0 else None
@client(merge="items")
def set_item_name(request, id: int, name: str) -> ItemOutput:
return ItemOutput(id=id, name=name)
@client(context="items")
def items_list(request) -> list[ItemOutput]:
return [ItemOutput(id=1, name="orig")]
@client
async def async_echo(request, text: str) -> EchoOutput:
# await something on the loop to prove we're really running async
await asyncio.sleep(0)
return EchoOutput(message=f"async: {text}")
register(echo, "echo")
register(add, "add")
register(current_user, "current_user")
register(user_count, "user_count")
register(update_email, "update_email")
register(whoami, "whoami")
register(list_items, "list_items")
register(find_item, "find_item")
register(set_item_name, "set_item_name")
register(items_list, "items_list")
register(async_echo, "async_echo")
fastapi_app = FastAPI()
fastapi_app.include_router(mizan_router, prefix="/api/mizan")
@@ -171,3 +205,58 @@ class InvalidationTests:
body = r.json()
# affects='user' is a context-name string → invalidate list contains 'user'
assert "user" in body["invalidate"]
# ─── Structured-output shapes ───────────────────────────────────────────────
class StructuredOutputTests:
"""list[BaseModel] and Optional[BaseModel] should reach the wire as bare values, not {result: ...}."""
def test_list_of_basemodel_returns_bare_array(self, http):
r = http.post("/api/mizan/call/", json={"fn": "list_items", "args": {}})
assert r.status_code == 200
body = r.json()
assert body["result"] == [
{"id": 1, "name": "a"},
{"id": 2, "name": "b"},
]
def test_optional_basemodel_returns_inner_or_none(self, http):
r_found = http.post("/api/mizan/call/", json={"fn": "find_item", "args": {"item_id": 5}})
assert r_found.status_code == 200
assert r_found.json()["result"] == {"id": 5, "name": "found"}
r_missing = http.post("/api/mizan/call/", json={"fn": "find_item", "args": {"item_id": 0}})
assert r_missing.status_code == 200
assert r_missing.json()["result"] is None
# ─── Merge protocol ─────────────────────────────────────────────────────────
class AsyncHandlerTests:
"""`async def` handlers dispatch on the loop via view.acall."""
def test_async_handler_returns_awaited_result(self, http):
r = http.post("/api/mizan/call/", json={"fn": "async_echo", "args": {"text": "hello"}})
assert r.status_code == 200
assert r.json()["result"] == {"message": "async: hello"}
class MergeTests:
"""@client(merge=...) emits a `merge` field in the response so the kernel can splice without refetch."""
def test_merge_target_emits_merge_entry(self, http):
r = http.post(
"/api/mizan/call/",
json={"fn": "set_item_name", "args": {"id": 42, "name": "renamed"}},
)
assert r.status_code == 200
body = r.json()
# Server resolves slot — items_list returns list[ItemOutput], mutation returns ItemOutput
assert body["merge"] == [
{"context": "items", "slot": "items_list", "value": {"id": 42, "name": "renamed"}}
]
# invalidate stays empty when only merge is declared
assert body["invalidate"] == []