在PyTorch中,可以通過定義一個函數來初始化模型的權重。以下是一個示例代碼:
import torch
import torch.nn as nn
def init_weights(m):
if type(m) == nn.Linear or type(m) == nn.Conv2d:
nn.init.xavier_uniform_(m.weight)
nn.init.zeros_(m.bias)
# 定義模型
model = nn.Sequential(
nn.Conv2d(3, 64, kernel_size=3),
nn.ReLU(),
nn.Linear(64*28*28, 10)
)
# 初始化模型權重
model.apply(init_weights)
在上面的代碼中,定義了一個init_weights
函數,該函數根據模型的類型對權重進行初始化。然后通過調用model.apply(init_weights)
來初始化模型的權重。這樣就可以保證模型的權重被正確地初始化。