たつぷりの調査報告書

博士後期課程(理学)の学生が趣味でUnityやBlenderで遊ぶブログです。素人が独学で勉強した際の忘備録です。

分類器で機械学習入門(その3)- 汎化誤差

こんにちは。たつぷりです。前回に引き続きCIFAR10の分類器についてのノートです。このシリーズは、pytorch公式のチュートリアルの解説とそれをより理解するために自分でやったことなどをまとめたものです。

Training a Classifier — PyTorch Tutorials 1.7.0 documentation

目標

前回までで、ニューラルネットワークのモデルを構成して、そのネットワークの学習を行った。ここではこのネットワークの妥当性を調べる。つまり学習によって分類器として振る舞う期待通りのネットワークが得られているかをチェックする。具体的には以下を調べる。

汎化誤差

まずは汎化誤差を調べる。前回の記事で述べたように、学習は誤差関数を小さくすることを指針にして行われた。誤差関数が小さいというのは、そのネットワークが入力した教師データを良く分類できていることを示している。

つまり、学習済みのネットワークは少なくとも教師データの分類についてできる(それを指針にしているので)。しかし重要なのは、まだ見ぬ未知のデータに対して正しく分類を行うことができるか?である。

まだ見ぬ未知のデータに対しての誤差を汎化誤差という。以下では具体的にこれを調べる。

テストデータの準備

まずはテスト用のデータを準備する。この手順については過去の記事で既に解説している。今回用いているのは一貫してCIFAR10のデータであるが、これはtorchvisionを使うことで簡単に用意できるのであった。今導入するのはテストデータなので、オプションでtrain = Falseを指定している。

これによって、各クラス1000枚ずつの10000枚の画像データにアクセスできる。

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=0)

具体的にデータにアクセスするにはtestloaderイテレーターに変換して各要素を読み出せばよい。

dataiter = iter(testloader)
images, labels = dataiter.next()

ネットワークからの予言値

前回保存したウェイトをロードして、学習済みのネットワークを呼び出す。

net = Net()
net.load_state_dict(torch.load(PATH))

このネットワークの出力を確認する。net()は前回述べたように、4*3*32*32の構造を持つデータの入力に対して4*10の構造を持つ量を返す。

outputs = net(images)
print(outputs.size())

#torch.Size([4, 10])が出力される

outputの第0要素を見てみると、例えば以下のようになっていた。

[-0.4645, -1.3041,  0.8090,  1.2276,  0.1213,  0.1624,  1.4485, -1.2337,0.2869, -0.6776]

この10個の列はそれぞれ各クラスに対応していており、最も大きい数字に対応するクラスである可能性が最も高いことを表している。

CIFAR10では以下の10個のクラスからなる。

('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

なので上の例ではネットワークは「frog」の可能性が最も高いと判断した、ということになる。ここで、このデータの本当のラベルをチェックする。

print(labels[0])
# 3 を出力する

このように実際のラベルは「cat」であったのでネットワークはこのデータに対しては、誤った予言を与えたことになる。

ちなみに、1番目のデータに対しては、以下の予言が得られ、

 [ 4.7203,  4.9927, -1.9707, -3.1006, -2.8680, -4.2864, -3.7756, -4.7392,7.6236,  2.3189]

実際のラベルが8であったので、ネットワークは正しい予言を与えたことになる。すなわちこの画像が「船」であると判断できた。

実際にこのミニバッチの画像を出力すると以下のようになっている。 f:id:Tatsupuri:20201128092944p:plain

確かに0番目のデータに関してカエルを見間違えたといわれると、個人的には少し納得した。ネットワークからの出力は、猫とカエルに対してそれぞれ1.22761.4485と予言を与えており、"迷っている"様子がうかがえる。

正答率

上では、個別のデータに対して予言を見たが今度はまとめて全データに対しての正答率を評価する。基本的には上とやることは同じで、テストデータはdataiterというイテレータになっているので、forループで評価すれば良い。

ただしイテレータの各要素はミニバッチごとに分けられているので、今回のケースでは各ステップで4つのデータを渡す。そのため各ステップでこの4つに対してそれぞれ上と同様のチェックをすることになる。

具体的には以下のように実装できる。

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

この出力は、

Accuracy of the network on the 10000 test images: 60 %

であり、このネットワークによる分類の精度は60%ということになる。完全にランダムに回答するとしたら10%の回答率になるはずなので、このネットワークは何かしら”学習”したことが結論できる

以下でこのコードの解説を行う。基本的にはcorrectに正しく予言を与えたデータの数、totalに全データ数を格納する。これらの量からネットワークのテストデータの対しての正答率を評価することができる。

with torch.no_grad():

with torch.no_grad():に関しては以下に詳しい解説があるので引用する。

Autograd: Automatic Differentiation — PyTorch Tutorials 1.7.0 documentation

To prevent tracking history (and using memory), you can also wrap the code block in with torch.no_grad():. This can be particularly helpful when evaluating a model because the model may have trainable parameters with requires_grad=True, but for which we don’t need the gradients.

つまり、モデルを評価するときにはgradientを計算する必要はないのでこのフラグを切っておくということである。

torch.max(outputs.data, 1)

これを理解するために以下を試す。なお、torch.maxのレファレンスはここ

outputs = net(images) #imageは上のコードを実行した後だと最後のミニバッチの画像
print(torch.max(outputs.data,1))

#出力
# torch.return_types.max(
# values=tensor([4.5723, 5.3089, 3.0941, 8.3328]),
# indices=tensor([3, 5, 4, 7]))

このように、torch.max(outputs.data, 1)の第2引数である1は、第1軸の中で最大値を探すというオプションである。つまり4つの画像データそれぞれに対して、10個の出力の中で最大のものを返す。 実際の出力を見たらわかるようにtorch.maxは二つの返り値を持ち、一つ目に最大値、二つ目にその最大値を与えたのは何番目の要素であるかを返す。

つまり、我々が今欲しい予言値はまさにtorch.maxの2つめの返り値に他ならない

ちなみに、outputoutput.dataの違いは、grad_fn=<AddmmBackward>のようなフラグを出力に含めるかどうかの差である。.dataをつけるとこのようなフラグは出力されない。

(predicted == labels).sum().item()

これを理解するために、以下を試す。

a = torch.Tensor([1,2,3,4])
b = torch.Tensor([1,2,3,0])

print((a == b))

# 出力
# tensor([ True,  True,  True, False])

このように二つのテンソル型の量に対し論理演算を行うと、各要素に対して論理演算を行い、そのBool型の結果を要素にもつテンソル型の量を返す。

よって、(a == b).sum()は、True=1,False=0として各要素の和をとるので、tensor(3)を返す。.item()メソッドを使うことで数値として要素を得ることができる。つまり(a == b).sum().item()3を返し、これはa,bの各要素のうち一致している要素の数を返す。

まとめると、(predicted == labels).sum().item()はミニバッチの中で予言が実際のラベルと一致していたデータの個数を返す

過学習

過学習は、訓練誤差が小さくなっているにも関わらず汎化誤差が大きくなってしまう現象である。つまり学習の結果、教師データにのみ特化したネットワークになってしまった状態である。

過学習が起こっていないことをチェックするには以下のような方法がある。学習の途中のステップでもウェイトを保存しておき、それらウェイトに対して汎化誤差の評価を行う。これによって汎化誤差が学習が進むにつれてどのように発展していくかを追うことができるようになる。

これでもし損失関数が減少しているにも関わらず汎化誤差が小さくなっていくことが確認できれば過学習が起きていると判断することができる。

今回のモデルで、エポック数を10にして2000ステップごとにウェイトを保存し、それらに対して汎化誤差の評価を行ったところ以下のような結果が得られた。

f:id:Tatsupuri:20201128111242p:plain

このように汎化誤差は傾向として単調に増加しており、この範囲においては過学習は起こっていなないといえるだろう。

まとめ

今回は学習によって得られたネットワークが、分類器としてそれなりに良く振る舞っていることが確認できた。このように機械学習においては学習によって得られたネットワークがまだ見ぬ未知のデータに対しても機能するかを調べることが非常に大事である。

次回はこのネットワークからの出力を分析することで、このネットワークの特徴や傾向を詳しく見ていくことにする。