Skip to content

Commit dd5c8ff

Browse files
authored
Merge pull request #22 from pfizer-opensource/feature/bfloat16
* Feature: add bfloat16 support * breaking change: output dimensionality is now: batch x sequence x tracks
2 parents 512aa3f + a8b6a1c commit dd5c8ff

29 files changed

+6738
-226
lines changed
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
name: Check pre-commit hooks
2+
3+
on:
4+
push:
5+
branches: [main]
6+
pull_request:
7+
workflow_dispatch:
8+
9+
env:
10+
FORCE_COLOR: "1"
11+
12+
jobs:
13+
tests-using-pixi:
14+
timeout-minutes: 10
15+
runs-on: ubuntu-latest
16+
steps:
17+
- name: Check out the repository
18+
uses: actions/checkout@v4
19+
with:
20+
fetch-depth: 0
21+
22+
- name: Install Pixi
23+
uses: prefix-dev/[email protected]
24+
with:
25+
pixi-version: "latest"
26+
run-install: false
27+
28+
- name: Install pre-commit
29+
run: pixi global install pre-commit
30+
31+
- name: install pre-commit hooks
32+
run: pre-commit install
33+
34+
- name: Run pre-commit hooks
35+
run: pre-commit run --all --show-diff-on-failure

.github/workflows/run_tests.yaml

Lines changed: 0 additions & 29 deletions
This file was deleted.

.github/workflows/tests.yml

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
name: Tests-With-Pixi
2+
3+
on:
4+
push:
5+
branches: [main]
6+
pull_request:
7+
workflow_dispatch:
8+
9+
env:
10+
FORCE_COLOR: "1"
11+
12+
jobs:
13+
tests-using-pixi:
14+
timeout-minutes: 10
15+
runs-on: ubuntu-22.04-gpu-t4
16+
steps:
17+
- name: Check out the repository
18+
uses: actions/checkout@v4
19+
with:
20+
fetch-depth: 0
21+
22+
- name: run nvidia-smi
23+
run: nvidia-smi
24+
25+
- name: Install Pixi
26+
uses: prefix-dev/[email protected]
27+
with:
28+
pixi-version: "latest"
29+
run-install: false
30+
31+
- name: Install pre-commit
32+
run: pixi global install pre-commit
33+
34+
- name: Install dev environment using pixi
35+
run: pixi install --environment dev
36+
37+
- name: Run tests
38+
run: pixi run --environment dev test

.pre-commit-config.yaml

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -29,16 +29,16 @@ repos:
2929
hooks:
3030
- id: mypy
3131
additional_dependencies:
32-
- pydantic==2.1.1
33-
- pydantic-settings==2.0.3
32+
- pydantic==2.12.4
33+
- pydantic-settings==2.12.0
3434
exclude: "^(build|docs|tests|benchmark|examples)"
3535
- repo: https://github.com/asottile/pyupgrade
3636
rev: v3.10.1
3737
hooks:
3838
- id: pyupgrade
3939
args: [--py37-plus, --keep-runtime-typing]
4040
- repo: https://github.com/PyCQA/bandit
41-
rev: '1.7.5'
41+
rev: '1.8.6'
4242
hooks:
4343
- id: bandit
4444
args: [ "-c", "pyproject.toml" ]

README.md

Lines changed: 43 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,33 @@
33
Fast batched dataloading of BigWig files containing epigentic track data and corresponding sequences powered by GPU
44
for deep learning applications.
55

6+
> ⚠️ **BREAKING CHANGE (v0.3.0+)**: The output matrix dimensionality has changed from `(n_tracks, batch_size, sequence_length)` to `(batch_size, sequence_length, n_tracks)`. This change was long overdue and eliminates the need for (potentially memory expensive) transpose operations downstream. If you're upgrading from an earlier version, please update your code accordingly (probaby you need to delete one transpose in your code).
7+
8+
> **NEW FEATURE (v0.3.0+)**: Full `bfloat16` support! You can now specify `dtype="bfloat16"` to get output tensors in bfloat16 format, reducing memory usage by 50%.
9+
10+
11+
12+
613
## Quickstart
714

15+
### Installation with Pixi
16+
Using [pixi](https://pixi.sh/) to install bigwig-loader is highly recommended.
17+
Please take a look at the pixi.toml file. If you just want to use bigwig-loader, just
18+
copy that pixi.toml, add the other libraries you need and use the "prod" environment
19+
(you don't need to clone this repo, pixi will download bigwig-loader from the
20+
conda "dataloading" channel):
21+
22+
* Install pixi, if not installed:
23+
```shell
24+
curl -fsSL https://pixi.sh/install.sh | sh
25+
```
26+
27+
* change directory to wherever you put the pixi.toml, and:
28+
```shell
29+
pixi run -e prod <my_training_command>
30+
```
31+
32+
833
### Installation with conda/mamba
934

1035
Bigwig-loader mainly depends on the rapidsai kvikio library and cupy, both of which are best installed using
@@ -65,16 +90,17 @@ dataset = PytorchBigWigDataset(
6590
regions_of_interest=train_regions,
6691
collection=example_bigwigs_directory,
6792
reference_genome_path=reference_genome_file,
68-
sequence_length=1000,
69-
center_bin_to_predict=500,
93+
sequence_length=1000,000,
94+
center_bin_to_predict=500,000,
7095
window_size=1,
71-
batch_size=32,
72-
super_batch_size=1024,
73-
batches_per_epoch=20,
96+
batch_size=1,
97+
super_batch_size=4,
98+
batches_per_epoch=100,
7499
maximum_unknown_bases_fraction=0.1,
75100
sequence_encoder="onehot",
76101
n_threads=4,
77102
return_batch_objects=True,
103+
dtype="bfloat16"
78104
)
79105
80106
# Don't use num_workers > 0 in DataLoader. The heavy
@@ -88,7 +114,7 @@ class MyTerribleModel(torch.nn.Module):
88114
self.linear = torch.nn.Linear(4, 2)
89115
90116
def forward(self, batch):
91-
return self.linear(batch).transpose(1, 2)
117+
return self.linear(batch)
92118
93119
94120
model = MyTerribleModel()
@@ -98,10 +124,10 @@ def poisson_loss(pred, target):
98124
return (pred - target * torch.log(pred.clamp(min=1e-8))).mean()
99125
100126
for batch in dataloader:
101-
# batch.sequences.shape = n_batch (32), sequence_length (1000), onehot encoding (4)
127+
# batch.sequences.shape = n_batch x sequence_length x onehot encoding (4)
102128
pred = model(batch.sequences)
103-
# batch.values.shape = n_batch (32), n_tracks (2) center_bin_to_predict (500)
104-
loss = poisson_loss(pred[:, :, 250:750], batch.values)
129+
# batch.values.shape = n_batch x center_bin_to_predict x n_tracks
130+
loss = poisson_loss(pred[:, 250000:750000, :], batch.values)
105131
print(loss)
106132
optimizer.zero_grad()
107133
loss.backward()
@@ -166,19 +192,23 @@ anything is unclear, please open an issue.
166192
167193
### Environment
168194
195+
The pixi.toml includes a dev environment that has bigwig-loader installed
196+
as an editable pypi dependency.
197+
169198
1. `git clone [email protected]:pfizer-opensource/bigwig-loader`
170199
2. `cd bigwig-loader`
171-
3. create the conda environment" `conda env create -f environment.yml`
172-
4. `pip install -e '.[dev]'`
173-
5. run `pre-commit install` to install the pre-commit hooks
200+
3. optional: `pixi install -e dev`
201+
4. run `pre-commit install` to install the pre-commit hooks
174202
175203
### Run Tests
176204
Tests are in the tests directory. One of the most important tests is
177205
test_against_pybigwig which makes sure that if there is a mistake in
178206
pyBigWIg, it is also in bigwig-loader.
179207
208+
In order to run these tests you need gpu.
209+
180210
```shell
181-
pytest -vv .
211+
pixi run -e dev test
182212
```
183213
184214
When github runners with GPU's will become available we would also

bigwig_loader/batch_processor.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
# from typing import TYPE_CHECKING
44
from typing import Callable
5+
from typing import Literal
56
from typing import Optional
67
from typing import Sequence
78
from typing import Union
@@ -12,6 +13,7 @@
1213

1314
from bigwig_loader.bigwig import BigWig
1415
from bigwig_loader.decompressor import Decoder
16+
from bigwig_loader.default_value import replace_out_tensor_if_needed
1517
from bigwig_loader.functional import load_decode_search
1618
from bigwig_loader.intervals_to_values import intervals_to_values
1719
from bigwig_loader.memory_bank import MemoryBank
@@ -59,19 +61,29 @@ def decoder(self) -> Decoder:
5961
def memory_bank(self) -> MemoryBank:
6062
return MemoryBank(elastic=True)
6163

62-
def _get_out_tensor(self, batch_size: int, sequence_length: int) -> cp.ndarray:
64+
def _get_out_tensor(
65+
self,
66+
batch_size: int,
67+
sequence_length: int,
68+
dtype: Literal["bfloat16", "float32"] = "float32",
69+
) -> cp.ndarray:
6370
"""Resuses a reserved tensor if possible (when out shape is constant),
6471
otherwise creates a new one.
6572
args:
6673
batch_size: batch size
6774
sequence_length: length of genomic sequence
75+
dtype: output dtype ('float32' or 'bfloat16')
6876
returns:
69-
tensor of shape (number of bigwig files, batch_size, sequence_length)
77+
tensor of shape (batch_size, sequence_length, number of bigwig files)
7078
"""
7179

72-
shape = (len(self._bigwigs), batch_size, sequence_length)
73-
if self._out.shape != shape:
74-
self._out = cp.zeros(shape, dtype=cp.float32)
80+
self._out = replace_out_tensor_if_needed(
81+
self._out,
82+
batch_size=batch_size,
83+
sequence_length=sequence_length,
84+
number_of_tracks=len(self._bigwigs),
85+
dtype=dtype,
86+
)
7587
return self._out
7688

7789
def preprocess(
@@ -105,6 +117,7 @@ def get_batch(
105117
window_size: int = 1,
106118
scaling_factors_cupy: Optional[cp.ndarray] = None,
107119
default_value: float = 0.0,
120+
dtype: Literal["float32", "bfloat16"] = "float32",
108121
out: Optional[cp.ndarray] = None,
109122
) -> cp.ndarray:
110123
(
@@ -139,9 +152,10 @@ def get_batch(
139152
query_ends=abs_end,
140153
window_size=window_size,
141154
default_value=default_value,
155+
dtype=dtype,
142156
out=out,
143157
)
144-
batch = cp.transpose(out, (1, 0, 2))
158+
# batch = cp.transpose(out, (1, 0, 2))
145159
if scaling_factors_cupy is not None:
146-
batch *= scaling_factors_cupy
147-
return batch
160+
out *= scaling_factors_cupy
161+
return out

bigwig_loader/bigwig.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -424,7 +424,7 @@ def _guess_max_rows_per_chunk(
424424
data_offsets = self.rtree_leaf_nodes["data_offset"]
425425
data_sizes = self.rtree_leaf_nodes["data_size"]
426426
if len(data_offsets) > sample_size:
427-
sample_indices = sample(range(len(data_offsets)), sample_size)
427+
sample_indices = sample(range(len(data_offsets)), sample_size) # nosec
428428
data_offsets = data_offsets[sample_indices]
429429
data_sizes = data_sizes[sample_indices]
430430

bigwig_loader/collection.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def reset_gpu(self) -> None:
9797
need to be recreated on the new gpu.
9898
"""
9999

100-
self._out = cp.zeros((len(self), 1, 1), dtype=cp.float32)
100+
# self._out = cp.zeros(1, (len(self), 1), dtype=cp.float32)
101101
if "decoder" in self.__dict__:
102102
del self.__dict__["decoder"]
103103
if "memory_bank" in self.__dict__:
@@ -131,7 +131,7 @@ def batch_processor(self) -> BatchProcessor:
131131
@cached_property
132132
def scaling_factors_cupy(self) -> cp.ndarray:
133133
return cp.asarray(self._scaling_factors, dtype=cp.float32).reshape(
134-
1, len(self._scaling_factors), 1
134+
1, 1, len(self._scaling_factors)
135135
)
136136

137137
def get_batch(

bigwig_loader/dataset.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,12 @@ class BigWigDataset:
8787
tracks in case sub_sample_tracks is set. Should be Iterable batches of track indices.
8888
return_batch_objects: if True, the batches will be returned as instances of
8989
bigwig_loader.batch.Batch
90+
dtype: float32 or bfloat16 output encoding of the target values (not the sequence encoding).
91+
Cupy does not support bfloat16 yet, but the cuda kernel that creates the target values
92+
does. When bfloat16 is choosen, the cupy array will show to have the data type uint16
93+
which can, for example, be converted to a torch.bfloat16 by
94+
torch_tensor = torch.as_tensor(out) # torch uint16
95+
torch_tensor = torch_tensor.view(torch.bfloat16) # Reinterpret as bfloat16
9096
"""
9197

9298
def __init__(
@@ -107,7 +113,7 @@ def __init__(
107113
] = "onehot",
108114
file_extensions: Sequence[str] = (".bigWig", ".bw"),
109115
crawl: bool = True,
110-
scale: Optional[dict[Union[str | Path], Any]] = None,
116+
scale: Optional[dict[Union[str, Path], Any]] = None,
111117
default_value: float = 0.0,
112118
first_n_files: Optional[int] = None,
113119
position_sampler_buffer_size: int = 100000,
@@ -117,6 +123,7 @@ def __init__(
117123
custom_position_sampler: Optional[Iterable[tuple[str, int]]] = None,
118124
custom_track_sampler: Optional[Iterable[list[int]]] = None,
119125
return_batch_objects: bool = False,
126+
dtype: Literal["float32", "bfloat16"] = "float32",
120127
):
121128
super().__init__()
122129

@@ -176,6 +183,8 @@ def __init__(
176183
else:
177184
self._track_sampler = None
178185

186+
self._dtype = dtype
187+
179188
def _create_dataloader(self) -> StreamedDataloader:
180189
sequence_sampler = GenomicSequenceSampler(
181190
reference_genome_path=self.reference_genome_path,
@@ -199,6 +208,7 @@ def _create_dataloader(self) -> StreamedDataloader:
199208
slice_size=self.batch_size,
200209
window_size=self.window_size,
201210
default_value=self._default_value,
211+
dtype=self._dtype,
202212
)
203213

204214
def __iter__(

0 commit comments

Comments
 (0)