跨境派

跨境派

跨境派,专注跨境行业新闻资讯、跨境电商知识分享!

当前位置:首页 > 综合服务 > 社群媒体 > nn.Flatten()函数详解及示例

nn.Flatten()函数详解及示例

时间:2024-04-16 18:20:26 来源:网络cs 作者:往北 栏目:社群媒体 阅读:

标签: 示例  函数 

torch.nn.Flatten(start_dim=1end_dim=- 1)

作用:将连续的维度范围展平为张量。 经常在nn.Sequential()中出现,一般写在某个神经网络模型之后,用于对神经网络模型的输出进行处理,得到tensor类型的数据。

 有俩个参数,start_dim和end_dim,分别表示开始的维度和终止的维度,默认值分别是1和-1,其中1表示第一维度,-1表示最后的维度。结合起来看意思就是从第一维度到最后一个维度全部给展平为张量。(注意:数据的维度是从0开始的,也就是存在第0维度,第一维度并不是真正意义上的第一个)

同理,如果我这么写:

self.flat = nn.Flatten(start_dim=2, end_dim=3)

那么意思就是从第二维度开始,到第三维度全部给展平,也就是将2、3两个维度展平。

官网给出的示例:

input = torch.randn(32, 1, 5, 5)# With default parametersm = nn.Flatten()output = m(input)output.size()#torch.Size([32, 25])# With non-default parametersm = nn.Flatten(0, 2)output = m(input)output.size()#torch.Size([160, 5])

#开头的代码是注释

整段代码的意思是:给定一个维度为(32,1,5,5)的随机数据。

1.先使用一次nn.Flatten(),使用默认参数:

m = nn.Flatten()

也就是说从第一维度展平到最后一个维度,数据的维度是从0开始的,第一维度实际上是数据的第二个位置代表的维度,也就是样例中的1。

因此进行展平后的结果也就是[32,1×5×5]➡[32,25]

2.接着再使用一次指定参数的nn.Flatten(),即

m = nn.Flatten(0, 2)

也就是说从第0维度展平到第2维度,0~2,对应的也就是前三个维度。

因此结果就是[32×1×5,5]➡[160,5]

本文链接:https://www.kjpai.cn/news/2024-04-16/159586.html,文章来源:网络cs,作者:往北,版权归作者所有,如需转载请注明来源和作者,否则将追究法律责任!

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌抄袭侵权/违法违规的内容,一经查实,本站将立刻删除。

文章评论