平方根平均二乗誤差の計算部分のコード

 今回も引き続き書籍のサンプルコードを ruby で実装します。今回は平方根平均二乗誤差を計算するメソッドです。サンプルコードのうちの下記部分になります。渡された多項式で平方根平均二乗誤差を計算しています。

# 平方根平均二乗誤差(Root mean square error)を計算
def rms_error(dataset, f):
    err = 0.0
    for index, line in dataset.iterrows():
        x, y = line.x, line.y
        err += 0.5 * (y - f(x))**2
    return np.sqrt(2 * err / len(dataset))

 引用元のサンプルスクリプトの全体は下記で公開されています。

github.com

行列データのイテレーション

 python では pandas.DataFrame.iterrows() でデータセットをイテレーションして処理しています。

http://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.iterrows.html

    err = 0.0
    for index, line in dataset.iterrows():
        x, y = line.x, line.y
        err += 0.5 * (y - f(x))**2

 ruby では Daru::DataFrame#each_row_with_index が該当します。

http://www.rubydoc.info/gems/daru/0.1.0/Daru%2FDataFrame%3Aeach_row_with_index

 今回は index は使用されていないので、 Daru::DataFrame.each_row を使います。

http://www.rubydoc.info/gems/daru/0.1.0/Daru%2FDataFrame%3Aeach_row

  err = 0.0
  dataset.each_row do |line|
    err += 0.5 * (line.y - f.call(line.x))**2
  end

行列データの行数取得

 python では len() でデータセットの行数を取得できます。

>>> dataset          
          x         y
0  0.000000 -0.054637
1  0.111111  0.368772
2  0.222222  0.928976
3  0.333333  0.363668
4  0.444444  0.024165
5  0.555556 -0.521905
6  0.666667 -0.830954
7  0.777778 -0.464237
8  0.888889  0.058286
9  1.000000  0.183302
>>>                  
>>> len(dataset)     
10                   

 ruby では Daru::DataFrame#size で行数を取得できます。

http://www.rubydoc.info/gems/daru/0.1.0/Daru%2FDataFrame%3Asize

irb(main):018:0* dataset
=> #<Daru::DataFrame(10x2)>
                     x          y
          0        0.0 -0.5455317
          1 0.11111111 0.86331945
          2 0.22222222 1.00393856
          3 0.33333333 0.82127838
          4 0.44444444 0.49240164
          5 0.55555555 -0.0866971
          6 0.66666666 -0.2723084
          7 0.77777777 -0.9982033
          8 0.88888888 -0.3344831
          9        1.0 -0.2558949
irb(main):019:0> 
irb(main):020:0* dataset.size
=> 10

 普通の配列であれば #size でも #count でも基本的には同じ結果になりますが、 Daru::DataFrame のオブジェクトで #count を使うとそれぞれの列のnullではない値の行数を返してくれます。

http://www.rubydoc.info/gems/daru/0.1.0/Daru%2FMaths%2FStatistics%2FDataFrame%3Acount

irb(main):027:0* dataset.count
=> #<Daru::Vector(2)>
       count
     x    10
     y    10

平方根

 python では numpy.sqrt で平方根を求めています。

https://docs.scipy.org/doc/numpy/reference/generated/numpy.sqrt.html

>>> np.sqrt(2)    
1.4142135623730951
>>> np.sqrt(4)    
2.0               

 ruby では Math.sqrt で平方根を求めます。

https://docs.ruby-lang.org/ja/latest/method/Math/m/sqrt.html

irb(main):029:0* Math.sqrt(2)
=> 1.4142135623730951
irb(main):030:0> Math.sqrt(4)
=> 2.0

rms_error メソッドを ruby で

 ここまでの内容を踏まえて rms_error メソッドを ruby で実装します。

def rms_error(dataset, f)
  err = 0.0
  dataset.each_row do |line|
    err += 0.5 * (line.y - f.call(line.x))**2
  end

  Math.sqrt(2 * err / dataset.size)
end

 python版とruby版それぞれの実行結果は下記のようになりました。

>>> dataset                      
          x         y            
0  0.000000 -0.380472            
1  0.111111  0.719204            
2  0.222222  0.909939            
3  0.333333  1.022670            
4  0.444444  0.254537            
5  0.555556 -0.127610            
6  0.666667 -0.557395            
7  0.777778 -0.652528            
8  0.888889 -0.436643            
9  1.000000  0.450449            
>>>                              
>>> f, ws = resolve(train_set, 3)
>>>                              
>>> rms_error(train_set, f)      
0.23852710875750457              
irb(main):012:0* train_set
=> #<Daru::DataFrame(10x2)>
                     x          y
          0        0.0 -0.1806039
          1 0.11111111 1.22769157
          2 0.22222222 1.17466241
          3 0.33333333 0.67013919
          4 0.44444444 0.61218251
          5 0.55555555 -0.3105334
          6 0.66666666 -0.7770481
          7 0.77777777 -0.8269982
          8 0.88888888 -0.3653022
          9        1.0 -0.0443871
irb(main):013:0> 
irb(main):014:0* f, ws = resolve(train_set, 3)
=> [#<Proc:0x007fb60ddd3798@/vagrant/02-square_error.rb:45 (lambda)>, Vector[0.012228114152125085, 11.30192331783317, -32.97851812391701, 21.81003531172303]]
irb(main):015:0> 
irb(main):016:0* rms_error(train_set, f)
=> 0.20727664859912284