どうも、カタミチです。
さて、今日も「最短コースでわかる PyTorch&深層学習プログラミング」のひとり読書会、やっていきますかねー。
今日もまた、線形回帰の続きですね。前回までで何かを成し遂げた気がしていましたが、実際には予測計算をしただけだったので、まだまだ入口ですね。今回は、その次の損失計算に入ります。
3-7. 損失関数
線形回帰の損失関数は「平均2乗誤差」で定義するんでしたね。正解値と予測値の差を2乗して平均を取るやつ。
式にするとこんな感じです。
どん!
関数名に「mse」と言うのが使われていますが、調べたら「平均2乗誤差」を表す「Mean Squared Error」の略で、統計学や機械学習の分野では一般的な用語のようです。
で、これをそのまま「loss」という変数に突っ込みましょう。えい。
ちなみに、mse関数内で定義されてる引数や戻り値の変数名と、実際に引数として渡してる変数や戻り値の格納先の変数名が一致してますが、これは必ずしも一致してる必要はないですかね。
で、計算グラフを見てみましょう。
YpがWとBを含む式なので、lossにもWとBがつながってますね。ということで、こう書けば…
計算グラフが見れますね。
実行してみましょう…
どん!
おー、計算グラフに新たなハコが3つ追加されましたね。平均2乗誤差の計算ですね。
ということで
計算グラフもだいぶ育ってきましたね。これからも、こうやって確認しながら式を組んでいく…って事なんでしょうねー。次は勾配計算ですね。
ではまた。