PyTorch 中mm和bmm函数的使用详解
torch.mm 是 PyTorch 中用于 二维矩阵乘法(matrix-matrix multiplication) 的函数,等价于数学中的 A × B 矩阵乘积。
一、函数定义
torch.mm(input, mat2) → Tensor
执行的是两个 2D Tensor(矩阵)的标准矩阵乘法。
- input: 第一个二维张量,形状为 (n × m)
- mat2: 第二个二维张量,形状为 (m × p)
- 返回:形状为 (n × p) 的张量
二、使用条件和注意事项
条件 说明 仅支持 2D 张量 一维或三维以上使用 torch.matmul 或 @ 操作符 维度要匹配 即 input.shape[1] == mat2.shape[0] 不支持广播 两个矩阵维度不匹配会直接报错 结果是普通矩阵乘积 不是逐元素乘法(Hadamard),即不是 * 或 torch.mul() 三、示例代码
示例 1:基本矩阵乘法
import torch A = torch.tensor([[1., 2.], [3., 4.]]) # 2x2 B = torch.tensor([[5., 6.], [7., 8.]]) # 2x2 C = torch.mm(A, B) print(C)
输出:
tensor([[19., 22.], [43., 50.]])
计算步骤:
C[0][0] = 1*5 + 2*7 = 19 C[0][1] = 1*6 + 2*8 = 22 ...
示例 2:不匹配维度导致报错
A = torch.rand(2, 3) B = torch.rand(4, 2) C = torch.mm(A, B) # ❌ 会报错
报错:
RuntimeError: mat1 and mat2 shapes cannot be multiplied (2x3 and 4x2)
示例 3:推荐写法(推荐使用 @ 或 matmul)
A = torch.rand(3, 4) B = torch.rand(4, 5) C1 = torch.mm(A, B) C2 = A @ B # 推荐用法 C3 = torch.matmul(A, B) # 推荐用法
四、与其他乘法函数的比较
函数名 支持维度 运算类型 支持广播 torch.mm 仅限二维 矩阵乘法 ❌ 不支持 torch.matmul 1D, 2D, ND 自动判断点乘 / 矩阵乘 ✅ 支持 torch.bmm 批量二维乘法 3D Tensor batch × batch ❌ 不支持 torch.mul 任意维度 元素乘(Hadamard) ✅ 支持 * 运算符 任意维度 元素乘 ✅ 支持 @ 运算符 ND(推荐用) 矩阵乘法(和 matmul 一样) ✅ 五、典型应用场景
- 神经网络权重乘法:output = torch.mm(W, x)
- 点云 / 图像变换:x' = torch.mm(R, x) + t
- 多层感知机中的矩阵计算
- 注意力机制中 QK^T 乘积
六、总结:什么时候用 mm?
使用场景 用什么 仅二维矩阵乘法 torch.mm 高维或支持广播乘法 torch.matmul / @ 批量矩阵乘法 (如 batch_size×3×3) torch.bmm 元素乘 torch.mul or * 在 PyTorch 中,torch.bmm 是 批量矩阵乘法(batch matrix multiplication) 的操作,专用于处理三维张量(batch of matrices)。它的主要作用是对一组矩阵成对进行乘法,效率远高于手动循环计算。
一、torch.bmm 语法
torch.bmm(input, mat2, *, out=None) → Tensor
- input: Tensor,形状为 (B, N, M)
- mat2: Tensor,形状为 (B, M, P)
- 返回结果形状为 (B, N, P)
这表示对 B 对 N×M 和 M×P 的矩阵进行成对相乘。
二、示例演示
示例 1:基础用法
import torch # 定义两个 batch 矩阵 A = torch.randn(4, 2, 3) # shape: (B=4, N=2, M=3) B = torch.randn(4, 3, 5) # shape: (B=4, M=3, P=5) # 批量矩阵乘法 C = torch.bmm(A, B) # shape: (4, 2, 5) print(C.shape) # 输出: torch.Size([4, 2, 5])
示例 2:手动循环 vs bmm 效率对比
# 慢速手动方式 C_manual = torch.stack([A[i] @ B[i] for i in range(A.size(0))]) # 等效于 bmm C_bmm = torch.bmm(A, B) print(torch.allclose(C_manual, C_bmm)) # True
三、注意事项
1. 维度必须是三维张量
- 否则会报错:
RuntimeError: batch1 must be a 3D tensor
你可以通过 .unsqueeze() 手动调整维度:
a = torch.randn(2, 3) b = torch.randn(3, 4) # 升维 a_batch = a.unsqueeze(0) # (1, 2, 3) b_batch = b.unsqueeze(0) # (1, 3, 4) c = torch.bmm(a_batch, b_batch) # (1, 2, 4)
2. 维度必须满足矩阵乘法规则
- (B, N, M) × (B, M, P) → (B, N, P)
- 若 M 不一致会报错:
RuntimeError: Expected size for the second dimension of batch2 tensor to match the first dimension of batch1 tensor
3. bmm 不支持广播(broadcasting)
- 必须显式提供相同的 batch size。
- 如果只有一个矩阵固定,可以使用 .expand():
A = torch.randn(1, 2, 3) # 单个矩阵 B = torch.randn(4, 3, 5) # 4 个矩阵 # 扩展 A 以进行 batch 乘法 A_expand = A.expand(4, -1, -1) C = torch.bmm(A_expand, B) # (4, 2, 5)
四、在实际应用中的例子
在点云变换中:批量乘旋转矩阵
# 假设有 B 个旋转矩阵和点坐标 R = torch.randn(B, 3, 3) # 旋转矩阵 points = torch.randn(B, 3, N) # 点云 # 先转置点坐标为 (B, N, 3) points_T = points.transpose(1, 2) # (B, N, 3) # 用 bmm 做点变换:每组点乘旋转 transformed = torch.bmm(points_T, R.transpose(1, 2)) # (B, N, 3)
五、总结
特性 torch.bmm 操作对象 三维张量(batch of matrices) 核心规则 (B, N, M) x (B, M, P) = (B, N, P) 是否支持广播 ❌ 不支持,需要手动 .expand() 与 matmul 区别 matmul 支持更多广播,bmm 更高效用于纯批量矩阵乘法 应用场景 批量线性变换、点云配准、神经网络前向传播等
- 否则会报错:
免责声明:我们致力于保护作者版权,注重分享,被刊用文章因无法核实真实出处,未能及时与作者取得联系,或有版权异议的,请联系管理员,我们会立即处理! 部分文章是来自自研大数据AI进行生成,内容摘自(百度百科,百度知道,头条百科,中国民法典,刑法,牛津词典,新华词典,汉语词典,国家院校,科普平台)等数据,内容仅供学习参考,不准确地方联系删除处理! 图片声明:本站部分配图来自人工智能系统AI生成,觅知网授权图片,PxHere摄影无版权图库和百度,360,搜狗等多加搜索引擎自动关键词搜索配图,如有侵权的图片,请第一时间联系我们。