Skip to content

Commit db1896d

Browse files
authored
Add *args, **kwargs to read data method in reader interface (#10)
1 parent 353e125 commit db1896d

File tree

4 files changed

+4
-4
lines changed

4 files changed

+4
-4
lines changed

dataset_reader/ann_h5_multi_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def read_queries(self) -> Iterator[Query]:
5252
)
5353

5454
def read_data(
55-
self, start_idx: int = 0, end_idx: int = None, chunk_size: int = 10_000
55+
self, start_idx: int = 0, end_idx: int = None, chunk_size: int = 10_000, *args, **kwargs
5656
) -> Iterator[Record]:
5757
"""
5858
Reads the 'train' data vectors from multiple HDF5 files based on the specified range.

dataset_reader/ann_h5_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def read_queries(self) -> Iterator[Query]:
2727
expected_scores=expected_scores.tolist(),
2828
)
2929

30-
def read_data(self) -> Iterator[Record]:
30+
def read_data(self, *args, **kwargs) -> Iterator[Record]:
3131
data = h5py.File(self.path)
3232

3333
for idx, vector in enumerate(data["train"]):

dataset_reader/base_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@ class Query:
1818

1919

2020
class BaseReader:
21-
def read_data(self) -> Iterator[Record]:
21+
def read_data(self, *args, **kwargs) -> Iterator[Record]:
2222
raise NotImplementedError()
2323

2424
def read_queries(self) -> Iterator[Query]:

dataset_reader/json_reader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ def read_queries(self) -> Iterator[Query]:
6060

6161
yield Query(vector=vector, meta_conditions=None, expected_result=neighbours)
6262

63-
def read_data(self) -> Iterator[Record]:
63+
def read_data(self, *args, **kwargs) -> Iterator[Record]:
6464
for idx, (vector, payload) in enumerate(
6565
zip(self.read_vectors(), self.read_payloads())
6666
):

0 commit comments

Comments
 (0)