たつぷりの調査報告書

博士後期課程(理学)の学生が趣味でUnityやBlenderで遊ぶブログです。素人が独学で勉強した際の忘備録です。

分類器で機械学習入門(その1)- データの準備と可視化

どうもたつぷりです。こんにちは。最近、個人的な趣味と研究上の必要によって、機械学習の門をたたいてみました。本当に何も知らないので公式のチュートリアルなどから触っていっています。

さて今回の記事は、この世には既に腐るほど存在している、CIFAR10の画像を分類する画像分類器を作成するチュートリアルの解説を行います。要は自分への備忘録です。

目的と目標

自分がチュートリアルをこなす際、知らなかったり分からなかったことをまとめながらやっていたので、その内容をブログにもまとめる。具体的には以下のURLのpytorchの公式チュートリアルの画像分類器の作成を、いろいろと何をやっているのかを解説付きでまとめた。

pytorch.org

今の自分の理解の範囲では正しいと思うことを書いているが、今後理解が進んでより深い理解をしたら修正を行う可能性がある。また現段階の勉強不足で分からないことは今後別のノートでまとめる。

学習データの準備

まず学習に必要なデータセットを準備する。

前準備

torchvision.datasetsを用いることで、有名なデータセットを簡単に取り込むことができる。また、 torch.utils.data.DataLoader を使うと画像のデータセットを作る、例えばミニバッチに分けるなどの操作をするのに便利である。

ここではCIFAR10のデータを用いる。各画像データの次元は3x32x32(色x高さx幅)である。

まずデータの準備に必要なモジュールを導入する。

import torch
import torchvision
import torchvision.transforms as transforms

ここで、torchvision.transformsは画像の成型を行うためのライブラリと考えてよい。

pytorch.org

今回は次のようなデータの成型を行う。

transform = transforms.Compose(
    [transforms.ToTensor(),
     transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))])

この際、transform.Composeは複数の変換をまとめて適用するためのものである。ここでは、

  • transforms.ToTensor()
  • transforms.Normalize()

をまとめて適用している。

transform.ToTensorは画像データをpytorchで扱えるテンソル型の量に変換する。

transform.Normalizeテンソル量を正規化する。この操作に関してはあまり理解できていないが、CNNなどで学習するときに効率が良くなることがあるなどの説明を見かけた。公式レファレンスによると、具体的には下のような操作を行っている。

Given mean: (mean[1],...,mean[n]) and std: (std[1],..,std[n]) for n channels, this transform will normalize each channel of the input torch.*Tensor i.e., output[channel] = (input[channel] - mean[channel]) / std[channel]

この式から考えると、例えば今の場合はRGBの3チャンネルが存在しているので。

Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))

と書くと、最初の(0.5, 0.5, 0.5)はrgbに対して正規化する際の平均値、後の(0.5, 0.5, 0.5)はrgbに対して正規化する際の標準偏差を与えている。よって

output_R = (input_R-0.5)/0.5

を与えていることになる。後で用いるように、inputを再構成するにはRGBそれぞれに対して

input_R = output_R * 0.5 + 0.5

をすればよい。

データのロード

実際に、以下のようにしてデータの準備をする。torchvision.datasets.CIFAR10()はCIFAR10のデータを読み込むのに用いる。引数に様々なオプションを指定する。

例えば今回筆者はデータのダウンロードも行ったので、download=Trueを指定した。CIFAR10ではあらかじめ教師データと、テストデータが準備されている。それらの指定をtrainで行う。上述のtransformを用いてデータの成型方法に関しての指定も行っている。

これによってtrainsetには教師用データの情報が格納された。ここで機械学習は各ステップでミニバッチと呼ばれる小さな部分集合ごとに行われる。ミニバッチに含まれる画像データの数をbatch_sizeで指定する。ここでは4にしているので学習の各ステップは4枚の画像に対して行われることになる。

trainset = torchvision.datasets.CIFAR10(root='./data', train=True,
                                        download=True, transform=transform)

trainloader = torch.utils.data.DataLoader(trainset, batch_size=4,
                                          shuffle=True, num_workers=0)

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)

testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=0)

classes = ('plane', 'car', 'bird', 'cat',
           'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

後で改めて述べるが、CIFAR10のデータは、画像データとその画像が何の画像かを表すラベルが含まれている。ラベルはこの場合だと0-9の10個の数字からなり、それぞれに犬や猫などが対応する。この数字と名前の対応がclassesで与えられている。

可視化

基本的にここで行うことは、

  • 正規化したデータを逆変換で元に戻す
  • テンソル量をNumpy arrayに変換しmatplotlibで可視化する

のである。まずこれに必要なモジュールを導入する。

import matplotlib.pyplot as plt
import numpy as np

データの構造の確認

実際に画像を出力する前に、各量のデータの構造を確認しておく。

dataiter = iter(trainloader)
images, labels = dataiter.next()

print(len(dataiter))
print(12500*4)

print(labels.size(),labels)
print(images.size())
print(images.shape)

上の入力に対して以下の出力が得られた。

50000
torch.Size([4]) tensor([5, 4, 5, 9])
torch.Size([4, 3, 32, 32])
torch.Size([4, 3, 32, 32])

この出力を理解する。

dataiter = iter(trainloader)でDataloader型のオブジェクトであるdataloaderをイテレータに変換している。イテレータに変換すると、next()を用いて各要素を順番に呼び出すことができるようになる。今回は省略するがこれらのメソッドは__iter__()で定義されるはずなので具体的に何をしているか見たければ、Dataloaderクラスのソースの該当箇所を見れば分かるはずである。

実際、詳しく見なくともこの後に続く文、

images, labels = dataiter.next()

を見ると、1つ目の返り値が画像データの情報で、2つ目の返り値が各データのラベルの情報であることが分かる。ラベルは数字で返ってきているが、これは既に定義したclasses = ('plane', 'car', 'bird', 'cat','deer', 'dog', 'frog', 'horse', 'ship', 'truck')と対応させてラベルを読み取ることができるようになっている。

とりあえず、このイテレータが何要素あるのか見てみると、 print(len(dataiter)) = 12500 であった。今回はバッチ数を4にしているで、各要素には4つの画像データが入っていることになる。つまり、print(12500*4) = 50000が読み込まれた全画像の数であることが分かる。

さて、この数字は何なのであろうか?以下のURLにCIFAR10についての基本的なことがまとめられているので見てみる。これによると、CIFAR10は全部で60000個の画像からなっており、そのうち50000個をTrainingデータ、10000個をTestデータとして用意している。今回ロードしたのは教師データなので、確かにこの数に一致している。

CIFAR-10 and CIFAR-100 datasets

次に、labelsのサイズと、実際の値を読み取った。 print(labels.size(),labels)に対してtorch.Size([4]) tensor([8, 2, 2, 5])が返ってきている。これはlabelが4つの量からなっており、その値は[8,2,2,5]であることを言っている。たしかにバッチ数と同じ数のラベルが返ってきており、この場合は「船、鳥、鳥、犬」を表していることになる。

同様に、imagesの構造も確認する。ちなみに、テンソルの構造はsize()shapeどちらでも見ることができる。

print(images.size())
print(images.shape)

これらに対しての返り値はtorch.Size([4, 3, 32, 32])であった。これは確かに、4つの画像からなり、各画像が3*32*32のサイズを持っていることを表しており、そのとおりである。

画像の可視化

実際にこれらを出力するには以下のようにすれば良い。まず、可視化を実現する関数を準備する。

def imshow(img):
    img = img / 2 + 0.5     # unnormalize
    npimg = img.numpy()
    plt.imshow(np.transpose(npimg, (1, 2, 0)))
    plt.show()

この関数内で上述した、データ準備する際施した正規化の逆変換を行い、numpy arrayに変換したあとmatplotlibで出力する一連の流れを実装したものである。ここで、transposeに関しては後で改めて説明する。

以下で実際に画像を出力する。

# show images
imshow(torchvision.utils.make_grid(images))
# print labels
print(' '.join('%5s' % classes[labels[j]] for j in range(4)))

f:id:Tatsupuri:20201124234728p:plain

このようにデータを可視化することができた。しかしここで上のコードを見ると、可視化されているのはimagesではなくてtorchvision.utils.make_grid(images)であることに注意する。

上で既に説明したように今回はサイズ4のミニバッチを用いているので、imagesの次元は4*3*32*32であり、画像データの次元C*H*Wと一致しない。そこでtorchvision.utils.make_gridを用いて4つの画像を1つの画像に結合する必要があるのだ。

実際以下のURL見るとわかるが、torchvision.utils.make_grid(images)は、(B,C,H,W)の構造をもつデータからバッチ数分ある画像を結像して一つの(C,H,W)の構造を持つデータに変換することを行う。

torchvision.utils — PyTorch 1.7.0 documentation

この際様々なパラメータを与えることができて、例えば結合するときの画像間の間隔などを変えることができる。上の公式レファレンスか下のURL

画像をただ並べたいときに使えるTorchVision | Shikoan's ML Blog

も参考になる。

ちなみに上の出力だと、4*3*32*32だったのを、3*36*138に変換している。これは、次のように理解できる。まず上下に、2ピクセルずつのマージンを加えているので、32から36になっており、横方向には、2ピクセルのマージンが5つ入っている。なので、32 * 4 + 2 * 5 = 138であり確かに一致している。

また、torchvision.utils.save_imageを用いれば結合された画像を保存することができる。

ここで先ほどimshow関数の中の以下の処理についても補足しておく。

plt.imshow(np.transpose(npimg, (1, 2, 0)))

NumpyのTranspose関数は多次元配列の軸の入れかえを行う関数である。この場合は、「もともと(0,1,2)の構造を持っていたものを(1,2,0)に変えよ」ということなので、(C,H,W) -> (H,W,C) の構造に変換したということである。これで正しくmatplotlibで出力することができる。

まとめ

今回は、公式チュートリアルのCIFAR10の分類器を作るプロジェクトにおいて、データの準備と可視化に関してフォーカスしてまとめた。これで学習の準備が整ったが、長くなってきたので次の記事で実際にモデルを構築し学習を行う。