デジタル・デザイン・ラボラトリーな日々

アラフィフプログラマーが数学と物理と英語を基礎からやり直す。https://qiita.com/yaju

TensorFlowコトハジメ 偶数と奇数に分類

はじめに

久しぶりにTensorFlowをさわってみました。
人工知能を勉強しようとしてもハードルが高いし、手書きの文字を分類したからって何って感じ、画像を集めるのも大変だし結果を出すにも時間がかかるしね。
先ずはリハビリとして何をやろうかと思ったのが、以前やったFizz-Buzz問題を応用して偶数と奇数に分類させてみるというもので、結果がすぐ出るのがいいよね。

yaju3d.hatenablog.jp

仕組み

101から127(27-1)までのデータで学習したニューラルネットワークに対して、1から100までの答え(偶数ならeven、奇数ならodd)の予測を出力するプログラムになっています。こんなんでも、贅沢にもディープラーニングを使ってます。
訓練する際に偶数か奇数かを振り分けるのに2の剰余(余り)を使っていますが答えを出すのに2の剰余(余り)を使っていません、コンピューターが学習して判断しています。

最低限で正しい結果が出るようにしたかったので、出来るだけ関連する数値を減らしています。

  • NUM_DIGITS = 7 … 101から学習させる範囲で2の指数値
  • NUM_HIDDEN = 5 … 隠れ層のユニット数
  • BATCH_SIZE = 1 … バッチ数
  • range(30) … エポック(学習ループの単位)の範囲

ソースコード

# coding: utf-8
# even odd in Tensorflow!

import numpy as np
import tensorflow as tf

NUM_DIGITS = 7

# Represent each input by an array of its binary digits.
def binary_encode(i, num_digits):
    return np.array([i >> d & 1 for d in range(num_digits)])

# One-hot encode the desired outputs: ["even", "odd"]
def even_odd_encode(i):
    if   i % 2 == 0: return np.array([1, 0])
    else:            return np.array([0, 1])

# Our goal is to produce even odd for the numbers 1 to 100. So it would be
# unfair to include these in our training data. Accordingly, the training data
# corresponds to the numbers 101 to (2 ** NUM_DIGITS - 1).
trX = np.array([binary_encode(i, NUM_DIGITS) for i in range(101, 2 ** NUM_DIGITS)])
trY = np.array([even_odd_encode(i)           for i in range(101, 2 ** NUM_DIGITS)])

# We'll want to randomly initialize weights.
def init_weights(shape):
    return tf.Variable(tf.random_normal(shape, stddev=0.01))

# Our model is a standard 1-hidden-layer multi-layer-perceptron with ReLU
# activation. The softmax (which turns arbitrary real-valued outputs into
# probabilities) gets applied in the cost function.
def model(X, w_h, w_o):
    h = tf.nn.relu(tf.matmul(X, w_h))
    return tf.matmul(h, w_o)

# Our variables. The input has width NUM_DIGITS, and the output has width 2.
X = tf.placeholder("float", [None, NUM_DIGITS])
Y = tf.placeholder("float", [None, 2])

# How many units in the hidden layer.
NUM_HIDDEN = 5

# Initialize the weights.
w_h = init_weights([NUM_DIGITS, NUM_HIDDEN])
w_o = init_weights([NUM_HIDDEN, 2])

# Predict y given x using the model.
py_x = model(X, w_h, w_o)

# We'll train our model by minimizing a cost function.
cost = tf.reduce_mean(tf.nn.softmax_cross_entropy_with_logits(py_x, Y))
train_op = tf.train.GradientDescentOptimizer(0.05).minimize(cost)

# And we'll make predictions by choosing the largest output.
predict_op = tf.argmax(py_x, 1)

# Finally, we need a way to turn a prediction (and an original number)
# into a even odd output
def even_odd(i, prediction):
     return ["{0:3d}".format(i) + ":even", "{0:3d}".format(i) + ":odd "][prediction]

BATCH_SIZE = 1

# Launch the graph in a session
with tf.Session() as sess:
    tf.initialize_all_variables().run()

    for epoch in range(30):
        # Shuffle the data before each training iteration.
        # print(range(len(trX)))

        p = np.random.permutation(range(len(trX)))
        trX, trY = trX[p], trY[p]

        # Train in batches of 1 inputs.
        for start in range(0, len(trX), BATCH_SIZE):
            end = start + BATCH_SIZE
            sess.run(train_op, feed_dict={X: trX[start:end], Y: trY[start:end]})

        # And print the current accuracy on the training data.
        if epoch % 10 == 0:
            print(epoch, np.mean(np.argmax(trY, axis=1) ==
                             sess.run(predict_op, feed_dict={X: trX, Y: trY})))

    # And now for some even odd
    numbers = np.arange(1, 101)
    teX = np.transpose(binary_encode(numbers, NUM_DIGITS))
    teY = sess.run(predict_op, feed_dict={X: teX})
    output = np.vectorize(even_odd)(numbers, teY)

    print(output)

出力結果

見事に1から100まで偶数と奇数に分類することが出来ました。

(0, 0.66666666666666663)
(10, 0.85185185185185186)
(20, 1.0)
['  1:odd ' '  2:even' '  3:odd ' '  4:even' '  5:odd ' '  6:even'
 '  7:odd ' '  8:even' '  9:odd ' ' 10:even' ' 11:odd ' ' 12:even'
 ' 13:odd ' ' 14:even' ' 15:odd ' ' 16:even' ' 17:odd ' ' 18:even'
 ' 19:odd ' ' 20:even' ' 21:odd ' ' 22:even' ' 23:odd ' ' 24:even'
 ' 25:odd ' ' 26:even' ' 27:odd ' ' 28:even' ' 29:odd ' ' 30:even'
 ' 31:odd ' ' 32:even' ' 33:odd ' ' 34:even' ' 35:odd ' ' 36:even'
 ' 37:odd ' ' 38:even' ' 39:odd ' ' 40:even' ' 41:odd ' ' 42:even'
 ' 43:odd ' ' 44:even' ' 45:odd ' ' 46:even' ' 47:odd ' ' 48:even'
 ' 49:odd ' ' 50:even' ' 51:odd ' ' 52:even' ' 53:odd ' ' 54:even'
 ' 55:odd ' ' 56:even' ' 57:odd ' ' 58:even' ' 59:odd ' ' 60:even'
 ' 61:odd ' ' 62:even' ' 63:odd ' ' 64:even' ' 65:odd ' ' 66:even'
 ' 67:odd ' ' 68:even' ' 69:odd ' ' 70:even' ' 71:odd ' ' 72:even'
 ' 73:odd ' ' 74:even' ' 75:odd ' ' 76:even' ' 77:odd ' ' 78:even'
 ' 79:odd ' ' 80:even' ' 81:odd ' ' 82:even' ' 83:odd ' ' 84:even'
 ' 85:odd ' ' 86:even' ' 87:odd ' ' 88:even' ' 89:odd ' ' 90:even'
 ' 91:odd ' ' 92:even' ' 93:odd ' ' 94:even' ' 95:odd ' ' 96:even'
 ' 97:odd ' ' 98:even' ' 99:odd ' '100:even']

ちなみに、隠れ層のユニット数「NUM_HIDDEN = 3」にした場合、間違った答えになります。

(0, 0.33333333333333331)
(10, 0.66666666666666663)
(20, 1.0)
['  1:odd ' '  2:odd ' '  3:odd ' '  4:odd ' '  5:odd ' '  6:odd '
 '  7:odd ' '  8:odd ' '  9:odd ' ' 10:odd ' ' 11:odd ' ' 12:odd '
 ' 13:odd ' ' 14:odd ' ' 15:odd ' ' 16:even' ' 17:odd ' ' 18:odd '
 ' 19:odd ' ' 20:odd ' ' 21:odd ' ' 22:odd ' ' 23:odd ' ' 24:odd '
 ' 25:odd ' ' 26:odd ' ' 27:odd ' ' 28:odd ' ' 29:odd ' ' 30:odd '
 ' 31:odd ' ' 32:even' ' 33:odd ' ' 34:odd ' ' 35:odd ' ' 36:odd '
 ' 37:odd ' ' 38:odd ' ' 39:odd ' ' 40:even' ' 41:odd ' ' 42:odd '
 ' 43:odd ' ' 44:odd ' ' 45:odd ' ' 46:odd ' ' 47:odd ' ' 48:even'
 ' 49:odd ' ' 50:odd ' ' 51:odd ' ' 52:odd ' ' 53:odd ' ' 54:odd '
 ' 55:odd ' ' 56:even' ' 57:odd ' ' 58:odd ' ' 59:odd ' ' 60:odd '
 ' 61:odd ' ' 62:odd ' ' 63:odd ' ' 64:even' ' 65:odd ' ' 66:odd '
 ' 67:odd ' ' 68:even' ' 69:odd ' ' 70:odd ' ' 71:odd ' ' 72:even'
 ' 73:odd ' ' 74:odd ' ' 75:odd ' ' 76:even' ' 77:odd ' ' 78:odd '
 ' 79:odd ' ' 80:even' ' 81:odd ' ' 82:odd ' ' 83:odd ' ' 84:even'
 ' 85:odd ' ' 86:even' ' 87:odd ' ' 88:even' ' 89:odd ' ' 90:even'
 ' 91:odd ' ' 92:even' ' 93:odd ' ' 94:odd ' ' 95:odd ' ' 96:even'
 ' 97:odd ' ' 98:odd ' ' 99:odd ' '100:even']

最後に

人工知能を使っていろいろやってみたいのですが、それをどうやって組むのかがまだピンと来ないんですよね。
前回、「確率を理解してみる-ベイジアンフィルタを実装」をやってみて自然言語が面白そうなので挑戦してみます。

スポンサーリンク