分類器で機械学習入門(その4)- 結果の解釈
こんにちは。たつぷりです。前回に引き続きCIFAR10の分類器についてのノートです。このシリーズは、pytorch公式のチュートリアルの解説とそれをより理解するために自分でやったことなどをまとめたものです。
前回までで「データの準備」、「学習」、「汎化誤差の評価」などを見てきた。今回は、学習済みのネットワークから何が言えるかを考察したいと思います。
目標
今回の目標は、ネットワークの出力に対してSoftmaxスコアを評価し、それを可視化して傾向を分析することである。そこから、今回のモデルの弱点などを考察していく。
Softmaxスコアの評価
ここでは、ネットワークからの出力に対する、Softmax関数を計算しその可視化を行う。
Softmax関数
まず、Softmax関数の定義とその性質を簡単にまとめておく。
まずSoftmax関数の定義は以下の通りである。に対してソフトマックス関数は、以下で与えられる。
ただし、
である。
これは、私のように物理をやっている人間からすると明らかであるが、これはボルツマンウェイトである。なのでわざわざ、分母をという物理屋にはなじみ深い分配関数の形で書いておいた。
次にSoftmax関数の性質をまとめる。Softmax関数からは引数のデータが局在しているか?などの情報を得ることができる(この表現が正しいか微妙な気もするが...)。
例えば、任意の に対してであったとすると、
になる。その一方で、であるような場合は
である。が十分大きければ]、 は1に近づき、は0に近づく。
このように入力データがどれも同じ程度の場合はSoftmax関数の値はいずれも[n]]に近づく。一方、突出して大きい値があればそれに対応するSoftmax関数の値は1に近づき、それ以外は0に近づくという性質を持つ。
以上の性質から、今回やっているようなネットワークからの出力に対してのSoftmax関数を考えたときは、ネットワークが入力画像の分類を明確にできているか否かが分かる。いわば、Softmax関数を評価することでどの程度自信をもって入力画像の分類を行ったかが分かるのである。
2次元での例
より具体的にSoftmax関数の性質をまとめておく目的で、2次元のベクトルに対してのSoftmax関数を見ておく。
に対して、Softmax関数は
、
である。よって上で議論した通り、の時は、
である。
ここでは単調増加関数であるので、が成立すれば、
が成立することがすぐに分かる。さらに、ならば上で議論した通り、
であることが見て取れる。
実装
Softmax関数の評価を実装する。
from scipy.special imp net = Net() net.load_state_dict(torch.load(PATH)) prediction = [] sm_list = [] label_list = [] with torch.no_grad(): for data in testloader: images, labels = data outputs = net(images) _, predicted = torch.max(outputs.data, 1) label_list.append(labels.numpy()) sm = softmax(outputs.numpy(),axis=1) sm_list.append(sm) sm_array = np.concatenate(sm_list) l_array = np.concatenate(label_list)
ここで、具体的に何をやっているかを見るためにループの最後におけるいくつかの変数を出力する。
print(sm.shape) # (4, 10) が出力される print(sm_array.shape) # (10000, 10) が出力される
まず上の出力を理解する。ネットワークからの出力であるoutputs
は4*10
の次元を持っている。これは、ミニバッチに含まれる4つの画像に対してそれぞれ10個の出力があるからである。
さて、softmax(array)
は、array
に対してのソフトマックス関数を適用する(レファレンス)。
ソフトマックス関数の重要なパラメータはaxis
である。第一引数のarray
は一般に多次元配列でも許される。この際、どこで規格化するかを指定するのがこのパラメーターである。例えば、上ではaxis = 1
にしているがこれはaxis = 1
に沿って和をとったとき、1になるように規格化するという意味である。
つまり今回のケースだと、第一引数にはoutputs.numpy()
であるが、この型は4 * 10
であった。axis = 1
というのは10
の方を意味する。つまりこの10個の量に対して足したら1になるように規格化を行うことになる。実際、下でそれが確認できる。
sm[0].sum() #1.0が出力される
このままだと、ソフトマックススコアがミニバッチに含まれる4つのデータごとに出力されている。しかし最終的な分析においてはミニバッチごとではなくデータ全体で分析したいので、これらを一つにまとめる。実際この場合では、sm
は4*10
の構造を持っていて、2500回ループする。この各ループでsm
をsm_list
に追加していくので、sm_list
は2500*4*10
の構造をもつ。
np.concatenate(sm_list)
はこのリストの平滑化を行う(レファレンス)。デフォルトではaxis = 0
であるので、0軸にそって結合を行うことになる。
今回の場合は2500*4*10->10000*10
のように構造を変える。
可視化
基本的には、クラスの数で円を等分して、各動径方形に単位ベクトルを用意する。これをSoftmax関数を係数として足したベクトルの位置に点を打っていくことで、Softmaxスコアの分布をみることができる。
この考えの下で以下のように実装した。
num_class = len(sm_array[0])#クラスの数(ここでは10) #各クラスに対応する単位ベクトルの準備 unit_x = np.array([ np.cos(2 * np.pi * i / num_class) for i in range(num_class)]) unit_y = np.array([ np.sin(2 * np.pi * i / num_class) for i in range(num_class)]) ##############図の成型部分################## #出力する図の大きさと背景の色の設定 fig,axis=plt.subplots(figsize=(8,8),facecolor='w') #10角形の表示 plt.plot(np.append(unit_x,unit_x[0]),np.append(unit_y,unit_y[0])) #グラフの軸を省略する axis.tick_params(axis='both',which='both',bottom=False,top=False,left=False,right=False,labelleft=False,labelbottom=False) #マージンの設定 plt.xlim(-1.3,1.3) plt.ylim(-1.3,1.3) ########################################## #メインの出力 for label in range(num_class): idx = np.where(l_array == label) sm_class = sm_array[idx] list_x = [] list_y = [] for sm in sm_class: list_x.append(np.dot(unit_x , sm)) list_y.append(np.dot(unit_y , sm)) plt.scatter(list_x,list_y,alpha=0.3) #クラス名の表示 #先頭のクラスが(0,0)の位置に来るので、ここで文字を-90度回転させる。そこから順番に360/10度ずつ足していけば円に沿って字が出力される。 for d,name in enumerate(classes): plt.text(unit_x[d]*1.1, unit_y[d]*1.1,str(name),fontsize=18,ha='center',va='center',rotation= -90 + 360/num_class*d)#-90 + 2 * np.pi/num_class*d/np.pi*180 plt.show()
以上を実行すると、次のような図が得られた。
考察
上で得られた結果に対して考察をする。
まずSoftmax関数から、
- きちんと識別できているデータは上の図では10角形の頂点に近づく
- 2つのクラスの間で判断がついていない時は、その2頂点を結ぶ直線状にのる
などの性質が分かるので、それを踏まえて上の図を分析する。
生物/無生物
まず、次の図のようにデータが集積している領域は大まかに2つに分けられることが見て取れる。
各頂点に対応するラベルを参考にすると、この領域はそれぞれ、生物(鳥、猫、鹿、犬、カエル、馬)と無生物(船、トラック、飛行機、車)の領域だと分かる。
つまり、学習によって作成した分類器は少なくとも生物か無生物か?程度の区別は比較的ついていることが理解できる(無論そうではない原点付近のデータもある)。
識別しずらいデータ?
また、SoftMaxの分布から以下のような特徴的な集積点(線)も見て取れる(ここで図示した限りではない)。
これによって例えば、このネットワークは「車」と「トラック」の識別があまりうまくいっていない、などと読み取ることができそうである。ほかにも「馬」と「鹿」など、言われてみれば確かに識別がうまくいっていないといわれても個人的に納得できるものであったので、このSoftmaxの分布から今回の学習の結果はそれなりに妥当であると考えられる。
まとめ
Softmax関数を評価して図示することで学習によって得られたネットワークの性質を考察することができた。 ただ、この辺の分析について筆者はまだ確信をもって議論できるほど勉強が進んでいないので今後改めて考察していくつもりである。
また、例えば馬と鹿のように識別の難しいクラスがいくつかあったが、これらをより精度よく分類する方法について今後考えたい。