在PyTorch中,nn.Linear
是一個用于定義線性變換的類。可以使用它來定義一個全連接層(也稱為線性層)。
以下是如何使用nn.Linear
的例子:
首先,導入需要的模塊:
import torch
import torch.nn as nn
接下來,定義一個包含輸入和輸出大小的線性層:
input_size = 10
output_size = 5
linear_layer = nn.Linear(input_size, output_size)
這將創建一個線性層,將輸入維度為10的特征映射到輸出維度為5的特征。
然后,可以將數據傳遞給線性層進行轉換:
input_data = torch.randn(1, input_size)
output_data = linear_layer(input_data)
這將根據線性層的權重和偏差將輸入數據進行線性變換,并返回輸出數據。
最后,可以查看線性層的權重和偏差:
print(linear_layer.weight)
print(linear_layer.bias)
這將打印出線性層的權重矩陣和偏差向量。
注意:nn.Linear
類還可以接受一些其他參數,例如是否添加偏差(默認為True)、權重初始化方法等。你可以查閱PyTorch的官方文檔以獲取更多詳細信息。