When Zygote imagines gradients

Automatic differentiation (AD) is a new kind of black magic that enables entire programs to be differentiated. Any old fool can tell you that the derivative of (x3+1)(x2+1)(x+1) (x^3 + 1) (x^2 + 1) (x + 1) at x=2x = 2 is 333333, but having a computer casually tell you so is surreal, especially when there's a loop involved:

1
2
3
4
5
6
7
8
9
julia> using Zygote # v0.6.10
julia> gradient(2.0) do x
           s = 1.0
           for n in 1:3
               s += s * x^n
           end
           s
       end
(333.0,)

Unlike other varieties of black magic, which tend to rely on clever tricks, AD is constructed predominantly from edge cases. Needless to say, things sometimes get weird.

In the following, we're going to take a look into why Zygote can give spurious imaginary components for gradients of real-valued functions of real variables:

1
2
3
4
5
6
7
8
julia> gradient(1.0) do x
           x^4 + x^2
       end
(6.0,)
julia> gradient(1.0) do x
           abs2(x^2 - im * x)
       end
(6.0 - 2.0im,)

Chain rule

At the heart of AD lies the chain rule of calculus, which instructs us to multiply intermediate derivatives to build up more interesting ones. To be concrete, we'll use the function f:R2Rf : \mathbb{R}^2 \to \mathbb{R}, f(g,h)=g2+h2, f(g, h) = g^2 + h^2, along with the two functions g,h:RRg, h : \mathbb{R} \to \mathbb{R}: g(x)=x2,h(x)=x. \begin{aligned} g(x) &= x^2, \\ h(x) &= -x. \end{aligned} The computational graph of this setup is very simple: Computational graph of f(g(x), h(x)): x -> g (real), x -> h (real), g -> f (real), h -> f (real). As indicated at the edges, each of the four partial derivatives is real-valued.

Using the chain rule, we see that the derivative of ff with respect to xx is fx=fggx+fhhx=2g2x+2h(1)=4x3+2x. \begin{aligned} \pdv{f}{x} &= \pdv{f}{g} \pdv{g}{x} + \pdv{f}{h} \pdv{h}{x} = 2 g 2 x + 2 h (-1) = 4 x^3 + 2 x. \end{aligned} This holds at x=1x = 1, so it must be true for all xx:

1
2
3
4
5
6
7
julia> f1(g, h) = g^2 + h^2;
julia> g1(x) = x^2;
julia> h1(x) = -x;
julia> gradient(1.0) do x
           f1(g1(x), h1(x))
       end
(6.0,)

Complex numbers as pairs of reals

The above example is equivalent to the function f:CRf : \mathbb{C} \to \mathbb{R}, f(z)=z2=((z))2+((z))2, f(z) = |z|^2 = (\Re(z))^2 + (\Im(z))^2, whose input is given by the function z:RCz : \mathbb{R} \to \mathbb{C}: z(x)=x2ix, \begin{aligned} z(x) &= x^2 - i x, \end{aligned} with (z(x))=x2,(z(x))=x. \begin{aligned} \Re(z(x)) &= x^2, \\ \Im(z(x)) &= -x. \end{aligned} Despite the additional complexity, the computational graph looks identical: Computational graph of f(z(x)): x -> Re (real), x -> Im (real), Re -> f (real), Im -> f (real). In particular, the partial derivatives are all still real, as all we've done is change two labels.

Treating ff as a function of the two real values (z)\Re(z) and (z)\Im(z), we can apply the chain rule and get the same answer as above: fx=f(z)(z)x+f(z)(z)x=2(z)2x+2(z)(1)=4x3+2x. \begin{aligned} \pdv{f}{x} &= \pdv{f}{\Re(z)} \pdv{\Re(z)}{x} + \pdv{f}{\Im(z)} \pdv{\Im(z)}{x} = 2 \Re(z) 2 x + 2 \Im(z) (-1) = 4 x^3 + 2 x. \end{aligned} Since this is merely a relabelling of the previous version (gg became \Re, hh became \Im), we don't expect any difference; yet we see

1
2
3
4
5
6
julia> f2(z) = abs2(z);
julia> z2(x) = x^2 - im * x;
julia> gradient(1.0) do x
           f2(z2(x))
       end
(6.0 - 2.0im,)

The real part is the same, but there's an unexpected imaginary part, which is not quite right. Zygote clearly doesn't split zz into real and imaginary components as we've done here.

Wirtinger derivatives

In the context of AD, zz (the argument of ff) is a single node in the computational graph. Therefore, ff should only have one arrow pointing at it, unlike in the previous arrangements we considered. That would require having definitions for z/x\partial z / \partial x and f/z\partial f / \partial z, which are not obvious in the general (non-holomorphic) case. Since z=(z)+i(z), z = \Re(z) + i \Im(z), by linearity, we should probably write zx=(z)x+i(z)x \pdv{z}{x} = \pdv{\Re(z)}{x} + i \pdv{\Im(z)}{x} for the former. However, it's not immediately clear what to do about the latter.

Thankfully, this problem has a solution in the form of Wirtinger derivatives. In short, one defines fz=12(f(z)if(z)), \pdv{f}{z} = \frac{1}{2} \left( \pdv{f}{\Re(z)} - i \pdv{f}{\Im(z)} \right), as well as the conjugate expression fz=12(f(z)+if(z)). \pdv{f}{z^*} = \frac{1}{2} \left( \pdv{f}{\Re(z)} + i \pdv{f}{\Im(z)} \right). To use these properly, we should think about a change of variables not from gg and hh to just zz, but rather to zz and its complex conjugate zz^*: f(g,h)=f(z,z). f(g, h) = f(z, z^*). Doing so results in this graph, which features complex-valued derivatives: Computational graph of f(z(x), z*(x)): x -> z (complex), x -> z* (complex), z -> f (complex), z* -> f (complex). Immediately, we see that this won't work, as there are again two arrows to ff.

Undeterred, we press on. With these definitions in hand, formal application of the chain rule results in fx=fzzx+fzzx=12(f(z)if(z))((z)x+i(z)x)+12(f(z)+if(z))((z)xi(z)x)=f(z)(z)x+f(z)(z)x. \begin{aligned} \pdv{f}{x} &= \pdv{f}{z} \pdv{z}{x} + \pdv{f}{z^*} \pdv{z^*}{x} \\ &= \frac{1}{2} \left( \pdv{f}{\Re(z)} - i \pdv{f}{\Im(z)} \right) \left( \pdv{\Re(z)}{x} + i \pdv{\Im(z)}{x} \right) \\ &\qquad + \frac{1}{2} \left( \pdv{f}{\Re(z)} + i \pdv{f}{\Im(z)} \right) \left( \pdv{\Re(z)}{x} - i \pdv{\Im(z)}{x} \right) \\ &= \pdv{f}{\Re(z)} \pdv{\Re(z)}{x} + \pdv{f}{\Im(z)} \pdv{\Im(z)}{x}. \end{aligned} All the cross terms cancel! Despite the individual derivatives being complex, the final result is strictly real. That's great and all, but how can we achieve this without introducing an additional node?

Zygote

If we take a glance at the documentation for Zygote, we see that it uses the definition fz=f(z)+if(z), \pdv{f}{z^*} = \pdv{f}{\Re(z)} + i \pdv{f}{\Im(z)}, which (up to a factor) is one of the Wirtinger derivatives. The other one is nowhere to be found, and that's admittedly a straightforward way to get rid of the extra variable. The computational graph can finally be what we've been striving for all along: Computational graph of f(z(x)): x -> z (complex), z -> f (complex).

If it was that easy, Wirtinger wouldn't have bothered coming up with two expressions, so there must be a consequence to this. It's not hard to see that fzzx=(f(z)+if(z))((z)xi(z)x)=f(z)(z)x+f(z)(z)xif(z)(z)x+if(z)(z)x \begin{aligned} \pdv{f}{z^*} \pdv{z^*}{x} &= \left( \pdv{f}{\Re(z)} + i \pdv{f}{\Im(z)} \right) \left( \pdv{\Re(z)}{x} - i \pdv{\Im(z)}{x} \right) \\ &= \pdv{f}{\Re(z)} \pdv{\Re(z)}{x} + \pdv{f}{\Im(z)} \pdv{\Im(z)}{x} \\ &\qquad - i \pdv{f}{\Re(z)} \pdv{\Im(z)}{x} + i \pdv{f}{\Im(z)} \pdv{\Re(z)}{x} \end{aligned} has an unwanted imaginary part from the cross terms. (If you're wondering about all the complex conjugation that's going on, it's because Zygote actually computes the conjugate of the gradient (the "adjoint"), so everything gets conjugated along the way. See, for example, the definition for literal_pow.)

For our concrete example, this yields fzzx=2(z)2x+2(z)(1)i2(z)(1)+i2(z)2x=4x3+2x2ix2, \begin{aligned} \pdv{f}{z^*} \pdv{z^*}{x} &= 2 \Re(z) 2 x + 2 \Im(z) (-1) - i 2 \Re(z) (-1) + i 2 \Im(z) 2 x \\ &= 4 x^3 + 2 x - 2 i x^2, \end{aligned} which differs from the desired result by 2ix2-2 i x^2. That's exactly what we had computed earlier at x=1x = 1:

1
2
3
4
julia> gradient(1.0) do x
           f2(z2(x))
       end
(6.0 - 2.0im,)

Because the real part is precisely what we're after, we can simply discard the rest:

1
2
3
4
julia> gradient(1.0) do x
           f2(z2(x))
       end .|> real
(6.0,)

This riddance is justified by the fact that (fzzx)=f(z)(z)x+f(z)(z)x=fx. \begin{aligned} \Re\left( \pdv{f}{z^*} \pdv{z^*}{x} \right) &= \pdv{f}{\Re(z)} \pdv{\Re(z)}{x} + \pdv{f}{\Im(z)} \pdv{\Im(z)}{x} = \pdv{f}{x}. \end{aligned}

Further reading