Skip to content

Commit 418cc18

Browse files
committed
add is_non_differentiable
1 parent 2f76da0 commit 418cc18

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

src/ChainRulesCore.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@ export frule_via_ad, rrule_via_ad
1313
export @non_differentiable, @opt_out, @scalar_rule, @thunk, @not_implemented
1414
export ProjectTo, canonicalize, unthunk # tangent operations
1515
export add!! # gradient accumulation operations
16-
export ignore_derivatives, @ignore_derivatives
16+
export ignore_derivatives, @ignore_derivatives, is_non_differentiable
1717
# tangents
1818
export Tangent, NoTangent, InplaceableThunk, Thunk, ZeroTangent, AbstractZero, AbstractThunk
1919

src/projection.jl

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -142,16 +142,42 @@ end
142142
# dx::AbstractArray (when both are possible), or the reverse. So for now we just pass them through:
143143
(::ProjectTo{T})(dx::Tangent{<:T}) where {T} = dx
144144

145+
#####
146+
##### A related utility which wants to live nearby
147+
#####
148+
149+
"""
150+
is_non_differentiable(x) == is_non_differentiable(typeof(x))
151+
152+
Returns `true` if `x` is known from its type not to have derivatives, else `false`.
153+
154+
Should mostly agree with whether `ProjectTo(x)` maps to `AbstractZero`,
155+
which is what the fallback method checks. The exception is that it will not look
156+
inside abstractly typed containers like `x = Any[true, false]`.
157+
"""
158+
is_non_differentiable(x) = is_non_differentiable(typeof(x))
159+
160+
is_non_differentiable(::Type{<:Number}) = false
161+
is_non_differentiable(::Type{<:NTuple{N,T}}) where {N,T} = is_non_differentiable(T)
162+
is_non_differentiable(::Type{<:AbstractArray{T}}) where {T} = is_non_differentiable(T)
163+
164+
function is_non_differentiable(::Type{T}) where {T} # fallback
165+
PT = Base._return_type(ProjectTo, Tuple{T}) # might be Union{} if unstable
166+
return isconcretetype(PT) && PT <: ProjectTo{<:AbstractZero}
167+
end
168+
145169
#####
146170
##### `Base`
147171
#####
148172

149173
# Bool
150174
ProjectTo(::Bool) = ProjectTo{NoTangent}() # same projector as ProjectTo(::AbstractZero) above
175+
is_non_differentiable(::Type{Bool}) = true
151176

152177
# Other never-differentiable types
153178
for T in (:Symbol, :Char, :AbstractString, :RoundingMode, :IndexStyle)
154179
@eval ProjectTo(::$T) = ProjectTo{NoTangent}()
180+
@eval is_non_differentiable(::Type{<:$T}) = true
155181
end
156182

157183
# Numbers

0 commit comments

Comments
 (0)