Add shapes module: Pydantic API surface for Django models
Imported from separate development branch. Provides Shape, Diff, and NestedDiff classes for defining typed Pydantic schemas backed by Django model querysets via django-readers. Optional dependency: install with djarea[shapes] to get django-readers. Import is guarded so the rest of djarea works without it. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -89,6 +89,10 @@ from . import setup
|
||||
from .channels import ReactChannel
|
||||
from .channels import register as register_channel
|
||||
from .client import ComposedContext, ServerFunction, client, compose
|
||||
try:
|
||||
from .shapes import Shape
|
||||
except ImportError:
|
||||
pass # django-readers not installed
|
||||
from .setup import (
|
||||
djarea_clients,
|
||||
djarea_module,
|
||||
@@ -167,6 +171,8 @@ __all__ = [
|
||||
# Channels
|
||||
"ReactChannel",
|
||||
"register_channel",
|
||||
# Shapes
|
||||
"Shape",
|
||||
# Submodules
|
||||
"client_module",
|
||||
"setup",
|
||||
|
||||
3
django/src/djarea/shapes/__init__.py
Normal file
3
django/src/djarea/shapes/__init__.py
Normal file
@@ -0,0 +1,3 @@
|
||||
from djarea.shapes.core import Diff, NestedDiff, Shape
|
||||
|
||||
__all__ = ["Diff", "NestedDiff", "Shape"]
|
||||
168
django/src/djarea/shapes/core.py
Normal file
168
django/src/djarea/shapes/core.py
Normal file
@@ -0,0 +1,168 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, ClassVar, Generic, TypeVar, get_type_hints
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from django_readers import pairs, specs
|
||||
from django_readers import qs as readers_qs
|
||||
|
||||
_M = TypeVar("_M")
|
||||
_S = TypeVar("_S", bound="Shape")
|
||||
|
||||
|
||||
def _extract_shape_class(hint) -> type[Shape] | None:
|
||||
origin = getattr(hint, "__origin__", None)
|
||||
args = getattr(hint, "__args__", ())
|
||||
|
||||
if origin is list and args and isinstance(args[0], type) and issubclass(args[0], Shape):
|
||||
return args[0]
|
||||
if isinstance(hint, type) and issubclass(hint, Shape) and hint is not Shape:
|
||||
return hint
|
||||
return None
|
||||
|
||||
|
||||
def _resolve_model(cls) -> Any | None:
|
||||
for base in cls.__bases__:
|
||||
meta = getattr(base, "__pydantic_generic_metadata__", None) or {}
|
||||
if meta.get("origin") is Shape and (args := meta.get("args")):
|
||||
return args[0]
|
||||
return None
|
||||
|
||||
|
||||
class Shape(BaseModel, Generic[_M]):
|
||||
_model: ClassVar[Any]
|
||||
_nested: ClassVar[dict[str, type[Shape]]]
|
||||
_field_names: ClassVar[list[str]]
|
||||
_spec: ClassVar[list]
|
||||
_pair: ClassVar[tuple]
|
||||
|
||||
def __init_subclass__(cls, **kwargs):
|
||||
super().__init_subclass__(**kwargs)
|
||||
|
||||
if not (model := _resolve_model(cls)):
|
||||
return
|
||||
|
||||
cls._model = model
|
||||
cls._nested = {}
|
||||
|
||||
hints = get_type_hints(cls, include_extras=False) or cls.__annotations__
|
||||
field_names = []
|
||||
|
||||
for name, hint in hints.items():
|
||||
if name.startswith("_"):
|
||||
continue
|
||||
if shape_cls := _extract_shape_class(hint):
|
||||
cls._nested[name] = shape_cls
|
||||
else:
|
||||
field_names.append(name)
|
||||
|
||||
cls._field_names = field_names
|
||||
cls._spec = [
|
||||
*field_names,
|
||||
*({name: shape._spec} for name, shape in cls._nested.items()),
|
||||
]
|
||||
cls._pair = specs.process(cls._spec)
|
||||
|
||||
@classmethod
|
||||
def _build_pair(cls, relation_qs: dict[str, Any]):
|
||||
field_pairs = [
|
||||
pairs.producer_to_projector(name, pairs.field(name))
|
||||
for name in cls._field_names
|
||||
]
|
||||
|
||||
rel_pairs = []
|
||||
for name, shape_cls in cls._nested.items():
|
||||
child_prepare, child_project = shape_cls._pair
|
||||
prepare = (
|
||||
readers_qs.pipe(relation_qs[name], child_prepare)
|
||||
if name in relation_qs
|
||||
else child_prepare
|
||||
)
|
||||
rel_pairs.append(
|
||||
pairs.producer_to_projector(
|
||||
name, pairs.relationship(name, (prepare, child_project))
|
||||
)
|
||||
)
|
||||
|
||||
return pairs.combine(*field_pairs, *rel_pairs)
|
||||
|
||||
@classmethod
|
||||
def query(cls: type[_S], *qs_fns, **relation_qs) -> list[_S]:
|
||||
prepare, project = cls._build_pair(relation_qs) if relation_qs else cls._pair
|
||||
base = cls._model.objects.all()
|
||||
|
||||
# Accept a raw QuerySet as the first arg, or qs functions, or nothing
|
||||
if qs_fns and hasattr(qs_fns[0], "query"):
|
||||
base, qs_fns = qs_fns[0], qs_fns[1:]
|
||||
|
||||
queryset = readers_qs.pipe(prepare, *qs_fns)(base)
|
||||
return [cls.model_validate(project(obj)) for obj in queryset]
|
||||
|
||||
def diff(self) -> Diff:
|
||||
cls = type(self)
|
||||
pk = getattr(self, "id", None)
|
||||
if pk:
|
||||
results = cls.query(cls._model.objects.filter(pk=pk))
|
||||
if not results:
|
||||
raise cls._model.DoesNotExist(f"{cls._model.__name__} with id={pk} does not exist")
|
||||
current = results[0]
|
||||
else:
|
||||
current = None
|
||||
|
||||
changed = (
|
||||
{
|
||||
k: getattr(self, k)
|
||||
for k in cls._field_names
|
||||
if k != "id" and getattr(self, k) != getattr(current, k)
|
||||
}
|
||||
if current
|
||||
else {k: getattr(self, k) for k in cls._field_names if k != "id"}
|
||||
)
|
||||
|
||||
nested = {}
|
||||
for name in cls._nested:
|
||||
incoming_items = getattr(self, name, None) or []
|
||||
current_items = getattr(current, name, None) or [] if current else []
|
||||
|
||||
if not isinstance(incoming_items, list):
|
||||
incoming_items = [incoming_items]
|
||||
if not isinstance(current_items, list):
|
||||
current_items = [current_items]
|
||||
|
||||
current_by_id = {c.id: c for c in current_items if c.id is not None}
|
||||
incoming_by_id = {c.id: c for c in incoming_items if c.id is not None}
|
||||
|
||||
nested[name] = NestedDiff(
|
||||
created=[c for c in incoming_items if c.id is None],
|
||||
updated=[
|
||||
c for id, c in incoming_by_id.items()
|
||||
if id in current_by_id and c != current_by_id[id]
|
||||
],
|
||||
deleted=[id for id in current_by_id if id not in incoming_by_id],
|
||||
)
|
||||
|
||||
return Diff(is_new=current is None, changed=changed, _nested=nested)
|
||||
|
||||
|
||||
class NestedDiff:
|
||||
__slots__ = ("created", "updated", "deleted")
|
||||
|
||||
def __init__(self, created=(), updated=(), deleted=()):
|
||||
self.created = list(created)
|
||||
self.updated = list(updated)
|
||||
self.deleted = list(deleted)
|
||||
|
||||
|
||||
class Diff:
|
||||
__slots__ = ("is_new", "changed", "_nested")
|
||||
|
||||
def __init__(self, is_new: bool, changed: dict[str, Any], _nested: dict[str, NestedDiff]):
|
||||
self.is_new = is_new
|
||||
self.changed = changed
|
||||
self._nested = _nested
|
||||
|
||||
def __getattr__(self, name: str) -> NestedDiff:
|
||||
if name in self._nested:
|
||||
return self._nested[name]
|
||||
return NestedDiff()
|
||||
Reference in New Issue
Block a user