たつぷりの調査報告書

博士後期課程(理学)の学生が趣味でUnityやBlenderで遊ぶブログです。素人が独学で勉強した際の忘備録です。

分類器で機械学習入門(その2)- 学習

こんにちは。たつぷりです。引き続き、CIFAR10の分類問題についてまとめます。前回はデータセットの準備とその可視化を行いました。今回は実際にネットワークを構成し、学習によりネットワークを最適化することをおこないます。

今回も自分用のノートをブログにのせているので実験的に進んでいたりと、スマートではない進み方をしますのでご了承下さい。

目標

この記事での目標は以下である。

  • 畳み込みニューラルネットワークのモデルを構成する
  • 学習によりネットワークを最適化する
  • 学習したネットワークのウェイトを保存する

また、今回もpytorch公式のチュートリアル

Training a Classifier — PyTorch Tutorials 1.7.0 documentation

に沿って進んでいく。

モデルの定義とその構造

学習させるモデルを作成する、すなわち学習するネットワークを構成する。具体的にpytorchでは、torch.nnクラスを継承するネットワーククラスを定義することで可能である。ここではNetクラスを作る。

import torch.nn as nn
import torch.nn.functional as F

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(6, 16, 5)
        self.fc1 = nn.Linear(16 * 5 * 5, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = x.view(-1, 16 * 5 * 5)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x

以下ではこのネットワークの構造について見ていく。ネットワークの定義の仕方に関して個人的により見やすいと思う方法を補足に書いた。

構造

上の書き方ではコンストラクタでネットワーク内で用いる操作を定義しており、実際のネットワークの構成はforwardに定義されている。

上から順に見てくと、

入力

畳み込み
↓Relu
マックスプ-リング

畳み込み
↓Relu
マックスプーリング
↓viewで成型
線形変換

線形変換

線形変換

出力

という流れになっている。

入力と出力

まずネットワークの入力は、教師データのミニバッチである。ここではCIFAR10の画像をバッチ数4でデータセットを作ったので、4*3*32*32の構造を持ったテンソル型のデータを入力として受け取る。

また最後の線形変換を見ると、10次元のベクトルが出力になっていることが読み取れる。つまりこのネットワークは真ん中をいったん無視すれば、

4*3*32*32の構造を持つをテンソル量を入力して10次元のベクトルを出力する"関数"として機能している。入力はもちろん各ミニバッチの教師データで、10個の出力はCIFAR10の10個のクラスに対応している。

最終的にこのネットワークを”学習”することで

"画像データ"を入力すると”クラス”を返す”関数”として機能するようになるのである

学習

上で定義したネットワークの学習を行う。上述のとおりネットワークは画像データを入力すると10次元のベクトルを返すので、学習とは教師データに対して正しいラベルに対応する出力が得られるようにネットワークのウェイトを最適化することに他ならない。

この記事の範囲では、大まかな流れを追うことにとどめる。長くなってしまうので畳み込みの操作や、クロスエントロピーなどの各種の詳細については別記事で補足することにする。

学習ループ

具体的に、最適化の指針となるのがLoss functionである。学習の各ステップでLoss functionをより小さくするように、ウェイトを更新していく。この時ウェイトを更新するアルゴリズムは複数あり、学習の実装ではOptimizer(torch.optim)にそのアルゴリズムを指定する。

import torch.optim as optim

#Loss Functionを定義
criterion = nn.CrossEntropyLoss()

#Optimizerを定義
optimizer = optim.SGD(net.parameters(), lr=0.001, momentum=0.9)

#学習ループ(今回はエポック数2)
for epoch in range(2): 

    running_loss = 0.0
    for i, data in enumerate(trainloader, 0):
        inputs, labels = data
         
        optimizer.zero_grad()

        outputs = net(inputs)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()
        
        running_loss += loss.item()
        
        # 学習経過を表示する出力部
        if i % 2000 == 1999:
            print('[%d, %5d] loss: %.3f' %
                 (epoch + 1, i + 1, running_loss / 2000))
            running_loss = 0.0

print('Finished Training')

このループ内で、inputsがネットワークに入力するデータで、4*3*32*32の構造を持っている。また、outputs = net(inputs)でネットワークから出力を得ており、これが10個の要素からなる。

以下で、ネットワークからの出力にLoss functionを評価する。ここではクロスエントロピーを用いているが、基本的には実際のラベルが得られる確率を評価しているようなものである(もとろんそのままの意味ではないので注意、詳細は別記事で補足)。

loss = criterion(outputs, labels)

学習はこのloss functionを最小にことを指導原理にして行われる。以下の部分で、ウェイトの更新を行う。

loss.backward()
optimizer.step()

以上が学習ループの実装の基本的な流れになっている。

running_loss += loss.item()について解説する。loss.item()が各ステップでのLoss functionの値である。これを足し上げている量がrunning_loss。結局、出力の際running_loss / 2000を出力し、その後初期化しているのが読み取れる。結局、これは2000ステップ(2000ミニバッチ)ごとのロスファンクションの平均値を与える。

ここで一つ注意を書いておく。pytorchでは、逆伝搬を行う前、すなわちloss.backward()を呼ぶ前にoptimizer.zero_grad()で初期化しなければならない。これはpytorchがデフォルトでloss.backward()は勾配を足し上げていくからである。これはRNN(回帰ニューラルネットワーク)を用いる時便利らしい。 これは以下の記事を参考にした。

python - Why do we need to call zero_grad() in PyTorch? - Stack Overflow

ウェイトの保存

以上でネットワークの学習が完了したので、今後この結果を用いることができるようにこの学習後のネットワークのウェイトを保存する。これにはtorch.save()関数を用いれば良い。以下のように保存するネットワークと保存先のパスを指定して使う。

PATH = './cifar_net.pth'
torch.save(net.state_dict(), PATH)

実際に、以下の図のように指定したパスに保存される。

f:id:Tatsupuri:20201126214050p:plain

まとめ

今回はネットワークを構成して学習によって、ウェイトを最適化した。またその学習結果を保存した。次回はこの結果の解析を行っていく。

今回は触れなかったが、学習ループ内でtorch.save()を呼び出すこともできる。そうすると、学習の途中のパラメータも保存しておけるので学習されていく過程を見たりすることもできる。これは次回、過学習について考察する際に使う。

補足

モデルの構成に関して、上で紹介したもの(チュートリアルで導入されていたもの)と等価でよりシンプルな書き方に、nn.Sequentialを用いた以下のような書き方もある。個人的にはこちらの方がネットワークの構造が良く見える気がするのでこちらの書き方を積極的に使っている。

class Net(nn.Module):
    def __init__(self):
            super(Net, self).__init__()
            self.main = nn.Sequential(
            nn.Conv2d(3, 6, 5),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(6, 16, 5),
            nn.ReLU(True),
            nn.MaxPool2d(2, 2),
            nn.ReLU(True),
            nn.Linear(16 * 5 * 5, 120),
            nn.Linear(120, 84)
            nn.Linear(84, 10)
        )

    def forward(self, input):
        return self.main(input)