在Torch中,可以通過繼承nn.Module
類來定義一個神經網絡結構。以下是一個簡單的示例:
import torch
import torch.nn as nn
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(784, 128)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(128, 10)
def forward(self, x):
x = self.fc1(x)
x = self.relu(x)
x = self.fc2(x)
return x
model = SimpleNN()
在這個示例中,SimpleNN
類繼承自nn.Module
,并在__init__
方法中定義了神經網絡的結構,包括兩個全連接層和一個ReLU激活函數。在forward
方法中定義了數據在神經網絡中的流動,也就是前向傳播過程。最后通過實例化SimpleNN
類來創建一個神經網絡模型。