Skip to content

Commit ea04570

Browse files
committed
update to the more efficient frac-connections
1 parent 756223c commit ea04570

File tree

3 files changed

+15
-7
lines changed

3 files changed

+15
-7
lines changed

README.md

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -265,3 +265,14 @@ $ pip install -U diffusers transformers accelerate scipy ftfy safetensors
265265
url = {https://api.semanticscholar.org/CorpusID:272987528}
266266
}
267267
```
268+
269+
```bibtex
270+
@article{Zhu2025FracConnectionsFE,
271+
title = {Frac-Connections: Fractional Extension of Hyper-Connections},
272+
author = {Defa Zhu and Hongzhi Huang and Jundong Zhou and Zihao Huang and Yutao Zeng and Banggu Wu and Qiyang Min and Xun Zhou},
273+
journal = {ArXiv},
274+
year = {2025},
275+
volume = {abs/2503.14125},
276+
url = {https://api.semanticscholar.org/CorpusID:277104144}
277+
}
278+
```

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
[project]
22
name = "transfusion-pytorch"
3-
version = "0.10.5"
3+
version = "0.11.0"
44
description = "Transfusion in Pytorch"
55
authors = [
66
{ name = "Phil Wang", email = "[email protected]" }

transfusion_pytorch/transfusion.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1007,7 +1007,8 @@ def __init__(
10071007
attn_laser = False,
10081008
unet_skips = True,
10091009
use_flex_attn = False,
1010-
num_residual_streams = 4
1010+
num_residual_streams = 1,
1011+
num_residual_fracs = 4
10111012
):
10121013
super().__init__()
10131014
self.use_flex_attn = use_flex_attn
@@ -1023,13 +1024,9 @@ def __init__(
10231024

10241025
# hyper connections
10251026

1026-
assert num_residual_streams > 0
1027-
is_hyper_connection = num_residual_streams > 1
1028-
self.num_residual_streams = num_residual_streams
1029-
10301027
counter = count()
10311028

1032-
init_residual_fn, self.expand_stream, self.reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(num_residual_streams, disable = num_residual_streams == 1)
1029+
init_residual_fn, self.expand_stream, self.reduce_stream = HyperConnections.get_init_and_expand_reduce_stream_functions(num_residual_streams, num_fracs = num_residual_fracs)
10331030

10341031
# layers
10351032

0 commit comments

Comments
 (0)