WA_automat

Gumbel-max trick

N 人看过

12月因为期末没有更博客,寒假开始补上一些内容。

最近在训模型时遇到需要用01向量来作为门控的时候,发现01向量的梯度会消失,需要使用特殊的技巧来处理,师兄让俺来学一下Gumbel-max,所以有了这篇博客。

这里主要讲一下01向量实现可微分的代码实现,原理部分讲解可能不太清楚。

基本原理

从离散分布采样(argmax)并参与训练,反向传播不能对argmax进行梯度计算,因此需要另辟蹊径;

对于一个Categorical distribution:

其中C表示C个类别,表示第c类的概率

常见的采样方式是:(其中U是标准正态分布)

但这样的方法也不能被微分,于是我们引入了Gumbel-Softmax的采样方法:

其中分布, 再用代替就可以微分了。

代码实现:

def gumbel_max(logits, tau):
    gumbel_noise = -torch.log(-torch.log(torch.rand_like(logits)))
    samples = (logits + gumbel_noise) / tau
    out = F.softmax(samples)
    return out

本作品采用 知识共享署名-非商业性使用-禁止演绎 4.0 国际许可协议 (CC BY-NC-ND 4.0) 进行许可。