Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions swanlab/data/porter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import wrapt

from swanlab.core_python import get_client
from swanlab.core_python.uploader import ColumnModel, ScalarModel, MediaModel, LogModel
from swanlab.core_python.uploader import ColumnModel, ScalarModel, MediaModel
from swanlab.core_python.uploader.thread import ThreadPool, UploadType
from swanlab.data.store import RunStore, get_run_store, reset_run_store
from swanlab.error import ValidationError
Expand Down Expand Up @@ -529,7 +529,7 @@ def _filter_media_by_step(self, metric: MediaModel) -> bool:
"""
return filter_metric(metric.key, metric.step, self._run_store.metrics)

def _filter_log_by_epoch(self, log: LogContent) -> LogModel:
def _filter_log_by_epoch(self, log: LogContent) -> bool:
"""
筛选日志数据,排除已经上传的日志
:param log: 日志数据
Expand Down
8 changes: 5 additions & 3 deletions swanlab/data/porter/datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,8 +106,7 @@ def _scan_record(self) -> Optional[Tuple[int, bytes]]:
self._index += LEVELDBLOG_HEADER_LEN
data = self._fp.read(data_length)
checksum_computed = zlib.crc32(data, self._crc[data_type]) & 0xFFFFFFFF
if checksum != checksum_computed:
raise ValidationError("Invalid record checksum, data may be corrupt")
assert checksum == checksum_computed, "Invalid record checksum, data may be corrupt"
self._index += data_length
# 3. 返回数据
return int(data_type), data
Expand Down Expand Up @@ -154,7 +153,10 @@ def __iter__(self):
return self

def __next__(self):
record = self.scan()
try:
record = self.scan()
except AssertionError:
raise ValidationError("One record is corrupted, cannot continue scanning")
if record is None:
raise StopIteration("End of file reached")
return record
Expand Down
4 changes: 2 additions & 2 deletions test/sync/sync.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,8 @@
@file: sync.py
@time: 2025/7/20 17:23
@description: 测试同步功能,与同目录下的sync.sh文件配合使用,使用方式为:
1. python test/sync/sync.py: 先运行 sync.sh,此时会要求输入 run_dir
2. bash ./test/sync/sync.s: 然后在项目根目录下运行此脚本,他会打印出当前实验的路径,复制此路径到步骤一,按回车
1. python test/sync/sync.py: 先运行 sync.py,此时会提示输出的日志文件夹路径,复制此路径
2. bash ./test/sync/sync.sh: 然后在项目根目录下运行此脚本,将上一步复制的路径粘贴到命令行中,回车进行下一步
这样将实现每五秒钟同步一次数据到云端
"""

Expand Down
16 changes: 8 additions & 8 deletions test/unit/data/porter/test_datastore.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,8 +61,8 @@ def test_tampered_data_content(temp_log_file):
# 尝试读取应触发异常
ds = DataStore()
ds.open_for_scan(str(temp_log_file))
with pytest.raises(ValidationError, match="Invalid record checksum"):
next(ds) # 读取第一条记录
with pytest.raises(AssertionError, match="Invalid record checksum"):
ds.scan() # 读取第一条记录


def test_tampered_checksum(temp_log_file):
Expand All @@ -74,8 +74,8 @@ def test_tampered_checksum(temp_log_file):

ds = DataStore()
ds.open_for_scan(str(temp_log_file))
with pytest.raises(ValidationError, match="Invalid record checksum"):
next(ds)
with pytest.raises(AssertionError, match="Invalid record checksum"):
ds.scan()


def test_tampered_record_type(temp_log_file):
Expand All @@ -87,8 +87,8 @@ def test_tampered_record_type(temp_log_file):

ds = DataStore()
ds.open_for_scan(str(temp_log_file))
with pytest.raises(ValidationError, match="Invalid record checksum"):
next(ds)
with pytest.raises(AssertionError, match="Invalid record checksum"):
ds.scan()


def test_tampered_multi_block_record(tmp_path):
Expand All @@ -113,7 +113,7 @@ def test_tampered_multi_block_record(tmp_path):
# 验证读取时触发异常
ds = DataStore()
ds.open_for_scan(str(temp_log_file))
with pytest.raises(ValidationError, match="Invalid record checksum"):
with pytest.raises(ValidationError, match="One record is corrupted, cannot continue scanning"):
next(ds)


Expand All @@ -139,7 +139,7 @@ def test_valid_records_after_tampered_one(temp_log_file):
ds.open_for_scan(str(temp_log_file))

# 第一条记录应失败
with pytest.raises(ValidationError):
with pytest.raises(ValidationError, match="One record is corrupted, cannot continue scanning"):
next(ds)

# 第二条记录应正常读取
Expand Down