keras 中 dense 层输入秩大于二

Author Avatar
patrickcty 11月 03, 2020

通常情况下输入到 Dense 层(又或者叫 FC 层)的张量是一个 (batch_size, length) 的秩为 2 的形式。比如 AlexNet 和 VGG,输出的特征图都会经过 flatten 操作降维到 (None, 1024),然后才输入到 Dense 层中。

但是今天我在看 MaskX R-CNN 的时候发现输入并不是一个二维矩阵,而是一个三维的张量。Keras 文档中在处理 Dense 秩大于二的时候会将其通过一个矩阵乘法来改变输出最后一维的长度(秩不变)。这样处理不是真正意义上所有神经元全连接,参数上也比 flatten 再还原要小很多。

x = layers.Input((81, 1024))  # (None, 81, 1024)
y = layers.Dense(256)  # y shape: (None, 81, 256)
# 参数 W shape: (1024, 256) 
# 参数 b shape: (256)