如何从一串数组中找到最大的前k个数? 建一个大根堆,第一次建堆时间O(N),然后弹出k次根,每次维护堆时间O(logN),消耗时间O(N+klogN)。 或者建一个大小为k的小根堆,每次读取N个数中的一个和根比大小,大了就放进来替换这个根,更新堆,最后留下的k个就是最大的,这个消耗时间O(Nlogk)。
但是上面的都是在CPU上做,如果我们有的是一个GPU,堆真的高效吗?
非也,以上的操作全都是串行的,而且对内存的读取不连续,这很糟糕。 我们看看torch是怎么做的:
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/TensorTopK.cu
找到第K大的值,然后
- 找到所有 严格大于 topKValue 的元素。
- 找到所有 等于 topKValue 的元素,并补足到K个。
那么怎么找到第K大的值呢?
如果在CPU上,我们可以做一次伪-快速排序,排的时候统计一下左边有多少个,右边有多少个,第K个只会在左边或者右边中的一块,递归的时候只看第K个所在的那一边即可,这玩意的时间复杂度的期望是O(N)
那么在GPU上怎么做?伪-基数排序。
https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/cuda/SortingRadixSelect.cuh
我们面对的通常是浮点数,浮点数怎么才能像整数一样按基数操作呢? 如果x的最高位(符号位)是1(即负数),mask是全1;如果是0(正数),mask是0x80000000。 x=x^mask,做异或。正数的符号位翻转为1,比所有负数大;负数叠一个全1,相当于大小关系反过来。 然后各个wrap分别取一部分的数组,在基数上做统计,最后和在一起。 当然我们不需要真的做排序,我们只要找到第K大,所以和前面伪-快速排序差不多,我们的任务是逐步缩小查找的范围,找到第K大在哪里。 具体来说,我们希望得到一个mask,这个mask最终将得到第K个数的值,按当前的基数统计然后得到第K个是哪个前缀,更新mask,统计下一级基数上的数目(和这个mask的前缀不匹配的就不参与计数),然后逐渐得到这个mask。实际上操作的次数是 $CN$ ,其中C是数字的位数。
前面讨论了计算topk的方法,显然排序这事没法求导,最后做出的选择是阶跃的。 但是如果我就是想对topk或者argmax(相当于topk的k=1)操作求导怎么办? (比如说需要训练一个网络来选择跑哪个网络,这时候训练做选择的网络就会遇到这个问题)
MoE实际上就是网络来选网络(gate选哪些expert要激活),但是MoE可以不需要对topk求导。 因为一般的做法里,gate输出各个expert的weight除了topk做选择之外,还会作为权重乘到expert的输出里加到该层最后的输出,那么靠这个权重把梯度传到gate网络里即可。
但是有时候我们不希望把gate输出乘到网络里,而是只希望它在做选择。 (比如说我想给原网络插一个跳层的模块,用这个gate网络来决定是否跳层)
比较老的工作有SkipNet(ECCV'18),它们在ResNet上插门控,门控用于判断能否跳掉这层的计算。他们的门控分两步训练,第一步将gate的离散输出松弛到连续输出,forward的时候门控值>0.5则计算该层,backward的时候用连续的gate值;第二步对gate做强化学习,略去不表。
在k=1的时候,其实有一个常用的trick是gumbel_softmax(ICLR'17: Categorical Variational Autoencoders using Gumbel-Softmax)。 用 $ sample = softmax((logits+G(0,1))/temperature) $ ,其中G是指在Gumbel分布上采样,然后做将做选择变成按输出的概率做采样。 逐步降低temperature使sample的选择趋近于100%选argmax,相当于训练的时候做退火。 比如说SkipGPT(ICML'25)就用到了Gumbel这个trick,来做LLaMA上的跳层。
那么在k>1的时候怎么办呢?Gumbel的问题是它在temperature->0的时候将只选一个值,所以直接用Gumbel不太行。 DSelect-k(NIPS'21)继续讨论了MoE模型上训练gate的问题,给出了一个新的可微分的gate,用二进制编码表示expert,训练logN个选择器来从N个里选1个,要选k个就训练k个这样的选择器。选择的逻辑就是用logN输出的编码和expert自己的编码做匹配。然后这个匹配操作还得靠一个平滑函数来变得可微。另外,ICML'19还有一个Gumbel-Top-k Trick,不过这篇我还没看。
在模型上插门控来跳掉一些计算(或者降低计算精度,或者卸载到其他计算设备)的工作非常多,这类工作都面临相似的问题,就是门控要靠什么规则来跳计算,这个规则得非常轻量的同时非常高效,才能让这类工作的效果比较好,而这是一个重要的难点。
同学言:“在GPU的simt架构上实现RadixSelect + Filter和Warp-Select + Block-Select的硬件协调加速有看过,这个是对topk算法本身加速实现的一些讨论。”