Skip to content

Commit 8b807fa

Browse files
danieldkMekkCyber
andauthored
Support functions as layers (#188)
This change adds two types of new functionality. First of all, it introduces the `(Locked|Local)?FuncRepo` classes these can be used to extend a layer with a kernel function. For instance, a layer like ``` @use_kernel_forward_from_hub("SiluAndMul") class SiluAndMul(nn.Module): def forward(self, input: torch.Tensor) -> torch.Tensor: d = input.shape[-1] // 2 return F.silu(input[..., :d]) * input[..., d:] ``` can now also be kernelized using a function `silu_and_mul` from the Hub: ``` with use_kernel_mapping({ "SiluAndMul": { "cuda": FuncRepository( repo_id="kernels-community/activation", func_name="silu_and_mul", ), } }): kernelize(...) ``` This makes it easier to kernelize pure layers (layers that do not use module state), since the Hub kernel does not have to provide a `layers` Python module with wrappers. Secondly, we introduce a decorator `use_kernel_func_from_hub` that turns functions into layers that can be kernelized. For example: ``` @use_kernel_forward_from_hub("silu_and_mul") def silu_and_mul(x: torch.Tensor) -> torch.Tensor: d = x.shape[-1] // 2 return F.silu(x[..., :d]) * x[..., d:] ``` will implicitly create an instance of the following class: ``` class Func(nn.Module): # We add some magic to preserve the function's signature. def forward(self, *args, **kwargs): return silu_and_mul(*args, **kwargs) ``` Due to the `__call__` implementation of `nn.Module`, the instance still behaves as a function: ``` out = silu_and_mul(x) ``` However, when the function is used as a member of an `nn.Module`, it will be kernelized: ``` class FeedForward(nn.Module): def __init__(self, in_features: int, out_features: int): self.linear = nn.Linear(in_features, out_features) # Note: silu_and_mul is a Torch module. self.silu_and_mul = silu_and_mul def forward(self, x: torch.Tensor) -> torch.Tensor: return self.silu_and_mul(self.linear(x)) ``` Co-authored-by: Mohamed Mekkouri <[email protected]>
1 parent 5b0067b commit 8b807fa

File tree

14 files changed

+672
-74
lines changed

14 files changed

+672
-74
lines changed

docs/source/api/layers.md

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,10 @@
66

77
[[autodoc]] kernels.use_kernel_forward_from_hub
88

9+
### use_kernel_func_from_hub
10+
11+
[[autodoc]] kernels.use_kernel_func_from_hub
12+
913
### replace_kernel_forward_from_hub
1014

1115
[[autodoc]] kernels.replace_kernel_forward_from_hub
@@ -36,14 +40,26 @@
3640

3741
[[autodoc]] kernels.Mode
3842

43+
### FuncRepository
44+
45+
[[autodoc]] kernels.FuncRepository
46+
3947
### LayerRepository
4048

4149
[[autodoc]] kernels.LayerRepository
4250

51+
### LocalFuncRepository
52+
53+
[[autodoc]] kernels.LocalFuncRepository
54+
4355
### LocalLayerRepository
4456

4557
[[autodoc]] kernels.LocalLayerRepository
4658

59+
### LockedFuncRepository
60+
61+
[[autodoc]] kernels.LockedFuncRepository
62+
4763
### LockedLayerRepository
4864

4965
[[autodoc]] kernels.LockedLayerRepository

docs/source/layers.md

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,36 @@ replace_kernel_forward_from_hub(SiluAndMul, "SiluAndMul")
4343
it signifies that the maintainer intends to keep the `forward` signature
4444
compatible with layers from the hub.
4545

46+
### Using a function as a layer
47+
48+
Sometimes it can be useful to make a function extensible, for example
49+
because the function cannot be replaced by a layer. In such cases, you
50+
can annotate the function with the `use_kernel_func_from_hub` decorator:
51+
52+
```python
53+
@use_kernel_func_from_hub("silu_and_mul")
54+
def silu_and_mul(x: torch.Tensor) -> torch.Tensor:
55+
d = x.shape[-1] // 2
56+
return F.silu(x[..., :d]) * x[..., d:]
57+
```
58+
59+
This will replace the function by an instantiated `torch.nn.Module`
60+
(singleton) that calls the function itself in its forward method.
61+
62+
**Note:** for kernelization to see the function, it must be a member of
63+
another `torch.nn.Module` that is part of the model. For example:
64+
65+
```python
66+
class FeedForward(nn.Module):
67+
def __init__(self, in_features: int, out_features: int):
68+
self.linear = nn.Linear(in_features, out_features)
69+
# Note: silu_and_mul is a Torch module.
70+
self.silu_and_mul = silu_and_mul
71+
72+
def forward(self, x: torch.Tensor) -> torch.Tensor:
73+
return self.silu_and_mul(self.linear(x))
74+
```
75+
4676
## Kernelizing a model
4777

4878
A model will not use Hub kernels by default, even if it contains extensible
@@ -157,6 +187,21 @@ with use_kernel_mapping(kernel_layer_mapping):
157187
This ensures that the mapping is not active anymore outside the
158188
`with`-scope.
159189

190+
If the layer is stateless (it does not use member variables in its forward _or_ it was
191+
originally a function that was converted into a kernel layer with
192+
`use_kernel_func_from_hub`), it can also be mapped to a kernel function:
193+
194+
```python
195+
kernel_layer_mapping = {
196+
"SiluAndMul": {
197+
"cuda": FuncRepository(
198+
repo_id="kernels-community/activation",
199+
func_name="silu_and_mul",
200+
),
201+
}
202+
}
203+
```
204+
160205
### Using version bounds
161206

162207
Kernels are versioned using tags of the form `v<major>.<minor>.<patch>`.

src/kernels/__init__.py

Lines changed: 13 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,22 @@
22

33
__version__ = importlib.metadata.version("kernels")
44

5-
from kernels.layer import Device, CUDAProperties
6-
from kernels.layer import kernelize, register_kernel_mapping, use_kernel_mapping
7-
from kernels.layer import Mode
85
from kernels.layer import (
6+
CUDAProperties,
7+
Device,
8+
FuncRepository,
99
LayerRepository,
10+
LocalFuncRepository,
1011
LocalLayerRepository,
12+
LockedFuncRepository,
1113
LockedLayerRepository,
14+
Mode,
15+
kernelize,
16+
register_kernel_mapping,
1217
replace_kernel_forward_from_hub,
1318
use_kernel_forward_from_hub,
19+
use_kernel_func_from_hub,
20+
use_kernel_mapping,
1421
)
1522
from kernels.utils import (
1623
get_kernel,
@@ -25,8 +32,11 @@
2532
"__version__",
2633
"CUDAProperties",
2734
"Device",
35+
"FuncRepository",
2836
"LayerRepository",
37+
"LocalFuncRepository",
2938
"LocalLayerRepository",
39+
"LockedFuncRepository",
3040
"LockedLayerRepository",
3141
"Mode",
3242
"get_kernel",
@@ -38,7 +48,6 @@
3848
"load_kernel",
3949
"register_kernel_mapping",
4050
"replace_kernel_forward_from_hub",
41-
"replace_kernel_func_from_hub",
4251
"use_kernel_forward_from_hub",
4352
"use_kernel_func_from_hub",
4453
"use_kernel_mapping",

src/kernels/layer/__init__.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
from .device import Device, CUDAProperties
1+
from .device import CUDAProperties, Device
2+
from .func import (
3+
FuncRepository,
4+
LocalFuncRepository,
5+
LockedFuncRepository,
6+
use_kernel_func_from_hub,
7+
)
28
from .kernelize import (
39
kernelize,
410
register_kernel_mapping,
@@ -16,13 +22,17 @@
1622
__all__ = [
1723
"CUDAProperties",
1824
"Device",
25+
"FuncRepository",
1926
"LayerRepository",
27+
"LocalFuncRepository",
2028
"LocalLayerRepository",
29+
"LockedFuncRepository",
2130
"LockedLayerRepository",
2231
"Mode",
2332
"kernelize",
2433
"register_kernel_mapping",
2534
"replace_kernel_forward_from_hub",
2635
"use_kernel_forward_from_hub",
36+
"use_kernel_func_from_hub",
2737
"use_kernel_mapping",
2838
]

0 commit comments

Comments
 (0)