diff --git a/swanlab/data/porter/__init__.py b/swanlab/data/porter/__init__.py index 913077aa..e6e69bc9 100644 --- a/swanlab/data/porter/__init__.py +++ b/swanlab/data/porter/__init__.py @@ -18,7 +18,7 @@ import wrapt from swanlab.core_python import get_client -from swanlab.core_python.uploader import ColumnModel, ScalarModel, MediaModel, LogModel +from swanlab.core_python.uploader import ColumnModel, ScalarModel, MediaModel from swanlab.core_python.uploader.thread import ThreadPool, UploadType from swanlab.data.store import RunStore, get_run_store, reset_run_store from swanlab.error import ValidationError @@ -529,7 +529,7 @@ def _filter_media_by_step(self, metric: MediaModel) -> bool: """ return filter_metric(metric.key, metric.step, self._run_store.metrics) - def _filter_log_by_epoch(self, log: LogContent) -> LogModel: + def _filter_log_by_epoch(self, log: LogContent) -> bool: """ 筛选日志数据,排除已经上传的日志 :param log: 日志数据 diff --git a/swanlab/data/porter/datastore.py b/swanlab/data/porter/datastore.py index abfc2fea..df52dd0b 100644 --- a/swanlab/data/porter/datastore.py +++ b/swanlab/data/porter/datastore.py @@ -106,8 +106,7 @@ def _scan_record(self) -> Optional[Tuple[int, bytes]]: self._index += LEVELDBLOG_HEADER_LEN data = self._fp.read(data_length) checksum_computed = zlib.crc32(data, self._crc[data_type]) & 0xFFFFFFFF - if checksum != checksum_computed: - raise ValidationError("Invalid record checksum, data may be corrupt") + assert checksum == checksum_computed, "Invalid record checksum, data may be corrupt" self._index += data_length # 3. 返回数据 return int(data_type), data @@ -154,7 +153,10 @@ def __iter__(self): return self def __next__(self): - record = self.scan() + try: + record = self.scan() + except AssertionError: + raise ValidationError("One record is corrupted, cannot continue scanning") if record is None: raise StopIteration("End of file reached") return record diff --git a/test/sync/sync.py b/test/sync/sync.py index 6bd1d715..a27e4726 100644 --- a/test/sync/sync.py +++ b/test/sync/sync.py @@ -3,8 +3,8 @@ @file: sync.py @time: 2025/7/20 17:23 @description: 测试同步功能,与同目录下的sync.sh文件配合使用,使用方式为: -1. python test/sync/sync.py: 先运行 sync.sh,此时会要求输入 run_dir -2. bash ./test/sync/sync.s: 然后在项目根目录下运行此脚本,他会打印出当前实验的路径,复制此路径到步骤一,按回车 +1. python test/sync/sync.py: 先运行 sync.py,此时会提示输出的日志文件夹路径,复制此路径 +2. bash ./test/sync/sync.sh: 然后在项目根目录下运行此脚本,将上一步复制的路径粘贴到命令行中,回车进行下一步 这样将实现每五秒钟同步一次数据到云端 """ diff --git a/test/unit/data/porter/test_datastore.py b/test/unit/data/porter/test_datastore.py index 3959613d..8404b233 100644 --- a/test/unit/data/porter/test_datastore.py +++ b/test/unit/data/porter/test_datastore.py @@ -61,8 +61,8 @@ def test_tampered_data_content(temp_log_file): # 尝试读取应触发异常 ds = DataStore() ds.open_for_scan(str(temp_log_file)) - with pytest.raises(ValidationError, match="Invalid record checksum"): - next(ds) # 读取第一条记录 + with pytest.raises(AssertionError, match="Invalid record checksum"): + ds.scan() # 读取第一条记录 def test_tampered_checksum(temp_log_file): @@ -74,8 +74,8 @@ def test_tampered_checksum(temp_log_file): ds = DataStore() ds.open_for_scan(str(temp_log_file)) - with pytest.raises(ValidationError, match="Invalid record checksum"): - next(ds) + with pytest.raises(AssertionError, match="Invalid record checksum"): + ds.scan() def test_tampered_record_type(temp_log_file): @@ -87,8 +87,8 @@ def test_tampered_record_type(temp_log_file): ds = DataStore() ds.open_for_scan(str(temp_log_file)) - with pytest.raises(ValidationError, match="Invalid record checksum"): - next(ds) + with pytest.raises(AssertionError, match="Invalid record checksum"): + ds.scan() def test_tampered_multi_block_record(tmp_path): @@ -113,7 +113,7 @@ def test_tampered_multi_block_record(tmp_path): # 验证读取时触发异常 ds = DataStore() ds.open_for_scan(str(temp_log_file)) - with pytest.raises(ValidationError, match="Invalid record checksum"): + with pytest.raises(ValidationError, match="One record is corrupted, cannot continue scanning"): next(ds) @@ -139,7 +139,7 @@ def test_valid_records_after_tampered_one(temp_log_file): ds.open_for_scan(str(temp_log_file)) # 第一条记录应失败 - with pytest.raises(ValidationError): + with pytest.raises(ValidationError, match="One record is corrupted, cannot continue scanning"): next(ds) # 第二条记录应正常读取