From 5a56d7a4a5f17dfe9a90b6878c82514dba2797ac Mon Sep 17 00:00:00 2001 From: Ryth Azhur Date: Tue, 31 Mar 2026 02:25:27 -0400 Subject: [PATCH] Update shapes tests for pk abstraction, strict Diff, and diff_many 13 new tests covering three changes from claude.ai: - pk abstraction: _pk_field resolved from model._meta, _get_pk helper - Strict Diff.__getattr__: typos raise AttributeError with valid names, nested() method raises KeyError for explicit access - diff_many: batched query (assertNumQueries(1)), mixed new/existing, empty list, all-new, nonexistent raises 38 shapes tests total, all passing. Co-Authored-By: Claude Opus 4.6 (1M context) --- django/src/djarea/shapes/core.py | 131 +++++++++++++++++----- django/src/djarea/tests/test_shapes.py | 148 +++++++++++++++++++++++-- 2 files changed, 246 insertions(+), 33 deletions(-) diff --git a/django/src/djarea/shapes/core.py b/django/src/djarea/shapes/core.py index 58a50b2..505a37c 100644 --- a/django/src/djarea/shapes/core.py +++ b/django/src/djarea/shapes/core.py @@ -15,7 +15,12 @@ def _extract_shape_class(hint) -> type[Shape] | None: origin = getattr(hint, "__origin__", None) args = getattr(hint, "__args__", ()) - if origin is list and args and isinstance(args[0], type) and issubclass(args[0], Shape): + if ( + origin is list + and args + and isinstance(args[0], type) + and issubclass(args[0], Shape) + ): return args[0] if isinstance(hint, type) and issubclass(hint, Shape) and hint is not Shape: return hint @@ -34,6 +39,7 @@ class Shape(BaseModel, Generic[_M]): _model: ClassVar[Any] _nested: ClassVar[dict[str, type[Shape]]] _field_names: ClassVar[list[str]] + _pk_field: ClassVar[str] _spec: ClassVar[list] _pair: ClassVar[tuple] @@ -45,6 +51,7 @@ class Shape(BaseModel, Generic[_M]): cls._model = model cls._nested = {} + cls._pk_field = model._meta.pk.name if model._meta.pk else "id" hints = get_type_hints(cls, include_extras=False) or cls.__annotations__ field_names = [] @@ -87,6 +94,10 @@ class Shape(BaseModel, Generic[_M]): return pairs.combine(*field_pairs, *rel_pairs) + @classmethod + def _get_pk(cls, instance) -> Any | None: + return getattr(instance, cls._pk_field, None) + @classmethod def query(cls: type[_S], *qs_fns, **relation_qs) -> list[_S]: prepare, project = cls._build_pair(relation_qs) if relation_qs else cls._pair @@ -99,30 +110,59 @@ class Shape(BaseModel, Generic[_M]): queryset = readers_qs.pipe(prepare, *qs_fns)(base) return [cls.model_validate(project(obj)) for obj in queryset] - def diff(self) -> Diff: - cls = type(self) - pk = getattr(self, "id", None) - if pk: - results = cls.query(cls._model.objects.filter(pk=pk)) - if not results: - raise cls._model.DoesNotExist(f"{cls._model.__name__} with id={pk} does not exist") - current = results[0] - else: - current = None + @classmethod + def diff_many(cls: type[_S], items: list[_S]) -> list[tuple[_S, Diff]]: + pk_field = cls._pk_field + pk_map: dict[Any, _S] = {} + new_items: list[_S] = [] + + for item in items: + pk = cls._get_pk(item) + if pk is not None: + pk_map[pk] = item + else: + new_items.append(item) + + # Single query for all existing items + current_map: dict[Any, _S] = {} + if pk_map: + current_items = cls.query( + cls._model.objects.filter(**{f"{pk_field}__in": pk_map.keys()}) + ) + current_map = {cls._get_pk(c): c for c in current_items} + + results: list[tuple[_S, Diff]] = [] + + for item in new_items: + results.append((item, cls._diff_one(item, None))) + + for pk, item in pk_map.items(): + current = current_map.get(pk) + if current is None: + raise cls._model.DoesNotExist( + f"{cls._model.__name__} with {pk_field}={pk} does not exist" + ) + results.append((item, cls._diff_one(item, current))) + + return results + + @classmethod + def _diff_one(cls, incoming: _S, current: _S | None) -> Diff: + pk_field = cls._pk_field changed = ( { - k: getattr(self, k) + k: getattr(incoming, k) for k in cls._field_names - if k != "id" and getattr(self, k) != getattr(current, k) + if k != pk_field and getattr(incoming, k) != getattr(current, k) } if current - else {k: getattr(self, k) for k in cls._field_names if k != "id"} + else {k: getattr(incoming, k) for k in cls._field_names if k != pk_field} ) nested = {} - for name in cls._nested: - incoming_items = getattr(self, name, None) or [] + for name, shape_cls in cls._nested.items(): + incoming_items = getattr(incoming, name, None) or [] current_items = getattr(current, name, None) or [] if current else [] if not isinstance(incoming_items, list): @@ -130,20 +170,45 @@ class Shape(BaseModel, Generic[_M]): if not isinstance(current_items, list): current_items = [current_items] - current_by_id = {c.id: c for c in current_items if c.id is not None} - incoming_by_id = {c.id: c for c in incoming_items if c.id is not None} + child_pk = shape_cls._pk_field + current_by_pk = { + shape_cls._get_pk(c): c + for c in current_items + if shape_cls._get_pk(c) is not None + } + incoming_by_pk = { + shape_cls._get_pk(c): c + for c in incoming_items + if shape_cls._get_pk(c) is not None + } nested[name] = NestedDiff( - created=[c for c in incoming_items if c.id is None], + created=[c for c in incoming_items if shape_cls._get_pk(c) is None], updated=[ - c for id, c in incoming_by_id.items() - if id in current_by_id and c != current_by_id[id] + c + for pk, c in incoming_by_pk.items() + if pk in current_by_pk and c != current_by_pk[pk] ], - deleted=[id for id in current_by_id if id not in incoming_by_id], + deleted=[pk for pk in current_by_pk if pk not in incoming_by_pk], ) return Diff(is_new=current is None, changed=changed, _nested=nested) + def diff(self) -> Diff: + cls = type(self) + pk = cls._get_pk(self) + if pk is not None: + results = cls.query(cls._model.objects.filter(pk=pk)) + if not results: + raise cls._model.DoesNotExist( + f"{cls._model.__name__} with {cls._pk_field}={pk} does not exist" + ) + current = results[0] + else: + current = None + + return cls._diff_one(self, current) + class NestedDiff: __slots__ = ("created", "updated", "deleted") @@ -157,12 +222,26 @@ class NestedDiff: class Diff: __slots__ = ("is_new", "changed", "_nested") - def __init__(self, is_new: bool, changed: dict[str, Any], _nested: dict[str, NestedDiff]): + def __init__( + self, is_new: bool, changed: dict[str, Any], _nested: dict[str, NestedDiff] + ): self.is_new = is_new self.changed = changed self._nested = _nested + def nested(self, name: str) -> NestedDiff: + """Strict access to nested diffs. Raises KeyError for invalid names.""" + if name not in self._nested: + valid = ", ".join(sorted(self._nested)) or "(none)" + raise KeyError(f"No nested diff for '{name}'. Valid nested shapes: {valid}") + return self._nested[name] + def __getattr__(self, name: str) -> NestedDiff: - if name in self._nested: - return self._nested[name] - return NestedDiff() + if name.startswith("_"): + raise AttributeError(name) + if name not in self._nested: + valid = ", ".join(sorted(self._nested)) or "(none)" + raise AttributeError( + f"No nested diff for '{name}'. Valid nested shapes: {valid}" + ) + return self._nested[name] diff --git a/django/src/djarea/tests/test_shapes.py b/django/src/djarea/tests/test_shapes.py index 7e5a9a1..19b9b84 100644 --- a/django/src/djarea/tests/test_shapes.py +++ b/django/src/djarea/tests/test_shapes.py @@ -66,6 +66,18 @@ class ShapeMetaTests(TestCase): self.assertIsNotNone(BookShape._pair) self.assertEqual(len(BookShape._pair), 2) # (prepare, project) + def test_pk_field_resolved_from_model_meta(self): + self.assertEqual(BookShape._pk_field, "id") + self.assertEqual(AuthorShape._pk_field, "id") + + def test_get_pk_reads_correct_field(self): + shape = BookShape(id=42, title="Test", pages=1) + self.assertEqual(BookShape._get_pk(shape), 42) + + def test_get_pk_returns_none_for_new(self): + shape = BookShape(title="New", pages=1) + self.assertIsNone(BookShape._get_pk(shape)) + # ============================================================================= # Querying @@ -273,13 +285,135 @@ class DiffNestedTests(TestCase): self.assertEqual(len(diff.books.updated), 1) self.assertEqual(diff.books.deleted, [self.book2.id]) - def test_accessing_nonexistent_nested_returns_empty(self): + def test_accessing_nonexistent_nested_raises_attribute_error(self): shape = BookShape(title="Simple", pages=10) diff = shape.diff() - # BookShape has no nested relations - empty = diff.nonexistent_relation - self.assertIsInstance(empty, NestedDiff) - self.assertEqual(empty.created, []) - self.assertEqual(empty.updated, []) - self.assertEqual(empty.deleted, []) + with self.assertRaises(AttributeError) as ctx: + diff.nonexistent_relation + + self.assertIn("nonexistent_relation", str(ctx.exception)) + + def test_nested_method_raises_key_error(self): + shape = BookShape(title="Simple", pages=10) + diff = shape.diff() + + with self.assertRaises(KeyError) as ctx: + diff.nested("nonexistent") + + self.assertIn("nonexistent", str(ctx.exception)) + + def test_nested_method_returns_valid_nested_diff(self): + shape = AuthorShape( + id=self.author.id, + name="Alice", + bio="Writer", + books=[ + BookShape(id=self.book1.id, title="Book One", pages=100), + ], + ) + diff = shape.diff() + books_diff = diff.nested("books") + + self.assertIsInstance(books_diff, NestedDiff) + self.assertEqual(diff.books.deleted, [self.book2.id]) + + def test_diff_error_message_lists_valid_names(self): + shape = AuthorShape( + id=self.author.id, + name="Alice", + bio="Writer", + books=[], + ) + diff = shape.diff() + + with self.assertRaises(AttributeError) as ctx: + diff.typo + + self.assertIn("books", str(ctx.exception)) + + +# ============================================================================= +# diff_many +# ============================================================================= + + +class DiffManyTests(TestCase): + """diff_many batches queries instead of N+1.""" + + def setUp(self): + self.author = Author.objects.create(name="Alice", bio="Writer") + self.book1 = Book.objects.create(title="Book One", pages=100, author=self.author) + self.book2 = Book.objects.create(title="Book Two", pages=200, author=self.author) + self.book3 = Book.objects.create(title="Book Three", pages=300, author=self.author) + + def test_diff_many_no_changes(self): + items = [ + BookShape(id=self.book1.id, title="Book One", pages=100), + BookShape(id=self.book2.id, title="Book Two", pages=200), + ] + results = BookShape.diff_many(items) + + self.assertEqual(len(results), 2) + for item, diff in results: + self.assertFalse(diff.is_new) + self.assertEqual(diff.changed, {}) + + def test_diff_many_with_changes(self): + items = [ + BookShape(id=self.book1.id, title="Renamed", pages=100), + BookShape(id=self.book2.id, title="Book Two", pages=999), + ] + results = BookShape.diff_many(items) + + diffs = {item.id: diff for item, diff in results} + self.assertEqual(diffs[self.book1.id].changed, {"title": "Renamed"}) + self.assertEqual(diffs[self.book2.id].changed, {"pages": 999}) + + def test_diff_many_with_new_items(self): + items = [ + BookShape(title="Brand New", pages=50), + BookShape(id=self.book1.id, title="Book One", pages=100), + ] + results = BookShape.diff_many(items) + + self.assertEqual(len(results), 2) + new_diffs = [(item, diff) for item, diff in results if diff.is_new] + existing_diffs = [(item, diff) for item, diff in results if not diff.is_new] + + self.assertEqual(len(new_diffs), 1) + self.assertEqual(new_diffs[0][0].title, "Brand New") + self.assertEqual(len(existing_diffs), 1) + + def test_diff_many_nonexistent_raises(self): + items = [ + BookShape(id=99999, title="Ghost", pages=0), + ] + with self.assertRaises(Book.DoesNotExist): + BookShape.diff_many(items) + + def test_diff_many_single_query(self): + """diff_many should use one query for all existing items, not N queries.""" + items = [ + BookShape(id=self.book1.id, title="A", pages=1), + BookShape(id=self.book2.id, title="B", pages=2), + BookShape(id=self.book3.id, title="C", pages=3), + ] + + with self.assertNumQueries(1): + BookShape.diff_many(items) + + def test_diff_many_empty_list(self): + results = BookShape.diff_many([]) + self.assertEqual(results, []) + + def test_diff_many_all_new(self): + items = [ + BookShape(title="New A", pages=10), + BookShape(title="New B", pages=20), + ] + results = BookShape.diff_many(items) + + self.assertEqual(len(results), 2) + for item, diff in results: + self.assertTrue(diff.is_new)