通过重计算来节省显存,参考论文《Training Deep Nets with Sublinear Memory Cost》。
本程序已经内置在bert4keras中
首先,确保环境变量加上RECOMPUTE=1
。
然后,在自定义层的时候,用recompute_grad
装饰call函数即可:
from recompute import recompute_grad
class MyLayer(Layer):
@recompute_grad
def call(self, inputs):
return inputs * 2
如果是现成的层,可以通过继承的方式来装饰:
from recompute import recompute_grad
class MyDense(Dense):
@recompute_grad
def call(self, inputs):
return super(MyDense, self).call(inputs)
在下面的环境下测试通过:
tensorflow 1.14 + keras 2.3.1
tensorflow 1.15 + keras 2.3.1
tensorflow 2.0 + keras 2.3.1
tensorflow 2.1 + keras 2.3.1
tensorflow 2.0 + 自带tf.keras
tensorflow 2.1 + 自带tf.keras
确认不支持的环境:
tensorflow 1.x + 自带tf.keras
欢迎报告更多的测试结果。
强烈建议用keras 2.3.1配合tensorflow来跑,强烈不建议使用tensorflow 2.x自带的tf.keras来跑
- 在BERT Base版本下,batch_size可以增大为原来的3倍左右;
- 在BERT Large版本下,batch_size可以增大为原来的4倍左右;
- 平均每个样本的训练时间大约增加25%;
- 理论上,层数越多,batch_size可以增大的倍数越大。