今回はscikit-learnで提供されている手書きの数字データを、Matplotlibで表示する方法を紹介する。
コンテンツ
データセットの読み込み
まずはscikit-learnから以下のとおりデータセットを読み込む。
from sklearn import datasets
digits = datasets.load_digits()
まずは説明変数となるimagesのうち、一つのデータを見てみよう。
digits.images[0].shape
# array([[ 0., 0., 5., 13., 9., 1., 0., 0.],
# [ 0., 0., 13., 15., 10., 15., 5., 0.],
# [ 0., 3., 15., 2., 0., 11., 8., 0.],
# [ 0., 4., 12., 0., 0., 8., 8., 0.],
# [ 0., 5., 8., 0., 0., 9., 8., 0.],
# [ 0., 4., 11., 0., 1., 12., 7., 0.],
# [ 0., 2., 14., 5., 10., 12., 0., 0.],
# [ 0., 0., 6., 13., 10., 0., 0., 0.]])
digits.images[0].shape
# (8, 8)
画像のピクセル数は「8×8」となっており、色情報の値が格納されている。
次に目的変数のtargetについて、こちらはその画像が何の数字かを(0〜9)表している。
digits.target[:20]
# array([0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 0, 1, 2, 3, 4, 5, 6, 7, 8, 9])
Matplotlibによるデータの可視化
とりあえず一気にコードを紹介する。
%matplotlib notebook
import matplotlib.pyplot as plt
image_and_labels = list(zip(digits.images, digits.target))
for index, (image, label) in enumerate(image_and_labels[:30]):
plt.subplot(6, 5, index + 1)
plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
plt.axis('off')
plt.subplots_adjust(wspace=1, hspace=1)
plt.title('Training: %i' % label)
plt.show()
実行すると下記のとおり、6行5列の手書き文字の画像データが出力される。
ではコードを解説していこう。
複数のリストをまとめる
画像データのimagesと正解データのtarget、二つのリストを一つにまとめる場合、list関数とzip関数を以下のとおり組み合わせる。
image_and_labels = list(zip(digits.images, digits.target))
enumerate関数でfor文を回す
まとめたリストを、次はfor文で回していく。
データ数が多すぎるので今回はとりあえず30データを表示させることにした。
enumerate関数は、引数に指定したシーケンス型オブジェクトと同時に、インデックス番号を取得できる。
つまりこの場合、indexに現在のインデックス番号、imageとlabelに先ほど一つにまとめたリスト(画像データ・正解データ)が格納される。
for index, (image, label) in enumerate(image_and_labels[:30]):
プロットの設定
for文の中で下記のとおりプロットの設定をおこなう。
# 一つの図に収めるプロットの数(6行5列 ※ループが30回のため)
plt.subplot(6, 5, index + 1)
# 画像データを表示する関数
# cmapで色指定(今回はグレースケール)
# interpolationで画像を拡大した時のピクセル補完
plt.imshow(image, cmap=plt.cm.gray_r, interpolation='nearest')
# 座標軸を非表示にする
plt.axis('off')
# プロット間の余白調整
plt.subplots_adjust(wspace=1, hspace=1)
# ラベルの表示
plt.title('Training: %i' % label)
最後にshow()関数を実行して出力。
以上、Matplotlibを使った画像データの出力方法だ。