今回は前回に引き続き書籍のサンプルコードの下記部分を ruby で実装します。最小二乗法の公式を用いて係数を計算するメソッドです。 メソッドの戻り値として、決定された多項式と係数を返しています。
# 最小二乗法で解を求める def resolve(dataset, m): t = dataset.y phi = DataFrame() for i in range(0,m+1): p = dataset.x**i p.name="x**%d" % i phi = pd.concat([phi,p], axis=1) tmp = np.linalg.inv(np.dot(phi.T, phi)) ws = np.dot(np.dot(tmp, phi.T), t) def f(x): y = 0 for i, w in enumerate(ws): y += w * (x ** i) return y return (f, ws)
引用元のサンプルスクリプトの全体は下記で公開されています。
行列の結合
python では pandas の concat メソッドで複数の pandas オブジェクトを結合します。
http://pandas.pydata.org/pandas-docs/stable/generated/pandas.concat.html
オプションとして axis に 0 を指定すると縦方向の結合、1 を指定すると横方向の結合になります。 ここでは 1 を指定していますので、まず空の DataFrame を用意して、そこに dataset の x 列を i 乗した Series を順次横方向に追加してく形になります。
ruby では空の Daru::DataFrame インスタンスは作成できないようでしたので、まず Hash に行列の内容を構成して、それを元に Daru::DataFrame のインスンタンスを作成しました。
columns = {} (m+1).times do |i| columns["x**#{i}"] = dataset.x ** i end phi = Daru::DataFrame.new(columns)
転置行列
python では転置行列は pandas.DataFrame.T メソッドで取得することができます。
http://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.T.html
ruby では Daru::DataFrame#transpose メソッドで取得することができます。
http://www.rubydoc.info/gems/daru/0.1.0/Daru%2FDataFrame%3Atranspose
irb(main):016:0* dataset = create_dataset(10) => #<Daru::DataFrame(10x2)> x y 0 0.0 0.35893030 1 0.11111111 0.75466334 2 0.22222222 1.16130076 3 0.33333333 0.76565346 4 0.44444444 0.56397365 5 0.55555555 -0.5168586 6 0.66666666 -1.2497477 7 0.77777777 -0.7234279 8 0.88888888 -0.4937113 9 1.0 -0.4791075 irb(main):017:0> irb(main):018:0* dataset.transpose => #<Daru::DataFrame(2x10)> 0 1 2 3 4 5 6 7 8 9 x 0.0 0.11111111 0.22222222 0.33333333 0.44444444 0.55555555 0.66666666 0.77777777 0.88888888 1.0 y 0.35893030 0.75466334 1.16130076 0.76565346 0.56397365 -0.5168586 -1.2497477 -0.7234279 -0.4937113 -0.479 075 irb(main):019:0>
行列の積
python では pandas.DataFrame.dot メソッドで行列の積を計算できます。
http://pandas.pydata.org/pandas-docs/stable/generated/pandas.DataFrame.dot.html
ruby では Daru::DataFrame で * メソッドが定義されているので、* で積を計算することができるのですが、今回は元の行列と転置行列をかけているため、indexに数値と文字列が混ざることになり、下記のようなエラーになってしまいます。
irb(main):024:0* dataset => #<Daru::DataFrame(10x2)> x y 0 0.0 -0.0742627 1 0.11111111 0.64832694 2 0.22222222 1.62979372 3 0.33333333 1.16074147 4 0.44444444 0.19131551 5 0.55555555 0.09922296 6 0.66666666 -0.6080503 7 0.77777777 -0.9894763 8 0.88888888 -0.4535080 9 1.0 0.03518189 irb(main):025:0> irb(main):026:0* dataset * dataset => #<Daru::DataFrame(10x2)> x y 0 0.0 0.00551495 1 0.01234567 0.42032782 2 0.04938271 2.65622758 3 0.11111111 1.34732077 4 0.19753086 0.03660162 5 0.30864197 0.00984519 6 0.44444444 0.36972517 7 0.60493827 0.97906348 8 0.79012345 0.20566954 9 1.0 0.00123776 irb(main):027:0> irb(main):028:0* dataset * dataset.transpose ArgumentError: comparison of Symbol with 0 failed from /usr/local/rbenv/versions/2.3.1/lib/ruby/gems/2.3.0/gems/daru-0.1.4.1/lib/daru/maths/arithmetic/dataframe.rb:62:in `sort' from /usr/local/rbenv/versions/2.3.1/lib/ruby/gems/2.3.0/gems/daru-0.1.4.1/lib/daru/maths/arithmetic/dataframe.rb:62:in `dataframe_binary_operation' from /usr/local/rbenv/versions/2.3.1/lib/ruby/gems/2.3.0/gems/daru-0.1.4.1/lib/daru/maths/arithmetic/dataframe.rb:55:in `binary_operation' from /usr/local/rbenv/versions/2.3.1/lib/ruby/gems/2.3.0/gems/daru-0.1.4.1/lib/daru/maths/arithmetic/dataframe.rb:18:in `*' from (irb):28 from /usr/local/rbenv/versions/2.3.1/bin/irb:11:in `<main>'
なので今回は Daru::DataFrame#to_matrix メソッドで標準の Matrix クラスに変換した上で * メソッドで計算します。
http://www.rubydoc.info/gems/daru/0.1.0/Daru%2FDataFrame%3Ato_matrix
irb(main):034:0* dataset.to_matrix * dataset.transpose.to_matrix => Matrix[[0.00551495178032287, -0.04814652292445256, -0.12103291720768458, -0.08619982081517798, -0.014207610754013434, -0.007368567440448755, 0.045155470293611104, 0.07348120775362572, 0.033678741797590464, -0.002612703024438115], [-0.04814652292445256, 0.4326735009138319, 1.0813305373090474, 0.7895770073355075, 0.1734177183698852, 0.12605731675017345, -0.3201413193247968, -0.555084433087857, -0.19525605324040052, 0.13392047936519433], [-0.12103291720768458, 1.0813305373090474, 2.705610300861487, 1.9658432484239334, 0.4105702573458222, 0.28516975670661376, -0.8428484217621567, -1.4398028683439108, -0.5415937066084376, 0.2795614486454444], [-0.08619982081517798, 0.7895770073355075, 1.9658432484239334, 1.4584318864404977, 0.3702160011396125, 0.3003573967170241, -0.4835669855859539, -0.8892670006157779, -0.23010930503481497, 0.37417041433074033], [-0.014207610754013434, 0.1734177183698852, 0.4105702573458222, 0.3702160011396125, 0.23413249035458505, 0.2658964729681377, 0.17996683941236247, 0.1563768318219559, 0.30829860276768317, 0.45117528617537195], [-0.007368567440448755, 0.12605731675017345, 0.28516975670661376, 0.3003573967170241, 0.2658964729681377, 0.31848717220014827, 0.3100378159950401, 0.33391998590584915, 0.4488287470660896, 0.5590464071902203], [0.045155470293611104, -0.3201413193247968, -0.8428484217621567, -0.4835669855859539, 0.17996683941236247, 0.3100378159950401, 0.8141696167863541, 1.120169924736844, 0.8683482991417188, 0.6452743066777754], [0.07348120775362572, -0.555084433087857, -1.4398028683439108, -0.8892670006157779, 0.1563768318219559, 0.33391998590584915, 1.120169924736844, 1.5840017535734776, 1.1400935207548575, 0.7429661273074588], [0.033678741797590464, -0.19525605324040052, -0.5415937066084376, -0.23010930503481497, 0.30829860276768317, 0.4488287470660896, 0.8683482991417188, 1.1400935207548575, 0.9957930064525073, 0.872933617825996], [-0.002612703024438115, 0.13392047936519433, 0.2795614486454444, 0.37417041433074033, 0.45117528617537195, 0.5590464071902203, 0.6452743066777754, 0.7429661273074588, 0.872933617825996, 1.001237765508352]]
逆行列
python では pandas.linalg.inv メソッドで逆行列が取得できます。
https://docs.scipy.org/doc/numpy-1.10.0/reference/generated/numpy.linalg.inv.html
ruby では Matrix#inv メソッドで取得できます。
irb(main):050:0* (dataset.to_matrix * dataset.transpose.to_matrix).inv => Matrix[[6.438056238931283e+17, -3.5959222998233816e+16, 4.88081712482937e+16, -4.18039023242074e+16, -4.90936881765618e+16, 1.9473195412556384e+16, 3.021796172248485e+16, -1.1756838226941134e+17, 1.104440295083465e+17, -8.790885266451936e+15], [7.894377025631253e+15, 4.130512217329567e+15, 1.226948454741468e+15, -1.080925575443484e+15, 9.51855025409944e+15, 1576790734528046.5, -1355771888949234.5, 1.191314652268012e+16, -1.3688955666893698e+16, -1.671726231048433e+15], [-4.530896280313147e+15, 6.832627236988067e+15, -256467997105680.2, -2.854089097619272e+15, -361754002852689.25, 1442321524185798.5, -293843073250848.9, 2.291785009164425e+15, -2164437097239707.2, -53995045129483.5], [6.56400708662395e+15, -8.278847572597992e+15, 2945428805276741.5, 611874690847290.5, -8.040224399925562e+15, 4.146425733346407e+15, 1.0063070988405478e+16, -1.099361383748988e+16, 7.012887589668771e+15, -3.060566106986782e+15], [2.0276047297689546e+17, -2.150234284194681e+16, 2.325837285911574e+15, 1.3855500557957406e+16, -9.66616800705062e+15, -1.4971038030737476e+16, -6.859023781680976e+15, -2.61127868104491e+16, 3.872186122638793e+16, 330309925585268.0], [1.1925716211143941e+17, -1.6248336393828732e+16, 6.71882570034233e+15, 7.539028212975366e+15, -2.6810813298632516e+16, 1.540682966685851e+16, 8.331481526823681e+15, -2.4692863560056996e+16, 3.42380291696946e+16, -1.5626708042627156e+16], [-6.604745687626929e+15, 727565952289667.5, 1.286427881109916e+15, 1508601318395916.8, -1.6562628335161202e+16, -4.61937542936672e+15, 2.284703494536934e+15, -3.360855015175375e+15, 7.170630317185643e+15, 3774876268811336.5], [6.0918212393723944e+16, -3.266090027295123e+15, 1.229518586531585e+15, 4197189849440108.5, -2.435086220896028e+15, 4.660830138590578e+15, 5.09467236899953e+15, -3.794140272523739e+15, 1.59348998039823e+15, -4.678366072552667e+15], [-1.1382587909489373e+17, 9.677645630477756e+15, -2.030349802911037e+15, -7.640675796532893e+15, 1.8680150397993788e+16, 7.114574957016326e+15, 1.658128693269082e+15, 1.4272195974858228e+16, -1.875905159959477e+16, -5.863347355441972e+15], [-1.00227143117868e+17, 1.2818655421880554e+16, -5.836686165645619e+15, -7.27446600919549e+15, 1.7224948871781866e+16, -1.0652985027149534e+16, -1.1678113837058996e+16, 1.966068157926607e+16, -2.591158029231125e+16, 1.6086644321288758e+16]]
resolveメソッドをrubyで
ここまでの内容を踏まえてサンプルコードの resolve メソッドを ruby で実装します。
def resolve(dataset, m) t = dataset.y columns = {} (m+1).times do |i| columns["x**#{i}"] = dataset.x ** i end phi = Daru::DataFrame.new(columns) tmp = (phi.transpose.to_matrix * phi.to_matrix).inv ws = (tmp * phi.transpose.to_matrix) * Vector.elements(t.to_a) f = lambda {|x| y = 0 ws.each_with_index do |w, i| y = y + w * (x ** i) end y } return f, ws end
下記については、 (tmp * phi.transpose.to_matrix) の結果は標準の Matrix クラスになり、t は Daru::Vector になるのですが、 標準の Matrix と Daru::Vector はそのまま内積が計算できないので、 Daru::Vector を標準の Vector クラスに変換した上で内積を計算しています。
tmp = (phi.transpose.to_matrix * phi.to_matrix).inv
ws = (tmp * phi.transpose.to_matrix) * Vector.elements(t.to_a)
また、python版ではメソッド内でさらに def でメソッドを定義して return でメソッドを返している部分を、 ruby では lambda を返すように実装しています。
python版とruby版それぞれの実行結果は下記のようになりました。
>>> train_set x y 0 0.000000 -0.333941 1 0.111111 1.425262 2 0.222222 0.853840 3 0.333333 0.982145 4 0.444444 0.244465 5 0.555556 -0.136371 6 0.666667 -0.961705 7 0.777778 -1.462275 8 0.888889 -0.668861 9 1.000000 -0.216649 >>> >>> f, ws = resolve(train_set, 3) >>> >>> type(f) <type 'function'> >>> >>> ws array([ -0.12912272, 13.31977786, -38.77768783, 25.50897351])
irb(main):072:0* train_set => #<Daru::DataFrame(10x2)> x y 0 0.0 -0.3822801 1 0.11111111 0.74486105 2 0.22222222 0.83353384 3 0.33333333 0.53626611 4 0.44444444 0.31220508 5 0.55555555 -0.6527569 6 0.66666666 -1.2765675 7 0.77777777 -0.9600806 8 0.88888888 -0.7323687 9 1.0 -0.5115570 irb(main):073:0> irb(main):074:0* f, ws = resolve(train_set, 3) => [#<Proc:0x007f496bd59c90@/vagrant/02-square_error.rb:42 (lambda)>, Vector[-0.2636530068380138, 10.784779039058368, -31.470536541107172, 20.64727993759569]] irb(main):075:0> irb(main):076:0* f.class => Proc irb(main):077:0> irb(main):078:0* ws => Vector[-0.2636530068380138, 10.784779039058368, -31.470536541107172, 20.64727993759569]