from __future__ import annotations import types from typing import Any, ClassVar, Generic, TypeVar, Union, get_type_hints from pydantic import BaseModel from django_readers import pairs, specs from django_readers import qs as readers_qs _M = TypeVar("_M") _S = TypeVar("_S", bound="Shape") def _extract_shape_class(hint) -> type[Shape] | None: origin = getattr(hint, "__origin__", None) args = getattr(hint, "__args__", ()) # list[SomeShape] if ( origin is list and args and isinstance(args[0], type) and issubclass(args[0], Shape) ): return args[0] # SomeShape (bare) if isinstance(hint, type) and issubclass(hint, Shape) and hint is not Shape: return hint # SomeShape | None (Union/Optional) if origin is Union or isinstance(hint, types.UnionType): for arg in args: if arg is type(None): continue if isinstance(arg, type) and issubclass(arg, Shape) and arg is not Shape: return arg return None def _resolve_model(cls) -> Any | None: for base in cls.__bases__: meta = getattr(base, "__pydantic_generic_metadata__", None) or {} if meta.get("origin") is Shape and (args := meta.get("args")): return args[0] return None class Shape(BaseModel, Generic[_M]): _model: ClassVar[Any] _nested: ClassVar[dict[str, type[Shape]]] _field_names: ClassVar[list[str]] _pk_field: ClassVar[str] _spec: ClassVar[list] _pair: ClassVar[tuple] def __init_subclass__(cls, **kwargs): super().__init_subclass__(**kwargs) if not (model := _resolve_model(cls)): return cls._model = model cls._nested = {} cls._pk_field = model._meta.pk.name if model._meta.pk else "id" hints = get_type_hints(cls, include_extras=False) or cls.__annotations__ field_names = [] for name, hint in hints.items(): if name.startswith("_"): continue if shape_cls := _extract_shape_class(hint): cls._nested[name] = shape_cls else: field_names.append(name) cls._field_names = field_names cls._spec = [ *field_names, *({name: shape._spec} for name, shape in cls._nested.items()), ] cls._pair = specs.process(cls._spec) @classmethod def _build_pair(cls, relation_qs: dict[str, Any]): field_pairs = [ pairs.producer_to_projector(name, pairs.field(name)) for name in cls._field_names ] rel_pairs = [] for name, shape_cls in cls._nested.items(): child_prepare, child_project = shape_cls._pair prepare = ( readers_qs.pipe(relation_qs[name], child_prepare) if name in relation_qs else child_prepare ) rel_pairs.append( pairs.producer_to_projector( name, pairs.relationship(name, (prepare, child_project)) ) ) return pairs.combine(*field_pairs, *rel_pairs) @classmethod def _get_pk(cls, instance) -> Any | None: return getattr(instance, cls._pk_field, None) @classmethod def query(cls: type[_S], *qs_fns, **relation_qs) -> list[_S]: prepare, project = cls._build_pair(relation_qs) if relation_qs else cls._pair base = cls._model.objects.all() # Accept a raw QuerySet as the first arg, or qs functions, or nothing if qs_fns and hasattr(qs_fns[0], "query"): base, qs_fns = qs_fns[0], qs_fns[1:] queryset = readers_qs.pipe(prepare, *qs_fns)(base) return [cls.model_validate(project(obj)) for obj in queryset] @classmethod def diff_many(cls: type[_S], items: list[_S]) -> list[tuple[_S, Diff]]: pk_field = cls._pk_field pk_map: dict[Any, _S] = {} new_items: list[_S] = [] for item in items: pk = cls._get_pk(item) if pk is not None: pk_map[pk] = item else: new_items.append(item) # Single query for all existing items current_map: dict[Any, _S] = {} if pk_map: current_items = cls.query( cls._model.objects.filter(**{f"{pk_field}__in": pk_map.keys()}) ) current_map = {cls._get_pk(c): c for c in current_items} results: list[tuple[_S, Diff]] = [] for item in new_items: results.append((item, cls._diff_one(item, None))) for pk, item in pk_map.items(): current = current_map.get(pk) if current is None: raise cls._model.DoesNotExist( f"{cls._model.__name__} with {pk_field}={pk} does not exist" ) results.append((item, cls._diff_one(item, current))) return results @classmethod def _diff_one(cls, incoming: _S, current: _S | None) -> Diff: pk_field = cls._pk_field changed = ( { k: getattr(incoming, k) for k in cls._field_names if k != pk_field and getattr(incoming, k) != getattr(current, k) } if current else {k: getattr(incoming, k) for k in cls._field_names if k != pk_field} ) nested = {} for name, shape_cls in cls._nested.items(): incoming_items = getattr(incoming, name, None) or [] current_items = getattr(current, name, None) or [] if current else [] if not isinstance(incoming_items, list): incoming_items = [incoming_items] if not isinstance(current_items, list): current_items = [current_items] child_pk = shape_cls._pk_field current_by_pk = { shape_cls._get_pk(c): c for c in current_items if shape_cls._get_pk(c) is not None } incoming_by_pk = { shape_cls._get_pk(c): c for c in incoming_items if shape_cls._get_pk(c) is not None } nested[name] = NestedDiff( created=[c for c in incoming_items if shape_cls._get_pk(c) is None], updated=[ c for pk, c in incoming_by_pk.items() if pk in current_by_pk and c != current_by_pk[pk] ], deleted=[pk for pk in current_by_pk if pk not in incoming_by_pk], ) return Diff(is_new=current is None, changed=changed, _nested=nested) def diff(self) -> Diff: cls = type(self) pk = cls._get_pk(self) if pk is not None: results = cls.query(cls._model.objects.filter(pk=pk)) if not results: raise cls._model.DoesNotExist( f"{cls._model.__name__} with {cls._pk_field}={pk} does not exist" ) current = results[0] else: current = None return cls._diff_one(self, current) class NestedDiff: __slots__ = ("created", "updated", "deleted") def __init__(self, created=(), updated=(), deleted=()): self.created = list(created) self.updated = list(updated) self.deleted = list(deleted) class Diff: __slots__ = ("is_new", "changed", "_nested") def __init__( self, is_new: bool, changed: dict[str, Any], _nested: dict[str, NestedDiff] ): self.is_new = is_new self.changed = changed self._nested = _nested def nested(self, name: str) -> NestedDiff: """Strict access to nested diffs. Raises KeyError for invalid names.""" if name not in self._nested: valid = ", ".join(sorted(self._nested)) or "(none)" raise KeyError(f"No nested diff for '{name}'. Valid nested shapes: {valid}") return self._nested[name] def __getattr__(self, name: str) -> NestedDiff: if name.startswith("_"): raise AttributeError(name) if name not in self._nested: valid = ", ".join(sorted(self._nested)) or "(none)" raise AttributeError( f"No nested diff for '{name}'. Valid nested shapes: {valid}" ) return self._nested[name]