Files
mizan/django/src/djarea/shapes/core.py
Ryth Azhur 625d8cf9b9 Fix Optional[Shape] unwrapping and add comprehensive shapes stress tests
_extract_shape_class now handles `Shape | None` (Union types) by checking
isinstance(hint, types.UnionType) and iterating args for Shape subclasses.
This fixes nullable FK detection — any `editor: AuthorShape | None` field
is now correctly recognized as a nested shape.

48 stress tests covering:
- 5-level deep nesting (Publisher → Author → Book → Chapter → Section)
- Two FKs to same model (author + editor)
- Slug PK (Tag), UUID PK (Section)
- M2M relationships (Book.tags)
- Nullable FKs returning None
- Empty strings, zero integers, false booleans (truthiness traps)
- 100-record smoke test
- Query efficiency (assertNumQueries)
- All diff operations with deep nesting

Known gap documented: self-referential forward refs (CategoryShape)
crash get_type_hints() at __init_subclass__ time. Needs deferred resolution.

Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
2026-03-31 02:56:11 -04:00

261 lines
8.3 KiB
Python

from __future__ import annotations
import types
from typing import Any, ClassVar, Generic, TypeVar, Union, get_type_hints
from pydantic import BaseModel
from django_readers import pairs, specs
from django_readers import qs as readers_qs
_M = TypeVar("_M")
_S = TypeVar("_S", bound="Shape")
def _extract_shape_class(hint) -> type[Shape] | None:
origin = getattr(hint, "__origin__", None)
args = getattr(hint, "__args__", ())
# list[SomeShape]
if (
origin is list
and args
and isinstance(args[0], type)
and issubclass(args[0], Shape)
):
return args[0]
# SomeShape (bare)
if isinstance(hint, type) and issubclass(hint, Shape) and hint is not Shape:
return hint
# SomeShape | None (Union/Optional)
if origin is Union or isinstance(hint, types.UnionType):
for arg in args:
if arg is type(None):
continue
if isinstance(arg, type) and issubclass(arg, Shape) and arg is not Shape:
return arg
return None
def _resolve_model(cls) -> Any | None:
for base in cls.__bases__:
meta = getattr(base, "__pydantic_generic_metadata__", None) or {}
if meta.get("origin") is Shape and (args := meta.get("args")):
return args[0]
return None
class Shape(BaseModel, Generic[_M]):
_model: ClassVar[Any]
_nested: ClassVar[dict[str, type[Shape]]]
_field_names: ClassVar[list[str]]
_pk_field: ClassVar[str]
_spec: ClassVar[list]
_pair: ClassVar[tuple]
def __init_subclass__(cls, **kwargs):
super().__init_subclass__(**kwargs)
if not (model := _resolve_model(cls)):
return
cls._model = model
cls._nested = {}
cls._pk_field = model._meta.pk.name if model._meta.pk else "id"
hints = get_type_hints(cls, include_extras=False) or cls.__annotations__
field_names = []
for name, hint in hints.items():
if name.startswith("_"):
continue
if shape_cls := _extract_shape_class(hint):
cls._nested[name] = shape_cls
else:
field_names.append(name)
cls._field_names = field_names
cls._spec = [
*field_names,
*({name: shape._spec} for name, shape in cls._nested.items()),
]
cls._pair = specs.process(cls._spec)
@classmethod
def _build_pair(cls, relation_qs: dict[str, Any]):
field_pairs = [
pairs.producer_to_projector(name, pairs.field(name))
for name in cls._field_names
]
rel_pairs = []
for name, shape_cls in cls._nested.items():
child_prepare, child_project = shape_cls._pair
prepare = (
readers_qs.pipe(relation_qs[name], child_prepare)
if name in relation_qs
else child_prepare
)
rel_pairs.append(
pairs.producer_to_projector(
name, pairs.relationship(name, (prepare, child_project))
)
)
return pairs.combine(*field_pairs, *rel_pairs)
@classmethod
def _get_pk(cls, instance) -> Any | None:
return getattr(instance, cls._pk_field, None)
@classmethod
def query(cls: type[_S], *qs_fns, **relation_qs) -> list[_S]:
prepare, project = cls._build_pair(relation_qs) if relation_qs else cls._pair
base = cls._model.objects.all()
# Accept a raw QuerySet as the first arg, or qs functions, or nothing
if qs_fns and hasattr(qs_fns[0], "query"):
base, qs_fns = qs_fns[0], qs_fns[1:]
queryset = readers_qs.pipe(prepare, *qs_fns)(base)
return [cls.model_validate(project(obj)) for obj in queryset]
@classmethod
def diff_many(cls: type[_S], items: list[_S]) -> list[tuple[_S, Diff]]:
pk_field = cls._pk_field
pk_map: dict[Any, _S] = {}
new_items: list[_S] = []
for item in items:
pk = cls._get_pk(item)
if pk is not None:
pk_map[pk] = item
else:
new_items.append(item)
# Single query for all existing items
current_map: dict[Any, _S] = {}
if pk_map:
current_items = cls.query(
cls._model.objects.filter(**{f"{pk_field}__in": pk_map.keys()})
)
current_map = {cls._get_pk(c): c for c in current_items}
results: list[tuple[_S, Diff]] = []
for item in new_items:
results.append((item, cls._diff_one(item, None)))
for pk, item in pk_map.items():
current = current_map.get(pk)
if current is None:
raise cls._model.DoesNotExist(
f"{cls._model.__name__} with {pk_field}={pk} does not exist"
)
results.append((item, cls._diff_one(item, current)))
return results
@classmethod
def _diff_one(cls, incoming: _S, current: _S | None) -> Diff:
pk_field = cls._pk_field
changed = (
{
k: getattr(incoming, k)
for k in cls._field_names
if k != pk_field and getattr(incoming, k) != getattr(current, k)
}
if current
else {k: getattr(incoming, k) for k in cls._field_names if k != pk_field}
)
nested = {}
for name, shape_cls in cls._nested.items():
incoming_items = getattr(incoming, name, None) or []
current_items = getattr(current, name, None) or [] if current else []
if not isinstance(incoming_items, list):
incoming_items = [incoming_items]
if not isinstance(current_items, list):
current_items = [current_items]
child_pk = shape_cls._pk_field
current_by_pk = {
shape_cls._get_pk(c): c
for c in current_items
if shape_cls._get_pk(c) is not None
}
incoming_by_pk = {
shape_cls._get_pk(c): c
for c in incoming_items
if shape_cls._get_pk(c) is not None
}
nested[name] = NestedDiff(
created=[c for c in incoming_items if shape_cls._get_pk(c) is None],
updated=[
c
for pk, c in incoming_by_pk.items()
if pk in current_by_pk and c != current_by_pk[pk]
],
deleted=[pk for pk in current_by_pk if pk not in incoming_by_pk],
)
return Diff(is_new=current is None, changed=changed, _nested=nested)
def diff(self) -> Diff:
cls = type(self)
pk = cls._get_pk(self)
if pk is not None:
results = cls.query(cls._model.objects.filter(pk=pk))
if not results:
raise cls._model.DoesNotExist(
f"{cls._model.__name__} with {cls._pk_field}={pk} does not exist"
)
current = results[0]
else:
current = None
return cls._diff_one(self, current)
class NestedDiff:
__slots__ = ("created", "updated", "deleted")
def __init__(self, created=(), updated=(), deleted=()):
self.created = list(created)
self.updated = list(updated)
self.deleted = list(deleted)
class Diff:
__slots__ = ("is_new", "changed", "_nested")
def __init__(
self, is_new: bool, changed: dict[str, Any], _nested: dict[str, NestedDiff]
):
self.is_new = is_new
self.changed = changed
self._nested = _nested
def nested(self, name: str) -> NestedDiff:
"""Strict access to nested diffs. Raises KeyError for invalid names."""
if name not in self._nested:
valid = ", ".join(sorted(self._nested)) or "(none)"
raise KeyError(f"No nested diff for '{name}'. Valid nested shapes: {valid}")
return self._nested[name]
def __getattr__(self, name: str) -> NestedDiff:
if name.startswith("_"):
raise AttributeError(name)
if name not in self._nested:
valid = ", ".join(sorted(self._nested)) or "(none)"
raise AttributeError(
f"No nested diff for '{name}'. Valid nested shapes: {valid}"
)
return self._nested[name]