nn.Linear
是 PyTorch 中的一個類,用來定義一個線性變換(線性層)的操作。
具體來說,nn.Linear
用于定義一個線性映射,將輸入張量的每個元素與權重矩陣相乘,并加上偏置向量。其功能可以總結如下:
線性變換:將輸入張量與權重矩陣相乘,得到輸出張量。輸入張量的形狀為 (batch_size, input_size)
,權重矩陣的形狀為 (output_size, input_size)
。輸出張量的形狀為 (batch_size, output_size)
。
加偏置:將輸出張量加上一個偏置向量,該偏置向量的形狀為 (output_size,)
。偏置向量會被廣播到每個樣本的輸出上。
自動創建參數:nn.Linear
創建線性層時會自動創建權重矩陣和偏置向量,并將它們保存在模型的參數列表中。
自動梯度計算:通過 PyTorch 的自動求導機制,nn.Linear
可以自動計算權重矩陣和偏置向量的梯度,并進行優化。
nn.Linear
通常在神經網絡模型中被用作全連接層(全連接神經網絡),用來將輸入特征映射到輸出特征。