勾配チェックポイントとは
勾配チェックポイント(Gradient Checkpointing、Activation Checkpointing)は、ディープラーニングの学習時にGPUメモリの使用量を削減するための技術です。通常、逆伝播(バックプロパゲーション)では順伝播時の中間活性化値をすべてメモリに保持する必要がありますが、勾配チェックポイントでは一部の活性化値のみを保持し、必要に応じて再計算することでメモリを節約します。
メモリと計算のトレードオフ
勾配チェックポイントの本質は、メモリと計算時間のトレードオフです。チェックポイントを設定した層の活性化値のみを保存し、その間の層の活性化値は逆伝播時に再計算します。これにより、メモリ使用量はO(n)からO(√n)に削減できますが、順伝播の計算が約1.3〜1.5倍増加します。大規模モデルではメモリがボトルネックになることが多いため、この追加計算は許容されることが一般的です。
実装方法
PyTorchではtorch.utils.checkpoint.checkpoint関数を使用して、特定のモジュールの活性化値の保存をスキップできます。TensorFlowではtf.recompute_gradが同等の機能を提供します。モデル全体に適用することも、メモリ消費が大きい特定の層のみに適用することもできます。Transformerの各層にチェックポイントを設定するのが一般的なアプローチです。
大規模モデル学習での重要性
GPTやLLaMAなどの大規模言語モデルの学習では、勾配チェックポイントは不可欠な技術です。数百億パラメータのモデルでは、すべての活性化値を保持するとGPUメモリが不足するため、勾配チェックポイントに加えて、モデル並列化やZeRO最適化と組み合わせてメモリ効率を最大化します。