循环神经网络

循环神经网络

代码功能概述

1
torch.matmul(torch.cat((X, H), 1), torch.cat((W_xh, W_hh), 0))

这段代码合并了以下两部分的计算:

逐步解析

(1) 拼接输入和隐状态

1
torch.cat((X, H), 1)
  • torch.cat:将张量在指定维度上拼接。

(2) 拼接权重矩阵

1
torch.cat((W_xh, W_hh), 0)

(3) 矩阵乘法

1
torch.matmul(torch.cat((X, H), 1), torch.cat((W_xh, W_hh), 0))
  • 拼接后的矩阵:


对比:原始计算 vs 合并计算

原始计算

1
torch.matmul(X, W_xh) + torch.matmul(H, W_hh)

合并计算

1
torch.matmul(torch.cat((X, H), 1), torch.cat((W_xh, W_hh), 0))

结果维度

  • 批量大小 b=3
  • 隐状态维度 d_h=4

优点

  1. 计算效率
    • 合并计算只需要一次矩阵乘法,节省了计算成本。
  2. 代码简洁性
    • 用一个矩阵乘法替代了两次独立的操作,简化了代码。

循环神经网络
http://example.com/2024/11/28/20241128_循环神经网络/
作者
XuanYa
发布于
2024年11月28日
许可协议