2

While studying, I noticed that in some lectures notes of mine we hand-wavingly use the chain rule formula to design a gradient descent algorithm with respect to some matrix $W \in \mathbb{R}^{m \times n}$ (i.e. the variable we want to minimize is a $m \times n$ matrix). Curious about how one would define the jacobian with respect to a matrix, I stumbled upon this question and how to represent rank-3-tensors as stacked matrices. With this, we define the Jacobian of a function $f: \mathbb{R}^{m \times n} \rightarrow \mathbb{R}^{l}$ with respect to the variable $W$ as $$\nabla_W f = \left({\frac{\partial f_i}{\partial w_{jk}}}\right)_{ijk} = \left(\left(\frac{\partial f_i}{\partial w_{j k}}\right)_{jk}\right)_{i}$$

Visually, this Jacobian consists of $l$ stacked $m \times n$ matrices each containing the individual $W$ derivatives (but fixing the output components $f_i$).

With this we can try to prove the chain rule for matrix derivatives. Let $f: \mathbb{R}^{m \times n} \rightarrow \mathbb{R}^{l}$ and $g: \mathbb{R}^l \rightarrow \mathbb{R}^s$. The chain rule should be $$J_W(g \circ f) = J(g)(f) \cdot J_W(f)$$

On the one hand, treating a matrix in this multiplication as a "scalar" and elementwise-multiplying it with the rank-3-tensor (seems natural in a module-theoretic sense) would consequently yield $$\frac{\partial g_i}{\partial w_{jk}}(f(W)) = (\nabla g_j)(f(W)) \cdot \left(\frac{\partial f_i}{\partial w_{ak}}(W)\right)_{a} = \sum\limits_{a = 1}^{l} \partial_a g_j(f(W)) \frac{\partial f_i}{\partial w_{ak}}(W)$$

On the other hand, employing the usual chain rule and some mildly sketchy argumentation, we have $$\frac{\partial}{\partial w_{jk}} (g_i(f(W))) = \frac{\partial}{\partial w_{jk}} (g_i(f_1(W),\ldots,f_l(W))) = \sum\limits_{a = 1}^{l} (\partial_a g_i(f(W)))\frac{\partial f_a}{\partial w_{jk}}(W) = (\nabla g_i)(f(W)) \cdot \left(\frac{\partial f_a}{\partial w_{jk}}(W)\right)_a$$

for arbitrary $i,j,k$. This means our Jacobian now looks like

$$\nabla_W (g \circ f) = \left(\left((\nabla g_i)(f(W)) \cdot \left(\frac{\partial f_a}{\partial w_{jk}}(W)\right)_a\right)_{jk}\right)_i = \left(\left(\nabla g_i)(f(W)) \cdot \left(\frac{\partial f_a}{\partial w_{jk}}(W)\right)_a\right)_{jk}\right)_{i}$$

By writing this out in the stacked matrix format, it seems that this calculation only differs up to a swap of indices. This is where I am stuck. If I take $s = 1$ (which corresponds to the case examined in my notes), the definition seems to fit, but my derivation of the chain rule seems to be wrong. It also might be that my understanding of tensor multiplication is not correct, such that both definition and calculation could be correct and just me not realizing it.

EDIT: Reading through the (concerningly sparse) material on the net about this, I might have found out that both formulations are equivalent under the premise that matrix-tensor-multiplications (generalized by tensor contractions) are the right way to go about this. Visually, employing the stacked matrix representation again, a multiplication of a rank-3-tensor with a matrix from the left would look like this: To obtain the value at index $ijk$, take the $i$-th row vector of $J(g)(f)$ and multiply it by the column vector gained by taking the $(j,k)$-th index of each submatrix of $D_W(f)$ in order (!).

This should make both arguments coincide (we especially get scalar-matrix-multiplication back for $l = s = 1$) and it seems natural for some reason. Is my reasoning about tensor multiplication correct?

Summarizing, my questions now are:

  1. Is my derivation of the chain rule up to the last point correct? If not, where exactly am I going wrong?
  2. How does the tensor multiplication work in the presumably correct statement of the chain rule above? With this I might be able to reverse engineer my error(s).

Thank you for reading through this barrage of indices and matrices!

TheOutZ
  • 1,489
  • 11
  • 22
  • Might be relevant https://math.stackexchange.com/questions/4035579/composite-function-gradient/4038098#4038098 – Koncopd Dec 28 '22 at 11:38
  • Ehm... could elaborate how the usual chain rule applies with this tensor notation? – TheOutZ Dec 28 '22 at 13:40

0 Answers0