When Zygote imagines gradients
2021-05-02Automatic differentiation (AD) is a new kind of black magic that enables entire programs to be differentiated. Any old fool can tell you that the derivative of at is , but having a computer casually tell you so is surreal, especially when there's a loop involved:
1 2 3 4 5 6 7 8 9 | julia> using Zygote # v0.6.10
julia> gradient(2.0) do x
s = 1.0
for n in 1:3
s += s * x^n
end
s
end
(333.0,)
|
Unlike other varieties of black magic, which tend to rely on clever tricks, AD is constructed predominantly from edge cases. Needless to say, things sometimes get weird.
In the following, we're going to take a look into why Zygote can give spurious imaginary components for gradients of real-valued functions of real variables:
1 2 3 4 5 6 7 8 | julia> gradient(1.0) do x
x^4 + x^2
end
(6.0,)
julia> gradient(1.0) do x
abs2(x^2 - im * x)
end
(6.0 - 2.0im,)
|
Chain rule
At the heart of AD lies the chain rule of calculus, which instructs us to multiply intermediate derivatives to build up more interesting ones. To be concrete, we'll use the function , along with the two functions : The computational graph of this setup is very simple: As indicated at the edges, each of the four partial derivatives is real-valued.
Using the chain rule, we see that the derivative of with respect to is This holds at , so it must be true for all :
1 2 3 4 5 6 7 | julia> f1(g, h) = g^2 + h^2;
julia> g1(x) = x^2;
julia> h1(x) = -x;
julia> gradient(1.0) do x
f1(g1(x), h1(x))
end
(6.0,)
|
Complex numbers as pairs of reals
The above example is equivalent to the function , whose input is given by the function : with Despite the additional complexity, the computational graph looks identical: In particular, the partial derivatives are all still real, as all we've done is change two labels.
Treating as a function of the two real values and , we can apply the chain rule and get the same answer as above: Since this is merely a relabelling of the previous version ( became , became ), we don't expect any difference; yet we see
1 2 3 4 5 6 | julia> f2(z) = abs2(z);
julia> z2(x) = x^2 - im * x;
julia> gradient(1.0) do x
f2(z2(x))
end
(6.0 - 2.0im,)
|
The real part is the same, but there's an unexpected imaginary part, which is not quite right. Zygote clearly doesn't split into real and imaginary components as we've done here.
Wirtinger derivatives
In the context of AD, (the argument of ) is a single node in the computational graph. Therefore, should only have one arrow pointing at it, unlike in the previous arrangements we considered. That would require having definitions for and , which are not obvious in the general (non-holomorphic) case. Since by linearity, we should probably write for the former. However, it's not immediately clear what to do about the latter.
Thankfully, this problem has a solution in the form of Wirtinger derivatives. In short, one defines as well as the conjugate expression To use these properly, we should think about a change of variables not from and to just , but rather to and its complex conjugate : Doing so results in this graph, which features complex-valued derivatives: Immediately, we see that this won't work, as there are again two arrows to .
Undeterred, we press on. With these definitions in hand, formal application of the chain rule results in All the cross terms cancel! Despite the individual derivatives being complex, the final result is strictly real. That's great and all, but how can we achieve this without introducing an additional node?
Zygote
If we take a glance at the documentation for Zygote, we see that it uses the definition which (up to a factor) is one of the Wirtinger derivatives. The other one is nowhere to be found, and that's admittedly a straightforward way to get rid of the extra variable. The computational graph can finally be what we've been striving for all along:
If it was that easy, Wirtinger wouldn't have bothered coming up with two expressions, so there must be a consequence to this.
It's not hard to see that
has an unwanted imaginary part from the cross terms.
(If you're wondering about all the complex conjugation that's going on, it's because Zygote actually computes the conjugate of the gradient (the "adjoint"), so everything gets conjugated along the way.
See, for example, the definition for literal_pow
.)
For our concrete example, this yields which differs from the desired result by . That's exactly what we had computed earlier at :
1 2 3 4 | julia> gradient(1.0) do x
f2(z2(x))
end
(6.0 - 2.0im,)
|
Because the real part is precisely what we're after, we can simply discard the rest:
1 2 3 4 | julia> gradient(1.0) do x
f2(z2(x))
end .|> real
(6.0,)
|
This riddance is justified by the fact that
Further reading
- Zygote.jl #29: Automatic differenciation for complex numbers
- Zygote.jl #142: Complex Number Interfaces
- Zygote.jl #342: Complex gradient on real function with complex intermediates
- ChainRulesCore.jl #159: Complex numbers (comment)
- Discourse: Taking Complex Autodiff Seriously in ChainRules
- ChainRules.jl #196: Correct chainrules for abs2, abs, conj and angle (comment)