在Torch中定義一個損失函數,一般是通過繼承nn.Module類來實現的。以下是一個示例:
import torch
import torch.nn as nn
class CustomLoss(nn.Module):
def __init__(self):
super(CustomLoss, self).__init__()
def forward(self, output, target):
loss = torch.mean((output - target) ** 2) # 以均方誤差為例
return loss
在上面的示例中,定義了一個名為CustomLoss的自定義損失函數類,其forward方法接受模型的輸出output和目標值target作為輸入,并計算損失值。這里使用的是均方誤差作為損失函數的計算方式,可以根據需要自定義不同的損失函數。