今回はscikit-learnで提供されている手書きの数字データを、Matplotlibで表示する方法を紹介する。

データセットの読み込み

まずはscikit-learnから以下のとおりデータセットを読み込む。

from sklearn import datasets

digits = datasets.load_digits()

まずは説明変数となるimagesのうち、一つのデータを見てみよう。

digits.images[0].shape
# array([[ 0.,  0.,  5., 13.,  9.,  1.,  0.,  0.],
#        [ 0.,  0., 13., 15., 10., 15.,  5.,  0.],
#        [ 0.,  3., 15.,  2.,  0., 11.,  8.,  0.],
#        [ 0.,  4., 12.,  0.,  0.,  8.,  8.,  0.],
#        [ 0.,  5.,  8.,  0.,  0.,  9.,  8.,  0.],
#        [ 0.,  4., 11.,  0.,  1., 12.,  7.,  0.],
#        [ 0.,  2., 14.,  5., 10., 12.,  0.,  0.],
#        [ 0.,  0.,  6., 13., 10.,  0.,  0.,  0.]])

digits.images[0].shape
# (8, 8)

画像のピクセル数は「8×8」となっており、色情報の値が格納されている。

次に目的変数のtargetについて、こちらはその画像が何の数字かを(0〜9)表している。

digits.target[:20]
# array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])

Matplotlibによるデータの可視化

とりあえず一気にコードを紹介する。

%matplotlib notebook
import matplotlib.pyplot as plt

image_and_labels = list(zip(digits.images, digits.target))

for index, (image, label) in enumerate(image_and_labels[:30]):
    plt.subplot(6, 5, index + 1)
    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
    plt.axis('off')
    plt.subplots_adjust(wspace=1, hspace=1)
    plt.title('Training: %i' % label)
    
plt.show()

実行すると下記のとおり、6行5列の手書き文字の画像データが出力される。

ではコードを解説していこう。

複数のリストをまとめる

画像データのimagesと正解データのtarget、二つのリストを一つにまとめる場合、list関数とzip関数を以下のとおり組み合わせる。

image_and_labels = list(zip(digits.images, digits.target))

enumerate関数でfor文を回す

まとめたリストを、次はfor文で回していく。
データ数が多すぎるので今回はとりあえず30データを表示させることにした。

enumerate関数は、引数に指定したシーケンス型オブジェクトと同時に、インデックス番号を取得できる。
つまりこの場合、indexに現在のインデックス番号、imageとlabelに先ほど一つにまとめたリスト(画像データ・正解データ)が格納される。

for index, (image, label) in enumerate(image_and_labels[:30]):

プロットの設定

for文の中で下記のとおりプロットの設定をおこなう。

    # 一つの図に収めるプロットの数(6行5列 ※ループが30回のため)
    plt.subplot(6, 5, index + 1)
    
    # 画像データを表示する関数
    # cmapで色指定(今回はグレースケール)
    # interpolationで画像を拡大した時のピクセル補完
    plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
    
    # 座標軸を非表示にする
    plt.axis('off')
    
    # プロット間の余白調整
    plt.subplots_adjust(wspace=1, hspace=1)
    
    # ラベルの表示
    plt.title('Training: %i' % label)

最後にshow()関数を実行して出力。

以上、Matplotlibを使った画像データの出力方法だ。