_extract_shape_class now handles `Shape | None` (Union types) by checking isinstance(hint, types.UnionType) and iterating args for Shape subclasses. This fixes nullable FK detection — any `editor: AuthorShape | None` field is now correctly recognized as a nested shape. 48 stress tests covering: - 5-level deep nesting (Publisher → Author → Book → Chapter → Section) - Two FKs to same model (author + editor) - Slug PK (Tag), UUID PK (Section) - M2M relationships (Book.tags) - Nullable FKs returning None - Empty strings, zero integers, false booleans (truthiness traps) - 100-record smoke test - Query efficiency (assertNumQueries) - All diff operations with deep nesting Known gap documented: self-referential forward refs (CategoryShape) crash get_type_hints() at __init_subclass__ time. Needs deferred resolution. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
261 lines
8.3 KiB
Python
261 lines
8.3 KiB
Python
from __future__ import annotations
|
|
|
|
import types
|
|
from typing import Any, ClassVar, Generic, TypeVar, Union, get_type_hints
|
|
|
|
from pydantic import BaseModel
|
|
|
|
from django_readers import pairs, specs
|
|
from django_readers import qs as readers_qs
|
|
|
|
_M = TypeVar("_M")
|
|
_S = TypeVar("_S", bound="Shape")
|
|
|
|
|
|
def _extract_shape_class(hint) -> type[Shape] | None:
|
|
origin = getattr(hint, "__origin__", None)
|
|
args = getattr(hint, "__args__", ())
|
|
|
|
# list[SomeShape]
|
|
if (
|
|
origin is list
|
|
and args
|
|
and isinstance(args[0], type)
|
|
and issubclass(args[0], Shape)
|
|
):
|
|
return args[0]
|
|
|
|
# SomeShape (bare)
|
|
if isinstance(hint, type) and issubclass(hint, Shape) and hint is not Shape:
|
|
return hint
|
|
|
|
# SomeShape | None (Union/Optional)
|
|
if origin is Union or isinstance(hint, types.UnionType):
|
|
for arg in args:
|
|
if arg is type(None):
|
|
continue
|
|
if isinstance(arg, type) and issubclass(arg, Shape) and arg is not Shape:
|
|
return arg
|
|
|
|
return None
|
|
|
|
|
|
def _resolve_model(cls) -> Any | None:
|
|
for base in cls.__bases__:
|
|
meta = getattr(base, "__pydantic_generic_metadata__", None) or {}
|
|
if meta.get("origin") is Shape and (args := meta.get("args")):
|
|
return args[0]
|
|
return None
|
|
|
|
|
|
class Shape(BaseModel, Generic[_M]):
|
|
_model: ClassVar[Any]
|
|
_nested: ClassVar[dict[str, type[Shape]]]
|
|
_field_names: ClassVar[list[str]]
|
|
_pk_field: ClassVar[str]
|
|
_spec: ClassVar[list]
|
|
_pair: ClassVar[tuple]
|
|
|
|
def __init_subclass__(cls, **kwargs):
|
|
super().__init_subclass__(**kwargs)
|
|
|
|
if not (model := _resolve_model(cls)):
|
|
return
|
|
|
|
cls._model = model
|
|
cls._nested = {}
|
|
cls._pk_field = model._meta.pk.name if model._meta.pk else "id"
|
|
|
|
hints = get_type_hints(cls, include_extras=False) or cls.__annotations__
|
|
field_names = []
|
|
|
|
for name, hint in hints.items():
|
|
if name.startswith("_"):
|
|
continue
|
|
if shape_cls := _extract_shape_class(hint):
|
|
cls._nested[name] = shape_cls
|
|
else:
|
|
field_names.append(name)
|
|
|
|
cls._field_names = field_names
|
|
cls._spec = [
|
|
*field_names,
|
|
*({name: shape._spec} for name, shape in cls._nested.items()),
|
|
]
|
|
cls._pair = specs.process(cls._spec)
|
|
|
|
@classmethod
|
|
def _build_pair(cls, relation_qs: dict[str, Any]):
|
|
field_pairs = [
|
|
pairs.producer_to_projector(name, pairs.field(name))
|
|
for name in cls._field_names
|
|
]
|
|
|
|
rel_pairs = []
|
|
for name, shape_cls in cls._nested.items():
|
|
child_prepare, child_project = shape_cls._pair
|
|
prepare = (
|
|
readers_qs.pipe(relation_qs[name], child_prepare)
|
|
if name in relation_qs
|
|
else child_prepare
|
|
)
|
|
rel_pairs.append(
|
|
pairs.producer_to_projector(
|
|
name, pairs.relationship(name, (prepare, child_project))
|
|
)
|
|
)
|
|
|
|
return pairs.combine(*field_pairs, *rel_pairs)
|
|
|
|
@classmethod
|
|
def _get_pk(cls, instance) -> Any | None:
|
|
return getattr(instance, cls._pk_field, None)
|
|
|
|
@classmethod
|
|
def query(cls: type[_S], *qs_fns, **relation_qs) -> list[_S]:
|
|
prepare, project = cls._build_pair(relation_qs) if relation_qs else cls._pair
|
|
base = cls._model.objects.all()
|
|
|
|
# Accept a raw QuerySet as the first arg, or qs functions, or nothing
|
|
if qs_fns and hasattr(qs_fns[0], "query"):
|
|
base, qs_fns = qs_fns[0], qs_fns[1:]
|
|
|
|
queryset = readers_qs.pipe(prepare, *qs_fns)(base)
|
|
return [cls.model_validate(project(obj)) for obj in queryset]
|
|
|
|
@classmethod
|
|
def diff_many(cls: type[_S], items: list[_S]) -> list[tuple[_S, Diff]]:
|
|
pk_field = cls._pk_field
|
|
pk_map: dict[Any, _S] = {}
|
|
new_items: list[_S] = []
|
|
|
|
for item in items:
|
|
pk = cls._get_pk(item)
|
|
if pk is not None:
|
|
pk_map[pk] = item
|
|
else:
|
|
new_items.append(item)
|
|
|
|
# Single query for all existing items
|
|
current_map: dict[Any, _S] = {}
|
|
if pk_map:
|
|
current_items = cls.query(
|
|
cls._model.objects.filter(**{f"{pk_field}__in": pk_map.keys()})
|
|
)
|
|
current_map = {cls._get_pk(c): c for c in current_items}
|
|
|
|
results: list[tuple[_S, Diff]] = []
|
|
|
|
for item in new_items:
|
|
results.append((item, cls._diff_one(item, None)))
|
|
|
|
for pk, item in pk_map.items():
|
|
current = current_map.get(pk)
|
|
if current is None:
|
|
raise cls._model.DoesNotExist(
|
|
f"{cls._model.__name__} with {pk_field}={pk} does not exist"
|
|
)
|
|
results.append((item, cls._diff_one(item, current)))
|
|
|
|
return results
|
|
|
|
@classmethod
|
|
def _diff_one(cls, incoming: _S, current: _S | None) -> Diff:
|
|
pk_field = cls._pk_field
|
|
|
|
changed = (
|
|
{
|
|
k: getattr(incoming, k)
|
|
for k in cls._field_names
|
|
if k != pk_field and getattr(incoming, k) != getattr(current, k)
|
|
}
|
|
if current
|
|
else {k: getattr(incoming, k) for k in cls._field_names if k != pk_field}
|
|
)
|
|
|
|
nested = {}
|
|
for name, shape_cls in cls._nested.items():
|
|
incoming_items = getattr(incoming, name, None) or []
|
|
current_items = getattr(current, name, None) or [] if current else []
|
|
|
|
if not isinstance(incoming_items, list):
|
|
incoming_items = [incoming_items]
|
|
if not isinstance(current_items, list):
|
|
current_items = [current_items]
|
|
|
|
child_pk = shape_cls._pk_field
|
|
current_by_pk = {
|
|
shape_cls._get_pk(c): c
|
|
for c in current_items
|
|
if shape_cls._get_pk(c) is not None
|
|
}
|
|
incoming_by_pk = {
|
|
shape_cls._get_pk(c): c
|
|
for c in incoming_items
|
|
if shape_cls._get_pk(c) is not None
|
|
}
|
|
|
|
nested[name] = NestedDiff(
|
|
created=[c for c in incoming_items if shape_cls._get_pk(c) is None],
|
|
updated=[
|
|
c
|
|
for pk, c in incoming_by_pk.items()
|
|
if pk in current_by_pk and c != current_by_pk[pk]
|
|
],
|
|
deleted=[pk for pk in current_by_pk if pk not in incoming_by_pk],
|
|
)
|
|
|
|
return Diff(is_new=current is None, changed=changed, _nested=nested)
|
|
|
|
def diff(self) -> Diff:
|
|
cls = type(self)
|
|
pk = cls._get_pk(self)
|
|
if pk is not None:
|
|
results = cls.query(cls._model.objects.filter(pk=pk))
|
|
if not results:
|
|
raise cls._model.DoesNotExist(
|
|
f"{cls._model.__name__} with {cls._pk_field}={pk} does not exist"
|
|
)
|
|
current = results[0]
|
|
else:
|
|
current = None
|
|
|
|
return cls._diff_one(self, current)
|
|
|
|
|
|
class NestedDiff:
|
|
__slots__ = ("created", "updated", "deleted")
|
|
|
|
def __init__(self, created=(), updated=(), deleted=()):
|
|
self.created = list(created)
|
|
self.updated = list(updated)
|
|
self.deleted = list(deleted)
|
|
|
|
|
|
class Diff:
|
|
__slots__ = ("is_new", "changed", "_nested")
|
|
|
|
def __init__(
|
|
self, is_new: bool, changed: dict[str, Any], _nested: dict[str, NestedDiff]
|
|
):
|
|
self.is_new = is_new
|
|
self.changed = changed
|
|
self._nested = _nested
|
|
|
|
def nested(self, name: str) -> NestedDiff:
|
|
"""Strict access to nested diffs. Raises KeyError for invalid names."""
|
|
if name not in self._nested:
|
|
valid = ", ".join(sorted(self._nested)) or "(none)"
|
|
raise KeyError(f"No nested diff for '{name}'. Valid nested shapes: {valid}")
|
|
return self._nested[name]
|
|
|
|
def __getattr__(self, name: str) -> NestedDiff:
|
|
if name.startswith("_"):
|
|
raise AttributeError(name)
|
|
if name not in self._nested:
|
|
valid = ", ".join(sorted(self._nested)) or "(none)"
|
|
raise AttributeError(
|
|
f"No nested diff for '{name}'. Valid nested shapes: {valid}"
|
|
)
|
|
return self._nested[name]
|