今回も引き続き書籍のサンプルコードを ruby で実装します。今回は平方根平均二乗誤差を計算するメソッドです。サンプルコードのうちの下記部分になります。渡された多項式で平方根平均二乗誤差を計算しています。
# 平方根平均二乗誤差(Root mean square error)を計算 def rms_error(dataset, f): err = 0.0 for index, line in dataset.iterrows(): x, y = line.x, line.y err += 0.5 * (y - f(x))**2 return np.sqrt(2 * err / len(dataset))
引用元のサンプルスクリプトの全体は下記で公開されています。
行列データのイテレーション
python では pandas.DataFrame.iterrows() でデータセットをイテレーションして処理しています。
http://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.iterrows.html
err = 0.0 for index, line in dataset.iterrows(): x, y = line.x, line.y err += 0.5 * (y - f(x))**2
ruby では Daru::DataFrame#each_row_with_index が該当します。
http://www.rubydoc.info/gems/daru/0.1.0/Daru%2FDataFrame%3Aeach_row_with_index
今回は index は使用されていないので、 Daru::DataFrame.each_row を使います。
http://www.rubydoc.info/gems/daru/0.1.0/Daru%2FDataFrame%3Aeach_row
err = 0.0 dataset.each_row do |line| err += 0.5 * (line.y - f.call(line.x))**2 end
行列データの行数取得
python では len() でデータセットの行数を取得できます。
>>> dataset x y 0 0.000000 -0.054637 1 0.111111 0.368772 2 0.222222 0.928976 3 0.333333 0.363668 4 0.444444 0.024165 5 0.555556 -0.521905 6 0.666667 -0.830954 7 0.777778 -0.464237 8 0.888889 0.058286 9 1.000000 0.183302 >>> >>> len(dataset) 10
ruby では Daru::DataFrame#size で行数を取得できます。
http://www.rubydoc.info/gems/daru/0.1.0/Daru%2FDataFrame%3Asize
irb(main):018:0* dataset => #<Daru::DataFrame(10x2)> x y 0 0.0 -0.5455317 1 0.11111111 0.86331945 2 0.22222222 1.00393856 3 0.33333333 0.82127838 4 0.44444444 0.49240164 5 0.55555555 -0.0866971 6 0.66666666 -0.2723084 7 0.77777777 -0.9982033 8 0.88888888 -0.3344831 9 1.0 -0.2558949 irb(main):019:0> irb(main):020:0* dataset.size => 10
普通の配列であれば #size でも #count でも基本的には同じ結果になりますが、 Daru::DataFrame のオブジェクトで #count を使うとそれぞれの列のnullではない値の行数を返してくれます。
http://www.rubydoc.info/gems/daru/0.1.0/Daru%2FMaths%2FStatistics%2FDataFrame%3Acount
irb(main):027:0* dataset.count => #<Daru::Vector(2)> count x 10 y 10
平方根
python では numpy.sqrt で平方根を求めています。
https://docs.scipy.org/doc/numpy/reference/generated/numpy.sqrt.html
>>> np.sqrt(2) 1.4142135623730951 >>> np.sqrt(4) 2.0
ruby では Math.sqrt で平方根を求めます。
https://docs.ruby-lang.org/ja/latest/method/Math/m/sqrt.html
irb(main):029:0* Math.sqrt(2) => 1.4142135623730951 irb(main):030:0> Math.sqrt(4) => 2.0
rms_error メソッドを ruby で
ここまでの内容を踏まえて rms_error メソッドを ruby で実装します。
def rms_error(dataset, f) err = 0.0 dataset.each_row do |line| err += 0.5 * (line.y - f.call(line.x))**2 end Math.sqrt(2 * err / dataset.size) end
python版とruby版それぞれの実行結果は下記のようになりました。
>>> dataset x y 0 0.000000 -0.380472 1 0.111111 0.719204 2 0.222222 0.909939 3 0.333333 1.022670 4 0.444444 0.254537 5 0.555556 -0.127610 6 0.666667 -0.557395 7 0.777778 -0.652528 8 0.888889 -0.436643 9 1.000000 0.450449 >>> >>> f, ws = resolve(train_set, 3) >>> >>> rms_error(train_set, f) 0.23852710875750457
irb(main):012:0* train_set => #<Daru::DataFrame(10x2)> x y 0 0.0 -0.1806039 1 0.11111111 1.22769157 2 0.22222222 1.17466241 3 0.33333333 0.67013919 4 0.44444444 0.61218251 5 0.55555555 -0.3105334 6 0.66666666 -0.7770481 7 0.77777777 -0.8269982 8 0.88888888 -0.3653022 9 1.0 -0.0443871 irb(main):013:0> irb(main):014:0* f, ws = resolve(train_set, 3) => [#<Proc:0x007fb60ddd3798@/vagrant/02-square_error.rb:45 (lambda)>, Vector[0.012228114152125085, 11.30192331783317, -32.97851812391701, 21.81003531172303]] irb(main):015:0> irb(main):016:0* rms_error(train_set, f) => 0.20727664859912284