Skip to content

Commit b63f6ee

Browse files
committed
add some docstrings
1 parent 491f781 commit b63f6ee

File tree

1 file changed

+33
-2
lines changed

1 file changed

+33
-2
lines changed

src/differentials.jl

Lines changed: 33 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,20 @@ wrapped by `x`, such that mutating `extern(x)` might mutate `x` itself.
4545
##### `AbstractWirtinger`
4646
#####
4747

48+
"""
49+
AbstractWirtinger <: AbstractDifferential
50+
51+
Represents the differential of a non-holomorphic function taking complex input.
52+
53+
All subtypes implement [`wirtinger_primal`](@ref) and [`wirtinger_conjugate`](@ref).
54+
55+
All subtypes wrap real/holomorphic differentials, and should always be the outermost wrapper.
56+
E.g., a typical differential would look like this:
57+
```
58+
Wirtinger(@thunk(::AbstractArray{Number}), @thunk(::AbstractArray{<:Number}))
59+
```
60+
`@thunk` and `AbstractArray` are, of course, optional.
61+
"""
4862
abstract type AbstractWirtinger <: AbstractDifferential end
4963

5064
wirtinger_primal(x) = x
@@ -64,7 +78,7 @@ Base.conj(x::AbstractWirtinger) = throw(MethodError(conj, x))
6478
#####
6579

6680
"""
67-
Wirtinger(primal, conjugate)
81+
Wirtinger(primal, conjugate) <: [`AbstractWirtinger`](@ref)
6882
6983
Returns a `Wirtinger` instance representing the complex differential:
7084
@@ -95,12 +109,23 @@ Base.iterate(::Wirtinger, ::Any) = nothing
95109
##### `ComplexGradient`
96110
#####
97111

112+
"""
113+
ComplexGradient(val) <: [`AbstractWirtinger`](@ref)
114+
115+
Returns a `ComplexGradient` instance representing the complex differential:
116+
117+
```
118+
df = ∂f/∂Re(z) * dRe(z) + im * ∂f/∂Im(z) * dIm(z)
119+
```
120+
121+
where `f` is a `ℂ(^n) -> ℝ(^m)` function and `val` corresponds to `df`.
122+
"""
98123
struct ComplexGradient{T} <: AbstractWirtinger
99124
val::T
100125
end
101126

102127
wirtinger_primal(x::ComplexGradient) = conj(wirtinger_conjugate(x))
103-
wirtinger_conjugate(x::ComplexGradient) = x.val / 2
128+
wirtinger_conjugate(x::ComplexGradient) = (1//2) * x.val
104129

105130
Base.Broadcast.broadcastable(x::ComplexGradient) = ComplexGradient(broadcastable(x.val))
106131

@@ -268,6 +293,12 @@ struct Thunk{F} <: AbstractThunk
268293
f::F
269294
end
270295

296+
"""
297+
@thunk body
298+
299+
Returns `Thunk(() -> body)`, except for when `body` is a call to (`Wirtinger`)[@ref] or (`ComplexGradient`)[@ref].
300+
In this case, it is equivalent to `Wirtinger(@thunk(primal), @thunk(conjugate))` / `ComplexGradient(@thunk primal)`.
301+
"""
271302
macro thunk(body)
272303
if body isa Expr && body.head == :call
273304
fname = body.args[1]

0 commit comments

Comments
 (0)