Skip to content

Commit 2189962

Browse files
atharva-2001Noah Negin-Ulster
authored and
Noah Negin-Ulster
committed
refactor: scaffolding to support custom context in extensions
NOTE: Since syrupy v4 migrated from instance methods to classmethods, this new context is not actual usable. This lays the groundwork for a switch back to instance methods though (if we continue along this path).
1 parent b240712 commit 2189962

File tree

6 files changed

+86
-21
lines changed

6 files changed

+86
-21
lines changed

src/syrupy/assertion.py

Lines changed: 24 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,9 @@ class SnapshotAssertion:
6464
exclude: Optional["PropertyFilter"] = None
6565
matcher: Optional["PropertyMatcher"] = None
6666

67+
# context is reserved exclusively for custom extensions
68+
context: Optional[Dict[str, Any]] = None
69+
6770
_exclude: Optional["PropertyFilter"] = field(
6871
init=False,
6972
default=None,
@@ -109,7 +112,8 @@ def __post_init__(self) -> None:
109112
def __init_extension(
110113
self, extension_class: Type["AbstractSyrupyExtension"]
111114
) -> "AbstractSyrupyExtension":
112-
return extension_class()
115+
kwargs = {"context": self.context} if self.context else {}
116+
return extension_class(**kwargs)
113117

114118
@property
115119
def extension(self) -> "AbstractSyrupyExtension":
@@ -178,6 +182,7 @@ def with_defaults(
178182
include: Optional["PropertyFilter"] = None,
179183
matcher: Optional["PropertyMatcher"] = None,
180184
extension_class: Optional[Type["AbstractSyrupyExtension"]] = None,
185+
context: Optional[Dict[str, Any]] = None,
181186
) -> "SnapshotAssertion":
182187
"""
183188
Create new snapshot assertion fixture with provided values. This preserves
@@ -191,6 +196,7 @@ def with_defaults(
191196
test_location=self.test_location,
192197
extension_class=extension_class or self.extension_class,
193198
session=self.session,
199+
context=context or self.context,
194200
)
195201

196202
def use_extension(
@@ -207,7 +213,10 @@ def assert_match(self, data: "SerializableData") -> None:
207213

208214
def _serialize(self, data: "SerializableData") -> "SerializedData":
209215
return self.extension.serialize(
210-
data, exclude=self._exclude, include=self._include, matcher=self.__matcher
216+
data,
217+
exclude=self._exclude,
218+
include=self._include,
219+
matcher=self.__matcher,
211220
)
212221

213222
def get_assert_diff(self) -> List[str]:
@@ -264,6 +273,7 @@ def __call__(
264273
extension_class: Optional[Type["AbstractSyrupyExtension"]] = None,
265274
matcher: Optional["PropertyMatcher"] = None,
266275
name: Optional["SnapshotIndex"] = None,
276+
context: Optional[Dict[str, Any]] = None,
267277
) -> "SnapshotAssertion":
268278
"""
269279
Modifies assertion instance options
@@ -272,14 +282,18 @@ def __call__(
272282
self.__with_prop("_exclude", exclude)
273283
if include:
274284
self.__with_prop("_include", include)
275-
if extension_class:
276-
self.__with_prop("_extension", self.__init_extension(extension_class))
277285
if matcher:
278286
self.__with_prop("_matcher", matcher)
279287
if name:
280288
self.__with_prop("_custom_index", name)
281289
if diff is not None:
282290
self.__with_prop("_snapshot_diff", diff)
291+
if context and context != self.context:
292+
self.__with_prop("context", context)
293+
# We need to force the extension to be re-initialized if the context changes
294+
extension_class = extension_class or self.extension_class
295+
if extension_class:
296+
self.__with_prop("_extension", self.__init_extension(extension_class))
283297
return self
284298

285299
def __repr__(self) -> str:
@@ -290,10 +304,12 @@ def __eq__(self, other: "SerializableData") -> bool:
290304

291305
def _assert(self, data: "SerializableData") -> bool:
292306
snapshot_location = self.extension.get_location(
293-
test_location=self.test_location, index=self.index
307+
test_location=self.test_location,
308+
index=self.index,
294309
)
295310
snapshot_name = self.extension.get_snapshot_name(
296-
test_location=self.test_location, index=self.index
311+
test_location=self.test_location,
312+
index=self.index,
297313
)
298314
snapshot_data: Optional["SerializedData"] = None
299315
serialized_data: Optional["SerializedData"] = None
@@ -316,7 +332,8 @@ def _assert(self, data: "SerializableData") -> bool:
316332
not tainted
317333
and snapshot_data is not None
318334
and self.extension.matches(
319-
serialized_data=serialized_data, snapshot_data=snapshot_data
335+
serialized_data=serialized_data,
336+
snapshot_data=snapshot_data,
320337
)
321338
)
322339
assertion_success = matches

src/syrupy/extensions/base.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,10 @@ class SnapshotCollectionStorage(ABC):
8080

8181
@classmethod
8282
def get_snapshot_name(
83-
cls, *, test_location: "PyTestLocation", index: "SnapshotIndex" = 0
83+
cls,
84+
*,
85+
test_location: "PyTestLocation",
86+
index: "SnapshotIndex" = 0,
8487
) -> str:
8588
"""Get the snapshot name for the assertion index in a test location"""
8689
index_suffix = ""
@@ -225,7 +228,11 @@ def _read_snapshot_collection(
225228

226229
@abstractmethod
227230
def _read_snapshot_data_from_location(
228-
self, *, snapshot_location: str, snapshot_name: str, session_id: str
231+
self,
232+
*,
233+
snapshot_location: str,
234+
snapshot_name: str,
235+
session_id: str,
229236
) -> Optional["SerializedData"]:
230237
"""
231238
Get only the snapshot data from location for assertion
@@ -259,15 +266,19 @@ class SnapshotReporter(ABC):
259266
_context_line_count = 1
260267

261268
def diff_snapshots(
262-
self, serialized_data: "SerializedData", snapshot_data: "SerializedData"
269+
self,
270+
serialized_data: "SerializedData",
271+
snapshot_data: "SerializedData",
263272
) -> "SerializedData":
264273
env = {DISABLE_COLOR_ENV_VAR: "true"}
265274
attrs = {"_context_line_count": 0}
266275
with env_context(**env), obj_attrs(self, attrs):
267276
return "\n".join(self.diff_lines(serialized_data, snapshot_data))
268277

269278
def diff_lines(
270-
self, serialized_data: "SerializedData", snapshot_data: "SerializedData"
279+
self,
280+
serialized_data: "SerializedData",
281+
snapshot_data: "SerializedData",
271282
) -> Iterator[str]:
272283
for line in self.__diff_lines(str(snapshot_data), str(serialized_data)):
273284
yield reset(line)

src/syrupy/extensions/json/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -145,6 +145,7 @@ def serialize(
145145
exclude: Optional["PropertyFilter"] = None,
146146
include: Optional["PropertyFilter"] = None,
147147
matcher: Optional["PropertyMatcher"] = None,
148+
**kwargs: Any,
148149
) -> "SerializedData":
149150
data = self._filter(
150151
data=data,

src/syrupy/extensions/single_file.py

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -54,7 +54,10 @@ def serialize(
5454

5555
@classmethod
5656
def get_snapshot_name(
57-
cls, *, test_location: "PyTestLocation", index: "SnapshotIndex" = 0
57+
cls,
58+
*,
59+
test_location: "PyTestLocation",
60+
index: "SnapshotIndex" = 0,
5861
) -> str:
5962
return cls.__clean_filename(
6063
AbstractSyrupyExtension.get_snapshot_name(
@@ -79,7 +82,9 @@ def dirname(cls, *, test_location: "PyTestLocation") -> str:
7982
return str(Path(original_dirname).joinpath(test_location.basename))
8083

8184
def _read_snapshot_collection(
82-
self, *, snapshot_location: str
85+
self,
86+
*,
87+
snapshot_location: str,
8388
) -> "SnapshotCollection":
8489
file_ext_len = len(self._file_extension) + 1 if self._file_extension else 0
8590
filename_wo_ext = snapshot_location[:-file_ext_len]
@@ -90,7 +95,11 @@ def _read_snapshot_collection(
9095
return snapshot_collection
9196

9297
def _read_snapshot_data_from_location(
93-
self, *, snapshot_location: str, snapshot_name: str, session_id: str
98+
self,
99+
*,
100+
snapshot_location: str,
101+
snapshot_name: str,
102+
session_id: str,
94103
) -> Optional["SerializableData"]:
95104
try:
96105
with open(
@@ -116,7 +125,9 @@ def get_write_encoding(cls) -> Optional[str]:
116125

117126
@classmethod
118127
def _write_snapshot_collection(
119-
cls, *, snapshot_collection: "SnapshotCollection"
128+
cls,
129+
*,
130+
snapshot_collection: "SnapshotCollection",
120131
) -> None:
121132
filepath, data = (
122133
snapshot_collection.location,

src/syrupy/session.py

Lines changed: 27 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import pickle
12
from collections import defaultdict
23
from dataclasses import (
34
dataclass,
@@ -54,7 +55,7 @@ class SnapshotSession:
5455
)
5556

5657
_queued_snapshot_writes: Dict[
57-
Tuple[Type["AbstractSyrupyExtension"], str],
58+
Tuple[Type["AbstractSyrupyExtension"], Optional[bytes], str],
5859
List[Tuple["SerializedData", "PyTestLocation", "SnapshotIndex"]],
5960
] = field(default_factory=dict)
6061

@@ -68,19 +69,41 @@ def queue_snapshot_write(
6869
snapshot_location = extension.get_location(
6970
test_location=test_location, index=index
7071
)
71-
key = (extension.__class__, snapshot_location)
72+
73+
extension_context = getattr(extension, "context", None)
74+
75+
try:
76+
extension_kwargs_bytes = (
77+
pickle.dumps(extension_context) if extension_context else None
78+
)
79+
except pickle.PicklingError:
80+
print("Extension context must be serializable.")
81+
raise
82+
83+
key = (extension.__class__, extension_kwargs_bytes, snapshot_location)
7284
queue = self._queued_snapshot_writes.get(key, [])
7385
queue.append((data, test_location, index))
7486
self._queued_snapshot_writes[key] = queue
7587

7688
def flush_snapshot_write_queue(self) -> None:
7789
for (
7890
extension_class,
91+
extension_kwargs_bytes,
7992
snapshot_location,
8093
), queued_write in self._queued_snapshot_writes.items():
8194
if queued_write:
82-
extension_class.write_snapshot(
83-
snapshot_location=snapshot_location, snapshots=queued_write
95+
# It's possible to instantiate an extension with context. We need to
96+
# ensure we never lose context between instantiations (since we may
97+
# instantiate multiple times in a test session).
98+
extension_kwargs = (
99+
{"context": pickle.loads(extension_kwargs_bytes)}
100+
if extension_kwargs_bytes
101+
else {}
102+
)
103+
extension = extension_class(**extension_kwargs)
104+
extension.write_snapshot(
105+
snapshot_location=snapshot_location,
106+
snapshots=queued_write,
84107
)
85108
self._queued_snapshot_writes = {}
86109

tests/examples/test_custom_snapshot_name.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""
22
Example: Custom Snapshot Name
33
"""
4+
from typing import Any
5+
46
import pytest
57

68
from syrupy.extensions.amber import AmberSnapshotExtension
@@ -11,10 +13,10 @@
1113
class CanadianNameExtension(AmberSnapshotExtension):
1214
@classmethod
1315
def get_snapshot_name(
14-
cls, *, test_location: "PyTestLocation", index: "SnapshotIndex"
16+
cls, *, test_location: "PyTestLocation", index: "SnapshotIndex", **kwargs: Any
1517
) -> str:
1618
original_name = AmberSnapshotExtension.get_snapshot_name(
17-
test_location=test_location, index=index
19+
test_location=test_location, index=index, **kwargs
1820
)
1921
return f"{original_name}🇨🇦"
2022

0 commit comments

Comments
 (0)