たつぷりの調査報告書

博士後期課程(理学)の学生が趣味でUnityやBlenderで遊ぶブログです。素人が独学で勉強した際の忘備録です。

分類器で機械学習入門(その4)- 結果の解釈

こんにちは。たつぷりです。前回に引き続きCIFAR10の分類器についてのノートです。このシリーズは、pytorch公式のチュートリアルの解説とそれをより理解するために自分でやったことなどをまとめたものです。

前回までで「データの準備」、「学習」、「汎化誤差の評価」などを見てきた。今回は、学習済みのネットワークから何が言えるかを考察したいと思います。

目標

今回の目標は、ネットワークの出力に対してSoftmaxスコアを評価し、それを可視化して傾向を分析することである。そこから、今回のモデルの弱点などを考察していく。

Softmaxスコアの評価

ここでは、ネットワークからの出力に対する、Softmax関数を計算しその可視化を行う。

Softmax関数

まず、Softmax関数の定義とその性質を簡単にまとめておく。

まずSoftmax関数の定義は以下の通りである。 \mathbf{x}\in \mathbb{R}^nに対してソフトマックス関数 \sigma_i\left ( \mathbf{x} \right)は、以下で与えられる。

 \sigma_i\left ( \mathbf{x} \right) = \frac{e^{x_i}}{Z}

ただし、

 Z = \sum_{i} e^{x_i}

である。

これは、私のように物理をやっている人間からすると明らかであるが、これはボルツマンウェイトである。なのでわざわざ、分母を Zという物理屋にはなじみ深い分配関数の形で書いておいた。

次にSoftmax関数の性質をまとめる。Softmax関数からは引数のデータが局在しているか?などの情報を得ることができる(この表現が正しいか微妙な気もするが...)。

例えば、任意のi に対して x_i = pであったとすると、

 \sigma_i\left ( \mathbf{x} \right) = \frac{1}{n}

になる。その一方で、 \mathbf{x} = (q,0,\cdots,0)であるような場合は

 \sigma_1\left ( \mathbf{x} \right) = \frac{e^{q}}{e^{q}+(n-1)}

 \sigma_{i\neq1}\left ( \mathbf{x} \right) = \frac{1}{e^{q}+(n-1)}

である。qが十分大きければ]、 \sigma_1\left ( \mathbf{x} \right) は1に近づき、 \sigma_{i\neq1}\left ( \mathbf{x} \right)は0に近づく。

このように入力データがどれも同じ程度の場合はSoftmax関数の値はいずれも \frac[1[n]]に近づく。一方、突出して大きい値があればそれに対応するSoftmax関数の値は1に近づき、それ以外は0に近づくという性質を持つ。

以上の性質から、今回やっているようなネットワークからの出力に対してのSoftmax関数を考えたときは、ネットワークが入力画像の分類を明確にできているか否かが分かる。いわば、Softmax関数を評価することでどの程度自信をもって入力画像の分類を行ったかが分かるのである。

2次元での例

より具体的にSoftmax関数の性質をまとめておく目的で、2次元のベクトルに対してのSoftmax関数を見ておく。

 \mathbf{x} =(p,q)に対して、Softmax関数は

 \sigma_1\left ( (p,q) \right) = \frac{e^{p}}{e^{p}+e^{q}} = \frac{1}{1+e^{q-p}} \sigma_2\left ( (p,q) \right) = \frac{e^{q}}{e^{p}+e^{q}} = \frac{1}{1+e^{p-q}}


である。よって上で議論した通り、p = qの時は、

  \sigma_1\left ( (p,q) \right) = \sigma_2\left ( (p,q) \right) = \frac{1}{2}

である。

ここでe^{x}は単調増加関数であるので、p>qが成立すれば、

  \sigma_1\left ( (p,q) \right) >  \sigma_2\left ( (p,q) \right)

が成立することがすぐに分かる。さらに、p>>qならば上で議論した通り、

  \sigma_1\left ( (p,q) \right) \sim 1 \ ,\  \sigma_2\left ( (p,q)  \right)\sim 0

であることが見て取れる。

実装

Softmax関数の評価を実装する。

from scipy.special imp
net = Net()
net.load_state_dict(torch.load(PATH))

prediction = []
sm_list = []
label_list = []

with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        
        label_list.append(labels.numpy())
        
        sm = softmax(outputs.numpy(),axis=1)
        sm_list.append(sm)
        
    sm_array = np.concatenate(sm_list)
    l_array = np.concatenate(label_list)

ここで、具体的に何をやっているかを見るためにループの最後におけるいくつかの変数を出力する。

print(sm.shape)
# (4, 10) が出力される

print(sm_array.shape)
# (10000, 10) が出力される

まず上の出力を理解する。ネットワークからの出力であるoutputs4*10の次元を持っている。これは、ミニバッチに含まれる4つの画像に対してそれぞれ10個の出力があるからである。

さて、softmax(array)は、arrayに対してのソフトマックス関数を適用する(レファレンス)。

ソフトマックス関数の重要なパラメータはaxisである。第一引数のarrayは一般に多次元配列でも許される。この際、どこで規格化するかを指定するのがこのパラメーターである。例えば、上ではaxis = 1にしているがこれはaxis = 1に沿って和をとったとき、1になるように規格化するという意味である。

つまり今回のケースだと、第一引数にはoutputs.numpy()であるが、この型は4 * 10であった。axis = 1というのは10の方を意味する。つまりこの10個の量に対して足したら1になるように規格化を行うことになる。実際、下でそれが確認できる。

sm[0].sum()
#1.0が出力される

このままだと、ソフトマックススコアがミニバッチに含まれる4つのデータごとに出力されている。しかし最終的な分析においてはミニバッチごとではなくデータ全体で分析したいので、これらを一つにまとめる。実際この場合では、sm4*10の構造を持っていて、2500回ループする。この各ループでsmsm_listに追加していくので、sm_list2500*4*10の構造をもつ。

np.concatenate(sm_list)はこのリストの平滑化を行う(レファレンス)。デフォルトではaxis = 0であるので、0軸にそって結合を行うことになる。

今回の場合は2500*4*10->10000*10のように構造を変える。

可視化

基本的には、クラスの数で円を等分して、各動径方形に単位ベクトルを用意する。これをSoftmax関数を係数として足したベクトルの位置に点を打っていくことで、Softmaxスコアの分布をみることができる。

この考えの下で以下のように実装した。

num_class = len(sm_array[0])#クラスの数(ここでは10)

#各クラスに対応する単位ベクトルの準備
unit_x = np.array([ np.cos(2 * np.pi * i / num_class) for i in range(num_class)])
unit_y = np.array([ np.sin(2 * np.pi * i / num_class) for i in range(num_class)])


##############図の成型部分##################
#出力する図の大きさと背景の色の設定
fig,axis=plt.subplots(figsize=(8,8),facecolor='w')

#10角形の表示
plt.plot(np.append(unit_x,unit_x[0]),np.append(unit_y,unit_y[0]))

#グラフの軸を省略する
axis.tick_params(axis='both',which='both',bottom=False,top=False,left=False,right=False,labelleft=False,labelbottom=False)

#マージンの設定
plt.xlim(-1.3,1.3)
plt.ylim(-1.3,1.3)
##########################################

#メインの出力
for label in range(num_class):
    idx = np.where(l_array == label)
    sm_class = sm_array[idx]
    
    list_x = []
    list_y = []
    
    for sm in sm_class:
        list_x.append(np.dot(unit_x , sm))
        list_y.append(np.dot(unit_y , sm))
    
    plt.scatter(list_x,list_y,alpha=0.3)
    
#クラス名の表示
#先頭のクラスが(0,0)の位置に来るので、ここで文字を-90度回転させる。そこから順番に360/10度ずつ足していけば円に沿って字が出力される。
for d,name in enumerate(classes):
        plt.text(unit_x[d]*1.1, unit_y[d]*1.1,str(name),fontsize=18,ha='center',va='center',rotation= -90 + 360/num_class*d)#-90 + 2 * np.pi/num_class*d/np.pi*180

plt.show()

以上を実行すると、次のような図が得られた。

f:id:Tatsupuri:20201128145707p:plain

考察

上で得られた結果に対して考察をする。

まずSoftmax関数から、

  • きちんと識別できているデータは上の図では10角形の頂点に近づく
  • 2つのクラスの間で判断がついていない時は、その2頂点を結ぶ直線状にのる

などの性質が分かるので、それを踏まえて上の図を分析する。

生物/無生物

まず、次の図のようにデータが集積している領域は大まかに2つに分けられることが見て取れる。

f:id:Tatsupuri:20201128150731p:plain

各頂点に対応するラベルを参考にすると、この領域はそれぞれ、生物(鳥、猫、鹿、犬、カエル、馬)と無生物(船、トラック、飛行機、車)の領域だと分かる。

つまり、学習によって作成した分類器は少なくとも生物か無生物か?程度の区別は比較的ついていることが理解できる(無論そうではない原点付近のデータもある)。

識別しずらいデータ?

また、SoftMaxの分布から以下のような特徴的な集積点(線)も見て取れる(ここで図示した限りではない)。

f:id:Tatsupuri:20201128151952p:plain

これによって例えば、このネットワークは「車」と「トラック」の識別があまりうまくいっていない、などと読み取ることができそうである。ほかにも「馬」と「鹿」など、言われてみれば確かに識別がうまくいっていないといわれても個人的に納得できるものであったので、このSoftmaxの分布から今回の学習の結果はそれなりに妥当であると考えられる。

まとめ

Softmax関数を評価して図示することで学習によって得られたネットワークの性質を考察することができた。 ただ、この辺の分析について筆者はまだ確信をもって議論できるほど勉強が進んでいないので今後改めて考察していくつもりである。

また、例えば馬と鹿のように識別の難しいクラスがいくつかあったが、これらをより精度よく分類する方法について今後考えたい。

分類器で機械学習入門(その3)- 汎化誤差

こんにちは。たつぷりです。前回に引き続きCIFAR10の分類器についてのノートです。このシリーズは、pytorch公式のチュートリアルの解説とそれをより理解するために自分でやったことなどをまとめたものです。

Training a Classifier — PyTorch Tutorials 1.7.0 documentation

目標

前回までで、ニューラルネットワークのモデルを構成して、そのネットワークの学習を行った。ここではこのネットワークの妥当性を調べる。つまり学習によって分類器として振る舞う期待通りのネットワークが得られているかをチェックする。具体的には以下を調べる。

汎化誤差

まずは汎化誤差を調べる。前回の記事で述べたように、学習は誤差関数を小さくすることを指針にして行われた。誤差関数が小さいというのは、そのネットワークが入力した教師データを良く分類できていることを示している。

つまり、学習済みのネットワークは少なくとも教師データの分類についてできる(それを指針にしているので)。しかし重要なのは、まだ見ぬ未知のデータに対して正しく分類を行うことができるか?である。

まだ見ぬ未知のデータに対しての誤差を汎化誤差という。以下では具体的にこれを調べる。

テストデータの準備

まずはテスト用のデータを準備する。この手順については過去の記事で既に解説している。今回用いているのは一貫してCIFAR10のデータであるが、これはtorchvisionを使うことで簡単に用意できるのであった。今導入するのはテストデータなので、オプションでtrain = Falseを指定している。

これによって、各クラス1000枚ずつの10000枚の画像データにアクセスできる。

testset = torchvision.datasets.CIFAR10(root='./data', train=False,
                                       download=True, transform=transform)
testloader = torch.utils.data.DataLoader(testset, batch_size=4,
                                         shuffle=False, num_workers=0)

具体的にデータにアクセスするにはtestloaderイテレーターに変換して各要素を読み出せばよい。

dataiter = iter(testloader)
images, labels = dataiter.next()

ネットワークからの予言値

前回保存したウェイトをロードして、学習済みのネットワークを呼び出す。

net = Net()
net.load_state_dict(torch.load(PATH))

このネットワークの出力を確認する。net()は前回述べたように、4*3*32*32の構造を持つデータの入力に対して4*10の構造を持つ量を返す。

outputs = net(images)
print(outputs.size())

#torch.Size([4, 10])が出力される

outputの第0要素を見てみると、例えば以下のようになっていた。

[-0.4645, -1.3041,  0.8090,  1.2276,  0.1213,  0.1624,  1.4485, -1.2337,0.2869, -0.6776]

この10個の列はそれぞれ各クラスに対応していており、最も大きい数字に対応するクラスである可能性が最も高いことを表している。

CIFAR10では以下の10個のクラスからなる。

('plane', 'car', 'bird', 'cat', 'deer', 'dog', 'frog', 'horse', 'ship', 'truck')

なので上の例ではネットワークは「frog」の可能性が最も高いと判断した、ということになる。ここで、このデータの本当のラベルをチェックする。

print(labels[0])
# 3 を出力する

このように実際のラベルは「cat」であったのでネットワークはこのデータに対しては、誤った予言を与えたことになる。

ちなみに、1番目のデータに対しては、以下の予言が得られ、

 [ 4.7203,  4.9927, -1.9707, -3.1006, -2.8680, -4.2864, -3.7756, -4.7392,7.6236,  2.3189]

実際のラベルが8であったので、ネットワークは正しい予言を与えたことになる。すなわちこの画像が「船」であると判断できた。

実際にこのミニバッチの画像を出力すると以下のようになっている。 f:id:Tatsupuri:20201128092944p:plain

確かに0番目のデータに関してカエルを見間違えたといわれると、個人的には少し納得した。ネットワークからの出力は、猫とカエルに対してそれぞれ1.22761.4485と予言を与えており、"迷っている"様子がうかがえる。

正答率

上では、個別のデータに対して予言を見たが今度はまとめて全データに対しての正答率を評価する。基本的には上とやることは同じで、テストデータはdataiterというイテレータになっているので、forループで評価すれば良い。

ただしイテレータの各要素はミニバッチごとに分けられているので、今回のケースでは各ステップで4つのデータを渡す。そのため各ステップでこの4つに対してそれぞれ上と同様のチェックをすることになる。

具体的には以下のように実装できる。

correct = 0
total = 0
with torch.no_grad():
    for data in testloader:
        images, labels = data
        outputs = net(images)
        _, predicted = torch.max(outputs.data, 1)
        total += labels.size(0)
        correct += (predicted == labels).sum().item()

print('Accuracy of the network on the 10000 test images: %d %%' % (100 * correct / total))

この出力は、

Accuracy of the network on the 10000 test images: 60 %

であり、このネットワークによる分類の精度は60%ということになる。完全にランダムに回答するとしたら10%の回答率になるはずなので、このネットワークは何かしら”学習”したことが結論できる

以下でこのコードの解説を行う。基本的にはcorrectに正しく予言を与えたデータの数、totalに全データ数を格納する。これらの量からネットワークのテストデータの対しての正答率を評価することができる。

with torch.no_grad():

with torch.no_grad():に関しては以下に詳しい解説があるので引用する。

Autograd: Automatic Differentiation — PyTorch Tutorials 1.7.0 documentation

To prevent tracking history (and using memory), you can also wrap the code block in with torch.no_grad():. This can be particularly helpful when evaluating a model because the model may have trainable parameters with requires_grad=True, but for which we don’t need the gradients.

つまり、モデルを評価するときにはgradientを計算する必要はないのでこのフラグを切っておくということである。

torch.max(outputs.data, 1)

これを理解するために以下を試す。なお、torch.maxのレファレンスはここ

outputs = net(images) #imageは上のコードを実行した後だと最後のミニバッチの画像
print(torch.max(outputs.data,1))

#出力
# torch.return_types.max(
# values=tensor([4.5723, 5.3089, 3.0941, 8.3328]),
# indices=tensor([3, 5, 4, 7]))

このように、torch.max(outputs.data, 1)の第2引数である1は、第1軸の中で最大値を探すというオプションである。つまり4つの画像データそれぞれに対して、10個の出力の中で最大のものを返す。 実際の出力を見たらわかるようにtorch.maxは二つの返り値を持ち、一つ目に最大値、二つ目にその最大値を与えたのは何番目の要素であるかを返す。

つまり、我々が今欲しい予言値はまさにtorch.maxの2つめの返り値に他ならない

ちなみに、outputoutput.dataの違いは、grad_fn=<AddmmBackward>のようなフラグを出力に含めるかどうかの差である。.dataをつけるとこのようなフラグは出力されない。

(predicted == labels).sum().item()

これを理解するために、以下を試す。

a = torch.Tensor([1,2,3,4])
b = torch.Tensor([1,2,3,0])

print((a == b))

# 出力
# tensor([ True,  True,  True, False])

このように二つのテンソル型の量に対し論理演算を行うと、各要素に対して論理演算を行い、そのBool型の結果を要素にもつテンソル型の量を返す。

よって、(a == b).sum()は、True=1,False=0として各要素の和をとるので、tensor(3)を返す。.item()メソッドを使うことで数値として要素を得ることができる。つまり(a == b).sum().item()3を返し、これはa,bの各要素のうち一致している要素の数を返す。

まとめると、(predicted == labels).sum().item()はミニバッチの中で予言が実際のラベルと一致していたデータの個数を返す

過学習

過学習は、訓練誤差が小さくなっているにも関わらず汎化誤差が大きくなってしまう現象である。つまり学習の結果、教師データにのみ特化したネットワークになってしまった状態である。

過学習が起こっていないことをチェックするには以下のような方法がある。学習の途中のステップでもウェイトを保存しておき、それらウェイトに対して汎化誤差の評価を行う。これによって汎化誤差が学習が進むにつれてどのように発展していくかを追うことができるようになる。

これでもし損失関数が減少しているにも関わらず汎化誤差が小さくなっていくことが確認できれば過学習が起きていると判断することができる。

今回のモデルで、エポック数を10にして2000ステップごとにウェイトを保存し、それらに対して汎化誤差の評価を行ったところ以下のような結果が得られた。

f:id:Tatsupuri:20201128111242p:plain

このように汎化誤差は傾向として単調に増加しており、この範囲においては過学習は起こっていなないといえるだろう。

まとめ

今回は学習によって得られたネットワークが、分類器としてそれなりに良く振る舞っていることが確認できた。このように機械学習においては学習によって得られたネットワークがまだ見ぬ未知のデータに対しても機能するかを調べることが非常に大事である。

次回はこのネットワークからの出力を分析することで、このネットワークの特徴や傾向を詳しく見ていくことにする。

分類器で機械学習入門(その2)- 学習

こんにちは。たつぷりです。引き続き、CIFAR10の分類問題についてまとめます。前回はデータセットの準備とその可視化を行いました。今回は実際にネットワークを構成し、学習によりネットワークを最適化することをおこないます。

今回も自分用のノートをブログにのせているので実験的に進んでいたりと、スマートではない進み方をしますのでご了承下さい。

目標

この記事での目標は以下である。

  • 畳み込みニューラルネットワークのモデルを構成する
  • 学習によりネットワークを最適化する
  • 学習したネットワークのウェイトを保存する

また、今回もpytorch公式のチュートリアル

Training a Classifier — PyTorch Tutorials 1.7.0 documentation

に沿って進んでいく。

モデルの定義とその構造

学習させるモデルを作成する、すなわち学習するネットワークを構成する。具体的にpytorchでは、torch.nnクラスを継承するネットワーククラスを定義することで可能である。ここではNetクラスを作る。

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

以下ではこのネットワークの構造について見ていく。ネットワークの定義の仕方に関して個人的により見やすいと思う方法を補足に書いた。

構造

上の書き方ではコンストラクタでネットワーク内で用いる操作を定義しており、実際のネットワークの構成はforwardに定義されている。

上から順に見てくと、

入力

畳み込み
↓Relu
マックスプ-リング

畳み込み
↓Relu
マックスプーリング
↓viewで成型
線形変換

線形変換

線形変換

出力

という流れになっている。

入力と出力

まずネットワークの入力は、教師データのミニバッチである。ここではCIFAR10の画像をバッチ数4でデータセットを作ったので、4*3*32*32の構造を持ったテンソル型のデータを入力として受け取る。

また最後の線形変換を見ると、10次元のベクトルが出力になっていることが読み取れる。つまりこのネットワークは真ん中をいったん無視すれば、

4*3*32*32の構造を持つをテンソル量を入力して10次元のベクトルを出力する"関数"として機能している。入力はもちろん各ミニバッチの教師データで、10個の出力はCIFAR10の10個のクラスに対応している。

最終的にこのネットワークを”学習”することで

"画像データ"を入力すると”クラス”を返す”関数”として機能するようになるのである

学習

上で定義したネットワークの学習を行う。上述のとおりネットワークは画像データを入力すると10次元のベクトルを返すので、学習とは教師データに対して正しいラベルに対応する出力が得られるようにネットワークのウェイトを最適化することに他ならない。

この記事の範囲では、大まかな流れを追うことにとどめる。長くなってしまうので畳み込みの操作や、クロスエントロピーなどの各種の詳細については別記事で補足することにする。

学習ループ

具体的に、最適化の指針となるのがLoss functionである。学習の各ステップでLoss functionをより小さくするように、ウェイトを更新していく。この時ウェイトを更新するアルゴリズムは複数あり、学習の実装ではOptimizer(torch.optim)にそのアルゴリズムを指定する。

import torch.optim as optim

#Loss Functionを定義
criterion = nn.CrossEntropyLoss()

#Optimizerを定義
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

#学習ループ(今回はエポック数2)
for epoch in range(2): 

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
         
        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        # 学習経過を表示する出力部
        if i % 2000 == 1999:
            print('[%d, %5d] loss: %.3f' %
                 (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

このループ内で、inputsがネットワークに入力するデータで、4*3*32*32の構造を持っている。また、outputs = net(inputs)でネットワークから出力を得ており、これが10個の要素からなる。

以下で、ネットワークからの出力にLoss functionを評価する。ここではクロスエントロピーを用いているが、基本的には実際のラベルが得られる確率を評価しているようなものである(もとろんそのままの意味ではないので注意、詳細は別記事で補足)。

loss = criterion(outputs, labels)

学習はこのloss functionを最小にことを指導原理にして行われる。以下の部分で、ウェイトの更新を行う。

loss.backward()
optimizer.step()

以上が学習ループの実装の基本的な流れになっている。

running_loss += loss.item()について解説する。loss.item()が各ステップでのLoss functionの値である。これを足し上げている量がrunning_loss。結局、出力の際running_loss / 2000を出力し、その後初期化しているのが読み取れる。結局、これは2000ステップ(2000ミニバッチ)ごとのロスファンクションの平均値を与える。

ここで一つ注意を書いておく。pytorchでは、逆伝搬を行う前、すなわちloss.backward()を呼ぶ前にoptimizer.zero_grad()で初期化しなければならない。これはpytorchがデフォルトでloss.backward()は勾配を足し上げていくからである。これはRNN(回帰ニューラルネットワーク)を用いる時便利らしい。 これは以下の記事を参考にした。

python - Why do we need to call zero_grad() in PyTorch? - Stack Overflow

ウェイトの保存

以上でネットワークの学習が完了したので、今後この結果を用いることができるようにこの学習後のネットワークのウェイトを保存する。これにはtorch.save()関数を用いれば良い。以下のように保存するネットワークと保存先のパスを指定して使う。

PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)

実際に、以下の図のように指定したパスに保存される。

f:id:Tatsupuri:20201126214050p:plain

まとめ

今回はネットワークを構成して学習によって、ウェイトを最適化した。またその学習結果を保存した。次回はこの結果の解析を行っていく。

今回は触れなかったが、学習ループ内でtorch.save()を呼び出すこともできる。そうすると、学習の途中のパラメータも保存しておけるので学習されていく過程を見たりすることもできる。これは次回、過学習について考察する際に使う。

補足

モデルの構成に関して、上で紹介したもの(チュートリアルで導入されていたもの)と等価でよりシンプルな書き方に、nn.Sequentialを用いた以下のような書き方もある。個人的にはこちらの方がネットワークの構造が良く見える気がするのでこちらの書き方を積極的に使っている。

class Net(nn.Module):
    def __init__(self):
            super(Net, self).__init__()
            self.main = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.ReLU(True),
            nn.Linear(16 * 5 * 5, 120),
            nn.Linear(120, 84)
            nn.Linear(84, 10)
        )

    def forward(self, input):
        return self.main(input)