CUDA C++でNeRFをほぼ0から実装してみた(Part3/3): NeRF編

NeRF編: 概要

これまでに説明したMLPエンコーダーに加えて,ボリュームレンダリング(RGB)とその誤差逆伝播,さらにはOccupancy Gridに関する私の実装を説明し,最終的にどのようにNeural Radiance Fieldsが形成されていくかを確認します.

NeRF編: はじめに

Part1, Part2では下準備としてMLPとMultiresolutioin Hash Encoding, Spherical Harmonic Encodingを実装しました.Part3では本格的にNeRFの実装に入っていきます.Part2の最後に画像の近似を行いましたが,あれは2次元平面内部におけるRGB色(RGB場とでも言いましょうか)を近似していくものでした.2次元画像近似は非常に簡単に行えましたが,3次元となった場合はなかなかに複雑化します.それは3次元の状態をそもそも可視化する手法が一筋縄ではいかないためです.色々な手法が提案されてきましたが,その1つにNeRF [Mildenhall et al, 2020]があります.これは煙などの媒質の表現に用いられるボリュームレンダリングを利用したもので,その媒質に関するパラメーターを最適化することにより,空間内部の媒質の状態を近似するというものです.2022年にNVIDIAの発表したInstant NeRFにおいてはOccupancy Gridというものが導入され,それによってボリュームレンダリングのサンプリング効率の向上が実現されました.ではそれぞれの実装を説明していきます.(空間内部の「状態」を近似と書いているのは,空間内部に存在する「物体表面の再構築」と区別するためですが,Occupancy Gridを使っている以上,結局NeRFも後者を行っているのですかね......)

念のため......

内容には注意はしておりますが,記事に誤り等があれば指摘していただけると幸いです.

NeRFの手続き

NeRFと聞くと身構えてしまいそうですが,実はそこまで複雑なことはしていません.まずはざっくりと雰囲気を書くと,
(0): カメラと画像のセットを用意する
(1): カメラから画像のピクセルに対応する方向にレイ(光線)を飛ばす
(2): レイの経路上の点をサンプリング
(3): サンプル点の座標,そしてレイの向きをニューラルネットワーク(NN)に入力
(4): NNを実行し,サンプル点における「色」と「密度」を計算
(5): 得られたサンプル点の情報からボリュームレンダリング方程式を計算
(6): 計算結果と対応するピクセルの色を比較し,誤差を計算
(7): ボリュームレンダリングの処理を逆方向に行う
(8): NNを誤差逆伝播
(9): NN最適化
(10): Occupancy Gridの最適化
となっております.こういう書き方をしているので順番から説明すべきではありますが,実装の説明をする前にボリュームレンダリングについて軽く触れておきます.

ボリュームレンダリング

NeRFではRGBレンダリングとしてボリュームレンダリングを行います.といってもパストレーシングでやるようなランダムウォーク的な経路追跡をするわけではなく, 単純にレイ(半直線)上のランダムにサンプリングされたサンプル点における放射輝度の変化を追うものです(また,距離関数を使用したレイマーチングではありません.).本来はちゃんと導出するのが良いですし,そもそも光線としての光(幾何光学)の妥当性を書くべきだとは思うのですが,今回はそれを他の記事に任せてフロントエンドだけ書きます.

今回は次の現象を扱います.
・媒質の吸収による放射輝度の減衰
・媒質の散乱による放射輝度の減衰
・媒質の発光(emission)や散乱(in-scattering)による放射輝度の増加
レイ上のある異なる2点を,カメラから近い順にt_near,t_far とします.この時, t_near, t_far における媒質によるカメラに入射するレイの放射輝度への関与は


\begin{align}
&Radiance(\mathbf{ray}) = \int_{t_{near}}^{t_{far}} T(t) \sigma(\mathbf{r}(t)) C(\mathbf{r}(t), \mathbf{dir})dt \\ \\
&T(t) = \exp\left(-{\int_{t_{near}}^{t} \sigma({\mathbf{r}(s)}) ds}\right)

\end{align}


ここで, tはカメラからレイ上のある点までの距離で, \mathbf{r}(t)はその点の座標,つまりレイの原点と向きをそれぞれ \mathbf{org}, \mathbf{dir}として, \mathbf{r}(t) = \mathbf{org} + t * \mathbf{dir}です.そして, \sigma(\mathbf{r}(t))はその点における媒質の「密度」で, C(\mathbf{r}(t), \mathbf{dir})はその点における媒質の「色」です.より正確に書くと,「密度」は光学的な消散係数にあたり,「色」は媒質内の粒子における散乱(in-scattering)や発光(emission)による放射輝度への増加に関わるパラメーターにあたります.図に描くとこんな感じのイメージです.

カメラ側から見た時,媒質の「密度」が十分大であれば媒質の「表面」しか見ることが出来ません.これは「表面」より向こう側の放射輝度によるカメラに入射するレイへの寄与が0であるためです.逆に,「密度」が非常に小さい場合,媒質の向こう側が透過して見えます.イイ感じにパラメータを設定してあげると現実世界の空間をある程度近似できそうだというのは何となく想像できます.

計算機でボリュームレンダリングを行う

先ほど示した積分を計算機上で解くためには,理想はレイを無限に分割してやるのが良いのですが,そんなことは出来ないため,有限のサンプル点を使用して近似的に積分を行います.

先ほどの式を離散的に書き直すと図の中に示す式になります.理論的な計算は省略します.["Optical Models for Direct Volume Rendering", Nelson Max, 1995]
この Tは先ほどの積分の式と同じく「どれだけその点におけるレイの放射輝度がカメラに入るレイの放射輝度に寄与しているか」を表します.さて,このように有限のサンプル点をサンプリングしてあげる必要があります.レンダリングの詳しい実装は後ほど行います.

NeRFの実装: (0): カメラと画像のセットを用意する

さて,やっていきましょう.ただロードすればいいと言われればそうなのですが,学習データと教師データのセットを意識してデータを保持してあげると取り扱いが容易になります.今回の実装においては,最初に画像と対応するカメラの姿勢を一気に全部ホスト側のメモリ上にロードします.そして,それをもとにして,「(画像ID, ピクセル座標)- ピクセルの色」を入力-教師のペアとして構造体にしておきます.次の構造体を書きました.

// 各スレッドではこの構造体1つにしかアクセスしないのでAoSの方が良いと考えた
// [31:21]: PixIndexX
// [20:10]: PixIndexY
// [9:0]  : ImageID
struct PixelInfo {
    //int ImageID;
    //int PixIndexX;
    //int PixIndexY;
    unsigned int PixInfo = 0; 
    Vec3h Color;

    MNPT_HOST_DEVICE PixelInfo(unsigned int PixInfo = 0) : PixInfo(PixInfo), Color() {}

    MFFM_HOST void SetUp(unsigned int ImID, unsigned int IndexX, unsigned int IndexY, Vec3h PixColor) {
        if (ImID >= 1024) {
            printf("Too large ImageID was set to PixelInfo\n");
            exit(1);
        }
        if (IndexX >= 2048) {
            printf("Too large ImageWidth was set to PixelInfo\n");
            exit(1);
        }
        if (IndexY >= 2048) {
            printf("Too large ImageHeight was set to PixelInfo\n");
            exit(1);
        }

        PixInfo = 0;
        PixInfo = PixInfo | (IndexX << 21) | (IndexY << 10) | ImID;
        Color = PixColor;
    }

    MFFM_HOST void SetUp(int ImID, int IndexX, int IndexY, __half* PixColor) {
        if (ImID >= 1024) {
            printf("Too large ImageID was set to PixelInfo\n");
            exit(1);
        }
        if (IndexX >= 2048) {
            printf("Too large ImageWidth was set to PixelInfo\n");
            exit(1);
        }
        if (IndexY >= 2048) {
            printf("Too large ImageHeight was set to PixelInfo\n");
            exit(1);
        }

        PixInfo = 0;
        PixInfo = PixInfo | (IndexX << 21) | (IndexY << 10) | ImID;
        Color.from_half(PixColor);
    }
    MFFM_HOST_DEVICE unsigned int PixIndexX() {
        unsigned int ret = (PixInfo >> 21); // [31:21]
        return ret;
    }
    MFFM_HOST_DEVICE unsigned int PixIndexY() {
        unsigned int ret = ((PixInfo & 0x001FFC00) >> 10); // [20:10]
        return ret;
    }
    MFFM_HOST_DEVICE unsigned int ImageID() {
        unsigned int ret = (PixInfo & (0x000003FF)); // [9:0]
        return ret;
    }
};

ただ単純に入れているだけなのですが,メモリ節約のために32bitの変数を切り分けて情報を保存しておきました.画像IDには10bit,画像上のピクセル座標にはそれぞれ11bitを使用しています.また,Vec3hという,半精度小数点を3個内蔵するベクトル構造体に色の情報を保存しておきます.Vec3やVec3hに関する具体的な説明はのちに行います.画像IDからカメラの姿勢にアクセスすることができ,そしてピクセル座標からカメラから飛ばすレイ(後述します)を計算でき,そして色の情報から教師データにアクセスできる,というビジョンです.ちなみにレイ-ピクセルの色というペアだとダメなのかという問いがあり得ますが,GPU上のメモリ(10GB)に載りませんでした.
このようにしてロードしたデータをGPU上に送り,学習時にはシャッフルを施して多様なレイを学習時に与えることを考えます.このあたりの処理はthrustライブラリの力を借りました.ここが「ほぼ0から実装してみた」になっている要因の一つです.

...
thrust::default_random_engine g;
thrust::shuffle(D_thrust_TargetPixelInfo.begin(), D_thrust_TargetPixelInfo.end(), g);
...
    if (IndexOfPixelInfo + nMaxPixelPerBatch > D_thrust_TargetPixelInfo.size()) {
        thrust::shuffle(D_thrust_TargetPixelInfo.begin(), D_thrust_TargetPixelInfo.end(), g);
        IndexOfPixelInfo = 0;
    }
...

全部のピクセルを1イテレーションで全て入力するなんてことはしません.nMaxPixelPerBatchで一回のイテレーションで使用するピクセル数を決めておきます.そして,前シャッフルされてから現在までに使用されているピクセル数をIndexOfPixelInfoで保持しておきます.このif文がおこなっていることは,前回シャッフルしてから初めて,未使用のピクセル数がnMaxPixelPerBatch未満であれば再びシャッフルを行う,ということです.つまりはシャッフルしなおしているだけです.図を置いておきます.

NeRFの実装: (1): カメラから画像のピクセルに対応する方向にレイ(光線)を飛ばす

私たちは光を直線として扱う,すなわちレイ(光線)という概念に馴染みきっております.カメラからレイを飛ばすと言いますが,実際にはカメラに対応する方向から入射してくる光の経路を追っているという方が正しいでしょう.結局何が言いたいかというと,カメラから目的のピクセルの場所に半直線をのばし,その半直線を光路としてカメラに入射してくるレイの放射輝度を求めたいです.なので,その半直線がいかなるものであるかを計算する必要があります.今回は単純にピンホールカメラを考えましょう.

まあCGソフトでよく見る感じのカメラではないでしょうか.基本姿勢としてはカメラの原点(CamOrg),カメラの向いている方向(CamDir),カメラの上向き(CamUp)となります.そして,原点からCamDir方向にCamToScreenDistだけ離れた場所にスクリーンがあると考えてください.スクリーン上の点と,その点を通ってカメラの原点に入り込むレイは一対一に対応します(レイが直線であるため).このようなレイの持つ放射輝度がカメラの写す画となります(細かい話は置いておきます).では,そのようなレイを求めます.図のようにカメラの原点から出てスクリーン上の座標(IndexX, IndexY)を通る半直線(橙の点線)は,まずスクリーンの中央に伸ばしたベクトル(赤色の矢印)にスクリーン上でのX方向(緑色の矢印),Y方向のベクトル(青色の矢印)を足し合わせてあげることで求まります.
では実装するうえで便利な構造体を置いておきましょう.

Vec3構造体

名前の通り,浮動小数点の3次元ベクトルを処理するための構造体です.

struct Vec3 {
    float x, y, z;
    float align;
    KGYK_HOST_DEVICE Vec3(const float x = 0.0f, const float y = 0.0f, const float z = 0.0f, const float align = 0.0f) : x(x), y(y), z(z), align(align) {}

    KGYK_HOST_DEVICE Vec3 operator+(const Vec3& b) const {
        return { x + b.x, y + b.y, z + b.z };
    }
    KGYK_HOST_DEVICE Vec3 operator-(const Vec3& b) const {
        return { x - b.x, y - b.y, z - b.z };
    }
    KGYK_HOST_DEVICE Vec3 operator*(const float& b) const {
        return { x * b, y * b, z * b };
    }
    KGYK_HOST_DEVICE Vec3 operator/(const float& b) const {
        return { x / b, y / b, z / b };
    }
    KGYK_HOST_DEVICE float length_squared() const {
        return x * x + y * y + z * z;
    }
    KGYK_HOST_DEVICE float length() const {
        return sqrtf(length_squared());
    }
    // Vec3 to float*
    KGYK_HOST_DEVICE void to_float(float* Dst) {
        Dst[0] = x;
        Dst[1] = y;
        Dst[2] = z;
    }
    KGYK_HOST_DEVICE void from_float(float* Src) {
        x = Src[0];
        y = Src[1];
        z = Src[2];
    }
    KGYK_HOST_DEVICE float operator[](size_t idx) {
        return *((float*)this + idx);
    }
    KGYK_HOST_DEVICE float at(size_t idx) {
        if (idx == 3) printf("WARNING: Accessing to Padding Area of Vec3 in Vec3::at()\n");
        if (idx > 3) printf("ERROR: Out-of-range access in Vec3::at()\n");
        return *((float*)this + idx);
    }
};

KGYK_HOST_DEVICE Vec3 operator*(float t, const Vec3 v) {
    return v * t;
}
KGYK_HOST_DEVICE bool operator==(const Vec3 a, const Vec3 b) {
    return (a.x == b.x && a.y == b.y && a.z == b.z);
}
KGYK_HOST_DEVICE bool operator!=(const Vec3 a, const Vec3 b) {
    return (a.x != b.x || a.y != b.y || a.z != b.z);
}
// 長さ1に正規化する
KGYK_HOST_DEVICE inline Vec3 normalize(const Vec3& b, const char* text = "") {
    if (b.length() == 0) {
        printf("Vec3 Warning: 0-size vector normalization. %s\n", text);
        return { 10000, 0, 0 };
    }
    return Vec3(b / b.length());
}
// AABB内部のベクトルbについて,そのAABBを[0,1]^3と正規化したときのベクトルbを求める.
KGYK_HOST_DEVICE inline Vec3 normalize_by_range(const Vec3& b, Vec3& pmin, Vec3& pmax) {
    Vec3 ret;
    ret.x = normalize(b.x, pmin.x, pmax.x, 0.0f, 1.0f);
    ret.y = normalize(b.y, pmin.y, pmax.y, 0.0f, 1.0f);
    ret.z = normalize(b.z, pmin.z, pmax.z, 0.0f, 1.0f);
    return ret;
}
KGYK_HOST_DEVICE inline const Vec3 multiply(const Vec3& v1, const Vec3& v2) {
    return Vec3(v1.x * v2.x, v1.y * v2.y, v1.z * v2.z);
}
KGYK_HOST_DEVICE inline const Vec3 divide(const Vec3& v1, const Vec3& v2) {
    return Vec3(v1.x / v2.x, v1.y / v2.y, v1.z / v2.z);
}
KGYK_HOST_DEVICE inline  float dot(const Vec3& v1, const Vec3& v2) {
    return v1.x * v2.x + v1.y * v2.y + v1.z * v2.z;
}
KGYK_HOST_DEVICE inline  float absdot(const Vec3& v1, const Vec3& v2) {
    return std::abs(dot(v1, v2));
}
KGYK_HOST_DEVICE inline  Vec3 cross(const Vec3& v1, const Vec3& v2) {
    return Vec3(
        (v1.y * v2.z) - (v1.z * v2.y),
        (v1.z * v2.x) - (v1.x * v2.z),
        (v1.x * v2.y) - (v1.y * v2.x));
}
KGYK_HOST_DEVICE inline const Vec3 exponential(const Vec3& v) {
    return Vec3(exp(v.x), exp(v.y), exp(v.z));
}

// "TextBefore v.x v.y v.z TextAfter\n"
KGYK_HOST_DEVICE inline void printVec3(const Vec3& v, const char* TextBefore = "", const char* TextAfter = "") {
    printf("%s %f %f %f %s\n", TextBefore, v.x, v.y, v.z, TextAfter);
}

KGYK_HOST_DEVICE……等はCUDAのhostやdevice等の修飾子です.特に書くこともないですので説明は省略します.ちなみにメンバ変数のalignは構造体サイズの32byteアラインメントです.Vec3hはこれらを半精度で行う構造体です.

Ray構造体

レイの構造体です.

/*
 * レイの情報を持つ構造体
 *
 * dir: レイの方向ベクトル
 * org: レイの原点位置ベクトル
 */
struct NeRFRay {
    Vec3 dir;
    Vec3 org;
    int nSample;
    int SampleBeginIdx;
    float tmin, tmax, pad1, pad2;
    MNPT_HOST_DEVICE NeRFRay(const Vec3& dir = {1,0,0}, const Vec3& org = {0,0,0}) : org(org), dir(dir) {}
};

nSample,SampleBeginIdx, tmin, tmaxはこの後使用する際に説明します.

Camera構造体

さて,ではカメラの実装に取り掛かりましょう.

struct NeRFCamera {
    size_t Image_height;
    size_t Image_width;
    Vec3 CamOrg;
    Vec3 CamDir;
    float CamToScreenDist;
    Vec3 ScreenXDir;
    Vec3 ScreenYDir;
    float PixelSize = 0.0f;

    MNPT_HOST_DEVICE NeRFCamera(const Vec3& CamDir, const Vec3& CamOrg, const float& CamToScreenDist) :
        Image_height(0), Image_width(0), CamDir(CamDir), CamOrg(CamOrg), CamToScreenDist(CamToScreenDist) {}

    // 指定されたカメラの情報からカメラの姿勢を求める
    MNPT_HOST void SetUp(float rotateX, float rotateY, float rotateZ) {
        const Vec3 CamUp = rotate(DEG2RAD(rotateX), DEG2RAD(rotateY), DEG2RAD(rotateZ), normalize({ 0.0, 1.0, 0.0 }));
        CamDir = rotate(DEG2RAD(rotateX), DEG2RAD(rotateY), DEG2RAD(rotateZ), normalize({ 0.0, 0.0, -1.0 }));
        ScreenXDir = cross(CamDir, CamUp);
        ScreenYDir = CamUp;
        PixelSize = 36.0f / (float)Image_height;
    }

    MNPT_HOST void SetUp(const size_t InputImage_height, const size_t InputImage_width, std::vector<std::vector<float>>& TransformMatrix, float FOVY) {
        // 入力画像のサイズをカメラに記録
        Image_height = InputImage_height;
        Image_width = InputImage_width;

        // 変換行列によりカメラの姿勢を計算
        Vec3 CamUp = rotate(TransformMatrix, normalize({ 0.0, 1.0, 0.0 }));
        CamDir = rotate(TransformMatrix, normalize({ 0.0, 0.0, -1.0 }));

        std::vector<std::vector<float>> Pos(4, std::vector<float>(1));
        std::vector<std::vector<float>> Base{{0.0f}, { 0.0f }, { 0.0f }, { 1.0f }};
        Pos = multiply_matrix(TransformMatrix, Base);
        CamOrg = { Pos[0][0], Pos[1][0], Pos[2][0] };

        float top = 18.0f;
        ScreenYDir = CamUp;
        ScreenXDir = cross(CamDir, CamUp);
        PixelSize = 2 * top / (float)Image_height;
        CamToScreenDist = top / tanf(FOVY/2.0f);
    }
    MNPT_HOST void SetUp(Camera* Cam) {
        Image_height = ImSizeY;
        Image_width = ImSizeX;
        CamOrg = Cam->CamOrg;
        CamDir = Cam->CamDir;
        CamToScreenDist = Cam->CamToScreenDist;
        ScreenXDir = Cam->ScreenXDir;
        ScreenYDir = Cam->ScreenYDir;
        PixelSize = Cam->PixelSize;
    }

    // 画像上のピクセルインデックス{IndexX, IndexY}に対してのカメラからの1次レイを求める
    MNPT_DEVICE inline NeRFRay Generate1stRay(const int IndexX, const int IndexY) {
        Vec3 RayDir = CamDir * CamToScreenDist +
                      ScreenXDir * ((float)IndexX - Image_height / 2.0f) * PixelSize +
                      ScreenYDir * (Image_width / 2.0f - (float)IndexY) * PixelSize;
        RayDir = normalize(RayDir, "FirstRayGen");

        return {RayDir, CamOrg};
    }
};

これは多少見ておきましょう.

struct NeRFCamera {
    size_t Image_height;
    size_t Image_width;
    Vec3 CamOrg;
    Vec3 CamDir;
    float CamToScreenDist;
    Vec3 ScreenXDir;
    Vec3 ScreenYDir;
    float PixelSize = 0.0f;
...

Image_heightやImage_widthはこのカメラの映し出す画像のサイズを指します.もっと言えばスクリーンのアスペクトを定義すると言っても良いでしょう.
PixelSizeはスクリーン上のピクセルサイズです.実はスクリーンのY方向のサイズを36に固定しております.つまり,PixelSizeは36/Image_heightとなります.この36という値は特に意味はなく,色々試した結果イイ感じの値というだけです.
後に続くSetUp関数は入力されたカメラの姿勢からメンバ変数を計算してあげる関数です.説明は省略しますが,rotate()は回転行列を作用させる関数です.
では本題のレイの計算に入ります.

 // 画像上のピクセルインデックス{IndexX, IndexY}に対してのカメラからの1次レイを求める
    MNPT_DEVICE inline NeRFRay Generate1stRay(const int IndexX, const int IndexY) {
        Vec3 RayDir = CamDir * CamToScreenDist +
                      ScreenXDir * ((float)IndexX - Image_height / 2.0f) * PixelSize +
                      ScreenYDir * (Image_width / 2.0f - (float)IndexY) * PixelSize;
        RayDir = normalize(RayDir, "FirstRayGen");

        return {RayDir, CamOrg};
    }

図中に示した矢印を赤緑青と上から順番に加算していることが分かると思います.細かいことを言えばピクセルの中心に当てるためには多少調整がいるのですが,まあしなくても良いです.計算したレイの長さを1に正規化してあげるのを忘れないようにしましょう.さて,ピクセル座標が分かればレイを計算できるようになりました.ここで先ほど入力データと教師データの対の構造体をつくっていました.ではこの構造体からレイを計算してあげましょう.

// カメラの情報からレイのバッチを作成する.
// これによって(レイ -- ピクセル色)という(入力データ--教師データ)の構造が構築される
__global__ void NeRF_GenerateRay(const int nPixel, NeRFCamera* Cam, PixelInfo *PixInfo,  NeRFRay* GeneratedRay) {
    int PixelId = blockIdx.x * blockDim.x + threadIdx.x;
    if (PixelId >= nPixel) {
        return;
    }

    PixelInfo PixelInfo_this_thread = PixInfo[PixelId];
    int ImageId = PixelInfo_this_thread.ImageID();
    int PosX = PixelInfo_this_thread.PixIndexX();
    int PosY = PixelInfo_this_thread.PixIndexY();

    // このピクセルにおけるRay
    NeRFRay Ray = Cam[ImageId].Generate1stRay(PosX, PosY);
    GeneratedRay[PixelId] = Ray;
}

細かく見ましょう.

__global__ void NeRF_GenerateRay(const int nPixel, NeRFCamera* Cam, PixelInfo *PixInfo,  NeRFRay* GeneratedRay) {
...

nPixel: 1イテレーションで使用するピクセル数です.さっき書いたnMaxPixelPerBatchがこれにあたります.
Cam: カメラ配列です.i番目の要素は画像IDがiの画像を撮影したカメラ情報を持つ(NeRF)Camera構造体です.
PixInfo: (画像ID, ピクセル座標)- (教師画像の色)を対として持つ構造体です.
GeneratedRay: 計算されたレイが格納される場所

...
int PixelId = blockIdx.x * blockDim.x + threadIdx.x;
if (PixelId >= nPixel) {
    return;
}
...

いつものやつです.スレッド数がnPixelを溢れた分は帰らせる処理です.

...
PixelInfo PixelInfo_this_thread = PixInfo[PixelId];
int ImageId = PixelInfo_this_thread.ImageID();
int PosX = PixelInfo_this_thread.PixIndexX();
int PosY = PixelInfo_this_thread.PixIndexY();

// このピクセルにおけるRay
NeRFRay Ray = Cam[ImageId].Generate1stRay(PosX, PosY);
GeneratedRay[PixelId] = Ray;
...

スレッドIDに対応するPixelInfoにアクセスし,情報を読み出し,それを元にしてレイを計算しています.
なお,先程のシャッフルの際に「すべてのピクセルを1イテレーションでは使用しない」と書いていましたが,イテレーションに関わらずPixelInfoへのアクセスインデックスが0以上nPixel未満となっており,怪しそうですが,カーネル実行時に次のような引数を与えてPixelInfoポインタの先頭をずらしています.

// レイを計算する
NeRF_GenerateRay <<<nMaxPixelPerBatch / 512, 512, 0, stream_train >>>
                (
                    nMaxPixelPerBatch,
                    D_NeRFCAM,
                    thrust::raw_pointer_cast(&D_thrust_TargetPixelInfo[IndexOfPixelInfo]),
                    thrust::raw_pointer_cast(&D_RayBatch[0])
                );
gpuErrchk(cudaGetLastError());
gpuErrchk(cudaStreamSynchronize(stream_train));

IndexOfPixelInfoが何たるやは先ほどPixelInfoの説明をした際のシャッフルの話で記述した通りです.また,thrust::raw_pointer_cast()はthrust::device_vectorのデバイスメモリ上でのポインタを返してくれる嬉しい関数です.

NeRFの実装: (2): レイの経路上の点をサンプリング

さあ,下準備も済んでようやくNeRFらしくなってきました.レイの経路上の点をサンプリングするうえではサンプリングする領域を決める必要があります.NeRFを生成する空間をAABB(Axis-Aligned-Bounding-Box),つまりxyz軸に辺がそれぞれ平行である直方体,で定義してあげると計算が楽です.言わずもがなですが,サンプリングする領域はレイの経路の内,このAABB内部にある線分でしょう.ではAABBとレイの交差判定を行いましょう.そのために,AABBの構造体が必要です.

struct NeRFAABB {
    Vec3 PosMin;
    Vec3 PosMax;

    MNPT_HOST_DEVICE NeRFAABB(const Vec3 PosMin = Vec3(SINF, SINF, SINF), const Vec3 PosMax = Vec3(BINF, BINF, BINF)) : PosMin(PosMin), PosMax(PosMax) {}

    MNPT_DEVICE bool willIntersectWithAABB(NeRFRay& Ray) {
        float t_max = 1e16f;
        float t_min = -1e16f;

#pragma unroll
        for (int i = 0; i < 3; i++) {
            float t1, t2, t_near, t_far;
            if (i == 0) {
                if (abs(Ray.dir.x) < 1e-5) { // x軸に垂直な平面上のレイを飛ばす場合
                    if (PosMin.x > Ray.org.x || PosMax.x < Ray.org.x) return false;
                    else continue;
                }
                t1 = (PosMin.x - Ray.org.x) / Ray.dir.x;
                t2 = (PosMax.x - Ray.org.x) / Ray.dir.x;
            }
            else if (i == 1) {
                if (abs(Ray.dir.y) < 1e-5) { // Y軸に垂直な平面上のレイを飛ばす場合
                    if (PosMin.y > Ray.org.y || PosMax.y < Ray.org.y) return false;
                    else continue;
                }
                t1 = (PosMin.y - Ray.org.y) / Ray.dir.y;
                t2 = (PosMax.y - Ray.org.y) / Ray.dir.y;
            }
            else {
                if (abs(Ray.dir.z) < 1e-5) { // Z軸に垂直な平面上のレイを飛ばす場合
                    if (PosMin.z > Ray.org.z || PosMax.z < Ray.org.z) return false;
                    else continue;
                }
                t1 = (PosMin.z - Ray.org.z) / Ray.dir.z;
                t2 = (PosMax.z - Ray.org.z) / Ray.dir.z;
            }

            t_near = min(t1, t2);
            t_far = max(t1, t2);
            t_max = min(t_max, t_far);
            t_min = max(t_min, t_near);

            if (t_min > t_max) return false;
        }
        Ray.tmin = max(t_min, 1e-4f);
        Ray.tmax = t_max;
        return true;
    }
};

Part2でも示しましたが,AABBを定義するにはすべての軸において座標が小さい方の点の座標,そしてその対角点(すべての軸において座標が大きい方の点)の座標があれば十分です.これがPosMinとPosMaxです.さて,肝心の交差判定の具体的な説明は以下の記事に任せて省略します.

marupeke296.com qiita.com

さて,先程(NeRF)Ray構造体のところでメンバ変数にtmin, tmaxとありました.これはレイの経路においてAABB内部にある,レイの原点からの距離tの最小値と最大値を保存します.私の実装における

...
Ray.tmin = max(t_min, 1e-4f);
Ray.tmax = t_max;
...

がそれです.ちなみにt_minが0や負になってほしくはないのでその場合は微小値を入れておきます.図で描くと次の感じです.

CMOSインバーターみたいなのはカメラです.t_min や t_maxが無い場合はそのレイに関してはサンプリングする必要がありません.
さて,サンプリングする領域が求まったのでサンプリングしていきましょう.サンプリングする処理を見ていきましょう.

// サンプル点を求める
/*
 * ミニバッチに関する処理
 * Rays: 関数内部でインデックス調整(入力には[0]を先頭としたポインタを渡す)
 * Pos: 関数内部でインデックス調整(入力には[0]を先頭としたポインタを渡す)
 * Dir: 関数内部でインデックス調整(入力には[0]を先頭としたポインタを渡す)
 */
__global__ void NeRF_GenerateSample
(
    const uint32_t nMaxBatch,
    const uint32_t nMaxPixelBatch, 
    NeRFInfo* SamplingInfo, NeRFRay* Rays, float* Pos, float* Dir, OccupancyGrid* Grids, 
    const uint32_t epoch
) 
{
    uint32_t threadID = blockIdx.x * blockDim.x + threadIdx.x;
    if (threadID >= nMaxPixelBatch) {
        return;
    }
    // レイ上の点をサンプリングする
    NeRFRay* Ray = &Rays[threadID];
    const Vec3 AABBMin = { SamplingInfo->AABB_pMin.x,SamplingInfo->AABB_pMin.y,SamplingInfo->AABB_pMin.z };
    const Vec3 AABBMax = { SamplingInfo->AABB_pMax.x,SamplingInfo->AABB_pMax.y,SamplingInfo->AABB_pMax.z };
    

    uint32_t nAcceptedSample = 0;

    // AABBとの交差判定
    NeRFAABB NeRFBox = { AABBMin, AABBMax};
    if (!NeRFBox.willIntersectWithAABB(*Ray)) {
        Ray->nSample = 0;
        return;
    }
    
    float t_min = Ray->tmin;
    float t_max = Ray->tmax;

    const float dt = (t_max - t_min) / NERF_MAX_SAMPLE_PER_RAY;

    // 0になった場合プログラムが停止してしまうため
    if (dt < SINF) {
        Ray->nSample = 0;
        return;
    }
    
    // RNG
    curandState state, state_old, state_firstAccept;
    curand_init(epoch, threadID, 0, &state);

    // Uniform sampling
    int index = 0;
    int index_firstAccept = 0;
    constexpr float Throughput_thres = 0.01f;
    float Throughput_hat = 1.0f; // これがThroughput_thresを下回れば中断
    Vec3 LastSamplePos;

    // 前半 ////////////////////////////////////////////////////////////////////////////////////////////////////////
    // とりあえずどのレイがどれぐらいサンプル点を取るかを求める
    // それによってMLPへの入力データのどの領域にサンプル点の情報が保存されるかを計算する.
    while (nAcceptedSample < NERF_MAX_SAMPLE_PER_RAY) {
        const float rnd = 0.5f;// curand_uniform(&state);
        float t1 = t_min + SINF + dt * index;
        float t2 = t_min + SINF + (dt * (index+1));
        float t = normalize(rnd, 0.0f, 1.0f, t1, t2);

        Vec3 s_Pos = Ray->org + Ray->dir * t;

        // AABBの領域に収める
        if (t > t_max) {
            break;
        }

        // 念のため
        s_Pos.x = clamp(s_Pos.x, AABBMin.x + SINF, AABBMax.x - SINF);
        s_Pos.y = clamp(s_Pos.y, AABBMin.y + SINF, AABBMax.y - SINF);
        s_Pos.z = clamp(s_Pos.z, AABBMin.z + SINF, AABBMax.z - SINF);

        // Occupancy Gridにおいて「物体が無い」と判断された場合はスキップする
        if (epoch > 16 && !Grids[0].is_Occupied(s_Pos)) {
            index++;
            continue;
        }
        else {
            if (nAcceptedSample == 0) {
                index_firstAccept = index;
            }
            nAcceptedSample++;
        }
        // Throughputの更新
        if (epoch > 32 && nAcceptedSample >= 2) {
            Vec3 FromLastSample =
            {
                s_Pos.x - LastSamplePos.x,
                s_Pos.y - LastSamplePos.y,
                s_Pos.z - LastSamplePos.z
            };
            Throughput_hat *= expf(-1.0f * Grids[0].Float_at(s_Pos) * FromLastSample.length());
        }
        // それ以降のサンプル点の寄与の見込みが小さい場合は打ち切る
        if (Throughput_hat < Throughput_thres) {
            break;
        }
        //state_old = state;
        LastSamplePos = s_Pos;
        index++;
    }
    
    // サンプリングが完了した
    // Dir配列とOrg配列にサンプル情報を記録する(配列上で連続となるようにする)
    // 格納先の配列におけるインデックス範囲を記録 : [begin, end]
    int begin_idx = atomicAdd(&SamplingInfo->nSamples, nAcceptedSample);
    Ray->nSample = nAcceptedSample;
    Ray->SampleBeginIdx = begin_idx;

    // 後半 ///////////////////////////////////////////////////////////////////////////////////////////////////
    // サンプルの保存インデックスがMaxBatchを超えている場合はそのレイを捨てる
    if (begin_idx + nAcceptedSample >= nMaxBatch) {
        Ray->nSample = 0;
        atomicAdd(&SamplingInfo->nSamples, -nAcceptedSample);
        return;
    }
    atomicAdd(&SamplingInfo->RayNum, 1);

    // はじめてサンプルを得たところまでスキップ
    index = index_firstAccept;
    state = state_firstAccept;
    nAcceptedSample = 0;
    Throughput_hat = 1.0f;
    
    while (nAcceptedSample < NERF_MAX_SAMPLE_PER_RAY) {
        const float rnd = 0.5f;//curand_uniform(&state);
        float t1 = t_min + SINF + dt * index;
        float t2 = t_min + SINF + (dt * (index+1));
        float t = normalize(rnd, 0.0f, 1.0f, t1, t2);

        Vec3 s_Pos = Ray->org + Ray->dir * t;

        // AABBの領域に収める
        if (t > t_max) {
            break;
        }

        // 念のため
        s_Pos.x = clamp(s_Pos.x, AABBMin.x + SINF, AABBMax.x - SINF);
        s_Pos.y = clamp(s_Pos.y, AABBMin.y + SINF, AABBMax.y - SINF);
        s_Pos.z = clamp(s_Pos.z, AABBMin.z + SINF, AABBMax.z - SINF);

        // Occupancy Gridにおいて「物体が無い」と判断された場合はスキップする
        if (epoch > 16 && !Grids[0].is_Occupied(s_Pos)) {
            index++;
            continue;
        }
        else {
            Ray->dir.to_float(Dir + 3 * (begin_idx + nAcceptedSample));
            s_Pos.to_float(Pos + 3 * (begin_idx + nAcceptedSample));
            nAcceptedSample++;
        }
        // Throughputの更新
        if (epoch > 32 && nAcceptedSample >= 2) {
            Vec3 FromLastSample =
            {
                s_Pos.x - LastSamplePos.x,
                s_Pos.y - LastSamplePos.y,
                s_Pos.z - LastSamplePos.z
            };
            Throughput_hat *= expf(-1.0f * Grids[0].Float_at(s_Pos) * FromLastSample.length());
        }
        // それ以降のサンプル点の寄与の見込みが小さい場合は打ち切る
        if (Throughput_hat < Throughput_thres) {
            break;
        }
        LastSamplePos = s_Pos;
        index++;
    }
}

さて,とてつもなく長いコードが出てきました.部分で見ていきましょう.

__global__ void NeRF_GenerateSample
(
    const uint32_t nMaxBatch,
    const uint32_t nMaxPixelBatch, 
    NeRFInfo* SamplingInfo, NeRFRay* Rays, float* Pos, float* Dir, OccupancyGrid* Grids, 
    const uint32_t epoch
) 
{
...

nMaxBatch: すべてのレイに対してサンプリングする点の総和の上限(NNに入力するバッチサイズの上限)
nMaxPixelBatch: ピクセル数の上限(このイテレーションで考慮するレイの本数)
NeRFInfo* SamplingInfo: サンプリングに関わる情報を記録する構造体
Rays: レイが保存されている配列
Pos, Dir: サンプル点における座標とレイの方向を記録する配列
Grids: Occupancy Gridの情報(最後に説明します)

nMaxPixelBatchとnMaxBatchが両方ある理由ですが,実際の処理においては本当にnMaxBatchだけサンプル点をサンプリングするわけではありません.この話は今回のNeRFのコード設計に大きくかかわっていることです.具体的な理由はこの後で説明します. さて,ここで出てきた構造体NeRFInfoについて書いておきます.

struct NeRFInfo {
    uint32_t nSamples = 0;
    uint32_t RayNum = 0;
    float3 AABB_pMin = { -1.0f, -1.0f, -1.0f };
    float3 AABB_pMax = { 1.0f, 1.0f, 1.0f };
    
    // GUI
    uint32_t epoch = 0;
    float Loss = 0.0f;
    uint32_t BatchSize = 0;
    uint32_t nRay = 0;
    bool StopTrain = false;
    bool StopRender = false;
};

nSamplesはすぐ後で説明します
RayNum: 最終的に使用することとなったレイの本数(詳しいことはすぐ後で説明します)
AABB_pMin, AABB_pMax: NeRFを生成するAABBの設定です.
これ以降のメンバ変数はNeRFの実装に直接かかわってはこないので省略します.(GUI上で情報を得るため等の為の変数です)

さて,さっきから「すべてをサンプリングしない」とか「最終的に使用することとなったレイの本数」とかいう理解を拒ませる変なことを言ってますが,さっさと明らかにしておきます.今回の関数は「各レイにおけるサンプル数を動的に決定する」実装を行っています.具体的な処理の流れを確認しましょう.まず実装の上で次を前提としています.

各レイはNERF_MAX_SAMPLE_PER_RAYという即値よりも多くサンプル点をサンプリングすることを許されません.しかし,すべてのレイがNERF_MAX_SAMPLE_PER_RAYもサンプル点を必要としているわけではありません.例えば,次の図に示すような場合,レイに対するサンプル数が少なくて済みます.

さらに,MLPに入力できるバッチサイズの上限をnMaxBatchとして,具体的には221として固定しております.また,学習の速度を上げたいのでなるべく多くのレイをバッチにはめ込みたいという欲望があります.仮にすべてのレイがNERF_MAX_SAMPLE_PER_RAY個のサンプル点を取ると決めつけてしまうと実装は楽になりますが,かなりレイの数が限られてしまいます.逆に,各レイの持つサンプル点の個数を動的になるべく最小限に抑えてあげることにより,実装は少しめんどくさいですがバッチになるべく多くのレイが収まってくれます.

各レイについて動的にサンプリングを行うとしましたが,じゃあ一体どうやって「扱いやすいデータ」とするのかという疑問が浮かびます.まあ実装してみると分かりますが,次のような問題が発生します.

簡単な実装をするとすれば,各レイに対してNERF_MAX_SAMPLE_PER_RAY要素だけサンプル点情報の領域を確保しておき,余った分は空白としておく,となるでしょう.しかし,これは全くもってメモリ領域の無駄であり,さらに言えばNNに入力するデータとしてはバッチに空白のデータが入り込んでしまい,正しい処理が出来ません.そのため,理想としては下側のように各レイでのサンプル点がメモリ上で連続に配置されている状態となります.

さて,各レイの処理は並列に行っています.ですが,このままでは処理中のスレッドにとって「今自分が処理しているサンプル点の情報をメモリ上のどこに書き込めばいいのか」が不明です.並列処理を諦めることになるのでしょうか.いいえ,実は同じ処理を2回繰り返すことでほとんど並列性を保って実装できます.前半では各スレッドが「自分が何個サンプリングすることになって,サンプル点情報を書き込む配列のどこを始点として書き込めばいいのか」を決定します.

そして後半では「前半と全く同じサンプリングを行い,決めた領域に書き込む」ということを行います.

では具体的な実装を確認していきます.

...
uint32_t nAcceptedSample = 0;
...

nAcceptedSample: このレイにおけるサンプリングすると決定されたサンプル点の数を記録します

// AABBとの交差判定
NeRFAABB NeRFBox = { AABBMin, AABBMax};
if (!NeRFBox.willIntersectWithAABB(*Ray)) {
    Ray->nSample = 0;
    return;
}

float t_min = Ray->tmin;
float t_max = Ray->tmax;

そもそもNeRFを生成する領域(AABB)とレイが交差しなければそのレイについては一切のサンプリングを行いません.なので交差判定を行って処理の必要性,そしてレイの経路上でAABBの内部にある領域を求めておきます.また,(NeRF)Rayのメンバ変数にあったnSampleは,この後も書きますが「そのレイが持つサンプル点の数」です.AABBとの交差が確認されないレイは0となります.

const float dt = (t_max - t_min) / NERF_MAX_SAMPLE_PER_RAY;

// 0になった場合プログラムが停止してしまうため
if (dt < SINF) {
    Ray->nSample = 0;
    return;
}

さて,サンプリングを行います.今回は単純に,AABB内部にあるレイの経路をNERF_MAX_SAMPLE_PER_RAY等分し,等分された一つの経路上の点をサンプル点としてふさわしいかを考慮することにします.なので「ステップサイズ」としてdtを設定してあげます.ちなみにコーナーケースとして,「レイがAABBの端ぎりぎりに入射した場合」はdtが微小となります.この場合,最悪プログラムがこの後のwhile文内でスタックするので微小の場合はやはりサンプリングせずにreturnしておきます.

// RNG
curandState state, state_old, state_firstAccept;
curand_init(epoch, threadID, 0, &state);

これは乱数生成のためのCUDAライブラリですが,現状の実装では使わないので無視してください(経路上の点をランダムサンプリングするやり方もありますが,今回は等間隔サンプリングとします.ちなみにInstantNGPの論文では他にも色々サンプリング手法が書かれています).

// Uniform sampling
int index = 0;
int index_firstAccept = 0;
constexpr float Throughput_thres = 0.01f;
float Throughput_hat = 1.0f; // これがThroughput_thresを下回れば中断
Vec3 LastSamplePos;

index: 何個目の等間隔サンプル点を考えているか,つまり考えているサンプル点のインデックスです.
index_firstAccept: 後半の処理では全く同じサンプリングをするので,前半で初めてサンプリングが成立したインデックスを記録しておきます.
Throughput_thres: 先ほどの例の図に示したcase2のように,寄与が非常に小さい場合は無視しても問題ありません.なのでその閾値を設定しておきます.
Throughput_hat: この寄与の大きさはNNを通さないと分からないので,Occupancy Gridに保存されている情報を用いて「ざっくりと」推定しておきます
LastSamplePos: 寄与の大きさを計算するためには前のサンプルからの距離が必要です.

// 前半 ////////////////////////////////////////////////////////////////////////////////////////////////////////
// とりあえずどのレイがどれぐらいサンプル点を取るかを求める
// それによってMLPへの入力データのどの領域にサンプル点の情報が保存されるかを計算する.
while (nAcceptedSample < NERF_MAX_SAMPLE_PER_RAY) {

サンプル数がNERF_MAX_SAMPLE_PER_RAYを超えたらもうそれ以上サンプリングさせないようにします(なお実装上,バグが無ければありえません)

const float rnd = 0.5f;// curand_uniform(&state);
float t1 = t_min + SINF + dt * index;
float t2 = t_min + SINF + (dt * (index+1));
float t = normalize(rnd, 0.0f, 1.0f, t1, t2);

サンプル点のカメラからの距離を決めてあげます.index番目のサンプル点は[t1, t2]の領域にあります.今回はここの中心を使ってあげます.つまり,サンプル点のカメラからの距離は t = (t1 + t2)/2となっています.

Vec3 s_Pos = Ray->org + Ray->dir * t;

// AABBの領域に収める
if (t > t_max) {
    break;
}

// 念のため
s_Pos.x = clamp(s_Pos.x, AABBMin.x + SINF, AABBMax.x - SINF);
s_Pos.y = clamp(s_Pos.y, AABBMin.y + SINF, AABBMax.y - SINF);
s_Pos.z = clamp(s_Pos.z, AABBMin.z + SINF, AABBMax.z - SINF);

s_Posはサンプル点の座標を記録する変数です.
サンプル点の座標がt_maxを超えた場合は,これ以降のindexにおいてサンプル点がAABBの内部に戻ってくることはあり得ないのでbreakしてあげます.
もしものため,s_Posの座標が絶対にAABBの内部に存在するようにしておきます.ちなみにこれをしなくても外部にあることはあり得ないはずです.

// Occupancy Gridにおいて「物体が無い」と判断された場合はスキップする
if (epoch > 16 && !Grids[0].is_Occupied(s_Pos)) {
    index++;
    continue;
}
else {
    if (nAcceptedSample == 0) {
        index_firstAccept = index;
    }
    nAcceptedSample++;
}

Occupancy Gridの話は最後にしますが,ここでやっているのは「Occupancy Gridに格納されているデータを参照し,そのサンプル点において「物体が存在する」,言い換えると媒質の密度がある程度高いと判定されない場合はサンプル点を無視する」ということをやっています.結局は「カメラに入ってくるレイに関与しないサンプル点を無視する」ということです.   無視しない場合はサンプリングが成立するので,初めてのサンプリング成立であればインデックスを保存し,そしてnAcceptedSampleに1を足してあげます.

// Throughputの更新
if (epoch > 32 && nAcceptedSample >= 2) {
    Vec3 FromLastSample =
    {
        s_Pos.x - LastSamplePos.x,
        s_Pos.y - LastSamplePos.y,
        s_Pos.z - LastSamplePos.z
    };
    Throughput_hat *= expf(-1.0f * Grids[0].Float_at(s_Pos) * FromLastSample.length());
}
// それ以降のサンプル点の寄与の見込みが小さい場合は打ち切る
if (Throughput_hat < Throughput_thres) {
    break;
}

コメントの通りです.

LastSamplePos = s_Pos;
index++;

さて,サンプリングが成立しました.なのでLastSamplePosを更新し,そしてwhile文の終わりなのでindexに1足してあげて次の等間隔領域に移ります.

// サンプリングが完了した
// Dir配列とOrg配列にサンプル情報を記録する(配列上で連続となるようにする)
// 格納先の配列におけるインデックス範囲を記録 : [begin, end]
int begin_idx = atomicAdd(&SamplingInfo->nSamples, nAcceptedSample);
Ray->nSample = nAcceptedSample;
Ray->SampleBeginIdx = begin_idx;

さて,前半のサンプリングの処理が終わりました.ここで各スレッドは「自分がサンプル点情報を格納する配列に,「どこから」「何個」サンプル点情報を保存するか」を知る必要があります.なので,説明時に出てきたCounterにどれだけサンプリングが成立したかを教えてあげます.このCounterはNeRFInfoのメンバ変数であるSamplingInfo.nSamplesです.SamplingInfo.nSamplesは現在既に何個のサンプリングが他のスレッドで成立してきたか(how many samples "have been" generated)をbegin_idxへと格納すると同時に処理中のスレッドが何個サンプル点を得たかを聞き,加算します.勿論,この処理はatomicで行われる必要があります.こうすることでCounterが更新され,そしてスレッドが「配列のどこから格納すればいいのか」を決定できます.
ここで,Rayのメンバ変数であるnSampleとSampleBeginIdxにそのスレッドが「どこから」「何個」サンプル点情報を保存したかを記憶させておきます.これは後にボリュームレンダリングを行う際に必要となる情報です.

// 後半 ///////////////////////////////////////////////////////////////////////////////////////////////////
// サンプルの保存インデックスがMaxBatchを超えている場合はそのレイを捨てる
if (begin_idx + nAcceptedSample >= nMaxBatch) {
    Ray->nSample = 0;
    atomicAdd(&SamplingInfo->nSamples, -nAcceptedSample);
    return;
}
atomicAdd(&SamplingInfo->RayNum, 1);

さて,後半の開始です.この処理の最初に,サンプル点の総数には上限があり,それがnMaxBatchで定義しておく,とありました.その部分の処理を行います.各スレッドが自分のbegin_idxとnAcceptedSampleから配列上に書き込むインデックスの範囲を知ることが出来ます.それがnMaxBatchを超えていた場合は,サンプル点情報の登録をキャンセルします.そして,キャンセルが発生したのでSamplingInfo->nSamplesからそのスレッドにおけるnAcceptedSampleを引いておきます.ちなみにここでスレッド同期の問題が不安として頭をよぎりましたが,「nMaxBatchから溢れている状態」となっているのが端の問題であるので多分大丈夫なはずです.

// はじめてサンプルを得たところまでスキップ
index = index_firstAccept;
state = state_firstAccept;
nAcceptedSample = 0;
Throughput_hat = 1.0f;

後半のサンプリングは前半と全く同じものです.やりましょう.ほとんどが同じなので説明は省略しますが,一か所だけ

// Occupancy Gridにおいて「物体が無い」と判断された場合はスキップする
if (epoch > 16 && !Grids[0].is_Occupied(s_Pos)) {
    index++;
    continue;
}
else {
    Ray->dir.to_float(Dir + 3 * (begin_idx + nAcceptedSample));
    s_Pos.to_float(Pos + 3 * (begin_idx + nAcceptedSample));
    nAcceptedSample++;
}

前半ではLastSamplePosに格納していましたが,今度はサンプル点の情報を格納する配列,即ちDirとPosに保存します.以上でサンプル点の情報をメモリ上で連続に配置することが出来ました.そして,このままNNに入力することが出来ます.なお,前半の際に静的に確保しているメモリ領域にサンプル点情報を保存すれば後半は不要になりますが,メモリ領域を圧迫しますし非効率なメモリ消費がかなり多くなります.

NeRFの実装: (3)と(4): NNの処理

サンプル点における「座標」と「方向」を入力し,サンプル点における「色」と「密度」を推定するNNを実装する必要があります.InstantNGPの論文では次のネットワークが推されています.

FCというのは全結合層で,活性化関数を下に付記しています.ReLUじゃなくてLeakyReLUを使用しても良いです.図の表記にクセがある(作ったのは半年前の私ですが)ので念のため言葉でも説明しておくと,
図中左半分は「密度」を推定するネットワークで,Density MLPと呼ぶことにし,次の構造を取ります.
入力はサンプル点の座標で,これをMultiresolution Hash Encodingに通しておきます. L = 16, F = 2として32次元の出力にします(厳密ではないので異なっていてもまあ良いと思います).
MLP: 入力層32次元, 隠れ層64次元,出力層16次元,隠れ層は1層,活性化関数はすべてReLUないしLeakyReLU,そしてexp(出力層の一つ目の要素)を「密度」の推定値とします.

そして後半ではサンプル点の「色」を推定します.Color MLPとします.ここでサンプル点におけるレイの方向を l \leq 3までのSpherical Harmonic Encoding(実数球面調和関数の基底に変換)を通して16次元にしておきます.そしてdensity networkの出力である16次元(勿論1要素目はexp activationされていない状態のものです)のベクトルと結合し,合わせて32次元としておきます.
MLP: 入力層32次元, 隠れ層64次元,出力層16次元,隠れ層は2層,活性化関数は入力層と隠れ層はReLUないしLeakyReLU,そして出力層はSigmoid(ロジスティック)とします.

なお,私の実装ではColor MLPの隠れ層の数を1にしています.では実装を見ましょう.

template <const uint32_t indim_den, const uint32_t hiddendim_den, const uint32_t outdim_den, const uint32_t nHiddenLayer_den, 
          const uint32_t indim_col, const uint32_t hiddendim_col, const uint32_t outdim_col, const uint32_t nHiddenLayer_col>
__global__ void MFFM_NeRF_Forward
(
    NeRFInfo* SamplingInfo, EncoderInfo* EncInfo_den, EncoderInfo* EncInfo_col,
    Encoder Encoder_den, Activation ActHid_den, Activation ActOut_den,
    Encoder Encoder_col, Activation ActHid_col, Activation ActOut_col,
    float* Pos, float* Dir, 
    __half* weights_den, __half* buffers_den, 
    __half* weights_col, __half* buffers_col, 
    float* Density, float* Color) 
{
    const int BatchSize = SamplingInfo->nSamples;

    // 即値
    constexpr uint32_t indim_den_aligned = next_multiple(indim_den, TENSOR_ROW);
    constexpr uint32_t hiddendim_den_aligned = next_multiple(hiddendim_den, TENSOR_ROW);
    constexpr uint32_t outdim_den_aligned = next_multiple(outdim_den, TENSOR_ROW);
    constexpr uint32_t indim_col_aligned = next_multiple(indim_col, TENSOR_ROW);
    constexpr uint32_t hiddendim_col_aligned = next_multiple(hiddendim_col, TENSOR_ROW);
    constexpr uint32_t outdim_col_aligned = next_multiple(outdim_col, TENSOR_ROW);
    
    const uint32_t maxdim_den = max(indim_den_aligned, max(hiddendim_den_aligned, outdim_den_aligned));

    extern __shared__ __half shmem[];
    __half* intermediate_col = shmem;
    __half* intermediate_den = shmem;
    __half* intermediate_dirInput = shmem + (outdim_den_aligned + SKEW) * ONEBATCH_SIZE;
    //////////////////////////////////////////////////////// Density MLP //////////////////////////////////////////////////////////////////////////
    
    // 入力のロード
    // エンコーダー無しの場合はこの時点でSKEWを与える
    // エンコーダーありの場合はSKEWを与えない(ただのコピー)
    if (Encoder_den == Encoder::None) {
        load_input(indim_den, indim_den_aligned + SKEW, Pos, intermediate_den);
    }
    else {
        load_input(3, 3, Pos, intermediate_den);
    }

    // Encode
    Encode<indim_den_aligned>(Encoder_den, *EncInfo_den,intermediate_den, false, SamplingInfo->AABB_pMin, SamplingInfo->AABB_pMax);
    // Density MLP
    Kernel_Debug_train_forward<indim_den, hiddendim_den, outdim_den, nHiddenLayer_den>(BatchSize, ActHid_den, ActOut_den, intermediate_den, weights_den, buffers_den);
    // Density MLPの出力の1つめの要素をDensityとして保存
    store_intermediate<float>(outdim_den_aligned + SKEW, 1, intermediate_den, Density);

    //////////////////////////////////////////////////////// END: Density MLP //////////////////////////////////////////////////////////////////////
    __syncthreads(); // 必要
    //////////////////////////////////////////////////////// Color MLP //////////////////////////////////////////////////////////////////////////
    // 入力のロード
    // エンコーダー無しの場合はこの時点でSKEWを与える
    // エンコーダーありの場合はSKEWを与えない(ただのコピー)
    if (Encoder_col == Encoder::None) {
        load_input(indim_col - outdim_den, indim_col_aligned - outdim_den_aligned + SKEW, Dir, intermediate_dirInput);
    }
    else {
        load_input(3, 3, Dir, intermediate_dirInput);
    }


    // Encode
    Encode<indim_col_aligned - outdim_den_aligned>(Encoder_col, *EncInfo_col, intermediate_dirInput, false, SamplingInfo->AABB_pMin, SamplingInfo->AABB_pMax);
    
    __syncthreads();

    // Density MLPの出力をconcatする.
    ConcatShmem<outdim_den, indim_col - outdim_den>(intermediate_den, intermediate_dirInput, intermediate_col);
    __syncthreads();

    // Color MLP
    Kernel_Debug_train_forward<indim_col, hiddendim_col, outdim_col, nHiddenLayer_col>(BatchSize, ActHid_col, ActOut_col, intermediate_col, weights_col, buffers_col);
    // 結果を保存
    store_intermediate<float>(outdim_col_aligned + SKEW, 3, intermediate_col, Color);
    //////////////////////////////////////////////////////// END: Density MLP //////////////////////////////////////////////////////////////////////
}

特に複雑なことはしていないので軽く見ていきましょう.

template <const uint32_t indim_den, const uint32_t hiddendim_den, const uint32_t outdim_den, const uint32_t nHiddenLayer_den, 
                 const uint32_t indim_col, const uint32_t hiddendim_col, const uint32_t outdim_col, const uint32_t nHiddenLayer_col>

この関数ではdenとついているものはDensity MLPに関わる値で,colとついているのはColor MLPに関わる値です.indimは入力層の次元,hiddendimは隠れ層の次元,outdimは出力層の次元,nHiddenLayerは隠れ層の数です.

...
__global__ void MFFM_NeRF_Forward
(
    NeRFInfo* SamplingInfo, EncoderInfo* EncInfo_den, EncoderInfo* EncInfo_col,
    Encoder Encoder_den, Activation ActHid_den, Activation ActOut_den,
    Encoder Encoder_col, Activation ActHid_col, Activation ActOut_col,
    float* Pos, float* Dir, 
    __half* weights_den, __half* buffers_den, 
    __half* weights_col, __half* buffers_col, 
    float* Density, float* Color) 
{
const int BatchSize = SamplingInfo->nSamples;
...

SamplingInfo: サンプリング処理で出てきたものと同じです.この構造体に「NN実行時のバッチサイズ」と「NeRFを生成するAABBの情報」が保存されているのでそれを読み込むために使用します.最後の行でBatchSizeに読み込ませているのが分かると思います.
EncInfoというのはエンコーダーに関する情報が入っていますが,まだ実装途中(Part2で軽く触れた内容です)なので無視してください.今回の実装では考えなくていいです.
Encoder, Activationというのは次に示すものです

enum class Encoder {
    None,
    Frequency,
    SH, // Spherical Harmonic
    HashGrid, // Multiresolution Hash Encoding
    UniformGrid
};

enum class Activation {
    ReLU,
    LeakyReLU,
    Sigmoid
};

Pos, Dirはサンプリング処理で得られたサンプル点の情報です.
weightsは全結合層のパラメーターです.
bufferは逆伝播時に使用する順伝播時の各層の入出力を保存しておくための配列です. Density, ColorはNNの出力(「密度」「色」)を記録するための配列です.

...
// 即値
constexpr uint32_t indim_den_aligned = next_multiple(indim_den, TENSOR_ROW);
...
const uint32_t maxdim_den = max(indim_den_aligned, max(hiddendim_den_aligned, outdim_den_aligned));
...

即値を計算していますが詳しい話はPart1を参照してください.

extern __shared__ __half shmem[];
__half* intermediate_col = shmem;
__half* intermediate_den = shmem;
__half* intermediate_dirInput = shmem + (outdim_den_aligned + SKEW) * ONEBATCH_SIZE;

1行目はshared memoryを動的に確保している部分です.処理中の特徴ベクトルであるintermediateをshared memoryに載せてやるという意図です.実際のところDensity MLPとColor MLPでは同じポインタを使用しますが,実装上区別したいので異なる名前を付けています.なおintermediate_dirInputについてですが,サンプル点の方向ベクトルはDensity MLPの処理を終えた後にshared memoryにロードします.Density MLPの出力がshared memoryに載ったままロードするので,Density MLPの出力を破壊しないように後ろ側の空きスペースにロードします.

//////////////////////////////////////////////////////// Density MLP //////////////////////////////////////////////////////////////////////////

// 入力のロード
// エンコーダー無しの場合はこの時点でSKEWを与える
// エンコーダーありの場合はSKEWを与えない(ただのコピー)
if (Encoder_den == Encoder::None) {
    load_input(indim_den, indim_den_aligned + SKEW, Pos, intermediate_den);
}
else {
    load_input(3, 3, Pos, intermediate_den);
}

// Encode
Encode<indim_den_aligned>(Encoder_den, *EncInfo_den,intermediate_den, false, SamplingInfo->AABB_pMin, SamplingInfo->AABB_pMax);
// Density MLP
Kernel_Debug_train_forward<indim_den, hiddendim_den, outdim_den, nHiddenLayer_den>(BatchSize, ActHid_den, ActOut_den, intermediate_den, weights_den, buffers_den);
// Density MLPの出力の1つめの要素をDensityとして保存
store_intermediate<float>(outdim_den_aligned + SKEW, 1, intermediate_den, Density);

//////////////////////////////////////////////////////// END: Density MLP //////////////////////////////////////////////////////////////////////

Density MLPの処理です.今回は必ずエンコーダーを通すので最初のif文はelseを通ります.現状実装しているエンコーダーは全て入力次元を3と仮定しているのでここでは3次元の入力としてハードコードしてロードしてますが,変更を加える予定です.load_input()はPart1を参照してください.やっていることとしては,3次元の入力(座標)をintermediate(shared memory)にロードしているだけです.3次元なのでバンクコンフリクトを避けるためのSKEWは入れません.つまりただのコピーです.
Encodeというのは次の関数です.

template <const uint32_t indim_aligned>
MFFM_DEVICE void Encode(Encoder encoder, EncoderInfo &Info, __half* intermediate, bool isInference, const float3 InputRangeMin = { 0,0,0 }, const float3 InputRangeMax = { 1,1,1 }, float* MHEBufferSTU = nullptr, unsigned int* MHEBufferIdxHT = nullptr) {
    const int bx = blockIdx.x;
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;

    switch (encoder) {
    case Encoder::None:
        // NOP
        break;
    case Encoder::Frequency:
        // not implemented...
        break;
    case Encoder::SH:
        SH::Encode_SH_L4(intermediate, intermediate);
        break;
    case Encoder::HashGrid:
        MHE::Encode<indim_aligned>(InputRangeMin, InputRangeMax, intermediate, intermediate);
        //MHE2::Encode_inside_kernel<MHE_F>(*Info.MHEInfo, InputRangeMin, InputRangeMax, intermediate, intermediate);
        break;
    case Encoder::UniformGrid:
        UGE::Encode(*Info.UGEInfo, intermediate, intermediate);
        break;
    default:
        printf("Invalid Encoding type\n");
        break;
    }
    __syncthreads();
}

今回はMultiresolution Hash Encodingを使用するのでHashGridの部分に入ります.Part2に示したEncode関数を使用します.SamplingInfoに保存している「NeRFを生成するAABBの情報」をMultiresolution Hash Encodingのエンコード領域に設定しているのが分かると思います.Part2を参照すれば具体的に何をしているかが分かるはずです.

さて,エンコードが終わったのでMLPに入力します.関数名がかなり酷いですがKernel_Debug_train_forward()がこの処理をしています.詳しい処理はPart1にあります.(Debug用の関数が最新版に成り果てました……) MLPの処理が終わったので1要素目をDensityに関わる値として保存しておきます.正確にはここでExp Activationをした方が良いのですが,今回は1要素目だけをActivationする特殊な例として扱い,ボリュームレンダリングを行う際にExp Activationしてあげることにします.store_intermediateを使用して保存してあげます.この関数もPart1に書かれています.

__syncthreads(); // 必要

ブロックレベルでの同期を取りましょう.
Color MLPも基本的に同じことをしますが,先程の図に示した結合処理を行っておく必要があります.

...
// Density MLPの出力をconcatする.
ConcatShmem<outdim_den, indim_col - outdim_den>(intermediate_den, intermediate_dirInput, intermediate_col);
...

Color MLPの入力次元はDensity MLPの出力次元とエンコードされた方向ベクトルの次元の和であるので,言い換えるとoutdim_den次元のベクトル(Density MLPの出力)と indim_col - outdim_den次元のベクトル(エンコードされた方向ベクトル)の結合を行うことになる,というのがtemplate引数の部分です.具体的にこの関数の中身を見ましょう

/////////////////////////////// CONCAT AND DIVIDE ///////////////////////////////////////////////////
/*
 * 1: shmem1
 * 2: shmem2
 * s: SKEW
 * 1s1s1s1s1s...1s1s1s1s2s2s2s2s2s2s...2s2s2s2s → 12s12s12s12s12s12s....12s12s12s12s12s
 * 番地的にはshmem = shmem1
 */
template <const uint32_t dim1, const uint32_t dim2>
MFFM_DEVICE void ConcatShmem(__half* shmem1, __half* shmem2, __half* shmem) {

    // 即値
    constexpr uint32_t dim1_aligned = next_multiple(dim1, TENSOR_ROW);
    constexpr uint32_t dim2_aligned = next_multiple(dim2, TENSOR_ROW);
    constexpr uint32_t newDim = dim1 + dim2;
    constexpr uint32_t newDim_aligned = next_multiple(newDim, TENSOR_ROW);

    const int bx = blockIdx.x;
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;

    const int warp_index_required = ONEBATCH_SIZE / 32;

    if (ty >= warp_index_required) {
        return;
    }

    float tmp[dim1 + dim2];
    __syncthreads();
    for (int i = 0; i < dim1; i++) {
        tmp[i] = shmem1[(dim1_aligned+SKEW) * 32 * ty + (dim1_aligned+SKEW) * tx + i];
    }
    for (int i = 0; i < dim2; i++) {
        tmp[dim1 + i] = shmem2[(dim2_aligned+SKEW) * 32 * ty + (dim2_aligned+SKEW) * tx + i];
    }
    __syncthreads();

    for (int i = 0; i < dim1 + dim2; i++) {
        shmem[(newDim + SKEW) * 32 * ty + (newDim + SKEW) * tx + i] = (__half)tmp[i];
    }
    for (int i = 0; i < SKEW; i++) {
        shmem[(newDim + SKEW) * 32 * ty + (newDim + SKEW) * tx + i + dim1 + dim2] = (__half)0.0f;
    }
    __syncthreads();
}

Part1を読んだ方であれば何をしているのか分かると思います.複雑そうに見えますが単純にconcatしているだけです.結合前の2つのベクトルは共に[要素][空白][要素][空白]......の構造を取っており,それを[要素1][要素2][空白][要素1][要素2][空白]......と組み直すのがこの関数の処理です.

NeRFの実装: (5)~(7): ボリュームレンダリングの順方向と逆方向

一気に逆方向まで処理してしまいます.最初に示したボリュームレンダリングをやるという話です.NNを通した結果,「色」と「密度」が得られています.説明の時に記した図を再掲します.

この式に代入するだけで計算できます.そして,NNへの逆伝播を行うために,「色」と「密度」への勾配の値を計算する必要があります.詳しい導出はしませんが,勾配は次の式で行えます.

 
\begin{align}
&\dfrac{\partial L}{\partial C_i} = T_i (1 - \exp(-\sigma_i\delta_i)) \dfrac{\partial L}{\partial Radiance}  \\
&\dfrac{\partial L}{\partial \sigma_i} = \delta_i (T_{i+1} * C_i - S_i) \dfrac{\partial L}{\partial Radiance} \\
&S_i = Radiance - \sum_{k=1}^{i} Radiance_k \\
\end{align}
ここで Radiance_kはk番目のサンプル点による最終的な放射輝度への寄与です.

後はこれを実装するだけです.やりましょう.

__global__ void CalculateOutputAndLoss
(
    const uint32_t nPixelBatch, NeRFInfo* SamplingInfo, 
    float* Pos,
    NeRFRay* Rays, float* Density, float* Color, 
    float* Output, PixelInfo* Target, float* Loss, bool NoLossCalc = false
) 
{
    uint32_t threadID = blockIdx.x * blockDim.x + threadIdx.x;
    if (threadID >= nPixelBatch) {
        return;
    }

    // 本スレッドでのポインタ等
    NeRFRay* Ray = &Rays[threadID];
    const uint32_t IdxBias = Ray->SampleBeginIdx;
    const uint32_t nSample = Ray->nSample;
    float* pos = Pos + 3 * IdxBias;
    float* c = Color + 3 * IdxBias;
    float* sigma = Density + IdxBias;
    float* out = Output + 3 * threadID;
    Vec3h ans = Target[threadID].Color;
    float* err = Loss + 3 * threadID;
    // サンプル点の厚み(次のサンプル点までの距離)
    float dt;
    // Throughput
    float T = 1.0f;

    for (int rgb = 0; rgb < 3; rgb++) {
        out[rgb] = 0.0f;
        if (!NoLossCalc) {
            err[rgb] = 0.0f;
        }
    }

    // Sigma Activation
    for (int s = 0; s < nSample; s++) {
        NeRF_SigmaExpActivation(sigma[s]); // exp activation
    }

    // Calc Output
    for (int s = 0; s < nSample; s++) {
        
        if (s != nSample - 1) {
            Vec3 ToNextSample =
            {
                pos[3 * (s + 1) + 0] - pos[3 * (s + 0) + 0],
                pos[3 * (s + 1) + 1] - pos[3 * (s + 0) + 1],
                pos[3 * (s + 1) + 2] - pos[3 * (s + 0) + 2],
            };
            dt = ToNextSample.length();
        }
        else {
            dt = 0.0f; // 最後のサンプルの厚みを0とする NOTE: これが正しいのかは分からない
        }

        // Radiance_i = T_i * c_i * (1 - exp(-sigma_i*dt_i)) 
        for (int rgb = 0; rgb < 3; rgb++) {
            out[rgb] += T * c[3*s + rgb] * (1.0f - expf(-sigma[s] * dt));
        }

        T *= expf(-sigma[s] * dt);
    }

    if (nSample == 0) {
        for (int rgb = 0; rgb < 3; rgb++) {
            out[rgb] = 0.0f;
            err[rgb] = 0.5f * (out[rgb] - (float)ans[rgb]) * (out[rgb] - (float)ans[rgb]);
        }
        return;
    }

    if (NoLossCalc) {
        return;
    }

    float dLdout[3] = {0.0f,0.0f,0.0f};

    // Calc Loss
    for (int rgb = 0; rgb < 3; rgb++) {
        Calc_HuborLoss(out[rgb], (float)ans[rgb], dLdout[rgb]);
    }

    // MSE
    for (int rgb = 0; rgb < 3; rgb++) {
        err[rgb] = 0.5f * (out[rgb] - (float)ans[rgb]) * (out[rgb] - (float)ans[rgb]);
    }

    // Calc dLdColor and dLdDensity
    float dLdc[3] = {0,0,0};
    float dLdsigma = 0.0f;
    float suffix[3] = {out[0], out[1], out[2]}; 

    T = 1.0f;

    for (int s = 0; s < nSample; s++) {
        if (s != nSample - 1) {
            Vec3 ToNextSample =
            {
                pos[3 * (s + 1) + 0] - pos[3 * (s + 0) + 0],
                pos[3 * (s + 1) + 1] - pos[3 * (s + 0) + 1],
                pos[3 * (s + 1) + 2] - pos[3 * (s + 0) + 2],
            };
            dt = ToNextSample.length();
        }
        else {
            dt = 0.0f; // 最後のサンプルの厚みを0とする NOTE: これが正しいのかは分からない
        }

        // dLdc_i = T_i * (1 - exp(-sigma_i*dt_i)) * dLdout
        for (int rgb = 0; rgb < 3; rgb++) {
            dLdc[rgb] = T * (1.0f - expf(-sigma[s] * dt)) * dLdout[rgb];
        }
        // Suffix_i := C_hat - C_1 - C_2 - ... - C_i
        for (int rgb = 0; rgb < 3; rgb++) {
            suffix[rgb] -= T * c[3 * s + rgb] * (1.0f - expf(-sigma[s] * dt));
        }
        // T_{i+1}
        T *= expf(-sigma[s] * dt);

        // dLdsigma = dt_i * (T_{i+1} * c_i - Suffix_i) * dLdout
        for (int rgb = 0; rgb < 3; rgb++) {
            dLdsigma += dt * (T * c[3*s + rgb] - suffix[rgb]) * dLdout[rgb];
        }
        // 保存
        for (int rgb = 0; rgb < 3; rgb++) {
            c[3*s + rgb] = dLdc[rgb];
            dLdc[rgb] = 0.0f;
        }
        NeRF_SigmaExpDerivative(dLdsigma, sigma[s]);
        sigma[s] = dLdsigma;
        dLdsigma = 0.0f;
    }
}

部分的にみましょう.

__global__ void CalculateOutputAndLoss
(
    const uint32_t nPixelBatch, NeRFInfo* SamplingInfo, 
    float* Pos,
    NeRFRay* Rays, float* Density, float* Color, 
    float* Output, PixelInfo* Target, float* Loss, bool NoLossCalc = false
) 
{
...

nPixelBatch: 使用したレイの本数(ピクセルの数)
SamplingInfo: サンプリング処理で保存した,そのレイに関わるサンプル点情報が「どこから」「何個」あるかを使用します
Pos: サンプル点間の距離がボリュームレンダリングに必要です
Rays, Density, Color: 名前の通りです.ただし,入力時は順方向のデータが入っているDensity, Colorには最終的に勾配のデータを格納して返します. Output: ボリュームレンダリングの計算結果のRGB色を格納します
Target: PixelInfo構造体には対応する教師データ(RGB色)が格納されています
Loss: 誤差情報を記録する配列です
NoLossCalc : 推論処理では誤差を計算する必要がありません.それを制御するフラグです.

// 本スレッドでのポインタ等
NeRFRay* Ray = &Rays[threadID];
const uint32_t IdxBias = Ray->SampleBeginIdx;
const uint32_t nSample = Ray->nSample;
float* pos = Pos + 3 * IdxBias;
float* c = Color + 3 * IdxBias;
float* sigma = Density + IdxBias;
float* out = Output + 3 * threadID;
Vec3h ans = Target[threadID].Color;
float* err = Loss + 3 * threadID;

配列におけるデータの内,処理中のスレッドで扱うレイに関するデータのポインタをあらかじめ計算しておくことで実装ミスが減ります.

// サンプル点の厚み(次のサンプル点までの距離)
float dt;
// Throughput
float T = 1.0f;

コメントの通りです.

for (int rgb = 0; rgb < 3; rgb++) {
    out[rgb] = 0.0f;
    if (!NoLossCalc) {
        err[rgb] = 0.0f;
    }
}

0初期化しておきましょう.推論でない場合は誤差も0初期化します.

// Sigma Activation
for (int s = 0; s < nSample; s++) {
    NeRF_SigmaExpActivation(sigma[s]); // exp activation
}

さて,NNの実装時には特殊であるため,レンダリング時にexp activationをするとしました.

MFFM_DEVICE void NeRF_SigmaExpActivation(float &sigma) {
    sigma = expf(clamp(sigma, -15.0f, 15.0f));
}

単純にexpを取っているだけです.なお,計算時の爆発を避けるため,指数部分を15で抑えています.

// Calc Output
for (int s = 0; s < nSample; s++) {
    
    if (s != nSample - 1) {
        Vec3 ToNextSample =
        {
            pos[3 * (s + 1) + 0] - pos[3 * (s + 0) + 0],
            pos[3 * (s + 1) + 1] - pos[3 * (s + 0) + 1],
            pos[3 * (s + 1) + 2] - pos[3 * (s + 0) + 2],
        };
        dt = ToNextSample.length();
    }
    else {
        dt = 0.0f; // 最後のサンプルの厚みを0とする NOTE: これが正しいのかは分からない
    }

    // Radiance_i = T_i * c_i * (1 - exp(-sigma_i*dt_i)) 
    for (int rgb = 0; rgb < 3; rgb++) {
        out[rgb] += T * c[3*s + rgb] * (1.0f - expf(-sigma[s] * dt));
    }

    T *= expf(-sigma[s] * dt);
}

ボリュームレンダリングの順方向です.
まずはサンプル点間の距離(先ほど示した図中の \delta_i)を計算しますが,最後のサンプル点については距離を定義できません.今回は最後のサンプル点については \delta_i = 0としました.
そして,図に示した式に代入してあげます.RGB独立に代入してあげていいです.

if (nSample == 0) {
    for (int rgb = 0; rgb < 3; rgb++) {
        out[rgb] = 0.0f;
        err[rgb] = 0.5f * (out[rgb] - (float)ans[rgb]) * (out[rgb] - (float)ans[rgb]);
    }
    return;
}

if (NoLossCalc) {
    return;
}

サンプル数が0の場合(そもそもAABBと交差しなかった,もしくはサンプリングをキャンセルした場合)には結果を{0,0,0}として返します.誤差計算はあんまり妥当ではないです.
推論処理である場合は今後の処理は不要なので返します.

float dLdout[3] = {0.0f,0.0f,0.0f};

// Calc Loss
for (int rgb = 0; rgb < 3; rgb++) {
    Calc_HuborLoss(out[rgb], (float)ans[rgb], dLdout[rgb]);
}

// MSE
for (int rgb = 0; rgb < 3; rgb++) {
   err[rgb] = 0.5f * (out[rgb] - (float)ans[rgb]) * (out[rgb] - (float)ans[rgb]);
}

ボリュームレンダリングの結果と教師画像のデータの誤差を計算します.誤差関数はHubor_Lossで固定してます.

// Calc dLdColor and dLdDensity
float dLdc[3] = {0,0,0};
float dLdsigma = 0.0f;
float suffix[3] = {out[0], out[1], out[2]}; 

T = 1.0f;

for (int s = 0; s < nSample; s++) {
    if (s != nSample - 1) {
        Vec3 ToNextSample =
        {
            pos[3 * (s + 1) + 0] - pos[3 * (s + 0) + 0],
            pos[3 * (s + 1) + 1] - pos[3 * (s + 0) + 1],
            pos[3 * (s + 1) + 2] - pos[3 * (s + 0) + 2],
        };
        dt = ToNextSample.length();
    }
    else {
        dt = 0.0f; // 最後のサンプルの厚みを0とする NOTE: これが正しいのかは分からない
    }

    // dLdc_i = T_i * (1 - exp(-sigma_i*dt_i)) * dLdout
    for (int rgb = 0; rgb < 3; rgb++) {
        dLdc[rgb] = T * (1.0f - expf(-sigma[s] * dt)) * dLdout[rgb];
    }
    // Suffix_i := Radiance_hat - Radiance_1 - Radiance_2 - ... - Radiance_i
    for (int rgb = 0; rgb < 3; rgb++) {
        suffix[rgb] -= T * c[3 * s + rgb] * (1.0f - expf(-sigma[s] * dt));
    }
    // T_{i+1}
    T *= expf(-sigma[s] * dt);

    // dLdsigma = dt_i * (T_{i+1} * c_i - Suffix_i) * dLdout
    for (int rgb = 0; rgb < 3; rgb++) {
        dLdsigma += dt * (T * c[3*s + rgb] - suffix[rgb]) * dLdout[rgb];
    }
    // 保存
    for (int rgb = 0; rgb < 3; rgb++) {
        c[3*s + rgb] = dLdc[rgb];
        dLdc[rgb] = 0.0f;
    }
    NeRF_SigmaExpDerivative(dLdsigma, sigma[s]);
    sigma[s] = dLdsigma;
    dLdsigma = 0.0f;
}

順方向と同じようにして,先程示した計算式に代入しているだけです.式中の S_iは実装中のSuffix_iのことです.NeRF_SigmaExpDerivativeはexp Activationの逆方向です.

MFFM_DEVICE void NeRF_SigmaExpDerivative(float& dLdsigma, float sigma) {
    dLdsigma = dLdsigma * sigma;
}

これでボリュームレンダリングの逆方向が終わりました.dLdcとdLdsigmaは処理が終わるごとに0に戻してあげてください.

NeRFの実装: (8): NNを誤差逆伝播

さて,ボリュームレンダリングの逆方向の処理を行い,NNには「色」と「密度」に関する勾配情報が入ってきました.ということで,この勾配情報を元に誤差逆伝播を行います.InstantNGPの論文にて推されているNNの構造を再掲します.

順方向時にはDensity MLP -> Color MLPの順番で処理をしたので逆伝播時には逆に処理します.

template <const uint32_t indim_den, const uint32_t hiddendim_den, const uint32_t outdim_den, const uint32_t nHiddenLayer_den,
          const uint32_t indim_col, const uint32_t hiddendim_col, const uint32_t outdim_col, const uint32_t nHiddenLayer_col>
__global__ void MFFM_NeRF_Backward
(
    const int epoch,
    NeRFInfo* SamplingInfo, EncoderInfo* EncInfo_den, EncoderInfo* EncInfo_col,
    Optimize Optim,
    Encoder Encoder_den, Activation ActHid_den, Activation ActOut_den,
    Encoder Encoder_col, Activation ActHid_col, Activation ActOut_col,
    float* Pos, float* Dir,
    __half* weights_den, __half* buffers_den,
    __half* weights_col, __half* buffers_col,
    float* dLdDensity, float* dLdColor,
    float* LossDerivativeSumALL_den, float* AdditionalParam_den,
    float* LossDerivativeSumALL_col, float* AdditionalParam_col,
    float* dLdInput_den = nullptr
)
{
    const int BatchSize = SamplingInfo->nSamples;

    // 即値
    constexpr uint32_t indim_den_aligned = next_multiple(indim_den, TENSOR_ROW);
    constexpr uint32_t hiddendim_den_aligned = next_multiple(hiddendim_den, TENSOR_ROW);
    constexpr uint32_t outdim_den_aligned = next_multiple(outdim_den, TENSOR_ROW);
    constexpr uint32_t indim_col_aligned = next_multiple(indim_col, TENSOR_ROW);
    constexpr uint32_t hiddendim_col_aligned = next_multiple(hiddendim_col, TENSOR_ROW);
    constexpr uint32_t outdim_col_aligned = next_multiple(outdim_col, TENSOR_ROW);

    const uint32_t maxdim_col = max(indim_col_aligned, max(hiddendim_col_aligned, outdim_col_aligned));

    extern __shared__ __half shmem[];
    __half* intermediate_den = shmem;
    __half* intermediate_col = shmem;
    __half* intermediate_InDir = shmem + (outdim_den_aligned + SKEW) * ONEBATCH_SIZE;
    __half* LossDerivativeSumOfBlock = shmem + (maxdim_col + SKEW) * ONEBATCH_SIZE;

    ////////////////////////////////////////////// Color MLP ////////////////////////////////////////////////////////////////////////

    const int newDim = outdim_col_aligned + SKEW;
    load_input(outdim_col, newDim, dLdColor, intermediate_col);

    Kernel_Debug_train_backward<indim_col, hiddendim_col, outdim_col, nHiddenLayer_col>
    (
        BatchSize, Optim, ActHid_col, ActOut_col, epoch, 
        intermediate_col, LossDerivativeSumOfBlock, weights_col, buffers_col, LossDerivativeSumALL_col, AdditionalParam_col, false
    );

    // Divide
    DivideShmem<outdim_den, indim_col - outdim_den>(intermediate_den, intermediate_InDir, intermediate_col);
    
    // 
    // 入力をロード
    if (Encoder_col == Encoder::HashGrid || Encoder_den == Encoder::UniformGrid) {
        load_input(3, 3, Dir, intermediate_InDir + (indim_col_aligned + SKEW) * ONEBATCH_SIZE);
    }

    BackPropEncoder<indim_col_aligned>(epoch, Encoder_col, *EncInfo_col, intermediate_InDir, SamplingInfo->AABB_pMin, SamplingInfo->AABB_pMax);

    ////////////////////////////////////////////// END: Color MLP ////////////////////////////////////////////////////////////////////
    
    // この時点で破壊が許容される領域:
    // ・LossDerivativeSumOfBlock
    // ・intermediate_InDir
    // ・
    //
    
    // dLdSigmaの誤差伝搬
    const int threadID_BlockScope = 32 * threadIdx.y + threadIdx.x;
    const int threadID_global = ONEBATCH_SIZE * blockIdx.x + threadID_BlockScope;
    if (32 * threadIdx.y + threadIdx.x < ONEBATCH_SIZE) {
        intermediate_den[(outdim_den_aligned + SKEW) * threadID_BlockScope] = intermediate_den[(outdim_den_aligned + SKEW) * threadID_BlockScope] + (__half)dLdDensity[threadID_global];
    }

    __syncthreads();
    ////////////////////////////////////////////// Density MLP ///////////////////////////////////////////////////////////////////////
    
    Kernel_Debug_train_backward<indim_den, hiddendim_den, outdim_den, nHiddenLayer_den>
        (
            BatchSize, Optim, ActHid_den, ActOut_den, epoch,
            intermediate_den, LossDerivativeSumOfBlock, weights_den, buffers_den, LossDerivativeSumALL_den, AdditionalParam_den, false
        );

    // 入力をロード
    if (Encoder_den == Encoder::HashGrid || Encoder_den == Encoder::UniformGrid) {
        load_input(3, 3, Pos, intermediate_den + (indim_den_aligned + SKEW) * ONEBATCH_SIZE);
    }

    BackPropEncoder<indim_den_aligned>(epoch, Encoder_den, *EncInfo_den, intermediate_den, SamplingInfo->AABB_pMin, SamplingInfo->AABB_pMax);
    ////////////////////////////////////////////// END: Density MLP //////////////////////////////////////////////////////////////////

    // Density MLPの入力側の誤差を保存する
    if (dLdInput_den != nullptr) {
        store_intermediate<float>(indim_den_aligned + SKEW, indim_den, intermediate_den, dLdInput_den);
    }
}

やるだけと言ってしまえばそうなのですが,軽く説明しておきます.

template <const uint32_t indim_den, const uint32_t hiddendim_den, const uint32_t outdim_den, const uint32_t nHiddenLayer_den,
                  const uint32_t indim_col, const uint32_t hiddendim_col, const uint32_t outdim_col, const uint32_t nHiddenLayer_col>

順伝播時と同じです.("in", "out"は順伝播時の定義と同じです)

__global__ void MFFM_NeRF_Backward
(
    const int epoch,
    NeRFInfo* SamplingInfo, EncoderInfo* EncInfo_den, EncoderInfo* EncInfo_col,
    Optimize Optim,
    Encoder Encoder_den, Activation ActHid_den, Activation ActOut_den,
    Encoder Encoder_col, Activation ActHid_col, Activation ActOut_col,
    float* Pos, float* Dir,
    __half* weights_den, __half* buffers_den,
    __half* weights_col, __half* buffers_col,
    float* dLdDensity, float* dLdColor,
    float* LossDerivativeSumALL_den, float* AdditionalParam_den,
    float* LossDerivativeSumALL_col, float* AdditionalParam_col,
    float* dLdInput_den = nullptr
)

epoch: イテレーション回数
SamplingInfo: 順伝播時の説明参照
EncInfo: 無視してください
Optimize: 最適化関数の設定です.

enum class Optimize {
    GD,
    Adam
};

LossDerivativeSumALL, AdditionalParamは詳しい話はPart1, Part2を参照してください(パラメーターの勾配とAdamのmとvです)
dLdInput_den: 無視してください(入力側への勾配伝搬です)

後は順伝播の逆向きをやるだけです.基本的に実装を読めばそのまんまの処理をしていることが分かるはずです.

...
////////////////////////////////////////////// Color MLP ////////////////////////////////////////////////////////////////////////

const int newDim = outdim_col_aligned + SKEW;
load_input(outdim_col, newDim, dLdColor, intermediate_col);

Kernel_Debug_train_backward<indim_col, hiddendim_col, outdim_col, nHiddenLayer_col>
(
    BatchSize, Optim, ActHid_col, ActOut_col, epoch, 
    intermediate_col, LossDerivativeSumOfBlock, weights_col, buffers_col, LossDerivativeSumALL_col, AdditionalParam_col, false
);

// Divide
DivideShmem<outdim_den, indim_col - outdim_den>(intermediate_den, intermediate_InDir, intermediate_col);
...

「色」の勾配をshared memoryにロードし(load_input()),MLP誤差逆伝播をし(Kernel_Debug_train_backward()),そして順伝播時にはColor MLPの入力を「Density MLPの出力」に「エンコードしたレイの方向ベクトル」を結合したものなので,ここでばらしてあげます

/*
 * 1: shmem1
 * 2: shmem2
 * s: SKEW
 * 12s12s12s12s12s12s....12s12s12s12s12s -> 1s1s1s1s1s...1s1s1s1s2s2s2s2s2s2s...2s2s2s2s
 * 番地的にはshmem = shmem1
 */
template <const uint32_t dim1, const uint32_t dim2>
MFFM_DEVICE void DivideShmem(__half* shmem1, __half* shmem2, __half* shmem) {

    // 即値
    constexpr uint32_t dim1_aligned = next_multiple(dim1, TENSOR_ROW);
    constexpr uint32_t dim2_aligned = next_multiple(dim2, TENSOR_ROW);
    constexpr uint32_t nowDim = dim1 + dim2;
    constexpr uint32_t nowDim_aligned = next_multiple(nowDim, TENSOR_ROW);

    const int bx = blockIdx.x;
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;

    const int warp_index_required = ONEBATCH_SIZE / 32;

    if (ty >= warp_index_required) {
        return;
    }

    float tmp[dim1 + dim2];
    __syncthreads();

    for (int i = 0; i < dim1 + dim2; i++) {
        tmp[i] = shmem[(nowDim_aligned+SKEW) * 32 * ty + (nowDim_aligned + SKEW) * tx + i];
    }
    __syncthreads();

    for (int i = 0; i < dim1; i++) {
        shmem1[(dim1_aligned + SKEW) * 32 * ty + (dim1_aligned + SKEW) * tx + i] = (__half)tmp[i];
    }
    for (int i = 0; i < dim2; i++) {
        shmem2[(dim2_aligned + SKEW) * 32 * ty + (dim2_aligned + SKEW) * tx + i] = (__half)tmp[i + dim1];
    }
    for (int i = 0; i < SKEW; i++) {
        shmem1[(dim1_aligned + SKEW) * 32 * ty + (dim1_aligned + SKEW) * tx + dim1 + i] = (__half)0.0f;
        shmem2[(dim2_aligned + SKEW) * 32 * ty + (dim2_aligned + SKEW) * tx + dim2 + i] = (__half)0.0f;
    }
    __syncthreads();
}

Concatの逆の処理をしているだけなので説明は省略します.さて,今回は方向ベクトルのエンコーダー誤差逆伝播は不要ですが,仮に必要な場合があったとして,その場合は処理を行う必要があります.そこで

...
// 入力をロード
if (Encoder_col == Encoder::HashGrid || Encoder_den == Encoder::UniformGrid) {
    load_input(3, 3, Dir, intermediate_InDir + (indim_col_aligned + SKEW) * ONEBATCH_SIZE);
}

BackPropEncoder<indim_col_aligned>(epoch, Encoder_col, *EncInfo_col, intermediate_InDir, SamplingInfo->AABB_pMin, SamplingInfo->AABB_pMax);
...

例えばMultiresolution Hash Encodingの誤差逆伝播時には入力データを必要とするので,ロードしてあげます.BackPropEncoder()においてエンコーダー誤差逆伝播を行います.

template <const uint32_t indim_aligned>
MFFM_DEVICE void BackPropEncoder(const int epoch, Encoder encoder, EncoderInfo& Info, __half* intermediate, const float3 InputRangeMin = { 0,0,0 }, const float3 InputRangeMax = { 1,1,1 },
    float* MHEBufferSTU = nullptr, unsigned int* MHEBufferIdxHT = nullptr) {
    const int bx = blockIdx.x;
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;

    switch (encoder) {
    case Encoder::None:
        // NOP
        break;
    case Encoder::Frequency:
        // NOP
        break;
    case Encoder::SH:
        // NOP
        break;
    case Encoder::HashGrid:
        MHE::Propagate_backward(InputRangeMin, InputRangeMax, intermediate, intermediate + (indim_aligned + SKEW) * ONEBATCH_SIZE, intermediate + (indim_aligned + SKEW) * ONEBATCH_SIZE + 3 * ONEBATCH_SIZE);
        //MHE2::Propagate_backward_inside_kernel<MHE_F>(*Info.MHEInfo, InputRangeMin, InputRangeMax, intermediate, intermediate + (indim_aligned + SKEW) * ONEBATCH_SIZE, intermediate + (indim_aligned + SKEW) * ONEBATCH_SIZE + 3 * ONEBATCH_SIZE);
        break;
    case Encoder::UniformGrid:
        UGE::BackPropagate(*Info.UGEInfo, intermediate + (indim_aligned + SKEW) * ONEBATCH_SIZE, intermediate);
        break;
    default:
        printf("Invalid Encoding type\n");
        break;
    }
    __syncthreads();
}

Multiresolution Hash EncodingのPropagate_backward()はPart2に示しました.(詳しくはPart1,2を参照してください)あとはこれの組み合わせです.これでNNの誤差逆伝播が終わり,あとはパラメーターを最適化するだけです.

NeRFの実装: (9): NN最適化

説明は省略します.Part1, Part2を参照してください.

// 最適化のみ
template <const uint32_t indim, const uint32_t hiddendim, const uint32_t outdim, const uint32_t nHiddenLayer>
__global__ void MFFM_Optimize(bool willOptimizeInsideOneKernel, const uint32_t BatchSize, EncoderInfo* EncInfo, Encoder Encoder, Optimize Optim, const int epoch, __half* weights, float* LossDerivative, float* AdditionalParam) {
    Optimize_MLP<indim, hiddendim, outdim, nHiddenLayer>(willOptimizeInsideOneKernel, BatchSize, Optim, epoch, weights, LossDerivative, AdditionalParam);
    OptimizeEncoder(Encoder, *EncInfo, Optim, BatchSize, epoch);
}
MFFM_DEVICE void OptimizeEncoder(Encoder encoder, EncoderInfo &Info, Optimize Optim, const uint32_t BatchSize, const int epoch) {
    switch (encoder) {
    case Encoder::None:
        // NOP
        break;
    case Encoder::Frequency:
        // NOP
        break;
    case Encoder::SH:
        // NOP
        break;
    case Encoder::HashGrid:
        MHE::Optimization(BatchSize, Optim, epoch);
        //MHE2::Optimization(Info.MHEInfo, BatchSize, Optim, epoch);
        break;
    case Encoder::UniformGrid:
        UGE::Optimization(epoch, BatchSize, Optim, *Info.UGEInfo);
        break;
    default:
        printf("Invalid Encoding type\n");
        break;
    }
    __syncthreads();
}

NeRFの実装: (10): Occupancy Gridの最適化

さて,サンプリング処理にも出てきましたが,まずそもそもOccupancy Gridとはなんなんだっていう話です.InstantNGPの論文のAppendix-E.2に書かれているのですが,簡単に言うと「空間に「物体」が存在するかどうかの情報を保持するボクセル」です.2種類のOccupancy Gridがあり,一つは「「物体」が存在するか」という01の情報で,もう一つは「どれくらいの「密度」であるか」というfloatの情報です.前者をBit_OccupancyGridとし,後者をFloat_OccupancyGridとします.どういう設計かの概念を図にしておきます.

グリッド内部のあるランダムに選んだサンプル点における「密度」を計算し,それをFloat_OccupancyGridに反映させた結果,左上のようなFloat_OccupancyGridが得られて,そしてそれを元にして何かしらの判定を行った結果その右に示すBit_OccupancyGridが得られたとします.そして,サンプリング処理で以下の処理がありましたね

...
// Occupancy Gridにおいて「物体が無い」と判断された場合はスキップする
if (epoch > 16 && !Grids[0].is_Occupied(s_Pos)) {
    index++;
    continue;
}
...

レイ上の点をサンプリングした際にそのサンプル点が果たしてどれだけ不透明であるかをOccupancyGridを用いて判定します.これによって弾かれたサンプル点が図中の青い点で,弾かれずに採用されたサンプル点が橙の点で表されています.このようにすることで,「有効なサンプル点」を増やすことが出来てサンプル効率の向上につながる,というのが簡単なお気持ちです.何もないところサンプリングしても何も得られませんので(屈折率とかの場を作ってそれに応じて屈折させてみるのも面白そうですけどね).

さて,じゃあOccupancy Gridをどうやって構築してあげたらいいんだって話です.InstantNGPの論文が言うことには 16イテレーションごとにOccupancy Gridを更新する
(1): Float_OccupancyGridの各値を0.95倍する
(2): M個のグリッドを選んで.max(現在の値,グリッド内部のサンプル点をNNに入れた結果得られる「密度」の値)へと更新する
(3): Float_OccupancyGridの各値が \frac{0.01 * 1024}{\sqrt{3}}より小さければ対応するBit_OccupancyGridを0,そうでければ1にする
なお,(2)のMは最初の256イテレーションは全てのグリッド数だけ一様サンプリング,それ以降は全てのグリッド数の半分を,半分は全てのグリッドから一様サンプリングし,残り半分はBit_OccupancyGridの中身が1になっているグリッドから一様サンプリングする.

ということらしいです.ですが実装コストの面で少しだけ簡略化します.(2)のサンプリングを,サンプル数は同じくして,ただしイテレーション回数に関わらず全てのグリッドから一様サンプリングします(なお,棄却法を使えば論文通りの実装が出来ます).
また,(3)の通りに実装すると学習が不安定だったので,今回の実装では別の判定方法を取りました.

では,実装を見ていきましょう.

Occupancy Gridの構造体

実装を容易にするためにOccupancy Gridの構造体を作っておきます.

struct OccupancyGrid {
    int GridID;
    bool* BitGrid;
    float* FloatGrid;
    Vec3 pMin;
    Vec3 pMax;
    float ave;

    MFFM_HOST_DEVICE OccupancyGrid(const int GridID = 0, const Vec3 pMin = { 0,0,0 }, const Vec3 pMax = { 1,1,1 }) : GridID(GridID), pMin(pMin), pMax(pMax), ave(0.0f) {}

    MFFM_DEVICE unsigned int PosToIdx(Vec3& pos) {
        return CalcMortonCode7bit(pos, pMin, pMax);
    }
    MFFM_DEVICE bool& Bit_at(Vec3& pos) {
        return BitGrid[CalcMortonCode7bit(pos, pMin, pMax)];
    }
    MFFM_DEVICE float& Float_at(Vec3& pos) {
        return FloatGrid[CalcMortonCode7bit(pos, pMin, pMax)];
    }
    MFFM_DEVICE bool is_Occupied(Vec3& pos) {
        return BitGrid[CalcMortonCode7bit(pos, pMin, pMax)];
    }
};

GridID: 論文中では複数のOccupancyGridを作ることを考慮しており,今回の実装でも番号だけ振っておきます.なお今回は1つしか使用しません.
BitGrid, FloatGrid: Bit_OccypancyGrid と Float_OccupancyGridです.
pMin, pMax: Gridの拡がる空間を表します(AABBです). ave: FloatGridの各グリッドの中身の値の平均値です.先ほど示した(3)の判定に使用します.

初期化は次のように行います.

__global__ void SetUpOccupancyGrid(int ID, NeRFInfo* Info, OccupancyGrid* Grids, bool* BitGridPtr, float* FloatGridPtr) {
    Grids[ID].GridID = ID + 1;
    float power = powf(2, ID);

    Grids[ID].BitGrid = BitGridPtr;
    Grids[ID].FloatGrid = FloatGridPtr;
    Grids[ID].pMin = { Info->AABB_pMin.x * power, Info->AABB_pMin.y * power,Info->AABB_pMin.z * power };
    Grids[ID].pMax = { Info->AABB_pMax.x * power, Info->AABB_pMax.y * power,Info->AABB_pMax.z * power };
}

BitGridPtr, FloatGridPtrはカーネル外部でthrust::device_vectorで定義している配列のポインタです.
pMinとpMaxは論文通りに実装しましたが,今回は結局1つめのOccupancyGridのみ使用するのでGridの領域は[-1,1]^3となります.


さて,Occupancy Gridの中身は単純な3次元配列を1次元に直したような配列では持たずに,MortonCode(Z-order curve)に載せます.

ja.wikipedia.org

今回は一辺が128要素のOccupancyGridを使用するので,各方向7bitで処理すればよいです.MortonCodeを計算する関数を置いておきます.

// 1111 -> 1001001001
KGYK_HOST_DEVICE unsigned int ExpandBits(unsigned int v) {
    v = (v * 0x00010001u) & 0xFF0000FFu;
    v = (v * 0x00000101u) & 0x0F00F00Fu;
    v = (v * 0x00000011u) & 0xC30C30C3u;
    v = (v * 0x00000005u) & 0x49249249u;
    return v;
}

これは元のビット列に2マス隙間を開ける関数です.そして,これを用いて次のようにすることで各方向128分割したグリッドのインデックスを32bitで表現できます.

MFFM_DEVICE unsigned int CalcMortonCode7bit(Vec3 position, Vec3 pMin, Vec3 pMax) {

    unsigned long long int MortonCode;

    // positionを[0, 1]^3の範囲に圧縮する
    float x = normalize(position.x, pMin.x, pMax.x, 0.0f, 1.0f);
    float y = normalize(position.y, pMin.y, pMax.y, 0.0f, 1.0f);
    float z = normalize(position.z, pMin.z, pMax.z, 0.0f, 1.0f);

    x = min(max(x * 128, 0.0f), 127.0f);
    y = min(max(y * 128, 0.0f), 127.0f);
    z = min(max(z * 128, 0.0f), 127.0f);
    unsigned int xx = ExpandBits((unsigned int)x);
    unsigned int yy = ExpandBits((unsigned int)y);
    unsigned int zz = ExpandBits((unsigned int)z);
    MortonCode = xx * 4 + yy * 2 + zz;
    return MortonCode;
}

この関数自体は以前にLinear-BVHというデータ構造を実装した際に参考にした次の記事に載っていたものを元にしています.

developer.nvidia.com

このようにしてインデックスを定義してあげることで,グリッド内で近いところにある2ブロックはメモリ上で近いところとなる,というわけです.さて,OccupancyGridの要素へのアクセスは,(座標)をMortonCodeに変換し,そのMortonCodeをインデックスとして指定するだけです.それらの処理がPosToIdx(), Bit_at(), Float_at()となっています.そして先に書いておくと,is_Occupied()はBitGridの要素が1であるかを判定しているだけです(図に書いた通りの処理です).

では更新処理を書いてあげましょう.OccupancyGridの初期値は全て0です.

(1): すべて0.95培

// 全てのブロックの値を0.95倍する
// 全てのブロック数スレッドを立てる
__global__ void NeRF_DecayOccupancyGrid(const int K, OccupancyGrid* Grids) {
    uint32_t threadID = blockIdx.x * blockDim.x + threadIdx.x;
    if (threadID >= K * 128 * 128 * 128) {
        return;
    }

    const uint32_t GridID = threadID / (128 * 128 * 128);
    const uint32_t BlockID = threadID % (128 * 128 * 128);
    Grids[GridID].FloatGrid[BlockID] *= 0.95f;
    if (BlockID == 0) {
        Grids[GridID].ave *= 0.95f;
    }
}

128 * 128 * 128 * Kだけスレッドを立てて処理します.(K: OccupancyGridの数.今回は1個としています)
単純に0.95掛けて行っているだけです.また,平均値aveにも0.95を掛けておきます(各Grid内のブロック1つ目を担当するスレッドが行います)

(2): M個サンプリングし,色々やる

この処理は中々めんどくさいです.ランダムに選ばれたM個のブロックにスレッドを割り振り,各スレッドが自分の担当するOccupancyGridにおけるあるブロック内部にあるランダムな点をサンプリングし,NNモデルを通して「密度」の値を得るのですが,M個のサンプリングをまずは行う必要があります.これは簡単のため一様に選ばせます.つまり1~128 * 128 * 128 * Kの順列をシャッフルして最初のM個を選択します.

...
// (2)
// M個のブロックを選ぶ(ここでは前半と後半で分けずに全て一様サンプリングする)
const uint32_t M = (epoch <= 256) ? nBlockALL : nBlockALL / 2;
...
static thrust::device_vector<int>   Permutaion(nBlockALL);
if (epoch == 16) {
    thrust::sequence(Permutaion.begin(), Permutaion.end());
}

if (epoch > 256) {
    thrust::default_random_engine g;
    g.seed(epoch + 1);
    thrust::shuffle(Permutaion.begin(), Permutaion.end(), g);
}
NeRF_SamplePositionInsideGrid <<<M / 512, 512, 0, Stream >>> (epoch, K, thrust::raw_pointer_cast(&Grids[0]), thrust::raw_pointer_cast(&Permutaion[0]), thrust::raw_pointer_cast(&Position[0]));
gpuErrchk(cudaGetLastError());
gpuErrchk(cudaStreamSynchronize(Stream));
...

Mの決定については論文に従います.また,nBlockALL = 128 * 128 * 128 * Kです.Permutation配列をシャッフルして最初のM個を選択することによりインデックスのサンプリングが完了します.なお,最初の256イテレーションは結局全部サンプリングするのでシャッフルしません

Occupancy Gridのブロックのサンプリングが終わったら今度はブロック内部の点をサンプリングします.ここはCUDAカーネルでM個並列に行います.これは単純に自分の担当しているOccupancyGridにおけるブロックが,NeRFを生成するAABB内部ではどの領域に対応しているのかを計算し,xyz各軸方向に対して乱数を用いて座標を選ぶというやり方で行きます.

// M個グリッドをサンプリングする
// M個のスレッドを立てる
// Indexにはサンプリングするグリッド上の座標が1次元のインデックスで記録されている
// よってグリッド上では(x,y,z)をIndexから計算する
// Indexに入っているインデックスを使用しても実際に(x,y,z)にはアクセスできないことに注意(MortonCodeでOccuoancyGridのインデックスが定義されているためである)
// そのため,サンプル点から得たdensityを元に更新するOccupancyGridのインデックスはサンプル点の座標から計算する
__global__ void NeRF_SamplePositionInsideGrid(const int seed, const int K, OccupancyGrid* Grids, const int* Index, float* Position) {
    uint32_t threadID = blockIdx.x * blockDim.x + threadIdx.x;
    if (threadID >= K * 128 * 128 * 128) {
        return;
    }
    const int Idx = Index[threadID];

    const uint32_t GridID = Idx / (128 * 128 * 128);
    const uint32_t BlockID = Idx % (128 * 128 * 128);

    // サンプルするグリッドの最小座標
    Vec3 pMinGrid = Grids[GridID].pMin;
    Vec3 pMaxGrid = Grids[GridID].pMax;

    // 今回サンプルする領域
    const uint32_t IdxX = BlockID % 128;
    const uint32_t IdxY = (BlockID / 128) % 128;
    const uint32_t IdxZ = (BlockID / 128 / 128);

    // サンプルするブロックの一辺の長さ(Assumption: ブロックは立方体)
    const float Range = (pMaxGrid.x - pMinGrid.x) / 128;

    // サンプルするブロックの最小座標
    const float Xmin = pMinGrid.x + Range * IdxX;
    const float Ymin = pMinGrid.y + Range * IdxY;
    const float Zmin = pMinGrid.z + Range * IdxZ;

    // RNG
    curandState state;
    curand_init(seed, threadID, 0, &state);

    // (0,1]
    float t = curand_uniform(&state);

    // サンプルする座標
    const float s_PosX = Xmin + t * Range;
    const float s_PosY = Ymin + t * Range;
    const float s_PosZ = Zmin + t * Range;

    // MLP入力のthreadID番目にサンプル点の情報を登録
    Position[3 * threadID + 0] = s_PosX;
    Position[3 * threadID + 1] = s_PosY;
    Position[3 * threadID + 2] = s_PosZ;
}

コメントに書いているとおりです.ちなみに本当はこの実装はあまりよろしくなくて,使用するOccupancyGridの拡がっている領域が[-1, 1]^3であり,これはNeRFのAABBと(偶然!)一致しているため,特に正規化を行わずして直接サンプル点の座標の「密度」を評価できます.真似はしない方が良いです.

サンプル点の座標が得られたため,「密度」のみを推定するNNに入れておきます.

template <const uint32_t indim_den, const uint32_t hiddendim_den, const uint32_t outdim_den, const uint32_t nHiddenLayer_den>
__global__ void MFFM_NeRF_InferenceDensityOnly
(
    NeRFInfo* SamplingInfo, EncoderInfo* EncInfo_den,
    Encoder Encoder_den, Activation ActHid_den, Activation ActOut_den,
    float* Pos,
    __half* weights_den,
    float* Density)
{
    // 即値
    constexpr uint32_t indim_den_aligned = next_multiple(indim_den, TENSOR_ROW);
    constexpr uint32_t hiddendim_den_aligned = next_multiple(hiddendim_den, TENSOR_ROW);
    constexpr uint32_t outdim_den_aligned = next_multiple(outdim_den, TENSOR_ROW);

    const uint32_t maxdim_den = max(indim_den_aligned, max(hiddendim_den_aligned, outdim_den_aligned));

    extern __shared__ __half shmem[];
    __half* intermediate_den = shmem;
    //////////////////////////////////////////////////////// Density MLP //////////////////////////////////////////////////////////////////////////

    // 入力のロード
    // エンコーダー無しの場合はこの時点でSKEWを与える
    // エンコーダーありの場合はSKEWを与えない(ただのコピー)
    if (Encoder_den == Encoder::None) {
        load_input(indim_den, indim_den_aligned + SKEW, Pos, intermediate_den);
    }
    else {
        load_input(3, 3, Pos, intermediate_den);
    }

    // Encode
    Encode<indim_den_aligned>(Encoder_den, *EncInfo_den, intermediate_den, true, SamplingInfo->AABB_pMin, SamplingInfo->AABB_pMax);
    // Density MLP
    Kernel_Debug_inference<indim_den, hiddendim_den, outdim_den, nHiddenLayer_den>(ActHid_den, ActOut_den, intermediate_den, weights_den);
    // Density MLPの出力の1つめの要素をDensityとして保存
    store_intermediate<float>(outdim_den_aligned + SKEW, 1, intermediate_den, Density);
    //////////////////////////////////////////////////////// END: Density MLP //////////////////////////////////////////////////////////////////////
}

やっていることは通常のNNの順伝播の実装と同じです.また,バッチサイズは言うまでもありませんがMです.

これによって「密度」が推定されたのでこれを用いてFloat_OccupancyGridを更新します.

// サンプリングしたM個のブロックを更新する
// M個のスレッドを立てる
// Indexには単純な1次元化した配列のインデックスが格納されている: [x][y][z] -> [idx]
// それゆえ,OccupancyGridへのアクセスはサンプル点の座標から行う(MortonCodeでアクセス)
__global__ void NeRF_UpdateOccupancyGrid_FloatGrids(const int M, OccupancyGrid* Grids, const float* Density, const int* Index, const float* Pos) {
    uint32_t threadID = blockIdx.x * blockDim.x + threadIdx.x;
    if (threadID >= M) {
        return;
    }
    Vec3 SampledPos = { Pos[3 * threadID],  Pos[3 * threadID + 1],  Pos[3 * threadID + 2] };
    int Idx = Index[threadID];
    const uint32_t GridID = Idx / (128 * 128 * 128);
    const float old = Grids[GridID].Float_at(SampledPos);
    Grids[GridID].Float_at(SampledPos) = max(old, Density[threadID]);

    atomicAdd(&Grids[GridID].ave, (Grids[GridID].Float_at(SampledPos) - old) / (128 * 128 * 128));

}

サンプリング時に使用したM個のOccupancyGridのブロックを更新します.処理自体は先ほど示した通り,現在の値と推定値のmaxに更新する,です.今思ったのですが,Exp Activationし忘れてましたね.これが論文通りの(3)が上手く動かない原因かもしれません.あとで試します.ここで更新差分を元にして,aveも更新しておきます.注意点として,Indexではなくて座標(Pos)からMortonCodeを計算してOccupancyGridにアクセスしないとダメです(OccupancyGridの構築の仕方を思い出してください)(1敗).以上でFloat_OccupancyGridの処理は終わりです.

(3): Bit_OccupancyGridの更新

これは簡単です.閾値を超えているか超えていないかで0と1にするだけです.

// 全てのグリッドを更新する
// 全てのグリッド数だけスレッドを立てる.
__global__ void NeRF_UpdateOccupancyGrid_BitGrids(const int K, OccupancyGrid* Grids) {
    uint32_t threadID = blockIdx.x * blockDim.x + threadIdx.x;
    if (threadID >= K * 128 * 128 * 128) {
        return;
    }

    const uint32_t GridID = threadID / (128 * 128 * 128);
    const uint32_t BlockID = threadID % (128 * 128 * 128);

    //constexpr float thres = 0.1f * 1024.0f / 1.732050807f;
    const float thres = 2.0f * Grids[GridID].ave;

    if (Grids[GridID].FloatGrid[BlockID] > thres) {
        Grids[GridID].BitGrid[BlockID] = true;
    }
    else {
        Grids[GridID].BitGrid[BlockID] = false;
    }
}

これはFloatGridの値がthres,ここではaveの2倍を超えているかどうかで閾値判定し,対応するBit_OccupancyGridの値を書き換えています.ちなみにこの2倍という値は特に意味はなく,学習が安定した値というだけです.ちなみにこの設定を間違えるとCUDAカーネルが実行時エラー出してプログラムが落ちます(例外処理してないので……).以上でOccupancy Gridの処理は終わりです.

ちなみに

Occupancy Gridを実装しなくてもある程度絵が出ると思っていましたが意外とそういうことはなく,Occupancy Gridを実装しないと学習精度がかなり低い状態が続きます.なので頑張って最後まで実装しきってください.Occupancy Grid自体は強力な手法です.

NeRFの実行

本当にお疲れ様です.以上でNeRFの実装は終わりました.本当はフロントエンドの実装も見せるべきなのですが,かなり説明が大変な複雑化をしているので(主にGUIのせい)ここでは説明を省略します.自分で実装してみると途方に暮れることは少なくともなく,これまで説明したパーツに相当するものを用いてちゃんと設計できると思います.意外とバグりやすいのでちゃんと実装ノート等に纏めた方が良いです.
さて,NeRFの実装が終わったのでどういう風に学習してくれるかを確認しましょう.まずはお決まりのLegoの空間を近似します.学習データはNVIDIAの出しているInstant NeRFのデモソフトにくっついてきたものを使用します.

この子たちですね.jsonファイルで画像とカメラの姿勢等を指定してファイルに読み込ませます.まずは改めて実行の設定を確認しましょう(図中に示したColor MLPの構成と少し異なります).
ボリュームレンダリングに関わる設定
NERF_MAX_SAMPLE_PER_RAY: 512

学習データ
画像: 100枚
サイズ: 800x800

NNの設定
DensityMLP:
InDim: 32
HiddenDim: 64
OutDim: 16
nHiddenLayer: 1
活性化関数: ReLU
エンコーダー: HashGrid
HashGrid: L = 16, F = 2, T = 220, Nmin = 2, b = 2.2

ColorMLP:
InDim: 32
HiddenDim: 64
OutDim: 3
nHiddenLayer: 1
活性化関数: 入力層と隠れ層: ReLU / 出力層: Sigmoid(Logistic)
エンコーダー: Spherical Harmnic ( l \leq 3)
最適化関数: Adam
誤差関数; Hubor_Loss
学習率: 0.07

NeRFの設定
AABB: [-1,1]^3
OccupancyGridの数: 1個 / 領域: [-1,1]^3

さて,では結果を見ていきましょう.とりあえず10000イテレーション回してみて出力画像の変化を見ました.

最初は何も見えませんが,モヤモヤし始めて目的の状態に近似されていくのが分かります.本当は動画で載せたいのですが,容量の問題で載せられないので最後にTwitterにあげておいた動画へのリンクを纏めて貼っておきます.今回の実装ではGUI上で動かしており,純粋な学習のみの処理時間を計測することは現状では不可能なのですが,GUI上のカメラにほとんどNeRFを生成するAABBを映していない状態で,約27秒で1000イテレーション回ります.学習処理だけであればもっと速いはずです.
せっかくなので,レイ1本あたりのサンプル点数の平均値でもグラフにプロットしておきます.


最初は設定した通り512サンプルきっかり全部サンプリングしてくれています.そしてiter = 16とiter = 32において激しく減少していることが分かります.これは16イテレーションごとにOccupancy Gridを更新し,サンプリング中にサンプルが弾かれやすくなっているためです.もっと縦軸の範囲を狭く,横軸の範囲をのばして見てみましょう.
もちろんサンプリングに使用するバッチによって多少の変化はあるのですが,全体として減少する傾向にあることが分かります.

他のデータセットも試してみましょう.

これはマイクですね.原著NeRFのプロジェクトページで公開されているデータセットBlenderで開ける3Dモデルファイルがあるのですが,それをBlender上で多少編集し,100視点ランダムに選んでレンダリングしました.この作業は次の記事を参考にしました. qiita.com 得られた画像群はこんな感じです.


さて,設定は先ほどと全く同じです.なお,画像のサイズが1024x1024になっております.

同様の変化を取っています.マイクの金網の部分も最終的には再現できていることが見えると思います.

最後に,同じデータセットにあるDrumを同様にしてレンダリングして入力データを作成しました.

さて,マイクと同じくして学習させました.

他の2つと比べると少し精度は低めですね.細い物体(高周波成分)の学習にはなかなか手こずります.ちなみにこれ床の色を黒にすると椅子等が上手く学習されません(色としての知覚上,床も椅子も黒に見えると区別がつかないためだと思います).

最後に動画への埋め込みリンクを載せておきます.録画したのは少し前となるので全く同じ実装ではないですが,今でも同じような(改善された)結果が得られます.

今よりかは多少遅いですが,まあでも数分でイイ感じに学習できています.現状の私の実装ではかなり荒いところもあるので改善はまだまだ出来そうです.

さいごに

Part1ではMLPの,Part2ではエンコーダーの,Part3ではNeRF全体の実装を示し,説明しました.こういう記事を書くのは初めてなのですが,記事を書いて初めて気づくバグや自分の理解不足な箇所とかを洗い出せてなかなか良い作業ではありました(記事書くために全部で数十時間かかったのでコスパといえばどうなのかとなりますが).さて,NeRF自体がどういうものであるか,というものは論文を読めば大体わかりますが,いざ実装をしてみると意外と手が動きません.そもそも現在では実装しなくても他人が実装したプログラムやライブラリがあります.しかし,何かを実装するということは実装対象をある意味で言語化的にアウトプットする作業とも言え,ぼんやりとした理解の存在を前面に押し出してくれます.一度は泥臭く実装してみるのも良いと思います.
実装に際してUshioさんには色々と助けていただきました.ありがとうございました.

参考文献

NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis [Ben Mildenhall et al. ECCV2020]
Instant Neural Graphics Primitives with a Multiresolution Hash Encoding [Thomas Müller et al. SIGGRAPH2022]
・Optical Models for Direct Volume Rendering [Nelson Max. 1995]
マルペケつくろーどっとコム その18 直線とAABB
Bounding Volume Hierarchy (BVH) の実装 - 交差判定編
memoRANDOM ボリュームレンダリング方程式 (Volume Rendering Equation) 1
NVIDIA Developer Blog, Thinking Parallel, Part III: Tree Construction on the GPU
live2d_dev NeRF検討用データセットをBlender Pythonスクリプトで作成する方法
Ushioさんのredpillの実装と解説スライド

CUDA C++でNeRFをほぼ0から実装してみた(Part2/3): エンコーダー編

エンコーダー編 概要

ニューラルネットワークの入力データをエンコーダーに通し,ベクトルの次元を上げることでネットワークの学習効率,精度を向上させることが出来る場合があります.本記事ではInstant NeRFにおいて使用されている2種類のエンコーダーをそれぞれ説明した後に私の実装を説明し,その効果を確認します.

エンコーダー編 はじめに

ニューラルネットワークはある入力のデータ(ベクトル)を元にして様々な計算処理を施し,最終的に目的のデータ(ベクトル)を回帰推定する手法です.ここで,入力層の次元が低い場合は,その結果を正確に推定することが難しくなります [Gehring et al. 2017. Convolutional Sequence to Sequence Learning].詳しいことは3編で書きますが,NeRFにおいてはニューラルネットワークの入力として,3次元空間内での場所と方向という6次元のデータを使用し,問題となっているサンプル点の「色 (3次元)」と「濃度 (1次元)」を推定します.6次元というのは非常に小さく,これだけの情報でNeRFのネットワークを最適化するのは不可能と言っても過言ではないと思います.そこで用いられるのがエンコーダーです.今回実装したNeRFにおいてはPositional EncodingとしてMultiresolution Hash Encodingを,Directional Encodingとして球面調和関数(Spherical Harmonic Encodingと私は呼んでいます)を使用しました.

原著のNeRF [Mildenhall et al. 2020]においては,sinとcosを繰り返すエンコーダーが使用されております.このエンコーダーは "Attention Is All You Need" [Vaswani et al. 2017. ]において使用されているものと同様の物で,処理的には入力データをsinとcosカーブ上の値(フーリエ級数展開的には基底に)に「マッピング」するといったものです.このエンコーダーを使用し,原著のNeRFは確かに3次元形状を近似することが出来ましたが,精度よく学習させるためにエンコーダー後続のネットワークが巨大となり,学習コストが高いなどの課題がありました.2022年にNVIDIAより発表されたMultiresolution Hash Encoding [Müller et al. 2022]はエンコーダーの一つで,これによってNeRF等の学習が非常に高精度,高速化されました.また,原著NeRFでは3次元空間での方向ベクトルを先程のsinとcosのエンコーダー(Frequency Encodingと呼んでいます)に通していましたが,これを球面調和関数に通してエンコードするという手法がMultiresolution Hash Encodingの論文におけるNeRFの実装に使用されており,今回の私の実装もそれにならって球面調和関数によるエンコーダーを使用しました.それぞれを説明していきます.

念のため

内容には気を付けているのですが誤り等があれば教えていただけると幸いです.

Multiresolution Hash Encodingの順伝播の手続き

先述した通り,Positional Encodingの一つです.入力データは空間における座標となります.くどく書けば,この座標を高次元の特徴ベクトルに変換します.最初は簡単のため,2次元で考えることにしましょう.つまり,ある正方形の内部にサンプル点が存在するとします.これは次の図の左側の状態を指します.サンプル点が橙の丸で表されております.

ここで,「レベル」の概念を説明しておきます.今回のエンコーダーではこの正方形に含まれるサンプル点を「様々な解像度レベルで解釈」します.よって,入力座標の存在する正方形を,レベルに応じて色んな解像度で分割します.これが上に示されている図の真ん中の状態です.例えば一番上のレベル( L = 1)では正方形を2分割しており, L = 2では正方形を3分割しております.そして,サンプル点(橙の丸印)がどの正方形に含まれるかを考えます.これが先程示した図において右側の状態です.

さらに,せっかく分割したので分割した正方形の頂点(格子点)に番号を振っておきましょう.また,このサンプル点が正方形内部のどんな位置にあるのかを表現しましょう. s, tを0以上1以下の実数として,次の図のようになります.(各レベルにおいて番号や記号は独立です)

つまりレベル1においてはサンプル点は頂点(1, 0) (2, 0) (1, 1) (2, 1)のなす正方形内部にあり,レベル2においてはサンプル点は頂点(2, 0), (3, 0), (2, 1), (3, 1)のなす正方形内部にあります.

じゃあ次に,この頂点番号をとあるハッシュ関数に入れてぐちゃぐちゃにしてあげましょう.ハッシュ関数 h(\mathbf{x})とします.整数の座標 \mathbf{x}を入れると整数 h(\mathbf{x})が返ってくると思ってください.すると,レベル1においてはサンプル点は h(1,0), h(2,0), h(1,1), h(2,1)のIDがふられた正方形内部にあり,レベル2においては h(2,0), h(3,0), h(2,1), h(3,1)のIDが降られた正方形内部にあります.

……一体何のためにこんなことしてるんだって思われている気がします.では,ここで「特徴ベクトルが格納されたテーブル」をレベルごとに用意します.はい,このテーブルに先ほど計算したハッシュ関数の出力をインデックスとしてアクセスするのです.すると,「格子点にテーブル上の特徴ベクトルが対応」します.そして,その特徴ベクトルを何かしらの方法で,今回はバイリニア補完で補完するのです.図で表すと次の通りです.

スペースの都合上,正確な図には出来なかったのですが,ここで使用するテーブルはレベルごとに独立の物を使用します.つまりレベルが違えば使用するテーブルは異なります.なお,テーブル上の特徴ベクトルの次元を Fとしています.

これまでの処理によって,各レベルにおけるサンプル点の持つ特徴ベクトルがそれぞれ求まりました.では最後に,これらを結合してあげます.これによって得られるベクトルが,このエンコーダーの出力です.

ちなみに論文ではさらにこれに追加のベクトルを結合することもあると書いていますが,少なくともNeRFにおいては使用しませんし,やることは単純なので本記事では省略します.

さて,これまでの処理をもう一度確認しておきましょう(厳密な処理ではなく,雰囲気です).

(1) (x, y) = (サンプル点の座標)
<(2) レベル数だけ繰り返す>
    (2.1) レベルに対応する解像度で正方形を分割する
    (2.2) 分割された小正方形のうち,どの小正方形に(x,y)が囲まれているかを求める
    (2.3) (x, y)が小正方形のどのあたりにあるかを求める(s, t)が求まる
    (2.4) 求めた小正方形の各頂点番号をハッシュ関数に入れ,テーブル上のIDを得る
    (2.5) 求まったIDからテーブルの特徴ベクトルを各頂点に読みだす
    (2.6) 読みだした特徴ベクトルを補完し,これをサンプル点の特徴ベクトルとする
<END: (2) レベル数だけ繰り返す>
(3) 各レベルでそれぞれ得られたサンプル点の特徴ベクトルを結合する

以上で2次元の場合のMultiresolution Hash Encodingの処理は完了しました.このエンコード処理によって,入力データが (x, y)の2次元だったところが,レベル数を L,テーブル上の特徴ベクトルの次元を Fとして LF次元となりました.

記事書いておいてなんですが……

自分のMultiresolution Hash Encodingの実装は現在書き直している途中です.理由としてはかなり乱れているためです.今回は書き直す前のコードで説明しますが,(もちろんこの記事に限った話ではないことですが)この記事に載せているコードをコピペするのではなく,実装の流れを確認して実装自体はそちらで行っていただくことをお勧めします.実装が改善された際に時間があれば書き直すかもしれません(このセクションが存在している限りは書き直されておりません).

Multiresolution Hash Encodingの実装: 順方向

さて,先ほどは2次元の入力に対して処理を行いましたが,実際のNeRFの入力は3次元です.ちなみにここを2次元にした場合は2次元のデータの近似(画像等)になります.さて,3次元を入力とする場合はエンコードの処理を3次元に対応させる必要があります.なので,これまで正方形として扱っていた個所を立方体として扱う必要があります.それを踏まえたうえで,まずは順伝播から実装を解説していきます.

私の実装をまず示します.

    /*
    Encode処理を行う
    (1) 各スレッドは1バッチの処理を行う(入力データは3次元であるという仮定を設ける)
        (1.1) 各レベルに対して次の処理を行う
            (1.1.1) 注目している座標が格子上ではどの8つの格子点に囲まれているかを求める
            (1.1.2) 各8格子点のHashTable上におけるインデックスを求める
            (1.1.3) もとめた8格子点における特徴ベクトルをHashTableより読みだす
            (1.1.4) 注目している座標が8格子点上でどの位置にあるか(s, t, u)を求める
            (1.1.5) (s, t, u)の値から特徴ベクトルを補完する
            コメント: 得られる特徴ベクトルは長さFである
        (1.2) 各レベルに対して(1.1)を行った結果,長さFの特徴ベクトルがL個得られる.それを繋げる(concat)
        (1.3) 長さEの追加特徴ベクトルをさらに(1.2)で得られた長さL*Fのベクトルに繋げる(concat)
    */
    template <const uint32_t indim_aligned>
    MFFM_DEVICE void Encode(const float3 InputRangeMin, const float3 InputRangeMax, __half* Input, __half* Encoded) {
        const int bx = blockIdx.x;
        const int tx = threadIdx.x;
        const int ty = threadIdx.y;
        const int global_threadId = bx * ONEBATCH_SIZE + 32 * ty + tx;
        const int block_threadId = 32 * ty + tx;

        // 1ブロック128バッチを担当する
        if (32 * ty + tx >= ONEBATCH_SIZE) {
            return;
        }

        unsigned int* NeighborPos_base = (unsigned int*)((__half*)Input + (indim_aligned + SKEW) * ONEBATCH_SIZE);
        unsigned int* NeighborPos = NeighborPos_base + (8 * 3) * block_threadId;
        float* NeighborFeatureVec_base = (float*)(NeighborPos_base + (8 * 3) * ONEBATCH_SIZE);
        float* NeighborFeatureVec = NeighborFeatureVec_base + (8 * MHE_F + SKEW) * block_threadId;
        float* stu_base = (NeighborFeatureVec_base + (8 * MHE_F + SKEW) * ONEBATCH_SIZE);
        float* stu = stu_base + 3 * block_threadId;

        // 入力のロード
        float x = normalize(Input[3 * block_threadId + 0], (__half)InputRangeMin.x, (__half)InputRangeMax.x, (__half)0.0f, (__half)1.0f);
        float y = normalize(Input[3 * block_threadId + 1], (__half)InputRangeMin.y, (__half)InputRangeMax.y, (__half)0.0f, (__half)1.0f);
        float z = normalize(Input[3 * block_threadId + 2], (__half)InputRangeMin.z, (__half)InputRangeMax.z, (__half)0.0f, (__half)1.0f);

        // ロードは必ず先に終わらせる
        __syncthreads();

        // (1.1)
#pragma unroll
        for (int l = 0; l < MHE_L; l++) {
            // (1.1.1)
            //unsigned int NeighborPos[8*3];
            Calc_NeighborVectorIndex(l, x, y, z, NeighborPos);

            // 近傍格子点における特徴ベクトルを求める (1.1.2) (1.1.3)
            //float NeighborFeatureVec[8*MHE_F];// [8][F]
#pragma unroll
            for (int i = 0; i < 8; i++) {
                // (1.1.2)
                unsigned int IndexOnHashTable = Calc_IndexOnHashTable(NeighborPos + 3 * i);

                // (1.1.3)
                Get_FeatureVectorOnHashTable(l, IndexOnHashTable, NeighborFeatureVec + MHE_F * i);
            }

            const int Buffer_stu_idx = PosToIdx2D(global_threadId, l, MHE_L) * 3;

            // 近傍格子点における特徴ベクトルから入力座標に対応する特徴ベクトルを求める (1.1.4) (1.1.5)
            // concatも行っていく (1.2)
            // SKEWを与えることに注意(出力データはL*F+E+SKEWとなる)
            Calc_CurrentFeatureVector(l, x, y, z, NeighborPos, NeighborFeatureVec, Encoded + block_threadId * (MHE_L * MHE_F + SKEW) + l * MHE_F, stu);

            // 今回は(1.3)は行わない
        }
        __syncthreads();
    }

Part1と同様に部分部分で見ていきましょう.

    template <const uint32_t indim_aligned>
    MFFM_DEVICE void Encode(const float3 InputRangeMin, const float3 InputRangeMax, __half* Input, __half* Encoded) {
...

・indim_aligned: MLPへの入力次元,つまりエンコーダーの出力次元を16の倍数に整形したものです.ちなみに必ずnext_multiple(LF, 16)となります……(じゃあconstexprで即値にすればいいのでは?)
・MFFM_DEVICE: CUDAの修飾子であるdeviceをdefineで置いたものです.
・InputRangeMin: サンプル点の座標 (x, y, z)が存在する領域はいわゆるAABB(Axis-Aligned-Bounding-Box),つまりxyz軸に平行な辺で構成された直方体です.その直方体のxyz座標の各々最小値が格納されています.
・InputRangeMax: InputRangeMinと同様にして,直方体のxyz座標の各々最大値が格納されています.
・Input: サンプル点の座標が格納された配列(の頭を指すポインタ)です.これはshared memoryに載っています.
・Encoded: エンコードされたデータの格納先です.ポインタとしてはInputと同じにしています.(shared memoryの容量が小さいためです)

...
const int bx = blockIdx.x;
const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int global_threadId = bx * ONEBATCH_SIZE + 32 * ty + tx;
const int block_threadId = 32 * ty + tx;
...

・global_threadID: デバイス全体で見たスレッドIDです.カーネル実行の設定はPart1を参照してください.
・block_threadId: ブロック単位で見たスレッドIDです.

...
// 1ブロック128バッチを担当する
if (32 * ty + tx >= ONEBATCH_SIZE) {
       return;
}
...

Part1を読めば詳しくは分かりますが,各ブロックは128バッチを処理します.なのでブロック単位で見たスレッドID(zero-indexed)が128以上のものは帰します.

...
unsigned int* NeighborPos_base = (unsigned int*)((__half*)Input + (indim_aligned + SKEW) * ONEBATCH_SIZE);
unsigned int* NeighborPos = NeighborPos_base + (8 * 3) * block_threadId;
float* NeighborFeatureVec_base = (float*)(NeighborPos_base + (8 * 3) * ONEBATCH_SIZE);
float* NeighborFeatureVec = NeighborFeatureVec_base + (8 * MHE_F + SKEW) * block_threadId;
float* stu_base = (NeighborFeatureVec_base + (8 * MHE_F + SKEW) * ONEBATCH_SIZE);
float* stu = stu_base + 3 * block_threadId;
...

うわあ……って感じです.えっと,実装中に使用する配列をshared memoryに載せようと努力しています.普通に静的配列として確保した方がいいと思います.細かい説明は出番が来た時にします.

...
 // 入力のロード
float x = normalize(Input[3 * block_threadId + 0], (__half)InputRangeMin.x, (__half)InputRangeMax.x, (__half)0.0f, (__half)1.0f);
float y = normalize(Input[3 * block_threadId + 1], (__half)InputRangeMin.y, (__half)InputRangeMax.y, (__half)0.0f, (__half)1.0f);
float z = normalize(Input[3 * block_threadId + 2], (__half)InputRangeMin.z, (__half)InputRangeMax.z, (__half)0.0f, (__half)1.0f);
...

入力座標をロードしておきます.ただし,サンプル点の座標が[[min.x, max.x], [min.y, max.y], [min.z, max.z]]に存在している状態だと面倒なので,ここでこの座標を[0, 1]^3に正規化しておきます.normalize関数は次の通りです.

// [SrcMIN, SrcMAX] -> [DstMIN, DstMAX]
template<typename T>
__host__ __device__ T normalize(T val, T SrcMIN, T SrcMAX, T DstMIN, T DstMAX) {
    T DstRange = DstMAX - DstMIN;
    T SrcRange = SrcMAX - SrcMIN;
    if (SrcRange == (T)0.0f) SrcRange = (T)1e-6f;
    T t = (val - SrcMIN) / SrcRange;
    return DstMIN + t * DstRange;
}

1次元の値に対して,元の最小値SrcMIN, 最大値SrcMAXの線分を点valで内分する際に比がどうなっているかを求め,それを出力側の最小値と最大値の線分に適用している感じです.ゼロ除算を避けるための処理はしていますが,エラー処理はしてません.

...
 // ロードは必ず先に終わらせる
 __syncthreads();
...

ブロック単位で同期を行います.ブロック内部の速いスレッドがエンコード結果を書き込む際に,ブロック内部の遅いスレッドが入力座標を読み出しが終わっていることを保証するためです.入力データ,出力データはshared memory(ブロックごとに独立)に載っているため,これで大丈夫です.

...
// (1.1)
#pragma unroll
for (int l = 0; l < MHE_L; l++) {
...

レベルの数だけ繰り返します.

...
 // (1.1.1)
//unsigned int NeighborPos[8*3];
Calc_NeighborVectorIndex(l, x, y, z, NeighborPos);
...

現在のレベルにおいて,サンプル点がどの小立方体内部にあるかを求めます.正確に言えば,サンプル点を包含する小立方体の各頂点のxyz各軸方向における頂点番号を求めます.NeighborPosは求めた各軸方向の頂点番号を格納する配列で,一頂点あたり3次元の頂点番号をもち,立方体は8頂点で構成されるので,uint32型8*3の容量が必要です.では,Calc_NeighborVectorIndexの処理を見ましょう.

 // レベルlevelにおける座標{x, y, z}の近傍格子点を求める
 MFFM_DEVICE inline void Calc_NeighborVectorIndex(int level, float x, float y, float z, unsigned int* NbVecIdx) {

     const unsigned int Nl = (unsigned int)(MHE_Nmin * powf(MHE_b, level));
     // 格子の1マスの大きさ
     float K = 1.0f / (float)Nl;

     NbVecIdx[0] = (int)(x / K); 
     NbVecIdx[1] = (int)(y / K); 
     NbVecIdx[2] = (int)(z / K);

     NbVecIdx[3] = NbVecIdx[0] + 1; 
     NbVecIdx[4] = NbVecIdx[1]; 
     NbVecIdx[5] = NbVecIdx[2];

     NbVecIdx[6] = NbVecIdx[0]; 
     NbVecIdx[7] = NbVecIdx[1] + 1; 
     NbVecIdx[8] = NbVecIdx[2];

     NbVecIdx[9] = NbVecIdx[0] + 1;
     NbVecIdx[10] = NbVecIdx[1] + 1; 
     NbVecIdx[11] = NbVecIdx[2];

     NbVecIdx[12] = NbVecIdx[0]; 
     NbVecIdx[13] = NbVecIdx[1]; 
     NbVecIdx[14] = NbVecIdx[2] + 1;

     NbVecIdx[15] = NbVecIdx[0] + 1; 
     NbVecIdx[16] = NbVecIdx[1]; 
     NbVecIdx[17] = NbVecIdx[2] + 1;

     NbVecIdx[18] = NbVecIdx[0]; 
     NbVecIdx[19] = NbVecIdx[1] + 1; 
     NbVecIdx[20] = NbVecIdx[2] + 1;

     NbVecIdx[21] = NbVecIdx[0] + 1; 
     NbVecIdx[22] = NbVecIdx[1] + 1; 
     NbVecIdx[23] = NbVecIdx[2] + 1;
 }

図を交えて説明しましょう.やってることは2次元での説明を3次元にしただけです.

手続きの説明は2次元で行っていました.サンプル点の存在する正方形を小正方形に分割し,どの小正方形に包含されるかを求めました.しかし今回は3次元でやるので,立方体を小立方体に分割し,どの小立方体にサンプル点が含まれるかを計算する必要があります.ここで,先ほど (x, y, z)をロードする際に,座標を[0, 1]^3にスケーリングしました.なので,全体の立方体の一辺の大きさは1です.そして,それをレベル lについては解像度 N_lだけ分割します.ここで,解像度はレベルに対応した解像度としたいので,次の式で解像度を計算します(説明時はレベルを1-indexedで扱っていますが,計算式や実装上は0-indexedです).

 
N_l = \lfloor N_{min} * b^l \rfloor


ここで, N_{min}は最小解像度, bは解像度のスケーリング指数, lはレベル番号を意味します.レベルが高くなるにつれて指数関数的に解像度が増加します.
さて,この解像度のもとで,小立方体の一辺の長さ( Kとします)は次のように求まります.

 
K = \dfrac{1}{N_l}


そして,次が成立しています. Positionはサンプル点の座標です.

 
整数M_x, M_y, M_zが存在して,
\begin{align}
M_xK \leq & Position.x < (M_x+1)K \\
M_yK \leq & Position.y < (M_y+1)K \\
M_zK \leq & Position.z < (M_z+1)K 
\end{align}


この M_x, M_y, M_zは小立方体において各軸小さい側の頂点の番号を表しています.つまり,先ほど示した図において,( M_x, M_y, M_z)が([0], [1], [2])です.残りの7頂点はこの頂点番号に1足したり足さなかったり......で求められます.以上がCalc_NeighborVectorIndexの処理です.続きを見ていきましょう.

// 近傍格子点における特徴ベクトルを求める (1.1.2) (1.1.3)
//float NeighborFeatureVec[8*MHE_F];// [8][F]
#pragma unroll
for (int i = 0; i < 8; i++) {
          // (1.1.2)
          unsigned int IndexOnHashTable = Calc_IndexOnHashTable(NeighborPos + 3 * i);

          // (1.1.3)
          Get_FeatureVectorOnHashTable(l, IndexOnHashTable, NeighborFeatureVec + MHE_F * i);
}

先ほどの処理でサンプル点を包含する小立方体の各頂点番号が分かりました.これをそれぞれハッシュ関数に入力し,テーブル上のインデックスを計算します(Calc_IndexOnHashTable).そして,そのテーブル上のそのインデックスに保存されている特徴ベクトルを読み出します(Get_FeatureVectorOnHashTable).読みだした特徴ベクトルはNeighborFeatureVecに保存されます.8頂点分の特徴ベクトルを保存するのでfloat8F個の容量となっております.
ではCalc_IndexOnHashTableから見ていきましょう.

// 頂点インデックス(3d)からHashTable上のインデックスをハッシュ関数により計算する
MFFM_DEVICE inline unsigned int Calc_IndexOnHashTable(unsigned int* VertexIndex) {
    const unsigned long long int p[3] = { 1, 2654435761, 805459861 };
    unsigned long long int h = 0;

    for (int i = 0; i < 3; i++) {
        h = h ^ (VertexIndex[i] * p[i]);
    }
    h %= MHE_T;

    return (unsigned int)h;
}

これはハッシュ関数を示した方が速いですね.次の関数 h(\mathbf{X})がテーブル上のインデックスを出力するハッシュ関数です.


h(\mathbf{X}) = (X.x * 1) \oplus (X.y * 2654435761) \oplus (X.z * 805459861 )  \mod T

なお, \oplus排他的論理和(XOR)で,Tはテーブルのサイズ,すなわち特徴ベクトルの本数です.この式に出てくる2654435761や805459861は論文において記されていた値ですが,大きな素数が使用されます.この計算によってテーブルの特徴ベクトルをロードする準備が整いましたのでロードします.Get_FeatureVectorOnHashTableを見ていきましょう.

    // VにHashTable上の特徴ベクトルを書き出す
    MFFM_DEVICE inline void Get_FeatureVectorOnHashTable(int level, unsigned int index, float* V) {
#pragma unroll
        for (int i = 0; i < MHE_F; i++) {
            V[i] = HashTable.at(level, index, i);
        }
    }

level: これまで何度も出てきている「レベル」です.
index: 先ほどのハッシュ関数によって計算されたインデックスです
V: テーブルを書きだす先の配列です.
MHE_F: 説明の際に示した,テーブルにおける特徴ベクトルの次元を表す Fのことです.
さて,ここでも触れるべき点があります.しかしこの部分は実装の幅を狭めるところなので真似はしないほうがいいです.コードを見てわかる通り,HashTableは何かしらの構造体として,グローバルのスコープで保持されていますね.ではその部分の実装を見ていきましょう.

enum class Initialize {
    Uniform,
    Load_from_file,
    Zero
};

struct Tb {
    float Data[MHE_L * MHE_T * MHE_F];

    MFFM_DEVICE void init(Initialize initialize, unsigned int seed = 1, float* LoadedData = nullptr) {
        const int index = blockIdx.x * blockDim.x + threadIdx.x;
        switch (initialize) {
        case(Initialize::Zero):
            Data[index] = 0.0f;
            break;
        case(Initialize::Uniform):
            curandState state;
            curand_init(seed, index, 0, &state);
            float rnd = curand_uniform(&state);
            Data[index] = normalize(rnd, 0.0f, 1.0f, -1e-4f, 1e-4f);
            break;
        case(Initialize::Load_from_file):
            Data[index] = LoadedData[index];
            break;
        default:
            printf("Invalid MHE initializer\n");
            break;
        }
    }

    MFFM_DEVICE float at(const uint32_t idxL, const uint32_t idxT, const uint32_t idxF) {
        return Data[MHE_T * MHE_F * idxL + MHE_F * idxT + idxF];
    }
    MFFM_DEVICE float* ptr_at(const uint32_t idxL, const uint32_t idxT, const uint32_t idxF) {
        return Data + MHE_T * MHE_F * idxL + MHE_F * idxT + idxF;
    }
};

MFFM_DEVICE Tb HashTable;
MFFM_DEVICE Tb dLdHashTable;
MFFM_DEVICE Tb v_Buffer_HashTable;
MFFM_DEVICE Tb m_Buffer_HashTable;

__global__ void init_HashTable() {
    HashTable.init(Initialize::Uniform);
    dLdHashTable.init(Initialize::Zero);
    v_Buffer_HashTable.init(Initialize::Zero);
    m_Buffer_HashTable.init(Initialize::Zero);
}

以上がテーブルの構造体周りの処理です.難しいことはしていないので軽く説明するにとどめます.

float Data[MHE_L * MHE_T * MHE_F];

保持しているデータはレベル数 L,テーブルのサイズ T,テーブルの特徴ベクトルの次元 Fの積である LTF要素のfloat配列です.これに対して,init関数では様々な初期化を行います.at関数は(レベル,テーブル上のインデックス,特徴ベクトル上のインデックス)をもとにしてテーブルの要素にアクセスする関数です(今気づきましたが参照してないですね).ptr_at関数は同様の要素を指すポインタにアクセスします.そして,テーブルの本体(HashTable),勾配を記録するdLdHashTable,Adam Optimizerのためのm_Buffer_HashTableとv_Buffer_HashTableがあります.init_HashTableにおいてそれぞれを初期化します.
これは重要なのですが,テーブル本体は初期値を[-1e-4, 1e-4]の範囲における一様乱数で初期化します.それ以外は普通に0初期化します.

さて,先ほどのGet_FeatureVectorOnHashTableにおける処理はこれでわかると思います.しかし,このようにグローバルなものとして定義すると,ニューラルネットワーク内部において1つしかMultiresolution Hash Encodingを使用できないという制約を抱えることになるため,避けた方がいいでしょう.現在主にこの周りの書き直しをしております.

説明の枝が長くなりましたが本筋の解説に戻りましょう.現在どこまでやったかというと,サンプル点の座標を包含する小立方体を求め,その頂点番号をハッシュ関数にいれてテーブル上のインデックスを計算し,そのインデックスに対応するテーブル上の特徴ベクトルを読みだしたところです.ということで,続きを見ていきましょう.

...
// 近傍格子点における特徴ベクトルから入力座標に対応する特徴ベクトルを求める (1.1.4) (1.1.5)
// concatも行っていく (1.2)
// SKEWを与えることに注意(出力データはL*F+E+SKEWとなる)
Calc_CurrentFeatureVector(l, x, y, z, NeighborPos, NeighborFeatureVec, Encoded + block_threadId * (MHE_L * MHE_F + SKEW) + l * MHE_F, stu);
...

この部分の処理では「各頂点にロードした特徴ベクトルの補完によるサンプル点における特徴ベクトルの計算」を行い,それを「レベル番号に対応したメモリ領域に保存(即ち結合と同義)」しています.
Encoded + block_threadId * (MHE_L * MHE_F + SKEW) + l * MHE_Fは,サンプル点におけるレベルlの特徴ベクトルを保存するポインタの先頭を指しています.1バッチ辺りのエンコード結果の特徴ベクトルはLF次元であり,shared memoryのバンクコンフリクトを避けるためにSKEWを与えるため,結局LF+SKEW次元となります.なので,スレッド番号にLF+SKEWを掛けてあげて,さらに各レベルではF次元の特徴ベクトルが得られるのでレベル番号にFを掛けてます.
stuは,いや本当にごめんなさいなんですけど,バイリニア補完の係数を載せる配列です.いや,関数内部で静的配列として確保してくださいね.
というわけで関数の中身を見ましょう.

    // 近傍点の特徴ベクトルからEncoder入力座標における特徴ベクトルをバイリニア補完する
    MFFM_DEVICE inline void Calc_CurrentFeatureVector(int level, float x, float y, float z, unsigned int* NbVecIdx, float* NbFeatureVec,
        __half* CurFeatureVec, float* stu) {

        const unsigned int Nl = (unsigned int)(MHE_Nmin * pow(MHE_b, level));
        // 格子の1マスの大きさ
        float K = 1.0f / (float)Nl;

        // バイリニア補完係数
        stu[0] = (x - K * (float)NbVecIdx[0]) / K;
        stu[1] = (y - K * (float)NbVecIdx[1]) / K;
        stu[2] = (z - K * (float)NbVecIdx[2]) / K;

        // 3d-バイリニア補完
#pragma unroll
        for (int i = 0; i < MHE_F; i++) {
            CurFeatureVec[i] = __float2half((1 - stu[0]) * (1 - stu[1]) * (1 - stu[2]) * NbFeatureVec[i] +
                stu[0] * (1 - stu[1]) * (1 - stu[2]) * NbFeatureVec[MHE_F + i] +
                (1 - stu[0]) * stu[1] * (1 - stu[2]) * NbFeatureVec[2 * MHE_F + i] +
                stu[0] * stu[1] * (1 - stu[2]) * NbFeatureVec[3 * MHE_F + i] +
                (1 - stu[0]) * (1 - stu[1]) * stu[2] * NbFeatureVec[4 * MHE_F + i] +
                stu[0] * (1 - stu[1]) * stu[2] * NbFeatureVec[5 * MHE_F + i] +
                (1 - stu[0]) * stu[1] * stu[2] * NbFeatureVec[6 * MHE_F + i] +
                stu[0] * stu[1] * stu[2] * NbFeatureVec[7 * MHE_F + i]);
        }
    }

 Kを求めるところまではCalc_NeighborVectorIndexでやったのと同じです.後半を図で説明します.

Calc_NeighborVectorIndexにても書きましたが,この小立方体においてすべての座標が小さい頂点(特徴ベクトルV[0]がある頂点)の座標は, (M_xK, M_yK, M_zK)です.また,この小立方体の一辺の長さはKです.つまり,このサンプル点の座標から (M_xK, M_yK, M_zK)を引いたベクトルをKで割るとサンプル点が立方体の各軸方向についてどの場所にあるかを表現することが出来ます.つまり,次の式によりバイリニア補完係数を求めています.

 
(s, t, u) = \dfrac{(Position) - (M_xK, M_yK, M_zK)}{K}

あとは図中の式に従って補完しましょう.

お疲れ様です.これでMultiresolution Hash Encodingの順方向が実装出来ました.

Multiresolution Hash Encodingの手続き: 逆方向

さて,実は逆方向はかなり簡単です.今回求める勾配(即ち更新パラメーター)はテーブル上の特徴ベクトルです.順方向でどのような操作をしたかを思い出しましょう.「サンプル点を含む小立方体の各頂点に対応する特徴ベクトルを読み出し,バイリニア補完により出力層を計算」しました.つまり,これの逆伝播としては,「バイリニア補完の逆伝播計算をし,各頂点に対応する特徴ベクトルの勾配を計算」することです.図中のバイリニア補完の式から,次の微分が出来ます.記号は先ほどの図に出てくる計算式を参照してください( \mathbf{X}は入力ではないです)

 
\begin{align}
\dfrac{\partial \mathbf{X}}{\partial \mathbf{V}}= (&(1-s)(1-t)(1-u), s(1-t)(1-u), (1-s)t(1-u), st(1-u), \\ 
                &(1-s)(1-t)u, s(1-t)u, (1-s)tu, stu) \\
\end{align}

よって,バイリニア補完の係数を出力層に流れ込んできた勾配に掛けてあげれば良いだけです.

Multiresolution Hash Encodingの実装: 逆方向

    /*
     * 誤差逆伝播
     * shared memory:
     * - dEdOut: 誤差.サイズ(INDIM_ALIGNED+SKEW) * ONEBATCH_SIZE
     * - additional_shmem: 余分なshared memory.dLdVの格納に使用する
     *                     サイズ/thread: 8*MHE_F+SKEW
     *                     始点: (8*MHE_F+SKEW)*(block_threadIdx)
     * (1) スレッドごとのエンコーダー出力層の誤差を全スレッドに対するエンコーダー出力層の誤差から読みだす
     * (2) 各レベルごとに次の処理を行う
     *      (2.1) スレッドごとのエンコーダー出力層の誤差をレベルごとに分割する
     *      (2.2) 順伝播時に記録したBuffer.s/t/uを読み込み,近傍点の特徴ベクトルの補完係数を求める
     *      (2.3) 各近傍点について次の処理を行う
     *          (2.3.1) dEdV[k]を求める
     *          (2.3.2) dEdV[k]のHashTable上でのインデックスをBufferから読みだす
     *          (2.3.3) 近傍点に対応するHashTableの要素の誤差を記録する
     */
    MFFM_DEVICE void Propagate_backward(const float3 InputRangeMin, const float3 InputRangeMax, __half* dEdOut, __half* Buffer_Input, __half* additional_shmem) {
        const int bx = blockIdx.x;
        const int tx = threadIdx.x;
        const int ty = threadIdx.y;

        // 1ブロック128バッチを担当する
        if (32 * ty + tx >= ONEBATCH_SIZE) {
            return;
        }

        const int global_threadId = bx * ONEBATCH_SIZE + 32 * ty + tx;
        const int block_threadId = 32 * ty + tx;

        // s, t, uは再計算する方が速い
        // 入力のロード
        float x = normalize(Buffer_Input[3 * block_threadId + 0], (__half)InputRangeMin.x, (__half)InputRangeMax.x, (__half)0.0f, (__half)1.0f);
        float y = normalize(Buffer_Input[3 * block_threadId + 1], (__half)InputRangeMin.y, (__half)InputRangeMax.y, (__half)0.0f, (__half)1.0f);
        float z = normalize(Buffer_Input[3 * block_threadId + 2], (__half)InputRangeMin.z, (__half)InputRangeMax.z, (__half)0.0f, (__half)1.0f);

        __syncthreads();

        // (1) SKEWに注意(次元はL*F+E+SKEW)
        __half* dEdOutLF = dEdOut + (MHE_L * MHE_F + SKEW) * block_threadId;
        __half* dLdV = additional_shmem + (8 * MHE_F + SKEW) * (block_threadId);
        // (2)
#pragma unroll
        for (int l = 0; l < MHE_L; l++) {
            // s, t, u, indexOnHashTableを求める処理 - レベルLの格子における近傍点の座標 ///////////////
            unsigned int NeighborPos[8 * 3];
            Calc_NeighborVectorIndex(l, x, y, z, NeighborPos);

            const unsigned int Nl = (unsigned int)(MHE_Nmin * pow(MHE_b, l));
            // 格子の1マスの大きさ
            const float K = 1.0f / (float)Nl;

            // バイリニア補完係数
            const float s = (x - K * (float)NeighborPos[0]) / K;
            const float t = (y - K * (float)NeighborPos[1]) / K;
            const float u = (z - K * (float)NeighborPos[2]) / K;
            ////////////////////////////////////////////////////////////////////////////////////

            // (2.1)
            __half* dEdOut_LvWise = dEdOutLF + l * MHE_F;

            // (2.2)
            const float weight[8] = { (1 - s) * (1 - t) * (1 - u), s * (1 - t) * (1 - u), (1 - s) * t * (1 - u), s * t * (1 - u),
                                      (1 - s) * (1 - t) * u      , s * (1 - t) * u      , (1 - s) * t * u      , s * t * u };

            // (2.3)
#pragma unroll
            for (int k = 0; k < 8; k++) {
                // IndexOnHashTableを求める.
                unsigned int IndexOnHashTable = Calc_IndexOnHashTable(NeighborPos + 3 * k);

                // (2.3.1)
#pragma unroll
                for (int f = 0; f < MHE_F; f++) {
                    dLdV[k * MHE_F + f] = weight[k] * (float)dEdOut_LvWise[f];
                }

                // (2.3.3)
#pragma unroll
                for (int f = 0; f < MHE_F; f++) {
                    atomicAdd(dLdHashTable.ptr_at(l, IndexOnHashTable, f), dLdV[k * MHE_F + f]);
                    //*dLdHashTable.ptr_at(l, IndexOnHashTable, f) += (float)dLdV[k * MHE_F + f];
                }
            }
        }
    }

見ていきましょう.

MFFM_DEVICE void Propagate_backward(const float3 InputRangeMin, const float3 InputRangeMax, __half* dEdOut, __half* Buffer_Input, __half* additional_shmem) {
...

・dEdOut: 出力層に流れ込んできた勾配
・Buffer_Input: 順方向の際の入力データです.即ちサンプル点の座標です
・additional_shmem: 酷い実装の片鱗です.演算時のデータをshared memoryに載せるために空いているshared memoryの領域を持ってきます.無視しても良いです.

...
const int bx = blockIdx.x;
...
float z = normalize(Buffer_Input[3 * block_threadId + 2], (__half)InputRangeMin.z, (__half)InputRangeMax.z, (__half)0.0f, (__half)1.0f);

__syncthreads();
...

順伝播と同じです.バイリニア補完の係数を求めるために順伝播の処理を部分的に行っています.s, t, uを保存しておくよりもこちらの方が綺麗に実装できると思います.

// (1) SKEWに注意(次元はL*F+E+SKEW)
__half* dEdOutLF = dEdOut + (MHE_L * MHE_F + SKEW) * block_threadId;
__half* dLdV = additional_shmem + (8 * MHE_F + SKEW) * (block_threadId);

・dEdOutLF: 出力層に流れ込んできた勾配データで,実行スレッドに対応する勾配データの先頭を指すポインタ
・dLdV: ああ,各頂点に対応する特徴ベクトルの勾配を保存する領域です.静的配列として確保した方が良いと思います.あと,誤差関数がLとEで表記揺れしていますが気にしないでください......

...
// (2)
#pragma unroll
        for (int l = 0; l < MHE_L; l++) {
...

レベルごとに行います.

...
// s, t, u, indexOnHashTableを求める処理 - レベルLの格子における近傍点の座標 ///////////////
...
const float u = (z - K * (float)NeighborPos[2]) / K;
////////////////////////////////////////////////////////////////////////////////////
...

コメントの通りです.順伝播と同じ処理なので省略します.

...
 // (2.1)
 __half* dEdOut_LvWise = dEdOutLF + l * MHE_F;
...

各レベルごとに処理をしたいので,処理中のレベルに対応した出力層に流れ込んできた勾配を指すポインタを計算します.1レベルごとにF要素を処理しているのでレベル番号にFを掛けてます.

// (2.2)
const float weight[8] = { (1 - s) * (1 - t) * (1 - u), s * (1 - t) * (1 - u), (1 - s) * t * (1 - u), s * t * (1 - u), (1 - s) * (1 - t) * u      , s * (1 - t) * u      , (1 - s) * t * u      , s * t * u };

バイリニア補完の係数ですね.これで準備が整いました.一気に行きましょう.

// (2.3)
#pragma unroll
for (int k = 0; k < 8; k++) {
         // IndexOnHashTableを求める.
         unsigned int IndexOnHashTable = Calc_IndexOnHashTable(NeighborPos + 3 * k);

        // (2.3.1)
#pragma unroll
       for (int f = 0; f < MHE_F; f++) {
             dLdV[k * MHE_F + f] = weight[k] * (float)dEdOut_LvWise[f];
       }

       // (2.3.3)
#pragma unroll
      for (int f = 0; f < MHE_F; f++) {
            atomicAdd(dLdHashTable.ptr_at(l, IndexOnHashTable, f), dLdV[k * MHE_F + f]);
            //*dLdHashTable.ptr_at(l, IndexOnHashTable, f) += (float)dLdV[k * MHE_F + f];
      }
 }

各頂点に対応する特徴ベクトルの勾配を求めるため,各頂点に注目して処理していきます.まず,頂点とテーブル上の特徴ベクトルを対応させるために,順伝播と同じようにインデックスを求めます.そして,先程示した勾配を求める式に代入し,頂点に対応する特徴ベクトルの勾配を求めます.そして最後に,その勾配を,勾配を記録するテーブルの構造体であるdLdHashTableにatomicAddによりaccumulateします.(正直アクセスが疎なのでatomicじゃなくても耐えるのでは?と思っていますが,確証がないのでちゃんとatomicにしてます).以上で逆方向の処理は完了です.

Multiresolution Hash Encodingの実装: 最適化

実は現状の実装におけるボトルネックです.テーブル上のすべてのパラメーターに対して最適化処理を行います.

    ///////////////////////////////// OPTIMIZATION IMPLEMENTATION //////////////////////////////////////////////////////////////////////
    MFFM_DEVICE void Optimization(const uint32_t BatchSize, Optimize optimize, const int epoch) {
        int bx = blockIdx.x;
        int tx = threadIdx.x;
        int ty = threadIdx.y;

        const int nThreads = 32 * blockDim.y * gridDim.x;
        const int threadId = blockIdx.x * 32 * blockDim.y + 32 * ty + tx;
        const int WeightSize_this_layer = MHE_L * MHE_T * MHE_F;

#pragma unroll
        for (int i = threadId; i < WeightSize_this_layer; i += nThreads) {
            float dL = dLdHashTable.Data[i];
            if (!isfinite(dL)) {
                dLdHashTable.Data[i] = 0.0f;
                continue;
            }

            dL = dL / (float)BatchSize;

            switch (optimize) {
            case(Optimize::GD):
                HashTable.Data[i] = HashTable.Data[i] - (float)LEARNINGRATE * dL;
                break;
            case(Optimize::Adam):
                if (!AdamOptimize(m_Buffer_HashTable.Data[i], v_Buffer_HashTable.Data[i], dL, HashTable.Data[i], epoch)) {
                    // printf("%d %f %f \n", idx, (float)AdditionalParam[2 * idx], (float)AdditionalParam[2 * idx + 1]);
                }
                break;
            default:
                printf("Invalid Optimization Type\n");
                break;
            }
            dLdHashTable.Data[i] = 0.0f;
        }
        __syncthreads();
    }

やっていることはPart1における全結合層のパラメーターの最適化と同じなので説明は省略します.実際に使用されるテーブル上のパラメーターは限られるのでそれだけ更新するという実装にしようとしていますが現状ではまだ上手くいってません.以上でMultiresolution Hash Encodingの実装が完了しました.

Spherical Harmonic Encodingについて

球面調和関数の基底を用いてエンコードするというものです.球面調和関数は物性とか微分方程式とかの講義で触れられた記憶がありますが,正直なところ式だけ提示されても何も分かりません.まずそもそも球面調和関数を使用してエンコードするというのは何故なのか,いったい何の意味があるのかという疑問が出てきます.例えば3DCGの分野でも光源系の表現などに使用している論文が2000年ごろにありましたが,まあ読んでも詳細には何をしているか分かりませんでした.というわけで土日を溶かして数学をしました.そのうえで自分の理解を述べます.実装だけが目的であれば飛ばしても大丈夫です.

球面調和関数の導出とその正規直交性(と完全性)

ラプラス方程式

 
\Delta \Psi = (\dfrac{\partial^2}{\partial x^2} + \dfrac{\partial^2}{\partial y^2} + \dfrac{\partial^2}{\partial z^2})\Psi = 0


を考えます.これは極座標の関係

 
(x, y, z) = (r \sin \phi \cos \theta, r \sin \phi \sin \theta, r \cos\phi)
を用いて


 
\dfrac{\partial}{\partial r} \left(r^2 \dfrac{\partial}{\partial r}  \right) \Psi + \dfrac{1}{\sin \theta} \dfrac{\partial}{\partial \theta} \left( \sin\theta \dfrac{\partial}{\partial \theta} \right) \Psi + \dfrac{1}{\sin^2\theta} \dfrac{\partial^2}{\partial \phi^2}\Psi = 0


と書き換えられます(ここの導出はかなり面倒なので省きます).変数分離法により変形していきます.まず,次を認めます.


 
R(r), Q(\theta, \phi)が存在して,\Psi(r, \theta, \phi) = R(r)Q(\theta, \phi)と書ける


これを先程の式に代入し,両辺を RQで割って整理すると,次式が得られます.


 
\dfrac{1}{R}\dfrac{d}{dr} \left(r^2 \dfrac{d}{dr} \right) R = -\dfrac{1}{Q\sin\theta}\dfrac{\partial}{\partial \theta} \left( \sin\theta\dfrac{\partial}{\partial \theta}\right)Q - \dfrac{1}{Q\sin^2\theta}\dfrac{\partial}{\partial\phi}Q


ここで,左辺は rのみの式,右辺は \theta, \phiのみの式となり,この等号がどのような r, \theta, \phiに対しても成り立つので,両辺は定数となります. \lambdaを定数として,


 
\begin{align}
& \dfrac{1}{R}\dfrac{d}{dr} \left(r^2 \dfrac{d}{dr} \right) R = \lambda \\
& \dfrac{1}{Q\sin\theta}\dfrac{\partial}{\partial \theta} \left( \sin\theta\dfrac{\partial}{\partial \theta}\right)Q - \dfrac{1}{Q\sin^2\theta}\dfrac{\partial}{\partial\phi}Q = -\lambda
\end{align}


とします.今回は2つめの式に注目します.さらに次を認めます.


 
\Theta(\theta), \Phi(\phi)が存在して,Q(\theta, \phi) = \Theta(\theta) \Phi(\phi)と書ける


これを先程の2つめの式に代入し,同様に整理すると,次が得られます.


 
\dfrac{\sin\theta}{\Theta} \dfrac{d}{d\theta} \left(\sin\theta \dfrac{d}{d\theta}  \right)\Theta + \lambda \sin^2\theta = -\dfrac{1}{\Phi} \dfrac{d}{d\phi}\Phi


左辺は \thetaのみの式,右辺は \phiのみの式になっていますね.また, \Phiは周期的な関数であるとすれば, mを整数として,


 
\dfrac{1}{\Phi} \dfrac{d}{d\phi}\Phi = -m^2 \\
\dfrac{\sin\theta}{\Theta} \dfrac{d}{d\theta} \left(\sin\theta \dfrac{d}{d\theta}  \right)\Theta + \lambda \sin^2\theta = m^2


として書けます.1つ目の式より, Aを任意定数として,


 
\Phi = A e^{jm\phi}


と書けます(一般解ではないです.共役な基底があります).また,2つ目の式に対して,


 
x := \cos \theta \\
P(x) = P(\cos\theta) = \Theta(\theta)


を代入すると,


 
(1-x^2)^2 \dfrac{d^2P}{dx^2} -2x(1-x^2)\dfrac{dP}{dx} + \left( \lambda(1-x^2) - m^2  \right)P = 0


両辺を[ tex: (1-x2) ]で割ると,('はxによる微分を意味します)


 
(1-x^2) P'' -2xP' + \left( \lambda - \dfrac{m^2}{1-x^2}  \right)P = 0


ですね.これは m = 0のときルジャンドル方程式,そうでない場合ルジャンドル培方程式と言って名前がついてます.ここで, m \geq 0としておくことにします.
この世界には次の式が浮かんでくる人がいるみたいです.


 
P = (1-x^2)^{m/2} \mathcal{P}


これを先程の式に代入すると次が得られます.


 
(1-x^2)\mathcal{P}'' - 2x(m+1)\mathcal{P}' + \left( \lambda - m(m+1)  \right)\mathcal{P} = 0


フロベニウスの方法(級数法)を使用してこの微分方程式を解きます. \mathcal{P}が次の級数の形で与えられるとします.


 
\mathcal{P} = \sum_{j=0} ^ \infty a_j x^{k+j}


 k = 0の時,xの次数が2以上の式を考えると,次の漸化式が得られます.


 
a_{j+2} = a_j \left( \dfrac{j^2 + (2m+1)j - \lambda + m(m+1) }{ (j+1)(j+2) } \right)


ここで, -1 \leq x \leq 1の範囲で \mathcal{P}(x)には収束してもらうため,


 
j^2 + (2m+1)j - \lambda + m(m+1) = 0を満たすjが存在する


を満たすようにしたいです.ここで lをm以上の整数として, j = l-mとします.これを上の式に代入することにより, \lambda = l(l+1)が先程の条件を満たしてくれます.
 k = 1の時,xの次数が1以上の式を比較することにより,


 
a_{j+1} = a_{j-1} \left( \dfrac{j^2 + (2m+1)j - \lambda + m(m+1) }{ (j+1)(j+2) } \right)


が得られます.同様の議論が出来ます.さて,ルジャンドル培方程式を改めて書き直しましょう.


 
(1-x^2) P_l^{m''} -2xP_l^{m'} + \left( l(l+1)- \dfrac{m^2}{1-x^2}  \right)P_l^m = 0


 m = 0の時,上式は


 
(1-x^2) P_l^{''} -2xP_l^{'} +  l(l+1)P_l = 0


となります(ルジャンドル方程式).この両辺をm回微分しましょう.ライプニッツの公式


 
\dfrac{d^m}{dx^m} (A(x)B(x)) = \sum_{r=0}^{m} {}_m \mathrm{C}_r \dfrac{d^{m-r}}{dx^{m-r}} A(x) \dfrac{d^r}{dx^r} B(x)


を利用します.


 
(1-x^2)u'' - 2x(m+1)u' + \left( l(l+1) - m(m+1)  \right)u = 0 \\
ただし,u := \dfrac{d^m}{dx^m}P_l(x)


となります.実はこの式は先ほど出てきた \mathcal{P}に関する微分方程式と同じですね( \lambda = l(l+1)).これより,


 
\mathcal{P}_l^m = (-1)^m \dfrac{d^m}{dx^m}P_l(x)


と書けるらしいですがこの(-1)のべき乗の項はまだよく分かってないです.コンドン-ショートレー位相と呼ぶらしいですが,AMS-55という定義があるらしいとかなんとか……とにかく,これにロドリゲスの公式


 
P_l(x) = \dfrac{1}{2^l l!} \dfrac{d^l}{dx^l}(x^2-1)^l


を代入することにより,


 
\mathcal{P}_l^m = \dfrac{(-1)^m}{2^l l!} \dfrac{d^{m+l}}{dx^{m+l}}(x^2-1)^l


であり,さらに


 
P_l^m = \dfrac{(-1)^m}{2^l l!} (1-x^2)^{m/2} \dfrac{d^{m+l}}{dx^{m+l}}(x^2-1)^l


となります.これをルジャンドル陪関数と呼びます.ここで,これまでは非負の mについて計算していたので,これを負の mについても拡張します.(本当にこんな拡張していいのかという疑問がまだ解決できておりませんが......)
ひたすら計算します.ライプニッツの公式を使用して,


 
\begin{align}
\dfrac{d^{l+m}}{dx^{l+m}} (1-x^2)^l &= \left(\dfrac{d}{dx}\right)^{l+m} (1+x)^l (1-x)^l \\
                                                           &= \sum_{r=0}^{l+m} {}_{l+m}\mathrm{C}_r \left(  \left(\dfrac{d}{dx}\right) ^{l+m-r} (1+x)^l \right) \left(  \left(\dfrac{d}{dx}\right) ^{r} (1-x)^l \right)
\end{align}


ここで,[tex: (1+x)l, (1-x)l]の最高次数はともに lなので, l+1微分すると0となります.また,努力により,この微分は計算出来て,


 
\begin{align}
\left(\dfrac{d}{dx}\right) ^{l+m-r} (1+x)^l &= \dfrac{l!}{(r-m)!}(1+x)^{r-m}  \;\; (r \geq m) \\
\left(\dfrac{d}{dx}\right) ^{r} (1-x)^l          &= \dfrac{(-1)^r l!}{(l-r)!}(1-x)^{l-r} \;\; (r \leq l)
\end{align}


これより,


 
\begin{align}
\dfrac{d^{l+m}}{dx^{l+m}} (1-x^2)^l &= \sum_{r=m}^{l} {}_{l+m}\mathrm{C}_r \left( \dfrac{l!}{(r-m)!}(1+x)^{r-m} \right) \left( \dfrac{(-1)^r l!}{(l-r)!}(1-x)^{l-r}  \right) \\
                                                            &= \sum_{k=0}^{l-m} {}_{l+m}\mathrm{C}_{k+m} \left( \dfrac{l!}{k!}(1+x)^{k} \right) \left( \dfrac{(-1)^{m+k} l!}{(l-m-k)!}(1-x)^{l-m-k}  \right) \;\; (k ;= r-m) \\
                                                            &= \sum_{k=0}^{l-m} {}_{l+m}\mathrm{C}_{k+m} \left( \dfrac{l!}{k!} \dfrac{(-1)^{m+k} l!}{(l-m-k)!} \dfrac{(1+x)^{k+m}}{(1+x)^m} \dfrac{(1-x)^{l-k}}{(1-x)^m} \right) \\
                                                            &= \sum_{k=0}^{l-m} \dfrac{(-1)^m}{(1-x^2)^m} \dfrac{(l+m)!}{(l-k)!(k+m)!}  \left( \dfrac{l!}{k!} \dfrac{(-1)^{k} l!}{(l-m-k)!} (1+x)^{k+m}(1-x)^{l-k} \right) \\
                                                            &= \sum_{k=0}^{l-m} \dfrac{(-1)^m}{(1-x^2)^m} \dfrac{(l+m)!}{(l-m)!} \left( \dfrac{(l-m)!}{k!(l-m-k)!} \dfrac{l!}{(k+m)!}(1+x)^{k+m} \dfrac{(-1)^kl!}{(l-k)!}(1-x)^{l-k}  \right) 
\end{align}


ここで,直前の努力により得られた式と括弧内の式を見比べると,


 
\begin{align}
\dfrac{d^{l+m}}{dx^{l+m}} (1-x^2)^l &= \dfrac{(-1)^m}{(1-x^2)^m} \dfrac{(l+m)!}{(l-m)!} \sum_{k=0}^{l-m} {}_{l-m}\mathrm{C}_k \left( \left(\dfrac{d}{dx}\right)^{l-m-k} (1+x)^l \right) \left( \left(\dfrac{d}{dx}\right)^{k} (1-x)^l \right)
\end{align}


ライプニッツの公式より,


 
\begin{align}
\dfrac{d^{l+m}}{dx^{l+m}} (1-x^2)^l &= \dfrac{(-1)^m}{(1-x^2)^m} \dfrac{(l+m)!}{(l-m)!} \dfrac{d^{l-m}}{dx^{l-m}} (1-x^2)^l
\end{align}


さて,準備が整ったのでルジャンドル陪関数の式に適用します.


 
\begin{align}
P_l^m(x) &= \dfrac{(-1)^m}{2^l l!} (1-x^2)^{m/2} \dfrac{d^{m+l}}{dx^{m+l}}(x^2-1)^l \\
               &=  \dfrac{(-1)^{m+l}}{2^l l!} (1-x^2)^{m/2} \dfrac{d^{m+l}}{dx^{m+l}}(1-x^2)^l \\
P_l^{-m}(x) &= \dfrac{(-1)^{l-m}}{2^l l!} (1-x^2)^{-m/2} \dfrac{d^{l-m}}{dx^{l-m}}(x^2-1)^l \\
                   &= \dfrac{(-1)^{l-m}}{2^l l!} (1-x^2)^{-m/2} \dfrac{(1-x^2)^m (l-m)!}{(-1)^m (l+m)!} {dx^{l+m}}(x^2-1)^l \\
                   &= (-1)^m \dfrac{(l-m)!}{(l+m)!} \dfrac{(-1)^{l+m}}{2^l l!} (1-x^2)^{m/2} {dx^{l+m}}(x^2-1)^l \\
                   &= (-1)^m \dfrac{(l-m)!}{(l+m)!} P_l^m(x)
\end{align}


これによって mが正の場合と負の場合の対応が付きました.この式はまた後に使うこととして,ここでルジャンドル陪関数の直交性を確認しましょう.
記号を次のように省略することとします.


 
\begin{align}
P_l^m(x) &= \dfrac{(-1)^m}{2^l l!} (1-x^2)^{m/2} \dfrac{d^{m+l}}{dx^{m+l}}(x^2-1)^l \\
               &= \dfrac{(-1)^m (-1)^{m/2}}{2^l l!} (x^2-1)^{m/2} \dfrac{d^{m+l}}{dx^{m+l}}(x^2-1)^l \\ 
               &= \dfrac{(-1)^m (-1)^{m/2}}{2^l l!} R^{m/2} D^{m+l} R^l \\
\end{align}
 
ただし,D = \dfrac{d}{dx}, R = x^2-1


 xの定義域は[-1, 1]であるため,-1から1までの積分を行います.


 
\begin{align}
\int_{-1}^{1} P_p^m(x) P_q^m(x) dx &= \dfrac{(-1)^{2m} (-1)^m}{2^{p+q}p!q!} \int_{-1}^{1} R^m (D^{p+m}R^{p}) (D^{q+m} R^{q})dx \\
                                                          &= \dfrac{(-1)^m}{2^{p+q}p!q!} \int_{-1}^{1} R^m (D^{p+m}R^{p}) (D^{q+m} R^{q})dx \\
                                                          &=: \dfrac{(-1)^m}{2^{p+q}p!q!} I
\end{align}


 I積分の部分です.これを計算しましょう.ただし, p \leq qとします.s回だけ部分積分したときの式は


 
\begin{align}
&I = \sum_{i = 1}^{s} s_{pq}^{i} + (-1)^s \int_{-1}^{1} D^{s} (R^m D^{m+p} R^p) D^{m+q-s} R^q dx \\
&ただし,s_{pq}^i = \left.(-1)^{i-1} D^{i-1} (R^m D^{m+p} R^p) D^{m+q-i} R^q \right|_{-1}^{1}
\end{align}


この s_{pq}^iは実は消えます.[tex: (x2-1)]の項が生きているうちは1と-1を代入すると0になるので,


 
\left. D^t R^q \right|_{-1}^{1} = 0 \;\; (t \leq q-1)


となりますね.この考え方を利用して,[tex: D^{m+q-i}Rq]の部分に注目します. m+q-i \leq q-1の時,つまり i \geq m+1の時はこの微分の結果に[tex: (x2-1)]の項が生きているので,結局1と-1を代入するとこの計算結果は0となり,


 
s_{pq}^i = 0 \;\; (i \geq m+1)


となります.次に,ライプニッツの公式を用いて,[tex: D^{i-1} (Rm D^{m+p} Rp)]の部分に注目すると,


 
D^{i-1} (R^m D^{m+p} R^p) = \sum_{r = 0}^{i-1} {}_{i-1} \mathrm{C}_r D^{i-1-r} R^m \; D^{m+p+r}R^p


[tex: D^{i-1-r} Rmの部分に注目します. i-1-r \leq i-1 \leq m-1の時,つまり i \leq mの時は \sumに現れるすべての項において[tex: (x2-1)]の項が生存します.そのため,


 
s_{pq}^i = 0 \;\; (i \leq m)


となります.先ほど得られた計算結果と合わせると,確かに全てのiについて s_{pq}^i = 0が満たされていることが分かります.結局,



\begin{align} 
I =(-1)^s \int_{-1}^{1} D^{s} (R^m D^{m+p} R^p) D^{m+q-s} R^q dx
\end{align}


ここで p \leq qであるので, I積分 m+qまで部分積分できます. s = m+qとすると,ライプニッツの式を利用して


 
\begin{align}
I &= (-1)^{m+q} \int_{-1}^{1} D^{m+q} (R^m D^{m+p} R^p) D^{0} R^q dx \\
  &= (-1)^{m+q} \int_{-1}^{1} dx R^q \left( \sum_{r=0}^{m+q} {}_{m+q}\mathrm{C}_r \; D^{m+q-r} R^m \; D^{m+p+r} R^p   \right) \\
\end{align}


さて, R xに関する2次多項式でした.なので[tex: Rm, Rp]はそれぞれ次数が 2m, 2pです.つまり次数よりも多く微分するとこれらは0になります.つまり,


 
D^{m+q-r} R^mについては,m+q-r \leq 2mを満たすrにのみ値が現れうる.\\
つまり,r \geq q-m \\
D^{m+q-r} R^mについては,m+p+r \leq 2pを満たすrにのみ値が現れうる.\\
つまり,r \leq p-m


ここで,p < qの時を考えると,なんと先ほどの議論よりどの rにも値は現れません.つまり, I = 0となります.
 p = qのときは r = p - mを満たすrのみに値が現れうるので,


 
\begin{align}
I &= (-1)^{m+q} \int_{-1}^{1} dx R^q \left( {}_{m+q}\mathrm{C}_{p-m} \; D^{2m} R^m \; D^{2p} R^p   \right) \\
\end{align}


となります.さらに,


 
\begin{align}
D^{2m} R^m = (2m)! \\
D^{2p} R^p = (2p)!
\end{align}


であるので, p = qに注意して,


 
\begin{align}
I &= (-1)^{m+q} (2m)!(2p)! {}_{m+q}\mathrm{C}_{p-m} \int_{-1}^{1} R^q dx \\
  &= (-1)^{m+p} \dfrac{(m+p)! (2p)!}{(p-m)!} \int_{-1}^{1} R^p dx \\
\end{align}


そして,この積分は部分積分を頑張ることで計算できます.


 
\begin{align}
\int_{-1}^{1} R^p dx &= \int_{-1}^{1} (x^2-1)^p dx \\
                                 &= \int_{-1}^{1} (x+1)^p (x-1)^p dx \\
                                 &= (努力) \\ 
                                 &= (-1)^p \dfrac{(p!)^2 2^{2p}}{(2p)!} \dfrac{2}{2p+1}
\end{align}


であるので,


 
\begin{align}
I &=  (-1)^{m+p} \dfrac{(m+p)! (2p)!}{(p-m)!} (-1)^p \dfrac{(p!)^2 2^{2p}}{(2p)!} \dfrac{2}{2p+1} \\
  &=  (-1)^m (p!)^2 2^{2p} \dfrac{(p+m)!}{(p-m)!} \dfrac{2}{2p+1}
\end{align}


と求まります.以上の議論より,クロネッカーのデルタを用いて,


 
\begin{align}
\int_{-1}^{1} P_p^m(x) P_q^m(x) dx &= \dfrac{(-1)^m}{2^{p+q}p!q!} I \\
                                                          &= \delta_{pq} \dfrac{(-1)^m}{2^{p+q}p!q!} \times (-1)^m (p!)^2 2^{2p} \dfrac{(p+m)!}{(p-m)!} \dfrac{2}{2p+1} \\
                                                          &= \dfrac{2}{2p+1} \dfrac{(p+m)!}{(p-m)!} \delta_{pq}  
\end{align}


以上より,ルジャンドル陪関数が直交性を持つことが分かりました.いったんまとめましょう.元々は最初に示したラプラス方程式 \Theta(\theta), \Phi(\phi)を求めていました.これまでの議論より,


 
\begin{align}
& \Phi(\phi) = A e^{jm\phi} \\
& \Theta(\theta) = B P_l^m(x) = B P_l^m(\cos\theta)
\end{align}


です.ここで,正規化を与えることを考えます.つまり,


 
\begin{align}
&\int_{\Omega_\phi} |\Phi(\phi)|^2 d\phi = 1 \\
&\int_{\Omega_\theta} |\Theta(\theta)|^2 d\theta = 1 \\
\end{align}


を満たさせることとします.


 
\begin{align}
&\int_{\Omega_\phi} |\Phi(\phi)|^2 d\phi = \int_{0}^{2\pi} |A e^{jm\phi}|^2 d\phi = 2\pi A^2 = 1 \\
&\int_{\Omega_\theta} |\Theta(\theta)|^2 d\theta = \int_{-1}^{1} |B P_l^m(x)|^2 dx  = \dfrac{2}{2l+1} \dfrac{(l+m)!}{(l-m)!} B^2 = 1 \\
\end{align}


これらより A, Bが求まり, \Phiと\Thetaは次のように書けます.


 
\begin{align}
& \Phi(\phi) = \dfrac{1}{\sqrt{2\pi}} e^{jm\phi} \\
& \Theta(\theta) = \sqrt{ \dfrac{2l+1}{2} \dfrac{(l-m)!}{(l+m)!} } P_l^m(\cos\theta)
\end{align}


さて,これの積を改めて[tex: Y_lm(\theta, \phi)]と書いて,


 
\begin{align}
Y_l^m(\theta, \phi) &= \sqrt{ \dfrac{2l+1}{4\pi} } \sqrt{ \dfrac{(l-m)!}{(l+m)!} } P_l^m(\cos\theta)e^{ jm\phi} \\
                               &= \sqrt{ \dfrac{2l+1}{4\pi} } \sqrt{ \dfrac{(l-m)!}{(l+m)!} } P_l^m(x)e^{jm\phi} \\
                               &=  \sqrt{ \dfrac{2l+1}{4\pi} } \sqrt{ \dfrac{(l-m)!}{(l+m)!} } \dfrac{(-1)^m}{2^l l!} (1-x^2)^{m/2} \dfrac{d^{m+l}}{dx^{m+l}}(x^2-1)^l e^{jm\phi}
\end{align}


先ほど示したルジャンドル陪関数の mの正負に関する関係


 
\begin{align}
P_l^{-m}(x) &= (-1)^m \dfrac{(l-m)!}{(l+m)!} P_l^m(x)
\end{align}


より,


 
\begin{align}
Y_l^{-m}(\theta, \phi) &= \sqrt{ \dfrac{2l+1}{4\pi} } \sqrt{ \dfrac{(l+m)!}{(l-m)!} } (-1)^m \dfrac{(l-m)!}{(l+m)!} P_l^m(x) e^{-jm\phi} \\
                                   &= (-1)^m Y_l^{m*}(\theta, \phi)
\end{align}


であるので,負の mを考慮した式は


 
\begin{align}
Y_l^m(\theta, \phi) &= (-1)^{(|m|+m)/2} \sqrt{ \dfrac{2l+1}{4\pi} } \sqrt{ \dfrac{(l-|m|)!}{(l+|m|)!} } \dfrac{1}{2^l l!} (1-x^2)^{|m|/2} \dfrac{d^{|m|+l}}{dx^{|m|+l}}(x^2-1)^l e^{jm\phi}
\end{align}


となります.これを球面調和関数と言います.ルジャンドル陪関数[tex: P_lm(\cos\theta)]の lに関する直交性と先程の正規化の処理に加えて, \Phi(\phi)の直交性


 
\begin{align}
\int_{0}^{2\pi} \Phi_m(\phi) \Phi_n^{*} (\phi) d\phi &= \dfrac{1}{2\pi} \int_{0}^{2\pi} e^{jm\phi} e^{-jn\phi}d\phi \\
                                                                                 &= \dfrac{1}{2\pi} \int_{0}^{2\pi} e^{j(m-n)\phi} d\phi \\
                                                                                 &= \delta_{mn}
\end{align}


より,球面調和関数には正規直交性が成立します.また,球面調和関数には完全性があり,球面上の連続で滑らかな関数 f(\theta, \phi)が球面調和関数系の線形結合


 
\begin{align}
f(\theta, \phi) = \sum_{l=0}^{\infty} \sum_{m = -l}^{l} f_l^m Y_l^m(\theta, \phi)
\end{align}


として一意に表せます.つまり,球面上で定義される関数を展開できるということです.実数上の関数を級数展開するあれと同じですね.私の知識では完全性の証明をすることは出来ませんでした.無念(ワイエルシュトラスの近似定理なるものを用いて色々やってる証明を見つけましたが理解できませんでした......)

さて,現状の球面調和関数系を使用しても良いのですが,近似する対象が実関数である場合は実数の球面調和関数を基底として扱いたいです.これを \mathcal{Y}(\theta, \phi)として,


 
\begin{align}
case(m = 0) \; &: \; \mathcal{Y}(\theta, \phi) = Y_l^m(\theta, \phi) \\
case(m > 0) \; &: \; \mathcal{Y}(\theta, \phi) = \dfrac{(-1)^m Y_l^m(\theta, \phi) + Y_l^{-m}(\theta, \phi)} {\sqrt{2}} \\
case(m < 0) \; &: \; \mathcal{Y}(\theta, \phi) = \dfrac{(-1)^{|m|} Y_l^{|m|}(\theta, \phi) - Y_l^{-|m|}(\theta, \phi)} {\sqrt{2}j} \\
\end{align}


と定義してあげることで実数球面調和関数が得られます.

エンコーダーとしての球面調和関数

長くなりましたが,Spherical Harmonic Encodingでは視線の方向をエンコードします.ここで,視線の方向は長さが1の3次元ベクトル (x, y, z)です.これは半径が1の球面上の点と見ることが出来ます.つまり,視線の方向を入力とする関数は,単位球面上を定義域とする関数として見ることが出来ます.ここで,球面調和関数の完全性より,球面上で定義される関数が球面調和関数の(無限の)基底の線形結合で表せました.全結合層は(有限の)基底を線形結合する(ことにより関数の応答を近似する)ということを考えると,Spherical Harmonic Encodingは「視線の方向の基底を変換し,よりパラメーター次元を増やすものである」と考えられると私は解釈しています.ただし,この結論に関しては参考文献などがあるわけではないので違うかもしれません.

Spherical Harmonic Encodingの実装

さてさて,では実装に取り掛かりましょう.これまでの理論なしにも実装自体は簡単に出来ますので実装を説明します.

MFFM_DEVICE void Encode_SH_L4(__half* input, __half* Out) {
    const int bx = blockIdx.x;
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;
    const int inidx = 3 * (32 * ty + tx);
    const int outidx = (16 + SKEW) * (32 * ty + tx);
    // 1ブロック128バッチを担当する
    if (32 * ty + tx >= ONEBATCH_SIZE) {
        return;
    }

    const __half x = input[inidx + 0];
    const __half y = input[inidx + 1];
    const __half z = input[inidx + 2];

    const __half xx = x * x;
    const __half yy = y * y;
    const __half zz = z * z;
    const __half xy = x * y;
    const __half xz = x * z;
    const __half yz = y * z;
    const __half xyz = x * y * z;

    const __half r2 = 1.4142135623730950488016887242097f;
    const __half r3 = 1.7320508075688772935274463415059f;
    const __half r5 = 2.2360679774997896964091736687313f;
    const __half r7 = 2.6457513110645905905016157536393f;
    const __half r15 = 3.8729833462074168851792653997824f;
    const __half r21 = 4.582575694955840006588047193728f;
    const __half r35 = 5.9160797830996160425673282915616f;
    const __half r105 = 10.246950765959598383221038680521f;
    const __half rpi = 1.7724538509055160272981674833411f;

    __syncthreads();

    // L = 0
    Out[outidx + 0] = (__half)1.0f / ((__half)2.0f * rpi);

    // L = 1
    Out[outidx + 1] = (r3 / ((__half)2.0f * rpi)) * y;
    Out[outidx + 2] = (r3 / ((__half)2.0f * rpi)) * z;
    Out[outidx + 3] = (r3 / ((__half)2.0f * rpi)) * x;

    // L = 2
    Out[outidx + 4] = (r15 / ((__half)2.0f * rpi)) * xy;
    Out[outidx + 5] = (r15 / ((__half)2.0f * rpi)) * yz;
    Out[outidx + 6] = (r5 / ((__half)4.0f * rpi)) * ((__half)3.0f * z * z - (__half)1.0f);
    Out[outidx + 7] = (r15 / ((__half)2.0f * rpi)) * xz;
    Out[outidx + 8] = (r15 / ((__half)4.0f * rpi)) * (xx - yy);

    // L = 3
    Out[outidx + 9] = (r2 * r35 / ((__half)8.0f * rpi)) * y * ((__half)3.0f * xx - yy);
    Out[outidx + 10] = (r105 / ((__half)2.0f * rpi)) * xyz;
    Out[outidx + 11] = (r2 * r21 / ((__half)8.0f * rpi)) * y * ((__half)-1.0f + (__half)5.0f * zz);
    Out[outidx + 12] = (r7 / ((__half)4.0f * rpi)) * z * ((__half)5.0f * z * z - (__half)3.0f);
    Out[outidx + 13] = (r2 * r21 / ((__half)8.0f * rpi)) * x * ((__half)-1.0f + (__half)5.0f * zz);
    Out[outidx + 14] = (r105 / ((__half)4.0f * rpi)) * (xx - yy) * z;
    Out[outidx + 15] = (r2 * r35 / ((__half)8.0f * rpi)) * x * (xx - (__half)3.0f * yy);

    for (int i = 16; i < 16 + SKEW; i++) {
        Out[outidx + i] = 0.0f;
    }
}

部分的にみていきましょう.

MFFM_DEVICE void Encode_SH_L4(__half* input, __half* Out) {
...

・input: 視線の方向が(x, y, z)の形で格納されています.
・Out: エンコード結果を格納するポインタです.

const int bx = blockIdx.x;
const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int inidx = 3 * (32 * ty + tx);
const int outidx = (16 + SKEW) * (32 * ty + tx);

・inidx: 実行中のスレッドにて処理する入力ベクトルを指すポインタへアクセスするためのインデックスです.
・outidx: 実行中のスレッドにてエンコード結果を格納するポインタへアクセスするためのインデックスです.

// 1ブロック128バッチを担当する
if (32 * ty + tx >= ONEBATCH_SIZE) {
    return;
}

Multiresolution Hash Encodingと同じことをしています.

const __half x = input[inidx + 0];
const __half y = input[inidx + 1];
...
const __half rpi = 1.7724538509055160272981674833411f;

入力ベクトルをロードし,各計算に必要な定数を置いてます.

__syncthreads();

Multiresolution Hash Encodingと同じ役割です.

...
// L = 0
Out[outidx + 0] = (__half)1.0f / ((__half)2.0f * rpi);
...
Out[outidx + 15] = (r2 * r35 / ((__half)8.0f * rpi)) * x * (xx - (__half)3.0f * yy);
...

実数球面調和関数の基底を計算しています.今回は l = 3まで計算します.

for (int i = 16; i < 16 + SKEW; i++) {
    Out[outidx + i] = 0.0f;
}

SKEWの部分を0埋めしてます.

以上です.当然学習パラメーターはありません.基底の計算式は理論のところで示した実数球面調和関数に対して座標系を極座標から直交座標系へ変換してあげれば良いのですが,面倒なので球面調和関数表を参照しましょう.Wikipediaにもあります.

Multiresolution Hash Encodingによる2次元画像の近似(2次元の関数の近似)

まずはこちら側のエンコーダーによる効果を見ていきましょう.2次元画像は「座標(2次元)を与えると色(RGBとします)を返す関数」として見ることが出来ます.では,次の画像を近似しましょうかね.

画像サイズは256x256です.4年ぐらい前にBlenderで作った3Dモデルのレンダリング画像です.画像上の座標を[0, 1]に正規化して入力しました.なお,3DのMultiresolution Hash Encodingなのでz座標が求められますが,これを0.5と固定しました.さて,次の設定で学習しました.
Multiresolution Hash Encodingのパラメーター

 
L = 16  \\
T = 2^{20} \\
F = 2 \\
E = 0 \\
N_{min} = 2 \\
b = 2.2

MLPのパラメーター
隠れ層次元: 64
出力層次元: 3
隠れ層の数: 4
入力層と隠れ層の活性化関数: ReLU
出力層の活性化関数: Sigmoid
学習設定
誤差関数はHubor(閾値0.05)
最適化関数がAdamの場合は学習率0.01
最適化関数がGradient Descendantの場合は学習率2.0

結果を見ていきましょう.まずは誤差の変化は次の図のようになりました.誤差は平均二乗誤差です.

各最適化関数による学習中の出力は次の画像に示す通りです.各画像内の左上にある数値はその画像を出力したときのイテレーションです.
Gradient Descendant

Adam

両者の学習する過程はかなり異なっていて面白いですね.それよりも,Part1では簡単な1次元関数でも近似させるのが困難であったのに,Multiresolution Hash Encodingを通すことで非常に近似精度が向上しましたね.では,エンコーダー無しの場合を見ていきましょう.画像上の座標を[-0,0001, 0.0001]に正規化してNNに入力しました.

MLPのパラメーター
隠れ層次元: 64
出力層次元: 3
隠れ層の数: 12
入力層と隠れ層の活性化関数: ReLU
出力層の活性化関数: Sigmoid
学習設定
誤差関数はHubor(閾値0.05)
最適化関数がAdamの場合は学習率0.01
最適化関数がGradient Descendantの場合は学習率0.5

Gradient Descendant Adam

近づいてはいますが,限界にぶつかっているように見えます.

今回は256x256の画像で確認しましたが,これをもっとピクセル数が多い画像に対しても行うことが可能です.こうしてみると,非常に強力な手法なのですが,苦しい点も勿論あります.一つは言うまでもありませんが,メモリ消費が激しいことです.そしてもう一つですが,かなり処理が遅いです.主な理由はメモリアクセスに起因しております.ハッシュ関数を利用してアクセスしているため,まずキャッシュのヒット率が辛いことになってます.そして,グローバルメモリとのやり取りが非常に多くなります.特にテーブル上のパラメーターが多いので最適化処理でかなり遅くなります.ざっくり計測した感じでは今回の画像近似においては学習処理では実行時間の90%ぐらいが最適化処理に持って行かれてます.推論処理だと処理時間の80%程度がMultiresolution Hash Encodingに持って行かれています.ただ,実装に関しては改善できる点も多いのでこの値はあまり参考にしなくていいと思います.ただしやはり遅いには遅いです.

Sphrerical Harmonic Encodingによる球面上の関数の近似

球面上の関数を近似しましょう. (x, y, z)を単位球面上の座標として,

 
\begin{align}
&\mathbf{V} = (\sin(2\pi x), \cos(5\pi y), \sin(2.5\pi z)\cos(3\pi z)) \\
&f(x, y, z) := 0.5(x, y, z) \cdot \mathbf{V} + 0.5
\end{align}


……特に意味はないです.さて,次の設定で近似しました.
MLPのパラメーター
隠れ層次元: 64
出力層次元: 1
隠れ層の数: 1
入力層と隠れ層の活性化関数: ReLU
出力層の活性化関数: Sigmoid
学習設定
誤差関数はHubor(閾値0.05)
最適化関数: Adam
学習率: 0.025

エンコーダーありとエンコーダーなし,両方について上記の設定で実行しました.誤差の発展は次の通りです.

MLP側の層がかなり小さいため,エンコーダー無しでは近似が厳しそうであることが見えます.次に,エンコーダーを使用した場合の近似の進む様子を見ます.次の図に出てくるプロット点は

 
\begin{align}
(x,y,z)(1 + 0.4 \hat{f}(x,y,z))
\end{align}

を3D空間上にプロットしており,黄色い点群が推定,黒い点群が真値です.

いい感じですね.ちなみに点群はBlenderで表示しており,プログラムからPythonスクリプトで頂点を追加する関数をテキストで出力してBlenderPythonスクリプトに貼り付けて実行することにより点群を作っています.

エンコーダー編 さいごに

2編ではInstant NeRFにおいて使用されているエンコーダーについて説明しました.これによりニューラルネットワークの基礎部分は完成しました.3編では遂にNeRFの実装に入ります.今回行った画像近似は2次元の近似です.しかし私たちが目にしているのは3次元の空間であり,今度はこれを近似する必要があります.しかし画像の近似とは異なり,3次元空間の近似はその空間の可視化(レンダリング)が容易ではなく,それゆえに色んな概念を使用してその近似が試みられています.その一つがNeRFです.詳しい話はPart3でやりましょう.

参考資料

Convolutional Sequence to Sequence Learning
Attention Is All You Need
NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis
Instant Neural Graphics Primitives with a Multiresolution Hash Encoding
Optim Tech Blog. Instant NeRF の心臓、Multiresolution Hash Encoding をシンプルに実装しつつ2次元画像で試してみる
Mathematical Methods for Physicists A Comprehensive Guide Seventh Edition 2012
宇宙物理メモ ルジャンドル陪関数
球面調和関数表 Wikipedia
Spherical Harmonic Lighting: The Gritty Details
Precomputed Radiance Transfer for Real-Time Rendering in Dynamic, Low-Frequency Lighting Environments

CUDA C++でNeRFをほぼ0から実装してみた(Part1/3): 概要~MLP編

概要

CUDA C++を使用してMLP,Multiresolution Hash Encoding,NeRFを実装しました.NVIDIAの実装と比較すると遅いですが,数分でおおよそ綺麗に学習できました.本記事では一番初めに実装の大まかな雰囲気を示し,3編に分けて具体的に実装を説明していきます.

想定する読者の対象

C++やCUDAの基本文法が分かっていれば大丈夫だと思ってます.

GitHubとかにプロジェクト全体のソースコード公開しないのか

個人的な研究用に書いてるプログラムもあったりするのでしばらくは公開しないと思います.

一応

内容には気を付けて書きましたが,もし誤りや実装の改善点等があれば教えていただけると嬉しいです.

NeRFとは

NeRFはNeural Radiance Fieldsの略称です.ある物体を様々な視点から撮った写真(2次元)の情報から3次元の形状を推定するView Synthesisの一手法です."NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis" [Mildenhall et al. 2020] において発表されました.

ざっくりとした実装

実装のレポート的な記事なので実装の話をメインでやりましょう.次のような実装をしました.

<前処理>

(1): ファイルからのカメラ情報,教師画像のロード

(2): 学習時に使用しやすい形に入力データを処理

(3): 学習に必要なパラメーターや配列等を定義

<学習ループ開始>

(4): カメラから飛ばすレイを計算

(5): レイ上の点(Position),レイの方向(Direction)をサンプリング

(6): サンプル点の情報をMLPに入力し,出力としてサンプル点の色と密度を得る

(7): カメラに入るレイの色を計算

(8): MLPの出力の誤差を計算

(9): MLP誤差逆伝播

(10): MLPのパラメーター最適化

<学習ループ終了>

本シリーズの記事においては(4) ~ (10)を3編に分けて説明しようと思います.(1/3)としてはMLPの実装,(2/3)としてはMLPの入力時に使用するEncoderの実装,(3/3)としてはNeRF全体の実装について扱います.

MLP編 概要

Fully Fused MLP [Müller et al. 2021] の概念を参考にしてMLPを実装しました.NeRFで使用するMLPは全結合層,即ち行列の乗算と活性化層の組み合わせです.ここで,行列の乗算は,行列を細かく分割した小行列の乗算を繰り返すことで計算できます.今回は16x16の小行列に分割し,そしてCUDAの提供するwmmaのAPIを使用して高速に行列演算を行い,高速にMLPの処理を行いました.さらに,GPUのglobal memoryをなるべく使用しない実装を行いました.

MLP編 はじめに

MLPとは多層パーセプトロン(Multilayer perceptron)というものですが,入力データの線形結合を施した結果を用いてニューロンの発火(活性化層)を考えるといった,神経を模したものとなっております.……結局は巨大な非線形関数です.MLPを実装するという人は勿論あまりおらず,多くの人はPyTorchやLibTorchなどのフレームワークを使用しています.もちろんそちらの方が「動作保証」もありますし,「MLPの高速化」が目的でない場合はMLPの実装自体が非本質的な過程となるため,避けるのが良いでしょう.そんな中,2021年にNVIDIAのグループがFully Fused MLPというMLPの設計を発表しました.

research.nvidia.com

今回はこのMLPの設計を元にして行ったMLPの実装を説明していこうと思います.

MLPの実装: コンセプト1: 行列乗算の分割

NeRFの学習に使用するMLPの基本構造は,全結合層(バイアス項なし)と活性化層のセットが繰り返されたものとなっています.全結合層はすなわち,行列(重み W)とベクトル(入力 x)の乗算です.ここで,行列とベクトルの乗算については,次のようにベクトルを並べて同時に行うことが出来ます.

 \displaystyle
WX =
W
\begin{pmatrix}
x_1&x_2&x_3& ... &x_N 
\end{pmatrix}
= 
\begin{pmatrix}
Wx_1&Wx_2&Wx_3& ... &Wx_N
\end{pmatrix}


機械学習的にみると,これは N個の入力ベクトルのバッチに対して同時に全結合層の処理を作用させることを意味しています.

さて,ここで簡単のため, Nは16の倍数であるとしましょう.つまり,自然数 Mが存在して N = 16Mであるとしましょう.じゃあ先ほどの式をもう一度分割しなおしましょう.

 \displaystyle
WX =
W
\begin{pmatrix}
X_1&X_2&X_3& ... &X_M 
\end{pmatrix}
= 
\begin{pmatrix}
WX_1&WX_2&WX_3& ... &WX_M
\end{pmatrix}


本当に分割しなおしただけです.先ほどの式から変わった場所としては,ベクトルだったx_1, x_2, ... x_Nが横幅16の行列 X_1, X_2, ... , X_Mとなっただけです.では今度は,出力のベクトルの次元( OutDimとします)も16の倍数としましょう.即ち,自然数 Kが存在して OutDim = 16Kを満たしているとします.この時,先ほどの式をさらに分割すると,

 \displaystyle
WX =
\begin{pmatrix}
W_1 \\
W_2  \\ 
... \\
W_K 
\end{pmatrix} 
\begin{pmatrix}
X_1& ... &X_M 
\end{pmatrix}
= 
\begin{pmatrix}
W_1 X_1 & W_1 X_2 & ... & W_1 X_M \\
W_2 X_1 & W_2 X_2 & ... & W_2 X_M \\
... \\
W_K X_1 & W_K X_2 & ... & W_K X_M \\
\end{pmatrix}


となります.まあ式見ても直感的でないので図で見ましょう.ここで入力ベクトルの次元 InDimも16の倍数とします.( 16L = InDimを満たすようにLを定義します)

こんな感じで,行列同士の掛け算は細かい小行列の掛け算に分割できます.この図においては Xにおける横幅16縦幅64の小行列(灰色で塗られていく領域)に対して Wにおける横幅64縦幅16の小行列(橙などに塗られた各々の領域)を乗算しています.じゃあ,今度はその小行列同士の掛け算に注目しましょう.この小行列同士の掛け算 をさらに分割すると,


 \displaystyle
W_i X_j = 
\begin{pmatrix}
W_{i1}&W_{i2}&...&W_{iL}
\end{pmatrix}
\begin{pmatrix}
X_{1j} \\
X_{2j} \\
...\\
X_{Lj}
\end{pmatrix}
=
\sum_{k=1}^L W_{ik}X_{kj}


となります.これによって行列の乗算を16x16の行列同士の掛け算にまで分割することが出来ました.

さて,この分割の何が嬉しいかを書きます.唐突ですがGPU上にはTensorコアといった行列乗算に最適化された演算回路があります.CUDAにおいてはWMMA(Wave Matrix Multiply-accumulate)のAPIがあります.実はこのAPIを用いた行列演算では演算に使用できる行列のサイズが固定されています(この記事に纏まっております) (n,m,k) = (16,16,16)の行列サイズを今回は採用することとしましょう.すると,これまで見てきた行列演算の分割によって,このAPIを使用した行列演算が可能となりました.このAPIを利用することでハードウェア最適化の恩恵を受けて高速に全結合層の処理ができるという塩梅です.ここで一つだけ注意点が必要ですが,この演算を行う場合,行列の各要素の値はfloat以上の精度を持っていると怒られます.wmmaのAPIから__halfという,fp16,つまり半精度浮動小数点の型が提供されているのでそれを使用する必要があります.

並列性の確認

小行列に分割したら遅くなるんじゃないかと疑問に思うかもしれませんが,実はある程度並列化可能です.

図中に示した黄色矢印はすべて並列に計算できます.よって,この図において,完全に並列化できない場所は,黄色矢印自体の演算,すなわち16x64と64x16の行列の乗算です.この部分は先ほど示した通り,16x16行列同士の乗算に分割し,加算してあげる必要があります.即ち4回のループが必要です.  Xの横方向についても並列化は可能です.しかしながら,問題(実装の説明時に説明します)が発生するため,今回では行いません.結局,この図では4x8の32回のループが必要です.

行列のサイズが16の倍数でない時

先ほどは行列のサイズが全て16の倍数であるという制約を設けましたが,実際にMLPを使用するうえで,層の次元,特に入力層や出力層が16の倍数次元になることは稀です.その場合はどうするのかというと,今回の実装では単純にゼロパディングを行い,行列サイズを16の倍数に揃えました.次の関数は今回の実装にて何度も出てきます.

// input以上の最小のKの倍数を返す
__host__ __device__ constexpr int next_multiple(const int input, const int K) {
    const int Q = input % K;
    if (Q == 0) {
        return input;
    }
    else {
        return input + K - Q;
    }
}

MLPの実装: コンセプト2: 遅いメモリと速いメモリがある

CUDAをある程度書いたことのある人はおそらくglobal memory,shared memoryと言った言葉を聞いたことがあるでしょう.
global memoryとはGPUすべてのスレッドからアクセスできるメモリで,最近のGPUだと10GBやそれ以上あったりして結構な大容量メモリです.しかしながらこのglobal memory,マジで遅いのでなるべく読み書きの回数を減らしたいです.
一方でshared memoryはカーネル実行時のブロック内で共有されているオンチップのメモリで,容量は数十KB程度のオーダーしかないものの(ないために?),非常に高速です.よって,メモリとのやり取りはなるべくshared memoryで行いたいですね.
ではまず,容量の確認から行きましょう.先ほども書きましたが,shared memoryはカーネル実行時のブロック内で共有されているメモリです.つまり,処理を実行するブロック内のスレッドが大きければ多いほど,各スレッドで使用できるshared memoryの容量は減ってしまいます.では,具体的にどういう感じにメモリを使用するか見ていきましょう.

CUDAでのMLP並列化?

私は大学の講義で,なぜかC言語でCPU上で動かすMLPを実装させられたことがあります.良い勉強になりましたが二度とやりません.それはともかく,MLPの処理の並列化とはどういうことなのでしょうか?

正確に書くと,並列化するのは「MLPバッチ処理」です.各バッチの処理は,最適化処理を除いて完全に独立です.ということで,理想的にはバッチサイズだけ並列化が可能です.CUDAにおいては,ブロック単位を意識してプログラムを書きます.これは何度も書いている通り,shared memoryがブロック内部でのみ共有されているためです.では,適度な量にバッチを分割して各ブロックの処理データ量を調整し,それらをブロック単位で扱うこととしましょう.

この図のように,全体の非常に大きなバッチを,各ブロックでshared memoryに載せられる程度に分割します.今回の実装では1ブロックにつき128バッチを処理する設計としました.よって,GridSizeは(全体のバッチ数)/ 128 ということになります.

ブロック内部の構造はどうする

これが何だかんだで一番その後の実装の全てに関わってきます.まず,CUDAの仕様として,warpという,スレッドの塊を表現する単位があります.このwarpですが,あるwarp内のスレッドは必ず同じ命令を実行します.if分岐とかがあってスレッド内部で処理が分かれた場合は,片方の分岐先の処理(if文の処理)が終わるまで他方に分岐するスレッドの処理(else if, elseの処理)は待機します.このように,warp内部での処理がバラバラになるとカーネル実行の並列性,効率が大幅に低下します.これをwarp divergenceと呼びます.今回の私の環境(RTX3080: compute 86)では32スレッドが1warpとして扱われます.よってwarpを意識したカーネル実行を行うべきですね.ということで,BlockSizeは{32u, ty, 1u}としましょう.ty(BlockSize.y)に関してはのちに決めます.というわけで,各ブロックは32*tyのスレッドを実行することとなりました.

shared memoryに載せるもの,global memoryに載せるもの

一旦まとめましょう.これまでの議論より,カーネルの実行設定は<<< nBatch/128, {32u, ty, 1u}>>>となりました.各ブロックにおいては128バッチを扱うことになりました.
今回の設計のMLPにおいて使用するメモリ領域は次の通りです.
(1): NNの入力データ(一番最初の入力)
(2): NNの出力データ(一番最後の出力)
(3): MLPを流れていくデータ
(4): 全結合層の重みパラメーター
(5): 誤差逆伝播用のバッファ
(6): バッチ全体によるパラメーターの勾配を記録する配列
(7): ブロック内部でのバッチによるパラメーターの勾配を記録する配列
(1), (2)に関してはMLP内部では扱いません.MLP全体をモジュール化した際のI/Oの役割を果たします.当然CPUとのデータ通信はglobal memoryと行われるので,global memoryに載せることになります.
(3)はMLP内部で何度も読み書きを行う配列です.MLP内部でのデータの最大次元をMaxDimとして,sizeof(half) * MaxDim * MaxDimです.MaxDimを128とすると,sizeof(half)が2byteであることに注意して,32KBです.MaxDimが64なら8KBです.shared memoryに載せましょう.
(4)(5)は順伝播と逆伝播において使用するデータです.これらは実は工夫することにより,読み書きの回数を1,2回に抑えることが可能です.また,順伝播と逆伝播で使用することを考えると,これらの二つは異なるCUDAカーネルで行われるので,shared memoryではなくglobal memoryに載せるべきでしょう.
(6)はスレッド全体で書き込むデータです.ゆえにglobal memoryに載せます.
(7): これは私の実装ではshared memoryに載せました.しかしこれは何とも言えません.MaxDimが128であるときを考えましょう.この時のこの配列が占めるメモリ領域は128 * 128 * sizeof(half) = 32KBです.(3)と合わせてみると,64KBとなります.この記事を参考にして48KB以上のshared memoryを使用することが出来ましたが,果たしてこれは適切な実装なのかは不明です.今回の記事に載せるソースコードではshared memoryに載せたものとして処理をしています.私の環境では動きます……

shared memoryの載せ方

shared memoryに載せるものを決めました.じゃあ載せましょうということで単純にmemcpy的なことをする前に少しお待ちください.shared memoryを扱う上で考慮する必要がある概念として,バンクコンフリクトというものがあります.具体的な説明は他の記事に任せるとして,ざっくり書くと,shared memoryはメモリ領域が16とか32とかの数のバンクに振り分けられており,同じバンクに複数のスレッドが同時にアクセスするとメモリの処理が逐次的になり,ゆえにカーネルの性能低下につながります.実際に設計したMLPでバンクコンフリクトを意図的に起こした場合,最悪50%の性能低下を起こしました.
今回の実装ではshared memoryには「層を流れる特徴ベクトルのバッチ」が載っています(今後,これをintermediateと呼びます).よって,これらは1バッチあたり16の倍数要素の配列となっております.実は,各バッチの最後にいくつかzero-paddingを行うことでバッチごとのバンクがずれ,バンクコンフリクトを軽減できます.このzero-paddingの数をSKEWとしましょう.今回はSKEWを8としましたが,他の値を試すのもいいと思います.つまり,次のようにshared memory配列を置きます.

図のように,intermediateの下側にSKEWを与えます.重み Wはglobal memoryに載せているため,当然SKEWは与えません.

MLPの実装: 順伝播 外観

さあ,これまでに説明した概念を用いて順伝播を実装していきましょう.私の実装はこちらです.

/*
 * 順伝播(学習あり)
 *
 * inDim -> HiddenDim -> HiddenDim -> HiddenDim -> ... -> HiddenDim -> outDim
 *          |<------------------- nHiddenLayer ------------------->|
 */
template <const uint32_t indim, const uint32_t hiddendim, const uint32_t outdim, const uint32_t nHiddenLayer>
MFFM_DEVICE void Kernel_Debug_train_forward(const uint32_t BatchSize, Activation ActHid, Activation ActOut, __half* intermediate, __half* weights, __half* buffers) {

    constexpr uint32_t indim_aligned = next_multiple(indim, TENSOR_ROW);
    constexpr uint32_t hiddendim_aligned = next_multiple(hiddendim, TENSOR_ROW);
    constexpr uint32_t outdim_aligned = next_multiple(outdim, TENSOR_ROW);

    uint32_t weights_elem_idx = 0;
    uint32_t buffers_elem_idx = 0;
    uint32_t buffers_elem_idx_blockbias = indim_aligned * ONEBATCH_SIZE * blockIdx.x;

    __syncthreads();

    // MLPの入力をback propagationのために保存
    int shmem_curDim = indim_aligned + SKEW;
    store_intermediate<__half>(shmem_curDim, indim_aligned, intermediate, buffers);

    // 入力層 INDIM(_ALIGNED) -> HIDDENDIM(_ALIGNED)
    MLP_Forward<indim, hiddendim, indim_aligned, hiddendim_aligned>(ActHid, intermediate, weights);

    shmem_curDim = hiddendim_aligned + SKEW;

    weights_elem_idx += indim_aligned * hiddendim_aligned;
    buffers_elem_idx += BatchSize * indim_aligned;
    buffers_elem_idx_blockbias = hiddendim_aligned * ONEBATCH_SIZE * blockIdx.x;

    // 隠れ層 HIDDENDIM(_ALIGNED) -> HIDDENDIM(_ALIGNED)
    for (int i = 0; i < nHiddenLayer - 1; i++) {
        // MLPの入力をback propagationのために保存
        store_intermediate<__half>(shmem_curDim, hiddendim_aligned, intermediate, buffers + buffers_elem_idx);

        MLP_Forward<hiddendim, hiddendim, hiddendim_aligned, hiddendim_aligned>(ActHid, intermediate, weights + weights_elem_idx);
        weights_elem_idx += hiddendim_aligned * hiddendim_aligned;
        buffers_elem_idx += BatchSize * hiddendim_aligned;
        __syncthreads();
    }

    // MLPの入力をback propagationのために保存
    store_intermediate<__half>(shmem_curDim, hiddendim_aligned, intermediate, buffers + buffers_elem_idx);
    // 出力層 HIDDENDIM(_ALIGNED) -> OUTDIM(_ALIGNED)
    MLP_Forward<hiddendim, outdim, hiddendim_aligned, outdim_aligned>(ActOut, intermediate, weights + weights_elem_idx);

    __syncthreads();

    shmem_curDim = outdim_aligned + SKEW;

    // 結果をback propagationのために記録
    buffers_elem_idx += BatchSize * hiddendim_aligned;
    buffers_elem_idx_blockbias = outdim_aligned * ONEBATCH_SIZE * blockIdx.x;
    store_intermediate<__half>(shmem_curDim, outdim_aligned, intermediate, buffers + buffers_elem_idx);
}

順番に一つずつ見ていきましょう.

template <const uint32_t indim, const uint32_t hiddendim, const uint32_t outdim, const uint32_t nHiddenLayer>
...

indim: NNの入力層の次元
hiddendim: NNの隠れ層の次元(すべての隠れ層で等しいとします)
outdim: NNの出力層の次元 nHiddenLayer: NNの隠れ層の数

これはコンパイル時に確定させ,実行中には変化させないものとします.

MFFM_DEVICE void Kernel_Debug_train_forward(const uint32_t BatchSize, Activation ActHid, Activation ActOut, __half* intermediate, __half* weights, __half* buffers) {
    ...
}

MFFM_DEVICEは deivice に同じです.
BatchSize: 文字通りバッチサイズです.
Activation構造体:

enum class Activation {
    ReLU,
    LeakyReLU,
    Sigmoid
};

これです.
ActHid: 入力層,隠れ層における活性化層の種類(すべての隠れ層で等しいとします)
ActOut: 出力層における活性化層の種類
intermediate: MLPの層を流れていくデータです.(shared memoryに載っています!)
weights: MLP内のすべての全結合層のパラメーター Wがこの配列に入っています.
buffers: 逆伝播時に使用する,順伝播時のintermediateの値をすべて格納します.

...
constexpr uint32_t indim_aligned = next_multiple(indim, TENSOR_ROW);
constexpr uint32_t hiddendim_aligned = next_multiple(hiddendim, TENSOR_ROW);
constexpr uint32_t outdim_aligned = next_multiple(outdim, TENSOR_ROW);
...

早速出てきました.先ほども示した通り,入力次元や出力次元は16の倍数とは限りません.そのため,入力層次元,隠れ層次元,出力層次元のそれぞれ以上の16の倍数を求めておきます.TENSOR_ROWは16の即値です.

...
uint32_t weights_elem_idx = 0;
uint32_t buffers_elem_idx = 0;
uint32_t buffers_elem_idx_blockbias = indim_aligned * ONEBATCH_SIZE * blockIdx.x;
...

weights_elem_idx: weightsには各MLPの全結合層において使用するパラメーターがすべて保存されています.よって,各全結合層が終了するたびに,使用する全結合層が保存されているインデックスを記録しておく必要があります.
buffers_elem_idx: buffersにも順伝播時のすべてのMLP間のintermediateを保存する必要があります.よって,各MLPが終了するたびに使用するbuffersのインデックスが記録されている必要があります.
buffers_elem_idx_blockbias : buffersは全てのスレッドが使用する配列です.intermediateはブロック間で独立であるため,「そのintermediateが存在していたブロック」が逆伝播時に明確である必要があります.ゆえに,blockIdxに対応する量だけインデックスをずらしてintermediateを保存しておきます.ここで注意ですが,保存するintermediateは,本来のサイズが16の倍数でなくても,16の倍数にパディングされた状態を記録します.これの理由は逆伝播時に明らかになります.

...
// MLPの入力をback propagationのために保存
int shmem_curDim = indim_aligned + SKEW;
store_intermediate<__half>(shmem_curDim, indim_aligned, intermediate, buffers);
...

shmem_curDim: 現在のshared memoryの1バッチの次元を記録します.コンセプト2の最後に述べましたが,shared memoryにはバンクコンフリクトを回避するため,SKEWというパディングを与えます.
store_intermediate: 後に説明します.逆伝播時に使用するため,buffersにintermediateの値を記録します.

...
// 入力層 INDIM(_ALIGNED) -> HIDDENDIM(_ALIGNED)
MLP_Forward<indim, hiddendim, indim_aligned, hiddendim_aligned>(ActHid, intermediate, weights);

shmem_curDim = hiddendim_aligned + SKEW;

weights_elem_idx += indim_aligned * hiddendim_aligned;
buffers_elem_idx += BatchSize * indim_aligned;
buffers_elem_idx_blockbias = hiddendim_aligned * ONEBATCH_SIZE * blockIdx.x;
...

MLP_Forward: 後に説明します.順伝播時のMLPの処理です.
shmem_curDim: 入力層が終わってintermediateの次元がhiddendimになりました.shared memoryに適したサイズを与えます.
buffers_elem_idx: 入力層が終わったのでbuffersのインデックスを更新します.入力層のデータが全バッチ入るように値を加えます.
buffers_elem_idx_blockbias: 先ほどと同様です.blockIdxに対応する量だけズラします.

...
// 隠れ層 HIDDENDIM(_ALIGNED) -> HIDDENDIM(_ALIGNED)
for (int i = 0; i < nHiddenLayer - 1; i++) {
    // MLPの入力をback propagationのために保存
    store_intermediate<__half>(shmem_curDim, hiddendim_aligned, intermediate, buffers + buffers_elem_idx);

    MLP_Forward<hiddendim, hiddendim, hiddendim_aligned, hiddendim_aligned>(ActHid, intermediate, weights + weights_elem_idx);
    weights_elem_idx += hiddendim_aligned * hiddendim_aligned;
    buffers_elem_idx += BatchSize * hiddendim_aligned;
    __syncthreads();
}

// MLPの入力をback propagationのために保存
store_intermediate<__half>(shmem_curDim, hiddendim_aligned, intermediate, buffers + buffers_elem_idx);
...

隠れ層の処理です.処理は入力層と同じですね.

...
// 出力層 HIDDENDIM(_ALIGNED) -> OUTDIM(_ALIGNED)
MLP_Forward<hiddendim, outdim, hiddendim_aligned, outdim_aligned>(ActOut, intermediate, weights + weights_elem_idx);
shmem_curDim = outdim_aligned + SKEW;

// 結果をback propagationのために記録
buffers_elem_idx += BatchSize * hiddendim_aligned;
buffers_elem_idx_blockbias = outdim_aligned * ONEBATCH_SIZE * blockIdx.x;
store_intermediate<__half>(shmem_curDim, outdim_aligned, intermediate, buffers + buffers_elem_idx);

出力層の処理です.

以上で順伝播の外見は出来上がりです.今回の関数には入っていないが説明すべき関数が一つあります(入力データのロードと出力データのストアを行っていないですよね?).また,関数内でも説明を省いた関数が2つあります.それらについて見ていきましょう.

MLPの実装: 順伝播 I/O

先ほど示した関数には出てきていないですが,load_input()という関数があります.

////////////////////////////////////////// LOAD AND STORE ///////////////////////////////////////////////////////////
/*
 * inDim * BatchSizeの入力データがCPUからGPUのグローバルメモリに保存された
 * 入力データ1つの次元はinDimであるが,これを16の倍数にゼロパディングする
 * バンクコンフリクト対策のため,skewだけさらに次元を拡張する(ゼロパディングする).すなわち入力データ1つにつきnext_multiple(inDim, 16) + skew次元になる.これをnewDimとする
 * ここで,1Blockにつき,128バッチを担当する.
 * また,カーネル実行の設定はBlockNum = GridSize = (BatchSize/128, 1, 1),BlockSize = (32, MAX_REQUIRED, 1)
 * 以上より,BlockIdx.x = bx, ThreadIdx = tx, tyとすると,
 *
 * input + inDim * 128 * bx + inDim * 32 * ty + inDim * tx からのinDim要素を
 * shmem_input + newDim * 32 * ty + newDim * txからのinDim要素に格納する
*/
MFFM_DEVICE void load_input(const int inDim, const int newDim, const float* __restrict__ input, __half* __restrict__ shmem_input) {
    const int bx = blockIdx.x;
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;

    const int warp_index_required = ONEBATCH_SIZE / 32;

    if (ty >= warp_index_required) {
        return;
    }

    for (int i = 0; i < inDim; i++) {
        shmem_input[newDim * 32 * ty + newDim * tx + i]
            = (__half)input[inDim * ONEBATCH_SIZE * bx + inDim * 32 * ty + inDim * tx + i];
    }
        for (int i = inDim; i < newDim; i++) {
            shmem_input[newDim * 32 * ty + newDim * tx + i] = 0.0f;
        }
}

さて,私の実装時のコメントが凄く丁寧ですが,このコメントを書いたときはまだ頭がこんがらがっておりました.こういうコメントはやはり安全でいいですね.というわけで部分的にみていきましょう.

MFFM_DEVICE void load_input(const int inDim, const int newDim, const float* __restrict__ input, __half* __restrict__ shmem_input) {
...

inDim: 入力データの次元
newDim: 入力データの次元を,それ以上の最小の16の倍数にし(next_multiple(inDim, 16)),それにSKEWを与える.つまり入力データをロードしたあとのintermediateの1バッチの次元
input: 入力データ(global memory) shmem_input: 入力データをshared memory上に,適切な構造となるようにコピーしたもの

...
const int warp_index_required = ONEBATCH_SIZE / 32;
if (ty >= warp_index_required) {
    return;
}
...

先ほどはカーネルの実行時にBlockSize.xを32とし,BlockSize.yの値を未定としていました.今回の関数における処理ではブロックごとに128バッチ,つまりONEBATCH_SIZEのみ読み込めばよいため,各スレッドで1バッチロードするとすればthreadIdx.yはONEBATCH_SIZE/32以上は必要ありません.ゆえにreturnしてもらいます.

...
for (int i = 0; i < inDim; i++) {
    shmem_input[newDim * 32 * ty + newDim * tx + i]
        = (__half)input[inDim * ONEBATCH_SIZE * bx + inDim * 32 * ty + inDim * tx + i];
}
for (int i = inDim; i < newDim; i++) {
    shmem_input[newDim * 32 * ty + newDim * tx + i] = 0.0f;
}
...

嫌な式ですね.バグ埋め込んだ時に眺めるのが苦痛な処理です.図で説明しましょう.

図に色々詰め込んでいますが,「バッチの次元を16の倍数にし,SKEWを与えるよ.増加した次元部分には0を埋め込むよ」と言っているだけです.ty=2の下に残っている矢印は消し忘れです.
この図においてメモリ上の配置が図中下部に示されています.図示上は2次元配列に見えますが,実際には1次元配列で扱います.inputは全てのバッチの入力データを保持しています.よって,blockIdx.xが処理すべきインデックスを指す必要があります.一つのブロックが扱うバッチは128なので,128 * indim * blockIdx.xのバイアスを与えると良いことになります.この状態が図中の左です.さらに,各スレッドはそれぞれ対応する1バッチをinputからshmem_inputにコピーします.よって,図中左におけるk番目のバッチを処理する際には,input側ではindim * k,shmem_input側ではnewdim * kのインデックス調整が必要です.インデックス調整が完了したら,そこからindim要素をコピーしてあげると入力データのロードが完了です.ちなみに2つ目のforを抜かしているとnanが出ることがあります(ゴミ値が入ってるため).

この処理を逆方向に行うのがstore_intermediate()です.

/*
 * outDim * BatchSizeの入力データをGPUのグローバルメモリに保存する
 * 現在所持している出力データはnext_multiple(outDim, 16) + SKEWの次元である.
 * これをoutDim次元に整形して出力したい
 *
 * ここで,1Blockにつき,128バッチを担当する.
 * また,カーネル実行の設定はBlockNum = GridSize = (BatchSize/128, 1, 1),BlockSize = (32, MAX_REQUIRED, 1)
 * 以上より,BlockIdx.x = bx, ThreadIdx = tx, tyとすると,
 *
 * shmem_output + shmem_outDim * 32 * ty + shmem_outDim * tx からのoutDim要素を
 * output + outDim * 128 * bx + outDim * 32 * ty + outDim * tx からのoutDim要素に格納する
*/
template<typename T>
MFFM_DEVICE void store_intermediate(const int shmem_curDim, const int curDim, const __half* __restrict__ shmem_intermediate, T* __restrict__ destination) {
    const int bx = blockIdx.x;
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;

    const int warp_index_required = ONEBATCH_SIZE / 32;

    if (ty >= warp_index_required) {
        return;
    }

    for (int i = 0; i < curDim; i++) {
        destination[curDim * ONEBATCH_SIZE * bx + curDim * 32 * ty + curDim * tx + i]
            = (T)shmem_intermediate[shmem_curDim * 32 * ty + shmem_curDim * tx + i];
    }
}

shmem_curDim: 現在のデータの次元 (indim, hiddendim, outdimのいずれか)以上の最小の16の倍数にSKEWを与えたもの
curDim: 現在のデータの次元
shmem_intermediate: intermediate
destination: 出力先
残りの処理は先ほど示した図と見比べてみてください.

MLPの実装: 順伝播 MLP_Forward

順伝播のメインディッシュです.この関数では全結合層 + 活性化層の処理を行います.すなわち
 Y = Activation(WX) を行います.実際にはYとXは両方ともにintermediateという領域を使用します.
wmmaを使用するために次のヘッダーファイルをインクルードします.CUDAを利用していれば既に存在しているはずです.

#include <mma.h>

では関数を見ていきましょう.

/////////////////////////////// MLP IMPLEMENTATION //////////////////////////////////////////////////

/*
 * inDim ---(FC)--> outDim ---(Activation)--> outDim
 * return: ActivationFunc(matmul(weight,intermediate))
 */
template <const int inDim, const int outDim, const int inDim_aligned, const int outDim_aligned>
MFFM_DEVICE void MLP_Forward(Activation activation, __half* __restrict__ intermediate, __half* __restrict__ weight, __half* __restrict__ buffer = nullptr) {
    using namespace nvcuda;

    // Y = WXを順伝播時に行う.(正確には Y_t = X_t * W_tを行う)
    // ここでのRow, Colとは転置前の行列の行と列の長さを表す
    constexpr int XRow = ONEBATCH_SIZE;
    constexpr int XCol = inDim;
    constexpr int WRow = inDim;
    constexpr int WCol = outDim;
    constexpr int YRow = ONEBATCH_SIZE;
    constexpr int YCol = outDim;

    constexpr int XRow_aligned = ONEBATCH_SIZE;
    constexpr int XCol_aligned = inDim_aligned;
    constexpr int WRow_aligned = inDim_aligned;
    constexpr int WCol_aligned = outDim_aligned;
    constexpr int YRow_aligned = ONEBATCH_SIZE;
    constexpr int YCol_aligned = outDim_aligned;

    constexpr int nBlock_XRow = XRow_aligned / TENSOR_ROW;
    constexpr int nBlock_XCol = XCol_aligned / TENSOR_ROW;
    constexpr int nBlock_WRow = WRow_aligned / TENSOR_ROW;
    constexpr int nBlock_WCol = WCol_aligned / TENSOR_ROW;
    constexpr int nBlock_YRow = YRow_aligned / TENSOR_ROW;
    constexpr int nBlock_YCol = YCol_aligned / TENSOR_ROW;

    const int bx = blockIdx.x;
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;

    constexpr int warp_index_required = nBlock_WCol;

    // Fragments
    wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> inputs_frag;
    wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::col_major> weights_frag[nBlock_WRow];
    wmma::fragment<wmma::accumulator, 16, 16, 16, __half> outputs_frag[nBlock_YRow];

    // weightをロード
#pragma unroll
    for (int i = 0; i < nBlock_WRow; i++) {
        // 必要ないwarpは黙らせておく
        if (ty < warp_index_required) {
            wmma::load_matrix_sync(weights_frag[i], weight + WRow_aligned * TENSOR_ROW * ty + TENSOR_ROW * i, WRow_aligned);
        }
    }

    __syncthreads();

    // 入力をロードしてY_t = X_t*W_t
#pragma unroll
    for (int i = 0; i < nBlock_YRow; i++) {
        // Yを0で初期化しておく
        wmma::fill_fragment(outputs_frag[i], __float2half(0.0f));

#pragma unroll
        for (int j = 0; j < nBlock_XCol; j++) {
            // inputsをTENSORにロード
            wmma::load_matrix_sync(inputs_frag, intermediate + TENSOR_ROW * j + (TENSOR_ROW * i) * (inDim_aligned + SKEW), inDim_aligned + SKEW);
            // matmal
            wmma::mma_sync(outputs_frag[i], inputs_frag, weights_frag[j], outputs_frag[i]);
        }

        // Activation
        Activation_Forward<__half>(activation, outputs_frag[i], outputs_frag[i]);
    }

    __syncthreads();

    // 次の入力へ記録
#pragma unroll
    for (int i = 0; i < nBlock_YRow; i++) {
        // 必要ないwarpは黙らせておく
        if (ty < warp_index_required) {
            wmma::store_matrix_sync(intermediate + TENSOR_ROW * ty + i * TENSOR_ROW * (outDim_aligned + SKEW), outputs_frag[i], outDim_aligned + SKEW, wmma::mem_row_major);
        }
    }
    __syncthreads();
}

これも一つずつ見ていきましょう.

template <const int inDim, const int outDim, const int inDim_aligned, const int outDim_aligned>
MFFM_DEVICE void MLP_Forward(Activation activation, __half* __restrict__ intermediate, __half* __restrict__ weight, __half* __restrict__ buffer = nullptr) {
...

inDim: このMLP一層における入力次元
outDim: このMLP一層における出力次元
inDim_aligned: inDim以上の最小の16の倍数
outDim_aligned: outDim以上の最小の16の倍数
activation: このMLP一層における活性化層の種類
intermediate: このMLP一層のIO weight: この一層における全結合層のパラメーター W
buffer: あるタイミングで実装を変えた際に不使用となった変数です(バッファをこの関数の外で記録するようになったためです.いつか消します.)

...
using namespace nvcuda;
...

wmmaの関数はnvcuda::wmma::の中にあります.(じゃあwmmaもusing namespaceすればいいのではって?確かに……)

...
// Y = WXを順伝播時に行う.(正確には Y_t = X_t * W_tを行う)
// ここでのRow, Colとは転置前の行列の行と列の長さを表す
constexpr int XRow = ONEBATCH_SIZE;
...
constexpr int YCol_aligned = outDim_aligned;
...

行列演算としては
 Y = WX
をするのですが,実は次が成り立ちます.

 Y^T = X^T W^T

今回は後者で実装します.いずれにせよ,転置前の行列 X, Y, Wにおける横幅サイズ(Row),SKEWを考慮しない縦幅サイズ(Col)をすべて書き出しておきました.全部は使いません.

...
constexpr int nBlock_XRow = XRow_aligned / TENSOR_ROW;
...
constexpr int nBlock_YCol = YCol_aligned / TENSOR_ROW;
...

先ほど求めた各行列の横幅サイズと縦幅サイズはすべて16の倍数です(そうなるようにゼロパディングを行っています).各行列を一マス16の正方形の集まりで表現した際の縦に並ぶブロック数と横に並ぶブロック数を計算しておきます.これらはすべてコンパイル時に確定します.

これ以降の説明に入る前に,いったん図で確認しておきましょう.

関数内で行う計算は次の図の計算です.

今回の実装では,図内で橙,緑,青,灰で塗られた4色の領域に関しては並列で計算します.wmmaの行列計算は1warpが共同作業という形で行われます.つまり,一回の16x16行列同士の掛け算には32スレッド必要ということです.今回は4色,すなわち行列の乗算タスクを4個同時にこなしていくこととなります.よって必要なwarpは4です.ここで,出力層の次元は128以下を想定しているとしましょう.これが128である場合,この配色は8色(行列乗算タスクは8個同時)になります.よって必要なwarp数は8です.
このように,MLP内部での最大次元によって必要なwarp数が異なります.ただし,先程書いたload_input()やstore_intermediate()においては4warpを必要としています.よって,カーネル実行時に設定するBlockSize.yは
 BlockSize.y = max(4, MaxDim_aligned / 16)
となります.(入力次元は考慮しなくても良いかも)
議論の余地として,このことによってカーネル実行中のInactiveなスレッドは増加する(Occupancyが低下する)ので,warp数を4に固定したうえで,これを並列数4の必要に応じた逐次実行にしておくというのも良いと思います.というかそちらもこちらで試しておくべきでした.また試します.
いずれにせよ,仮にwarp数が沢山あったとしても,この関数が必要としているwarp数は(OutDim_aligned / 16)となるわけです.よって,

...
constexpr int warp_index_required = nBlock_WCol;
...

となります.
ここで,仮にMaxDim_alignedが128の時に Xの横方向に対しても並列化を行ったとします.すると並列数は64となります.……さて,64warp必要とするのであれば,必要なBlockSizeは2048となります.これではCUDAカーネル実行が出来ません.よってこちらに関しては並列化しません.

...
// Fragments
wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> inputs_frag;
wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::col_major> weights_frag[nBlock_WRow];
wmma::fragment<wmma::accumulator, 16, 16, 16, __half> outputs_frag[nBlock_YRow];
...

wmma特有の構造体が出てきました.具体的な説明は他の記事に任せるとして,
inputs_frag: [tex: XT]の小行列です.先ほど示した図において矢印の向きが右であったのでrow_majorでロードします.
weights_frag: [tex: WT]の小行列です.先ほど示した図において矢印の向きが下であったのでcol_majorでロードします.
outputs_frag: [tex:YT]の小行列です.

ではここから演算に入ります.どういう演算なのかを確認しておきましょう.

さて,この図を参考にしながら見ていきます.

...
// weightをロード
#pragma unroll
for (int i = 0; i < nBlock_WRow; i++) {
    // 必要ないwarpは黙らせておく
    if (ty < warp_index_required) {
        wmma::load_matrix_sync(weights_frag[i], weight + WRow_aligned * TENSOR_ROW * ty + TENSOR_ROW * i, WRow_aligned);
    }
}

__syncthreads();
...

そのwarpが必要としているweight(図中で対応している色で塗られているブロック)をweights配列からロードします.すべてここでロードします.つまり,weightsからのロードはこの1回だけです.インデックスの指定は図を参考にしてください.あと,ブロック内での同期を取りたいので__syncthreads()を置いておきましょう.

// 入力をロードしてY_t = X_t*W_t
#pragma unroll
for (int i = 0; i < nBlock_YRow; i++) {
    // Yを0で初期化しておく
    wmma::fill_fragment(outputs_frag[i], __float2half(0.0f));

#pragma unroll
    for (int j = 0; j < nBlock_XCol; j++) {
        // inputsをTENSORにロード
        wmma::load_matrix_sync(inputs_frag, intermediate + TENSOR_ROW * j + (TENSOR_ROW * i) * (inDim_aligned + SKEW), inDim_aligned + SKEW);
        // matmal
        wmma::mma_sync(outputs_frag[i], inputs_frag, weights_frag[j], outputs_frag[i]);
    }

    // Activation
    Activation_Forward<__half>(activation, outputs_frag[i], outputs_frag[i]);
}

__syncthreads();

Yは未初期化の状態ではゴミ値が入っているので0で初期化しておきましょう.カウンタ変数がiのループは図中でXYにおいて上から下に移動していくブロックです.カウンタ変数がjのループは図中における移動する緑色の矢印です.
Activation_Forwardは文字通り活性化層の関数ですが,この後で記述することにします.

// 次の入力へ記録
#pragma unroll
for (int i = 0; i < nBlock_YRow; i++) {
    // 必要ないwarpは黙らせておく
    if (ty < warp_index_required) {
        wmma::store_matrix_sync(intermediate + TENSOR_ROW * ty + i * TENSOR_ROW * (outDim_aligned + SKEW), outputs_frag[i], outDim_aligned + SKEW, wmma::mem_row_major);
    }
}
__syncthreads();

演算結果を保存する処理です.データ配列の向き(図中の黒矢印)に注意して,row_majorで保存しましょう.

活性化層

活性化層の順伝播の関数が出てきました.簡単に触れておきます.

template <typename T, typename fragment_t>
MFFM_DEVICE void Activation_Forward(Activation activation, fragment_t& AfterFC, fragment_t& Activated) {
    switch (activation) {
    case Activation::ReLU:
#pragma unroll
        for (int i = 0; i < AfterFC.num_elements; i++) {
            Activated.x[i] = ((T)AfterFC.x[i] > (T)0.0f) ? AfterFC.x[i] : (T)0.0f;
        }
        break;
    case Activation::LeakyReLU:
#pragma unroll
        for (int i = 0; i < AfterFC.num_elements; i++) {
            Activated.x[i] = ((T)AfterFC.x[i] > (T)0.0f) ? AfterFC.x[i] : (T)0.05f * AfterFC.x[i];
        }
        break;
    case Activation::Sigmoid:
#pragma unroll
        for (int i = 0; i < AfterFC.num_elements; i++) {
            Activated.x[i] = (T)1.0f / ((T)1.0 + (T)expf(-AfterFC.x[i]));
        }
        break;
    default:
        printf("Invalid Activation Type\n");
        break;
    }
}

この関数ではReLU,LeakyReLU,Sigmoidの実装が書かれています.fragment_tにはwmma::fragmentが入ります.各活性化層の計算式についてはここでは説明しません.
.num_elements()はfragment_tのもつメンバ関数で,そのスレッドのfragmentが保持する要素の数を返してくれます.

順伝播のglobal memoryの関わるメモリ読み書きの確認

NNの入力(global memory)をshared memoryにロード: 1回
MLPの全結合層のパラメーターをロード: 1パラメーターにつき1回
MLPのintermediateをbuffersに保存: およそ層数回 (推論時には不要!)
NNの出力を保存: 1回

このようにみるとglobal memoryとのメモリのやり取りがほとんど無く,特に推論処理のみでは実現しうる最小回数であり,メモリ的にかなり効率が良いことが分かります.

MLPの実装: 逆伝播 外観

Forwardはかなり丁寧に説明しました.Backwardに関してはこれまでの処理を逆方向にやっていくだけなので,これまでの処理が分かっていれば特に理解に困ることはないと思います.

template <const uint32_t indim, const uint32_t hiddendim, const uint32_t outdim, const uint32_t nHiddenLayer>
MFFM_DEVICE void Kernel_Debug_train_backward(const uint32_t BatchSize, Optimize Optim, Activation ActHid, Activation ActOut, const int epoch, __half* intermediate, __half* LossDerivativeSumOfBlock, __half* weights, const __half* buffers,
    float* LossDerivativeSumALL, float* AdditionalParam) 
{

    // 即値
    constexpr uint32_t indim_aligned = next_multiple(indim, TENSOR_ROW);
    constexpr uint32_t hiddendim_aligned = next_multiple(hiddendim, TENSOR_ROW);
    constexpr uint32_t outdim_aligned = next_multiple(outdim, TENSOR_ROW);
    constexpr uint32_t weightsize = indim_aligned * hiddendim_aligned + (nHiddenLayer - 1) * hiddendim_aligned * hiddendim_aligned + hiddendim_aligned * outdim_aligned;
    const uint32_t buffersize = (indim_aligned + nHiddenLayer * hiddendim_aligned + outdim_aligned) * BatchSize;

    uint32_t weights_elem_idx = weightsize - outdim_aligned * hiddendim_aligned;
    uint32_t buffers_elem_idx = buffersize - hiddendim_aligned * BatchSize - outdim_aligned * BatchSize;
    uint32_t buffers_elem_idx_blockbias = outdim_aligned * ONEBATCH_SIZE * blockIdx.x;

    __syncthreads();
    // 出力層の逆伝播
    MLP_Backward<hiddendim, outdim, hiddendim_aligned, outdim_aligned>(BatchSize, ActOut, intermediate, weights + weights_elem_idx,
        buffers + buffers_elem_idx, LossDerivativeSumOfBlock);
    SumUp_LossDerivative<hiddendim, outdim, hiddendim_aligned, outdim_aligned>(LossDerivativeSumOfBlock, LossDerivativeSumALL + weights_elem_idx);
    __syncthreads();

    // 隠れ層の逆伝播
    for (int i = 0; i < nHiddenLayer - 1; i++) {
        weights_elem_idx -= hiddendim_aligned * hiddendim_aligned;
        buffers_elem_idx -= hiddendim_aligned * BatchSize;
        buffers_elem_idx_blockbias = hiddendim_aligned * ONEBATCH_SIZE * blockIdx.x;
        MLP_Backward<hiddendim, hiddendim, hiddendim_aligned, hiddendim_aligned>(BatchSize, ActHid, intermediate, weights + weights_elem_idx, buffers + buffers_elem_idx, LossDerivativeSumOfBlock);
        SumUp_LossDerivative<hiddendim, hiddendim, hiddendim_aligned, hiddendim_aligned>(LossDerivativeSumOfBlock, LossDerivativeSumALL + weights_elem_idx);
    }
    weights_elem_idx -= indim_aligned * hiddendim_aligned;
    buffers_elem_idx -= indim_aligned * BatchSize;
    buffers_elem_idx_blockbias = indim_aligned * ONEBATCH_SIZE * blockIdx.x;

    // 入力層の逆伝播
    MLP_Backward<indim, hiddendim, indim_aligned, hiddendim_aligned>(BatchSize, ActHid, intermediate, weights + weights_elem_idx, buffers + buffers_elem_idx, LossDerivativeSumOfBlock);
    SumUp_LossDerivative<indim, hiddendim, indim_aligned, hiddendim_aligned>(LossDerivativeSumOfBlock, LossDerivativeSumALL + weights_elem_idx);
    __syncthreads();
}

これもざっくり見ていきましょう.

template <const uint32_t indim, const uint32_t hiddendim, const uint32_t outdim, const uint32_t nHiddenLayer>
MFFM_DEVICE void Kernel_Debug_train_backward(const uint32_t BatchSize, Optimize Optim, Activation ActHid, Activation ActOut, const int epoch, __half* intermediate, __half* LossDerivativeSumOfBlock, __half* weights, const __half* buffers, float* LossDerivativeSumALL, float* AdditionalParam) {
...

indim ~ nHiddenLayer: 順伝播と同じ
BatchSize: バッチサイズ
Optimize構造体:

enum class Optimize {
    GD,
    Adam
};

これです.最適化関数を設定します.GDはGradient Descendant,Adamは名前の通りです.
ActHid, ActOut: 順伝播と同じ
epoch: 学習のイテレーション
intermediate: 順伝播と同じ
LossDerivativeSumOfBlock: コンセプト2で出てきた,(7): ブロック内部でのバッチによるパラメーターの勾配を記録する配列です.今回はshared memoryに載ってます.
weights,buffers: 順伝播と同じ
LossDerivativeSumALL: コンセプト2で出てきた,(6): バッチ全体によるパラメーターの勾配を記録する配列です.
AdditionalParam: 過去の実装では使用していた配列です.今回は使いません.というかいつか消します.

...
// 即値
constexpr uint32_t indim_aligned = next_multiple(indim, TENSOR_ROW);
constexpr uint32_t hiddendim_aligned = next_multiple(hiddendim, TENSOR_ROW);
constexpr uint32_t outdim_aligned = next_multiple(outdim, TENSOR_ROW);
constexpr uint32_t weightsize = indim_aligned * hiddendim_aligned + (nHiddenLayer - 1) * hiddendim_aligned * hiddendim_aligned + hiddendim_aligned * outdim_aligned;
const uint32_t buffersize = (indim_aligned + nHiddenLayer * hiddendim_aligned + outdim_aligned) * BatchSize;

uint32_t weights_elem_idx = weightsize - outdim_aligned * hiddendim_aligned;
uint32_t buffers_elem_idx = buffersize - hiddendim_aligned * BatchSize - outdim_aligned * BatchSize;
uint32_t buffers_elem_idx_blockbias = outdim_aligned * ONEBATCH_SIZE * blockIdx.x;

__syncthreads();
...

順伝播時にやってたことと同じです.

...
// 出力層の逆伝播
MLP_Backward<hiddendim, outdim, hiddendim_aligned, outdim_aligned>(BatchSize, ActOut, intermediate, weights + weights_elem_idx,
    buffers + buffers_elem_idx, LossDerivativeSumOfBlock);
SumUp_LossDerivative<hiddendim, outdim, hiddendim_aligned, outdim_aligned>(LossDerivativeSumOfBlock, LossDerivativeSumALL + weights_elem_idx);
__syncthreads();
...

MLP_Backward: 活性化層→全結合層の順番で逆伝播する関数です.後述します.
SumUp_LossDerivative: 各ブロックでは128バッチの誤差逆伝播によって得られたパラメーターの勾配を持っていますが,それを全スレッドで合計してあげる必要があります.それをする関数です.後述します.

隠れ層,入力層の逆伝播も出力層と同様に行います.

MLPの実装: 逆伝播 MLP_Backward

というわけで軽く逆伝播の処理を見ていきましょう.

template <const int inDim, const int outDim, const int inDim_aligned, const int outDim_aligned>
MFFM_DEVICE void MLP_Backward(const uint32_t BatchSize, Activation activation, __half* __restrict__ intermediate, __half* __restrict__ weight, const __half* __restrict__ buffer, __half* __restrict__ LossDerivative) {
    using namespace nvcuda;

    // Y = WXを順伝播時に行った.(正確には Y_t = X_t * W_tを行った)
    // ここでのRow, Colとは転置前の行列の行と列の長さを表す
    constexpr int XRow_aligned = ONEBATCH_SIZE;
    constexpr int XCol_aligned = inDim_aligned;
    constexpr int WRow_aligned = inDim_aligned;
    constexpr int WCol_aligned = outDim_aligned;
    constexpr int YRow_aligned = ONEBATCH_SIZE;
    constexpr int YCol_aligned = outDim_aligned;

    constexpr int nBlock_XRow = XRow_aligned / TENSOR_ROW;
    constexpr int nBlock_XCol = XCol_aligned / TENSOR_ROW;
    constexpr int nBlock_WRow = WRow_aligned / TENSOR_ROW;
    constexpr int nBlock_WCol = WCol_aligned / TENSOR_ROW;
    constexpr int nBlock_YRow = YRow_aligned / TENSOR_ROW;
    constexpr int nBlock_YCol = YCol_aligned / TENSOR_ROW;

    const int bx = blockIdx.x;
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;

    constexpr int warp_index_required = nBlock_WRow;

    const __half* Buffer_MLPin = buffer + inDim_aligned * ONEBATCH_SIZE * bx;
    const __half* Buffer_MLPout = buffer + inDim_aligned * BatchSize + outDim_aligned * ONEBATCH_SIZE * bx;

    // Fragments
    wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> dLdY_frag_for_dLdX;      // dLdoutputを指す
    wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::row_major> weights_frag[nBlock_WCol];
    wmma::fragment<wmma::accumulator, 16, 16, 16, __half> dLdX_frag[nBlock_XRow];               // dLdinputを指す

    wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::col_major> dLdY_frag_for_dLdW;      // dLdoutputを指す
    wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::row_major> buffer_frag[nBlock_XRow];  // MLPinを格納する       
    wmma::fragment<wmma::accumulator, 16, 16, 16, __half> dLdweights_frag[nBlock_WCol];            // dLdW

    // Activation
    // 32行のブロックごとに処理する
    if (ty * 32 + tx < ONEBATCH_SIZE) {
        Activation_Backward<__half>(activation, outDim, Buffer_MLPout + outDim_aligned * (ty * 32 + tx),
            intermediate + (outDim_aligned + SKEW) * (ty * 32 + tx),
            intermediate + (outDim_aligned + SKEW) * (ty * 32 + tx));
    }

    // weightをロード
#pragma unroll
    for (int i = 0; i < nBlock_WCol; i++) {
        // 必要ないwarpは黙らせておく
        if (ty < warp_index_required) {
            wmma::load_matrix_sync(weights_frag[i], weight + TENSOR_ROW * ty + WRow_aligned * TENSOR_ROW * i, WRow_aligned);
        }
    }

    // MLPinをロード
#pragma unroll
    for (int i = 0; i < nBlock_XRow; i++) {
        // 必要ないwarpは黙らせておく
        if (ty < warp_index_required) {
            wmma::load_matrix_sync(buffer_frag[i], Buffer_MLPin + TENSOR_ROW * ty + XCol_aligned * TENSOR_ROW * i, XCol_aligned);
        }
    }
    __syncthreads();

    // dLdX_t = dLdY_t * W
#pragma unroll
    for (int i = 0; i < nBlock_XRow; i++) {
        // 初期化
        wmma::fill_fragment(dLdX_frag[i], __float2half(0.0f));

#pragma unroll
        for (int j = 0; j < nBlock_YCol; j++) {
            // dLdoutputをTENSORにロード
            wmma::load_matrix_sync(dLdY_frag_for_dLdX, intermediate + TENSOR_ROW * j + (TENSOR_ROW * i) * (YCol_aligned + SKEW), YCol_aligned + SKEW);
            // matmal
            wmma::mma_sync(dLdX_frag[i], dLdY_frag_for_dLdX, weights_frag[j], dLdX_frag[i]);
        }
    }

    // dLdW = dLdY * X_t
#pragma unroll
    for (int i = 0; i < nBlock_WCol; i++) {
        // 初期化
        wmma::fill_fragment(dLdweights_frag[i], __float2half(0.0f));

#pragma unroll
        for (int j = 0; j < nBlock_XRow; j++) {
            // dLdoutputをTENSORにロード
            wmma::load_matrix_sync(dLdY_frag_for_dLdW, intermediate + TENSOR_ROW * i + TENSOR_ROW * (outDim_aligned + SKEW) * j, outDim_aligned + SKEW);
            // matmal
            wmma::mma_sync(dLdweights_frag[i], dLdY_frag_for_dLdW, buffer_frag[j], dLdweights_frag[i]);
        }
    }

    __syncthreads();

    // 次の入力へ記録
#pragma unroll
    for (int i = 0; i < nBlock_XRow; i++) {
        // 必要ないwarpは黙らせておく
        if (ty < warp_index_required) {
            wmma::store_matrix_sync(intermediate + TENSOR_ROW * ty + i * TENSOR_ROW * (inDim_aligned + SKEW), dLdX_frag[i], inDim_aligned + SKEW, wmma::mem_row_major);
        }
    }
#pragma unroll
    for (int i = 0; i < nBlock_WCol; i++) {
        // 必要ないwarpは黙らせておく
        if (ty < warp_index_required) {
            wmma::store_matrix_sync(LossDerivative + TENSOR_ROW * ty + i * TENSOR_ROW * (WRow_aligned + SKEW), dLdweights_frag[i], WRow_aligned + SKEW, wmma::mem_row_major);
        }
    }
    __syncthreads();
}

部分部分でみていきましょう.

...
template <const int inDim, const int outDim, const int inDim_aligned, const int outDim_aligned>
...

念のためここでも定義しておきます.
inDim: 順伝播時のMLPの入力次元
outDim: 順伝播時のMLPの出力次元
つまり,"inとoutの定義"は順伝播を基準とします.

...
MFFM_DEVICE void MLP_Backward(const uint32_t BatchSize, Activation activation, __half* __restrict__ intermediate, __half* __restrict__ weight, const __half* __restrict__ buffer, __half* __restrict__ LossDerivative) {
...

BatchSize: バッチサイズ
activation: このMLPの活性化層の種類
intermediate: 略
weight: このMLPの全結合層のパラメーター
buffer: このMLPの順伝播時の入力側のintermediateの先頭を指すポインタ
LossDerivative: このMLPの全結合層の,ブロックレベルの勾配を記録する配列(の先頭のポインタ)

const __half* Buffer_MLPin = buffer + inDim_aligned * ONEBATCH_SIZE * bx;
const __half* Buffer_MLPout = buffer + inDim_aligned * BatchSize + outDim_aligned * ONEBATCH_SIZE * bx;

bufferはこのMLPの順伝播時の入力側のintermediateの先頭を指すポインタなので,入力側のデータ数だけポインタをずらすことで出力側のintermediateの先頭を指すポインタを得られます.順伝播時にもやりましたが,bufferにはすべてのバッチの記録情報が入っています.なので今回でもblockIdxに応じてインデックス調整を入れてあげます.

処理的には先に活性化層がやってきます.先に活性化層を確認しましょう.

// Activation
// 32行のブロックごとに処理する
if (ty * 32 + tx < ONEBATCH_SIZE) {
        Activation_Backward<__half>(activation, outDim, Buffer_MLPout + outDim_aligned * (ty * 32 + tx),
        intermediate + (outDim_aligned + SKEW) * (ty * 32 + tx),
        intermediate + (outDim_aligned + SKEW) * (ty * 32 + tx));
    }

Activation_Backwardが活性化層の逆伝播の処理です.この関数ではブロック内にある128バッチの誤差情報を,128並列で処理します.

/*
 * Activation Back Propatation
 * DIM次元の特徴ベクトルがONEBATCH_SIZE個並べられている.
 * 各ワープはそれらのうち32個を担当する
 */
template <typename T>
MFFM_DEVICE void Activation_Backward(Activation activation, const int DIM, const __half* Activated, __half* dLdAfterFC, __half* dLdActivated) {
    __syncthreads();
    switch (activation) {
    case Activation::ReLU:
#pragma unroll
        for (int i = 0; i < DIM; i++) {
            dLdAfterFC[i] = ((T)Activated[i] > (T)0.0f) ? dLdActivated[i] : (__half)0.0f;
        }
        break;
    case Activation::LeakyReLU:
#pragma unroll
        for (int i = 0; i < DIM; i++) {
            dLdAfterFC[i] = ((T)Activated[i] > (T)0.0f) ? dLdActivated[i] : (__half)0.05f * dLdActivated[i];
        }
        break;
    case Activation::Sigmoid:
#pragma unroll
        for (int i = 0; i < DIM; i++) {
            dLdAfterFC[i] = (T)dLdActivated[i] * ((T)1.0 - Activated[i]) * Activated[i];
        }
        break;
    default:
        printf("Invalid Activation Type\n");
        break;
    }
}
template <typename T>
MFFM_DEVICE void Activation_Backward(Activation activation, const int DIM, const __half* Activated, __half* dLdAfterFC, __half* dLdActivated) {
...

activation: 略
DIM: その時点での,align処理をしていない生の次元数です.
Activated: 順伝播時の活性化層の出力です.bufferから読み出します.
dLdAfterFC: 全結合層の出力に流していく誤差情報です.(FC: Fully Connected)
dLdActivated: 活性化層の出力に流れ込んできた誤差情報です.

...
switch (activation) {
    case Activation::ReLU:
#pragma unroll
        for (int i = 0; i < DIM; i++) {
            dLdAfterFC[i] = ((T)Activated[i] > (T)0.0f) ? dLdActivated[i] : (__half)0.0f;
        }
        break;
    case Activation::LeakyReLU:
#pragma unroll
        for (int i = 0; i < DIM; i++) {
            dLdAfterFC[i] = ((T)Activated[i] > (T)0.0f) ? dLdActivated[i] : (__half)0.05f * dLdActivated[i];
        }
        break;
    case Activation::Sigmoid:
#pragma unroll
        for (int i = 0; i < DIM; i++) {
            dLdAfterFC[i] = (T)dLdActivated[i] * ((T)1.0 - Activated[i]) * Activated[i];
        }
        break;
    default:
        printf("Invalid Activation Type\n");
        break;
    }
...

順伝播時にはfragment_tとして処理中のデータが流れていましたが,今回は単純なhalf配列として扱います.この配列において,各スレッドが扱う要素のうち最初のDIM個以降はゼロパディングされている領域なので処理しません.順伝播同様,各活性化層の処理の説明はしません.

では全結合層の逆伝播の処理を見ていきます.ここで処理を確認しておきましょう.

 \displaystyle
\begin{align}
\dfrac{\partial L}{\partial X^T} &= \dfrac{\partial L}{\partial Y^T}W \\
\dfrac{\partial L}{\partial W} &= \dfrac{\partial L}{\partial Y}X^T
\end{align}

順伝播時と同じように行列を16x16の小行列に分割して計算していきます.図で確認しましょう.

 \dfrac{\partial L}{\partial X^T} = \dfrac{\partial L}{\partial Y^T}W

この計算に必要なものは WとdLdYです.データの向きに注目してfragmentにロードします.

...

// Fragments
wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::row_major> dLdY_frag_for_dLdX;      // dLdoutputを指す
wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::row_major> weights_frag[nBlock_WCol];
wmma::fragment<wmma::accumulator, 16, 16, 16, __half> dLdX_frag[nBlock_XRow];               // dLdinputを指す

...

// weightをロード
#pragma unroll
for (int i = 0; i < nBlock_WCol; i++) {
    // 必要ないwarpは黙らせておく
    if (ty < warp_index_required) {
        wmma::load_matrix_sync(weights_frag[i], weight + TENSOR_ROW * ty + WRow_aligned * TENSOR_ROW * i, WRow_aligned);
    }
}

...

// dLdX_t = dLdY_t * W
#pragma unroll
for (int i = 0; i < nBlock_XRow; i++) {
    // 初期化
    wmma::fill_fragment(dLdX_frag[i], __float2half(0.0f));

#pragma unroll
    for (int j = 0; j < nBlock_YCol; j++) {
        // dLdoutputをTENSORにロード
        wmma::load_matrix_sync(dLdY_frag_for_dLdX, intermediate + TENSOR_ROW * j + (TENSOR_ROW * i) * (YCol_aligned + SKEW), YCol_aligned + SKEW);
        // matmal
        wmma::mma_sync(dLdX_frag[i], dLdY_frag_for_dLdX, weights_frag[j], dLdX_frag[i]);
    }
}

...

// 次の入力へ記録
#pragma unroll
for (int i = 0; i < nBlock_XRow; i++) {
    // 必要ないwarpは黙らせておく
    if (ty < warp_index_required) {
        wmma::store_matrix_sync(intermediate + TENSOR_ROW * ty + i * TENSOR_ROW * (inDim_aligned + SKEW), dLdX_frag[i], inDim_aligned + SKEW, wmma::mem_row_major);
    }
}
...

図中に表示しているデータの配置方向(黒い矢印)に注意してrow_majorかcol_majorかを確認しましょう.順伝播と同様にして,最初にweightをすべてロードしておきます.また,この処理においてはこれ以上weightをロードする必要はありません.やはり1パラメーターにつき1回のロードですみます.ここまでこれはやることは順伝播と同じなので,あとは図を参考にしてfragmentの初期化に使用するintermediateのインデックスを丁寧に設定してあげてください.

 \dfrac{\partial L}{\partial W} = \dfrac{\partial L}{\partial Y}X^T

この計算に必要なものはdLdYとXです.データの向きに注目してfragmentにロードします.ここで,順伝播時にbufferにはintermediateの1バッチの次元を,本来の次元に戻さずに16の倍数次元のまま保存しました.これは,ここでfragmentにロードする際に特にレイアウトの整形作業等をすることなく直接ロードすることが出来ます.メモリ的には少し無駄ですが,こうした方が実装しやすいのと速そうなのでこうしました.では,コードを見ましょう.

...
wmma::fragment<wmma::matrix_a, 16, 16, 16, __half, wmma::col_major> dLdY_frag_for_dLdW;      // dLdoutputを指す
wmma::fragment<wmma::matrix_b, 16, 16, 16, __half, wmma::row_major> buffer_frag[nBlock_XRow];  // MLPinを格納する       
wmma::fragment<wmma::accumulator, 16, 16, 16, __half> dLdweights_frag[nBlock_WCol];            // dLdW

...

// MLPinをロード
#pragma unroll
for (int i = 0; i < nBlock_XRow; i++) {
    // 必要ないwarpは黙らせておく
    if (ty < warp_index_required) {
        wmma::load_matrix_sync(buffer_frag[i], Buffer_MLPin + TENSOR_ROW * ty + XCol_aligned * TENSOR_ROW * i, XCol_aligned);
    }
}
__syncthreads();

...

// dLdW = dLdY * X_t
#pragma unroll
for (int i = 0; i < nBlock_WCol; i++) {
    // 初期化
    wmma::fill_fragment(dLdweights_frag[i], __float2half(0.0f));

#pragma unroll
    for (int j = 0; j < nBlock_XRow; j++) {
        // dLdoutputをTENSORにロード
        wmma::load_matrix_sync(dLdY_frag_for_dLdW, intermediate + TENSOR_ROW * i + TENSOR_ROW * (outDim_aligned + SKEW) * j, outDim_aligned + SKEW);
        // matmal
        wmma::mma_sync(dLdweights_frag[i], dLdY_frag_for_dLdW, buffer_frag[j], dLdweights_frag[i]);
    }
}

...

#pragma unroll
for (int i = 0; i < nBlock_WCol; i++) {
    // 必要ないwarpは黙らせておく
    if (ty < warp_index_required) {
        wmma::store_matrix_sync(LossDerivative + TENSOR_ROW * ty + i * TENSOR_ROW * (WRow_aligned + SKEW), dLdweights_frag[i], WRow_aligned + SKEW, wmma::mem_row_major);
    }
}

図中の黒い矢印(データの配置方向)に注意してfragmentの設定を丁寧にしましょう.メモリのロードとしてはbuffer(global memory)からのロードが1回,dLdYのロードが1回(2回目)あります.dLdYに関してはdLdXの計算時にロードしていたじゃないかと思うかもしれませんが,転置しているので構造が異なります.転置しなければしなかったで今度はmatrix_bの方に行ってしまうので,やはりダメです.ここ悔しいですね.
細かい処理の部分は図を参考にしてください.

勾配の集約

さて,各ブロックにおける全結合層のパラメーターの勾配の計算は終わりました.これは128バッチの伝搬によって得られた勾配なので,全ブロック(即ち全スレッド)における勾配をaccumulateする必要があります.実はここは未だに実装が確定しきっていない箇所です.具体的な話は最後にして,とりあえず現在の私の実装を説明します.

///////////////////////////////// SUM UP THE LOSS DERIVATIVES ////////////////////////////////////////////////////////////
template<const int inDim, const int outDim, const int inDim_aligned, const int outDim_aligned>
MFFM_DEVICE void SumUp_LossDerivative(__half* __restrict__ LossDerivativeSumOfBlock, float* LossDerivativeSumALL) {
    int bx = blockIdx.x;
    int tx = threadIdx.x;
    int ty = threadIdx.y;
    int nThreads = 32 * blockDim.y;
    int threadId = 32 * ty + tx;

    const int WRow = inDim_aligned;
    const int WeightSize_this_layer = inDim_aligned * outDim_aligned;

#pragma unroll
    for (int i = threadId; i < WeightSize_this_layer; i += nThreads) {
        const int elem_idx = (nThreads * bx + i) % WeightSize_this_layer;
        const int row_idx = elem_idx % WRow;
        const int col_idx = elem_idx / WRow;
        const int elem_idx_shmem = (WRow + SKEW) * col_idx + row_idx;

        if (row_idx >= inDim || col_idx >= outDim) {
            continue;
        }

        atomicAdd(&LossDerivativeSumALL[elem_idx], (float)LossDerivativeSumOfBlock[elem_idx_shmem]);

    }
    __syncthreads();
}

部分部分で見ていきましょう.

template<const int inDim, const int outDim, const int inDim_aligned, const int outDim_aligned>
MFFM_DEVICE void SumUp_LossDerivative(__half* __restrict__ LossDerivativeSumOfBlock, float* LossDerivativeSumALL) {
...

inDim, outDim: その全結合層の入力次元と出力次元です(16の倍数にはしていない,生の次元) LossDerivativeSumOfBlock: コンセプト2で出てきた,(7): ブロック内部でのバッチによるパラメーターの勾配を記録する配列です.今回はshared memoryに載ってます. LossDerivativeSumALL: コンセプト2で出てきた,(6): バッチ全体によるパラメーターの勾配を記録する配列です.

int nThreads = 32 * blockDim.y;
int threadId = 32 * ty + tx;

const int WRow = inDim_aligned;
const int WeightSize_this_layer = inDim_aligned * outDim_aligned;

nThreads: 現在のCUDAカーネルで立っている,ブロック内部でのスレッド数(i.e. BlockSize)
threadId: ブロックレベルでのスレッド番号(全スレッド,すなわちグローバルなスレッド番号ではない)
WeightSize_this_layer: 現在の全結合層のパラメーター数です.ただし,zero-paddingした箇所も含めています.

#pragma unroll
for (int i = threadId; i < WeightSize_this_layer; i += nThreads) {
    const int elem_idx = (nThreads * bx + i) % WeightSize_this_layer;
    const int row_idx = elem_idx % WRow;
    const int col_idx = elem_idx / WRow;
    const int elem_idx_shmem = (WRow + SKEW) * col_idx + row_idx;

    if (row_idx >= inDim || col_idx >= outDim) {
        continue;
    }

    atomicAdd(&LossDerivativeSumALL[elem_idx], (float)LossDerivativeSumOfBlock[elem_idx_shmem]);
}
__syncthreads();

まあ何してるのって話ではありますよね.目的は単純で,全パラメーターの勾配をglobal memoryの領域にatomicAddしているだけです.もっと細かく見ましょう.

for (int i = threadId; i < WeightSize_this_layer; i += nThreads) {
...

立っているブロック内部のスレッドをなるべく全部使いたいです.なので,スレッド番号に対してユニークに初期の仕事(global memoryに値を格納するタスク)を振り分けます.その仕事が終われば,立っているスレッド数分だけストライドをまたいで次の仕事に向かいます.仕事がない(仕事の対象,すなわち格納するWeightのインデックスがout-of-range)場合は何もしません.もっと端的に書くと,スレッドIDがKのスレッドは

 
index \equiv K \mod nThreads

を満たすweightのインデックスの勾配をglobal memoryに格納します.

......だけではありません.次を見ましょう.

...
const int elem_idx = (nThreads * bx + i) % WeightSize_this_layer;
...

なんでbx(blockIdx)があるんだと.これは正直有効なのかは分からないのですが,blockIdxに対してもバイアスを与えています.これは何かというと,global memoryに勾配をaccumulateする際にはatomicAdd()を使用します.そのため,なるべく異なるブロックは異なるインデックスを参照してほしいものです.そのために,ブロックごとに「ブロック数が少ない場合はユニーク」なインデックスとなるように仕事を振り分けています.ただ,そんなに小さいブロック数になっていることは稀なので効果は分からないです.

...
const int row_idx = elem_idx % WRow;
const int col_idx = elem_idx / WRow;
const int elem_idx_shmem = (WRow + SKEW) * col_idx + row_idx;

if (row_idx >= inDim || col_idx >= outDim) {
    continue;
}

atomicAdd(&LossDerivativeSumALL[elem_idx], (float)LossDerivativeSumOfBlock[elem_idx_shmem]);
...

さて,計算したweightの勾配のインデックスがzero-paddingされている場所であれば飛ばします.そして,global memory上のメモリの位置をちゃんと計算して,atomicAddを行います.

逆伝播のglobal memoryの関わるメモリ読み書きの確認

NNの出力誤差をshared memoryにロード: 1回
MLPの全結合層のパラメーターをロード: 1パラメーターにつき1回
MLPの全結合層の勾配の集約(書き込み): 不明かつatomic
MLPのbufferをロード: およそ層の数だけ
NNの入力誤差を保存(不要な場合はしない): 1回
(もしかしたら他にもあるかも)
全体としてかなり抑えられてはいますが,やはり勾配の集約とbufferのロードがかなりメモリのやり取りを起こしています.実際,推論のみの処理と比べると順伝播,逆伝播ともに数倍から数十倍遅くなります.

最適化処理

実を言うと特に書くことはないです.単純に大量のスレッドを立てて,各スレッドが対応するパラメーターを最適化するだけです(勾配の集約と同じイメージです).実装は載せておきますが解説は省略いたします.

///////////////////////////////// OPTIMIZATION IMPLEMENTATION //////////////////////////////////////////////////////////////////////
template<const int inDim, const int outDim, const int inDim_aligned, const int outDim_aligned>
MFFM_DEVICE void Optimization(bool willOptimizeInsideThisKernel, const uint32_t BatchSize, Optimize optimize, __half* __restrict__ weight, float* __restrict__ dLdW, int epoch, float* __restrict__ AdditionalParam) {
    int bx = blockIdx.x;
    int tx = threadIdx.x;
    int ty = threadIdx.y;

    const int nThreads = 32 * blockDim.y * gridDim.x;
    const int threadId = blockIdx.x * 32 * blockDim.y + 32 * ty + tx;
    const int WeightSize_this_layer = inDim * outDim;

    __syncthreads();

#pragma unroll
    for (int i = threadId; i < WeightSize_this_layer; i += nThreads) {
        // サイズ調整時にゼロパディングした箇所は更新しない
        const int Rowidx = i % inDim;
        const int Colidx = i / inDim;
        const int idx = inDim_aligned * Colidx + Rowidx;

        if (!isfinite((float)dLdW[idx])) {
            dLdW[idx] = 0.0f;
            continue;
        }

        dLdW[idx] = dLdW[idx] / (float)BatchSize;

        switch (optimize) {
        case(Optimize::GD):
            weight[idx] = weight[idx] - (__half)LEARNINGRATE * (__half)dLdW[idx];
            break;
        case(Optimize::Adam):
            if (!AdamOptimize(AdditionalParam[2 * idx], AdditionalParam[2 * idx + 1], dLdW[idx], weight[idx], epoch)) {
                printf("%d %f %f \n", idx, (float)AdditionalParam[2 * idx], (float)AdditionalParam[2 * idx + 1]);
            }
            break;
        default:
            printf("Invalid Optimization Type\n");
            break;
        }
        dLdW[idx] = (__half)0.0f;
    }
    __syncthreads();
}

Adam Optimizer:

template<typename T>
MFFM_DEVICE bool AdamOptimize(float& m, float& v, float& dLdX, T& X, int t) {

    m = ADAM_BETA1 * m + (1.0f - ADAM_BETA1) * dLdX;
    v = ADAM_BETA2 * v + (1.0f - ADAM_BETA2) * dLdX * dLdX;

    float m_hat = m / (1.0f - powf(ADAM_BETA1, t));
    float v_hat = v / (1.0f - powf(ADAM_BETA2, t));

    float eps = 1e-4f;

    if (sqrtf(v_hat) + eps == 0.0f) { // ?!?!?!?!
        dLdX = 0.0f;
        return true; 
    }
    X = X - ((T)LEARNINGRATE * (T)(m_hat / (sqrtf(v_hat) + eps)));

    dLdX = 0.0f;
    return true;
}

1次元の関数近似

さて,これまでに実装した関数を組み合わせることでNNのモジュールが完成します.私の実装ではどのような挙動をするのかを確認します.まずは学習の精度から確認しましょう.

以下の関数を近似しましょう.


f(x) = 0.5 + 0.4 \sin(8x) + 0.1 \cos(20x)  \;\;\; (0 \leq x \leq 1)


つまりfは1次元の実数を1次元の実数に移します.……特に意味はないです.これを次のネットワーク構造で近似しました.
入力層の次元: 1次元
隠れ層の次元; 64次元
出力層の次元: 1次元
隠れ層の数: 3
誤差関数: Hubor Loss(閾値: 0.05)
隠れ層の活性化関数: Sigmoid
出力層の活性化関数: Sigmoid
全結合層のパラメーター初期分布: He
最適化関数: AdamもしくはGradient Descendant 学習率: Adamの場合は0.02,Gradient Descendantの場合は0.03
Adamのパラメーター:  \beta_1 = 0.9, \beta_2 = 0.99

このネットワークの最適化を行い,イテレーション毎の推定値と真値の平均二乗誤差をグラフにプロットしました.

Adam君は頑張って収束してくれましたが,Gradient Descendant君はダメでした.次に,近似の様子を見ていきましょう.

y_hatは推定値,yは真値を意味します.エンコーダーを通さない1次元を入力としているので上出来だと思います.最後に,肝心の処理速度を確認しましょう.

処理時間の計測にはC++のchronoを使用しました.計測方法としては,学習処理のforループの手前にてchronoによる時間計測を開始し,10000回のイテレーションが終わった直後にchronoで時間計測を完了します.その後,計測結果の時間を10000(イテレーション回数)で割り,1イテレーションあたりの所要時間の平均値として記録しました.最適化関数はAdamで,活性化層はすべてSigmoidです.実行GPUはGeForce RTX3080 10GBです.

隠れ層の数による処理時間の変化を示しています.横軸は対数目盛になっております.当然ですが隠れ層の数にある程度比例しているのが分かります.次元数の違いによっても,比例関係があります.

バッチサイズによる処理時間の変化を示しています.両対数グラフです.隠れ層の数は全ての計測において4としました.バッチサイズ8192までは特に変化が見られないものの,それ以降はバッチサイズに比例した時間がかかっているように見えます.

推論処理だけを行った際の処理時間を確認します.

隠れ層の数による処理時間の変化を示しています.横軸は対数目盛になっております.

バッチサイズによる処理時間の変化を示しています.両対数グラフです.隠れ層の数は全ての計測において4としました.横軸の計測範囲が学習処理の処理時間のグラフと異なることに注意してください.

両者ともに,全体としての傾向は学習処理を入れている場合と同じですが,やはりバッファへの保存や逆伝播の処理が無いという点で,推論処理のみの方が学習処理と比較して数倍から十数倍速いことが分かります.バッチサイズが8192を超えたのちには線形に増加しているのはスループットの上限が目に見えているのでしょうかね?

最後に

これまで実装してきたもので,MLPの順伝播,逆伝播,最適化が行えます.つまりNeRFにおける機械学習(ここではMLP)処理の根幹は出来上がりました.2編ではNeRFを実装するうえで不可欠となっているエンコーダーについて実装を提示し,処理を説明していこうと思います.それの後,3編では本格的にNeRFの自分の実装を説明します.
これ以降は実装中に私が感じた疑問点を述べる領域です.


今回の実装では順伝播処理と逆伝播処理が同じカーネルで実行できました.しかし,逆伝播と最適化は同じカーネルで実行することが出来ませんでした.実は,最適化関数がGradient Descendantであれば可能だったのですが,Adamを使用すると厳しいと思ってます.この理由としては,Adamを実行する際にはmやvといった値を計算します.すなわち,実行時にはすべてのバッチの伝搬と誤差集約が終わっている必要があります.ですが,CUDAの実行として,デバイスレベルでの同期はカーネル外部でないと出来ません(そもそもとして,すべてのブロックが同時には動きません).それゆえに,デバイスレベルでの同期をとるために一度カーネルの外部に出ることになります.これに対しては何か解決策的なもの取れないんでしょうかね.現在実装しているパストレの自作レンダラに機械学習のモジュールを組み込んでいるのですが,やはりカーネルを飛び出すとなれば非常にメモリ的効率が低下し,また,実装としてもかなり面倒くさいものになりますね……

実はこれが誤差集約の実装がまだ揺れている原因です.私がこのコードを書いている時には最適化処理まで一つのカーネルで貫く予定でした.しかし,それが厳しいと分かった以上,thrustのreductionをする方が速いのでは?という気持ちもあります.まだ実装していないので何とも言えないといえばそうですが.

参考資料

Real-time Neural Radiance Caching for Path Tracing
tiny-cuda-nnのgitリポジトリ
Tensorコアを使ってみた
CUDAでShared memoryを48KiB以上使うには