r/MachineLearning🔥 138
💬 12

要注意:AI生成のCUDAカーネルが学習と推論を密かに破壊する現象が発生

laginimaineb
約23時間前

ディスカッション (12件)

0
laginimainebOP🔥 138
約23時間前

先月、NVIDIAはDeepSeekやQwen、Gemma、Kimiなどから抽出した235個の本番環境向けCUDAカーネルを含む新しいベンチマーク「SOL-ExecBench」を公開しました。私たちは、このベンチマークで高スコアを記録したAI生成コードをいくつか試しに本番ワークロードへ導入してみたのですが、驚くべきことにその多くが意図せず機能停止を引き起こしました。その一つが、Transformerの学習ステップの最後で実行される「fused embedding-gradient + RMSNorm backward pass」です。ベンチマークの検証を余裕でクリアした最速のカーネルを採用したところ、学習の損失(Loss)が発散し、二度と回復しなくなりました。原因究明のためデータセットやオプティマイザ(SGDからAdamWへ変更)を入れ替えると症状が消えるという、研究において最も厄介なパターンに陥りました。これは「アイデア自体がダメだったのか?」という疑念を生み、研究者がデータやアーキテクチャの修正に無駄な時間を費やしてしまう種類のバグです。調査の結果、実際のバグは「カーネルの勾配計算の一部がfp32ではなくbf16で蓄積されていたこと」でした。Embedding backwardは小さな勾配を積み重ねますが、実データでは特定のトークンIDにアクセスが集中するため、bf16では精度不足となり値がゼロに丸められて漂流してしまうのです。AdamWはそのバイアスを吸収するため一見問題がないように見えますが、他のアルゴリズムでは致命的なバグとなります。他の壊れた実装例についても、私たちのブログ記事で詳細に解説しています。

1
lostmsu
👍4約21時間前

AdamWでは動くのにSGDでは動かないからといって、bf16の代わりにfp32を使うのがバグだとは思えないな。

2
siegevjorn
👍15約19時間前

それこそがOPが「見つけるのがめちゃくちゃ難しい」と言っている理由だよ。でも、修正が必要なバグであることに変わりはないね。

3
JustOneAvailableName
約18時間前

いや、AdamWを使うときはそのカーネルを使いたいんだよ。バグっていうのは間違った実装のことで、今回のは主にドキュメントの記載ミスだ。

4
Bakoro
👍50約21時間前

結局、本当のバグはカーネルの埋め込み勾配(embedding-gradient)の半分がfp32ではなくbf16で累積されていたことだった。

うわ、それは多くの人が絶対に見つけられないような類のものだね。bf16はよく使われるから、それを見て見過ごしちゃう人もいそうだ。

5
az226
👍5約19時間前

MuonのV100向けカスタムNSカーネルで似たような問題があったよ。うまくいくケースもあれば失敗するケースもあったんだ。原因はメモリタイルの負荷で、kを32に制限したら100%うまくいった。kが64や128だと、何割かの確率でひっそりと間違った累積値になっちゃうんだよね。

6
max123246
👍3約16時間前

誰もカーネルの数値的精度のテストには投資しないからね。有望な研究もいくつかあるけど、エラーの許容範囲を広げて「はい終わり」にするほうが圧倒的に一般的だよ。AIが生成したカーネルなら、なおさら厳密さは低くなると思う。

7
sohang-3112
👍1約14時間前

CUDAカーネルは書いたことないんだけど、bf16とfp32ってそれぞれ16bitと32bitの浮動小数点数のことで合ってる?

8
pm_me_your_pay_slips
👍17約19時間前

じゃあ、解決策はAdamWを使うことだったんだね。

9
siegevjorn
👍21約19時間前

解決策はバグを修正することだよ。彼らは小さなTransformerで学習させたからAdamWがbf16とfp32の不一致を吸収できたかもしれない。でも、それが1T規模のLLMだったらどうするの?

10
siegevjorn
👍3約19時間前

うわ。一体どうやってそんなバグが起きたんだ?fp32が必要な箇所をbf16で置き換えちゃったってこと?

11
TailorImaginary3629
👍1約19時間前

投稿のソースコードはどうやって見るの?