-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathtensorstore_python_benchmark_read.py
executable file
·88 lines (76 loc) · 3.1 KB
/
tensorstore_python_benchmark_read.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
#!/usr/bin/env python3
import numpy as np
import timeit
import asyncio
import click
from functools import wraps
import tensorstore as ts
def coro(f):
@wraps(f)
def wrapper(*args, **kwargs):
return asyncio.run(f(*args, **kwargs))
return wrapper
@click.command()
@coro
@click.argument('path', type=str)
@click.option('--concurrent_chunks', type=int, default=None, help='Number of concurrent async chunk reads. Ignored if --read-all is set')
@click.option('--read_all', is_flag=True, show_default=True, default=False, help='Read the entire array in one operation.')
async def main(path, concurrent_chunks, read_all):
if path.startswith("http"):
kvstore = {
'driver': 'http',
'base_url': path,
}
else:
kvstore = {
'driver': 'file',
'path': path,
}
dataset_future = ts.open({
'driver': 'zarr3',
'kvstore': kvstore,
# 'context': {
# 'cache_pool': {
# 'total_bytes_limit': 100_000_000
# }
# },
# 'recheck_cached_data': 'open',
})
dataset = dataset_future.result()
print(dataset)
domain_shape = dataset.domain.shape
chunk_shape = dataset.chunk_layout.write_chunk.shape # shard or chunk shape
print("Domain shape", domain_shape)
print("Chunk shape", chunk_shape)
num_chunks =[(domain + chunk_shape - 1) // chunk_shape for (domain, chunk_shape) in zip(domain_shape, chunk_shape)]
print("Number of chunks", num_chunks)
async def chunk_read(chunk_index):
chunk_slice = [ts.Dim(inclusive_min=index*cshape, exclusive_max=min(index * cshape + cshape, dshape)) for (index, cshape, dshape) in zip(chunk_index, chunk_shape, domain_shape)]
# print("Reading", chunk_index)
return await dataset[ts.IndexDomain(chunk_slice)].read()
# print("Read", chunk_index)
start_time = timeit.default_timer()
if read_all:
print(dataset.read().result().shape)
elif concurrent_chunks is None:
async with asyncio.TaskGroup() as tg:
for chunk_index in np.ndindex(*num_chunks):
tg.create_task(chunk_read(chunk_index))
elif concurrent_chunks == 1:
for chunk_index in np.ndindex(*num_chunks):
chunk_slice = [ts.Dim(inclusive_min=index*cshape, exclusive_max=min(index * cshape + cshape, dshape)) for (index, cshape, dshape) in zip(chunk_index, chunk_shape, domain_shape)]
await dataset[ts.IndexDomain(chunk_slice)].read()
else:
# TODO: Not sure if this is the fastest API for this
semaphore = asyncio.Semaphore(concurrent_chunks)
async def chunk_read_concurrent_limit(chunk_index):
async with semaphore:
return await chunk_read(chunk_index)
async with asyncio.TaskGroup() as tg:
for chunk_index in np.ndindex(*num_chunks):
tg.create_task(chunk_read_concurrent_limit(chunk_index))
elapsed = timeit.default_timer() - start_time
elapsed_ms = elapsed * 1000.0
print(f"Decoded in {elapsed_ms:.2f}ms")
if __name__ == "__main__":
asyncio.run(main())