@@ -4,9 +4,7 @@ export RandomFourierFeatures
44export RandomOrientationFeatures
55export rand_rigid, get_rigid
66
7- using Functors: @functor
8- import Optimisers
9-
7+ using Flux
108using BatchedTransformations
119
1210sumdrop (f, A:: AbstractArray ; dims) = dropdims (sum (f, A; dims); dims)
@@ -35,8 +33,7 @@ struct RandomFourierFeatures{T<:Real,A<:AbstractMatrix{T}}
3533 W:: A
3634end
3735
38- @functor RandomFourierFeatures
39- Optimisers. trainable (:: RandomFourierFeatures ) = (;) # no trainable parameters
36+ Flux. @layer RandomFourierFeatures trainable= ()
4037
4138RandomFourierFeatures (dims:: Pair{<:Integer, <:Integer} , σ:: Real ) = RandomFourierFeatures (dims, float (σ))
4239
@@ -115,8 +112,7 @@ struct RandomOrientationFeatures{A<:AbstractArray{<:Real}}
115112 FB:: A
116113end
117114
118- @functor RandomOrientationFeatures
119- Optimisers. trainable (:: RandomOrientationFeatures ) = (;) # no trainable parameters
115+ Flux. @layer RandomOrientationFeatures trainable= ()
120116
121117"""
122118 RandomOrientationFeatures(m, σ)
@@ -151,8 +147,7 @@ handle batch dimensions.
151147The transformation gets applied according to `NNlib.batched_mul(R, x) .+ t`
152148"""
153149function get_rigid (R:: AbstractArray , t:: AbstractArray )
154- batch_size = size (R)[3 : end ]
155- t = reshape (t, 3 , 1 , batch_size... )
150+ t = reshape (t, 3 , 1 , size (R)[3 : end ]. .. )
156151 Translation (t) ∘ Rotation (R)
157152end
158153
0 commit comments