デジタル・デザイン・ラボラトリーな日々

アラフィフプログラマーが数学と物理と英語を基礎からやり直す。https://qiita.com/yaju

交差エントロピーを理解してみる

はじめに

ここ数ヶ月は別の件で忙しくて、機械学習に向き合えていませんでした。
仕事で調べたり学んだことはQiitaブログの方に書いていて結構すんなり書けるんですが、このブログは数学・物理・機械学習と特化するようにしているので、パワーがないと進まない。

ここ数ヶ月で頭の片隅にあったのは、損失関数の「交差エントロピー」です。今回はこれを理解していこうと思います。

何故、気になっていたのかというと、下記サイトの誤差関数(loss)のときのTensorflow関数「tf.nn.sigmoid_cross_entropy_with_logits」があり、シグモイド関数と交差エントロピーが一緒になっているからです。

yaju3d.hatenablog.jp

# 誤差関数(loss)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=u, labels=y_)

損失関数

機械学習では学習時に、いかに答えに近い値になるように重みパラメータを調整(必要な補正量を定量的に示す)させます。その部分を担うのが損失関数となります。そして損失関数の最小値を探すことが学習のゴールとなる。 損失関数を設定する理由は、認識精度を指標するとパラメータの微分がほとんどの場所で 0 (動かなくなる)になるから。

簡単な説明では「値=損失」ということで、この損失をいかに少なくするのかということで「損失関数」となります。
ちなみに、損失関数と誤差関数の呼び方が違うだけで同義となります。
目的関数、コスト関数、誤差関数、損失関数いろいろあるけど、なにが違うのかを検討 - Qiita

損失関数は幾つか種類がありますが、機械学習では下記の2種類がよく使われます。
※理由として「誤差逆伝播」ができる関数であることのようです。

交差エントロピー誤差(cross entropy error)

二乗誤差の方は比較的分かりやすいので、交差エントロピーを理解していきます。
分類問題(識別問題)になると交差エントロピーが出てきます。

f:id:Yaju3D:20160421012315p:plain

\displaystyle E=\sum_{k=1} -y_k\log{t_k}-(1-y_k)\log{(1-t_k)}

これがある条件(分類問題)になると簡単になります。

\displaystyle E=\sum_{k=1} - t_k \log y_k

エントロピー

エントロピーとは平均情報量のことです。

交差

交差エントロピーの「交差」とは何なのかを偶然知った。曰く、p×log(p)のようにlogの中身と外側に同じ変数が使われているのが普通のエントロピー。それに対して、t×log(y)のようにlogの中身と外側に異なる変数が使われているものを"交差"エントロピーと呼ぶらしい。 「ゼロから作るDeep Learning」を読んだ(後編) - 不確定な世界

なぜ分類問題では交差エントロピーが使われるのか

損失関数といえば二乗誤差が有名ですが、分類問題を扱う際には交差エントロピーが頻繁に使われます。
教師データと学習結果が大きく乖離している(損失関数の値が大きい)時、交差エントロピーを使った方が学習スピードが早い!!(1学習あたりの損失関数の減少幅が大きい)

出力結果(y)が0に近いところだと、交差エントロピー微分値がとてつもなく大きなマイナスになっていることがわかります。微分値が大きなマイナスになっているということは、それだけ損失関数がマイナスに大きく変動しているということを意味しています。
なぜ交差(クロス)エントロピーが機械学習(ニューラルネットワーク)の損失関数に使われるのか? | 大人になってから学びたい日本の歴史

sigmoid_cross_entropy_with_logits

交差エントロピーは、ニューラルネットワークの誤差関数(損失関数)として使われることがあります。特に、シグモイド関数との相性がよい。
Tensorflowでは、「tf.nn.sigmoid_cross_entropy_with_logits - tensorflow」と一緒になっている。

交差エントロピーの式の中にシグモイド関数を入ったものである。
\displaystyle E=\sum_{k=1} -y_k\log{t_k}-(1-y_k)\log{(1-t_k)}

         ↓

\displaystyle E=\sum_{k=1} -z\log{(sigmoid(x))}-(1-z)\log{(1-sigmoid(x))}

         ↓ 符号が違うので、それぞれに -1 を掛ける \displaystyle E=\sum_{k=1} z-log{(sigmoid(x))}+(1-z)-\log{(1-sigmoid(x))}

 z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
= z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + log(1 + exp(-x))
= x - x * z + log(1 + exp(-x))

上記プログラムは、tf.nn.sigmoid_cross_entropy_with_logits のロジックのコメントにあったものである。
次の式の展開がどうしてそうなるのか分からない、数学が出来る人は過程を省略するんだよね。
そういうのは嫌い、自分が分かった限りで説明していきます。

= z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))

この式で分かりにくい部分は、-log(1 / (1 + exp(-x))) → log(1 + exp(-x)) になることですね。
これは、対数の中身が分数の場合の展開が分かるといいです。log1=0
\displaystyle log\left(\frac{1}{8}\right) = log\left(\frac{1}{2^{3}}\right)  = log1-log(2^{3}) = 0-3log2 = -3log2

対数の分数は引き算に分けることができます。
\displaystyle -log\left(\frac{1}{1+exp(-x)}\right)
=-(log1-log(1+ exp(-x)))
=-(0-log(1 + exp(-x))
= log(1 + exp(-x))

同じ部分として -log(exp(-x) / (1 + exp(-x)) → (-log(exp(-x)) + log(1 + exp(-x))) があります。
これは先程と同じ展開で対数の分数は引き算になります。
-log(exp(-x) / (1 + exp(-x))
=-(log(exp(-x)) - log(1+ exp(-x)))
=(-log(exp(-x)) + log(1 + exp(-x))

3段目で分かりにくい部分は (-log(exp(-x)) → x ですかね。
これは自然対数の底の性質 (log(exp(x)) → x を使います。
今回は expの中の符号と外側の符号がマイナスになっていますので展開すると (-log(exp(-x)) =-(-x) = x となります。

= z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + log(1 + exp(-x))
= x - x * z + log(1 + exp(-x))

これは、 log(1 + exp(-x)) を 別の変数にすると分かりやすくなります。一旦 qにします。
=z * q + (1 - z) * (x + q)
=z * q + (1 - z) * x + (1 - z) * q
=zq + (1 - z) * x + q - zq
=(1 - z) * x + q - zq + zq
=(1 - z) * x + q
定義のqを元に戻す。
=(1 - z) * x + log(1 + exp(-x))
後は素直に展開すればいい。
=x - x * z + log(1 + exp(-x))

 

最後に

この記事はまだ途中である。

参照