@@ -45,6 +45,20 @@ wrapped by `x`, such that mutating `extern(x)` might mutate `x` itself.
45
45
# #### `AbstractWirtinger`
46
46
# ####
47
47
48
+ """
49
+ AbstractWirtinger <: AbstractDifferential
50
+
51
+ Represents the differential of a non-holomorphic function taking complex input.
52
+
53
+ All subtypes implement [`wirtinger_primal`](@ref) and [`wirtinger_conjugate`](@ref).
54
+
55
+ All subtypes wrap real/holomorphic differentials, and should always be the outermost wrapper.
56
+ E.g., a typical differential would look like this:
57
+ ```
58
+ Wirtinger(@thunk(::AbstractArray{Number}), @thunk(::AbstractArray{<:Number}))
59
+ ```
60
+ `@thunk` and `AbstractArray` are, of course, optional.
61
+ """
48
62
abstract type AbstractWirtinger <: AbstractDifferential end
49
63
50
64
wirtinger_primal (x) = x
@@ -64,7 +78,7 @@ Base.conj(x::AbstractWirtinger) = throw(MethodError(conj, x))
64
78
# ####
65
79
66
80
"""
67
- Wirtinger(primal, conjugate)
81
+ Wirtinger(primal, conjugate) <: [`AbstractWirtinger`](@ref)
68
82
69
83
Returns a `Wirtinger` instance representing the complex differential:
70
84
@@ -95,12 +109,23 @@ Base.iterate(::Wirtinger, ::Any) = nothing
95
109
# #### `ComplexGradient`
96
110
# ####
97
111
112
+ """
113
+ ComplexGradient(val) <: [`AbstractWirtinger`](@ref)
114
+
115
+ Returns a `ComplexGradient` instance representing the complex differential:
116
+
117
+ ```
118
+ df = ∂f/∂Re(z) * dRe(z) + im * ∂f/∂Im(z) * dIm(z)
119
+ ```
120
+
121
+ where `f` is a `ℂ(^n) -> ℝ(^m)` function and `val` corresponds to `df`.
122
+ """
98
123
struct ComplexGradient{T} <: AbstractWirtinger
99
124
val:: T
100
125
end
101
126
102
127
wirtinger_primal (x:: ComplexGradient ) = conj (wirtinger_conjugate (x))
103
- wirtinger_conjugate (x:: ComplexGradient ) = x. val / 2
128
+ wirtinger_conjugate (x:: ComplexGradient ) = ( 1 // 2 ) * x. val
104
129
105
130
Base. Broadcast. broadcastable (x:: ComplexGradient ) = ComplexGradient (broadcastable (x. val))
106
131
@@ -268,6 +293,12 @@ struct Thunk{F} <: AbstractThunk
268
293
f:: F
269
294
end
270
295
296
+ """
297
+ @thunk body
298
+
299
+ Returns `Thunk(() -> body)`, except for when `body` is a call to (`Wirtinger`)[@ref] or (`ComplexGradient`)[@ref].
300
+ In this case, it is equivalent to `Wirtinger(@thunk(primal), @thunk(conjugate))` / `ComplexGradient(@thunk primal)`.
301
+ """
271
302
macro thunk (body)
272
303
if body isa Expr && body. head == :call
273
304
fname = body. args[1 ]
0 commit comments