CUDA C++でNeRFをほぼ0から実装してみた(Part2/3): エンコーダー編

エンコーダー編 概要

ニューラルネットワークの入力データをエンコーダーに通し,ベクトルの次元を上げることでネットワークの学習効率,精度を向上させることが出来る場合があります.本記事ではInstant NeRFにおいて使用されている2種類のエンコーダーをそれぞれ説明した後に私の実装を説明し,その効果を確認します.

エンコーダー編 はじめに

ニューラルネットワークはある入力のデータ(ベクトル)を元にして様々な計算処理を施し,最終的に目的のデータ(ベクトル)を回帰推定する手法です.ここで,入力層の次元が低い場合は,その結果を正確に推定することが難しくなります [Gehring et al. 2017. Convolutional Sequence to Sequence Learning].詳しいことは3編で書きますが,NeRFにおいてはニューラルネットワークの入力として,3次元空間内での場所と方向という6次元のデータを使用し,問題となっているサンプル点の「色 (3次元)」と「濃度 (1次元)」を推定します.6次元というのは非常に小さく,これだけの情報でNeRFのネットワークを最適化するのは不可能と言っても過言ではないと思います.そこで用いられるのがエンコーダーです.今回実装したNeRFにおいてはPositional EncodingとしてMultiresolution Hash Encodingを,Directional Encodingとして球面調和関数(Spherical Harmonic Encodingと私は呼んでいます)を使用しました.

原著のNeRF [Mildenhall et al. 2020]においては,sinとcosを繰り返すエンコーダーが使用されております.このエンコーダーは "Attention Is All You Need" [Vaswani et al. 2017. ]において使用されているものと同様の物で,処理的には入力データをsinとcosカーブ上の値(フーリエ級数展開的には基底に)に「マッピング」するといったものです.このエンコーダーを使用し,原著のNeRFは確かに3次元形状を近似することが出来ましたが,精度よく学習させるためにエンコーダー後続のネットワークが巨大となり,学習コストが高いなどの課題がありました.2022年にNVIDIAより発表されたMultiresolution Hash Encoding [Müller et al. 2022]はエンコーダーの一つで,これによってNeRF等の学習が非常に高精度,高速化されました.また,原著NeRFでは3次元空間での方向ベクトルを先程のsinとcosのエンコーダー(Frequency Encodingと呼んでいます)に通していましたが,これを球面調和関数に通してエンコードするという手法がMultiresolution Hash Encodingの論文におけるNeRFの実装に使用されており,今回の私の実装もそれにならって球面調和関数によるエンコーダーを使用しました.それぞれを説明していきます.

念のため

内容には気を付けているのですが誤り等があれば教えていただけると幸いです.

Multiresolution Hash Encodingの順伝播の手続き

先述した通り,Positional Encodingの一つです.入力データは空間における座標となります.くどく書けば,この座標を高次元の特徴ベクトルに変換します.最初は簡単のため,2次元で考えることにしましょう.つまり,ある正方形の内部にサンプル点が存在するとします.これは次の図の左側の状態を指します.サンプル点が橙の丸で表されております.

ここで,「レベル」の概念を説明しておきます.今回のエンコーダーではこの正方形に含まれるサンプル点を「様々な解像度レベルで解釈」します.よって,入力座標の存在する正方形を,レベルに応じて色んな解像度で分割します.これが上に示されている図の真ん中の状態です.例えば一番上のレベル( L = 1)では正方形を2分割しており, L = 2では正方形を3分割しております.そして,サンプル点(橙の丸印)がどの正方形に含まれるかを考えます.これが先程示した図において右側の状態です.

さらに,せっかく分割したので分割した正方形の頂点(格子点)に番号を振っておきましょう.また,このサンプル点が正方形内部のどんな位置にあるのかを表現しましょう. s, tを0以上1以下の実数として,次の図のようになります.(各レベルにおいて番号や記号は独立です)

つまりレベル1においてはサンプル点は頂点(1, 0) (2, 0) (1, 1) (2, 1)のなす正方形内部にあり,レベル2においてはサンプル点は頂点(2, 0), (3, 0), (2, 1), (3, 1)のなす正方形内部にあります.

じゃあ次に,この頂点番号をとあるハッシュ関数に入れてぐちゃぐちゃにしてあげましょう.ハッシュ関数 h(\mathbf{x})とします.整数の座標 \mathbf{x}を入れると整数 h(\mathbf{x})が返ってくると思ってください.すると,レベル1においてはサンプル点は h(1,0), h(2,0), h(1,1), h(2,1)のIDがふられた正方形内部にあり,レベル2においては h(2,0), h(3,0), h(2,1), h(3,1)のIDが降られた正方形内部にあります.

……一体何のためにこんなことしてるんだって思われている気がします.では,ここで「特徴ベクトルが格納されたテーブル」をレベルごとに用意します.はい,このテーブルに先ほど計算したハッシュ関数の出力をインデックスとしてアクセスするのです.すると,「格子点にテーブル上の特徴ベクトルが対応」します.そして,その特徴ベクトルを何かしらの方法で,今回はバイリニア補完で補完するのです.図で表すと次の通りです.

スペースの都合上,正確な図には出来なかったのですが,ここで使用するテーブルはレベルごとに独立の物を使用します.つまりレベルが違えば使用するテーブルは異なります.なお,テーブル上の特徴ベクトルの次元を Fとしています.

これまでの処理によって,各レベルにおけるサンプル点の持つ特徴ベクトルがそれぞれ求まりました.では最後に,これらを結合してあげます.これによって得られるベクトルが,このエンコーダーの出力です.

ちなみに論文ではさらにこれに追加のベクトルを結合することもあると書いていますが,少なくともNeRFにおいては使用しませんし,やることは単純なので本記事では省略します.

さて,これまでの処理をもう一度確認しておきましょう(厳密な処理ではなく,雰囲気です).

(1) (x, y) = (サンプル点の座標)
<(2) レベル数だけ繰り返す>
    (2.1) レベルに対応する解像度で正方形を分割する
    (2.2) 分割された小正方形のうち,どの小正方形に(x,y)が囲まれているかを求める
    (2.3) (x, y)が小正方形のどのあたりにあるかを求める(s, t)が求まる
    (2.4) 求めた小正方形の各頂点番号をハッシュ関数に入れ,テーブル上のIDを得る
    (2.5) 求まったIDからテーブルの特徴ベクトルを各頂点に読みだす
    (2.6) 読みだした特徴ベクトルを補完し,これをサンプル点の特徴ベクトルとする
<END: (2) レベル数だけ繰り返す>
(3) 各レベルでそれぞれ得られたサンプル点の特徴ベクトルを結合する

以上で2次元の場合のMultiresolution Hash Encodingの処理は完了しました.このエンコード処理によって,入力データが (x, y)の2次元だったところが,レベル数を L,テーブル上の特徴ベクトルの次元を Fとして LF次元となりました.

記事書いておいてなんですが……

自分のMultiresolution Hash Encodingの実装は現在書き直している途中です.理由としてはかなり乱れているためです.今回は書き直す前のコードで説明しますが,(もちろんこの記事に限った話ではないことですが)この記事に載せているコードをコピペするのではなく,実装の流れを確認して実装自体はそちらで行っていただくことをお勧めします.実装が改善された際に時間があれば書き直すかもしれません(このセクションが存在している限りは書き直されておりません).

Multiresolution Hash Encodingの実装: 順方向

さて,先ほどは2次元の入力に対して処理を行いましたが,実際のNeRFの入力は3次元です.ちなみにここを2次元にした場合は2次元のデータの近似(画像等)になります.さて,3次元を入力とする場合はエンコードの処理を3次元に対応させる必要があります.なので,これまで正方形として扱っていた個所を立方体として扱う必要があります.それを踏まえたうえで,まずは順伝播から実装を解説していきます.

私の実装をまず示します.

    /*
    Encode処理を行う
    (1) 各スレッドは1バッチの処理を行う(入力データは3次元であるという仮定を設ける)
        (1.1) 各レベルに対して次の処理を行う
            (1.1.1) 注目している座標が格子上ではどの8つの格子点に囲まれているかを求める
            (1.1.2) 各8格子点のHashTable上におけるインデックスを求める
            (1.1.3) もとめた8格子点における特徴ベクトルをHashTableより読みだす
            (1.1.4) 注目している座標が8格子点上でどの位置にあるか(s, t, u)を求める
            (1.1.5) (s, t, u)の値から特徴ベクトルを補完する
            コメント: 得られる特徴ベクトルは長さFである
        (1.2) 各レベルに対して(1.1)を行った結果,長さFの特徴ベクトルがL個得られる.それを繋げる(concat)
        (1.3) 長さEの追加特徴ベクトルをさらに(1.2)で得られた長さL*Fのベクトルに繋げる(concat)
    */
    template <const uint32_t indim_aligned>
    MFFM_DEVICE void Encode(const float3 InputRangeMin, const float3 InputRangeMax, __half* Input, __half* Encoded) {
        const int bx = blockIdx.x;
        const int tx = threadIdx.x;
        const int ty = threadIdx.y;
        const int global_threadId = bx * ONEBATCH_SIZE + 32 * ty + tx;
        const int block_threadId = 32 * ty + tx;

        // 1ブロック128バッチを担当する
        if (32 * ty + tx >= ONEBATCH_SIZE) {
            return;
        }

        unsigned int* NeighborPos_base = (unsigned int*)((__half*)Input + (indim_aligned + SKEW) * ONEBATCH_SIZE);
        unsigned int* NeighborPos = NeighborPos_base + (8 * 3) * block_threadId;
        float* NeighborFeatureVec_base = (float*)(NeighborPos_base + (8 * 3) * ONEBATCH_SIZE);
        float* NeighborFeatureVec = NeighborFeatureVec_base + (8 * MHE_F + SKEW) * block_threadId;
        float* stu_base = (NeighborFeatureVec_base + (8 * MHE_F + SKEW) * ONEBATCH_SIZE);
        float* stu = stu_base + 3 * block_threadId;

        // 入力のロード
        float x = normalize(Input[3 * block_threadId + 0], (__half)InputRangeMin.x, (__half)InputRangeMax.x, (__half)0.0f, (__half)1.0f);
        float y = normalize(Input[3 * block_threadId + 1], (__half)InputRangeMin.y, (__half)InputRangeMax.y, (__half)0.0f, (__half)1.0f);
        float z = normalize(Input[3 * block_threadId + 2], (__half)InputRangeMin.z, (__half)InputRangeMax.z, (__half)0.0f, (__half)1.0f);

        // ロードは必ず先に終わらせる
        __syncthreads();

        // (1.1)
#pragma unroll
        for (int l = 0; l < MHE_L; l++) {
            // (1.1.1)
            //unsigned int NeighborPos[8*3];
            Calc_NeighborVectorIndex(l, x, y, z, NeighborPos);

            // 近傍格子点における特徴ベクトルを求める (1.1.2) (1.1.3)
            //float NeighborFeatureVec[8*MHE_F];// [8][F]
#pragma unroll
            for (int i = 0; i < 8; i++) {
                // (1.1.2)
                unsigned int IndexOnHashTable = Calc_IndexOnHashTable(NeighborPos + 3 * i);

                // (1.1.3)
                Get_FeatureVectorOnHashTable(l, IndexOnHashTable, NeighborFeatureVec + MHE_F * i);
            }

            const int Buffer_stu_idx = PosToIdx2D(global_threadId, l, MHE_L) * 3;

            // 近傍格子点における特徴ベクトルから入力座標に対応する特徴ベクトルを求める (1.1.4) (1.1.5)
            // concatも行っていく (1.2)
            // SKEWを与えることに注意(出力データはL*F+E+SKEWとなる)
            Calc_CurrentFeatureVector(l, x, y, z, NeighborPos, NeighborFeatureVec, Encoded + block_threadId * (MHE_L * MHE_F + SKEW) + l * MHE_F, stu);

            // 今回は(1.3)は行わない
        }
        __syncthreads();
    }

Part1と同様に部分部分で見ていきましょう.

    template <const uint32_t indim_aligned>
    MFFM_DEVICE void Encode(const float3 InputRangeMin, const float3 InputRangeMax, __half* Input, __half* Encoded) {
...

・indim_aligned: MLPへの入力次元,つまりエンコーダーの出力次元を16の倍数に整形したものです.ちなみに必ずnext_multiple(LF, 16)となります……(じゃあconstexprで即値にすればいいのでは?)
・MFFM_DEVICE: CUDAの修飾子であるdeviceをdefineで置いたものです.
・InputRangeMin: サンプル点の座標 (x, y, z)が存在する領域はいわゆるAABB(Axis-Aligned-Bounding-Box),つまりxyz軸に平行な辺で構成された直方体です.その直方体のxyz座標の各々最小値が格納されています.
・InputRangeMax: InputRangeMinと同様にして,直方体のxyz座標の各々最大値が格納されています.
・Input: サンプル点の座標が格納された配列(の頭を指すポインタ)です.これはshared memoryに載っています.
・Encoded: エンコードされたデータの格納先です.ポインタとしてはInputと同じにしています.(shared memoryの容量が小さいためです)

...
const int bx = blockIdx.x;
const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int global_threadId = bx * ONEBATCH_SIZE + 32 * ty + tx;
const int block_threadId = 32 * ty + tx;
...

・global_threadID: デバイス全体で見たスレッドIDです.カーネル実行の設定はPart1を参照してください.
・block_threadId: ブロック単位で見たスレッドIDです.

...
// 1ブロック128バッチを担当する
if (32 * ty + tx >= ONEBATCH_SIZE) {
       return;
}
...

Part1を読めば詳しくは分かりますが,各ブロックは128バッチを処理します.なのでブロック単位で見たスレッドID(zero-indexed)が128以上のものは帰します.

...
unsigned int* NeighborPos_base = (unsigned int*)((__half*)Input + (indim_aligned + SKEW) * ONEBATCH_SIZE);
unsigned int* NeighborPos = NeighborPos_base + (8 * 3) * block_threadId;
float* NeighborFeatureVec_base = (float*)(NeighborPos_base + (8 * 3) * ONEBATCH_SIZE);
float* NeighborFeatureVec = NeighborFeatureVec_base + (8 * MHE_F + SKEW) * block_threadId;
float* stu_base = (NeighborFeatureVec_base + (8 * MHE_F + SKEW) * ONEBATCH_SIZE);
float* stu = stu_base + 3 * block_threadId;
...

うわあ……って感じです.えっと,実装中に使用する配列をshared memoryに載せようと努力しています.普通に静的配列として確保した方がいいと思います.細かい説明は出番が来た時にします.

...
 // 入力のロード
float x = normalize(Input[3 * block_threadId + 0], (__half)InputRangeMin.x, (__half)InputRangeMax.x, (__half)0.0f, (__half)1.0f);
float y = normalize(Input[3 * block_threadId + 1], (__half)InputRangeMin.y, (__half)InputRangeMax.y, (__half)0.0f, (__half)1.0f);
float z = normalize(Input[3 * block_threadId + 2], (__half)InputRangeMin.z, (__half)InputRangeMax.z, (__half)0.0f, (__half)1.0f);
...

入力座標をロードしておきます.ただし,サンプル点の座標が[[min.x, max.x], [min.y, max.y], [min.z, max.z]]に存在している状態だと面倒なので,ここでこの座標を[0, 1]^3に正規化しておきます.normalize関数は次の通りです.

// [SrcMIN, SrcMAX] -> [DstMIN, DstMAX]
template<typename T>
__host__ __device__ T normalize(T val, T SrcMIN, T SrcMAX, T DstMIN, T DstMAX) {
    T DstRange = DstMAX - DstMIN;
    T SrcRange = SrcMAX - SrcMIN;
    if (SrcRange == (T)0.0f) SrcRange = (T)1e-6f;
    T t = (val - SrcMIN) / SrcRange;
    return DstMIN + t * DstRange;
}

1次元の値に対して,元の最小値SrcMIN, 最大値SrcMAXの線分を点valで内分する際に比がどうなっているかを求め,それを出力側の最小値と最大値の線分に適用している感じです.ゼロ除算を避けるための処理はしていますが,エラー処理はしてません.

...
 // ロードは必ず先に終わらせる
 __syncthreads();
...

ブロック単位で同期を行います.ブロック内部の速いスレッドがエンコード結果を書き込む際に,ブロック内部の遅いスレッドが入力座標を読み出しが終わっていることを保証するためです.入力データ,出力データはshared memory(ブロックごとに独立)に載っているため,これで大丈夫です.

...
// (1.1)
#pragma unroll
for (int l = 0; l < MHE_L; l++) {
...

レベルの数だけ繰り返します.

...
 // (1.1.1)
//unsigned int NeighborPos[8*3];
Calc_NeighborVectorIndex(l, x, y, z, NeighborPos);
...

現在のレベルにおいて,サンプル点がどの小立方体内部にあるかを求めます.正確に言えば,サンプル点を包含する小立方体の各頂点のxyz各軸方向における頂点番号を求めます.NeighborPosは求めた各軸方向の頂点番号を格納する配列で,一頂点あたり3次元の頂点番号をもち,立方体は8頂点で構成されるので,uint32型8*3の容量が必要です.では,Calc_NeighborVectorIndexの処理を見ましょう.

 // レベルlevelにおける座標{x, y, z}の近傍格子点を求める
 MFFM_DEVICE inline void Calc_NeighborVectorIndex(int level, float x, float y, float z, unsigned int* NbVecIdx) {

     const unsigned int Nl = (unsigned int)(MHE_Nmin * powf(MHE_b, level));
     // 格子の1マスの大きさ
     float K = 1.0f / (float)Nl;

     NbVecIdx[0] = (int)(x / K); 
     NbVecIdx[1] = (int)(y / K); 
     NbVecIdx[2] = (int)(z / K);

     NbVecIdx[3] = NbVecIdx[0] + 1; 
     NbVecIdx[4] = NbVecIdx[1]; 
     NbVecIdx[5] = NbVecIdx[2];

     NbVecIdx[6] = NbVecIdx[0]; 
     NbVecIdx[7] = NbVecIdx[1] + 1; 
     NbVecIdx[8] = NbVecIdx[2];

     NbVecIdx[9] = NbVecIdx[0] + 1;
     NbVecIdx[10] = NbVecIdx[1] + 1; 
     NbVecIdx[11] = NbVecIdx[2];

     NbVecIdx[12] = NbVecIdx[0]; 
     NbVecIdx[13] = NbVecIdx[1]; 
     NbVecIdx[14] = NbVecIdx[2] + 1;

     NbVecIdx[15] = NbVecIdx[0] + 1; 
     NbVecIdx[16] = NbVecIdx[1]; 
     NbVecIdx[17] = NbVecIdx[2] + 1;

     NbVecIdx[18] = NbVecIdx[0]; 
     NbVecIdx[19] = NbVecIdx[1] + 1; 
     NbVecIdx[20] = NbVecIdx[2] + 1;

     NbVecIdx[21] = NbVecIdx[0] + 1; 
     NbVecIdx[22] = NbVecIdx[1] + 1; 
     NbVecIdx[23] = NbVecIdx[2] + 1;
 }

図を交えて説明しましょう.やってることは2次元での説明を3次元にしただけです.

手続きの説明は2次元で行っていました.サンプル点の存在する正方形を小正方形に分割し,どの小正方形に包含されるかを求めました.しかし今回は3次元でやるので,立方体を小立方体に分割し,どの小立方体にサンプル点が含まれるかを計算する必要があります.ここで,先ほど (x, y, z)をロードする際に,座標を[0, 1]^3にスケーリングしました.なので,全体の立方体の一辺の大きさは1です.そして,それをレベル lについては解像度 N_lだけ分割します.ここで,解像度はレベルに対応した解像度としたいので,次の式で解像度を計算します(説明時はレベルを1-indexedで扱っていますが,計算式や実装上は0-indexedです).

 
N_l = \lfloor N_{min} * b^l \rfloor


ここで, N_{min}は最小解像度, bは解像度のスケーリング指数, lはレベル番号を意味します.レベルが高くなるにつれて指数関数的に解像度が増加します.
さて,この解像度のもとで,小立方体の一辺の長さ( Kとします)は次のように求まります.

 
K = \dfrac{1}{N_l}


そして,次が成立しています. Positionはサンプル点の座標です.

 
整数M_x, M_y, M_zが存在して,
\begin{align}
M_xK \leq & Position.x < (M_x+1)K \\
M_yK \leq & Position.y < (M_y+1)K \\
M_zK \leq & Position.z < (M_z+1)K 
\end{align}


この M_x, M_y, M_zは小立方体において各軸小さい側の頂点の番号を表しています.つまり,先ほど示した図において,( M_x, M_y, M_z)が([0], [1], [2])です.残りの7頂点はこの頂点番号に1足したり足さなかったり......で求められます.以上がCalc_NeighborVectorIndexの処理です.続きを見ていきましょう.

// 近傍格子点における特徴ベクトルを求める (1.1.2) (1.1.3)
//float NeighborFeatureVec[8*MHE_F];// [8][F]
#pragma unroll
for (int i = 0; i < 8; i++) {
          // (1.1.2)
          unsigned int IndexOnHashTable = Calc_IndexOnHashTable(NeighborPos + 3 * i);

          // (1.1.3)
          Get_FeatureVectorOnHashTable(l, IndexOnHashTable, NeighborFeatureVec + MHE_F * i);
}

先ほどの処理でサンプル点を包含する小立方体の各頂点番号が分かりました.これをそれぞれハッシュ関数に入力し,テーブル上のインデックスを計算します(Calc_IndexOnHashTable).そして,そのテーブル上のそのインデックスに保存されている特徴ベクトルを読み出します(Get_FeatureVectorOnHashTable).読みだした特徴ベクトルはNeighborFeatureVecに保存されます.8頂点分の特徴ベクトルを保存するのでfloat8F個の容量となっております.
ではCalc_IndexOnHashTableから見ていきましょう.

// 頂点インデックス(3d)からHashTable上のインデックスをハッシュ関数により計算する
MFFM_DEVICE inline unsigned int Calc_IndexOnHashTable(unsigned int* VertexIndex) {
    const unsigned long long int p[3] = { 1, 2654435761, 805459861 };
    unsigned long long int h = 0;

    for (int i = 0; i < 3; i++) {
        h = h ^ (VertexIndex[i] * p[i]);
    }
    h %= MHE_T;

    return (unsigned int)h;
}

これはハッシュ関数を示した方が速いですね.次の関数 h(\mathbf{X})がテーブル上のインデックスを出力するハッシュ関数です.


h(\mathbf{X}) = (X.x * 1) \oplus (X.y * 2654435761) \oplus (X.z * 805459861 )  \mod T

なお, \oplus排他的論理和(XOR)で,Tはテーブルのサイズ,すなわち特徴ベクトルの本数です.この式に出てくる2654435761や805459861は論文において記されていた値ですが,大きな素数が使用されます.この計算によってテーブルの特徴ベクトルをロードする準備が整いましたのでロードします.Get_FeatureVectorOnHashTableを見ていきましょう.

    // VにHashTable上の特徴ベクトルを書き出す
    MFFM_DEVICE inline void Get_FeatureVectorOnHashTable(int level, unsigned int index, float* V) {
#pragma unroll
        for (int i = 0; i < MHE_F; i++) {
            V[i] = HashTable.at(level, index, i);
        }
    }

level: これまで何度も出てきている「レベル」です.
index: 先ほどのハッシュ関数によって計算されたインデックスです
V: テーブルを書きだす先の配列です.
MHE_F: 説明の際に示した,テーブルにおける特徴ベクトルの次元を表す Fのことです.
さて,ここでも触れるべき点があります.しかしこの部分は実装の幅を狭めるところなので真似はしないほうがいいです.コードを見てわかる通り,HashTableは何かしらの構造体として,グローバルのスコープで保持されていますね.ではその部分の実装を見ていきましょう.

enum class Initialize {
    Uniform,
    Load_from_file,
    Zero
};

struct Tb {
    float Data[MHE_L * MHE_T * MHE_F];

    MFFM_DEVICE void init(Initialize initialize, unsigned int seed = 1, float* LoadedData = nullptr) {
        const int index = blockIdx.x * blockDim.x + threadIdx.x;
        switch (initialize) {
        case(Initialize::Zero):
            Data[index] = 0.0f;
            break;
        case(Initialize::Uniform):
            curandState state;
            curand_init(seed, index, 0, &state);
            float rnd = curand_uniform(&state);
            Data[index] = normalize(rnd, 0.0f, 1.0f, -1e-4f, 1e-4f);
            break;
        case(Initialize::Load_from_file):
            Data[index] = LoadedData[index];
            break;
        default:
            printf("Invalid MHE initializer\n");
            break;
        }
    }

    MFFM_DEVICE float at(const uint32_t idxL, const uint32_t idxT, const uint32_t idxF) {
        return Data[MHE_T * MHE_F * idxL + MHE_F * idxT + idxF];
    }
    MFFM_DEVICE float* ptr_at(const uint32_t idxL, const uint32_t idxT, const uint32_t idxF) {
        return Data + MHE_T * MHE_F * idxL + MHE_F * idxT + idxF;
    }
};

MFFM_DEVICE Tb HashTable;
MFFM_DEVICE Tb dLdHashTable;
MFFM_DEVICE Tb v_Buffer_HashTable;
MFFM_DEVICE Tb m_Buffer_HashTable;

__global__ void init_HashTable() {
    HashTable.init(Initialize::Uniform);
    dLdHashTable.init(Initialize::Zero);
    v_Buffer_HashTable.init(Initialize::Zero);
    m_Buffer_HashTable.init(Initialize::Zero);
}

以上がテーブルの構造体周りの処理です.難しいことはしていないので軽く説明するにとどめます.

float Data[MHE_L * MHE_T * MHE_F];

保持しているデータはレベル数 L,テーブルのサイズ T,テーブルの特徴ベクトルの次元 Fの積である LTF要素のfloat配列です.これに対して,init関数では様々な初期化を行います.at関数は(レベル,テーブル上のインデックス,特徴ベクトル上のインデックス)をもとにしてテーブルの要素にアクセスする関数です(今気づきましたが参照してないですね).ptr_at関数は同様の要素を指すポインタにアクセスします.そして,テーブルの本体(HashTable),勾配を記録するdLdHashTable,Adam Optimizerのためのm_Buffer_HashTableとv_Buffer_HashTableがあります.init_HashTableにおいてそれぞれを初期化します.
これは重要なのですが,テーブル本体は初期値を[-1e-4, 1e-4]の範囲における一様乱数で初期化します.それ以外は普通に0初期化します.

さて,先ほどのGet_FeatureVectorOnHashTableにおける処理はこれでわかると思います.しかし,このようにグローバルなものとして定義すると,ニューラルネットワーク内部において1つしかMultiresolution Hash Encodingを使用できないという制約を抱えることになるため,避けた方がいいでしょう.現在主にこの周りの書き直しをしております.

説明の枝が長くなりましたが本筋の解説に戻りましょう.現在どこまでやったかというと,サンプル点の座標を包含する小立方体を求め,その頂点番号をハッシュ関数にいれてテーブル上のインデックスを計算し,そのインデックスに対応するテーブル上の特徴ベクトルを読みだしたところです.ということで,続きを見ていきましょう.

...
// 近傍格子点における特徴ベクトルから入力座標に対応する特徴ベクトルを求める (1.1.4) (1.1.5)
// concatも行っていく (1.2)
// SKEWを与えることに注意(出力データはL*F+E+SKEWとなる)
Calc_CurrentFeatureVector(l, x, y, z, NeighborPos, NeighborFeatureVec, Encoded + block_threadId * (MHE_L * MHE_F + SKEW) + l * MHE_F, stu);
...

この部分の処理では「各頂点にロードした特徴ベクトルの補完によるサンプル点における特徴ベクトルの計算」を行い,それを「レベル番号に対応したメモリ領域に保存(即ち結合と同義)」しています.
Encoded + block_threadId * (MHE_L * MHE_F + SKEW) + l * MHE_Fは,サンプル点におけるレベルlの特徴ベクトルを保存するポインタの先頭を指しています.1バッチ辺りのエンコード結果の特徴ベクトルはLF次元であり,shared memoryのバンクコンフリクトを避けるためにSKEWを与えるため,結局LF+SKEW次元となります.なので,スレッド番号にLF+SKEWを掛けてあげて,さらに各レベルではF次元の特徴ベクトルが得られるのでレベル番号にFを掛けてます.
stuは,いや本当にごめんなさいなんですけど,バイリニア補完の係数を載せる配列です.いや,関数内部で静的配列として確保してくださいね.
というわけで関数の中身を見ましょう.

    // 近傍点の特徴ベクトルからEncoder入力座標における特徴ベクトルをバイリニア補完する
    MFFM_DEVICE inline void Calc_CurrentFeatureVector(int level, float x, float y, float z, unsigned int* NbVecIdx, float* NbFeatureVec,
        __half* CurFeatureVec, float* stu) {

        const unsigned int Nl = (unsigned int)(MHE_Nmin * pow(MHE_b, level));
        // 格子の1マスの大きさ
        float K = 1.0f / (float)Nl;

        // バイリニア補完係数
        stu[0] = (x - K * (float)NbVecIdx[0]) / K;
        stu[1] = (y - K * (float)NbVecIdx[1]) / K;
        stu[2] = (z - K * (float)NbVecIdx[2]) / K;

        // 3d-バイリニア補完
#pragma unroll
        for (int i = 0; i < MHE_F; i++) {
            CurFeatureVec[i] = __float2half((1 - stu[0]) * (1 - stu[1]) * (1 - stu[2]) * NbFeatureVec[i] +
                stu[0] * (1 - stu[1]) * (1 - stu[2]) * NbFeatureVec[MHE_F + i] +
                (1 - stu[0]) * stu[1] * (1 - stu[2]) * NbFeatureVec[2 * MHE_F + i] +
                stu[0] * stu[1] * (1 - stu[2]) * NbFeatureVec[3 * MHE_F + i] +
                (1 - stu[0]) * (1 - stu[1]) * stu[2] * NbFeatureVec[4 * MHE_F + i] +
                stu[0] * (1 - stu[1]) * stu[2] * NbFeatureVec[5 * MHE_F + i] +
                (1 - stu[0]) * stu[1] * stu[2] * NbFeatureVec[6 * MHE_F + i] +
                stu[0] * stu[1] * stu[2] * NbFeatureVec[7 * MHE_F + i]);
        }
    }

 Kを求めるところまではCalc_NeighborVectorIndexでやったのと同じです.後半を図で説明します.

Calc_NeighborVectorIndexにても書きましたが,この小立方体においてすべての座標が小さい頂点(特徴ベクトルV[0]がある頂点)の座標は, (M_xK, M_yK, M_zK)です.また,この小立方体の一辺の長さはKです.つまり,このサンプル点の座標から (M_xK, M_yK, M_zK)を引いたベクトルをKで割るとサンプル点が立方体の各軸方向についてどの場所にあるかを表現することが出来ます.つまり,次の式によりバイリニア補完係数を求めています.

 
(s, t, u) = \dfrac{(Position) - (M_xK, M_yK, M_zK)}{K}

あとは図中の式に従って補完しましょう.

お疲れ様です.これでMultiresolution Hash Encodingの順方向が実装出来ました.

Multiresolution Hash Encodingの手続き: 逆方向

さて,実は逆方向はかなり簡単です.今回求める勾配(即ち更新パラメーター)はテーブル上の特徴ベクトルです.順方向でどのような操作をしたかを思い出しましょう.「サンプル点を含む小立方体の各頂点に対応する特徴ベクトルを読み出し,バイリニア補完により出力層を計算」しました.つまり,これの逆伝播としては,「バイリニア補完の逆伝播計算をし,各頂点に対応する特徴ベクトルの勾配を計算」することです.図中のバイリニア補完の式から,次の微分が出来ます.記号は先ほどの図に出てくる計算式を参照してください( \mathbf{X}は入力ではないです)

 
\begin{align}
\dfrac{\partial \mathbf{X}}{\partial \mathbf{V}}= (&(1-s)(1-t)(1-u), s(1-t)(1-u), (1-s)t(1-u), st(1-u), \\ 
                &(1-s)(1-t)u, s(1-t)u, (1-s)tu, stu) \\
\end{align}

よって,バイリニア補完の係数を出力層に流れ込んできた勾配に掛けてあげれば良いだけです.

Multiresolution Hash Encodingの実装: 逆方向

    /*
     * 誤差逆伝播
     * shared memory:
     * - dEdOut: 誤差.サイズ(INDIM_ALIGNED+SKEW) * ONEBATCH_SIZE
     * - additional_shmem: 余分なshared memory.dLdVの格納に使用する
     *                     サイズ/thread: 8*MHE_F+SKEW
     *                     始点: (8*MHE_F+SKEW)*(block_threadIdx)
     * (1) スレッドごとのエンコーダー出力層の誤差を全スレッドに対するエンコーダー出力層の誤差から読みだす
     * (2) 各レベルごとに次の処理を行う
     *      (2.1) スレッドごとのエンコーダー出力層の誤差をレベルごとに分割する
     *      (2.2) 順伝播時に記録したBuffer.s/t/uを読み込み,近傍点の特徴ベクトルの補完係数を求める
     *      (2.3) 各近傍点について次の処理を行う
     *          (2.3.1) dEdV[k]を求める
     *          (2.3.2) dEdV[k]のHashTable上でのインデックスをBufferから読みだす
     *          (2.3.3) 近傍点に対応するHashTableの要素の誤差を記録する
     */
    MFFM_DEVICE void Propagate_backward(const float3 InputRangeMin, const float3 InputRangeMax, __half* dEdOut, __half* Buffer_Input, __half* additional_shmem) {
        const int bx = blockIdx.x;
        const int tx = threadIdx.x;
        const int ty = threadIdx.y;

        // 1ブロック128バッチを担当する
        if (32 * ty + tx >= ONEBATCH_SIZE) {
            return;
        }

        const int global_threadId = bx * ONEBATCH_SIZE + 32 * ty + tx;
        const int block_threadId = 32 * ty + tx;

        // s, t, uは再計算する方が速い
        // 入力のロード
        float x = normalize(Buffer_Input[3 * block_threadId + 0], (__half)InputRangeMin.x, (__half)InputRangeMax.x, (__half)0.0f, (__half)1.0f);
        float y = normalize(Buffer_Input[3 * block_threadId + 1], (__half)InputRangeMin.y, (__half)InputRangeMax.y, (__half)0.0f, (__half)1.0f);
        float z = normalize(Buffer_Input[3 * block_threadId + 2], (__half)InputRangeMin.z, (__half)InputRangeMax.z, (__half)0.0f, (__half)1.0f);

        __syncthreads();

        // (1) SKEWに注意(次元はL*F+E+SKEW)
        __half* dEdOutLF = dEdOut + (MHE_L * MHE_F + SKEW) * block_threadId;
        __half* dLdV = additional_shmem + (8 * MHE_F + SKEW) * (block_threadId);
        // (2)
#pragma unroll
        for (int l = 0; l < MHE_L; l++) {
            // s, t, u, indexOnHashTableを求める処理 - レベルLの格子における近傍点の座標 ///////////////
            unsigned int NeighborPos[8 * 3];
            Calc_NeighborVectorIndex(l, x, y, z, NeighborPos);

            const unsigned int Nl = (unsigned int)(MHE_Nmin * pow(MHE_b, l));
            // 格子の1マスの大きさ
            const float K = 1.0f / (float)Nl;

            // バイリニア補完係数
            const float s = (x - K * (float)NeighborPos[0]) / K;
            const float t = (y - K * (float)NeighborPos[1]) / K;
            const float u = (z - K * (float)NeighborPos[2]) / K;
            ////////////////////////////////////////////////////////////////////////////////////

            // (2.1)
            __half* dEdOut_LvWise = dEdOutLF + l * MHE_F;

            // (2.2)
            const float weight[8] = { (1 - s) * (1 - t) * (1 - u), s * (1 - t) * (1 - u), (1 - s) * t * (1 - u), s * t * (1 - u),
                                      (1 - s) * (1 - t) * u      , s * (1 - t) * u      , (1 - s) * t * u      , s * t * u };

            // (2.3)
#pragma unroll
            for (int k = 0; k < 8; k++) {
                // IndexOnHashTableを求める.
                unsigned int IndexOnHashTable = Calc_IndexOnHashTable(NeighborPos + 3 * k);

                // (2.3.1)
#pragma unroll
                for (int f = 0; f < MHE_F; f++) {
                    dLdV[k * MHE_F + f] = weight[k] * (float)dEdOut_LvWise[f];
                }

                // (2.3.3)
#pragma unroll
                for (int f = 0; f < MHE_F; f++) {
                    atomicAdd(dLdHashTable.ptr_at(l, IndexOnHashTable, f), dLdV[k * MHE_F + f]);
                    //*dLdHashTable.ptr_at(l, IndexOnHashTable, f) += (float)dLdV[k * MHE_F + f];
                }
            }
        }
    }

見ていきましょう.

MFFM_DEVICE void Propagate_backward(const float3 InputRangeMin, const float3 InputRangeMax, __half* dEdOut, __half* Buffer_Input, __half* additional_shmem) {
...

・dEdOut: 出力層に流れ込んできた勾配
・Buffer_Input: 順方向の際の入力データです.即ちサンプル点の座標です
・additional_shmem: 酷い実装の片鱗です.演算時のデータをshared memoryに載せるために空いているshared memoryの領域を持ってきます.無視しても良いです.

...
const int bx = blockIdx.x;
...
float z = normalize(Buffer_Input[3 * block_threadId + 2], (__half)InputRangeMin.z, (__half)InputRangeMax.z, (__half)0.0f, (__half)1.0f);

__syncthreads();
...

順伝播と同じです.バイリニア補完の係数を求めるために順伝播の処理を部分的に行っています.s, t, uを保存しておくよりもこちらの方が綺麗に実装できると思います.

// (1) SKEWに注意(次元はL*F+E+SKEW)
__half* dEdOutLF = dEdOut + (MHE_L * MHE_F + SKEW) * block_threadId;
__half* dLdV = additional_shmem + (8 * MHE_F + SKEW) * (block_threadId);

・dEdOutLF: 出力層に流れ込んできた勾配データで,実行スレッドに対応する勾配データの先頭を指すポインタ
・dLdV: ああ,各頂点に対応する特徴ベクトルの勾配を保存する領域です.静的配列として確保した方が良いと思います.あと,誤差関数がLとEで表記揺れしていますが気にしないでください......

...
// (2)
#pragma unroll
        for (int l = 0; l < MHE_L; l++) {
...

レベルごとに行います.

...
// s, t, u, indexOnHashTableを求める処理 - レベルLの格子における近傍点の座標 ///////////////
...
const float u = (z - K * (float)NeighborPos[2]) / K;
////////////////////////////////////////////////////////////////////////////////////
...

コメントの通りです.順伝播と同じ処理なので省略します.

...
 // (2.1)
 __half* dEdOut_LvWise = dEdOutLF + l * MHE_F;
...

各レベルごとに処理をしたいので,処理中のレベルに対応した出力層に流れ込んできた勾配を指すポインタを計算します.1レベルごとにF要素を処理しているのでレベル番号にFを掛けてます.

// (2.2)
const float weight[8] = { (1 - s) * (1 - t) * (1 - u), s * (1 - t) * (1 - u), (1 - s) * t * (1 - u), s * t * (1 - u), (1 - s) * (1 - t) * u      , s * (1 - t) * u      , (1 - s) * t * u      , s * t * u };

バイリニア補完の係数ですね.これで準備が整いました.一気に行きましょう.

// (2.3)
#pragma unroll
for (int k = 0; k < 8; k++) {
         // IndexOnHashTableを求める.
         unsigned int IndexOnHashTable = Calc_IndexOnHashTable(NeighborPos + 3 * k);

        // (2.3.1)
#pragma unroll
       for (int f = 0; f < MHE_F; f++) {
             dLdV[k * MHE_F + f] = weight[k] * (float)dEdOut_LvWise[f];
       }

       // (2.3.3)
#pragma unroll
      for (int f = 0; f < MHE_F; f++) {
            atomicAdd(dLdHashTable.ptr_at(l, IndexOnHashTable, f), dLdV[k * MHE_F + f]);
            //*dLdHashTable.ptr_at(l, IndexOnHashTable, f) += (float)dLdV[k * MHE_F + f];
      }
 }

各頂点に対応する特徴ベクトルの勾配を求めるため,各頂点に注目して処理していきます.まず,頂点とテーブル上の特徴ベクトルを対応させるために,順伝播と同じようにインデックスを求めます.そして,先程示した勾配を求める式に代入し,頂点に対応する特徴ベクトルの勾配を求めます.そして最後に,その勾配を,勾配を記録するテーブルの構造体であるdLdHashTableにatomicAddによりaccumulateします.(正直アクセスが疎なのでatomicじゃなくても耐えるのでは?と思っていますが,確証がないのでちゃんとatomicにしてます).以上で逆方向の処理は完了です.

Multiresolution Hash Encodingの実装: 最適化

実は現状の実装におけるボトルネックです.テーブル上のすべてのパラメーターに対して最適化処理を行います.

    ///////////////////////////////// OPTIMIZATION IMPLEMENTATION //////////////////////////////////////////////////////////////////////
    MFFM_DEVICE void Optimization(const uint32_t BatchSize, Optimize optimize, const int epoch) {
        int bx = blockIdx.x;
        int tx = threadIdx.x;
        int ty = threadIdx.y;

        const int nThreads = 32 * blockDim.y * gridDim.x;
        const int threadId = blockIdx.x * 32 * blockDim.y + 32 * ty + tx;
        const int WeightSize_this_layer = MHE_L * MHE_T * MHE_F;

#pragma unroll
        for (int i = threadId; i < WeightSize_this_layer; i += nThreads) {
            float dL = dLdHashTable.Data[i];
            if (!isfinite(dL)) {
                dLdHashTable.Data[i] = 0.0f;
                continue;
            }

            dL = dL / (float)BatchSize;

            switch (optimize) {
            case(Optimize::GD):
                HashTable.Data[i] = HashTable.Data[i] - (float)LEARNINGRATE * dL;
                break;
            case(Optimize::Adam):
                if (!AdamOptimize(m_Buffer_HashTable.Data[i], v_Buffer_HashTable.Data[i], dL, HashTable.Data[i], epoch)) {
                    // printf("%d %f %f \n", idx, (float)AdditionalParam[2 * idx], (float)AdditionalParam[2 * idx + 1]);
                }
                break;
            default:
                printf("Invalid Optimization Type\n");
                break;
            }
            dLdHashTable.Data[i] = 0.0f;
        }
        __syncthreads();
    }

やっていることはPart1における全結合層のパラメーターの最適化と同じなので説明は省略します.実際に使用されるテーブル上のパラメーターは限られるのでそれだけ更新するという実装にしようとしていますが現状ではまだ上手くいってません.以上でMultiresolution Hash Encodingの実装が完了しました.

Spherical Harmonic Encodingについて

球面調和関数の基底を用いてエンコードするというものです.球面調和関数は物性とか微分方程式とかの講義で触れられた記憶がありますが,正直なところ式だけ提示されても何も分かりません.まずそもそも球面調和関数を使用してエンコードするというのは何故なのか,いったい何の意味があるのかという疑問が出てきます.例えば3DCGの分野でも光源系の表現などに使用している論文が2000年ごろにありましたが,まあ読んでも詳細には何をしているか分かりませんでした.というわけで土日を溶かして数学をしました.そのうえで自分の理解を述べます.実装だけが目的であれば飛ばしても大丈夫です.

球面調和関数の導出とその正規直交性(と完全性)

ラプラス方程式

 
\Delta \Psi = (\dfrac{\partial^2}{\partial x^2} + \dfrac{\partial^2}{\partial y^2} + \dfrac{\partial^2}{\partial z^2})\Psi = 0


を考えます.これは極座標の関係

 
(x, y, z) = (r \sin \phi \cos \theta, r \sin \phi \sin \theta, r \cos\phi)
を用いて


 
\dfrac{\partial}{\partial r} \left(r^2 \dfrac{\partial}{\partial r}  \right) \Psi + \dfrac{1}{\sin \theta} \dfrac{\partial}{\partial \theta} \left( \sin\theta \dfrac{\partial}{\partial \theta} \right) \Psi + \dfrac{1}{\sin^2\theta} \dfrac{\partial^2}{\partial \phi^2}\Psi = 0


と書き換えられます(ここの導出はかなり面倒なので省きます).変数分離法により変形していきます.まず,次を認めます.


 
R(r), Q(\theta, \phi)が存在して,\Psi(r, \theta, \phi) = R(r)Q(\theta, \phi)と書ける


これを先程の式に代入し,両辺を RQで割って整理すると,次式が得られます.


 
\dfrac{1}{R}\dfrac{d}{dr} \left(r^2 \dfrac{d}{dr} \right) R = -\dfrac{1}{Q\sin\theta}\dfrac{\partial}{\partial \theta} \left( \sin\theta\dfrac{\partial}{\partial \theta}\right)Q - \dfrac{1}{Q\sin^2\theta}\dfrac{\partial}{\partial\phi}Q


ここで,左辺は rのみの式,右辺は \theta, \phiのみの式となり,この等号がどのような r, \theta, \phiに対しても成り立つので,両辺は定数となります. \lambdaを定数として,


 
\begin{align}
& \dfrac{1}{R}\dfrac{d}{dr} \left(r^2 \dfrac{d}{dr} \right) R = \lambda \\
& \dfrac{1}{Q\sin\theta}\dfrac{\partial}{\partial \theta} \left( \sin\theta\dfrac{\partial}{\partial \theta}\right)Q - \dfrac{1}{Q\sin^2\theta}\dfrac{\partial}{\partial\phi}Q = -\lambda
\end{align}


とします.今回は2つめの式に注目します.さらに次を認めます.


 
\Theta(\theta), \Phi(\phi)が存在して,Q(\theta, \phi) = \Theta(\theta) \Phi(\phi)と書ける


これを先程の2つめの式に代入し,同様に整理すると,次が得られます.


 
\dfrac{\sin\theta}{\Theta} \dfrac{d}{d\theta} \left(\sin\theta \dfrac{d}{d\theta}  \right)\Theta + \lambda \sin^2\theta = -\dfrac{1}{\Phi} \dfrac{d}{d\phi}\Phi


左辺は \thetaのみの式,右辺は \phiのみの式になっていますね.また, \Phiは周期的な関数であるとすれば, mを整数として,


 
\dfrac{1}{\Phi} \dfrac{d}{d\phi}\Phi = -m^2 \\
\dfrac{\sin\theta}{\Theta} \dfrac{d}{d\theta} \left(\sin\theta \dfrac{d}{d\theta}  \right)\Theta + \lambda \sin^2\theta = m^2


として書けます.1つ目の式より, Aを任意定数として,


 
\Phi = A e^{jm\phi}


と書けます(一般解ではないです.共役な基底があります).また,2つ目の式に対して,


 
x := \cos \theta \\
P(x) = P(\cos\theta) = \Theta(\theta)


を代入すると,


 
(1-x^2)^2 \dfrac{d^2P}{dx^2} -2x(1-x^2)\dfrac{dP}{dx} + \left( \lambda(1-x^2) - m^2  \right)P = 0


両辺を[ tex: (1-x2) ]で割ると,('はxによる微分を意味します)


 
(1-x^2) P'' -2xP' + \left( \lambda - \dfrac{m^2}{1-x^2}  \right)P = 0


ですね.これは m = 0のときルジャンドル方程式,そうでない場合ルジャンドル培方程式と言って名前がついてます.ここで, m \geq 0としておくことにします.
この世界には次の式が浮かんでくる人がいるみたいです.


 
P = (1-x^2)^{m/2} \mathcal{P}


これを先程の式に代入すると次が得られます.


 
(1-x^2)\mathcal{P}'' - 2x(m+1)\mathcal{P}' + \left( \lambda - m(m+1)  \right)\mathcal{P} = 0


フロベニウスの方法(級数法)を使用してこの微分方程式を解きます. \mathcal{P}が次の級数の形で与えられるとします.


 
\mathcal{P} = \sum_{j=0} ^ \infty a_j x^{k+j}


 k = 0の時,xの次数が2以上の式を考えると,次の漸化式が得られます.


 
a_{j+2} = a_j \left( \dfrac{j^2 + (2m+1)j - \lambda + m(m+1) }{ (j+1)(j+2) } \right)


ここで, -1 \leq x \leq 1の範囲で \mathcal{P}(x)には収束してもらうため,


 
j^2 + (2m+1)j - \lambda + m(m+1) = 0を満たすjが存在する


を満たすようにしたいです.ここで lをm以上の整数として, j = l-mとします.これを上の式に代入することにより, \lambda = l(l+1)が先程の条件を満たしてくれます.
 k = 1の時,xの次数が1以上の式を比較することにより,


 
a_{j+1} = a_{j-1} \left( \dfrac{j^2 + (2m+1)j - \lambda + m(m+1) }{ (j+1)(j+2) } \right)


が得られます.同様の議論が出来ます.さて,ルジャンドル培方程式を改めて書き直しましょう.


 
(1-x^2) P_l^{m''} -2xP_l^{m'} + \left( l(l+1)- \dfrac{m^2}{1-x^2}  \right)P_l^m = 0


 m = 0の時,上式は


 
(1-x^2) P_l^{''} -2xP_l^{'} +  l(l+1)P_l = 0


となります(ルジャンドル方程式).この両辺をm回微分しましょう.ライプニッツの公式


 
\dfrac{d^m}{dx^m} (A(x)B(x)) = \sum_{r=0}^{m} {}_m \mathrm{C}_r \dfrac{d^{m-r}}{dx^{m-r}} A(x) \dfrac{d^r}{dx^r} B(x)


を利用します.


 
(1-x^2)u'' - 2x(m+1)u' + \left( l(l+1) - m(m+1)  \right)u = 0 \\
ただし,u := \dfrac{d^m}{dx^m}P_l(x)


となります.実はこの式は先ほど出てきた \mathcal{P}に関する微分方程式と同じですね( \lambda = l(l+1)).これより,


 
\mathcal{P}_l^m = (-1)^m \dfrac{d^m}{dx^m}P_l(x)


と書けるらしいですがこの(-1)のべき乗の項はまだよく分かってないです.コンドン-ショートレー位相と呼ぶらしいですが,AMS-55という定義があるらしいとかなんとか……とにかく,これにロドリゲスの公式


 
P_l(x) = \dfrac{1}{2^l l!} \dfrac{d^l}{dx^l}(x^2-1)^l


を代入することにより,


 
\mathcal{P}_l^m = \dfrac{(-1)^m}{2^l l!} \dfrac{d^{m+l}}{dx^{m+l}}(x^2-1)^l


であり,さらに


 
P_l^m = \dfrac{(-1)^m}{2^l l!} (1-x^2)^{m/2} \dfrac{d^{m+l}}{dx^{m+l}}(x^2-1)^l


となります.これをルジャンドル陪関数と呼びます.ここで,これまでは非負の mについて計算していたので,これを負の mについても拡張します.(本当にこんな拡張していいのかという疑問がまだ解決できておりませんが......)
ひたすら計算します.ライプニッツの公式を使用して,


 
\begin{align}
\dfrac{d^{l+m}}{dx^{l+m}} (1-x^2)^l &= \left(\dfrac{d}{dx}\right)^{l+m} (1+x)^l (1-x)^l \\
                                                           &= \sum_{r=0}^{l+m} {}_{l+m}\mathrm{C}_r \left(  \left(\dfrac{d}{dx}\right) ^{l+m-r} (1+x)^l \right) \left(  \left(\dfrac{d}{dx}\right) ^{r} (1-x)^l \right)
\end{align}


ここで,[tex: (1+x)l, (1-x)l]の最高次数はともに lなので, l+1微分すると0となります.また,努力により,この微分は計算出来て,


 
\begin{align}
\left(\dfrac{d}{dx}\right) ^{l+m-r} (1+x)^l &= \dfrac{l!}{(r-m)!}(1+x)^{r-m}  \;\; (r \geq m) \\
\left(\dfrac{d}{dx}\right) ^{r} (1-x)^l          &= \dfrac{(-1)^r l!}{(l-r)!}(1-x)^{l-r} \;\; (r \leq l)
\end{align}


これより,


 
\begin{align}
\dfrac{d^{l+m}}{dx^{l+m}} (1-x^2)^l &= \sum_{r=m}^{l} {}_{l+m}\mathrm{C}_r \left( \dfrac{l!}{(r-m)!}(1+x)^{r-m} \right) \left( \dfrac{(-1)^r l!}{(l-r)!}(1-x)^{l-r}  \right) \\
                                                            &= \sum_{k=0}^{l-m} {}_{l+m}\mathrm{C}_{k+m} \left( \dfrac{l!}{k!}(1+x)^{k} \right) \left( \dfrac{(-1)^{m+k} l!}{(l-m-k)!}(1-x)^{l-m-k}  \right) \;\; (k ;= r-m) \\
                                                            &= \sum_{k=0}^{l-m} {}_{l+m}\mathrm{C}_{k+m} \left( \dfrac{l!}{k!} \dfrac{(-1)^{m+k} l!}{(l-m-k)!} \dfrac{(1+x)^{k+m}}{(1+x)^m} \dfrac{(1-x)^{l-k}}{(1-x)^m} \right) \\
                                                            &= \sum_{k=0}^{l-m} \dfrac{(-1)^m}{(1-x^2)^m} \dfrac{(l+m)!}{(l-k)!(k+m)!}  \left( \dfrac{l!}{k!} \dfrac{(-1)^{k} l!}{(l-m-k)!} (1+x)^{k+m}(1-x)^{l-k} \right) \\
                                                            &= \sum_{k=0}^{l-m} \dfrac{(-1)^m}{(1-x^2)^m} \dfrac{(l+m)!}{(l-m)!} \left( \dfrac{(l-m)!}{k!(l-m-k)!} \dfrac{l!}{(k+m)!}(1+x)^{k+m} \dfrac{(-1)^kl!}{(l-k)!}(1-x)^{l-k}  \right) 
\end{align}


ここで,直前の努力により得られた式と括弧内の式を見比べると,


 
\begin{align}
\dfrac{d^{l+m}}{dx^{l+m}} (1-x^2)^l &= \dfrac{(-1)^m}{(1-x^2)^m} \dfrac{(l+m)!}{(l-m)!} \sum_{k=0}^{l-m} {}_{l-m}\mathrm{C}_k \left( \left(\dfrac{d}{dx}\right)^{l-m-k} (1+x)^l \right) \left( \left(\dfrac{d}{dx}\right)^{k} (1-x)^l \right)
\end{align}


ライプニッツの公式より,


 
\begin{align}
\dfrac{d^{l+m}}{dx^{l+m}} (1-x^2)^l &= \dfrac{(-1)^m}{(1-x^2)^m} \dfrac{(l+m)!}{(l-m)!} \dfrac{d^{l-m}}{dx^{l-m}} (1-x^2)^l
\end{align}


さて,準備が整ったのでルジャンドル陪関数の式に適用します.


 
\begin{align}
P_l^m(x) &= \dfrac{(-1)^m}{2^l l!} (1-x^2)^{m/2} \dfrac{d^{m+l}}{dx^{m+l}}(x^2-1)^l \\
               &=  \dfrac{(-1)^{m+l}}{2^l l!} (1-x^2)^{m/2} \dfrac{d^{m+l}}{dx^{m+l}}(1-x^2)^l \\
P_l^{-m}(x) &= \dfrac{(-1)^{l-m}}{2^l l!} (1-x^2)^{-m/2} \dfrac{d^{l-m}}{dx^{l-m}}(x^2-1)^l \\
                   &= \dfrac{(-1)^{l-m}}{2^l l!} (1-x^2)^{-m/2} \dfrac{(1-x^2)^m (l-m)!}{(-1)^m (l+m)!} {dx^{l+m}}(x^2-1)^l \\
                   &= (-1)^m \dfrac{(l-m)!}{(l+m)!} \dfrac{(-1)^{l+m}}{2^l l!} (1-x^2)^{m/2} {dx^{l+m}}(x^2-1)^l \\
                   &= (-1)^m \dfrac{(l-m)!}{(l+m)!} P_l^m(x)
\end{align}


これによって mが正の場合と負の場合の対応が付きました.この式はまた後に使うこととして,ここでルジャンドル陪関数の直交性を確認しましょう.
記号を次のように省略することとします.


 
\begin{align}
P_l^m(x) &= \dfrac{(-1)^m}{2^l l!} (1-x^2)^{m/2} \dfrac{d^{m+l}}{dx^{m+l}}(x^2-1)^l \\
               &= \dfrac{(-1)^m (-1)^{m/2}}{2^l l!} (x^2-1)^{m/2} \dfrac{d^{m+l}}{dx^{m+l}}(x^2-1)^l \\ 
               &= \dfrac{(-1)^m (-1)^{m/2}}{2^l l!} R^{m/2} D^{m+l} R^l \\
\end{align}
 
ただし,D = \dfrac{d}{dx}, R = x^2-1


 xの定義域は[-1, 1]であるため,-1から1までの積分を行います.


 
\begin{align}
\int_{-1}^{1} P_p^m(x) P_q^m(x) dx &= \dfrac{(-1)^{2m} (-1)^m}{2^{p+q}p!q!} \int_{-1}^{1} R^m (D^{p+m}R^{p}) (D^{q+m} R^{q})dx \\
                                                          &= \dfrac{(-1)^m}{2^{p+q}p!q!} \int_{-1}^{1} R^m (D^{p+m}R^{p}) (D^{q+m} R^{q})dx \\
                                                          &=: \dfrac{(-1)^m}{2^{p+q}p!q!} I
\end{align}


 I積分の部分です.これを計算しましょう.ただし, p \leq qとします.s回だけ部分積分したときの式は


 
\begin{align}
&I = \sum_{i = 1}^{s} s_{pq}^{i} + (-1)^s \int_{-1}^{1} D^{s} (R^m D^{m+p} R^p) D^{m+q-s} R^q dx \\
&ただし,s_{pq}^i = \left.(-1)^{i-1} D^{i-1} (R^m D^{m+p} R^p) D^{m+q-i} R^q \right|_{-1}^{1}
\end{align}


この s_{pq}^iは実は消えます.[tex: (x2-1)]の項が生きているうちは1と-1を代入すると0になるので,


 
\left. D^t R^q \right|_{-1}^{1} = 0 \;\; (t \leq q-1)


となりますね.この考え方を利用して,[tex: D^{m+q-i}Rq]の部分に注目します. m+q-i \leq q-1の時,つまり i \geq m+1の時はこの微分の結果に[tex: (x2-1)]の項が生きているので,結局1と-1を代入するとこの計算結果は0となり,


 
s_{pq}^i = 0 \;\; (i \geq m+1)


となります.次に,ライプニッツの公式を用いて,[tex: D^{i-1} (Rm D^{m+p} Rp)]の部分に注目すると,


 
D^{i-1} (R^m D^{m+p} R^p) = \sum_{r = 0}^{i-1} {}_{i-1} \mathrm{C}_r D^{i-1-r} R^m \; D^{m+p+r}R^p


[tex: D^{i-1-r} Rmの部分に注目します. i-1-r \leq i-1 \leq m-1の時,つまり i \leq mの時は \sumに現れるすべての項において[tex: (x2-1)]の項が生存します.そのため,


 
s_{pq}^i = 0 \;\; (i \leq m)


となります.先ほど得られた計算結果と合わせると,確かに全てのiについて s_{pq}^i = 0が満たされていることが分かります.結局,



\begin{align} 
I =(-1)^s \int_{-1}^{1} D^{s} (R^m D^{m+p} R^p) D^{m+q-s} R^q dx
\end{align}


ここで p \leq qであるので, I積分 m+qまで部分積分できます. s = m+qとすると,ライプニッツの式を利用して


 
\begin{align}
I &= (-1)^{m+q} \int_{-1}^{1} D^{m+q} (R^m D^{m+p} R^p) D^{0} R^q dx \\
  &= (-1)^{m+q} \int_{-1}^{1} dx R^q \left( \sum_{r=0}^{m+q} {}_{m+q}\mathrm{C}_r \; D^{m+q-r} R^m \; D^{m+p+r} R^p   \right) \\
\end{align}


さて, R xに関する2次多項式でした.なので[tex: Rm, Rp]はそれぞれ次数が 2m, 2pです.つまり次数よりも多く微分するとこれらは0になります.つまり,


 
D^{m+q-r} R^mについては,m+q-r \leq 2mを満たすrにのみ値が現れうる.\\
つまり,r \geq q-m \\
D^{m+q-r} R^mについては,m+p+r \leq 2pを満たすrにのみ値が現れうる.\\
つまり,r \leq p-m


ここで,p < qの時を考えると,なんと先ほどの議論よりどの rにも値は現れません.つまり, I = 0となります.
 p = qのときは r = p - mを満たすrのみに値が現れうるので,


 
\begin{align}
I &= (-1)^{m+q} \int_{-1}^{1} dx R^q \left( {}_{m+q}\mathrm{C}_{p-m} \; D^{2m} R^m \; D^{2p} R^p   \right) \\
\end{align}


となります.さらに,


 
\begin{align}
D^{2m} R^m = (2m)! \\
D^{2p} R^p = (2p)!
\end{align}


であるので, p = qに注意して,


 
\begin{align}
I &= (-1)^{m+q} (2m)!(2p)! {}_{m+q}\mathrm{C}_{p-m} \int_{-1}^{1} R^q dx \\
  &= (-1)^{m+p} \dfrac{(m+p)! (2p)!}{(p-m)!} \int_{-1}^{1} R^p dx \\
\end{align}


そして,この積分は部分積分を頑張ることで計算できます.


 
\begin{align}
\int_{-1}^{1} R^p dx &= \int_{-1}^{1} (x^2-1)^p dx \\
                                 &= \int_{-1}^{1} (x+1)^p (x-1)^p dx \\
                                 &= (努力) \\ 
                                 &= (-1)^p \dfrac{(p!)^2 2^{2p}}{(2p)!} \dfrac{2}{2p+1}
\end{align}


であるので,


 
\begin{align}
I &=  (-1)^{m+p} \dfrac{(m+p)! (2p)!}{(p-m)!} (-1)^p \dfrac{(p!)^2 2^{2p}}{(2p)!} \dfrac{2}{2p+1} \\
  &=  (-1)^m (p!)^2 2^{2p} \dfrac{(p+m)!}{(p-m)!} \dfrac{2}{2p+1}
\end{align}


と求まります.以上の議論より,クロネッカーのデルタを用いて,


 
\begin{align}
\int_{-1}^{1} P_p^m(x) P_q^m(x) dx &= \dfrac{(-1)^m}{2^{p+q}p!q!} I \\
                                                          &= \delta_{pq} \dfrac{(-1)^m}{2^{p+q}p!q!} \times (-1)^m (p!)^2 2^{2p} \dfrac{(p+m)!}{(p-m)!} \dfrac{2}{2p+1} \\
                                                          &= \dfrac{2}{2p+1} \dfrac{(p+m)!}{(p-m)!} \delta_{pq}  
\end{align}


以上より,ルジャンドル陪関数が直交性を持つことが分かりました.いったんまとめましょう.元々は最初に示したラプラス方程式 \Theta(\theta), \Phi(\phi)を求めていました.これまでの議論より,


 
\begin{align}
& \Phi(\phi) = A e^{jm\phi} \\
& \Theta(\theta) = B P_l^m(x) = B P_l^m(\cos\theta)
\end{align}


です.ここで,正規化を与えることを考えます.つまり,


 
\begin{align}
&\int_{\Omega_\phi} |\Phi(\phi)|^2 d\phi = 1 \\
&\int_{\Omega_\theta} |\Theta(\theta)|^2 d\theta = 1 \\
\end{align}


を満たさせることとします.


 
\begin{align}
&\int_{\Omega_\phi} |\Phi(\phi)|^2 d\phi = \int_{0}^{2\pi} |A e^{jm\phi}|^2 d\phi = 2\pi A^2 = 1 \\
&\int_{\Omega_\theta} |\Theta(\theta)|^2 d\theta = \int_{-1}^{1} |B P_l^m(x)|^2 dx  = \dfrac{2}{2l+1} \dfrac{(l+m)!}{(l-m)!} B^2 = 1 \\
\end{align}


これらより A, Bが求まり, \Phiと\Thetaは次のように書けます.


 
\begin{align}
& \Phi(\phi) = \dfrac{1}{\sqrt{2\pi}} e^{jm\phi} \\
& \Theta(\theta) = \sqrt{ \dfrac{2l+1}{2} \dfrac{(l-m)!}{(l+m)!} } P_l^m(\cos\theta)
\end{align}


さて,これの積を改めて[tex: Y_lm(\theta, \phi)]と書いて,


 
\begin{align}
Y_l^m(\theta, \phi) &= \sqrt{ \dfrac{2l+1}{4\pi} } \sqrt{ \dfrac{(l-m)!}{(l+m)!} } P_l^m(\cos\theta)e^{ jm\phi} \\
                               &= \sqrt{ \dfrac{2l+1}{4\pi} } \sqrt{ \dfrac{(l-m)!}{(l+m)!} } P_l^m(x)e^{jm\phi} \\
                               &=  \sqrt{ \dfrac{2l+1}{4\pi} } \sqrt{ \dfrac{(l-m)!}{(l+m)!} } \dfrac{(-1)^m}{2^l l!} (1-x^2)^{m/2} \dfrac{d^{m+l}}{dx^{m+l}}(x^2-1)^l e^{jm\phi}
\end{align}


先ほど示したルジャンドル陪関数の mの正負に関する関係


 
\begin{align}
P_l^{-m}(x) &= (-1)^m \dfrac{(l-m)!}{(l+m)!} P_l^m(x)
\end{align}


より,


 
\begin{align}
Y_l^{-m}(\theta, \phi) &= \sqrt{ \dfrac{2l+1}{4\pi} } \sqrt{ \dfrac{(l+m)!}{(l-m)!} } (-1)^m \dfrac{(l-m)!}{(l+m)!} P_l^m(x) e^{-jm\phi} \\
                                   &= (-1)^m Y_l^{m*}(\theta, \phi)
\end{align}


であるので,負の mを考慮した式は


 
\begin{align}
Y_l^m(\theta, \phi) &= (-1)^{(|m|+m)/2} \sqrt{ \dfrac{2l+1}{4\pi} } \sqrt{ \dfrac{(l-|m|)!}{(l+|m|)!} } \dfrac{1}{2^l l!} (1-x^2)^{|m|/2} \dfrac{d^{|m|+l}}{dx^{|m|+l}}(x^2-1)^l e^{jm\phi}
\end{align}


となります.これを球面調和関数と言います.ルジャンドル陪関数[tex: P_lm(\cos\theta)]の lに関する直交性と先程の正規化の処理に加えて, \Phi(\phi)の直交性


 
\begin{align}
\int_{0}^{2\pi} \Phi_m(\phi) \Phi_n^{*} (\phi) d\phi &= \dfrac{1}{2\pi} \int_{0}^{2\pi} e^{jm\phi} e^{-jn\phi}d\phi \\
                                                                                 &= \dfrac{1}{2\pi} \int_{0}^{2\pi} e^{j(m-n)\phi} d\phi \\
                                                                                 &= \delta_{mn}
\end{align}


より,球面調和関数には正規直交性が成立します.また,球面調和関数には完全性があり,球面上の連続で滑らかな関数 f(\theta, \phi)が球面調和関数系の線形結合


 
\begin{align}
f(\theta, \phi) = \sum_{l=0}^{\infty} \sum_{m = -l}^{l} f_l^m Y_l^m(\theta, \phi)
\end{align}


として一意に表せます.つまり,球面上で定義される関数を展開できるということです.実数上の関数を級数展開するあれと同じですね.私の知識では完全性の証明をすることは出来ませんでした.無念(ワイエルシュトラスの近似定理なるものを用いて色々やってる証明を見つけましたが理解できませんでした......)

さて,現状の球面調和関数系を使用しても良いのですが,近似する対象が実関数である場合は実数の球面調和関数を基底として扱いたいです.これを \mathcal{Y}(\theta, \phi)として,


 
\begin{align}
case(m = 0) \; &: \; \mathcal{Y}(\theta, \phi) = Y_l^m(\theta, \phi) \\
case(m > 0) \; &: \; \mathcal{Y}(\theta, \phi) = \dfrac{(-1)^m Y_l^m(\theta, \phi) + Y_l^{-m}(\theta, \phi)} {\sqrt{2}} \\
case(m < 0) \; &: \; \mathcal{Y}(\theta, \phi) = \dfrac{(-1)^{|m|} Y_l^{|m|}(\theta, \phi) - Y_l^{-|m|}(\theta, \phi)} {\sqrt{2}j} \\
\end{align}


と定義してあげることで実数球面調和関数が得られます.

エンコーダーとしての球面調和関数

長くなりましたが,Spherical Harmonic Encodingでは視線の方向をエンコードします.ここで,視線の方向は長さが1の3次元ベクトル (x, y, z)です.これは半径が1の球面上の点と見ることが出来ます.つまり,視線の方向を入力とする関数は,単位球面上を定義域とする関数として見ることが出来ます.ここで,球面調和関数の完全性より,球面上で定義される関数が球面調和関数の(無限の)基底の線形結合で表せました.全結合層は(有限の)基底を線形結合する(ことにより関数の応答を近似する)ということを考えると,Spherical Harmonic Encodingは「視線の方向の基底を変換し,よりパラメーター次元を増やすものである」と考えられると私は解釈しています.ただし,この結論に関しては参考文献などがあるわけではないので違うかもしれません.

Spherical Harmonic Encodingの実装

さてさて,では実装に取り掛かりましょう.これまでの理論なしにも実装自体は簡単に出来ますので実装を説明します.

MFFM_DEVICE void Encode_SH_L4(__half* input, __half* Out) {
    const int bx = blockIdx.x;
    const int tx = threadIdx.x;
    const int ty = threadIdx.y;
    const int inidx = 3 * (32 * ty + tx);
    const int outidx = (16 + SKEW) * (32 * ty + tx);
    // 1ブロック128バッチを担当する
    if (32 * ty + tx >= ONEBATCH_SIZE) {
        return;
    }

    const __half x = input[inidx + 0];
    const __half y = input[inidx + 1];
    const __half z = input[inidx + 2];

    const __half xx = x * x;
    const __half yy = y * y;
    const __half zz = z * z;
    const __half xy = x * y;
    const __half xz = x * z;
    const __half yz = y * z;
    const __half xyz = x * y * z;

    const __half r2 = 1.4142135623730950488016887242097f;
    const __half r3 = 1.7320508075688772935274463415059f;
    const __half r5 = 2.2360679774997896964091736687313f;
    const __half r7 = 2.6457513110645905905016157536393f;
    const __half r15 = 3.8729833462074168851792653997824f;
    const __half r21 = 4.582575694955840006588047193728f;
    const __half r35 = 5.9160797830996160425673282915616f;
    const __half r105 = 10.246950765959598383221038680521f;
    const __half rpi = 1.7724538509055160272981674833411f;

    __syncthreads();

    // L = 0
    Out[outidx + 0] = (__half)1.0f / ((__half)2.0f * rpi);

    // L = 1
    Out[outidx + 1] = (r3 / ((__half)2.0f * rpi)) * y;
    Out[outidx + 2] = (r3 / ((__half)2.0f * rpi)) * z;
    Out[outidx + 3] = (r3 / ((__half)2.0f * rpi)) * x;

    // L = 2
    Out[outidx + 4] = (r15 / ((__half)2.0f * rpi)) * xy;
    Out[outidx + 5] = (r15 / ((__half)2.0f * rpi)) * yz;
    Out[outidx + 6] = (r5 / ((__half)4.0f * rpi)) * ((__half)3.0f * z * z - (__half)1.0f);
    Out[outidx + 7] = (r15 / ((__half)2.0f * rpi)) * xz;
    Out[outidx + 8] = (r15 / ((__half)4.0f * rpi)) * (xx - yy);

    // L = 3
    Out[outidx + 9] = (r2 * r35 / ((__half)8.0f * rpi)) * y * ((__half)3.0f * xx - yy);
    Out[outidx + 10] = (r105 / ((__half)2.0f * rpi)) * xyz;
    Out[outidx + 11] = (r2 * r21 / ((__half)8.0f * rpi)) * y * ((__half)-1.0f + (__half)5.0f * zz);
    Out[outidx + 12] = (r7 / ((__half)4.0f * rpi)) * z * ((__half)5.0f * z * z - (__half)3.0f);
    Out[outidx + 13] = (r2 * r21 / ((__half)8.0f * rpi)) * x * ((__half)-1.0f + (__half)5.0f * zz);
    Out[outidx + 14] = (r105 / ((__half)4.0f * rpi)) * (xx - yy) * z;
    Out[outidx + 15] = (r2 * r35 / ((__half)8.0f * rpi)) * x * (xx - (__half)3.0f * yy);

    for (int i = 16; i < 16 + SKEW; i++) {
        Out[outidx + i] = 0.0f;
    }
}

部分的にみていきましょう.

MFFM_DEVICE void Encode_SH_L4(__half* input, __half* Out) {
...

・input: 視線の方向が(x, y, z)の形で格納されています.
・Out: エンコード結果を格納するポインタです.

const int bx = blockIdx.x;
const int tx = threadIdx.x;
const int ty = threadIdx.y;
const int inidx = 3 * (32 * ty + tx);
const int outidx = (16 + SKEW) * (32 * ty + tx);

・inidx: 実行中のスレッドにて処理する入力ベクトルを指すポインタへアクセスするためのインデックスです.
・outidx: 実行中のスレッドにてエンコード結果を格納するポインタへアクセスするためのインデックスです.

// 1ブロック128バッチを担当する
if (32 * ty + tx >= ONEBATCH_SIZE) {
    return;
}

Multiresolution Hash Encodingと同じことをしています.

const __half x = input[inidx + 0];
const __half y = input[inidx + 1];
...
const __half rpi = 1.7724538509055160272981674833411f;

入力ベクトルをロードし,各計算に必要な定数を置いてます.

__syncthreads();

Multiresolution Hash Encodingと同じ役割です.

...
// L = 0
Out[outidx + 0] = (__half)1.0f / ((__half)2.0f * rpi);
...
Out[outidx + 15] = (r2 * r35 / ((__half)8.0f * rpi)) * x * (xx - (__half)3.0f * yy);
...

実数球面調和関数の基底を計算しています.今回は l = 3まで計算します.

for (int i = 16; i < 16 + SKEW; i++) {
    Out[outidx + i] = 0.0f;
}

SKEWの部分を0埋めしてます.

以上です.当然学習パラメーターはありません.基底の計算式は理論のところで示した実数球面調和関数に対して座標系を極座標から直交座標系へ変換してあげれば良いのですが,面倒なので球面調和関数表を参照しましょう.Wikipediaにもあります.

Multiresolution Hash Encodingによる2次元画像の近似(2次元の関数の近似)

まずはこちら側のエンコーダーによる効果を見ていきましょう.2次元画像は「座標(2次元)を与えると色(RGBとします)を返す関数」として見ることが出来ます.では,次の画像を近似しましょうかね.

画像サイズは256x256です.4年ぐらい前にBlenderで作った3Dモデルのレンダリング画像です.画像上の座標を[0, 1]に正規化して入力しました.なお,3DのMultiresolution Hash Encodingなのでz座標が求められますが,これを0.5と固定しました.さて,次の設定で学習しました.
Multiresolution Hash Encodingのパラメーター

 
L = 16  \\
T = 2^{20} \\
F = 2 \\
E = 0 \\
N_{min} = 2 \\
b = 2.2

MLPのパラメーター
隠れ層次元: 64
出力層次元: 3
隠れ層の数: 4
入力層と隠れ層の活性化関数: ReLU
出力層の活性化関数: Sigmoid
学習設定
誤差関数はHubor(閾値0.05)
最適化関数がAdamの場合は学習率0.01
最適化関数がGradient Descendantの場合は学習率2.0

結果を見ていきましょう.まずは誤差の変化は次の図のようになりました.誤差は平均二乗誤差です.

各最適化関数による学習中の出力は次の画像に示す通りです.各画像内の左上にある数値はその画像を出力したときのイテレーションです.
Gradient Descendant

Adam

両者の学習する過程はかなり異なっていて面白いですね.それよりも,Part1では簡単な1次元関数でも近似させるのが困難であったのに,Multiresolution Hash Encodingを通すことで非常に近似精度が向上しましたね.では,エンコーダー無しの場合を見ていきましょう.画像上の座標を[-0,0001, 0.0001]に正規化してNNに入力しました.

MLPのパラメーター
隠れ層次元: 64
出力層次元: 3
隠れ層の数: 12
入力層と隠れ層の活性化関数: ReLU
出力層の活性化関数: Sigmoid
学習設定
誤差関数はHubor(閾値0.05)
最適化関数がAdamの場合は学習率0.01
最適化関数がGradient Descendantの場合は学習率0.5

Gradient Descendant Adam

近づいてはいますが,限界にぶつかっているように見えます.

今回は256x256の画像で確認しましたが,これをもっとピクセル数が多い画像に対しても行うことが可能です.こうしてみると,非常に強力な手法なのですが,苦しい点も勿論あります.一つは言うまでもありませんが,メモリ消費が激しいことです.そしてもう一つですが,かなり処理が遅いです.主な理由はメモリアクセスに起因しております.ハッシュ関数を利用してアクセスしているため,まずキャッシュのヒット率が辛いことになってます.そして,グローバルメモリとのやり取りが非常に多くなります.特にテーブル上のパラメーターが多いので最適化処理でかなり遅くなります.ざっくり計測した感じでは今回の画像近似においては学習処理では実行時間の90%ぐらいが最適化処理に持って行かれてます.推論処理だと処理時間の80%程度がMultiresolution Hash Encodingに持って行かれています.ただ,実装に関しては改善できる点も多いのでこの値はあまり参考にしなくていいと思います.ただしやはり遅いには遅いです.

Sphrerical Harmonic Encodingによる球面上の関数の近似

球面上の関数を近似しましょう. (x, y, z)を単位球面上の座標として,

 
\begin{align}
&\mathbf{V} = (\sin(2\pi x), \cos(5\pi y), \sin(2.5\pi z)\cos(3\pi z)) \\
&f(x, y, z) := 0.5(x, y, z) \cdot \mathbf{V} + 0.5
\end{align}


……特に意味はないです.さて,次の設定で近似しました.
MLPのパラメーター
隠れ層次元: 64
出力層次元: 1
隠れ層の数: 1
入力層と隠れ層の活性化関数: ReLU
出力層の活性化関数: Sigmoid
学習設定
誤差関数はHubor(閾値0.05)
最適化関数: Adam
学習率: 0.025

エンコーダーありとエンコーダーなし,両方について上記の設定で実行しました.誤差の発展は次の通りです.

MLP側の層がかなり小さいため,エンコーダー無しでは近似が厳しそうであることが見えます.次に,エンコーダーを使用した場合の近似の進む様子を見ます.次の図に出てくるプロット点は

 
\begin{align}
(x,y,z)(1 + 0.4 \hat{f}(x,y,z))
\end{align}

を3D空間上にプロットしており,黄色い点群が推定,黒い点群が真値です.

いい感じですね.ちなみに点群はBlenderで表示しており,プログラムからPythonスクリプトで頂点を追加する関数をテキストで出力してBlenderPythonスクリプトに貼り付けて実行することにより点群を作っています.

エンコーダー編 さいごに

2編ではInstant NeRFにおいて使用されているエンコーダーについて説明しました.これによりニューラルネットワークの基礎部分は完成しました.3編では遂にNeRFの実装に入ります.今回行った画像近似は2次元の近似です.しかし私たちが目にしているのは3次元の空間であり,今度はこれを近似する必要があります.しかし画像の近似とは異なり,3次元空間の近似はその空間の可視化(レンダリング)が容易ではなく,それゆえに色んな概念を使用してその近似が試みられています.その一つがNeRFです.詳しい話はPart3でやりましょう.

参考資料

Convolutional Sequence to Sequence Learning
Attention Is All You Need
NeRF: Representing Scenes as Neural Radiance Fields for View Synthesis
Instant Neural Graphics Primitives with a Multiresolution Hash Encoding
Optim Tech Blog. Instant NeRF の心臓、Multiresolution Hash Encoding をシンプルに実装しつつ2次元画像で試してみる
Mathematical Methods for Physicists A Comprehensive Guide Seventh Edition 2012
宇宙物理メモ ルジャンドル陪関数
球面調和関数表 Wikipedia
Spherical Harmonic Lighting: The Gritty Details
Precomputed Radiance Transfer for Real-Time Rendering in Dynamic, Low-Frequency Lighting Environments