Skip to content

Commit d54ee94

Browse files
committed
compat, add Flux
1 parent fcd0e56 commit d54ee94

File tree

2 files changed

+7
-14
lines changed

2 files changed

+7
-14
lines changed

Project.toml

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -5,8 +5,7 @@ version = "0.2.1"
55

66
[deps]
77
BatchedTransformations = "8ba27c4b-52b5-4b10-bc66-a4fda05aa11b"
8-
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
9-
Optimisers = "3bd65402-5787-11e9-1adc-39752487f4e2"
8+
Flux = "587475ba-b771-5e3f-ad9e-33799f191a9c"
109

1110
[weakdeps]
1211
GraphNeuralNetworks = "cffab07f-9bc2-4db1-8861-388f63bf7694"
@@ -16,9 +15,8 @@ GraphNeuralNetworksExt = "GraphNeuralNetworks"
1615

1716
[compat]
1817
BatchedTransformations = "0.5"
19-
Functors = "0.4, 0.5"
20-
GraphNeuralNetworks = "0.6"
21-
Optimisers = "0.3, 0.4"
18+
Flux = "0.14, 0.15, 0.16"
19+
GraphNeuralNetworks = "0.6, 1.0"
2220
julia = "1.10"
2321

2422
[extras]

src/RandomFeatureMaps.jl

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,7 @@ export RandomFourierFeatures
44
export RandomOrientationFeatures
55
export rand_rigid, get_rigid
66

7-
using Functors: @functor
8-
import Optimisers
9-
7+
using Flux
108
using BatchedTransformations
119

1210
sumdrop(f, A::AbstractArray; dims) = dropdims(sum(f, A; dims); dims)
@@ -35,8 +33,7 @@ struct RandomFourierFeatures{T<:Real,A<:AbstractMatrix{T}}
3533
W::A
3634
end
3735

38-
@functor RandomFourierFeatures
39-
Optimisers.trainable(::RandomFourierFeatures) = (;) # no trainable parameters
36+
Flux.@layer RandomFourierFeatures trainable=()
4037

4138
RandomFourierFeatures(dims::Pair{<:Integer, <:Integer}, σ::Real) = RandomFourierFeatures(dims, float(σ))
4239

@@ -115,8 +112,7 @@ struct RandomOrientationFeatures{A<:AbstractArray{<:Real}}
115112
FB::A
116113
end
117114

118-
@functor RandomOrientationFeatures
119-
Optimisers.trainable(::RandomOrientationFeatures) = (;) # no trainable parameters
115+
Flux.@layer RandomOrientationFeatures trainable=()
120116

121117
"""
122118
RandomOrientationFeatures(m, σ)
@@ -151,8 +147,7 @@ handle batch dimensions.
151147
The transformation gets applied according to `NNlib.batched_mul(R, x) .+ t`
152148
"""
153149
function get_rigid(R::AbstractArray, t::AbstractArray)
154-
batch_size = size(R)[3:end]
155-
t = reshape(t, 3, 1, batch_size...)
150+
t = reshape(t, 3, 1, size(R)[3:end]...)
156151
Translation(t) Rotation(R)
157152
end
158153

0 commit comments

Comments
 (0)