あきらぼ

テック系ブログ

Tensorflowの画像分類器の学習設定で損失が減らずハマった話

こんにちは。

現在、VTuberアイドルグループであるHololiveの画像分類器を作成しているのですが、CNNのモデル設定でハマったのでそのことを書こうと思います。

ちなみにHololiveはこんなグループです。
所属タレント一覧 | hololive(ホロライブ)公式サイト




前提条件

まず、前提条件としては学習データとしてはホロライブメンバーのうち9人分の画像データをGoogle画像検索で集めました。
以下の方法です。
aki-lab.hatenadiary.com

そこで一人当たり400-500枚、計4175枚の画像になります。
これらのデータの内8割を学習データ、2割を検証データとします。

分類モデルとしては有名なvgg16を転移学習させる形でモデルを組みました。
最後に全結合層で512個のニューロンを追加しています。

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 rescaling (Rescaling)       (None, 256, 256, 3)       0         
                                                                 
 random_rotation (RandomRota  (None, 256, 256, 3)      0         
 tion)                                                           
                                                                 
 random_zoom (RandomZoom)    (None, 256, 256, 3)       0         
                                                                 
 random_flip (RandomFlip)    (None, 256, 256, 3)       0         
                                                                 
 random_translation (RandomT  (None, 256, 256, 3)      0         
 ranslation)                                                     
                                                                 
 vgg16 (Functional)          (None, None, None, 512)   14714688  
                                                                 
 flatten (Flatten)           (None, 32768)             0         
                                                                 
 dense (Dense)               (None, 512)               16777728  
                                                                 
 dropout (Dropout)           (None, 512)               0         
                                                                 
 dense_1 (Dense)             (None, 9)                 4617      
                                                                 
=================================================================
Total params: 31,497,033
Trainable params: 23,861,769
Non-trainable params: 7,635,264
_________________________________________________________________

VGG16のパラメータの内、下流4層だけ学習可能にしています。
このうちrandom_rotation、random_zoom、random_flip、 random_translationはデータ強化層
dropoutは過学習対策の層となります。

今回は学習データが少ないので、データ強化層を4層入れています。
データ強化方法は、データ強化層を追加する方法以外に、そもそも元の学習データを事前に水増しする方法があります。(openCVやPillowを使って)

学習

ここから実際に学習をさせていきます。
パラメータ最適化方法としてはAdamを使用しました。

そして、バッチサイズや学習率を変化させて学習するようにします。
実際には全結合層の数や最適化手法も変更して色々試したのですが、なかなか教師データに対する損失(Loss)が減らずに悩んでいました。
結論から言うと、過学習対策を入れたままモデル検証していたからでした。

一例として過学習対策有りと無しで学習率の変化の影響がどう出るのか例を載せます。

過学習対策あり

以下は過学習対策ありのまま学習率を変化させた場合です。

確かに学習率を小さくして改善しているようですが、精度が低いままです。
これは過学習対策が教師データを各エピックで変更してしまっているために教師データに対しての精度も損失もあまり改善していないように見えてしまいます。

この状態だとモデルの表現力が分類に必要なだけあるかどうかも判断がつきづらいです。
エポック数を増やせば判断は付きますが、時間がかかってしまうのでモデル選定のトライ&エラーが大変になってしまします。

過学習対策なし

次に同じモデルで過学習対策を抜いた場合を見てましょう。

実際に過学習対策を抜いたモデルがこちら。

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 rescaling (Rescaling)       (None, 256, 256, 3)       0         
                                                                 
 vgg16 (Functional)          (None, None, None, 512)   14714688  
                                                                 
 flatten (Flatten)           (None, 32768)             0         
                                                                 
 dense (Dense)               (None, 512)               16777728  
                                                                 
 dense_1 (Dense)             (None, 9)                 4617      
                                                                 
=================================================================
Total params: 31,497,033
Trainable params: 23,861,769
Non-trainable params: 7,635,264
_________________________________________________________________

このモデルで同様に学習率違いを見てみましょう。

こちらの場合は学習率を小さくすることで学習が大きく改善され、損失が落ちて行っていることが一目で分かります。
もちろん、過学習対策を入れていないので、検証用データに対する損失と精度は悪いままですが、過学習対策を入れれば改善されます。

実際にこの学習率で過学習対策を入れてエポック数を増やせばかなり精度が上がります。

結論

今回の結果から過学習対策はモデル選定の段階では入れない方が良いということが分かりました。
過学習対策なしの状態で教師データに対して十分最適化され、精度が上がる状態で過学習対策を順次実装していく必要があります。

順番でいうとこんな感じ。

  1. 過学習対策なしでモデル選定・最適化パラメータ調整
  2. 教師データに十分フィッティングできる(低損失高精度)モデルであることを確認
  3. 過学習対策を追加

実際に過学習対策なしでフィッティングが進んだモデルでエポック数を増やしていくと以下のように精度が上がりました。

現状の分類器はこちら。

github.com

このような分類器作成の参考になる本がこちら。Keras、Tensorflowによるモデル作り方が分かりやすく解説されています。

引き続き分類器の作成を続けていこうと思います。