前回に引き続き書籍「ゼロから作るDeepLearning」をベースに、前回NArray配列として扱うようにしたMNISTデータに対して推論処理を行うニューラルネットワークを実装してみます。
ニューロンの構成は下記の通りです。
- 入力層: 784 # 28 x 28 の画像データのピクセル数
- 隠れ層1: 50
- 隠れ層2: 100
- 出力層: 10 # 結果として10クラス(0から9の数字)に分類する
サンプルコード全体
まずはサンプルコードの全体を掲載しておきます。
require 'numo/narray' require 'json' require './sigmoid.rb' require './softmax.rb' require './mnist.rb' def get_data x_train, t_train, x_test, t_test = load_mnist(true, true, false) [x_test, t_test] end def init_network nw = JSON.load(File.read('sample_weight.json')) network = {} nw.each do |k, v| network[k.to_sym] = Numo::DFloat[*v] end network end def predict(network, x) w1 = network[:w1] w2 = network[:w2] w3 = network[:w3] b1 = network[:b1] b2 = network[:b2] b3 = network[:b3] a1 = x.dot(w1) + b1 z1 = sigmoid(a1) a2 = z1.dot(w2) + b2 z2 = sigmoid(a2) a3 = z2.dot(w3) + b3 softmax(a3) end x, t = get_data network = init_network batch_size = 100 accuracy_cnt = 0 x.to_a.each_slice(batch_size).with_index do |x_batch, idx| y_batch = predict(network, Numo::DFloat[*x_batch]) p = y_batch.max_index(1) % 10 accuracy_cnt += p.eq(t[(idx * batch_size)..(idx * batch_size + (batch_size - 1))]).cast_to(Numo::UInt8).sum end puts "Accuracy: #{accuracy_cnt.to_f / x.shape[0]}"
MNISTデータの取得
前回実装したMNISTデータのロード処理(mnist.rb)を使ってMNISTデータのテストデータを取得します。
def get_data x_train, t_train, x_test, t_test = load_mnist(true, true, false) [x_test, t_test] end
重みパラメータのロード
学習済みの重みとバイアスのパラメータは、書籍のサンプルコードで pickle ファイルとして提供されているものをあらかじめJSONに変換してファイルに保存しておき(sample_weight.json)、それを読み込んでいます。
def init_network nw = JSON.load(File.read('sample_weight.json')) network = {} nw.each do |k, v| network[k.to_sym] = Numo::DFloat[*v] end network end
推論処理
上記でロードしたテストデータと重みパラメータに対して、隠れ層での活性化関数にはシグモイド関数、出力層での活性化関数にはソフトマックス関数を使って推論処理を行います。シグモイド関数は sigmoid.rb、ソフトマックス関数は softmax.rb として保存して読み込んでおきます。
def sigmoid(x) 1 / (1 + Numo::DFloat::Math.exp(-x)) end
def softmax(a) c = a.max exp_a = Numo::DFloat::Math.exp(a - c) sum_exp_a = exp_a.sum exp_a / sum_exp_a end
def predict(network, x) w1 = network[:w1] w2 = network[:w2] w3 = network[:w3] b1 = network[:b1] b2 = network[:b2] b3 = network[:b3] a1 = x.dot(w1) + b1 z1 = sigmoid(a1) a2 = z1.dot(w2) + b2 z2 = sigmoid(a2) a3 = z2.dot(w3) + b3 softmax(a3) end
処理の実行
上記メソッドを使って処理を実行します。
x, t = get_data network = init_network batch_size = 100 accuracy_cnt = 0 x.to_a.each_slice(batch_size).with_index do |x_batch, idx| y_batch = predict(network, Numo::DFloat[*x_batch]) p = y_batch.max_index(1) % 10 accuracy_cnt += p.eq(t[(idx * batch_size)..(idx * batch_size + (batch_size - 1))]).cast_to(Numo::UInt8).sum end puts "Accuracy: #{accuracy_cnt.to_f / x.shape[0]}"
get_data で取得した画像データを100件ずつバッチ処理します。NArrayの二次元配列データはそのままループすると一次元配列として並べて各要素が参照されてしまうので、to_a で通常の配列データに変換した上で、 each_slice で100件ずつのまとまりにし、with_index でインデックスを取得します。
x.to_a.each_slice(batch_size).with_index do |x_batch, idx| ... end
ループの中では100件分の画像データを再度NArray配列に変換してpredictメソッドに渡し、推論処理を実行した結果を y_batch として受け取っています。predict では100件分の各画像データについて、0-9の各数字に対しての確度が確率として返されるので、100 x 10 の二次元配列になります。
y_batch = predict(network, Numo::DFloat[*x_batch])
その y_batch に対して max_index メソッドで二次元目の配列データについて最も数値が大きい要素のインデックスを取得しています。ただし取得されるインデックスは 100 x 10 を一次元配列として並べた場合のインデックスになるので、10で割った余りを取得することで、0-9の分類に変換しています。
Class: Numo::Int32 — Documentation by YARD 0.9.8
p = y_batch.max_index(1) % 10
処理している画像データに対応するラベルデータを取得します。each_sliceとwith_indexを使った場合、インデックスとしてはeach_sliceのまとまりごとに0から順番に振られるので、インデックスにバッチサイズ(100)をかけることで配列データの始点を特定し、それにさらにバッチサイズを加算したものから1を引くことで、終点を特定して、ラベルデータ配列 t から該当するラベルデータを取得します。
t[(idx * batch_size)..(idx * batch_size + (batch_size - 1))]
それを分類結果データ p と eq メソッドで比較します。
Class: Numo::Int32 — Documentation by YARD 0.9.8
eqメソッドでは配列の各要素を比較し、一致する場合は1を、一致しない場合は0を配列として返します。
irb(main):043:0* a => Numo::Int32#shape=[5] [1, 3, 5, 7, 9] irb(main):044:0> b => Numo::Int32#shape=[5] [1, 2, 5, 7, 8] irb(main):045:0> a.eq(b) => Numo::Bit#shape=[5] [1, 0, 1, 1, 0]
返される配列のデータ型はNumo::BitなのでNumo::UInt8に変換し、合計値を取得することで推論が正しかった要素の数が取得できます。
accuracy_cnt += p.eq(t[(idx * batch_size)..(idx * batch_size + (batch_size - 1))]).cast_to(Numo::UInt8).sum
そして最後に正解数を要素数で割ることで、正解率を計算しています。
puts "Accuracy: #{accuracy_cnt.to_f / x.shape[0]}"
今回実装したコードはこちらにも公開しました。