From a726fd6863d0b4a3db626cc57fe9c231896aa8b9 Mon Sep 17 00:00:00 2001 From: Ryth Azhur Date: Tue, 31 Mar 2026 01:53:20 -0400 Subject: [PATCH] Add shapes module: Pydantic API surface for Django models Imported from separate development branch. Provides Shape, Diff, and NestedDiff classes for defining typed Pydantic schemas backed by Django model querysets via django-readers. Optional dependency: install with djarea[shapes] to get django-readers. Import is guarded so the rest of djarea works without it. Co-Authored-By: Claude Opus 4.6 (1M context) --- django/pyproject.toml | 3 + django/src/djarea/__init__.py | 6 + django/src/djarea/shapes/__init__.py | 3 + django/src/djarea/shapes/core.py | 168 +++++++++++++++++++++++++++ 4 files changed, 180 insertions(+) create mode 100644 django/src/djarea/shapes/__init__.py create mode 100644 django/src/djarea/shapes/core.py diff --git a/django/pyproject.toml b/django/pyproject.toml index e7dc351..dc6a579 100644 --- a/django/pyproject.toml +++ b/django/pyproject.toml @@ -22,6 +22,9 @@ allauth = [ webauthn = [ "fido2>=2.0", ] +shapes = [ + "django-readers>=2.0", +] dev = [ "pytest>=8.0", "pytest-django>=4.9", diff --git a/django/src/djarea/__init__.py b/django/src/djarea/__init__.py index 4096e7a..dbbe155 100644 --- a/django/src/djarea/__init__.py +++ b/django/src/djarea/__init__.py @@ -89,6 +89,10 @@ from . import setup from .channels import ReactChannel from .channels import register as register_channel from .client import ComposedContext, ServerFunction, client, compose +try: + from .shapes import Shape +except ImportError: + pass # django-readers not installed from .setup import ( djarea_clients, djarea_module, @@ -167,6 +171,8 @@ __all__ = [ # Channels "ReactChannel", "register_channel", + # Shapes + "Shape", # Submodules "client_module", "setup", diff --git a/django/src/djarea/shapes/__init__.py b/django/src/djarea/shapes/__init__.py new file mode 100644 index 0000000..2957d08 --- /dev/null +++ b/django/src/djarea/shapes/__init__.py @@ -0,0 +1,3 @@ +from djarea.shapes.core import Diff, NestedDiff, Shape + +__all__ = ["Diff", "NestedDiff", "Shape"] diff --git a/django/src/djarea/shapes/core.py b/django/src/djarea/shapes/core.py new file mode 100644 index 0000000..58a50b2 --- /dev/null +++ b/django/src/djarea/shapes/core.py @@ -0,0 +1,168 @@ +from __future__ import annotations + +from typing import Any, ClassVar, Generic, TypeVar, get_type_hints + +from pydantic import BaseModel + +from django_readers import pairs, specs +from django_readers import qs as readers_qs + +_M = TypeVar("_M") +_S = TypeVar("_S", bound="Shape") + + +def _extract_shape_class(hint) -> type[Shape] | None: + origin = getattr(hint, "__origin__", None) + args = getattr(hint, "__args__", ()) + + if origin is list and args and isinstance(args[0], type) and issubclass(args[0], Shape): + return args[0] + if isinstance(hint, type) and issubclass(hint, Shape) and hint is not Shape: + return hint + return None + + +def _resolve_model(cls) -> Any | None: + for base in cls.__bases__: + meta = getattr(base, "__pydantic_generic_metadata__", None) or {} + if meta.get("origin") is Shape and (args := meta.get("args")): + return args[0] + return None + + +class Shape(BaseModel, Generic[_M]): + _model: ClassVar[Any] + _nested: ClassVar[dict[str, type[Shape]]] + _field_names: ClassVar[list[str]] + _spec: ClassVar[list] + _pair: ClassVar[tuple] + + def __init_subclass__(cls, **kwargs): + super().__init_subclass__(**kwargs) + + if not (model := _resolve_model(cls)): + return + + cls._model = model + cls._nested = {} + + hints = get_type_hints(cls, include_extras=False) or cls.__annotations__ + field_names = [] + + for name, hint in hints.items(): + if name.startswith("_"): + continue + if shape_cls := _extract_shape_class(hint): + cls._nested[name] = shape_cls + else: + field_names.append(name) + + cls._field_names = field_names + cls._spec = [ + *field_names, + *({name: shape._spec} for name, shape in cls._nested.items()), + ] + cls._pair = specs.process(cls._spec) + + @classmethod + def _build_pair(cls, relation_qs: dict[str, Any]): + field_pairs = [ + pairs.producer_to_projector(name, pairs.field(name)) + for name in cls._field_names + ] + + rel_pairs = [] + for name, shape_cls in cls._nested.items(): + child_prepare, child_project = shape_cls._pair + prepare = ( + readers_qs.pipe(relation_qs[name], child_prepare) + if name in relation_qs + else child_prepare + ) + rel_pairs.append( + pairs.producer_to_projector( + name, pairs.relationship(name, (prepare, child_project)) + ) + ) + + return pairs.combine(*field_pairs, *rel_pairs) + + @classmethod + def query(cls: type[_S], *qs_fns, **relation_qs) -> list[_S]: + prepare, project = cls._build_pair(relation_qs) if relation_qs else cls._pair + base = cls._model.objects.all() + + # Accept a raw QuerySet as the first arg, or qs functions, or nothing + if qs_fns and hasattr(qs_fns[0], "query"): + base, qs_fns = qs_fns[0], qs_fns[1:] + + queryset = readers_qs.pipe(prepare, *qs_fns)(base) + return [cls.model_validate(project(obj)) for obj in queryset] + + def diff(self) -> Diff: + cls = type(self) + pk = getattr(self, "id", None) + if pk: + results = cls.query(cls._model.objects.filter(pk=pk)) + if not results: + raise cls._model.DoesNotExist(f"{cls._model.__name__} with id={pk} does not exist") + current = results[0] + else: + current = None + + changed = ( + { + k: getattr(self, k) + for k in cls._field_names + if k != "id" and getattr(self, k) != getattr(current, k) + } + if current + else {k: getattr(self, k) for k in cls._field_names if k != "id"} + ) + + nested = {} + for name in cls._nested: + incoming_items = getattr(self, name, None) or [] + current_items = getattr(current, name, None) or [] if current else [] + + if not isinstance(incoming_items, list): + incoming_items = [incoming_items] + if not isinstance(current_items, list): + current_items = [current_items] + + current_by_id = {c.id: c for c in current_items if c.id is not None} + incoming_by_id = {c.id: c for c in incoming_items if c.id is not None} + + nested[name] = NestedDiff( + created=[c for c in incoming_items if c.id is None], + updated=[ + c for id, c in incoming_by_id.items() + if id in current_by_id and c != current_by_id[id] + ], + deleted=[id for id in current_by_id if id not in incoming_by_id], + ) + + return Diff(is_new=current is None, changed=changed, _nested=nested) + + +class NestedDiff: + __slots__ = ("created", "updated", "deleted") + + def __init__(self, created=(), updated=(), deleted=()): + self.created = list(created) + self.updated = list(updated) + self.deleted = list(deleted) + + +class Diff: + __slots__ = ("is_new", "changed", "_nested") + + def __init__(self, is_new: bool, changed: dict[str, Any], _nested: dict[str, NestedDiff]): + self.is_new = is_new + self.changed = changed + self._nested = _nested + + def __getattr__(self, name: str) -> NestedDiff: + if name in self._nested: + return self._nested[name] + return NestedDiff()