どうも、カタミチです。
さて、今日も「最短コースでわかる ディープラーニングの数学」のひとり読書会、やっていきたいと思います。
今日は8章「ロジスティック回帰モデル(2値分類)」の続きですね。張り切っていってみよ〜
8-4. 損失関数(交差エントロピー関数)
今回は損失関数の導出です。なかなか難易度が高くて、理解するのに苦戦しましたねー。
まず、予測モデルが正解値\(yt=1\)になる確率として定義され、それが\(yp\)となるってことから、逆に\(yt=0\)になる確率は\(1-yp\)ってことになります。これは、確率・統計の最尤推定の節で出てきた、「当たり」と「はずれ」の例と同じ考えが適用できますね。
つまり、尤度関数…というかそれに対数をとった対数尤度関数を損失関数として使うことになります。本書では、まずはデータを5つに絞ったところから拡張するロジックで説明してくれていて、分かりやすかったです。
ここでは、自分の理解を確かめる意味で、初めからM個のデータで式を組み上げてみます。
尤度関数を\(Lk\)とすると…
\(Lk = P^{(0)} \cdot P^{(1)} \cdot P^{(2)} \cdots P^{(M-1)}\)
(Mはデータ数)
となるので、両辺に対数をとると…
\(\log{Lk} = \log{(P^{(0)})} + \log{(P^{(1)})} + \log{(P^{(2)})} + \cdots + \log{(P^{(M-1)})}\)
ということになります。
ここで、尤度関数は最大値を探ることで確率が最大になる場所を探るやり方だったのに対し、勾配降下法の損失関数は最小にする必要があるので、符号を反転します。かつ、件数の影響を避けるためにデータ個数\(M\)で割ります。これは線形回帰のときと同じ考え方ですね。ということで、上の\(\log{Lk}\)をいじって\(L(w_0,w_1,w_2)\)とします。足し算の部分も\(\sum\)でまとめますねー。
\(L(w_0,w_1,w_2) = -\frac{1}{M} \sum_{m=0}^{M-1} \log{(P^{(m)})} \)
で、問題はこの\(P^{(m)}\)ってやつです。正解値\(yt\)ごとに、予測値が\(yp\)なのか\( (1-yp)\)なのかが、ふらふら変化します。プログラムで組むなら「場合分け」を駆使すればできるのかなぁ…と妄想もしてみたんですが、巧妙に式を組むことで、1つの式で表現できるようです。こうですね…どん!
\(L(w_0,w_1,w_2) = -\frac{1}{M} \sum_{m=0}^{M-1} ( (yt^{(m)} \cdot \log (yp^{(m)}) + (1-yt^{(m)}) \log (1-yp^{(m)})) \)
理解するのに結構時間がかかりましたが、実に見事な変換です。\(yt=1\)の時は式の右半分が消えて\(\log (yp^{(m)})\)が残り、\(yt=0\)の時は式の左半分が消えて、\(\log (1-yp^{(m)})\)が残る。バッチリ\(\log{(P^{(m)})}\)を表現してますね!
こういうのを思いつく人が居るから、機械学習のモデルもどんどん洗練されていくんでしょうねー。実に味わい深い表現です。
ということで上の式が今回の損失関数になります。
で、ついでに次の節の下ごしらえがされてました。損失関数のシグマの中の式を
\(ce = -(yt \log (yp) + (1-yt) \log (1-yp)) \)と置くと、それを\(yp\)で微分したものは…
\(\frac{d(ce)}{d(yp)} = \frac{yp-yt}{yp(1-yp)} \)
となります。微分自体は簡単ですね。次の節で使うみたいなので、とりあえず押さえておきますかねー。
ということで
今回は、考え方がなかなか難しかったので、繰り返し読み込みました。今回特に痛感したのが、基本的な数学の知識を得るだけでは、機械学習にはそのままは使えないって事です。
理論編で語られた式たちから上の損失関数も導くことができますが、とても自力でモデルを発見して損失関数を求めるところまで辿り着けるとは思えない…。やはり、各モデルで使われる式に落とし込むまでを数学知識と捉えておく方が良さそうですね。
逆にいうと、機械学習とはすなわち数学だってことになりますかねー。今後もその心構えでモデルを見ていく必要があるかなー、と、ヒシヒシと感じました。
ではまた。