こんにちは。
現在、VTuberアイドルグループであるHololiveの画像分類器を作成しているのですが、CNNのモデル設定でハマったのでそのことを書こうと思います。
ちなみにHololiveはこんなグループです。
所属タレント一覧 | hololive(ホロライブ)公式サイト
前提条件
まず、前提条件としては学習データとしてはホロライブメンバーのうち9人分の画像データをGoogle画像検索で集めました。
以下の方法です。
aki-lab.hatenadiary.com
そこで一人当たり400-500枚、計4175枚の画像になります。
これらのデータの内8割を学習データ、2割を検証データとします。
分類モデルとしては有名なvgg16を転移学習させる形でモデルを組みました。
最後に全結合層で512個のニューロンを追加しています。
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= rescaling (Rescaling) (None, 256, 256, 3) 0 random_rotation (RandomRota (None, 256, 256, 3) 0 tion) random_zoom (RandomZoom) (None, 256, 256, 3) 0 random_flip (RandomFlip) (None, 256, 256, 3) 0 random_translation (RandomT (None, 256, 256, 3) 0 ranslation) vgg16 (Functional) (None, None, None, 512) 14714688 flatten (Flatten) (None, 32768) 0 dense (Dense) (None, 512) 16777728 dropout (Dropout) (None, 512) 0 dense_1 (Dense) (None, 9) 4617 ================================================================= Total params: 31,497,033 Trainable params: 23,861,769 Non-trainable params: 7,635,264 _________________________________________________________________
VGG16のパラメータの内、下流4層だけ学習可能にしています。
このうちrandom_rotation、random_zoom、random_flip、 random_translationはデータ強化層
dropoutは過学習対策の層となります。
今回は学習データが少ないので、データ強化層を4層入れています。
データ強化方法は、データ強化層を追加する方法以外に、そもそも元の学習データを事前に水増しする方法があります。(openCVやPillowを使って)
学習
ここから実際に学習をさせていきます。
パラメータ最適化方法としてはAdamを使用しました。
そして、バッチサイズや学習率を変化させて学習するようにします。
実際には全結合層の数や最適化手法も変更して色々試したのですが、なかなか教師データに対する損失(Loss)が減らずに悩んでいました。
結論から言うと、過学習対策を入れたままモデル検証していたからでした。
一例として過学習対策有りと無しで学習率の変化の影響がどう出るのか例を載せます。
過学習対策あり
以下は過学習対策ありのまま学習率を変化させた場合です。
確かに学習率を小さくして改善しているようですが、精度が低いままです。
これは過学習対策が教師データを各エピックで変更してしまっているために教師データに対しての精度も損失もあまり改善していないように見えてしまいます。
この状態だとモデルの表現力が分類に必要なだけあるかどうかも判断がつきづらいです。
エポック数を増やせば判断は付きますが、時間がかかってしまうのでモデル選定のトライ&エラーが大変になってしまします。
過学習対策なし
次に同じモデルで過学習対策を抜いた場合を見てましょう。
実際に過学習対策を抜いたモデルがこちら。
Model: "sequential" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= rescaling (Rescaling) (None, 256, 256, 3) 0 vgg16 (Functional) (None, None, None, 512) 14714688 flatten (Flatten) (None, 32768) 0 dense (Dense) (None, 512) 16777728 dense_1 (Dense) (None, 9) 4617 ================================================================= Total params: 31,497,033 Trainable params: 23,861,769 Non-trainable params: 7,635,264 _________________________________________________________________
このモデルで同様に学習率違いを見てみましょう。
こちらの場合は学習率を小さくすることで学習が大きく改善され、損失が落ちて行っていることが一目で分かります。
もちろん、過学習対策を入れていないので、検証用データに対する損失と精度は悪いままですが、過学習対策を入れれば改善されます。
実際にこの学習率で過学習対策を入れてエポック数を増やせばかなり精度が上がります。