FSDP

Fully Sharded Data Parallel

FSDPとは

FSDP(Fully Sharded Data Parallel)とは、PyTorchが提供する分散学習の手法で、モデルのパラメータ、勾配、オプティマイザ状態をすべてのGPUに分散(シャーディング)することでメモリ効率を最大化する技術です。MicrosoftのDeepSpeed ZeRO Stage 3に相当する機能をPyTorchネイティブで実現します。

FSDPの仕組み

通常のデータ並列(DDP)では各GPUがモデル全体のコピーを保持しますが、FSDPではモデルパラメータをGPU間で分割して保持します。フォワードパスやバックワードパスで必要になった際にAllGather操作でパラメータを一時的に再構成し、計算後にはメモリから解放します。これにより、単一GPUのメモリに収まらない大規模モデルの学習が可能になります。

FSDPの利点

FSDPの最大の利点はPyTorchネイティブであることです。外部ライブラリへの依存なしに大規模モデルの分散学習が可能で、PyTorchのエコシステム(自動微分、モジュールAPI、チェックポイント)とシームレスに統合されます。Mixed Precision学習やActivation Checkpointingとの組み合わせも容易です。

PyTorch FSDPの進化

PyTorch 2.0以降ではFSDP2が開発されており、通信の効率化やAPIの簡素化が進んでいます。Metaの大規模言語モデル(LLaMAなど)の学習にもFSDPが使用されており、実績あるスケーラブルな分散学習ソリューションとして定着しています。