CUDA C++でNeRFをほぼ0から実装してみた(Part1/3): 概要~MLP編

概要

CUDA C++を使用してMLP,Multiresolution Hash Encoding,NeRFを実装しました.NVIDIAの実装と比較すると遅いですが,数分でおおよそ綺麗に学習できました.本記事では一番初めに実装の大まかな雰囲気を示し,3編に分けて具体的に実装を説明していきます.

想定する読者の対象

C++やCUDAの基本文法が分かっていれば大丈夫だと思ってます.

GitHubとかにプロジェクト全体のソースコード公開しないのか

個人的な研究用に書いてるプログラムもあったりするのでしばらくは公開しないと思います.

一応

内容には気を付けて書きましたが,もし誤りや実装の改善点等があれば教えていただけると嬉しいです.

NeRFとは

NeRFはNeural Radiance Fieldsの略称です.ある物体を様々な視点から撮った写真(2次元)の情報から3次元の形状を推定するView Synthesisの一手法です."NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis" [Mildenhall et al. 2020] において発表されました.

ざっくりとした実装

実装のレポート的な記事なので実装の話をメインでやりましょう.次のような実装をしました.

<前処理>

(1): ファイルからのカメラ情報,教師画像のロード

(2): 学習時に使用しやすい形に入力データを処理

(3): 学習に必要なパラメーターや配列等を定義

<学習ループ開始>

(4): カメラから飛ばすレイを計算

(5): レイ上の点(Position),レイの方向(Direction)をサンプリング

(6): サンプル点の情報をMLPに入力し,出力としてサンプル点の色と密度を得る

(7): カメラに入るレイの色を計算

(8): MLPの出力の誤差を計算

(9): MLP誤差逆伝播

(10): MLPのパラメーター最適化

<学習ループ終了>

本シリーズの記事においては(4) ~ (10)を3編に分けて説明しようと思います.(1/3)としてはMLPの実装,(2/3)としてはMLPの入力時に使用するEncoderの実装,(3/3)としてはNeRF全体の実装について扱います.

MLP編 概要

Fully Fused MLP [Müller et al. 2021] の概念を参考にしてMLPを実装しました.NeRFで使用するMLPは全結合層,即ち行列の乗算と活性化層の組み合わせです.ここで,行列の乗算は,行列を細かく分割した小行列の乗算を繰り返すことで計算できます.今回は16x16の小行列に分割し,そしてCUDAの提供するwmmaのAPIを使用して高速に行列演算を行い,高速にMLPの処理を行いました.さらに,GPUのglobal memoryをなるべく使用しない実装を行いました.

MLP編 はじめに

MLPとは多層パーセプトロン(Multilayer perceptron)というものですが,入力データの線形結合を施した結果を用いてニューロンの発火(活性化層)を考えるといった,神経を模したものとなっております.……結局は巨大な非線形関数です.MLPを実装するという人は勿論あまりおらず,多くの人はPyTorchやLibTorchなどのフレームワークを使用しています.もちろんそちらの方が「動作保証」もありますし,「MLPの高速化」が目的でない場合はMLPの実装自体が非本質的な過程となるため,避けるのが良いでしょう.そんな中,2021年にNVIDIAのグループがFully Fused MLPというMLPの設計を発表しました.

research.nvidia.com

今回はこのMLPの設計を元にして行ったMLPの実装を説明していこうと思います.

MLPの実装: コンセプト1: 行列乗算の分割

NeRFの学習に使用するMLPの基本構造は,全結合層(バイアス項なし)と活性化層のセットが繰り返されたものとなっています.全結合層はすなわち,行列(重み W)とベクトル(入力 x)の乗算です.ここで,行列とベクトルの乗算については,次のようにベクトルを並べて同時に行うことが出来ます.

 \displaystyle
WX =
W
\begin{pmatrix}
x_1&x_2&x_3& ... &x_N 
\end{pmatrix}
= 
\begin{pmatrix}
Wx_1&Wx_2&Wx_3& ... &Wx_N
\end{pmatrix}


機械学習的にみると,これは N個の入力ベクトルのバッチに対して同時に全結合層の処理を作用させることを意味しています.

さて,ここで簡単のため, Nは16の倍数であるとしましょう.つまり,自然数 Mが存在して N = 16Mであるとしましょう.じゃあ先ほどの式をもう一度分割しなおしましょう.

 \displaystyle
WX =
W
\begin{pmatrix}
X_1&X_2&X_3& ... &X_M 
\end{pmatrix}
= 
\begin{pmatrix}
WX_1&WX_2&WX_3& ... &WX_M
\end{pmatrix}


本当に分割しなおしただけです.先ほどの式から変わった場所としては,ベクトルだったx_1, x_2, ... x_Nが横幅16の行列 X_1, X_2, ... , X_Mとなっただけです.では今度は,出力のベクトルの次元( OutDimとします)も16の倍数としましょう.即ち,自然数 Kが存在して OutDim = 16Kを満たしているとします.この時,先ほどの式をさらに分割すると,

 \displaystyle
WX =
\begin{pmatrix}
W_1 \\
W_2  \\ 
... \\
W_K 
\end{pmatrix} 
\begin{pmatrix}
X_1& ... &X_M 
\end{pmatrix}
= 
\begin{pmatrix}
W_1 X_1 & W_1 X_2 & ... & W_1 X_M \\
W_2 X_1 & W_2 X_2 & ... & W_2 X_M \\
... \\
W_K X_1 & W_K X_2 & ... & W_K X_M \\
\end{pmatrix}


となります.まあ式見ても直感的でないので図で見ましょう.ここで入力ベクトルの次元 InDimも16の倍数とします.( 16L = InDimを満たすようにLを定義します)

こんな感じで,行列同士の掛け算は細かい小行列の掛け算に分割できます.この図においては Xにおける横幅16縦幅64の小行列(灰色で塗られていく領域)に対して Wにおける横幅64縦幅16の小行列(橙などに塗られた各々の領域)を乗算しています.じゃあ,今度はその小行列同士の掛け算に注目しましょう.この小行列同士の掛け算 をさらに分割すると,


 \displaystyle
W_i X_j = 
\begin{pmatrix}
W_{i1}&W_{i2}&...&W_{iL}
\end{pmatrix}
\begin{pmatrix}
X_{1j} \\
X_{2j} \\
...\\
X_{Lj}
\end{pmatrix}
=
\sum_{k=1}^L W_{ik}X_{kj}


となります.これによって行列の乗算を16x16の行列同士の掛け算にまで分割することが出来ました.

さて,この分割の何が嬉しいかを書きます.唐突ですがGPU上にはTensorコアといった行列乗算に最適化された演算回路があります.CUDAにおいてはWMMA(Wave Matrix Multiply-accumulate)のAPIがあります.実はこのAPIを用いた行列演算では演算に使用できる行列のサイズが固定されています(この記事に纏まっております) (n,m,k) = (16,16,16)の行列サイズを今回は採用することとしましょう.すると,これまで見てきた行列演算の分割によって,このAPIを使用した行列演算が可能となりました.このAPIを利用することでハードウェア最適化の恩恵を受けて高速に全結合層の処理ができるという塩梅です.ここで一つだけ注意点が必要ですが,この演算を行う場合,行列の各要素の値はfloat以上の精度を持っていると怒られます.wmmaのAPIから__halfという,fp16,つまり半精度浮動小数点の型が提供されているのでそれを使用する必要があります.

並列性の確認

小行列に分割したら遅くなるんじゃないかと疑問に思うかもしれませんが,実はある程度並列化可能です.

図中に示した黄色矢印はすべて並列に計算できます.よって,この図において,完全に並列化できない場所は,黄色矢印自体の演算,すなわち16x64と64x16の行列の乗算です.この部分は先ほど示した通り,16x16行列同士の乗算に分割し,加算してあげる必要があります.即ち4回のループが必要です.  Xの横方向についても並列化は可能です.しかしながら,問題(実装の説明時に説明します)が発生するため,今回では行いません.結局,この図では4x8の32回のループが必要です.

行列のサイズが16の倍数でない時

先ほどは行列のサイズが全て16の倍数であるという制約を設けましたが,実際にMLPを使用するうえで,層の次元,特に入力層や出力層が16の倍数次元になることは稀です.その場合はどうするのかというと,今回の実装では単純にゼロパディングを行い,行列サイズを16の倍数に揃えました.次の関数は今回の実装にて何度も出てきます.

// input以上の最小のKの倍数を返す
__host__ __device__ constexpr int next_multiple(const int input, const int K) {
    const int Q = input % K;
    if (Q == 0) {
        return input;
    }
    else {
        return input + K - Q;
    }
}

MLPの実装: コンセプト2: 遅いメモリと速いメモリがある

CUDAをある程度書いたことのある人はおそらくglobal memory,shared memoryと言った言葉を聞いたことがあるでしょう.
global memoryとはGPUすべてのスレッドからアクセスできるメモリで,最近のGPUだと10GBやそれ以上あったりして結構な大容量メモリです.しかしながらこのglobal memory,マジで遅いのでなるべく読み書きの回数を減らしたいです.
一方でshared memoryはカーネル実行時のブロック内で共有されているオンチップのメモリで,容量は数十KB程度のオーダーしかないものの(ないために?),非常に高速です.よって,メモリとのやり取りはなるべくshared memoryで行いたいですね.
ではまず,容量の確認から行きましょう.先ほども書きましたが,shared memoryはカーネル実行時のブロック内で共有されているメモリです.つまり,処理を実行するブロック内のスレッドが大きければ多いほど,各スレッドで使用できるshared memoryの容量は減ってしまいます.では,具体的にどういう感じにメモリを使用するか見ていきましょう.

CUDAでのMLP並列化?

私は大学の講義で,なぜかC言語でCPU上で動かすMLPを実装させられたことがあります.良い勉強になりましたが二度とやりません.それはともかく,MLPの処理の並列化とはどういうことなのでしょうか?

正確に書くと,並列化するのは「MLPバッチ処理」です.各バッチの処理は,最適化処理を除いて完全に独立です.ということで,理想的にはバッチサイズだけ並列化が可能です.CUDAにおいては,ブロック単位を意識してプログラムを書きます.これは何度も書いている通り,shared memoryがブロック内部でのみ共有されているためです.では,適度な量にバッチを分割して各ブロックの処理データ量を調整し,それらをブロック単位で扱うこととしましょう.

この図のように,全体の非常に大きなバッチを,各ブロックでshared memoryに載せられる程度に分割します.今回の実装では1ブロックにつき128バッチを処理する設計としました.よって,GridSizeは(全体のバッチ数)/ 128 ということになります.

ブロック内部の構造はどうする

これが何だかんだで一番その後の実装の全てに関わってきます.まず,CUDAの仕様として,warpという,スレッドの塊を表現する単位があります.このwarpですが,あるwarp内のスレッドは必ず同じ命令を実行します.if分岐とかがあってスレッド内部で処理が分かれた場合は,片方の分岐先の処理(if文の処理)が終わるまで他方に分岐するスレッドの処理(else if, elseの処理)は待機します.このように,warp内部での処理がバラバラになるとカーネル実行の並列性,効率が大幅に低下します.これをwarp divergenceと呼びます.今回の私の環境(RTX3080: compute 86)では32スレッドが1warpとして扱われます.よってwarpを意識したカーネル実行を行うべきですね.ということで,BlockSizeは{32u, ty, 1u}としましょう.ty(BlockSize.y)に関してはのちに決めます.というわけで,各ブロックは32*tyのスレッドを実行することとなりました.

shared memoryに載せるもの,global memoryに載せるもの

一旦まとめましょう.これまでの議論より,カーネルの実行設定は<<< nBatch/128, {32u, ty, 1u}>>>となりました.各ブロックにおいては128バッチを扱うことになりました.
今回の設計のMLPにおいて使用するメモリ領域は次の通りです.
(1): NNの入力データ(一番最初の入力)
(2): NNの出力データ(一番最後の出力)
(3): MLPを流れていくデータ
(4): 全結合層の重みパラメーター
(5): 誤差逆伝播用のバッファ
(6): バッチ全体によるパラメーターの勾配を記録する配列
(7): ブロック内部でのバッチによるパラメーターの勾配を記録する配列
(1), (2)に関してはMLP内部では扱いません.MLP全体をモジュール化した際のI/Oの役割を果たします.当然CPUとのデータ通信はglobal memoryと行われるので,global memoryに載せることになります.
(3)はMLP内部で何度も読み書きを行う配列です.MLP内部でのデータの最大次元をMaxDimとして,sizeof(half) * MaxDim * MaxDimです.MaxDimを128とすると,sizeof(half)が2byteであることに注意して,32KBです.MaxDimが64なら8KBです.shared memoryに載せましょう.
(4)(5)は順伝播と逆伝播において使用するデータです.これらは実は工夫することにより,読み書きの回数を1,2回に抑えることが可能です.また,順伝播と逆伝播で使用することを考えると,これらの二つは異なるCUDAカーネルで行われるので,shared memoryではなくglobal memoryに載せるべきでしょう.
(6)はスレッド全体で書き込むデータです.ゆえにglobal memoryに載せます.
(7): これは私の実装ではshared memoryに載せました.しかしこれは何とも言えません.MaxDimが128であるときを考えましょう.この時のこの配列が占めるメモリ領域は128 * 128 * sizeof(half) = 32KBです.(3)と合わせてみると,64KBとなります.この記事を参考にして48KB以上のshared memoryを使用することが出来ましたが,果たしてこれは適切な実装なのかは不明です.今回の記事に載せるソースコードではshared memoryに載せたものとして処理をしています.私の環境では動きます……

shared memoryの載せ方

shared memoryに載せるものを決めました.じゃあ載せましょうということで単純にmemcpy的なことをする前に少しお待ちください.shared memoryを扱う上で考慮する必要がある概念として,バンクコンフリクトというものがあります.具体的な説明は他の記事に任せるとして,ざっくり書くと,shared memoryはメモリ領域が16とか32とかの数のバンクに振り分けられており,同じバンクに複数のスレッドが同時にアクセスするとメモリの処理が逐次的になり,ゆえにカーネルの性能低下につながります.実際に設計したMLPでバンクコンフリクトを意図的に起こした場合,最悪50%の性能低下を起こしました.
今回の実装ではshared memoryには「層を流れる特徴ベクトルのバッチ」が載っています(今後,これをintermediateと呼びます).よって,これらは1バッチあたり16の倍数要素の配列となっております.実は,各バッチの最後にいくつかzero-paddingを行うことでバッチごとのバンクがずれ,バンクコンフリクトを軽減できます.このzero-paddingの数をSKEWとしましょう.今回はSKEWを8としましたが,他の値を試すのもいいと思います.つまり,次のようにshared memory配列を置きます.

図のように,intermediateの下側にSKEWを与えます.重み Wはglobal memoryに載せているため,当然SKEWは与えません.

MLPの実装: 順伝播 外観

さあ,これまでに説明した概念を用いて順伝播を実装していきましょう.私の実装はこちらです.

/*
 * 順伝播(学習あり)
 *
 * inDim -> HiddenDim -> HiddenDim -> HiddenDim -> ... -> HiddenDim -> outDim
 *          |<------------------- nHiddenLayer ------------------->|
 */
template <const uint32_t indim, const uint32_t hiddendim, const uint32_t outdim, const uint32_t nHiddenLayer>
MFFM_DEVICE void Kernel_Debug_train_forward(const uint32_t BatchSize, Activation ActHid, Activation ActOut, __half* intermediate, __half* weights, __half* buffers) {

    constexpr uint32_t indim_aligned = next_multiple(indim, TENSOR_ROW);
    constexpr uint32_t hiddendim_aligned = next_multiple(hiddendim, TENSOR_ROW);
    constexpr uint32_t outdim_aligned = next_multiple(outdim, TENSOR_ROW);

    uint32_t weights_elem_idx = 0;
    uint32_t buffers_elem_idx = 0;
    uint32_t buffers_elem_idx_blockbias = indim_aligned * ONEBATCH_SIZE * blockIdx.x;

    __syncthreads();

    // MLPの入力をback propagationのために保存
    int shmem_curDim = indim_aligned + SKEW;
    store_intermediate<__half>(shmem_curDim, indim_aligned, intermediate, buffers);

    // 入力層 INDIM(_ALIGNED) -> HIDDENDIM(_ALIGNED)
    MLP_Forward<indim, hiddendim, indim_aligned, hiddendim_aligned>(ActHid, intermediate, weights);

    shmem_curDim = hiddendim_aligned + SKEW;

    weights_elem_idx += indim_aligned * hiddendim_aligned;
    buffers_elem_idx += BatchSize * indim_aligned;
    buffers_elem_idx_blockbias = hiddendim_aligned * ONEBATCH_SIZE * blockIdx.x;

    // 隠れ層 HIDDENDIM(_ALIGNED) -> HIDDENDIM(_ALIGNED)
    for (int i = 0; i < nHiddenLayer - 1; i++) {
        // MLPの入力をback propagationのために保存
        store_intermediate<__half>(shmem_curDim, hiddendim_aligned, intermediate, buffers + buffers_elem_idx);

        MLP_Forward<hiddendim, hiddendim, hiddendim_aligned, hiddendim_aligned>(ActHid, intermediate, weights + weights_elem_idx);
        weights_elem_idx += hiddendim_aligned * hiddendim_aligned;
        buffers_elem_idx += BatchSize * hiddendim_aligned;
        __syncthreads();
    }

    // MLPの入力をback propagationのために保存
    store_intermediate<__half>(shmem_curDim, hiddendim_aligned, intermediate, buffers + buffers_elem_idx);
    // 出力層 HIDDENDIM(_ALIGNED) -> OUTDIM(_ALIGNED)
    MLP_Forward<hiddendim, outdim, hiddendim_aligned, outdim_aligned>(ActOut, intermediate, weights + weights_elem_idx);

    __syncthreads();

    shmem_curDim = outdim_aligned + SKEW;

    // 結果をback propagationのために記録
    buffers_elem_idx += BatchSize * hiddendim_aligned;
    buffers_elem_idx_blockbias = outdim_aligned * ONEBATCH_SIZE * blockIdx.x;
    store_intermediate<__half>(shmem_curDim, outdim_aligned, intermediate, buffers + buffers_elem_idx);
}

順番に一つずつ見ていきましょう.

template <const uint32_t indim, const uint32_t hiddendim, const uint32_t outdim, const uint32_t nHiddenLayer>
...

indim: NNの入力層の次元
hiddendim: NNの隠れ層の次元(すべての隠れ層で等しいとします)
outdim: NNの出力層の次元 nHiddenLayer: NNの隠れ層の数

これはコンパイル時に確定させ,実行中には変化させないものとします.

MFFM_DEVICE void Kernel_Debug_train_forward(const uint32_t BatchSize, Activation ActHid, Activation ActOut, __half* intermediate, __half* weights, __half* buffers) {
    ...
}

MFFM_DEVICEは deivice に同じです.
BatchSize: 文字通りバッチサイズです.
Activation構造体:

enum class Activation {
    ReLU,
    LeakyReLU,
    Sigmoid
};

これです.
ActHid: 入力層,隠れ層における活性化層の種類(すべての隠れ層で等しいとします)
ActOut: 出力層における活性化層の種類
intermediate: MLPの層を流れていくデータです.(shared memoryに載っています!)
weights: MLP内のすべての全結合層のパラメーター Wがこの配列に入っています.
buffers: 逆伝播時に使用する,順伝播時のintermediateの値をすべて格納します.

...
constexpr uint32_t indim_aligned = next_multiple(indim, TENSOR_ROW);
constexpr uint32_t hiddendim_aligned = next_multiple(hiddendim, TENSOR_ROW);
constexpr uint32_t outdim_aligned = next_multiple(outdim, TENSOR_ROW);
...

早速出てきました.先ほども示した通り,入力次元や出力次元は16の倍数とは限りません.そのため,入力層次元,隠れ層次元,出力層次元のそれぞれ以上の16の倍数を求めておきます.TENSOR_ROWは16の即値です.

...
uint32_t weights_elem_idx = 0;
uint32_t buffers_elem_idx = 0;
uint32_t buffers_elem_idx_blockbias = indim_aligned * ONEBATCH_SIZE * blockIdx.x;
...

weights_elem_idx: weightsには各MLPの全結合層において使用するパラメーターがすべて保存されています.よって,各全結合層が終了するたびに,使用する全結合層が保存されているインデックスを記録しておく必要があります.
buffers_elem_idx: buffersにも順伝播時のすべてのMLP間のintermediateを保存する必要があります.よって,各MLPが終了するたびに使用するbuffersのインデックスが記録されている必要があります.
buffers_elem_idx_blockbias : buffersは全てのスレッドが使用する配列です.intermediateはブロック間で独立であるため,「そのintermediateが存在していたブロック」が逆伝播時に明確である必要があります.ゆえに,blockIdxに対応する量だけインデックスをずらしてintermediateを保存しておきます.ここで注意ですが,保存するintermediateは,本来のサイズが16の倍数でなくても,16の倍数にパディングされた状態を記録します.これの理由は逆伝播時に明らかになります.

...
// MLPの入力をback propagationのために保存
int shmem_curDim = indim_aligned + SKEW;
store_intermediate<__half>(shmem_curDim, indim_aligned, intermediate, buffers);
...

shmem_curDim: 現在のshared memoryの1バッチの次元を記録します.コンセプト2の最後に述べましたが,shared memoryにはバンクコンフリクトを回避するため,SKEWというパディングを与えます.
store_intermediate: 後に説明します.逆伝播時に使用するため,buffersにintermediateの値を記録します.

...
// 入力層 INDIM(_ALIGNED) -> HIDDENDIM(_ALIGNED)
MLP_Forward<indim, hiddendim, indim_aligned, hiddendim_aligned>(ActHid, intermediate, weights);

shmem_curDim = hiddendim_aligned + SKEW;

weights_elem_idx += indim_aligned * hiddendim_aligned;
buffers_elem_idx += BatchSize * indim_aligned;
buffers_elem_idx_blockbias = hiddendim_aligned * ONEBATCH_SIZE * blockIdx.x;
...

MLP_Forward: 後に説明します.順伝播時のMLPの処理です.
shmem_curDim: 入力層が終わってintermediateの次元がhiddendimになりました.shared memoryに適したサイズを与えます.
buffers_elem_idx: 入力層が終わったのでbuffersのインデックスを更新します.入力層のデータが全バッチ入るように値を加えます.
buffers_elem_idx_blockbias: 先ほどと同様です.blockIdxに対応する量だけズラします.

...
// 隠れ層 HIDDENDIM(_ALIGNED) -> HIDDENDIM(_ALIGNED)
for (int i = 0; i < nHiddenLayer - 1; i++) {
    // MLPの入力をback propagationのために保存
    store_intermediate<__half>(shmem_curDim, hiddendim_aligned, intermediate, buffers + buffers_elem_idx);

    MLP_Forward<hiddendim, hiddendim, hiddendim_aligned, hiddendim_aligned>(ActHid, intermediate, weights + weights_elem_idx);
    weights_elem_idx += hiddendim_aligned * hiddendim_aligned;
    buffers_elem_idx += BatchSize * hiddendim_aligned;
    __syncthreads();
}

// MLPの入力をback propagationのために保存
store_intermediate<__half>(shmem_curDim, hiddendim_aligned, intermediate, buffers + buffers_elem_idx);
...

隠れ層の処理です.処理は入力層と同じですね.

...
// 出力層 HIDDENDIM(_ALIGNED) -> OUTDIM(_ALIGNED)
MLP_Forward<hiddendim, outdim, hiddendim_aligned, outdim_aligned>(ActOut, intermediate, weights + weights_elem_idx);
shmem_curDim = outdim_aligned + SKEW;

// 結果をback propagationのために記録
buffers_elem_idx += BatchSize * hiddendim_aligned;
buffers_elem_idx_blockbias = outdim_aligned * ONEBATCH_SIZE * blockIdx.x;
store_intermediate<__half>(shmem_curDim, outdim_aligned, intermediate, buffers + buffers_elem_idx);

出力層の処理です.

以上で順伝播の外見は出来上がりです.今回の関数には入っていないが説明すべき関数が一つあります(入力データのロードと出力データのストアを行っていないですよね?).また,関数内でも説明を省いた関数が2つあります.それらについて見ていきましょう.

MLPの実装: 順伝播 I/O

先ほど示した関数には出てきていないですが,load_input()という関数があります.

////////////////////////////////////////// LOAD AND STORE ///////////////////////////////////////////////////////////
/*
 * inDim * BatchSizeの入力データがCPUからGPUのグローバルメモリに保存された
 * 入力データ1つの次元はinDimであるが,これを16の倍数にゼロパディングする
 * バンクコンフリクト対策のため,skewだけさらに次元を拡張する(ゼロパディングする).すなわち入力データ1つにつきnext_multiple(inDim, 16) + skew次元になる.これをnewDimとする
 * ここで,1Blockにつき,128バッチを担当する.
 * また,カーネル実行の設定はBlockNum = GridSize = (BatchSize/128, 1, 1),BlockSize = (32, MAX_REQUIRED, 1)
 * 以上より,BlockIdx.x = bx, ThreadIdx = tx, tyとすると,
 *
 * input + inDim * 128 * bx + inDim * 32 * ty + inDim * tx からのinDim要素を
 * shmem_input + newDim * 32 * ty + newDim * txからのinDim要素に格納する
*/
MFFM_DEVICE void load_input(const int inDim, const int newDim, const float* __restrict__ input, __half* __restrict__ shmem_input) {
    const int bx = blockIdx.x;
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;

    const int warp_index_required = ONEBATCH_SIZE / 32;

    if (ty >= warp_index_required) {
        return;
    }

    for (int i = 0; i < inDim; i++) {
        shmem_input[newDim * 32 * ty + newDim * tx + i]
            = (__half)input[inDim * ONEBATCH_SIZE * bx + inDim * 32 * ty + inDim * tx + i];
    }
        for (int i = inDim; i < newDim; i++) {
            shmem_input[newDim * 32 * ty + newDim * tx + i] = 0.0f;
        }
}

さて,私の実装時のコメントが凄く丁寧ですが,このコメントを書いたときはまだ頭がこんがらがっておりました.こういうコメントはやはり安全でいいですね.というわけで部分的にみていきましょう.

MFFM_DEVICE void load_input(const int inDim, const int newDim, const float* __restrict__ input, __half* __restrict__ shmem_input) {
...

inDim: 入力データの次元
newDim: 入力データの次元を,それ以上の最小の16の倍数にし(next_multiple(inDim, 16)),それにSKEWを与える.つまり入力データをロードしたあとのintermediateの1バッチの次元
input: 入力データ(global memory) shmem_input: 入力データをshared memory上に,適切な構造となるようにコピーしたもの

...
const int warp_index_required = ONEBATCH_SIZE / 32;
if (ty >= warp_index_required) {
    return;
}
...

先ほどはカーネルの実行時にBlockSize.xを32とし,BlockSize.yの値を未定としていました.今回の関数における処理ではブロックごとに128バッチ,つまりONEBATCH_SIZEのみ読み込めばよいため,各スレッドで1バッチロードするとすればthreadIdx.yはONEBATCH_SIZE/32以上は必要ありません.ゆえにreturnしてもらいます.

...
for (int i = 0; i < inDim; i++) {
    shmem_input[newDim * 32 * ty + newDim * tx + i]
        = (__half)input[inDim * ONEBATCH_SIZE * bx + inDim * 32 * ty + inDim * tx + i];
}
for (int i = inDim; i < newDim; i++) {
    shmem_input[newDim * 32 * ty + newDim * tx + i] = 0.0f;
}
...

嫌な式ですね.バグ埋め込んだ時に眺めるのが苦痛な処理です.図で説明しましょう.

図に色々詰め込んでいますが,「バッチの次元を16の倍数にし,SKEWを与えるよ.増加した次元部分には0を埋め込むよ」と言っているだけです.ty=2の下に残っている矢印は消し忘れです.
この図においてメモリ上の配置が図中下部に示されています.図示上は2次元配列に見えますが,実際には1次元配列で扱います.inputは全てのバッチの入力データを保持しています.よって,blockIdx.xが処理すべきインデックスを指す必要があります.一つのブロックが扱うバッチは128なので,128 * indim * blockIdx.xのバイアスを与えると良いことになります.この状態が図中の左です.さらに,各スレッドはそれぞれ対応する1バッチをinputからshmem_inputにコピーします.よって,図中左におけるk番目のバッチを処理する際には,input側ではindim * k,shmem_input側ではnewdim * kのインデックス調整が必要です.インデックス調整が完了したら,そこからindim要素をコピーしてあげると入力データのロードが完了です.ちなみに2つ目のforを抜かしているとnanが出ることがあります(ゴミ値が入ってるため).

この処理を逆方向に行うのがstore_intermediate()です.

/*
 * outDim * BatchSizeの入力データをGPUのグローバルメモリに保存する
 * 現在所持している出力データはnext_multiple(outDim, 16) + SKEWの次元である.
 * これをoutDim次元に整形して出力したい
 *
 * ここで,1Blockにつき,128バッチを担当する.
 * また,カーネル実行の設定はBlockNum = GridSize = (BatchSize/128, 1, 1),BlockSize = (32, MAX_REQUIRED, 1)
 * 以上より,BlockIdx.x = bx, ThreadIdx = tx, tyとすると,
 *
 * shmem_output + shmem_outDim * 32 * ty + shmem_outDim * tx からのoutDim要素を
 * output + outDim * 128 * bx + outDim * 32 * ty + outDim * tx からのoutDim要素に格納する
*/
template<typename T>
MFFM_DEVICE void store_intermediate(const int shmem_curDim, const int curDim, const __half* __restrict__ shmem_intermediate, T* __restrict__ destination) {
    const int bx = blockIdx.x;
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;

    const int warp_index_required = ONEBATCH_SIZE / 32;

    if (ty >= warp_index_required) {
        return;
    }

    for (int i = 0; i < curDim; i++) {
        destination[curDim * ONEBATCH_SIZE * bx + curDim * 32 * ty + curDim * tx + i]
            = (T)shmem_intermediate[shmem_curDim * 32 * ty + shmem_curDim * tx + i];
    }
}

shmem_curDim: 現在のデータの次元 (indim, hiddendim, outdimのいずれか)以上の最小の16の倍数にSKEWを与えたもの
curDim: 現在のデータの次元
shmem_intermediate: intermediate
destination: 出力先
残りの処理は先ほど示した図と見比べてみてください.

MLPの実装: 順伝播 MLP_Forward

順伝播のメインディッシュです.この関数では全結合層 + 活性化層の処理を行います.すなわち
 Y = Activation(WX) を行います.実際にはYとXは両方ともにintermediateという領域を使用します.
wmmaを使用するために次のヘッダーファイルをインクルードします.CUDAを利用していれば既に存在しているはずです.

#include <mma.h>

では関数を見ていきましょう.

/////////////////////////////// MLP IMPLEMENTATION //////////////////////////////////////////////////

/*
 * inDim ---(FC)--> outDim ---(Activation)--> outDim
 * return: ActivationFunc(matmul(weight,intermediate))
 */
template <const int inDim, const int outDim, const int inDim_aligned, const int outDim_aligned>
MFFM_DEVICE void MLP_Forward(Activation activation, __half* __restrict__ intermediate, __half* __restrict__ weight, __half* __restrict__ buffer = nullptr) {
    using namespace nvcuda;

    // Y = WXを順伝播時に行う.(正確には Y_t = X_t * W_tを行う)
    // ここでのRow, Colとは転置前の行列の行と列の長さを表す
    constexpr int XRow = ONEBATCH_SIZE;
    constexpr int XCol = inDim;
    constexpr int WRow = inDim;
    constexpr int WCol = outDim;
    constexpr int YRow = ONEBATCH_SIZE;
    constexpr int YCol = outDim;

    constexpr int XRow_aligned = ONEBATCH_SIZE;
    constexpr int XCol_aligned = inDim_aligned;
    constexpr int WRow_aligned = inDim_aligned;
    constexpr int WCol_aligned = outDim_aligned;
    constexpr int YRow_aligned = ONEBATCH_SIZE;
    constexpr int YCol_aligned = outDim_aligned;

    constexpr int nBlock_XRow = XRow_aligned / TENSOR_ROW;
    constexpr int nBlock_XCol = XCol_aligned / TENSOR_ROW;
    constexpr int nBlock_WRow = WRow_aligned / TENSOR_ROW;
    constexpr int nBlock_WCol = WCol_aligned / TENSOR_ROW;
    constexpr int nBlock_YRow = YRow_aligned / TENSOR_ROW;
    constexpr int nBlock_YCol = YCol_aligned / TENSOR_ROW;

    const int bx = blockIdx.x;
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;

    constexpr int warp_index_required = nBlock_WCol;

    // Fragments
    wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> inputs_frag;
    wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::col_major> weights_frag[nBlock_WRow];
    wmma::fragment<wmma::accumulator, 16, 16, 16, __half> outputs_frag[nBlock_YRow];

    // weightをロード
#pragma unroll
    for (int i = 0; i < nBlock_WRow; i++) {
        // 必要ないwarpは黙らせておく
        if (ty < warp_index_required) {
            wmma::load_matrix_sync(weights_frag[i], weight + WRow_aligned * TENSOR_ROW * ty + TENSOR_ROW * i, WRow_aligned);
        }
    }

    __syncthreads();

    // 入力をロードしてY_t = X_t*W_t
#pragma unroll
    for (int i = 0; i < nBlock_YRow; i++) {
        // Yを0で初期化しておく
        wmma::fill_fragment(outputs_frag[i], __float2half(0.0f));

#pragma unroll
        for (int j = 0; j < nBlock_XCol; j++) {
            // inputsをTENSORにロード
            wmma::load_matrix_sync(inputs_frag, intermediate + TENSOR_ROW * j + (TENSOR_ROW * i) * (inDim_aligned + SKEW), inDim_aligned + SKEW);
            // matmal
            wmma::mma_sync(outputs_frag[i], inputs_frag, weights_frag[j], outputs_frag[i]);
        }

        // Activation
        Activation_Forward<__half>(activation, outputs_frag[i], outputs_frag[i]);
    }

    __syncthreads();

    // 次の入力へ記録
#pragma unroll
    for (int i = 0; i < nBlock_YRow; i++) {
        // 必要ないwarpは黙らせておく
        if (ty < warp_index_required) {
            wmma::store_matrix_sync(intermediate + TENSOR_ROW * ty + i * TENSOR_ROW * (outDim_aligned + SKEW), outputs_frag[i], outDim_aligned + SKEW, wmma::mem_row_major);
        }
    }
    __syncthreads();
}

これも一つずつ見ていきましょう.

template <const int inDim, const int outDim, const int inDim_aligned, const int outDim_aligned>
MFFM_DEVICE void MLP_Forward(Activation activation, __half* __restrict__ intermediate, __half* __restrict__ weight, __half* __restrict__ buffer = nullptr) {
...

inDim: このMLP一層における入力次元
outDim: このMLP一層における出力次元
inDim_aligned: inDim以上の最小の16の倍数
outDim_aligned: outDim以上の最小の16の倍数
activation: このMLP一層における活性化層の種類
intermediate: このMLP一層のIO weight: この一層における全結合層のパラメーター W
buffer: あるタイミングで実装を変えた際に不使用となった変数です(バッファをこの関数の外で記録するようになったためです.いつか消します.)

...
using namespace nvcuda;
...

wmmaの関数はnvcuda::wmma::の中にあります.(じゃあwmmaもusing namespaceすればいいのではって?確かに……)

...
// Y = WXを順伝播時に行う.(正確には Y_t = X_t * W_tを行う)
// ここでのRow, Colとは転置前の行列の行と列の長さを表す
constexpr int XRow = ONEBATCH_SIZE;
...
constexpr int YCol_aligned = outDim_aligned;
...

行列演算としては
 Y = WX
をするのですが,実は次が成り立ちます.

 Y^T = X^T W^T

今回は後者で実装します.いずれにせよ,転置前の行列 X, Y, Wにおける横幅サイズ(Row),SKEWを考慮しない縦幅サイズ(Col)をすべて書き出しておきました.全部は使いません.

...
constexpr int nBlock_XRow = XRow_aligned / TENSOR_ROW;
...
constexpr int nBlock_YCol = YCol_aligned / TENSOR_ROW;
...

先ほど求めた各行列の横幅サイズと縦幅サイズはすべて16の倍数です(そうなるようにゼロパディングを行っています).各行列を一マス16の正方形の集まりで表現した際の縦に並ぶブロック数と横に並ぶブロック数を計算しておきます.これらはすべてコンパイル時に確定します.

これ以降の説明に入る前に,いったん図で確認しておきましょう.

関数内で行う計算は次の図の計算です.

今回の実装では,図内で橙,緑,青,灰で塗られた4色の領域に関しては並列で計算します.wmmaの行列計算は1warpが共同作業という形で行われます.つまり,一回の16x16行列同士の掛け算には32スレッド必要ということです.今回は4色,すなわち行列の乗算タスクを4個同時にこなしていくこととなります.よって必要なwarpは4です.ここで,出力層の次元は128以下を想定しているとしましょう.これが128である場合,この配色は8色(行列乗算タスクは8個同時)になります.よって必要なwarp数は8です.
このように,MLP内部での最大次元によって必要なwarp数が異なります.ただし,先程書いたload_input()やstore_intermediate()においては4warpを必要としています.よって,カーネル実行時に設定するBlockSize.yは
 BlockSize.y = max(4, MaxDim_aligned / 16)
となります.(入力次元は考慮しなくても良いかも)
議論の余地として,このことによってカーネル実行中のInactiveなスレッドは増加する(Occupancyが低下する)ので,warp数を4に固定したうえで,これを並列数4の必要に応じた逐次実行にしておくというのも良いと思います.というかそちらもこちらで試しておくべきでした.また試します.
いずれにせよ,仮にwarp数が沢山あったとしても,この関数が必要としているwarp数は(OutDim_aligned / 16)となるわけです.よって,

...
constexpr int warp_index_required = nBlock_WCol;
...

となります.
ここで,仮にMaxDim_alignedが128の時に Xの横方向に対しても並列化を行ったとします.すると並列数は64となります.……さて,64warp必要とするのであれば,必要なBlockSizeは2048となります.これではCUDAカーネル実行が出来ません.よってこちらに関しては並列化しません.

...
// Fragments
wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> inputs_frag;
wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::col_major> weights_frag[nBlock_WRow];
wmma::fragment<wmma::accumulator, 16, 16, 16, __half> outputs_frag[nBlock_YRow];
...

wmma特有の構造体が出てきました.具体的な説明は他の記事に任せるとして,
inputs_frag: [tex: XT]の小行列です.先ほど示した図において矢印の向きが右であったのでrow_majorでロードします.
weights_frag: [tex: WT]の小行列です.先ほど示した図において矢印の向きが下であったのでcol_majorでロードします.
outputs_frag: [tex:YT]の小行列です.

ではここから演算に入ります.どういう演算なのかを確認しておきましょう.

さて,この図を参考にしながら見ていきます.

...
// weightをロード
#pragma unroll
for (int i = 0; i < nBlock_WRow; i++) {
    // 必要ないwarpは黙らせておく
    if (ty < warp_index_required) {
        wmma::load_matrix_sync(weights_frag[i], weight + WRow_aligned * TENSOR_ROW * ty + TENSOR_ROW * i, WRow_aligned);
    }
}

__syncthreads();
...

そのwarpが必要としているweight(図中で対応している色で塗られているブロック)をweights配列からロードします.すべてここでロードします.つまり,weightsからのロードはこの1回だけです.インデックスの指定は図を参考にしてください.あと,ブロック内での同期を取りたいので__syncthreads()を置いておきましょう.

// 入力をロードしてY_t = X_t*W_t
#pragma unroll
for (int i = 0; i < nBlock_YRow; i++) {
    // Yを0で初期化しておく
    wmma::fill_fragment(outputs_frag[i], __float2half(0.0f));

#pragma unroll
    for (int j = 0; j < nBlock_XCol; j++) {
        // inputsをTENSORにロード
        wmma::load_matrix_sync(inputs_frag, intermediate + TENSOR_ROW * j + (TENSOR_ROW * i) * (inDim_aligned + SKEW), inDim_aligned + SKEW);
        // matmal
        wmma::mma_sync(outputs_frag[i], inputs_frag, weights_frag[j], outputs_frag[i]);
    }

    // Activation
    Activation_Forward<__half>(activation, outputs_frag[i], outputs_frag[i]);
}

__syncthreads();

Yは未初期化の状態ではゴミ値が入っているので0で初期化しておきましょう.カウンタ変数がiのループは図中でXYにおいて上から下に移動していくブロックです.カウンタ変数がjのループは図中における移動する緑色の矢印です.
Activation_Forwardは文字通り活性化層の関数ですが,この後で記述することにします.

// 次の入力へ記録
#pragma unroll
for (int i = 0; i < nBlock_YRow; i++) {
    // 必要ないwarpは黙らせておく
    if (ty < warp_index_required) {
        wmma::store_matrix_sync(intermediate + TENSOR_ROW * ty + i * TENSOR_ROW * (outDim_aligned + SKEW), outputs_frag[i], outDim_aligned + SKEW, wmma::mem_row_major);
    }
}
__syncthreads();

演算結果を保存する処理です.データ配列の向き(図中の黒矢印)に注意して,row_majorで保存しましょう.

活性化層

活性化層の順伝播の関数が出てきました.簡単に触れておきます.

template <typename T, typename fragment_t>
MFFM_DEVICE void Activation_Forward(Activation activation, fragment_t& AfterFC, fragment_t& Activated) {
    switch (activation) {
    case Activation::ReLU:
#pragma unroll
        for (int i = 0; i < AfterFC.num_elements; i++) {
            Activated.x[i] = ((T)AfterFC.x[i] > (T)0.0f) ? AfterFC.x[i] : (T)0.0f;
        }
        break;
    case Activation::LeakyReLU:
#pragma unroll
        for (int i = 0; i < AfterFC.num_elements; i++) {
            Activated.x[i] = ((T)AfterFC.x[i] > (T)0.0f) ? AfterFC.x[i] : (T)0.05f * AfterFC.x[i];
        }
        break;
    case Activation::Sigmoid:
#pragma unroll
        for (int i = 0; i < AfterFC.num_elements; i++) {
            Activated.x[i] = (T)1.0f / ((T)1.0 + (T)expf(-AfterFC.x[i]));
        }
        break;
    default:
        printf("Invalid Activation Type\n");
        break;
    }
}

この関数ではReLU,LeakyReLU,Sigmoidの実装が書かれています.fragment_tにはwmma::fragmentが入ります.各活性化層の計算式についてはここでは説明しません.
.num_elements()はfragment_tのもつメンバ関数で,そのスレッドのfragmentが保持する要素の数を返してくれます.

順伝播のglobal memoryの関わるメモリ読み書きの確認

NNの入力(global memory)をshared memoryにロード: 1回
MLPの全結合層のパラメーターをロード: 1パラメーターにつき1回
MLPのintermediateをbuffersに保存: およそ層数回 (推論時には不要!)
NNの出力を保存: 1回

このようにみるとglobal memoryとのメモリのやり取りがほとんど無く,特に推論処理のみでは実現しうる最小回数であり,メモリ的にかなり効率が良いことが分かります.

MLPの実装: 逆伝播 外観

Forwardはかなり丁寧に説明しました.Backwardに関してはこれまでの処理を逆方向にやっていくだけなので,これまでの処理が分かっていれば特に理解に困ることはないと思います.

template <const uint32_t indim, const uint32_t hiddendim, const uint32_t outdim, const uint32_t nHiddenLayer>
MFFM_DEVICE void Kernel_Debug_train_backward(const uint32_t BatchSize, Optimize Optim, Activation ActHid, Activation ActOut, const int epoch, __half* intermediate, __half* LossDerivativeSumOfBlock, __half* weights, const __half* buffers,
    float* LossDerivativeSumALL, float* AdditionalParam) 
{

    // 即値
    constexpr uint32_t indim_aligned = next_multiple(indim, TENSOR_ROW);
    constexpr uint32_t hiddendim_aligned = next_multiple(hiddendim, TENSOR_ROW);
    constexpr uint32_t outdim_aligned = next_multiple(outdim, TENSOR_ROW);
    constexpr uint32_t weightsize = indim_aligned * hiddendim_aligned + (nHiddenLayer - 1) * hiddendim_aligned * hiddendim_aligned + hiddendim_aligned * outdim_aligned;
    const uint32_t buffersize = (indim_aligned + nHiddenLayer * hiddendim_aligned + outdim_aligned) * BatchSize;

    uint32_t weights_elem_idx = weightsize - outdim_aligned * hiddendim_aligned;
    uint32_t buffers_elem_idx = buffersize - hiddendim_aligned * BatchSize - outdim_aligned * BatchSize;
    uint32_t buffers_elem_idx_blockbias = outdim_aligned * ONEBATCH_SIZE * blockIdx.x;

    __syncthreads();
    // 出力層の逆伝播
    MLP_Backward<hiddendim, outdim, hiddendim_aligned, outdim_aligned>(BatchSize, ActOut, intermediate, weights + weights_elem_idx,
        buffers + buffers_elem_idx, LossDerivativeSumOfBlock);
    SumUp_LossDerivative<hiddendim, outdim, hiddendim_aligned, outdim_aligned>(LossDerivativeSumOfBlock, LossDerivativeSumALL + weights_elem_idx);
    __syncthreads();

    // 隠れ層の逆伝播
    for (int i = 0; i < nHiddenLayer - 1; i++) {
        weights_elem_idx -= hiddendim_aligned * hiddendim_aligned;
        buffers_elem_idx -= hiddendim_aligned * BatchSize;
        buffers_elem_idx_blockbias = hiddendim_aligned * ONEBATCH_SIZE * blockIdx.x;
        MLP_Backward<hiddendim, hiddendim, hiddendim_aligned, hiddendim_aligned>(BatchSize, ActHid, intermediate, weights + weights_elem_idx, buffers + buffers_elem_idx, LossDerivativeSumOfBlock);
        SumUp_LossDerivative<hiddendim, hiddendim, hiddendim_aligned, hiddendim_aligned>(LossDerivativeSumOfBlock, LossDerivativeSumALL + weights_elem_idx);
    }
    weights_elem_idx -= indim_aligned * hiddendim_aligned;
    buffers_elem_idx -= indim_aligned * BatchSize;
    buffers_elem_idx_blockbias = indim_aligned * ONEBATCH_SIZE * blockIdx.x;

    // 入力層の逆伝播
    MLP_Backward<indim, hiddendim, indim_aligned, hiddendim_aligned>(BatchSize, ActHid, intermediate, weights + weights_elem_idx, buffers + buffers_elem_idx, LossDerivativeSumOfBlock);
    SumUp_LossDerivative<indim, hiddendim, indim_aligned, hiddendim_aligned>(LossDerivativeSumOfBlock, LossDerivativeSumALL + weights_elem_idx);
    __syncthreads();
}

これもざっくり見ていきましょう.

template <const uint32_t indim, const uint32_t hiddendim, const uint32_t outdim, const uint32_t nHiddenLayer>
MFFM_DEVICE void Kernel_Debug_train_backward(const uint32_t BatchSize, Optimize Optim, Activation ActHid, Activation ActOut, const int epoch, __half* intermediate, __half* LossDerivativeSumOfBlock, __half* weights, const __half* buffers, float* LossDerivativeSumALL, float* AdditionalParam) {
...

indim ~ nHiddenLayer: 順伝播と同じ
BatchSize: バッチサイズ
Optimize構造体:

enum class Optimize {
    GD,
    Adam
};

これです.最適化関数を設定します.GDはGradient Descendant,Adamは名前の通りです.
ActHid, ActOut: 順伝播と同じ
epoch: 学習のイテレーション
intermediate: 順伝播と同じ
LossDerivativeSumOfBlock: コンセプト2で出てきた,(7): ブロック内部でのバッチによるパラメーターの勾配を記録する配列です.今回はshared memoryに載ってます.
weights,buffers: 順伝播と同じ
LossDerivativeSumALL: コンセプト2で出てきた,(6): バッチ全体によるパラメーターの勾配を記録する配列です.
AdditionalParam: 過去の実装では使用していた配列です.今回は使いません.というかいつか消します.

...
// 即値
constexpr uint32_t indim_aligned = next_multiple(indim, TENSOR_ROW);
constexpr uint32_t hiddendim_aligned = next_multiple(hiddendim, TENSOR_ROW);
constexpr uint32_t outdim_aligned = next_multiple(outdim, TENSOR_ROW);
constexpr uint32_t weightsize = indim_aligned * hiddendim_aligned + (nHiddenLayer - 1) * hiddendim_aligned * hiddendim_aligned + hiddendim_aligned * outdim_aligned;
const uint32_t buffersize = (indim_aligned + nHiddenLayer * hiddendim_aligned + outdim_aligned) * BatchSize;

uint32_t weights_elem_idx = weightsize - outdim_aligned * hiddendim_aligned;
uint32_t buffers_elem_idx = buffersize - hiddendim_aligned * BatchSize - outdim_aligned * BatchSize;
uint32_t buffers_elem_idx_blockbias = outdim_aligned * ONEBATCH_SIZE * blockIdx.x;

__syncthreads();
...

順伝播時にやってたことと同じです.

...
// 出力層の逆伝播
MLP_Backward<hiddendim, outdim, hiddendim_aligned, outdim_aligned>(BatchSize, ActOut, intermediate, weights + weights_elem_idx,
    buffers + buffers_elem_idx, LossDerivativeSumOfBlock);
SumUp_LossDerivative<hiddendim, outdim, hiddendim_aligned, outdim_aligned>(LossDerivativeSumOfBlock, LossDerivativeSumALL + weights_elem_idx);
__syncthreads();
...

MLP_Backward: 活性化層→全結合層の順番で逆伝播する関数です.後述します.
SumUp_LossDerivative: 各ブロックでは128バッチの誤差逆伝播によって得られたパラメーターの勾配を持っていますが,それを全スレッドで合計してあげる必要があります.それをする関数です.後述します.

隠れ層,入力層の逆伝播も出力層と同様に行います.

MLPの実装: 逆伝播 MLP_Backward

というわけで軽く逆伝播の処理を見ていきましょう.

template <const int inDim, const int outDim, const int inDim_aligned, const int outDim_aligned>
MFFM_DEVICE void MLP_Backward(const uint32_t BatchSize, Activation activation, __half* __restrict__ intermediate, __half* __restrict__ weight, const __half* __restrict__ buffer, __half* __restrict__ LossDerivative) {
    using namespace nvcuda;

    // Y = WXを順伝播時に行った.(正確には Y_t = X_t * W_tを行った)
    // ここでのRow, Colとは転置前の行列の行と列の長さを表す
    constexpr int XRow_aligned = ONEBATCH_SIZE;
    constexpr int XCol_aligned = inDim_aligned;
    constexpr int WRow_aligned = inDim_aligned;
    constexpr int WCol_aligned = outDim_aligned;
    constexpr int YRow_aligned = ONEBATCH_SIZE;
    constexpr int YCol_aligned = outDim_aligned;

    constexpr int nBlock_XRow = XRow_aligned / TENSOR_ROW;
    constexpr int nBlock_XCol = XCol_aligned / TENSOR_ROW;
    constexpr int nBlock_WRow = WRow_aligned / TENSOR_ROW;
    constexpr int nBlock_WCol = WCol_aligned / TENSOR_ROW;
    constexpr int nBlock_YRow = YRow_aligned / TENSOR_ROW;
    constexpr int nBlock_YCol = YCol_aligned / TENSOR_ROW;

    const int bx = blockIdx.x;
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;

    constexpr int warp_index_required = nBlock_WRow;

    const __half* Buffer_MLPin = buffer + inDim_aligned * ONEBATCH_SIZE * bx;
    const __half* Buffer_MLPout = buffer + inDim_aligned * BatchSize + outDim_aligned * ONEBATCH_SIZE * bx;

    // Fragments
    wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> dLdY_frag_for_dLdX;      // dLdoutputを指す
    wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::row_major> weights_frag[nBlock_WCol];
    wmma::fragment<wmma::accumulator, 16, 16, 16, __half> dLdX_frag[nBlock_XRow];               // dLdinputを指す

    wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::col_major> dLdY_frag_for_dLdW;      // dLdoutputを指す
    wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::row_major> buffer_frag[nBlock_XRow];  // MLPinを格納する       
    wmma::fragment<wmma::accumulator, 16, 16, 16, __half> dLdweights_frag[nBlock_WCol];            // dLdW

    // Activation
    // 32行のブロックごとに処理する
    if (ty * 32 + tx < ONEBATCH_SIZE) {
        Activation_Backward<__half>(activation, outDim, Buffer_MLPout + outDim_aligned * (ty * 32 + tx),
            intermediate + (outDim_aligned + SKEW) * (ty * 32 + tx),
            intermediate + (outDim_aligned + SKEW) * (ty * 32 + tx));
    }

    // weightをロード
#pragma unroll
    for (int i = 0; i < nBlock_WCol; i++) {
        // 必要ないwarpは黙らせておく
        if (ty < warp_index_required) {
            wmma::load_matrix_sync(weights_frag[i], weight + TENSOR_ROW * ty + WRow_aligned * TENSOR_ROW * i, WRow_aligned);
        }
    }

    // MLPinをロード
#pragma unroll
    for (int i = 0; i < nBlock_XRow; i++) {
        // 必要ないwarpは黙らせておく
        if (ty < warp_index_required) {
            wmma::load_matrix_sync(buffer_frag[i], Buffer_MLPin + TENSOR_ROW * ty + XCol_aligned * TENSOR_ROW * i, XCol_aligned);
        }
    }
    __syncthreads();

    // dLdX_t = dLdY_t * W
#pragma unroll
    for (int i = 0; i < nBlock_XRow; i++) {
        // 初期化
        wmma::fill_fragment(dLdX_frag[i], __float2half(0.0f));

#pragma unroll
        for (int j = 0; j < nBlock_YCol; j++) {
            // dLdoutputをTENSORにロード
            wmma::load_matrix_sync(dLdY_frag_for_dLdX, intermediate + TENSOR_ROW * j + (TENSOR_ROW * i) * (YCol_aligned + SKEW), YCol_aligned + SKEW);
            // matmal
            wmma::mma_sync(dLdX_frag[i], dLdY_frag_for_dLdX, weights_frag[j], dLdX_frag[i]);
        }
    }

    // dLdW = dLdY * X_t
#pragma unroll
    for (int i = 0; i < nBlock_WCol; i++) {
        // 初期化
        wmma::fill_fragment(dLdweights_frag[i], __float2half(0.0f));

#pragma unroll
        for (int j = 0; j < nBlock_XRow; j++) {
            // dLdoutputをTENSORにロード
            wmma::load_matrix_sync(dLdY_frag_for_dLdW, intermediate + TENSOR_ROW * i + TENSOR_ROW * (outDim_aligned + SKEW) * j, outDim_aligned + SKEW);
            // matmal
            wmma::mma_sync(dLdweights_frag[i], dLdY_frag_for_dLdW, buffer_frag[j], dLdweights_frag[i]);
        }
    }

    __syncthreads();

    // 次の入力へ記録
#pragma unroll
    for (int i = 0; i < nBlock_XRow; i++) {
        // 必要ないwarpは黙らせておく
        if (ty < warp_index_required) {
            wmma::store_matrix_sync(intermediate + TENSOR_ROW * ty + i * TENSOR_ROW * (inDim_aligned + SKEW), dLdX_frag[i], inDim_aligned + SKEW, wmma::mem_row_major);
        }
    }
#pragma unroll
    for (int i = 0; i < nBlock_WCol; i++) {
        // 必要ないwarpは黙らせておく
        if (ty < warp_index_required) {
            wmma::store_matrix_sync(LossDerivative + TENSOR_ROW * ty + i * TENSOR_ROW * (WRow_aligned + SKEW), dLdweights_frag[i], WRow_aligned + SKEW, wmma::mem_row_major);
        }
    }
    __syncthreads();
}

部分部分でみていきましょう.

...
template <const int inDim, const int outDim, const int inDim_aligned, const int outDim_aligned>
...

念のためここでも定義しておきます.
inDim: 順伝播時のMLPの入力次元
outDim: 順伝播時のMLPの出力次元
つまり,"inとoutの定義"は順伝播を基準とします.

...
MFFM_DEVICE void MLP_Backward(const uint32_t BatchSize, Activation activation, __half* __restrict__ intermediate, __half* __restrict__ weight, const __half* __restrict__ buffer, __half* __restrict__ LossDerivative) {
...

BatchSize: バッチサイズ
activation: このMLPの活性化層の種類
intermediate: 略
weight: このMLPの全結合層のパラメーター
buffer: このMLPの順伝播時の入力側のintermediateの先頭を指すポインタ
LossDerivative: このMLPの全結合層の,ブロックレベルの勾配を記録する配列(の先頭のポインタ)

const __half* Buffer_MLPin = buffer + inDim_aligned * ONEBATCH_SIZE * bx;
const __half* Buffer_MLPout = buffer + inDim_aligned * BatchSize + outDim_aligned * ONEBATCH_SIZE * bx;

bufferはこのMLPの順伝播時の入力側のintermediateの先頭を指すポインタなので,入力側のデータ数だけポインタをずらすことで出力側のintermediateの先頭を指すポインタを得られます.順伝播時にもやりましたが,bufferにはすべてのバッチの記録情報が入っています.なので今回でもblockIdxに応じてインデックス調整を入れてあげます.

処理的には先に活性化層がやってきます.先に活性化層を確認しましょう.

// Activation
// 32行のブロックごとに処理する
if (ty * 32 + tx < ONEBATCH_SIZE) {
        Activation_Backward<__half>(activation, outDim, Buffer_MLPout + outDim_aligned * (ty * 32 + tx),
        intermediate + (outDim_aligned + SKEW) * (ty * 32 + tx),
        intermediate + (outDim_aligned + SKEW) * (ty * 32 + tx));
    }

Activation_Backwardが活性化層の逆伝播の処理です.この関数ではブロック内にある128バッチの誤差情報を,128並列で処理します.

/*
 * Activation Back Propatation
 * DIM次元の特徴ベクトルがONEBATCH_SIZE個並べられている.
 * 各ワープはそれらのうち32個を担当する
 */
template <typename T>
MFFM_DEVICE void Activation_Backward(Activation activation, const int DIM, const __half* Activated, __half* dLdAfterFC, __half* dLdActivated) {
    __syncthreads();
    switch (activation) {
    case Activation::ReLU:
#pragma unroll
        for (int i = 0; i < DIM; i++) {
            dLdAfterFC[i] = ((T)Activated[i] > (T)0.0f) ? dLdActivated[i] : (__half)0.0f;
        }
        break;
    case Activation::LeakyReLU:
#pragma unroll
        for (int i = 0; i < DIM; i++) {
            dLdAfterFC[i] = ((T)Activated[i] > (T)0.0f) ? dLdActivated[i] : (__half)0.05f * dLdActivated[i];
        }
        break;
    case Activation::Sigmoid:
#pragma unroll
        for (int i = 0; i < DIM; i++) {
            dLdAfterFC[i] = (T)dLdActivated[i] * ((T)1.0 - Activated[i]) * Activated[i];
        }
        break;
    default:
        printf("Invalid Activation Type\n");
        break;
    }
}
template <typename T>
MFFM_DEVICE void Activation_Backward(Activation activation, const int DIM, const __half* Activated, __half* dLdAfterFC, __half* dLdActivated) {
...

activation: 略
DIM: その時点での,align処理をしていない生の次元数です.
Activated: 順伝播時の活性化層の出力です.bufferから読み出します.
dLdAfterFC: 全結合層の出力に流していく誤差情報です.(FC: Fully Connected)
dLdActivated: 活性化層の出力に流れ込んできた誤差情報です.

...
switch (activation) {
    case Activation::ReLU:
#pragma unroll
        for (int i = 0; i < DIM; i++) {
            dLdAfterFC[i] = ((T)Activated[i] > (T)0.0f) ? dLdActivated[i] : (__half)0.0f;
        }
        break;
    case Activation::LeakyReLU:
#pragma unroll
        for (int i = 0; i < DIM; i++) {
            dLdAfterFC[i] = ((T)Activated[i] > (T)0.0f) ? dLdActivated[i] : (__half)0.05f * dLdActivated[i];
        }
        break;
    case Activation::Sigmoid:
#pragma unroll
        for (int i = 0; i < DIM; i++) {
            dLdAfterFC[i] = (T)dLdActivated[i] * ((T)1.0 - Activated[i]) * Activated[i];
        }
        break;
    default:
        printf("Invalid Activation Type\n");
        break;
    }
...

順伝播時にはfragment_tとして処理中のデータが流れていましたが,今回は単純なhalf配列として扱います.この配列において,各スレッドが扱う要素のうち最初のDIM個以降はゼロパディングされている領域なので処理しません.順伝播同様,各活性化層の処理の説明はしません.

では全結合層の逆伝播の処理を見ていきます.ここで処理を確認しておきましょう.

 \displaystyle
\begin{align}
\dfrac{\partial L}{\partial X^T} &= \dfrac{\partial L}{\partial Y^T}W \\
\dfrac{\partial L}{\partial W} &= \dfrac{\partial L}{\partial Y}X^T
\end{align}

順伝播時と同じように行列を16x16の小行列に分割して計算していきます.図で確認しましょう.

 \dfrac{\partial L}{\partial X^T} = \dfrac{\partial L}{\partial Y^T}W

この計算に必要なものは WとdLdYです.データの向きに注目してfragmentにロードします.

...

// Fragments
wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> dLdY_frag_for_dLdX;      // dLdoutputを指す
wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::row_major> weights_frag[nBlock_WCol];
wmma::fragment<wmma::accumulator, 16, 16, 16, __half> dLdX_frag[nBlock_XRow];               // dLdinputを指す

...

// weightをロード
#pragma unroll
for (int i = 0; i < nBlock_WCol; i++) {
    // 必要ないwarpは黙らせておく
    if (ty < warp_index_required) {
        wmma::load_matrix_sync(weights_frag[i], weight + TENSOR_ROW * ty + WRow_aligned * TENSOR_ROW * i, WRow_aligned);
    }
}

...

// dLdX_t = dLdY_t * W
#pragma unroll
for (int i = 0; i < nBlock_XRow; i++) {
    // 初期化
    wmma::fill_fragment(dLdX_frag[i], __float2half(0.0f));

#pragma unroll
    for (int j = 0; j < nBlock_YCol; j++) {
        // dLdoutputをTENSORにロード
        wmma::load_matrix_sync(dLdY_frag_for_dLdX, intermediate + TENSOR_ROW * j + (TENSOR_ROW * i) * (YCol_aligned + SKEW), YCol_aligned + SKEW);
        // matmal
        wmma::mma_sync(dLdX_frag[i], dLdY_frag_for_dLdX, weights_frag[j], dLdX_frag[i]);
    }
}

...

// 次の入力へ記録
#pragma unroll
for (int i = 0; i < nBlock_XRow; i++) {
    // 必要ないwarpは黙らせておく
    if (ty < warp_index_required) {
        wmma::store_matrix_sync(intermediate + TENSOR_ROW * ty + i * TENSOR_ROW * (inDim_aligned + SKEW), dLdX_frag[i], inDim_aligned + SKEW, wmma::mem_row_major);
    }
}
...

図中に表示しているデータの配置方向(黒い矢印)に注意してrow_majorかcol_majorかを確認しましょう.順伝播と同様にして,最初にweightをすべてロードしておきます.また,この処理においてはこれ以上weightをロードする必要はありません.やはり1パラメーターにつき1回のロードですみます.ここまでこれはやることは順伝播と同じなので,あとは図を参考にしてfragmentの初期化に使用するintermediateのインデックスを丁寧に設定してあげてください.

 \dfrac{\partial L}{\partial W} = \dfrac{\partial L}{\partial Y}X^T

この計算に必要なものはdLdYとXです.データの向きに注目してfragmentにロードします.ここで,順伝播時にbufferにはintermediateの1バッチの次元を,本来の次元に戻さずに16の倍数次元のまま保存しました.これは,ここでfragmentにロードする際に特にレイアウトの整形作業等をすることなく直接ロードすることが出来ます.メモリ的には少し無駄ですが,こうした方が実装しやすいのと速そうなのでこうしました.では,コードを見ましょう.

...
wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::col_major> dLdY_frag_for_dLdW;      // dLdoutputを指す
wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::row_major> buffer_frag[nBlock_XRow];  // MLPinを格納する       
wmma::fragment<wmma::accumulator, 16, 16, 16, __half> dLdweights_frag[nBlock_WCol];            // dLdW

...

// MLPinをロード
#pragma unroll
for (int i = 0; i < nBlock_XRow; i++) {
    // 必要ないwarpは黙らせておく
    if (ty < warp_index_required) {
        wmma::load_matrix_sync(buffer_frag[i], Buffer_MLPin + TENSOR_ROW * ty + XCol_aligned * TENSOR_ROW * i, XCol_aligned);
    }
}
__syncthreads();

...

// dLdW = dLdY * X_t
#pragma unroll
for (int i = 0; i < nBlock_WCol; i++) {
    // 初期化
    wmma::fill_fragment(dLdweights_frag[i], __float2half(0.0f));

#pragma unroll
    for (int j = 0; j < nBlock_XRow; j++) {
        // dLdoutputをTENSORにロード
        wmma::load_matrix_sync(dLdY_frag_for_dLdW, intermediate + TENSOR_ROW * i + TENSOR_ROW * (outDim_aligned + SKEW) * j, outDim_aligned + SKEW);
        // matmal
        wmma::mma_sync(dLdweights_frag[i], dLdY_frag_for_dLdW, buffer_frag[j], dLdweights_frag[i]);
    }
}

...

#pragma unroll
for (int i = 0; i < nBlock_WCol; i++) {
    // 必要ないwarpは黙らせておく
    if (ty < warp_index_required) {
        wmma::store_matrix_sync(LossDerivative + TENSOR_ROW * ty + i * TENSOR_ROW * (WRow_aligned + SKEW), dLdweights_frag[i], WRow_aligned + SKEW, wmma::mem_row_major);
    }
}

図中の黒い矢印(データの配置方向)に注意してfragmentの設定を丁寧にしましょう.メモリのロードとしてはbuffer(global memory)からのロードが1回,dLdYのロードが1回(2回目)あります.dLdYに関してはdLdXの計算時にロードしていたじゃないかと思うかもしれませんが,転置しているので構造が異なります.転置しなければしなかったで今度はmatrix_bの方に行ってしまうので,やはりダメです.ここ悔しいですね.
細かい処理の部分は図を参考にしてください.

勾配の集約

さて,各ブロックにおける全結合層のパラメーターの勾配の計算は終わりました.これは128バッチの伝搬によって得られた勾配なので,全ブロック(即ち全スレッド)における勾配をaccumulateする必要があります.実はここは未だに実装が確定しきっていない箇所です.具体的な話は最後にして,とりあえず現在の私の実装を説明します.

///////////////////////////////// SUM UP THE LOSS DERIVATIVES ////////////////////////////////////////////////////////////
template<const int inDim, const int outDim, const int inDim_aligned, const int outDim_aligned>
MFFM_DEVICE void SumUp_LossDerivative(__half* __restrict__ LossDerivativeSumOfBlock, float* LossDerivativeSumALL) {
    int bx = blockIdx.x;
    int tx = threadIdx.x;
    int ty = threadIdx.y;
    int nThreads = 32 * blockDim.y;
    int threadId = 32 * ty + tx;

    const int WRow = inDim_aligned;
    const int WeightSize_this_layer = inDim_aligned * outDim_aligned;

#pragma unroll
    for (int i = threadId; i < WeightSize_this_layer; i += nThreads) {
        const int elem_idx = (nThreads * bx + i) % WeightSize_this_layer;
        const int row_idx = elem_idx % WRow;
        const int col_idx = elem_idx / WRow;
        const int elem_idx_shmem = (WRow + SKEW) * col_idx + row_idx;

        if (row_idx >= inDim || col_idx >= outDim) {
            continue;
        }

        atomicAdd(&LossDerivativeSumALL[elem_idx], (float)LossDerivativeSumOfBlock[elem_idx_shmem]);

    }
    __syncthreads();
}

部分部分で見ていきましょう.

template<const int inDim, const int outDim, const int inDim_aligned, const int outDim_aligned>
MFFM_DEVICE void SumUp_LossDerivative(__half* __restrict__ LossDerivativeSumOfBlock, float* LossDerivativeSumALL) {
...

inDim, outDim: その全結合層の入力次元と出力次元です(16の倍数にはしていない,生の次元) LossDerivativeSumOfBlock: コンセプト2で出てきた,(7): ブロック内部でのバッチによるパラメーターの勾配を記録する配列です.今回はshared memoryに載ってます. LossDerivativeSumALL: コンセプト2で出てきた,(6): バッチ全体によるパラメーターの勾配を記録する配列です.

int nThreads = 32 * blockDim.y;
int threadId = 32 * ty + tx;

const int WRow = inDim_aligned;
const int WeightSize_this_layer = inDim_aligned * outDim_aligned;

nThreads: 現在のCUDAカーネルで立っている,ブロック内部でのスレッド数(i.e. BlockSize)
threadId: ブロックレベルでのスレッド番号(全スレッド,すなわちグローバルなスレッド番号ではない)
WeightSize_this_layer: 現在の全結合層のパラメーター数です.ただし,zero-paddingした箇所も含めています.

#pragma unroll
for (int i = threadId; i < WeightSize_this_layer; i += nThreads) {
    const int elem_idx = (nThreads * bx + i) % WeightSize_this_layer;
    const int row_idx = elem_idx % WRow;
    const int col_idx = elem_idx / WRow;
    const int elem_idx_shmem = (WRow + SKEW) * col_idx + row_idx;

    if (row_idx >= inDim || col_idx >= outDim) {
        continue;
    }

    atomicAdd(&LossDerivativeSumALL[elem_idx], (float)LossDerivativeSumOfBlock[elem_idx_shmem]);
}
__syncthreads();

まあ何してるのって話ではありますよね.目的は単純で,全パラメーターの勾配をglobal memoryの領域にatomicAddしているだけです.もっと細かく見ましょう.

for (int i = threadId; i < WeightSize_this_layer; i += nThreads) {
...

立っているブロック内部のスレッドをなるべく全部使いたいです.なので,スレッド番号に対してユニークに初期の仕事(global memoryに値を格納するタスク)を振り分けます.その仕事が終われば,立っているスレッド数分だけストライドをまたいで次の仕事に向かいます.仕事がない(仕事の対象,すなわち格納するWeightのインデックスがout-of-range)場合は何もしません.もっと端的に書くと,スレッドIDがKのスレッドは

 
index \equiv K \mod nThreads

を満たすweightのインデックスの勾配をglobal memoryに格納します.

......だけではありません.次を見ましょう.

...
const int elem_idx = (nThreads * bx + i) % WeightSize_this_layer;
...

なんでbx(blockIdx)があるんだと.これは正直有効なのかは分からないのですが,blockIdxに対してもバイアスを与えています.これは何かというと,global memoryに勾配をaccumulateする際にはatomicAdd()を使用します.そのため,なるべく異なるブロックは異なるインデックスを参照してほしいものです.そのために,ブロックごとに「ブロック数が少ない場合はユニーク」なインデックスとなるように仕事を振り分けています.ただ,そんなに小さいブロック数になっていることは稀なので効果は分からないです.

...
const int row_idx = elem_idx % WRow;
const int col_idx = elem_idx / WRow;
const int elem_idx_shmem = (WRow + SKEW) * col_idx + row_idx;

if (row_idx >= inDim || col_idx >= outDim) {
    continue;
}

atomicAdd(&LossDerivativeSumALL[elem_idx], (float)LossDerivativeSumOfBlock[elem_idx_shmem]);
...

さて,計算したweightの勾配のインデックスがzero-paddingされている場所であれば飛ばします.そして,global memory上のメモリの位置をちゃんと計算して,atomicAddを行います.

逆伝播のglobal memoryの関わるメモリ読み書きの確認

NNの出力誤差をshared memoryにロード: 1回
MLPの全結合層のパラメーターをロード: 1パラメーターにつき1回
MLPの全結合層の勾配の集約(書き込み): 不明かつatomic
MLPのbufferをロード: およそ層の数だけ
NNの入力誤差を保存(不要な場合はしない): 1回
(もしかしたら他にもあるかも)
全体としてかなり抑えられてはいますが,やはり勾配の集約とbufferのロードがかなりメモリのやり取りを起こしています.実際,推論のみの処理と比べると順伝播,逆伝播ともに数倍から数十倍遅くなります.

最適化処理

実を言うと特に書くことはないです.単純に大量のスレッドを立てて,各スレッドが対応するパラメーターを最適化するだけです(勾配の集約と同じイメージです).実装は載せておきますが解説は省略いたします.

///////////////////////////////// OPTIMIZATION IMPLEMENTATION //////////////////////////////////////////////////////////////////////
template<const int inDim, const int outDim, const int inDim_aligned, const int outDim_aligned>
MFFM_DEVICE void Optimization(bool willOptimizeInsideThisKernel, const uint32_t BatchSize, Optimize optimize, __half* __restrict__ weight, float* __restrict__ dLdW, int epoch, float* __restrict__ AdditionalParam) {
    int bx = blockIdx.x;
    int tx = threadIdx.x;
    int ty = threadIdx.y;

    const int nThreads = 32 * blockDim.y * gridDim.x;
    const int threadId = blockIdx.x * 32 * blockDim.y + 32 * ty + tx;
    const int WeightSize_this_layer = inDim * outDim;

    __syncthreads();

#pragma unroll
    for (int i = threadId; i < WeightSize_this_layer; i += nThreads) {
        // サイズ調整時にゼロパディングした箇所は更新しない
        const int Rowidx = i % inDim;
        const int Colidx = i / inDim;
        const int idx = inDim_aligned * Colidx + Rowidx;

        if (!isfinite((float)dLdW[idx])) {
            dLdW[idx] = 0.0f;
            continue;
        }

        dLdW[idx] = dLdW[idx] / (float)BatchSize;

        switch (optimize) {
        case(Optimize::GD):
            weight[idx] = weight[idx] - (__half)LEARNINGRATE * (__half)dLdW[idx];
            break;
        case(Optimize::Adam):
            if (!AdamOptimize(AdditionalParam[2 * idx], AdditionalParam[2 * idx + 1], dLdW[idx], weight[idx], epoch)) {
                printf("%d %f %f \n", idx, (float)AdditionalParam[2 * idx], (float)AdditionalParam[2 * idx + 1]);
            }
            break;
        default:
            printf("Invalid Optimization Type\n");
            break;
        }
        dLdW[idx] = (__half)0.0f;
    }
    __syncthreads();
}

Adam Optimizer:

template<typename T>
MFFM_DEVICE bool AdamOptimize(float& m, float& v, float& dLdX, T& X, int t) {

    m = ADAM_BETA1 * m + (1.0f - ADAM_BETA1) * dLdX;
    v = ADAM_BETA2 * v + (1.0f - ADAM_BETA2) * dLdX * dLdX;

    float m_hat = m / (1.0f - powf(ADAM_BETA1, t));
    float v_hat = v / (1.0f - powf(ADAM_BETA2, t));

    float eps = 1e-4f;

    if (sqrtf(v_hat) + eps == 0.0f) { // ?!?!?!?!
        dLdX = 0.0f;
        return true; 
    }
    X = X - ((T)LEARNINGRATE * (T)(m_hat / (sqrtf(v_hat) + eps)));

    dLdX = 0.0f;
    return true;
}

1次元の関数近似

さて,これまでに実装した関数を組み合わせることでNNのモジュールが完成します.私の実装ではどのような挙動をするのかを確認します.まずは学習の精度から確認しましょう.

以下の関数を近似しましょう.


f(x) = 0.5 + 0.4 \sin(8x) + 0.1 \cos(20x)  \;\;\; (0 \leq x \leq 1)


つまりfは1次元の実数を1次元の実数に移します.……特に意味はないです.これを次のネットワーク構造で近似しました.
入力層の次元: 1次元
隠れ層の次元; 64次元
出力層の次元: 1次元
隠れ層の数: 3
誤差関数: Hubor Loss(閾値: 0.05)
隠れ層の活性化関数: Sigmoid
出力層の活性化関数: Sigmoid
全結合層のパラメーター初期分布: He
最適化関数: AdamもしくはGradient Descendant 学習率: Adamの場合は0.02,Gradient Descendantの場合は0.03
Adamのパラメーター:  \beta_1 = 0.9, \beta_2 = 0.99

このネットワークの最適化を行い,イテレーション毎の推定値と真値の平均二乗誤差をグラフにプロットしました.

Adam君は頑張って収束してくれましたが,Gradient Descendant君はダメでした.次に,近似の様子を見ていきましょう.

y_hatは推定値,yは真値を意味します.エンコーダーを通さない1次元を入力としているので上出来だと思います.最後に,肝心の処理速度を確認しましょう.

処理時間の計測にはC++のchronoを使用しました.計測方法としては,学習処理のforループの手前にてchronoによる時間計測を開始し,10000回のイテレーションが終わった直後にchronoで時間計測を完了します.その後,計測結果の時間を10000(イテレーション回数)で割り,1イテレーションあたりの所要時間の平均値として記録しました.最適化関数はAdamで,活性化層はすべてSigmoidです.実行GPUはGeForce RTX3080 10GBです.

隠れ層の数による処理時間の変化を示しています.横軸は対数目盛になっております.当然ですが隠れ層の数にある程度比例しているのが分かります.次元数の違いによっても,比例関係があります.

バッチサイズによる処理時間の変化を示しています.両対数グラフです.隠れ層の数は全ての計測において4としました.バッチサイズ8192までは特に変化が見られないものの,それ以降はバッチサイズに比例した時間がかかっているように見えます.

推論処理だけを行った際の処理時間を確認します.

隠れ層の数による処理時間の変化を示しています.横軸は対数目盛になっております.

バッチサイズによる処理時間の変化を示しています.両対数グラフです.隠れ層の数は全ての計測において4としました.横軸の計測範囲が学習処理の処理時間のグラフと異なることに注意してください.

両者ともに,全体としての傾向は学習処理を入れている場合と同じですが,やはりバッファへの保存や逆伝播の処理が無いという点で,推論処理のみの方が学習処理と比較して数倍から十数倍速いことが分かります.バッチサイズが8192を超えたのちには線形に増加しているのはスループットの上限が目に見えているのでしょうかね?

最後に

これまで実装してきたもので,MLPの順伝播,逆伝播,最適化が行えます.つまりNeRFにおける機械学習(ここではMLP)処理の根幹は出来上がりました.2編ではNeRFを実装するうえで不可欠となっているエンコーダーについて実装を提示し,処理を説明していこうと思います.それの後,3編では本格的にNeRFの自分の実装を説明します.
これ以降は実装中に私が感じた疑問点を述べる領域です.


今回の実装では順伝播処理と逆伝播処理が同じカーネルで実行できました.しかし,逆伝播と最適化は同じカーネルで実行することが出来ませんでした.実は,最適化関数がGradient Descendantであれば可能だったのですが,Adamを使用すると厳しいと思ってます.この理由としては,Adamを実行する際にはmやvといった値を計算します.すなわち,実行時にはすべてのバッチの伝搬と誤差集約が終わっている必要があります.ですが,CUDAの実行として,デバイスレベルでの同期はカーネル外部でないと出来ません(そもそもとして,すべてのブロックが同時には動きません).それゆえに,デバイスレベルでの同期をとるために一度カーネルの外部に出ることになります.これに対しては何か解決策的なもの取れないんでしょうかね.現在実装しているパストレの自作レンダラに機械学習のモジュールを組み込んでいるのですが,やはりカーネルを飛び出すとなれば非常にメモリ的効率が低下し,また,実装としてもかなり面倒くさいものになりますね……

実はこれが誤差集約の実装がまだ揺れている原因です.私がこのコードを書いている時には最適化処理まで一つのカーネルで貫く予定でした.しかし,それが厳しいと分かった以上,thrustのreductionをする方が速いのでは?という気持ちもあります.まだ実装していないので何とも言えないといえばそうですが.

参考資料

Real-time Neural Radiance Caching for Path Tracing
tiny-cuda-nnのgitリポジトリ
Tensorコアを使ってみた
CUDAでShared memoryを48KiB以上使うには