diff --git a/src/syrupy/assertion.py b/src/syrupy/assertion.py index 6bdb4fcf..aeb83eef 100644 --- a/src/syrupy/assertion.py +++ b/src/syrupy/assertion.py @@ -64,6 +64,9 @@ class SnapshotAssertion: exclude: Optional["PropertyFilter"] = None matcher: Optional["PropertyMatcher"] = None + # context is reserved exclusively for custom extensions + context: Optional[Dict[str, Any]] = None + _exclude: Optional["PropertyFilter"] = field( init=False, default=None, @@ -109,7 +112,8 @@ def __post_init__(self) -> None: def __init_extension( self, extension_class: Type["AbstractSyrupyExtension"] ) -> "AbstractSyrupyExtension": - return extension_class() + kwargs = {"context": self.context} if self.context else {} + return extension_class(**kwargs) @property def extension(self) -> "AbstractSyrupyExtension": @@ -178,6 +182,7 @@ def with_defaults( include: Optional["PropertyFilter"] = None, matcher: Optional["PropertyMatcher"] = None, extension_class: Optional[Type["AbstractSyrupyExtension"]] = None, + context: Optional[Dict[str, Any]] = None, ) -> "SnapshotAssertion": """ Create new snapshot assertion fixture with provided values. This preserves @@ -191,6 +196,7 @@ def with_defaults( test_location=self.test_location, extension_class=extension_class or self.extension_class, session=self.session, + context=context or self.context, ) def use_extension( @@ -207,7 +213,10 @@ def assert_match(self, data: "SerializableData") -> None: def _serialize(self, data: "SerializableData") -> "SerializedData": return self.extension.serialize( - data, exclude=self._exclude, include=self._include, matcher=self.__matcher + data, + exclude=self._exclude, + include=self._include, + matcher=self.__matcher, ) def get_assert_diff(self) -> List[str]: @@ -264,6 +273,7 @@ def __call__( extension_class: Optional[Type["AbstractSyrupyExtension"]] = None, matcher: Optional["PropertyMatcher"] = None, name: Optional["SnapshotIndex"] = None, + context: Optional[Dict[str, Any]] = None, ) -> "SnapshotAssertion": """ Modifies assertion instance options @@ -272,14 +282,18 @@ def __call__( self.__with_prop("_exclude", exclude) if include: self.__with_prop("_include", include) - if extension_class: - self.__with_prop("_extension", self.__init_extension(extension_class)) if matcher: self.__with_prop("_matcher", matcher) if name: self.__with_prop("_custom_index", name) if diff is not None: self.__with_prop("_snapshot_diff", diff) + if context and context != self.context: + self.__with_prop("context", context) + # We need to force the extension to be re-initialized if the context changes + extension_class = extension_class or self.extension_class + if extension_class: + self.__with_prop("_extension", self.__init_extension(extension_class)) return self def __repr__(self) -> str: @@ -290,10 +304,12 @@ def __eq__(self, other: "SerializableData") -> bool: def _assert(self, data: "SerializableData") -> bool: snapshot_location = self.extension.get_location( - test_location=self.test_location, index=self.index + test_location=self.test_location, + index=self.index, ) snapshot_name = self.extension.get_snapshot_name( - test_location=self.test_location, index=self.index + test_location=self.test_location, + index=self.index, ) snapshot_data: Optional["SerializedData"] = None serialized_data: Optional["SerializedData"] = None @@ -316,7 +332,8 @@ def _assert(self, data: "SerializableData") -> bool: not tainted and snapshot_data is not None and self.extension.matches( - serialized_data=serialized_data, snapshot_data=snapshot_data + serialized_data=serialized_data, + snapshot_data=snapshot_data, ) ) assertion_success = matches diff --git a/src/syrupy/extensions/base.py b/src/syrupy/extensions/base.py index 945cf20b..acafd4c7 100644 --- a/src/syrupy/extensions/base.py +++ b/src/syrupy/extensions/base.py @@ -80,7 +80,10 @@ class SnapshotCollectionStorage(ABC): @classmethod def get_snapshot_name( - cls, *, test_location: "PyTestLocation", index: "SnapshotIndex" = 0 + cls, + *, + test_location: "PyTestLocation", + index: "SnapshotIndex" = 0, ) -> str: """Get the snapshot name for the assertion index in a test location""" index_suffix = "" @@ -225,7 +228,11 @@ def _read_snapshot_collection( @abstractmethod def _read_snapshot_data_from_location( - self, *, snapshot_location: str, snapshot_name: str, session_id: str + self, + *, + snapshot_location: str, + snapshot_name: str, + session_id: str, ) -> Optional["SerializedData"]: """ Get only the snapshot data from location for assertion @@ -259,7 +266,9 @@ class SnapshotReporter(ABC): _context_line_count = 1 def diff_snapshots( - self, serialized_data: "SerializedData", snapshot_data: "SerializedData" + self, + serialized_data: "SerializedData", + snapshot_data: "SerializedData", ) -> "SerializedData": env = {DISABLE_COLOR_ENV_VAR: "true"} attrs = {"_context_line_count": 0} @@ -267,7 +276,9 @@ def diff_snapshots( return "\n".join(self.diff_lines(serialized_data, snapshot_data)) def diff_lines( - self, serialized_data: "SerializedData", snapshot_data: "SerializedData" + self, + serialized_data: "SerializedData", + snapshot_data: "SerializedData", ) -> Iterator[str]: for line in self.__diff_lines(str(snapshot_data), str(serialized_data)): yield reset(line) diff --git a/src/syrupy/extensions/json/__init__.py b/src/syrupy/extensions/json/__init__.py index 5b52a8d5..ccc9488e 100644 --- a/src/syrupy/extensions/json/__init__.py +++ b/src/syrupy/extensions/json/__init__.py @@ -145,6 +145,7 @@ def serialize( exclude: Optional["PropertyFilter"] = None, include: Optional["PropertyFilter"] = None, matcher: Optional["PropertyMatcher"] = None, + **kwargs: Any, ) -> "SerializedData": data = self._filter( data=data, diff --git a/src/syrupy/extensions/single_file.py b/src/syrupy/extensions/single_file.py index 0b216115..90d2ad3c 100644 --- a/src/syrupy/extensions/single_file.py +++ b/src/syrupy/extensions/single_file.py @@ -54,7 +54,10 @@ def serialize( @classmethod def get_snapshot_name( - cls, *, test_location: "PyTestLocation", index: "SnapshotIndex" = 0 + cls, + *, + test_location: "PyTestLocation", + index: "SnapshotIndex" = 0, ) -> str: return cls.__clean_filename( AbstractSyrupyExtension.get_snapshot_name( @@ -79,7 +82,9 @@ def dirname(cls, *, test_location: "PyTestLocation") -> str: return str(Path(original_dirname).joinpath(test_location.basename)) def _read_snapshot_collection( - self, *, snapshot_location: str + self, + *, + snapshot_location: str, ) -> "SnapshotCollection": file_ext_len = len(self._file_extension) + 1 if self._file_extension else 0 filename_wo_ext = snapshot_location[:-file_ext_len] @@ -90,7 +95,11 @@ def _read_snapshot_collection( return snapshot_collection def _read_snapshot_data_from_location( - self, *, snapshot_location: str, snapshot_name: str, session_id: str + self, + *, + snapshot_location: str, + snapshot_name: str, + session_id: str, ) -> Optional["SerializableData"]: try: with open( @@ -116,7 +125,9 @@ def get_write_encoding(cls) -> Optional[str]: @classmethod def _write_snapshot_collection( - cls, *, snapshot_collection: "SnapshotCollection" + cls, + *, + snapshot_collection: "SnapshotCollection", ) -> None: filepath, data = ( snapshot_collection.location, diff --git a/src/syrupy/session.py b/src/syrupy/session.py index 6b612145..941c756f 100644 --- a/src/syrupy/session.py +++ b/src/syrupy/session.py @@ -1,3 +1,4 @@ +import pickle from collections import defaultdict from dataclasses import ( dataclass, @@ -54,7 +55,7 @@ class SnapshotSession: ) _queued_snapshot_writes: Dict[ - Tuple[Type["AbstractSyrupyExtension"], str], + Tuple[Type["AbstractSyrupyExtension"], Optional[bytes], str], List[Tuple["SerializedData", "PyTestLocation", "SnapshotIndex"]], ] = field(default_factory=dict) @@ -68,7 +69,18 @@ def queue_snapshot_write( snapshot_location = extension.get_location( test_location=test_location, index=index ) - key = (extension.__class__, snapshot_location) + + extension_context = getattr(extension, "context", None) + + try: + extension_kwargs_bytes = ( + pickle.dumps(extension_context) if extension_context else None + ) + except pickle.PicklingError: + print("Extension context must be serializable.") + raise + + key = (extension.__class__, extension_kwargs_bytes, snapshot_location) queue = self._queued_snapshot_writes.get(key, []) queue.append((data, test_location, index)) self._queued_snapshot_writes[key] = queue @@ -76,11 +88,22 @@ def queue_snapshot_write( def flush_snapshot_write_queue(self) -> None: for ( extension_class, + extension_kwargs_bytes, snapshot_location, ), queued_write in self._queued_snapshot_writes.items(): if queued_write: - extension_class.write_snapshot( - snapshot_location=snapshot_location, snapshots=queued_write + # It's possible to instantiate an extension with context. We need to + # ensure we never lose context between instantiations (since we may + # instantiate multiple times in a test session). + extension_kwargs = ( + {"context": pickle.loads(extension_kwargs_bytes)} + if extension_kwargs_bytes + else {} + ) + extension = extension_class(**extension_kwargs) + extension.write_snapshot( + snapshot_location=snapshot_location, + snapshots=queued_write, ) self._queued_snapshot_writes = {} diff --git a/tests/examples/test_custom_snapshot_name.py b/tests/examples/test_custom_snapshot_name.py index c627f8c7..8cae106e 100644 --- a/tests/examples/test_custom_snapshot_name.py +++ b/tests/examples/test_custom_snapshot_name.py @@ -1,6 +1,8 @@ """ Example: Custom Snapshot Name """ +from typing import Any + import pytest from syrupy.extensions.amber import AmberSnapshotExtension @@ -11,10 +13,10 @@ class CanadianNameExtension(AmberSnapshotExtension): @classmethod def get_snapshot_name( - cls, *, test_location: "PyTestLocation", index: "SnapshotIndex" + cls, *, test_location: "PyTestLocation", index: "SnapshotIndex", **kwargs: Any ) -> str: original_name = AmberSnapshotExtension.get_snapshot_name( - test_location=test_location, index=index + test_location=test_location, index=index, **kwargs ) return f"{original_name}🇨🇦"