suztomoの日記

To be a good software engineer

RustでConvolutional Neural Network

趣味プログラミングとしてRustとNeural Networkを勉強したかったので、ゼロから作るDeep Learningを読んでRustでChapter 7のConvolutional Neural Networkを実装した。

www.oreilly.co.jp

実装したコード: Language Study/Rust - conv_nn_mnist

Rust

プログラミング言語Rustがこれからより多くの場所に使われるだろう。コンパイラが間違った(間違いそうな)コードを弾いてくれる点と、クロージャや台数データ型といったプログラミングの抽象化がスピードを犠牲にせずに利用できるという点はとても魅力的である。

一方でRustは難しい。書いたコードはだいたいコンパイラエラーに当たる。Structや関数のsignatureに対してRustのコンパイラがエラーを出した場合、メッセージをしっかり理解してその部分を直すことでよりデータの所有権や参照のlifetimeが明瞭な安全なプログラムを書くことができる。

プログラミング言語研究の成果を享受できるのでRustは好きだ。

開発環境

VSCode。良くできてる。lldbのプラグインGUIを使ったデバッグもできる。

エラーメッセージ

Rustのエラーメッセージは時々難しい。例えば、weightsという行列の各要素からself.mの各要素を少し変形したものを引くという

        *weights -= (&self.m * lr_t / &self.v.mapv(|v| v.sqrt() + 1e-7));

このコードに対して、下のエラー。

error[E0271]: type mismatch resolving `<ndarray::OwnedRepr<f64> as ndarray::Data>::Elem == ndarray::ArrayBase<ndarray::OwnedRepr<f64>, D>`
  --> src/network.rs:66:18                                                                                                               
   |                                                                                                                                     
66 |         *weights -= (&self.m * lr_t / &self.v.mapv(|v| v.sqrt() + 1e-7));                                                           
   |                  ^^ expected f64, found struct `ndarray::ArrayBase`                                                                 
   |                                                                                                                                     
   = note: expected type `f64`                                                                                                           
              found type `ndarray::ArrayBase<ndarray::OwnedRepr<f64>, D>`                                                                
   = note: required because of the requirements on the impl of `std::ops::SubAssign` for `ndarray::ArrayBase<ndarray::OwnedRepr<f64>, D>`
                                                                                                                                         
error[E0277]: the trait bound `ndarray::ArrayBase<ndarray::OwnedRepr<f64>, D>: ndarray::ScalarOperand` is not satisfied                  
  --> src/network.rs:66:18                                                                                                               
   |                                                                                                                                     
66 |         *weights -= (&self.m * lr_t / &self.v.mapv(|v| v.sqrt() + 1e-7));                                                           
   |                  ^^ the trait `ndarray::ScalarOperand` is not implemented for `ndarray::ArrayBase<ndarray::OwnedRepr<f64>, D>`

そのまま読めばArrayBase<ndarray::OwnedRepr, D>という型がScalarOperandというtraitに合っていないということ。だからといって加減剰余の演算をArrayBaseに対して行えないばけではない。実はこのエラーは"-="はbinary operationの右辺は行列ではなく行列への参照じゃないといけない。ndarrayのドキュメントにそう書いてある

docs.rs

これを読めば、&を右辺の最初につければいいんだなとわかるのだが、RustコンパイラはこのドキュメントへのURLを表示することはできない。

クロージャのTrait

FnやFnMutのエラーメッセージも難しい。Rustのかわいい蟹の本のMulti-threaded web serverの章にもクロージャにまつわる難解なエラーメッセージの話があるように、Rustコンパイラは完璧ではない。Rustがより良くなっていってこのあたりが簡単になりますように。

ndarray

行列ライブラリにはRustのndarray - Rustを使った。

why are there so many linear algebra crates? Which one is "best"?の中でndarrayの評判がよかったので。

dotと*

Array.dot行列の掛け算で、*がelement-wise operation。行列の掛け算は1次元の行列と2次元の行列とで振る舞いが異なる。

into_shape()

ハマったところはinto_shape()。まずinto_shapeが使えない場合があるので自分でreshape関数を作ったが、これは常に行列をコピーすることになって失敗しない代わりにメモリが無駄になっているはず。

次にpermuted_axis()into_shape()は混ぜて使えないところ。このlimitationは最後の最後まで気づかなかった。

Resolve into_shape() limitations · Issue #390 · rust-ndarray/ndarray · GitHub

slice_mut

これは便利

    for y in 0..filter_height {
        let y_max = min(y + stride * out_h, input_padded_height);
        for x in 0..filter_width {
            let x_max = min(x + stride * out_w, input_padded_width);
            ...
            let img_slice = img.slice(s![.., .., y..y_max;stride, x..x_max;stride]);
            let mut col_slice_mut = col.slice_mut(s![.., .., y, x, .., ..]);
            col_slice_mut.assign(&img_slice);
        }
Layerの型

最初はニューラルネットワークの層を表すtraitを作ろうとして挫折した。これは層を1つ1つ実装していく中で層の中には4次元の入力を受けるものと2次元の入力を受けるものがあるからだった。今考えると入力の次元数と出力の次元数でgenericsを使ったtraitを作れる気がする。Layer<Input, Output> where Input: Dimension, Output: Dimension みたいな。

ところで、Numpyでは下のコードの様に異なる入力の型をとるlayerをまとめてリストに入れたり、その入力と出力をforループで回したりできる。

    def predict(self, x):
        for layer in self.layers.values():
            x = layer.forward(x)
        return x

しかしRustは型に厳密なので、これができない。一応ndarrayにはdynamic dimensional arrayはあるが型安全性は失われてしまう。これを解決してくれる良いデータ型はないものだろうか。例えばこんなの

Network [ Layer<Ix4, Ix4>, Layer<Ix4, Ix2>, Layer<Ix2, Ix2> ]   # => compiles
Network [ Layer<Ix4, Ix4>, Layer<Ix4, Ix4>, Layer<Ix2, Ix2> ]   # => error becauase the second output and the third input are different
絵が書けない

PythonとNumpyの組み合わせに比べるとこのRustでの実装ではグラフ描画やipythonのようなインタラクティブな操作ができない。入力データの画像を目でも見たかったのでansi_termで色をつけて表示してみたりした。

MNIST number 7 shown in terminal
MNIST 7 in ansi_term

Convolutional Neural Network

ニューラルネットワークを計算グラフとして考えることで、損失関数の重さに関する偏微分の計算が局所的になるところがミソ。なのでネットワークにいくつもの層があっても、各層ごとに重さと入力の各要素の偏微分が(局所的に)合っていれば全体の計算が合う。ネットワークが期待通りに動かなかった時にどこの計算の実装が間違っているのか調べるのは骨が折れる作業だが、Gradient checkを使うと実装が怪しい層を見つけることができる。自分の場合はどうしてもバグの箇所がわからない時があったので、結局上の本のgithubにあるNumpyの実装に同じ値を与えてどこの変換でバグがあるのか調べた。(結局上のpermuted_axis+into_shapeが原因だった)

本のChapter 7のSimple Convnetの構成はconvolution layer -> relu layer -> pooling layer -> affine layer -> relu layer -> affine layer -> relu layer 。損失関数はsoftmaxとクロスエントロピー誤差。入力はMNISTの28x28のグレイスケール画像データ。層を何枚も重ねる「ディープラーニング」はしなかった。

Convolution Layer

は入力画像の一部に適用される重み(filter)を何層も持つことでAffine layerだけのネットワークよりも画像の一部の形に反応するネットワークを作ることができるらしい。

Adam

Adamはネットワーク内部の重みを調節するアルゴリズムの一つ。これを導入する前は重みから偏微分されたものをlearning rate 0.01掛けて引いており学習の結果は悪かった(テストデータに対して40%とかの確率)。Adamを導入したら同じ回数の学習で97%を達成。

RustとConvolutional Neural Networkを学べたので目標達成。

やらなかったこと