在Torch中定義一個神經網絡模型通常需要使用nn.Module類。下面是一個示例代碼,展示了如何定義一個簡單的全連接神經網絡模型:
import torch
import torch.nn as nn
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
model = SimpleNN()
在上面的代碼中,我們定義了一個名為SimpleNN的神經網絡模型,它包含兩個全連接層和一個ReLU激活函數。在__init__
方法中,我們定義了模型的各個層,然后在forward
方法中定義了數據在模型中的流動路徑。
需要注意的是,在定義神經網絡模型時,通常需要繼承nn.Module類,并實現__init__
和forward
方法。__init__
方法用于初始化模型的結構,forward
方法用于定義數據在模型中的傳播路徑。