目录
- 一. 用法
- 二. 参数
- 三. 实例
一. 用法
Flatten层主要是用来将输入“压平”,即把多维的输入一维化,用在卷积层到全连接层的过渡。其不会影响batch的大小,可以理解为把高纬度的数组按照x轴或者y轴进行拉伸,变成一维的数组。
二. 参数
1.start_dim(可选参数):指定从哪个维度开始展平张量。默认情况下,start_dim
被设置为0,表示从第一个维度(通常是批大小)开始展平。如果设置为其他整数值,则会从指定的维度开始展平。
2.end_dim(可选参数):指定在哪个维度结束展平张量。默认情况下,end_dim
被设置为-1,表示展平直到最后一个维度。如果设置为其他整数值,则会在指定的维度结束展平。
三. 实例
(1). 首先随机定义一个满足正态分布的(2,3,4)的数据x
import torch x = torch.randn(2,3,4) print(x) x = x.flatten(0) print(x) ------------------------------------ tensor([[[ 0.1281, 1.6878, 0.2301, -0.0721], [ 1.2374, -0.6929, 1.1186, 0.4372], [ 0.5122, 1.4653, -0.1673, 0.7258]], [[ 0.2772, -1.9994, -1.2284, 0.2764], [-0.0451, -0.9195, 0.5749, 0.1942], [ 0.8539, -0.0434, -0.7313, 0.0234]]]) tensor([ 0.1281, 1.6878, 0.2301, -0.0721, 1.2374, -0.6929, 1.1186, 0.4372, 0.5122, 1.4653, -0.1673, 0.7258, 0.2772, -1.9994, -1.2284, 0.2764, -0.0451, -0.9195, 0.5749, 0.1942, 0.8539, -0.0434, -0.7313, 0.0234]) import torch x = torch.randn(2,3,4) print(x) x = x.flatten(0) print(x) ------------------------------------ tensor([[[ 0.1281, 1.6878, 0.2301, -0.0721], [ 1.2374, -0.6929, 1.1186, 0.4372], [ 0.5122, 1.4653, -0.1673, 0.7258]], [[ 0.2772, -1.9994, -1.2284, 0.2764], [-0.0451, -0.9195, 0.5749, 0.1942], [ 0.8539, -0.0434, -0.7313, 0.0234]]]) tensor([ 0.1281, 1.6878, 0.2301, -0.0721, 1.2374, -0.6929, 1.1186, 0.4372, 0.5122, 1.4653, -0.1673, 0.7258, 0.2772, -1.9994, -1.2284, 0.2764, -0.0451, -0.9195, 0.5749, 0.1942, 0.8539, -0.0434, -0.7313, 0.0234])
此时x的维度是2×3×4=24,x = flatten(0) 和 x = flatten()的结果相同。
(2).
import torch x = torch.randn(2,3,4) print(x) x = x.flatten(1) print(x) =========================================== tensor([[[-0.7137, -0.0859, -1.5284, 0.7284], [ 0.8425, 0.3606, 1.7639, 0.1848], [ 0.4040, -1.6575, 1.9134, -1.0787]], [[ 0.6981, 1.3494, -0.5817, -1.1824], [-0.4972, 0.4179, 2.1742, -0.2462], [ 0.2429, -1.9315, -0.3497, 0.7190]]]) tensor([[-0.7137, -0.0859, -1.5284, 0.7284, 0.8425, 0.3606, 1.7639, 0.1848, 0.4040, -1.6575, 1.9134, -1.0787], [ 0.6981, 1.3494, -0.5817, -1.1824, -0.4972, 0.4179, 2.1742, -0.2462, 0.2429, -1.9315, -0.3497, 0.7190]])
此时x是从1维度开始展开,最后的x维度为(2,3×4),也就是(2,12)
注意:start_dim
和end_dim
参数的取值范围应该在 -x.dim() <= start_dim <= end_dim < x.dim()
之间。
到此这篇关于PyTorch中flatten() 函数的用法的文章就介绍到这了,更多相关PyTorch flatten() 函数内容请搜索本网站以前的文章或继续浏览下面的相关文章希望大家以后多多支持本网站!
您可能感兴趣的文章:
- pytorch中nn.Flatten()函数详解及示例
- pytorch中的reshape()、view()、nn.flatten()和flatten()使用
- Pytorch中torch.flatten()和torch.nn.Flatten()实例详解
- Pytorch阅读文档中的flatten函数