im2col
原文链接:https://mp.weixin.qq.com/s/GPDYKQlIOq6Su0Ta9ipzig
一句话:im2col是将一个[C,H,W]矩阵变成一个[H,W]矩阵的一个方法,其原理是利用了行列式进行等价转换。
为什么要做im2col? 减少调用gemm的次数。
重要:本次的代码只是为了方便理解im2col,不是用来做加速,所以代码写的很简单且没有做任何优化。
一、卷积的可视化
例子是一个[1, 6, 6]的输入,卷积核是[1, 3, 3],stride等于1,padding等于0。那么卷积的过程可视化如下图,一共需要做16次卷积计算,每次卷积计算有9次乘法和8次加法。
输出的公式如下,即Output_height = (6 - 3 + 2*0)/1 + 1 = 4 = Output_width
二、行列式
乘号左边的横条,跟乘号右边的竖条进行点乘(即每个元素对应相乘后再全部加起来)。
关于行列式,大家都清楚的一点,一根横条的元素个数要等于一根竖条的元素个数(这样才可以让做点乘的时候能一一对应起来,不会让小方块落单)。竖条有多少条,出来的结果就有多少个小方块(在横条的个数为1的情况下)。
出来的结果(等号的右边)的行数等于乘号左边的横条的行数,出来的结果(等号的右边)的列数等于乘号右边的横条的列数,公式表示就是[row, x] * [x, col] = [row, col]。举个例子[3, 8] * [8, 4] = [3, 4]
在这里插入图片描述
三、[1, H, W]的im2col
展开后,就可以直接做两个数组的矩阵乘积了
中间俩个for循环是来填满展开的数组/矩阵的每一列,即卷积核对应的元素,其个数等于卷积核的元素个数,举个例子,[1, 3, 3]的卷积核,那么该卷积核的元素个数等于9;最外层的两个for循环是用来填满展开的数组/矩阵的每一行,即列数,也就是卷积核在输入滑动了多少次
pytorch来做验证
四、[C, H, W]的im2col
在这里插入图片描述
前面一堆图,是我故意不写文字,希望大家能够通过图能够看明白。前面卷积核只有一行的情况,跟[1, H, W]的情况基本一摸一样,只是这一行的元素个数等于卷积核的元素个数即可5x3x3=45,展开的特征图的每一个竖条也是45。
当卷积核函数等于3的时候,就是对应的只要增加卷积核的横条数即可,展开的特征图没有改变。这里希望大家用行列式的计算和普通卷积的过程联想起来,你会发现是一摸一样的计算过程。
代码其实跟[1,H, W]只有一初不同,就是从特征图里面取数据的时候多了个维度,需要取对应的通道。这里为什么要取对应的通道数呢?原因是行列式的计算中,横条和竖条是元素一一对应做乘法。
pytorch代码的验证
五、[B, C, H, W]的im2col
问题:如何bs=9的情况呢,要怎么做im2col+gemm呢?方法 1:把filter摊平的shape变成[3,5339],把input摊平的shape变成[5339,16] – output的shape就为[3,16]了 - ❌
方法 2:把filter摊平的shape变成[39,533],把input摊平的shape变成[533,16],output的shape就为[39,16]了 – 隐患:如何filter数量是51233这种数量,那么非常占用显存/内存
方法 3:im2col+gemm外面加一层关于bs的for循环 – 隐患:加一层for循环嵌套非常耗时
经过简单分析,发现采取for循环的方式来进行im2col是相对合适的情况。我向msnh2012的作者穆士凝魂请教,得到的答案是,是用加一层for循环的方式居多,而且由于可以并发,多一层循环的开销比想象中小一些。如果是推理框架的话,有部分情况bs是等于1的,所以可以规避这个问题。
Last updated