デジタル・デザイン・ラボラトリーな日々

アラフィフプログラマーが数学と物理と英語を基礎からやり直す。https://qiita.com/yaju

交差エントロピーを理解してみる

はじめに

ここ数ヶ月は別の件で忙しくて、機械学習に向き合えていませんでした。
仕事で調べたり学んだことはQiitaブログの方に書いていて結構すんなり書けるんですが、このブログは数学・物理・機械学習と特化するようにしているので、パワーがないと進まない。

ここ数ヶ月で頭の片隅にあったのは、損失関数の「交差エントロピー」です。今回はこれを理解していこうと思います。

何故、気になっていたのかというと、下記サイトの誤差関数(loss)のときのTensorflow関数「tf.nn.sigmoid_cross_entropy_with_logits」があり、シグモイド関数と交差エントロピーが一緒になっているからです。

yaju3d.hatenablog.jp

# 誤差関数(loss)
loss = tf.reduce_mean(tf.nn.sigmoid_cross_entropy_with_logits(logits=u, labels=y_)

損失関数

機械学習では学習時に、いかに答えに近い値になるように重みパラメータを調整(必要な補正量を定量的に示す)させます。その部分を担うのが損失関数となります。そして損失関数の最小値を探すことが学習のゴールとなる。 損失関数を設定する理由は、認識精度を指標するとパラメータの微分がほとんどの場所で 0 (動かなくなる)になるから。

簡単な説明では「値=損失」ということで、この損失をいかに少なくするのかということで「損失関数」となります。
ちなみに、損失関数と誤差関数の呼び方が違うだけで同義となります。
目的関数、コスト関数、誤差関数、損失関数いろいろあるけど、なにが違うのかを検討 - Qiita

損失関数は幾つか種類がありますが、機械学習では下記の2種類がよく使われます。
※理由として「誤差逆伝播」ができる関数であることのようです。

交差エントロピー誤差(cross entropy error)

二乗誤差の方は比較的分かりやすいので、交差エントロピーを理解していきます。
分類問題(識別問題)になると交差エントロピーが出てきます。

f:id:Yaju3D:20160421012315p:plain

\displaystyle E=\sum_{k=1} -y_k\log{t_k}-(1-y_k)\log{(1-t_k)}

これがある条件(分類問題)になると簡単になります。

\displaystyle E=\sum_{k=1} - t_k \log y_k

エントロピー

エントロピーとは平均情報量のことです。

交差

交差エントロピーの「交差」とは何なのかを偶然知った。曰く、p×log(p)のようにlogの中身と外側に同じ変数が使われているのが普通のエントロピー。それに対して、t×log(y)のようにlogの中身と外側に異なる変数が使われているものを"交差"エントロピーと呼ぶらしい。 「ゼロから作るDeep Learning」を読んだ(後編) - 不確定な世界

sigmoid_cross_entropy_with_logits

tf.nn.sigmoid_cross_entropy_with_logits - tensorflow

交差エントロピーの式の中にシグモイド関数を入ったものである。
\displaystyle E=\sum_{k=1} -y_k\log{t_k}-(1-y_k)\log{(1-t_k)}

         ↓

\displaystyle E=\sum_{k=1} -z\log{(sigmoid(x))}-(1-z)\log{(1-sigmoid(x))}

         ↓ 符号が違うので、-1 を掛ける

\displaystyle E=\sum_{k=1} z\log{(sigmoid(x))}+(1-z)-\log{(1-sigmoid(x))}

 z * -log(sigmoid(x)) + (1 - z) * -log(1 - sigmoid(x))
= z * -log(1 / (1 + exp(-x))) + (1 - z) * -log(exp(-x) / (1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (-log(exp(-x)) + log(1 + exp(-x)))
= z * log(1 + exp(-x)) + (1 - z) * (x + log(1 + exp(-x))
= (1 - z) * x + log(1 + exp(-x))
= x - x * z + log(1 + exp(-x))

最後に

この記事はまだ途中である。

参照