デジタル・デザイン・ラボラトリーな日々

アラフィフプログラマーが数学と物理と英語を基礎からやり直す。https://qiita.com/yaju

交差エントロピーを理解してみる

はじめに

ここ数ヶ月は別の件で忙しくて、機械学習に向き合えていませんでした。
仕事で調べたり学んだことはQiitaブログの方に書いていて結構すんなり書けるんですが、このブログは数学・物理・機械学習と特化するようにしているので、パワーがないと進まない。

ここ数ヶ月で頭の片隅にあったのは、損失関数の「交差エントロピー」です。今回はこれを理解していこうと思います。

何故、気になっていたのかというと、下記サイトの誤差関数(loss)のときのTensorflow関数「tf.nn.sigmoid_cross_entropy_with_logits」があり、シグモイド関数と交差エントロピーが一緒になっているからです。

yaju3d.hatenablog.jp

# 誤差関数(loss)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=u, labels=y_)

損失関数

機械学習では学習時に、いかに答えに近い値になるように重みパラメータを調整(必要な補正量を定量的に示す)させます。その部分を担うのが損失関数となります。そして損失関数の最小値を探すことが学習のゴールとなる。 損失関数を設定する理由は、認識精度を指標するとパラメータの微分がほとんどの場所で 0 (動かなくなる)になるから。

簡単な説明では「値=損失」ということで、この損失をいかに少なくするのかということで「損失関数」となります。
ちなみに、損失関数と誤差関数の呼び方が違うだけで同義となります。
目的関数、コスト関数、誤差関数、損失関数いろいろあるけど、なにが違うのかを検討 - Qiita

損失関数は幾つか種類がありますが、機械学習では下記の2種類がよく使われます。
※理由として「誤差逆伝播」ができる関数であることのようです。

交差エントロピー誤差(cross entropy error)

二乗誤差の方は比較的分かりやすいので、交差エントロピーを理解していきます。
分類問題(識別問題)になると交差エントロピーが出てきます。

f:id:Yaju3D:20160421012315p:plain

\displaystyle E=\sum_{k=1} -y_k\log{t_k}-(1-y_k)\log{(1-t_k)}

これがある条件(分類問題)になると簡単になります。
交差エントロピー誤差では、実質、正解データt = 1 の場合にしか計算は行われません(正解データt = 0 の時には、乗算結果は常に 0 に収束するから)

\displaystyle E=\sum_{k=1} - t_k \log y_k

エントロピー

エントロピーとは平均情報量のことです。

A国、B国に置けるその日の天気を表した表を以下に定義します。

晴れ(%) 曇り(%) 雨(%) 雪(%)
A国 25 25 25 25
B国 50 25 12.5 12.5

この時それぞれの天気のエンコード方法は以下のようにするとします。

天気 ビット
晴れ 00
曇り 01
10
11

この時、メモリの容量が1KB(8000bit)だとすると、4000(日)のデータを保存できる。
\displaystyle \frac{8000(bit)}{2(bit)}=4000(日)

しかし、このエンコード方法は無駄があることがわかるでしょうか?
なぜなら、B国では晴れの日が50%である為、以下の表のように、エンコードした際に使うbit数を1にすると、容量を節約することができ、同じ8000bitのメモリで4500日以上分のデータを保存できる。

天気 ビット
晴れ 0
曇り 10
110
111

\displaystyle H(P_b) = E_{P_b} \begin{bmatrix}log_{2}(\frac1{P_{b}(\omega)})\end{bmatrix} \displaystyle =0.5log_{2} \left(\frac1{0.5}\right)+0.25log_{2}\left (\frac1{0.25}\right) +0.125log_{2}\left (\frac1{0.125}\right)+0.125log_{2}\left (\frac1{0.125}\right) = 1.75

最適なコードの割り当て方法は数学的証明により、求められており

確率P(\omega) で起こる事象には長さ\displaystyle log_{2} \left(\frac{1}{P(\omega)}\right) のコードを割り当てるのが最適とされている。
また、期待値(1回の試行の結果を伝えるのに要する平均ビット数)をエントロピー(平均情報量)とよびます。

交差エントロピー

ここでそれぞれの国の最適エントロピーを上のような数学的方法によって求めてみましょう。
まずは、ある国の最適なエントロピーの求め方で、他の国のエントロピーを求めた場合を考える。

[例] B国の最適エントロピーの求め方で、A国のエントロピーを求める。

\displaystyle H(P_b) = E_{P_b} \begin{bmatrix}log_{2}(\frac1{P_{b}(\omega)})\end{bmatrix}
=\displaystyle  0.25log_{2} \left(\frac{1}{0.5}\right)+0.25log_{2} \left(\frac{1}{0.25}\right)+0.25log_{2} \left(\frac{1}{0.125}\right)+0.25log_{2} \left(\frac{1}{0.125}\right)=2.25

A国の最適エントロピーの求め方をしていないので、0.25[bit]分無駄が生じてしまっています。
このように、ある確率分布に最適化された方式で別の確率分布をエンコードした時の平均ビット長を「交差エントロピー」と呼びます。

交差エントロピーの「交差」とは何なのかを偶然知った。曰く、p×log(p)のようにlogの中身と外側に同じ変数が使われているのが普通のエントロピー。それに対して、t×log(y)のようにlogの中身と外側に異なる変数が使われているものを"交差"エントロピーと呼ぶらしい。 「ゼロから作るDeep Learning」を読んだ(後編) - 不確定な世界

損失関数・交差エントロピー誤差とは

交差エントロピー誤差では、 自然対数eを底とするモデル出力値のlog値と正解データ値を乗算したものの総和を、損失とします。
自然対数logは、logに渡される x の値が 0 に近い時には絶対数の大きな出力を返し、 x の値が 1 に近いほど、絶対数が 0 に近い出力を返します。 すなわち、正解データ t が 1 の時、それに対応するモデル出力 y が 1 に近い数値を出力できていれば、 t と x の乗算結果は小さくなり、 x が 0 に近い誤った数値を出力していれば、 t と x の乗算結果は大きくなる、という論理です。

wild-data-chase.com

qiita.com

なぜ分類問題では交差エントロピーが使われるのか

損失関数といえば二乗誤差が有名ですが、分類問題を扱う際には交差エントロピーが頻繁に使われます。
教師データと学習結果が大きく乖離している(損失関数の値が大きい)時、交差エントロピーを使った方が学習スピードが早い!!(1学習あたりの損失関数の減少幅が大きい)

出力結果(y)が0に近いところだと、交差エントロピー微分値がとてつもなく大きなマイナスになっていることがわかります。微分値が大きなマイナスになっているということは、それだけ損失関数がマイナスに大きく変動しているということを意味しています。
なぜ交差(クロス)エントロピーが機械学習(ニューラルネットワーク)の損失関数に使われるのか? | 大人になってから学びたい日本の歴史

sigmoid_cross_entropy_with_logits

交差エントロピーは、ニューラルネットワークの誤差関数(損失関数)として使われることがあります。特に、シグモイド関数との相性がよい。
Tensorflowでは、「tf.nn.sigmoid_cross_entropy_with_logits - tensorflow」と一緒になっている。

交差エントロピーの式の中にシグモイド関数が入ったものである。
\displaystyle E=\sum_{k=1} -y_k\log{t_k}-(1-y_k)\log{(1-t_k)}

         ↓

\displaystyle E=\sum_{k=1} -z\log{(sigmoid(x))}-(1-z)\log{(1-sigmoid(x))}

         ↓ 符号が違うので、それぞれに -1 を掛ける \displaystyle E=\sum_{k=1} z-log{(sigmoid(x))}+(1-z)-\log{(1-sigmoid(x))}

 z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
= z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + log(1 + exp(-x))
= x - x * z + log(1 + exp(-x))

上記プログラムは、tf.nn.sigmoid_cross_entropy_with_logits のロジックのコメントにあったものである。
次の式の展開がどうしてそうなるのか分からない、数学が出来る人は過程を省略するんだよね。
そういうのは嫌い、自分が分かった限りで説明していきます。

= z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))

この式で分かりにくい部分は、-log(1 / (1 + exp(-x))) → log(1 + exp(-x)) になることですね。
これは、対数の中身が分数の場合の展開が分かるといいです。log1=0
\displaystyle log\left(\frac{1}{8}\right) = log\left(\frac{1}{2^{3}}\right)  = log1-log(2^{3}) = 0-3log2 = -3log2

対数の分数は引き算に分けることができます。
\displaystyle -log\left(\frac{1}{1+exp(-x)}\right)
=-(log1-log(1+ exp(-x)))
=-(0-log(1 + exp(-x))
= log(1 + exp(-x))

同じ部分として -log(exp(-x) / (1 + exp(-x)) → (-log(exp(-x)) + log(1 + exp(-x))) があります。
これは先程と同じ展開で対数の分数は引き算になります。
-log(exp(-x) / (1 + exp(-x))
=-(log(exp(-x)) - log(1+ exp(-x)))
=(-log(exp(-x)) + log(1 + exp(-x))

3段目で分かりにくい部分は (-log(exp(-x)) → x ですかね。
これは自然対数の底の性質 (log(exp(x)) → x を使います。
今回は expの中の符号と外側の符号がマイナスになっていますので展開すると (-log(exp(-x)) =-(-x) = x となります。

= z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + log(1 + exp(-x))
= x - x * z + log(1 + exp(-x))

これは、 log(1 + exp(-x)) を 別の変数にすると分かりやすくなります。一旦 qにします。
=z * q + (1 - z) * (x + q)
=z * q + (1 - z) * x + (1 - z) * q
=zq + (1 - z) * x + q - zq
=(1 - z) * x + q - zq + zq
=(1 - z) * x + q
定義のqを元に戻す。
=(1 - z) * x + log(1 + exp(-x))
後は素直に展開すればいい。
=x - x * z + log(1 + exp(-x))

 

最後に

この記事はまだ途中である。

参照

スポンサーリンク