keras自定义训练流程

标准流程

keras 的 api 集成度都非常高,在你没有额外需求的时候的时候能非常轻松地完成整个训练流程:

  • 加载数据
    • 可以选择 generator
    • 也可以直接传入内存的数据
    • 还可以按照一定格式组织成文件夹然后直接传文件夹名
  • 构造模型
  • 编译模型
    • 指定优化器
    • 指定损失函数
    • 指定评价标准
  • 训练模型
    • 指定训练轮次(epoch)
    • 指定回调

自定义流程

标准流程在大多数情况下都能满足需求,但是对于一些需要获取网络中细节的情况下就需要自定义流程了。自定义主要也是对训练步骤进行处理,基本步骤如下:

  • 定义一个 step 的操作:
    • 取出这个 batch 的数据
    • 传入网络得到输出
    • 计算 loss
    • 计算梯度
    • 梯度下降

写成代码如下所示:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('train_accuracy')

for idx, (x_train, y_train) in enumerate(train_gen):
with tf.GradientTape() as tape:
predictions = model(x_train, training=True) # 传入网络得到输出
loss = loss_object(y_train, predictions) # 计算 loss
grads = tape.gradient(loss, model.trainable_variables) # 计算梯度
optimizer.apply_gradients(zip(grads, model.trainable_variables)) # 梯度下降

# 一些指标
train_loss(loss)
train_accuracy(y_train, predictions)

和 TensorBoard 一起作用

尽管 keras 的 callback 里面也有 tensorboard,但是默认情况下它只能每个 ep 来保存评价指标和直方图,不能看一个 step 中的变化情况,也不能将参数或者梯度来画成图表。在这里我们将参数和梯度的 l2 范数变化情况画成图表,并且原本就有的直方图也不落下。

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
# create a tensorboard file writer
summary_writer = tf.summary.create_file_writer(some_path)
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
train_loss = tf.keras.metrics.Mean('train_loss', dtype=tf.float32)
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy('train_accuracy')

for epoch in range(1, num_epochs + 1):
for idx, (x_train, y_train) in enumerate(train_gen):
n_iter = (epoch - 1) * len(train_gen) + idx + 1

with tf.GradientTape() as tape:
predictions = model(x_train, training=True) # 传入网络得到输出
loss = loss_object(y_train, predictions) # 计算 loss
gradients = tape.gradient(loss, model.trainable_variables) # 计算梯度
optimizer.apply_gradients(zip(grads, model.trainable_variables)) # 梯度下降
trainable_vars = model.trainable_variables

# 一些指标,都是标量
train_loss(loss)
train_accuracy(y_train, predictions)

with summary_writer.as_default():
# 写入评价指标
tf.summary.scalar('loss', train_loss.result(), step=n_iter)
tf.summary.scalar('accuracy', train_accuracy.result(), step=n_iter)

for var, grad in zip(trainable_vars, gradients)
# 写入各个可训练元素的直方图、梯度和参数
tf.summary.histograme(var.name, var, n_iter)
tf.summary.scalar('Grads:' + var.name, tf.norm(grad), n_iter)
tf.summart.scalar('Weights' + var.name, ty.norm(var), n_iter)

print('Epoch {:03d} finished.'.format(epoch))

TensorBoard 里面的数据不能导出来,也可以单独将其写入 csv 来方便后续的处理。

参考教程