keras中BatchNormalization在迁移学习中的坑

Author Avatar
patrickcty 5月 13, 2020

这个问题已经在 TF 2.0 中修复了,见文档

However, in the case of the BatchNormalization layer, setting trainable = False on the layer means that the layer will be subsequently run in inference mode (meaning that it will use the moving mean and the moving variance to normalize the current batch, rather than using the mean and variance of the current batch).

This behavior has been introduced in TensorFlow 2.0, in order to enable layer.trainable = False to produce the most commonly expected behavior in the convnet fine-tuning use case.

背景

最近 Keras 的文档更新了,我发现了一个迁移学习的文档,结果进去之后发现一个奇怪的地方:

base_model = keras.applications.Xception(
    weights="imagenet",  # Load weights pre-trained on ImageNet.
    input_shape=(150, 150, 3),
    include_top=False,
)  # Do not include the ImageNet classifier at the top.

# Freeze the base_model
base_model.trainable = False

# Create new model on top
inputs = keras.Input(shape=(150, 150, 3))
x = data_augmentation(inputs)  # Apply random data augmentation
x = keras.layers.experimental.preprocessing.Rescaling(1.0 / 255.0)(
    x
)  # Scale inputs to [0. 1]
# The base model contains batchnorm layers. We want to keep them in inference mode
# when we unfreeze the base model for fine-tuning, so we make sure that the
# base_model is running in inference mode here.
x = base_model(x, training=False)
x = keras.layers.GlobalAveragePooling2D()(x)
x = keras.layers.Dropout(0.2)(x)  # Regularize with dropout
outputs = keras.layers.Dense(1)(x)
model = keras.Model(inputs, outputs)

model.summary()

这里是说迁移学习的时候,使用其他的模型,然后冻结之,再增加新的层进行训练。其中 x = base_model(x, training=False) 以及上面的注释引起了我的注意。在这里设置 training=False 是为了让 backbone 处于 inference 状态,这个状态主要是对 BN 起作用,那就是不更新 BN 的参数,即使 unfreeze 之后也不更新。

这个 inference 状态和 training 状态有什么用呢?我们知道,有一些层在训练和测试的时候表现是不同的,比如 BN 和 Dropout。其中 BN 在训练的时候使用 mini-batch 的数据来进行归一化,同时更新 moving mean 和 moving variance,在测试的时候就使用上面的 moving mean 和 moving variance 来进行归一化。这里的 training 是用来控制这些层的表现。

这个 training 出现在 keras Layer 和 Model 的 call 方法中:

def call(self, inputs, training=False):
    pass

当调用层的时候可以指定,比如:

x = BatchNormalization()(x, training=False)
model = Xception(input_shape=(150, 150, 3))(x, training=True)

本着想更深入地理解这个参数,就查了一下,结果发现了一个 issue 和一个 PR,这里面就描述了 keras BatchNormalization 在迁移学习中的坑。

问题

大佬的博客清楚地解释了问题,我这里再重新复述一下。

上面提到过了,BN 层在训练状态和测试状态下的表现是不同的,一个是使用 mini-batch 的数据,另一个是使用积累下来的 moving mean 和 moving variance。而在迁移学习中,我们通常会把 backbone 直接 freeze,训练新加的层,再 unfreeze backbone,然后一起训练。

不过如果不按照上面那样设置 x = base_model(x, training=False),而是直接像下面这样使用已有的模型然后进行训练(实际上大多数数据增强都不会整合到 model 中,因为只有训练的时候才需要,也就是说,基本不会出现上面这种调用形式),那么就会出现问题。(以下代码来自提到的博客)

import numpy as np
from tensorflow.keras.datasets import cifar10
 
from tensorflow.keras.preprocessing.image import ImageDataGenerator
from tensorflow.keras.applications.resnet50 import ResNet50, preprocess_input
from tensorflow.keras.models import Model, load_model
from tensorflow.keras.layers import Dense, Flatten
from tensorflow.keras import backend as K

seed = 42
epochs = 10
records_per_class = 100

# We take only 2 classes from CIFAR10 and a very small sample to intentionally overfit the model.
# We will also use the same data for train/test and expect that Keras will give the same accuracy.
(x, y), _ = cifar10.load_data()
 
def filter_resize(category):
   # We do the preprocessing here instead in the Generator to get around a bug on Keras 2.1.5.
   return [preprocess_input(img) for img in x[y.flatten()==category][:records_per_class]]
 
x = np.stack(filter_resize(3)+filter_resize(5))
records_per_class = x.shape[0] // 2
y = np.array([[1,0]]*records_per_class + [[0,1]]*records_per_class)
 
 
# We will use a pre-trained model and finetune the top layers.
np.random.seed(seed)
base_model = ResNet50(weights='imagenet', include_top=False, input_shape=(32, 32, 3))
l = Flatten()(base_model.output)
predictions = Dense(2, activation='softmax')(l)
model = Model(inputs=base_model.input, outputs=predictions)
 
# for layer in model.layers[:140]:
#    layer.trainable = False
 
# for layer in model.layers[140:]:
#    layer.trainable = True
base_model.trainable = False
 
model.compile(optimizer='sgd', loss='categorical_crossentropy', metrics=['accuracy'])
model.fit_generator(ImageDataGenerator().flow(x, y, seed=42), 
                    steps_per_epoch=7,
                    epochs=epochs, 
                    validation_data=ImageDataGenerator().flow(x, y, seed=42),
                    validation_steps=7
                    )
 
# Store the model on disk
model.save('tmp.h5')
 
 
# In every test we will clear the session and reload the model to force Learning_Phase values to change.
print('DYNAMIC LEARNING_PHASE')
K.clear_session()
model = load_model('tmp.h5')
# This accuracy should match exactly the one of the validation set on the last iteration.
print(model.evaluate(ImageDataGenerator().flow(x, y, seed=42), steps=7))
 
 
print('STATIC LEARNING_PHASE = 0')
K.clear_session()
K.set_learning_phase(0)
model = load_model('tmp.h5')
# Again the accuracy should match the above.
print(model.evaluate(ImageDataGenerator().flow(x, y, seed=42), steps=7))
 
 
print('STATIC LEARNING_PHASE = 1')
K.clear_session()
K.set_learning_phase(1)
model = load_model('tmp.h5')
# The accuracy will be close to the one of the training set on the last iteration.
print(model.evaluate(ImageDataGenerator().flow(x, y, seed=42), steps=7))

运行上面的代码,其中训练集和验证集是同一个数据集。我们可以看到这两者的结果截然不同,训练的结果远好于验证的结果:

Epoch 1/10
6/7 [========================>.....] - ETA: 0s - loss: 1.1314 - acc: 0.5298Epoch 1/10
7/7 [==============================] - 3s 394ms/step - loss: 2.0678 - acc: 0.5700
7/7 [==============================] - 5s 760ms/step - loss: 1.2129 - acc: 0.5300 - val_loss: 2.0678 - val_acc: 0.5700
Epoch 2/10
6/7 [========================>.....] - ETA: 0s - loss: 0.9528 - acc: 0.6012Epoch 1/10
7/7 [==============================] - 2s 265ms/step - loss: 1.4357 - acc: 0.5600
7/7 [==============================] - 4s 558ms/step - loss: 0.8973 - acc: 0.6150 - val_loss: 1.4357 - val_acc: 0.5600
Epoch 3/10
6/7 [========================>.....] - ETA: 0s - loss: 0.7655 - acc: 0.6667Epoch 1/10
7/7 [==============================] - 2s 215ms/step - loss: 1.4113 - acc: 0.5950
7/7 [==============================] - 4s 535ms/step - loss: 0.8119 - acc: 0.6550 - val_loss: 1.4113 - val_acc: 0.5950
Epoch 4/10
6/7 [========================>.....] - ETA: 0s - loss: 0.7548 - acc: 0.7440Epoch 1/10
7/7 [==============================] - 1s 151ms/step - loss: 1.9380 - acc: 0.5800
7/7 [==============================] - 2s 331ms/step - loss: 0.7230 - acc: 0.7350 - val_loss: 1.9380 - val_acc: 0.5800
Epoch 5/10
6/7 [========================>.....] - ETA: 0s - loss: 0.5866 - acc: 0.7202Epoch 1/10
7/7 [==============================] - 1s 150ms/step - loss: 1.8147 - acc: 0.6000
7/7 [==============================] - 2s 322ms/step - loss: 0.5802 - acc: 0.7150 - val_loss: 1.8147 - val_acc: 0.6000
Epoch 6/10
6/7 [========================>.....] - ETA: 0s - loss: 0.3704 - acc: 0.8095Epoch 1/10
7/7 [==============================] - 1s 151ms/step - loss: 1.5603 - acc: 0.6450
7/7 [==============================] - 2s 321ms/step - loss: 0.3881 - acc: 0.7950 - val_loss: 1.5603 - val_acc: 0.6450
Epoch 7/10
6/7 [========================>.....] - ETA: 0s - loss: 0.5056 - acc: 0.7738Epoch 1/10
7/7 [==============================] - 1s 151ms/step - loss: 1.9539 - acc: 0.6250
7/7 [==============================] - 2s 322ms/step - loss: 0.5618 - acc: 0.7400 - val_loss: 1.9539 - val_acc: 0.6250
Epoch 8/10
6/7 [========================>.....] - ETA: 0s - loss: 0.5849 - acc: 0.7976Epoch 1/10
7/7 [==============================] - 1s 153ms/step - loss: 1.4035 - acc: 0.6600
7/7 [==============================] - 2s 323ms/step - loss: 0.5465 - acc: 0.8050 - val_loss: 1.4035 - val_acc: 0.6600
Epoch 9/10
6/7 [========================>.....] - ETA: 0s - loss: 0.4055 - acc: 0.8512Epoch 1/10
7/7 [==============================] - 1s 147ms/step - loss: 1.0538 - acc: 0.6650
7/7 [==============================] - 2s 322ms/step - loss: 0.3984 - acc: 0.8450 - val_loss: 1.0538 - val_acc: 0.6650
Epoch 10/10
6/7 [========================>.....] - ETA: 0s - loss: 0.4082 - acc: 0.8452Epoch 1/10
7/7 [==============================] - 1s 152ms/step - loss: 1.8019 - acc: 0.6000
7/7 [==============================] - 2s 322ms/step - loss: 0.4177 - acc: 0.8400 - val_loss: 1.8019 - val_acc: 0.6000

再看看最后输出的结果:

DYNAMIC LEARNING_PHASE
7/7 [==============================] - 2s 256ms/step - loss: 2.0028 - acc: 0.6000
[2.002779943602426, 0.6]
STATIC LEARNING_PHASE = 0
7/7 [==============================] - 1s 204ms/step - loss: 2.0028 - acc: 0.6000
[2.002779943602426, 0.6]
STATIC LEARNING_PHASE = 1
7/7 [==============================] - 1s 212ms/step - loss: 0.3017 - acc: 0.8650
[0.30170093051024843, 0.865]

第一个结果是 keras 直接自动设置运行状态,第二个结果是手动设定运行状态为测试状态,第三个结果是手动设定运行结果为训练状态。可以看出来,keras 在测试的时候自动设置为测试状态,但这个时候结果出现了明显的下滑,而设置为训练状态的时候结果很正常。

其原因在于,在训练的时候,虽然 freeze 了 BN 的参数,但是 keras 仍然认为 BN 是在训练状态,因此会使用 mini-batch 的数据来标准化。也就是说,这时候后层网络学习到的是 mini-batch(训练数据集) 的分布。但是当测试的时候,BN 使用 moving mean 和 moving variance 来标准化,这两个参数由于没更新,是来自于原来数据集的。因为二者分布偏差很大,因此在测试模式下得到的结果非常差。

这个 PR 的改进就是当 freeze BN 的时候,就让 BN 层按照测试状态来进行,而不使用 mini-batch 的数据。

结果

看了半天,keras 官方好像没有改这个 bug,但是 TF 2.0 版本已经修改了这个 bug 了,以下是在 TF 2.0 下运行同样代码的结果,可以看到训练和验证的结果是相差不大的。另外值得一提的是,改进过后收敛速度明显快了很多,loss 从 0.3 直接降到了 0.01。

Epoch 1/10
7/7 [==============================] - 2s 332ms/step - loss: 7.3916 - accuracy: 0.4700 - val_loss: 3.1501 - val_accuracy: 0.6500
Epoch 2/10
7/7 [==============================] - 1s 207ms/step - loss: 2.8816 - accuracy: 0.6700 - val_loss: 8.4492 - val_accuracy: 0.5100
Epoch 3/10
7/7 [==============================] - 1s 206ms/step - loss: 4.1846 - accuracy: 0.6750 - val_loss: 11.3409 - val_accuracy: 0.5600
Epoch 4/10
7/7 [==============================] - 1s 204ms/step - loss: 3.4036 - accuracy: 0.7800 - val_loss: 0.4167 - val_accuracy: 0.8650
Epoch 5/10
7/7 [==============================] - 1s 210ms/step - loss: 0.8244 - accuracy: 0.8150 - val_loss: 9.1833 - val_accuracy: 0.5400
Epoch 6/10
7/7 [==============================] - 1s 210ms/step - loss: 2.3888 - accuracy: 0.7600 - val_loss: 0.7993 - val_accuracy: 0.8100
Epoch 7/10
7/7 [==============================] - 1s 207ms/step - loss: 0.5801 - accuracy: 0.8600 - val_loss: 2.9707 - val_accuracy: 0.6700
Epoch 8/10
7/7 [==============================] - 1s 205ms/step - loss: 4.2250 - accuracy: 0.6050 - val_loss: 1.0646 - val_accuracy: 0.8500
Epoch 9/10
7/7 [==============================] - 1s 206ms/step - loss: 0.4886 - accuracy: 0.8900 - val_loss: 0.0866 - val_accuracy: 0.9800
Epoch 10/10
7/7 [==============================] - 1s 206ms/step - loss: 0.0969 - accuracy: 0.9700 - val_loss: 0.0109 - val_accuracy: 1.0000
DYNAMIC LEARNING_PHASE
7/7 [==============================] - 1s 95ms/step - loss: 0.0118 - accuracy: 1.0000
[0.011801988817751408, 1.0]
STATIC LEARNING_PHASE = 0
7/7 [==============================] - 1s 94ms/step - loss: 0.0118 - accuracy: 1.0000
[0.011801988817751408, 1.0]
STATIC LEARNING_PHASE = 1
7/7 [==============================] - 1s 92ms/step - loss: 0.0118 - accuracy: 1.0000
[0.011801988817751408, 1.0]

最后

这种问题真的是防不胜防,毕竟很少人会去训练集和验证集使用同一个数据集,训练集和验证集相差大大家也只会怪罪到过拟合头上去。所以平常对于一些关键的东西还是得把他摸透,并且要多看官方文档,遇到问题多思考(所以深度学习就是这一点不好,出了问题有太多可能的原因,很难定位到问题所在)。

P.S. 今天在训练 SOD 的时候并没有出现这个问题,其原因可能在于:

  • 我的模型在 backbone 之外增加了很多的参数,减弱了 BN 的影响,因此结果是差不多的。(回头再多做一点实验)
  • DUTS 的数据本来就来自 ImageNet Detection,因此分布非常接近。