最小二乗法で解を求めるコード

 今回は前回に引き続き書籍のサンプルコードの下記部分を ruby で実装します。最小二乗法の公式を用いて係数を計算するメソッドです。 メソッドの戻り値として、決定された多項式と係数を返しています。

# 最小二乗法で解を求める
def resolve(dataset, m):
    t = dataset.y
    phi = DataFrame()
    for i in range(0,m+1):
        p = dataset.x**i
        p.name="x**%d" % i
        phi = pd.concat([phi,p], axis=1)
    tmp = np.linalg.inv(np.dot(phi.T, phi))
    ws = np.dot(np.dot(tmp, phi.T), t)

    def f(x):
        y = 0
        for i, w in enumerate(ws):
            y += w * (x ** i)
        return y

    return (f, ws)

 引用元のサンプルスクリプトの全体は下記で公開されています。

github.com

行列の結合

 python では pandas の concat メソッドで複数の pandas オブジェクトを結合します。

http://pandas.pydata.org/pandas-docs/stable/generated/pandas.concat.html

 オプションとして axis に 0 を指定すると縦方向の結合、1 を指定すると横方向の結合になります。 ここでは 1 を指定していますので、まず空の DataFrame を用意して、そこに dataset の x 列を i 乗した Series を順次横方向に追加してく形になります。

 ruby では空の Daru::DataFrame インスタンスは作成できないようでしたので、まず Hash に行列の内容を構成して、それを元に Daru::DataFrame のインスンタンスを作成しました。

  columns = {}
  (m+1).times do |i|
    columns["x**#{i}"] = dataset.x ** i
  end
  phi = Daru::DataFrame.new(columns)

転置行列

python では転置行列は pandas.DataFrame.T メソッドで取得することができます。

http://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.T.html

ruby では Daru::DataFrame#transpose メソッドで取得することができます。

http://www.rubydoc.info/gems/daru/0.1.0/Daru%2FDataFrame%3Atranspose

irb(main):016:0* dataset = create_dataset(10)
=> #<Daru::DataFrame(10x2)>
                     x          y
          0        0.0 0.35893030
          1 0.11111111 0.75466334
          2 0.22222222 1.16130076
          3 0.33333333 0.76565346
          4 0.44444444 0.56397365
          5 0.55555555 -0.5168586
          6 0.66666666 -1.2497477
          7 0.77777777 -0.7234279
          8 0.88888888 -0.4937113
          9        1.0 -0.4791075
irb(main):017:0> 
irb(main):018:0* dataset.transpose
=> #<Daru::DataFrame(2x10)>
                     0          1          2          3          4          5          6          7          8          9
          x        0.0 0.11111111 0.22222222 0.33333333 0.44444444 0.55555555 0.66666666 0.77777777 0.88888888 1.0
          y 0.35893030 0.75466334 1.16130076 0.76565346 0.56397365 -0.5168586 -1.2497477 -0.7234279 -0.4937113 -0.479 075
irb(main):019:0> 

行列の積

 python では pandas.DataFrame.dot メソッドで行列の積を計算できます。

http://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.dot.html

 ruby では Daru::DataFrame で * メソッドが定義されているので、* で積を計算することができるのですが、今回は元の行列と転置行列をかけているため、indexに数値と文字列が混ざることになり、下記のようなエラーになってしまいます。

irb(main):024:0* dataset
=> #<Daru::DataFrame(10x2)>
                     x          y
          0        0.0 -0.0742627
          1 0.11111111 0.64832694
          2 0.22222222 1.62979372
          3 0.33333333 1.16074147
          4 0.44444444 0.19131551
          5 0.55555555 0.09922296
          6 0.66666666 -0.6080503
          7 0.77777777 -0.9894763
          8 0.88888888 -0.4535080
          9        1.0 0.03518189
irb(main):025:0> 
irb(main):026:0* dataset * dataset
=> #<Daru::DataFrame(10x2)>
                     x          y
          0        0.0 0.00551495
          1 0.01234567 0.42032782
          2 0.04938271 2.65622758
          3 0.11111111 1.34732077
          4 0.19753086 0.03660162
          5 0.30864197 0.00984519
          6 0.44444444 0.36972517
          7 0.60493827 0.97906348
          8 0.79012345 0.20566954
          9        1.0 0.00123776
irb(main):027:0> 
irb(main):028:0* dataset * dataset.transpose
ArgumentError: comparison of Symbol with 0 failed
        from /usr/local/rbenv/versions/2.3.1/lib/ruby/gems/2.3.0/gems/daru-0.1.4.1/lib/daru/maths/arithmetic/dataframe.rb:62:in `sort'
        from /usr/local/rbenv/versions/2.3.1/lib/ruby/gems/2.3.0/gems/daru-0.1.4.1/lib/daru/maths/arithmetic/dataframe.rb:62:in `dataframe_binary_operation'
        from /usr/local/rbenv/versions/2.3.1/lib/ruby/gems/2.3.0/gems/daru-0.1.4.1/lib/daru/maths/arithmetic/dataframe.rb:55:in `binary_operation'
        from /usr/local/rbenv/versions/2.3.1/lib/ruby/gems/2.3.0/gems/daru-0.1.4.1/lib/daru/maths/arithmetic/dataframe.rb:18:in `*'
        from (irb):28
        from /usr/local/rbenv/versions/2.3.1/bin/irb:11:in `<main>'

 なので今回は Daru::DataFrame#to_matrix メソッドで標準の Matrix クラスに変換した上で * メソッドで計算します。

http://www.rubydoc.info/gems/daru/0.1.0/Daru%2FDataFrame%3Ato_matrix

irb(main):034:0* dataset.to_matrix * dataset.transpose.to_matrix
=> Matrix[[0.00551495178032287, -0.04814652292445256, -0.12103291720768458, -0.08619982081517798, -0.014207610754013434, -0.007368567440448755, 0.045155470293611104, 0.07348120775362572, 0.033678741797590464, -0.002612703024438115], [-0.04814652292445256, 0.4326735009138319, 1.0813305373090474, 0.7895770073355075, 0.1734177183698852, 0.12605731675017345, -0.3201413193247968, -0.555084433087857, -0.19525605324040052, 0.13392047936519433], [-0.12103291720768458, 1.0813305373090474, 2.705610300861487, 1.9658432484239334, 0.4105702573458222, 0.28516975670661376, -0.8428484217621567, -1.4398028683439108, -0.5415937066084376, 0.2795614486454444], [-0.08619982081517798, 0.7895770073355075, 1.9658432484239334, 1.4584318864404977, 0.3702160011396125, 0.3003573967170241, -0.4835669855859539, -0.8892670006157779, -0.23010930503481497, 0.37417041433074033], [-0.014207610754013434, 0.1734177183698852, 0.4105702573458222, 0.3702160011396125, 0.23413249035458505, 0.2658964729681377, 0.17996683941236247, 0.1563768318219559, 0.30829860276768317, 0.45117528617537195], [-0.007368567440448755, 0.12605731675017345, 0.28516975670661376, 0.3003573967170241, 0.2658964729681377, 0.31848717220014827, 0.3100378159950401, 0.33391998590584915, 0.4488287470660896, 0.5590464071902203], [0.045155470293611104, -0.3201413193247968, -0.8428484217621567, -0.4835669855859539, 0.17996683941236247, 0.3100378159950401, 0.8141696167863541, 1.120169924736844, 0.8683482991417188, 0.6452743066777754], [0.07348120775362572, -0.555084433087857, -1.4398028683439108, -0.8892670006157779, 0.1563768318219559, 0.33391998590584915, 1.120169924736844, 1.5840017535734776, 1.1400935207548575, 0.7429661273074588], [0.033678741797590464, -0.19525605324040052, -0.5415937066084376, -0.23010930503481497, 0.30829860276768317, 0.4488287470660896, 0.8683482991417188, 1.1400935207548575, 0.9957930064525073, 0.872933617825996], [-0.002612703024438115, 0.13392047936519433, 0.2795614486454444, 0.37417041433074033, 0.45117528617537195, 0.5590464071902203, 0.6452743066777754, 0.7429661273074588, 0.872933617825996, 1.001237765508352]]

逆行列

 python では pandas.linalg.inv メソッドで逆行列が取得できます。

https://docs.scipy.org/doc/numpy-1.10.0/reference/generated/numpy.linalg.inv.html

 ruby では Matrix#inv メソッドで取得できます。

irb(main):050:0* (dataset.to_matrix * dataset.transpose.to_matrix).inv                                                
=> Matrix[[6.438056238931283e+17, -3.5959222998233816e+16, 4.88081712482937e+16, -4.18039023242074e+16, -4.90936881765618e+16, 1.9473195412556384e+16, 3.021796172248485e+16, -1.1756838226941134e+17, 1.104440295083465e+17, -8.790885266451936e+15], [7.894377025631253e+15, 4.130512217329567e+15, 1.226948454741468e+15, -1.080925575443484e+15, 9.51855025409944e+15, 1576790734528046.5, -1355771888949234.5, 1.191314652268012e+16, -1.3688955666893698e+16, -1.671726231048433e+15], [-4.530896280313147e+15, 6.832627236988067e+15, -256467997105680.2, -2.854089097619272e+15, -361754002852689.25, 1442321524185798.5, -293843073250848.9, 2.291785009164425e+15, -2164437097239707.2, -53995045129483.5], [6.56400708662395e+15, -8.278847572597992e+15, 2945428805276741.5, 611874690847290.5, -8.040224399925562e+15, 4.146425733346407e+15, 1.0063070988405478e+16, -1.099361383748988e+16, 7.012887589668771e+15, -3.060566106986782e+15], [2.0276047297689546e+17, -2.150234284194681e+16, 2.325837285911574e+15, 1.3855500557957406e+16, -9.66616800705062e+15, -1.4971038030737476e+16, -6.859023781680976e+15, -2.61127868104491e+16, 3.872186122638793e+16, 330309925585268.0], [1.1925716211143941e+17, -1.6248336393828732e+16, 6.71882570034233e+15, 7.539028212975366e+15, -2.6810813298632516e+16, 1.540682966685851e+16, 8.331481526823681e+15, -2.4692863560056996e+16, 3.42380291696946e+16, -1.5626708042627156e+16], [-6.604745687626929e+15, 727565952289667.5, 1.286427881109916e+15, 1508601318395916.8, -1.6562628335161202e+16, -4.61937542936672e+15, 2.284703494536934e+15, -3.360855015175375e+15, 7.170630317185643e+15, 3774876268811336.5], [6.0918212393723944e+16, -3.266090027295123e+15, 1.229518586531585e+15, 4197189849440108.5, -2.435086220896028e+15, 4.660830138590578e+15, 5.09467236899953e+15, -3.794140272523739e+15, 1.59348998039823e+15, -4.678366072552667e+15], [-1.1382587909489373e+17, 9.677645630477756e+15, -2.030349802911037e+15, -7.640675796532893e+15, 1.8680150397993788e+16, 7.114574957016326e+15, 1.658128693269082e+15, 1.4272195974858228e+16, -1.875905159959477e+16, -5.863347355441972e+15], [-1.00227143117868e+17, 1.2818655421880554e+16, -5.836686165645619e+15, -7.27446600919549e+15, 1.7224948871781866e+16, -1.0652985027149534e+16, -1.1678113837058996e+16, 1.966068157926607e+16, -2.591158029231125e+16, 1.6086644321288758e+16]]

resolveメソッドをrubyで

ここまでの内容を踏まえてサンプルコードの resolve メソッドを ruby で実装します。

def resolve(dataset, m)
  t = dataset.y

  columns = {}
  (m+1).times do |i|
    columns["x**#{i}"] = dataset.x ** i
  end
  phi = Daru::DataFrame.new(columns)

  tmp = (phi.transpose.to_matrix * phi.to_matrix).inv
  ws = (tmp * phi.transpose.to_matrix) * Vector.elements(t.to_a)

  f = lambda {|x|
    y = 0
    ws.each_with_index do |w, i|
      y = y + w * (x ** i)
    end

    y
  }

  return f, ws
end

 下記については、 (tmp * phi.transpose.to_matrix) の結果は標準の Matrix クラスになり、t は Daru::Vector になるのですが、 標準の Matrix と Daru::Vector はそのまま内積が計算できないので、 Daru::Vector を標準の Vector クラスに変換した上で内積を計算しています。

  tmp = (phi.transpose.to_matrix * phi.to_matrix).inv
  ws = (tmp * phi.transpose.to_matrix) * Vector.elements(t.to_a)

 また、python版ではメソッド内でさらに def でメソッドを定義して return でメソッドを返している部分を、 ruby では lambda を返すように実装しています。

 python版とruby版それぞれの実行結果は下記のようになりました。

>>> train_set                                                  
          x         y                                          
0  0.000000 -0.333941                                          
1  0.111111  1.425262                                          
2  0.222222  0.853840                                          
3  0.333333  0.982145                                          
4  0.444444  0.244465                                          
5  0.555556 -0.136371                                          
6  0.666667 -0.961705                                          
7  0.777778 -1.462275                                          
8  0.888889 -0.668861                                          
9  1.000000 -0.216649                                          
>>>                                                            
>>> f, ws = resolve(train_set, 3)                               
>>>                                                            
>>> type(f)                                                    
<type 'function'>                                              
>>>                                                                                                        
>>> ws                                                         
array([ -0.12912272,  13.31977786, -38.77768783,  25.50897351])
irb(main):072:0* train_set
=> #<Daru::DataFrame(10x2)>
                     x          y
          0        0.0 -0.3822801
          1 0.11111111 0.74486105
          2 0.22222222 0.83353384
          3 0.33333333 0.53626611
          4 0.44444444 0.31220508
          5 0.55555555 -0.6527569
          6 0.66666666 -1.2765675
          7 0.77777777 -0.9600806
          8 0.88888888 -0.7323687
          9        1.0 -0.5115570
irb(main):073:0> 
irb(main):074:0* f, ws = resolve(train_set, 3)
=> [#<Proc:0x007f496bd59c90@/vagrant/02-square_error.rb:42 (lambda)>, Vector[-0.2636530068380138, 10.784779039058368, 
-31.470536541107172, 20.64727993759569]]
irb(main):075:0> 
irb(main):076:0* f.class
=> Proc
irb(main):077:0> 
irb(main):078:0* ws
=> Vector[-0.2636530068380138, 10.784779039058368, -31.470536541107172, 20.64727993759569]