From 1a63bdb18059046219b33062ad856cf37be3324f Mon Sep 17 00:00:00 2001 From: Kang Li Date: Tue, 23 Sep 2025 10:58:53 +0800 Subject: [PATCH] Refactor DataStore error handling and update tests Replaced ValidationError with AssertionError for checksum validation in DataStore.scan, and updated test cases to expect AssertionError. Added ValidationError for corrupted record handling in DataStore iterator. Updated sync.py usage instructions for clarity. --- swanlab/data/porter/__init__.py | 4 ++-- swanlab/data/porter/datastore.py | 8 +++++--- test/sync/sync.py | 4 ++-- test/unit/data/porter/test_datastore.py | 16 ++++++++-------- 4 files changed, 17 insertions(+), 15 deletions(-) diff --git a/swanlab/data/porter/__init__.py b/swanlab/data/porter/__init__.py index 913077aa..e6e69bc9 100644 --- a/swanlab/data/porter/__init__.py +++ b/swanlab/data/porter/__init__.py @@ -18,7 +18,7 @@ import wrapt from swanlab.core_python import get_client -from swanlab.core_python.uploader import ColumnModel, ScalarModel, MediaModel, LogModel +from swanlab.core_python.uploader import ColumnModel, ScalarModel, MediaModel from swanlab.core_python.uploader.thread import ThreadPool, UploadType from swanlab.data.store import RunStore, get_run_store, reset_run_store from swanlab.error import ValidationError @@ -529,7 +529,7 @@ def _filter_media_by_step(self, metric: MediaModel) -> bool: """ return filter_metric(metric.key, metric.step, self._run_store.metrics) - def _filter_log_by_epoch(self, log: LogContent) -> LogModel: + def _filter_log_by_epoch(self, log: LogContent) -> bool: """ 筛选日志数据,排除已经上传的日志 :param log: 日志数据 diff --git a/swanlab/data/porter/datastore.py b/swanlab/data/porter/datastore.py index abfc2fea..df52dd0b 100644 --- a/swanlab/data/porter/datastore.py +++ b/swanlab/data/porter/datastore.py @@ -106,8 +106,7 @@ def _scan_record(self) -> Optional[Tuple[int, bytes]]: self._index += LEVELDBLOG_HEADER_LEN data = self._fp.read(data_length) checksum_computed = zlib.crc32(data, self._crc[data_type]) & 0xFFFFFFFF - if checksum != checksum_computed: - raise ValidationError("Invalid record checksum, data may be corrupt") + assert checksum == checksum_computed, "Invalid record checksum, data may be corrupt" self._index += data_length # 3. 返回数据 return int(data_type), data @@ -154,7 +153,10 @@ def __iter__(self): return self def __next__(self): - record = self.scan() + try: + record = self.scan() + except AssertionError: + raise ValidationError("One record is corrupted, cannot continue scanning") if record is None: raise StopIteration("End of file reached") return record diff --git a/test/sync/sync.py b/test/sync/sync.py index 6bd1d715..a27e4726 100644 --- a/test/sync/sync.py +++ b/test/sync/sync.py @@ -3,8 +3,8 @@ @file: sync.py @time: 2025/7/20 17:23 @description: 测试同步功能,与同目录下的sync.sh文件配合使用,使用方式为: -1. python test/sync/sync.py: 先运行 sync.sh,此时会要求输入 run_dir -2. bash ./test/sync/sync.s: 然后在项目根目录下运行此脚本,他会打印出当前实验的路径,复制此路径到步骤一,按回车 +1. python test/sync/sync.py: 先运行 sync.py,此时会提示输出的日志文件夹路径,复制此路径 +2. bash ./test/sync/sync.sh: 然后在项目根目录下运行此脚本,将上一步复制的路径粘贴到命令行中,回车进行下一步 这样将实现每五秒钟同步一次数据到云端 """ diff --git a/test/unit/data/porter/test_datastore.py b/test/unit/data/porter/test_datastore.py index 3959613d..8404b233 100644 --- a/test/unit/data/porter/test_datastore.py +++ b/test/unit/data/porter/test_datastore.py @@ -61,8 +61,8 @@ def test_tampered_data_content(temp_log_file): # 尝试读取应触发异常 ds = DataStore() ds.open_for_scan(str(temp_log_file)) - with pytest.raises(ValidationError, match="Invalid record checksum"): - next(ds) # 读取第一条记录 + with pytest.raises(AssertionError, match="Invalid record checksum"): + ds.scan() # 读取第一条记录 def test_tampered_checksum(temp_log_file): @@ -74,8 +74,8 @@ def test_tampered_checksum(temp_log_file): ds = DataStore() ds.open_for_scan(str(temp_log_file)) - with pytest.raises(ValidationError, match="Invalid record checksum"): - next(ds) + with pytest.raises(AssertionError, match="Invalid record checksum"): + ds.scan() def test_tampered_record_type(temp_log_file): @@ -87,8 +87,8 @@ def test_tampered_record_type(temp_log_file): ds = DataStore() ds.open_for_scan(str(temp_log_file)) - with pytest.raises(ValidationError, match="Invalid record checksum"): - next(ds) + with pytest.raises(AssertionError, match="Invalid record checksum"): + ds.scan() def test_tampered_multi_block_record(tmp_path): @@ -113,7 +113,7 @@ def test_tampered_multi_block_record(tmp_path): # 验证读取时触发异常 ds = DataStore() ds.open_for_scan(str(temp_log_file)) - with pytest.raises(ValidationError, match="Invalid record checksum"): + with pytest.raises(ValidationError, match="One record is corrupted, cannot continue scanning"): next(ds) @@ -139,7 +139,7 @@ def test_valid_records_after_tampered_one(temp_log_file): ds.open_for_scan(str(temp_log_file)) # 第一条记录应失败 - with pytest.raises(ValidationError): + with pytest.raises(ValidationError, match="One record is corrupted, cannot continue scanning"): next(ds) # 第二条记录应正常读取