在PyTorch中,flatten
函數用于將輸入張量展平為一維張量。它的用法如下:
torch.flatten(input, start_dim=0, end_dim=-1)
參數說明:
input
:輸入的張量。start_dim
:開始展平的維度,默認為0。end_dim
:結束展平的維度,默認為-1,表示展平到最后一維。flatten
函數將沿著指定的維度范圍將輸入張量展平為一維張量。展平后的張量將包含原始張量中的所有元素,并將其重新排列為一維。
示例:
import torch
x = torch.randn(3, 4, 5)
flattened = torch.flatten(x)
print(flattened.shape) # 輸出: torch.Size([60])
flattened_dim1 = torch.flatten(x, start_dim=1)
print(flattened_dim1.shape) # 輸出: torch.Size([3, 20])
flattened_dim1_dim2 = torch.flatten(x, start_dim=1, end_dim=2)
print(flattened_dim1_dim2.shape) # 輸出: torch.Size([3, 20, 5])
在上面的示例中,flatten
函數首先將形狀為(3, 4, 5)的張量x
展平為形狀為(60,)的一維張量。然后,通過指定start_dim=1
,將張量x
的第二個維度展平,得到形狀為(3, 20)的張量。最后,通過指定start_dim=1, end_dim=2
,將張量x
的第二個和第三個維度展平,得到形狀為(3, 20, 5)的張量。