-
-
Notifications
You must be signed in to change notification settings - Fork 4
Structural Layers
Structural layers are mid-level layers, like Chain
or SkipConnection
, that have sub-layers as fields of the data structure. They must store references to these sub-layers as well as define how the sub-layers integrate together to produce a forward pass. A ML Community Call lead to a discussion on how to implement and support such layers. The design aspects of that discussion are summarized below.
Other frameworks (e.g. PyTorch) define custom types, like BasicBlock
or InceptionBlock
, to define non-standard layers. This Metalhead.jl PR chooses instead to define functions, like inceptionblock
, that return a Flux model made of only standard components.
This approach makes it easier to pass pre-trained models into functions that dispatch by layer type. For example, a user might want to replace all the activations in Conv
layers with tanh
. This can be achieved by dispatching on Chain
, Conv
, and Any
, if the model only contains standard layers. If a non-standard layer is used, then a special dispatch rule must be added (e.g. on BottleNeck
) to de-sugar it into function calls on the submodules — as is done for standard layers like Chain
.
On the other hand, with this (solely) approach, it becomes difficult to recognize structures like Bottleneck
after they have been created. Moreover, for some blocks, like bottleneck, SkipConnection
can be used. For others, like inception blocks, a (anonymous) closure must be used. Here is a snippet of that in practice:
function inceptionblock(inplanes, out_1x1,
red_3x3, out_3x3,
red_5x5, out_5x5, pool_proj)
branch1 = Chain(
conv_block((1,1), inplanes, out_1x1)...)
branch2 = Chain(
conv_block((1,1), inplanes, red_3x3)...,
conv_block((3,3), red_3x3, out_3x3; pad=1)...)
branch3 = Chain(
conv_block((1,1), inplanes, red_5x5)...,
conv_block((5,5), red_5x5, out_5x5; pad=2)...)
branch4 = Chain(
MaxPool((3, 3), stride=1, pad=1),
conv_block((1,1), inplanes, pool_proj)...)
return x -> begin
y1 = branch1(x)
y2 = branch2(x)
y3 = branch3(x)
y4 = branch4(x)
return cat(y1, y2, y3, y4; dims=3)
end
end
This is the approach used by most frameworks. There are two possible implementations that both introduce a new type (e.g Bottleneck
):
- Create a struct that stores the submodules as separate, unique fields (i.e. what PyTorch does).
- Create a “wrapper” type that stores a reference to the standard-layer implementation of the structure (i.e. what’s returned by the functional approach). Here, the wrapper type defers to the underlying reference for the forward pass, etc. Wrapper types can be standardized in how they wrap the underlying implementation. This allows some assumptions when dispatching on a wrapper.
One known disadvantage to using custom structures is additional dependencies when serializing the model.
Functors are an abstraction for structures like the layers in this discussion (including standard ones like Chain
). In Functors.jl, operations like fmap
allow someone to apply a function to the parameters underlying a complex layer like Chain
. Note that this can be extended to allow shallow traversal of a complex layer. So, instead of going down to parameters, we traverse down to “primitive” layers like Conv
.
Importantly, this method allows the custom type approach to gain some of the benefits of the functional approach. In other words, functors make it so that it is not unclear how to de-sugar an arbitrary layer.
Sometimes, we don’t want to dispatch on a particular structure (e.g. a “branch-reduce” style layer), but on a specific instance of that structure (e.g. a bottleneck). A custom type can allow us to do this. Having custom types also documents the model through the type system. Looking at a Chain
of standard layers only is hard to decipher, but if the sub-layers are labeled “bottleneck,” then a user can recognize their purpose.
This proposal extends Chain
to accept NamedTuple
s which adds the “naming” that is lost when foregoing a custom type. Notably, the change adds a lot of extensibility, addresses some of the issues raised, and seems like it should be easy to implement.
There is a current PR to Flux.jl that adds a Parallel
layer. This takes the “branch-reduce” style layer and turns it into a standard implementation (note: that PR is also a good description of the user-facing issues surrounding this discussion).
There are some additional benefits when paired with other approaches:
- Used with the functional approach, this method allows code to recognize and dispatch on the “branch-reduce” structure within a model.
- Used with the
NamedTuple
approach, this method allowsParallel
layers within a model to be labeled as bottlenecks, etc.
Having a standard type not only provides static structural information (something that the functor approach can do too), but also how the sub-modules integrate to produce a forward pass. None of the other methods provide this information (except custom types, but they lack the same generality as Parallel
).
Here is an example of what the forward pass might look like for Parallel
:
struct Parallel{F, T}
fs::F
op::T
end
# init might need to specified depending on l.op
# init might need to be a field in Parallel
(l::Parallel)(x::AbstractArray) =
mapreduce(f -> f(x), l.op, l.fs; init = zero(x))
Additionally, we can provide a branch(x reduce_op, layers...; config)
function that executes layers
on x
in parallel according to config
(for speed), then reduces the outputs with reduce_op
. The branch
function can be the forward pass for Parallel
:
(l::Parallel)(x) =
branch(x, l.op, l.layers...; config = l.config)