5. 基础知识总结
1. 乘法全面整理¶
总览表¶
| 乘法类型 | PyTorch 写法 | 数学符号 | 输入形状 → 输出形状 |
|---|---|---|---|
| 按元素乘(Hadamard) | a * b 或 torch.mul |
$A \odot B$ | (m,n) × (m,n) → (m,n) |
| 矩阵乘法 | a @ b 或 torch.matmul |
$AB$ | (m,k) × (k,n) → (m,n) |
| 内积(点积) | torch.dot |
$a \cdot b$ | (n,) × (n,) → 标量 |
| 外积 | torch.outer |
$ab^T$ | (m,) × (n,) → (m,n) |
| 批量矩阵乘 | torch.bmm |
— | (B,m,k) × (B,k,n) → (B,m,n) |
| 标量乘 | a * scalar |
$\lambda A$ | (m,n) × 标量 → (m,n) |
逐一详解¶
① 按元素乘(Hadamard Product)*¶
A = torch.tensor([[1, 2], [3, 4]])
B = torch.tensor([[5, 6], [7, 8]])
A * B # [[5,12],[21,32]]
操作:对应位置直接相乘,形状不变。支持 Broadcasting。
常见场景:
- MMoE 中门控加权(本题正是此操作)
- Attention mask 的遮蔽
- 特征的逐元素缩放(如 SE 网络的 channel-wise 加权)
② 矩阵乘法 @ / torch.matmul¶
A = torch.randn(3, 4)
B = torch.randn(4, 5)
A @ B # (3, 5)
操作:行 × 列求和,即 $C_{ij} = \sum_k A_{ik} B_{kj}$
常见场景:
- 全连接层的线性变换 $y = Wx$
- Attention 中 $QK^T$、Attention Score × V
- 任意线性投影
③ 内积(点积)torch.dot¶
a = torch.tensor([1.0, 2.0, 3.0])
b = torch.tensor([4.0, 5.0, 6.0])
torch.dot(a, b) # 1*4 + 2*5 + 3*6 = 32(标量)
操作:对应元素相乘后求和,结果为标量,几何意义是 $|a||b|\cos\theta$。
常见场景:
- 计算两个向量的相似度(协同过滤、双塔召回)
- 余弦相似度的分子部分
- 推荐系统中 user embedding · item embedding
注意:
torch.dot只接受 1D 张量,批量场景用torch.matmul或torch.einsum。
④ 外积 torch.outer¶
a = torch.tensor([1.0, 2.0, 3.0]) # (3,)
b = torch.tensor([4.0, 5.0]) # (2,)
torch.outer(a, b) # (3, 2)
# [[4, 5],
# [8, 10],
# [12,15]]
操作:$C_{ij} = a_i \cdot b_j$,结果为矩阵,是内积的"逆操作"。
常见场景:
- FM(Factorization Machine)中二阶特征交叉
- 构造协方差矩阵
- Cross Network(DCN)中的特征交叉层 $x_0 x_l^T w$
⑤ 批量矩阵乘 torch.bmm¶
A = torch.randn(10, 3, 4) # batch=10
B = torch.randn(10, 4, 5)
torch.bmm(A, B) # (10, 3, 5)
操作:对 batch 中每个样本独立做矩阵乘法。
常见场景:
- Transformer 中批量计算 Attention Score
- 批量处理序列中每个时间步的变换
⑥ Broadcasting 机制¶
# expert_outputs_stacked: (10, 3, 72)
# current_gate_weights: (10, 3)
# .unsqueeze(-1) 后: (10, 3, 1)
weighted_experts = expert_outputs_stacked * current_gate_weights.unsqueeze(-1)
# (10, 3, 72) * (10, 3, 1) → Broadcasting → (10, 3, 72)
Broadcasting 规则:维度为 1 的轴自动扩展匹配对方:
(10, 3, 72)
(10, 3, 1) ← 最后一维从 1 广播到 72
────────────
(10, 3, 72) ← 每个 expert 的 72 维向量都乘上同一个标量权重
选择乘法的决策树¶
需要结果是标量?
├─ 是 → 内积 torch.dot(1D)或 matmul + sum
└─ 否
├─ 两个向量生成矩阵? → 外积 torch.outer
├─ 形状相同,逐位操作? → 按元素乘 *
├─ 需要线性投影/变换? → 矩阵乘 @
└─ 有 Batch 维度的矩阵乘? → bmm 或 matmul(自动支持 batch)