PyTorch中的LSTM(Long Short-Term Memory)和GRU(Gated Recurrent Unit)是通過torch.nn模塊實現的。在PyTorch中,可以使用torch.nn.LSTM和torch.nn.GRU類來創建LSTM和GRU模型。
下面是一個簡單的例子,演示如何使用PyTorch中的LSTM和GRU:
import torch
import torch.nn as nn
# 定義輸入數據
input_size = 10
hidden_size = 20
seq_len = 5
batch_size = 3
input_data = torch.randn(seq_len, batch_size, input_size)
# 使用LSTM
lstm = nn.LSTM(input_size, hidden_size)
output, (h_n, c_n) = lstm(input_data)
print("LSTM output shape:", output.shape)
print("LSTM hidden state shape:", h_n.shape)
print("LSTM cell state shape:", c_n.shape)
# 使用GRU
gru = nn.GRU(input_size, hidden_size)
output, h_n = gru(input_data)
print("GRU output shape:", output.shape)
print("GRU hidden state shape:", h_n.shape)
在上面的例子中,我們首先定義了輸入數據的維度,并使用torch.nn.LSTM和torch.nn.GRU類分別創建了一個LSTM和一個GRU模型。然后,我們將輸入數據傳遞給這兩個模型,并輸出它們的輸出和隱藏狀態的形狀。
值得注意的是,LSTM和GRU模型的輸出形狀可能會有所不同,具體取決于輸入數據的維度和模型的參數設置。通常,輸出形狀將包含序列長度、批次大小和隱藏單元數量等信息。