MNISTデータセットをNArray配列として扱う

 RubyからMNISTのデータセットを扱えるように、書籍「ゼロから作るDeepLearning」のコードをベースにMNISTのデータセットをNArray配列としてロードするコードを書いてみました。

www.oreilly.co.jp

コード全体

 まずはコード全体を掲載します。

require 'open-uri'
require 'zlib'
require 'fileutils'
require 'numo/narray'

URL_BASE = 'http://yann.lecun.com/exdb/mnist/'
KEY_FILES = {
  train_img:   'train-images-idx3-ubyte.gz',
  train_label: 'train-labels-idx1-ubyte.gz',
  test_img:    't10k-images-idx3-ubyte.gz',
  test_label:   't10k-labels-idx1-ubyte.gz'
}

DATASET_DIR = "#{File.dirname(__FILE__)}/dataset"
SAVE_FILE = "#{DATASET_DIR}/mnist.dump"

IMG_SIZE = 784

def download(file_name)
  puts "Downloading #{file_name} ..."
  open(URL_BASE + file_name) do |file|
    open("#{DATASET_DIR}/#{file_name}", 'w+b') do |out|
      out.write(file.read)
    end
  end
  puts "Done."
end

def download_mnist
  FileUtils.mkdir_p(DATASET_DIR)
  KEY_FILES.each do |k, file|
    download(file)
  end
end

def load_img(file_name)
  puts "Converting #{file_name} to NArray ..."
  data = nil
  Zlib::GzipReader.open("#{DATASET_DIR}/#{file_name}") do |gz|
    data = gz.each_byte.to_a[16..-1].each_slice(IMG_SIZE).to_a
    data = Numo::UInt8[*data]
  end
  puts "Done"
  data
end

def load_label(file_name)
  puts "Converting #{file_name} to NArray ..."
  data = nil
  Zlib::GzipReader.open("#{DATASET_DIR}/#{file_name}") do |gz|
    data = Numo::UInt8[*gz.each_byte.to_a[8..-1]]
  end
  puts "Done"
  data
end

def convert_narray
  dataset = {}
  dataset[:train_img] = load_img(KEY_FILES[:train_img])
  dataset[:train_label] = load_label(KEY_FILES[:train_label])
  dataset[:test_img] = load_img(KEY_FILES[:test_img])
  dataset[:test_label] = load_label(KEY_FILES[:test_label])
  dataset
end

def init_mnist
  download_mnist
  dataset = convert_narray
  puts "Creating dump file ..."
  File.write(SAVE_FILE, Marshal.dump(dataset))
  puts "Done!"
end

def change_one_hot_label(x)
  one_hot_arrays = x.to_a.map do |v|
    one_hot_array = Array.new(10, 0)
    one_hot_array[v] = 1
    one_hot_array
  end
  Numo::UInt8[*one_hot_arrays]
end

def load_mnist(normalize = true, flatten = true, one_hot_label = false)
  unless File.exist?(SAVE_FILE)
    init_mnist
  end

  dataset = Marshal.load(File.read(SAVE_FILE))

  if normalize
    %i(train_img test_img).each do |key|
      dataset[key] = dataset[key].cast_to(Numo::DFloat)
      dataset[key] /= 255.0
    end
  end

  if one_hot_label
    dataset[:train_label] = change_one_hot_label(dataset[:train_label])
    dataset[:test_label] = change_one_hot_label(dataset[:test_label])
  end

  unless flatten
    %i(train_img test_img).each do |key|
      dataset[key] = dataset[key].reshape(dataset[key].shape[0], 28, 28)
    end
  end

  [dataset[:train_img], dataset[:train_label], dataset[:test_img], dataset[:test_label]]
end

データのダウンロード

 初回実行時はMNISTデータがホスティングされている Yann LeCun’s MNIST からデータをダウンロードします。対象ファイルは下記のように定義しています。

URL_BASE = 'http://yann.lecun.com/exdb/mnist/'
KEY_FILES = {
  train_img:   'train-images-idx3-ubyte.gz',
  train_label: 'train-labels-idx1-ubyte.gz',
  test_img:    't10k-images-idx3-ubyte.gz',
  test_label:   't10k-labels-idx1-ubyte.gz'
}

 これをデータセット用のディレクトリにダウンロードします。

def download(file_name)
  puts "Downloading #{file_name} ..."
  open(URL_BASE + file_name) do |file|
    open("#{DATASET_DIR}/#{file_name}", 'w+b') do |out|
      out.write(file.read)
    end
  end
  puts "Done."
end

def download_mnist
  FileUtils.mkdir_p(DATASET_DIR)
  KEY_FILES.each do |k, file|
    download(file)
  end
end

画像データのロード

 ダウンロードしたMNISTデータセットの画像データをロードします。データ形式は一般的な画像フォーマットではなく、Yann LeCun’s MNISTに下記のように説明があります。

TRAINING SET IMAGE FILE (train-images-idx3-ubyte):

[offset] [type]          [value]          [description] 
0000     32 bit integer  0x00000803(2051) magic number 
0004     32 bit integer  60000            number of images 
0008     32 bit integer  28               number of rows 
0012     32 bit integer  28               number of columns 
0016     unsigned byte   ??               pixel 
0017     unsigned byte   ??               pixel 
........ 
xxxx     unsigned byte   ??               pixel
Pixels are organized row-wise. Pixel values are 0 to 255. 0 means background (white), 255 means foreground (black).

TEST SET IMAGE FILE (t10k-images-idx3-ubyte):

[offset] [type]          [value]          [description] 
0000     32 bit integer  0x00000803(2051) magic number 
0004     32 bit integer  10000            number of images 
0008     32 bit integer  28               number of rows 
0012     32 bit integer  28               number of columns 
0016     unsigned byte   ??               pixel 
0017     unsigned byte   ??               pixel 
........ 
xxxx     unsigned byte   ??               pixel

 また、下記サイトも参考にさせていただきました。

TensorFlow : MNIST データ・ダウンロード (コード解説) – TensorFlow

 上記内容と書籍のコードを元に、下記のように実装しました。

def load_img(file_name)
  puts "Converting #{file_name} to NArray ..."
  data = nil
  Zlib::GzipReader.open("#{DATASET_DIR}/#{file_name}") do |gz|
    data = gz.each_byte.to_a[16..-1].each_slice(IMG_SIZE).to_a
    data = Numo::UInt8[*data]
  end
  puts "Done"
  data
end

 ダウンロードしたgzファイルをopenし、byte単位で読み込みます。先頭16byteは画像以外の情報なので、16byte以降から末尾までを読み込み、画像サイズ(28 x 28 = 784)ごとの配列で分割して二次元配列にして、それをNArrayの配列データに変換します。

ラベルデータのロード

 ラベルデータのフォーマットは下記の通りです。

TRAINING SET LABEL FILE (train-labels-idx1-ubyte):

[offset] [type]          [value]          [description] 
0000     32 bit integer  0x00000801(2049) magic number (MSB first) 
0004     32 bit integer  60000            number of items 
0008     unsigned byte   ??               label 
0009     unsigned byte   ??               label 
........ 
xxxx     unsigned byte   ??               label
The labels values are 0 to 9.

TEST SET LABEL FILE (t10k-labels-idx1-ubyte):

[offset] [type]          [value]          [description] 
0000     32 bit integer  0x00000801(2049) magic number (MSB first) 
0004     32 bit integer  10000            number of items 
0008     unsigned byte   ??               label 
0009     unsigned byte   ??               label 
........ 
xxxx     unsigned byte   ??               label
The labels values are 0 to 9.

 実装は下記の通りです。基本的には画像データの場合と同じですが、画像データと違って正解ラベルの一次配列データなので、画像サイズで分割するような処理はありません。

def load_label(file_name)
  puts "Converting #{file_name} to NArray ..."
  data = nil
  Zlib::GzipReader.open("#{DATASET_DIR}/#{file_name}") do |gz|
    data = Numo::UInt8[*gz.each_byte.to_a[8..-1]]
  end
  puts "Done"
  data
end

MNISTデータのセーブとロード

 毎回MNISTデータをサイトからダウンロードしたらgzファイルからロードするのは時間がかかるので、初回実行時にデータをローカルに保持するようにします。今回はひとまず汎用性はあまり考慮せず、このスクリプトからのみ扱う前提で、 Marshal.dump したものをファイルに保存します。

File.write(SAVE_FILE, Marshal.dump(dataset))

 次回実行時はdumpファイルが存在すればダウンロードや変換処理はスキップして、dumpファイルをロードします。

dataset = Marshal.load(File.read(SAVE_FILE))

オプション

 基本的な処理はここまでですが、書籍の内容に沿って、3つのオプションを用意します。

 まず normalize オプションは画像データの各ピクセルデータを正規化するかどうかの指定で、元のデータは 0 - 255 の範囲の数値ですが、true を指定するとこれを 0.0 - 1.0 の値に正規化します。

if normalize
  %i(train_img test_img).each do |key|
    dataset[key] = dataset[key].cast_to(Numo::DFloat)
    dataset[key] /= 255.0
  end
end

 次に one_hot_label オプションは、正解ラベルデータを one_hot 表現の配列データとして返すかどうかの指定です。元データは画像データがどの数字なのかを表す一次元配列で、例えば画像データが 3, 5, 2 であれば、[3, 5, 2] という形ですが、 one_hot 表現だと、それぞれについて要素数10の配列を用意し、正解のインデックスのみ1で、それ以外は0の二次元配列になります。

[
 [0, 0, 0, 1, 0, 0, 0, 0, 0, 0], # 3が正解なのでインデックス3の要素だけ1
 [0, 0, 0, 0, 0, 1, 0, 0, 0, 0], # 5が正解なのでインデックス5の要素だけ1
 [0, 0, 1, 0, 0, 0, 0, 0, 0, 0]  # 2が正解なのでインデックス2の要素だけ1
]
def change_one_hot_label(x)
  one_hot_arrays = x.to_a.map do |v|
    one_hot_array = Array.new(10, 0)
    one_hot_array[v] = 1
    one_hot_array
  end
  Numo::UInt8[*one_hot_arrays]
end
if one_hot_label
  dataset[:train_label] = change_one_hot_label(dataset[:train_label])
  dataset[:test_label] = change_one_hot_label(dataset[:test_label])
end

 最後に flatten オプションは各画像のデータを 28 x 28 の二次元配列として返すか、784要素の一次元配列として返すかの指定になります。元データは784要素の一次元配列なので、flatten オプションに false が指定されれば 28 x 28 の二次元配列に変換します。

unless flatten
  %i(train_img test_img).each do |key|
    dataset[key] = dataset[key].reshape(dataset[key].shape[0], 28, 28)
  end
end

ロードしたMNISTデータを参照するサンプル

 上記コードによって読み込んだMNISTデータを参照してみます。上記のロード用コードは mnist.rb として保存しています。

require 'rmagick'
require './mnist.rb'

x_train, t_train, x_test, t_test = load_mnist

img = x_train[0, true]
label = t_train[0]
puts img.shape
puts img.max
puts img.min
puts label

image = Magick::Image.new(28, 28)
image.import_pixels(0, 0, 28, 28, 'I', img, Magick::FloatPixel)
image.display

 画像データとラベルデータからそれぞれ一つ目を取り出し、ラベルはターミナルに表示しています。画像データは ImageMagick(RMagick)で 28 x 28 の画像オブジェクトを作成した上で、 import_pixelsメソッドでインポートして、displayメソッドで別ウィンドウに表示します。先頭データは5なので、5の画像が表示されます。

 今回実装したコードは下記にも公開しています。

github.com