记录 PyTorch 中 masked_fill 的一个坑

Author Avatar
patrickcty 11月 20, 2021

这个函数一般被用于将 pad 的部分填充 0,即 value.masked_fill(mask[..., None], float(0))。从文档中来看,mask 的 dtype 应该为 bool(int 也行,只能为 0 或 1),为 True 的部分就会就会被填充值替换掉。但是如果 mask dtype 为 float 则不会被 masked 掉。

在使用的时候,记得将 pad 部分的 mask 设为 True,其他设为 False,不然被 masked 掉的就是有用的部分了……