梯度计算总结
微积分与梯度计算专题:从导数到反向传播¶
一、知识框架总览¶
整个体系可以用一条主线串联:标量函数的导数 → 多变量函数的偏导数与梯度 → 向量函数的雅可比矩阵 → 复合函数的链式法则 → 神经网络中的反向传播。每一层都是前一层的自然延伸。
二、第一层:导数(一元函数)¶
2.1 本质是什么?¶
导数描述的是函数在某一点处的局部变化率——当自变量变化极小量 $\Delta x$ 时,因变量变化了多少。形式定义为:
$$f'(x) = \lim_{\Delta x \to 0} \frac{f(x + \Delta x) - f(x)}{\Delta x}$$
几何上,$f'(x)$ 就是曲线在 $x$ 处切线的斜率。梯度下降中"沿梯度反方向走",本质依赖的就是这个量。
2.2 必须烂熟于心的基本求导公式¶
以下是深度学习中真实会用到的:
$$\frac{d}{dx} x^n = n x^{n-1} \qquad \frac{d}{dx} e^x = e^x \qquad \frac{d}{dx} \ln x = \frac{1}{x}$$
$$\frac{d}{dx} \sigma(x) = \sigma(x)(1-\sigma(x)) \qquad \frac{d}{dx} \tanh(x) = 1 - \tanh^2(x)$$
$$\frac{d}{dx} \text{ReLU}(x) = \begin{cases}1 & x > 0 \\ 0 & x < 0\end{cases}$$
sigmoid 的导数推导是面试高频考点,值得显式推一遍:
$$\sigma(x) = \frac{1}{1+e^{-x}}$$
令 $u = 1 + e^{-x}$,则 $\sigma = u^{-1}$,用链式法则:
$$\sigma'(x) = -u^{-2} \cdot (-e^{-x}) = \frac{e^{-x}}{(1+e^{-x})^2} = \frac{1}{1+e^{-x}} \cdot \frac{e^{-x}}{1+e^{-x}} = \sigma(x)(1-\sigma(x))$$
三、第二层:偏导数与梯度(多变量函数)¶
3.1 从一元扩展到多元¶
当函数有多个输入时,比如 $f(x_1, x_2)$,就不能只谈"导数"了——需要知道沿每个方向各自的变化率。偏导数就是在固定其他变量不动的情况下,对某一个变量求导:
$$\frac{\partial f}{\partial x_1} = \lim_{\Delta x_1 \to 0} \frac{f(x_1 + \Delta x_1, x_2) - f(x_1, x_2)}{\Delta x_1}$$
计算时,把其他变量当作常数,剩余的求导规则与一元完全相同。
具体例子:$f(x_1, x_2) = x_1^2 + 3x_1 x_2 + x_2^3$
$$\frac{\partial f}{\partial x_1} = 2x_1 + 3x_2 \qquad \frac{\partial f}{\partial x_2} = 3x_1 + 3x_2^2$$
3.2 梯度:将所有偏导数打包成向量¶
梯度就是把所有偏导数收集成一个列向量,记作 $\nabla f$。对于 $f: \mathbb{R}^n \to \mathbb{R}$(输入是 $n$ 维向量,输出是一个标量),梯度为:
$$\nabla f(\mathbf{x}) = \begin{bmatrix} \frac{\partial f}{\partial x_1} \\ \frac{\partial f}{\partial x_2} \\ \vdots \\ \frac{\partial f}{\partial x_n} \end{bmatrix} \in \mathbb{R}^n$$
梯度的几何含义:它指向函数值增长最快的方向,其模长代表最快增长的速率。梯度下降之所以沿负梯度方向更新,正是因为这个原因——负梯度是下降最快的方向。
梯度下降更新规则(向量形式):
$$\mathbf{x}_{t+1} = \mathbf{x}_t - \eta \nabla f(\mathbf{x}_t)$$
四、第三层:雅可比矩阵(向量函数)¶
4.1 它解决什么问题?¶
梯度处理的是"多输入、单输出"函数($\mathbb{R}^n \to \mathbb{R}$)。但神经网络中大量的操作是"多输入、多输出"($\mathbb{R}^n \to \mathbb{R}^m$),比如一个全连接层将 $n$ 维输入映射到 $m$ 维输出。这时就需要雅可比矩阵来完整描述所有偏导数关系。
4.2 定义¶
对于 $\mathbf{f}: \mathbb{R}^n \to \mathbb{R}^m$,即:
$$\mathbf{f}(\mathbf{x}) = \begin{bmatrix} f_1(x_1,\ldots,x_n) \\ f_2(x_1,\ldots,x_n) \\ \vdots \\ f_m(x_1,\ldots,x_n) \end{bmatrix}$$
雅可比矩阵 $J \in \mathbb{R}^{m \times n}$ 的第 $i$ 行第 $j$ 列元素为:
$$J_{ij} = \frac{\partial f_i}{\partial x_j}$$
展开写:
$$J = \begin{bmatrix} \frac{\partial f_1}{\partial x_1} & \frac{\partial f_1}{\partial x_2} & \cdots & \frac{\partial f_1}{\partial x_n} \\ \frac{\partial f_2}{\partial x_1} & \frac{\partial f_2}{\partial x_2} & \cdots & \frac{\partial f_2}{\partial x_n} \\ \vdots & & \ddots & \vdots \\ \frac{\partial f_m}{\partial x_1} & \frac{\partial f_m}{\partial x_2} & \cdots & \frac{\partial f_m}{\partial x_n} \end{bmatrix}$$
直觉:$J$ 的第 $i$ 行告诉你"第 $i$ 个输出如何随所有输入变化";第 $j$ 列告诉你"第 $j$ 个输入变化时,所有输出如何响应"。
4.3 具体例子:一个全连接层¶
设 $\mathbf{y} = W\mathbf{x}$,其中 $W \in \mathbb{R}^{m \times n}$,$\mathbf{x} \in \mathbb{R}^n$,$\mathbf{y} \in \mathbb{R}^m$。则:
$$y_i = \sum_{j=1}^n W_{ij} x_j \implies \frac{\partial y_i}{\partial x_j} = W_{ij}$$
所以这个线性映射的雅可比矩阵就是 $W$ 本身,即 $J = W$。这在反向传播推导中非常重要。
4.4 梯度是雅可比的特例¶
当 $m=1$ 时,$\mathbf{f}$ 退化为标量函数 $f$,雅可比矩阵退化为 $1 \times n$ 的行向量,其转置就是梯度 $\nabla f$。三者的关系可以这样理解:
$$\text{导数(}n=1,m=1\text{)} \subset \text{梯度(}m=1\text{)} \subset \text{雅可比(一般情形)}$$
五、第四层:链式法则¶
5.1 为什么链式法则如此关键?¶
神经网络本质是复合函数的嵌套:输入经过线性变换,再经过激活函数,再经过线性变换……最终输出损失。要对最初的参数 $W$ 求导,就必须将这条"计算链"从后往前逐层拆解,这正是链式法则的用武之地。
5.2 一元版本¶
若 $y = f(g(x))$,令 $u = g(x)$,则:
$$\frac{dy}{dx} = \frac{dy}{du} \cdot \frac{du}{dx} = f'(g(x)) \cdot g'(x)$$
例:求 $y = e^{x^2}$ 的导数。令 $u = x^2$,则 $y = e^u$:
$$\frac{dy}{dx} = e^u \cdot 2x = 2x e^{x^2}$$
5.3 多元版本(全微分形式)¶
若 $z = f(u, v)$,而 $u = g(x)$,$v = h(x)$,则:
$$\frac{dz}{dx} = \frac{\partial z}{\partial u} \cdot \frac{du}{dx} + \frac{\partial z}{\partial v} \cdot \frac{dv}{dx}$$
原理:$x$ 的变化通过 $u$ 和 $v$ 两条路径影响 $z$,最终效果是两条路径贡献之和。这个"路径累加"的思想贯穿整个反向传播。
5.4 向量版本(矩阵形式)¶
若 $\mathbf{z} = f(\mathbf{y})$,$\mathbf{y} = g(\mathbf{x})$,则复合函数的雅可比为:
$$\frac{\partial \mathbf{z}}{\partial \mathbf{x}} = \frac{\partial \mathbf{z}}{\partial \mathbf{y}} \cdot \frac{\partial \mathbf{y}}{\partial \mathbf{x}} = J_f \cdot J_g$$
其中 $J_f \in \mathbb{R}^{p \times m}$,$J_g \in \mathbb{R}^{m \times n}$,结果为 $\mathbb{R}^{p \times n}$。这就是为什么说"反向传播是矩阵乘法的链式展开"。
六、第五层:反向传播¶
反向传播是链式法则在计算图上的高效实现策略。理解它需要先建立"计算图"的概念。
6.1 计算图¶
将一次前向计算拆解成有向无环图(DAG),每个节点是一个基本运算。以 $L = (\sigma(wx + b) - y)^2$ 为例:
6.2 逐步数值计算示例¶
用具体数值走一遍完整的前向 + 反向过程,设 $w=2, x=1, b=0, y=0.8$。
前向传播(从左到右计算各节点值):
$$z_1 = wx = 2 \times 1 = 2$$ $$z_2 = z_1 + b = 2 + 0 = 2$$ $$a = \sigma(z_2) = \frac{1}{1+e^{-2}} \approx 0.8808$$ $$L = (a - y)^2 = (0.8808 - 0.8)^2 \approx 0.00653$$
反向传播(从右到左,链式法则逐层传递):
$$\frac{\partial L}{\partial L} = 1 \quad \text{(出发点,恒为 1)}$$
$$\frac{\partial L}{\partial a} = 2(a - y) = 2 \times 0.0808 \approx 0.1616$$
$$\frac{\partial L}{\partial z_2} = \frac{\partial L}{\partial a} \cdot \frac{\partial a}{\partial z_2} = 0.1616 \times a(1-a) = 0.1616 \times 0.8808 \times 0.1192 \approx 0.01697$$
$$\frac{\partial L}{\partial z_1} = \frac{\partial L}{\partial z_2} \cdot \underbrace{\frac{\partial z_2}{\partial z_1}}_{=1} = 0.01697 \quad \text{(加法节点,梯度直接通过)}$$
$$\frac{\partial L}{\partial w} = \frac{\partial L}{\partial z_1} \cdot \underbrace{\frac{\partial z_1}{\partial w}}_{=x} = 0.01697 \times 1 = 0.01697$$
$$\frac{\partial L}{\partial b} = \frac{\partial L}{\partial z_2} \cdot 1 = 0.01697$$
参数更新($\eta = 0.1$):
$$w \leftarrow 2 - 0.1 \times 0.01697 = 1.9983$$
6.3 两类节点的局部导数规律¶
理解了计算图,反向传播的核心就浓缩成对每个基本节点记住其"局部导数":
加法节点:$z = x + y$,则 $\frac{\partial z}{\partial x} = 1$,$\frac{\partial z}{\partial y} = 1$。梯度原封不动地向两个输入分发,称为梯度分发器。
乘法节点:$z = x \cdot y$,则 $\frac{\partial z}{\partial x} = y$,$\frac{\partial z}{\partial y} = x$。梯度乘以对方的值传回,两路互换,称为梯度交换器。
激活函数节点:$z = \sigma(x)$,则 $\frac{\partial z}{\partial x} = \sigma(x)(1-\sigma(x))$,梯度乘以当前激活值的导数后传回。
矩阵乘法节点:$\mathbf{y} = W\mathbf{x}$,设上游传来梯度 $\delta = \frac{\partial L}{\partial \mathbf{y}}$,则:
$$\frac{\partial L}{\partial \mathbf{x}} = W^T \delta \qquad \frac{\partial L}{\partial W} = \delta \mathbf{x}^T$$
这里雅可比矩阵的作用就显现了:对 $\mathbf{x}$ 的梯度是 $W$ 转置乘以上游梯度,对 $W$ 的梯度是上游梯度与输入的外积。
七、全局知识框架总结¶
以下将所有概念的输入输出形状、计算方式和应用场景整理成一张对照表:
| 概念 | 函数类型 | 输出形状 | 梯度/导数计算 | 在深度学习中出现的场景 | |||---||| | 导数 $f'(x)$ | $\mathbb{R} \to \mathbb{R}$ | 标量 | 极限定义或查表 | 单个激活函数节点的局部导数 | | 偏导数 $\frac{\partial f}{\partial x_i}$ | $\mathbb{R}^n \to \mathbb{R}$ | 标量 | 固定其他变量求导 | 对某个权重的损失敏感度 | | 梯度 $\nabla f$ | $\mathbb{R}^n \to \mathbb{R}$ | $n$ 维向量 | 所有偏导数打包 | 参数更新的方向和步幅 | | 雅可比矩阵 $J$ | $\mathbb{R}^n \to \mathbb{R}^m$ | $m \times n$ 矩阵 | $J_{ij} = \frac{\partial f_i}{\partial x_j}$ | 全连接层、激活层的梯度反传 | | 链式法则 | 复合函数 | 同上游输出 | 局部导数相乘/矩阵连乘 | 多层网络反向传播的核心公式 | | 反向传播 | 计算图 | 每层的梯度 | 从输出向输入动态规划 | 所有深度学习框架的训练引擎 |
一个关键的形状验证原则:在反向传播中,某个参数的梯度形状必须与该参数本身的形状完全一致。若 $W \in \mathbb{R}^{m \times n}$,则 $\frac{\partial L}{\partial W}$ 也必须是 $m \times n$ 的矩阵。遇到形状对不上,往往意味着某处转置漏写或矩阵乘法顺序有误。这个验证习惯能帮你在推导时快速定位错误。
一、多层网络的参数更新顺序¶
在多层网络中,参数更新遵循严格的"先完整反传、后统一更新"流程,而非逐层算完立即更新。具体分为三个阶段。
前向传播阶段:输入数据从第一层流向最后一层,每个节点的中间值(激活值、线性变换结果等)全部被缓存,供反向传播使用。
反向传播阶段:从损失函数出发,依次计算 $\frac{\partial L}{\partial W_n}, \frac{\partial L}{\partial W_{n-1}}, \ldots, \frac{\partial L}{\partial W_1}$,所有层的梯度被完整计算并存储,但此时参数尚未改变。这一阶段之所以必须算完对中间变量(包括每层输入 $x$)的梯度,正是因为第 $k$ 层对 $W_k$ 的梯度,依赖于第 $k+1$ 层传回来的上游梯度 $\frac{\partial L}{\partial z_{k+1}}$;而该上游梯度恰恰是通过对第 $k$ 层输出求导得到的。换言之,每一层的"输入梯度"是下一层梯度计算的原材料,缺少任何一环,整条链式法则便无法连通。
统一更新阶段:所有梯度计算完毕后,执行一次梯度下降:
$$W_k \leftarrow W_k - \eta \cdot \frac{\partial L}{\partial W_k}, \quad k = 1, 2, \ldots, n$$
若在反向传播中途就修改参数,前向传播缓存的中间值将与当前参数不一致,导致后续层的梯度计算错误。这是标准深度学习框架(PyTorch、JAX 等)的统一实现逻辑。
二、Embedding 层的参数更新机制¶
Embedding 层在结构上是一个形状为 $|\mathcal{V}| \times d$ 的可学习矩阵,其中 $|\mathcal{V}|$ 为词表大小,$d$ 为嵌入维度。前向传播时,输入 token id 对应的行被"查出"作为当前输入向量;从网络的角度看,这等价于一次矩阵乘法 $\mathbf{h} = E[i, :]$,即取 Embedding 矩阵第 $i$ 行。
这里对输入求导的意义便得到了完整体现。若将原始离散输入理解为一个连续向量 $\mathbf{x}$(即 Embedding 矩阵中对应的行),则:
$$\frac{\partial L}{\partial \mathbf{x}} = \frac{\partial L}{\partial \mathbf{h}}$$
这个梯度的形状与 $\mathbf{x}$ 完全一致,因此可以直接用于更新 Embedding 矩阵中被访问的那些行。若本次 batch 中出现了 3 个 token,每个 token 以 4 维向量表示,则输入矩阵 $x \in \mathbb{R}^{3 \times 4}$,反向传播算出的 $\frac{\partial L}{\partial x} \in \mathbb{R}^{3 \times 4}$,两者形状精确对应,只有本次出现过的 3 行会被更新,其余词表行梯度为零、保持不变。
这正是"对 $x$ 求导"在实际系统中最重要的应用场景之一——它将离散符号的查表操作,转化为一个可微的连续优化问题,使得词向量能够通过梯度下降端到端地学习。
三、小矩阵完整计算示例¶
以下以一个两层全连接网络为例,$x \in \mathbb{R}^{2 \times 3}$ 表示 2 个样本、每个样本 3 维特征(即 batch size = 2),网络结构为:
$$x \in \mathbb{R}^{2\times3} \xrightarrow{W_1 \in \mathbb{R}^{3\times2}} Z \in \mathbb{R}^{2\times2} \xrightarrow{W_2 \in \mathbb{R}^{2\times1}} \hat{y} \in \mathbb{R}^{2\times1} \longrightarrow L$$
设定具体数值:
$$x = \begin{bmatrix}1 & 2 & 0\\ 0 & 1 & 3\end{bmatrix}, \quad W_1 = \begin{bmatrix}1 & 0\\ -1 & 1\\ 0 & 2\end{bmatrix}, \quad W_2 = \begin{bmatrix}1\\ -1\end{bmatrix}, \quad y_{\text{true}} = \begin{bmatrix}1\\ 0\end{bmatrix}$$
前向传播:
$$Z = x W_1 = \begin{bmatrix}1&2&0\\0&1&3\end{bmatrix}\begin{bmatrix}1&0\\-1&1\\0&2\end{bmatrix} = \begin{bmatrix}1\cdot1+2\cdot(-1)+0\cdot0 & 1\cdot0+2\cdot1+0\cdot2\\ 0\cdot1+1\cdot(-1)+3\cdot0 & 0\cdot0+1\cdot1+3\cdot2\end{bmatrix} = \begin{bmatrix}-1 & 2\\ -1 & 7\end{bmatrix}$$
$$\hat{y} = Z W_2 = \begin{bmatrix}-1&2\\-1&7\end{bmatrix}\begin{bmatrix}1\\-1\end{bmatrix} = \begin{bmatrix}-3\\-8\end{bmatrix}$$
$$L = \frac{1}{2}\|\hat{y} - y\|^2 = \frac{1}{2}\left[(-3-1)^2 + (-8-0)^2\right] = \frac{1}{2}(16 + 64) = 40$$
反向传播(链式法则逐层展开):
第一步,损失对 $\hat{y}$ 的梯度($\in \mathbb{R}^{2\times1}$):
$$\frac{\partial L}{\partial \hat{y}} = \hat{y} - y = \begin{bmatrix}-3\\-8\end{bmatrix} - \begin{bmatrix}1\\0\end{bmatrix} = \begin{bmatrix}-4\\-8\end{bmatrix}$$
第二步,对 $W_2$ 求梯度($\in \mathbb{R}^{2\times1}$,与 $W_2$ 形状一致):
$$\frac{\partial L}{\partial W_2} = Z^T \cdot \frac{\partial L}{\partial \hat{y}} = \begin{bmatrix}-1&-1\\2&7\end{bmatrix}\begin{bmatrix}-4\\-8\end{bmatrix} = \begin{bmatrix}(-1)(-4)+(-1)(-8)\\ 2(-4)+7(-8)\end{bmatrix} = \begin{bmatrix}12\\ -64\end{bmatrix}$$
第三步,梯度向前传递,计算对 $Z$ 的梯度($\in \mathbb{R}^{2\times2}$,与 $Z$ 形状一致):
$$\frac{\partial L}{\partial Z} = \frac{\partial L}{\partial \hat{y}} \cdot W_2^T = \begin{bmatrix}-4\\-8\end{bmatrix}\begin{bmatrix}1 & -1\end{bmatrix} = \begin{bmatrix}-4 & 4\\ -8 & 8\end{bmatrix}$$
第四步,对 $W_1$ 求梯度($\in \mathbb{R}^{3\times2}$,与 $W_1$ 形状一致):
$$\frac{\partial L}{\partial W_1} = x^T \cdot \frac{\partial L}{\partial Z} = \begin{bmatrix}1&0\\2&1\\0&3\end{bmatrix}\begin{bmatrix}-4&4\\-8&8\end{bmatrix} = \begin{bmatrix}-4&4\\-16&16\\-24&24\end{bmatrix}$$
第五步,对 $x$ 求梯度($\in \mathbb{R}^{2\times3}$,与 $x$ 形状一致;若 $x$ 为 Embedding,此梯度用于更新对应行):
$$\frac{\partial L}{\partial x} = \frac{\partial L}{\partial Z} \cdot W_1^T = \begin{bmatrix}-4&4\\-8&8\end{bmatrix}\begin{bmatrix}1&-1&0\\0&1&2\end{bmatrix} = \begin{bmatrix}-4&8&8\\-8&16&16\end{bmatrix}$$
参数更新($\eta = 0.01$):
$$W_2 \leftarrow \begin{bmatrix}1\\-1\end{bmatrix} - 0.01\begin{bmatrix}12\\-64\end{bmatrix} = \begin{bmatrix}0.88\\-0.36\end{bmatrix}$$
$$W_1 \leftarrow \begin{bmatrix}1&0\\-1&1\\0&2\end{bmatrix} - 0.01\begin{bmatrix}-4&4\\-16&16\\-24&24\end{bmatrix} = \begin{bmatrix}1.04&-0.04\\-0.84&0.84\\0.24&1.76\end{bmatrix}$$
四、两个通用公式与记忆方式¶
对于任意全连接层 $Y = XW$,其中 $X \in \mathbb{R}^{B \times n}$,$W \in \mathbb{R}^{n \times m}$,$Y \in \mathbb{R}^{B \times m}$,设上游传来的梯度为 $\Delta = \frac{\partial L}{\partial Y} \in \mathbb{R}^{B \times m}$,两个核心公式为:
$$\frac{\partial L}{\partial W} = X^T \Delta \qquad \frac{\partial L}{\partial X} = \Delta W^T$$
记忆方式只需抓住一个原则:谁是"另一侧",梯度就乘谁。对 $W$ 求导时,$W$ 的"另一侧"是 $X$,所以乘 $X^T$;对 $X$ 求导时,$X$ 的"另一侧"是 $W$,所以乘 $W^T$。加转置的原因并非约定,而是链式法则展开后求和结构的自然结果——前向传播用 $W$ 将 $n$ 维映射到 $m$ 维,反向传播则需要 $W^T$ 将梯度从 $m$ 维映射回 $n$ 维,方向对称,形状自洽。
验证时只需检查形状:$X^T \in \mathbb{R}^{n \times B}$ 乘以 $\Delta \in \mathbb{R}^{B \times m}$,结果为 $\mathbb{R}^{n \times m}$,与 $W$ 一致;$\Delta \in \mathbb{R}^{B \times m}$ 乘以 $W^T \in \mathbb{R}^{m \times n}$,结果为 $\mathbb{R}^{B \times n}$,与 $X$ 一致。形状验证是推导正确性的最快检验手段,建议作为推导习惯固化下来。
完整数值例子:从前向到反向,每个参数的梯度一览无余¶
网络结构设定¶
输入 $\mathbf{x} = [1, 2, 3, 4, 5, 6]$,线性层输出2维,再加一层 $\mathbf{v} \in \mathbb{R}^2$ 压缩成标量,与标签做MSE。
$$\mathbf{x} \xrightarrow{W, \mathbf{b}} \mathbf{h} = [h_1, h_2] \xrightarrow{\mathbf{v}} o = v_1 h_1 + v_2 h_2 \xrightarrow{\text{MSE}} \mathcal{L} = (o - y)^2$$
取具体参数值:
$$W = \begin{bmatrix} 1 & 0 \\ 0 & 1 \\ 1 & 0 \\ 0 & 1 \\ 1 & 0 \\ 0 & 1 \end{bmatrix}, \quad \mathbf{b} = [0, 0], \quad \mathbf{v} = [1, 1], \quad y = 10$$
第一步:前向传播,算出每个中间值¶
$$h_1 = x_1 w_{11} + x_2 w_{21} + x_3 w_{31} + x_4 w_{41} + x_5 w_{51} + x_6 w_{61} + b_1$$ $$= 1\times1 + 2\times0 + 3\times1 + 4\times0 + 5\times1 + 6\times0 = 9$$
$$h_2 = x_1 w_{12} + x_2 w_{22} + x_3 w_{32} + x_4 w_{42} + x_5 w_{52} + x_6 w_{62} + b_2$$ $$= 1\times0 + 2\times1 + 3\times0 + 4\times1 + 5\times0 + 6\times1 = 12$$
$$o = v_1 h_1 + v_2 h_2 = 1\times9 + 1\times12 = 21$$
$$\mathcal{L} = (21 - 10)^2 = 121$$
第二步:反向传播,逐层链式求导¶
Loss → o:
$$\frac{\partial \mathcal{L}}{\partial o} = 2(o - y) = 2 \times 11 = 22$$
o → $h_1$,$h_2$:
$$\frac{\partial \mathcal{L}}{\partial h_1} = \frac{\partial \mathcal{L}}{\partial o} \cdot v_1 = 22 \times 1 = 22 \equiv \delta_1$$
$$\frac{\partial \mathcal{L}}{\partial h_2} = \frac{\partial \mathcal{L}}{\partial o} \cdot v_2 = 22 \times 1 = 22 \equiv \delta_2$$
此处 $\delta_1 = \delta_2 = 22$,因为 $v_1 = v_2$。这是巧合,不是规律。
$h_1$,$h_2$ → W 的所有参数(核心部分):
$$\frac{\partial \mathcal{L}}{\partial w_{ij}} = \delta_j \cdot x_i$$
逐一列出所有12个参数的梯度:
| 参数 | 公式 | 数值 | 说明 |
|---|---|---|---|
| $w_{11}$ | $\delta_1 \cdot x_1 = 22 \times 1$ | 22 | 特征A第1维 → $h_1$ |
| $w_{21}$ | $\delta_1 \cdot x_2 = 22 \times 2$ | 44 | 特征A第2维 → $h_1$ |
| $w_{31}$ | $\delta_1 \cdot x_3 = 22 \times 3$ | 66 | 特征B第1维 → $h_1$ |
| $w_{41}$ | $\delta_1 \cdot x_4 = 22 \times 4$ | 88 | 特征B第2维 → $h_1$ |
| $w_{51}$ | $\delta_1 \cdot x_5 = 22 \times 5$ | 110 | 特征C第1维 → $h_1$ |
| $w_{61}$ | $\delta_1 \cdot x_6 = 22 \times 6$ | 132 | 特征C第2维 → $h_1$ |
| $w_{12}$ | $\delta_2 \cdot x_1 = 22 \times 1$ | 22 | 特征A第1维 → $h_2$ |
| $w_{22}$ | $\delta_2 \cdot x_2 = 22 \times 2$ | 44 | 特征A第2维 → $h_2$ |
| $w_{32}$ | $\delta_2 \cdot x_3 = 22 \times 3$ | 66 | 特征B第1维 → $h_2$ |
| $w_{42}$ | $\delta_2 \cdot x_4 = 22 \times 4$ | 88 | 特征B第2维 → $h_2$ |
| $w_{52}$ | $\delta_2 \cdot x_5 = 22 \times 5$ | 110 | 特征C第1维 → $h_2$ |
| $w_{62}$ | $\delta_2 \cdot x_6 = 22 \times 6$ | 132 | 特征C第2维 → $h_2$ |
第三步:直接回答核心问题¶
$w_{11}$ 和 $w_{12}$ 是并行更新的吗?
是的,所有参数在一次反向传播中同时独立计算各自的梯度,然后并行更新。$w_{11}$ 和 $w_{12}$ 梯度均为22,数值相同纯属本例巧合(因为 $\delta_1 = \delta_2$ 且 $x_1$ 相同),若 $v_1 \neq v_2$ 则二者立刻不同。
为什么同一列的参数梯度不同?
$w_{11}=22$,$w_{21}=44$,$w_{31}=66$……同属第一列(均流向 $h_1$,接收同样的 $\delta_1=22$),但对应输入 $x_1=1, x_2=2, x_3=3$ 各不相同,梯度因此各异。下游信号相同,输入值不同 → 梯度不同。
三个特征如何共同影响 W?
特征A($x_1, x_2$)决定 $w_{11}, w_{21}, w_{12}, w_{22}$ 的梯度;特征B($x_3, x_4$)决定 $w_{31}, w_{41}, w_{32}, w_{42}$ 的梯度;特征C($x_5, x_6$)决定 $w_{51}, w_{61}, w_{52}, w_{62}$ 的梯度。W 被三个特征同时更新,但每个特征只负责属于自己的那几行,互不干扰。
一句话统领全局¶
$$\boxed{\frac{\partial \mathcal{L}}{\partial w_{ij}} = \underbrace{\delta_j}_{\text{来自下游,决定方向}} \times \underbrace{x_i}_{\text{来自上游,决定幅度}}}$$
梯度永远是"从哪里接收信号"乘以"输入是多少",两个因素缺一不可,这就是为什么共享下游路径不等于梯度相同。
一次SGD梯度更新的完整过程¶
网络结构回顾¶
$$\mathbf{x} = [x_1, x_2, x_3, x_4, x_5, x_6] = [1,2,3,4,5,6]$$
$$h_1 = x_1 w_{11} + x_2 w_{21} + x_3 w_{31} + x_4 w_{41} + x_5 w_{51} + x_6 w_{61} + b_1 = 9$$
$$h_2 = x_1 w_{12} + x_2 w_{22} + x_3 w_{32} + x_4 w_{42} + x_5 w_{52} + x_6 w_{62} + b_2 = 12$$
$$o = v_1 h_1 + v_2 h_2 = 21, \qquad \mathcal{L} = (o - y)^2 = (21-10)^2 = 121$$
第一层反向:$\mathcal{L} \to o$¶
$$\frac{\partial \mathcal{L}}{\partial o} = 2(o - y) = 2 \times (21 - 10) = 22$$
第二层反向:$o \to h_1, h_2, v_1, v_2$¶
因为 $o = v_1 h_1 + v_2 h_2$,所以:
$$\frac{\partial \mathcal{L}}{\partial h_1} = \frac{\partial \mathcal{L}}{\partial o} \cdot \frac{\partial o}{\partial h_1} = 22 \times v_1 = 22 \times 1 = 22 \equiv \delta_1$$
$$\frac{\partial \mathcal{L}}{\partial h_2} = \frac{\partial \mathcal{L}}{\partial o} \cdot \frac{\partial o}{\partial h_2} = 22 \times v_2 = 22 \times 1 = 22 \equiv \delta_2$$
$$\frac{\partial \mathcal{L}}{\partial v_1} = \frac{\partial \mathcal{L}}{\partial o} \cdot \frac{\partial o}{\partial v_1} = 22 \times h_1 = 22 \times 9 = 198$$
$$\frac{\partial \mathcal{L}}{\partial v_2} = \frac{\partial \mathcal{L}}{\partial o} \cdot \frac{\partial o}{\partial v_2} = 22 \times h_2 = 22 \times 12 = 264$$
第三层反向:$h_1, h_2 \to W$¶
关键:为什么每个 $w_{ij}$ 的梯度要同时经过 $h_1$ 和 $h_2$ 两条路径?
以 $w_{11}$ 为例,观察前向传播:
- $w_{11}$ 出现在 $h_1 = x_1 w_{11} + \ldots$ 中,因此 $\partial h_1 / \partial w_{11} = x_1$
- $w_{11}$ 不出现在 $h_2$ 的表达式中,因此 $\partial h_2 / \partial w_{11} = 0$
所以完整链式展开为:
$$\frac{\partial \mathcal{L}}{\partial w_{11}} = \frac{\partial \mathcal{L}}{\partial h_1} \cdot \frac{\partial h_1}{\partial w_{11}} + \frac{\partial \mathcal{L}}{\partial h_2} \cdot \frac{\partial h_2}{\partial w_{11}} = \delta_1 \cdot x_1 + \delta_2 \cdot 0 = 22 \times 1 + 22 \times 0 = 22$$
这正是你问的为什么有一项为0——因为 $w_{11}$ 根本没有参与 $h_2$ 的计算,偏导数天然为零。
同理展开所有 $W$ 的参数:
$$\frac{\partial \mathcal{L}}{\partial w_{21}} = \frac{\partial \mathcal{L}}{\partial h_1} \cdot \frac{\partial h_1}{\partial w_{21}} + \frac{\partial \mathcal{L}}{\partial h_2} \cdot \frac{\partial h_2}{\partial w_{21}} = \delta_1 \cdot x_2 + \delta_2 \cdot 0 = 22 \times 2 + 22 \times 0 = 44$$
$$\frac{\partial \mathcal{L}}{\partial w_{31}} = \frac{\partial \mathcal{L}}{\partial h_1} \cdot \frac{\partial h_1}{\partial w_{31}} + \frac{\partial \mathcal{L}}{\partial h_2} \cdot \frac{\partial h_2}{\partial w_{31}} = \delta_1 \cdot x_3 + \delta_2 \cdot 0 = 22 \times 3 + 22 \times 0 = 66$$
$$\frac{\partial \mathcal{L}}{\partial w_{12}} = \frac{\partial \mathcal{L}}{\partial h_1} \cdot \frac{\partial h_1}{\partial w_{12}} + \frac{\partial \mathcal{L}}{\partial h_2} \cdot \frac{\partial h_2}{\partial w_{12}} = \delta_1 \cdot 0 + \delta_2 \cdot x_1 = 22 \times 0 + 22 \times 1 = 22$$
$$\frac{\partial \mathcal{L}}{\partial w_{22}} = \frac{\partial \mathcal{L}}{\partial h_1} \cdot \frac{\partial h_1}{\partial w_{22}} + \frac{\partial \mathcal{L}}{\partial h_2} \cdot \frac{\partial h_2}{\partial w_{22}} = \delta_1 \cdot 0 + \delta_2 \cdot x_2 = 22 \times 0 + 22 \times 2 = 44$$
$$\frac{\partial \mathcal{L}}{\partial w_{32}} = \frac{\partial \mathcal{L}}{\partial h_1} \cdot \frac{\partial h_1}{\partial w_{32}} + \frac{\partial \mathcal{L}}{\partial h_2} \cdot \frac{\partial h_2}{\partial w_{32}} = \delta_1 \cdot 0 + \delta_2 \cdot x_3 = 22 \times 0 + 22 \times 3 = 66$$
规律已经清晰:第一列参数($w_{i1}$)只参与 $h_1$ 的计算,第二列参数($w_{i2}$)只参与 $h_2$ 的计算,另一项必然为0。
第四层反向:$h_1, h_2 \to \mathbf{x}$(Embedding更新)¶
这里是最关键的部分。 每个 $x_i$ 同时参与了 $h_1$ 和 $h_2$ 的计算,因此两条路径的梯度都不为零,必须全部加起来。
以 $x_1$ 为例:
$$\frac{\partial \mathcal{L}}{\partial x_1} = \frac{\partial \mathcal{L}}{\partial h_1} \cdot \frac{\partial h_1}{\partial x_1} + \frac{\partial \mathcal{L}}{\partial h_2} \cdot \frac{\partial h_2}{\partial x_1} = \delta_1 \cdot w_{11} + \delta_2 \cdot w_{12} = 22 \times 1 + 22 \times 0 = 22$$
以 $x_2$ 为例:
$$\frac{\partial \mathcal{L}}{\partial x_2} = \frac{\partial \mathcal{L}}{\partial h_1} \cdot \frac{\partial h_1}{\partial x_2} + \frac{\partial \mathcal{L}}{\partial h_2} \cdot \frac{\partial h_2}{\partial x_2} = \delta_1 \cdot w_{21} + \delta_2 \cdot w_{22} = 22 \times 0 + 22 \times 1 = 22$$
以 $x_3$ 为例:
$$\frac{\partial \mathcal{L}}{\partial x_3} = \frac{\partial \mathcal{L}}{\partial h_1} \cdot \frac{\partial h_1}{\partial x_3} + \frac{\partial \mathcal{L}}{\partial h_2} \cdot \frac{\partial h_2}{\partial x_3} = \delta_1 \cdot w_{31} + \delta_2 \cdot w_{32} = 22 \times 1 + 22 \times 0 = 22$$
此处两项均不为零的原因:$x_1$ 同时出现在 $h_1$(通过 $w_{11}$)和 $h_2$(通过 $w_{12}$)的表达式中,两条路径都有梯度流过来,必须全部累加。本例中 $w_{12}=0$ 导致第二项数值为零,但这是初始参数的巧合,在式子结构上两项都不应省略。
统一公式:一眼看穿所有梯度的来源¶
$$\boxed{\frac{\partial \mathcal{L}}{\partial w_{ij}} = \delta_j \cdot x_i \qquad \text{(}w_{ij}\text{ 只连接 }x_i\text{ 和 }h_j\text{,另一条路径偏导为0)}}$$
$$\boxed{\frac{\partial \mathcal{L}}{\partial x_i} = \sum_{j} \delta_j \cdot w_{ij} \qquad \text{(}x_i\text{ 同时连接所有 }h_j\text{,每条路径都有贡献)}}$$
这两个公式的差异正好解释了你的疑惑:$w_{ij}$ 只活在一条路径上,所以只有一项非零;$x_i$ 活在所有路径上,所以每条路径都要累加。