あきらぼ

テック系ブログ

【Tensorflow】【Keras】CNNでHololiveメンバーの識別をしてみる

こんにちは。

今回は先日のモデルに引き続き、ざっくりと最初のHololiveメンバー識別機を作ってみたのでそのことを書こうと思います。


まず今回作った識別機がどのように動くかの動画がこちらになります。

youtu.be

 

 

 

 

背景

今回識別器を作った理由としては最近流行っているHololiveのメンバーがどんどん増えていって覚えきれません。

しかし、サムネイル等には出てくるし、いちいち調べるのも面倒。

そこでAIを使って判定させてしまおうというのが今回作った動機になります。

 

今回参考にしたのはこちらの本です。実際のKerasのコードまで書いてあって分かりやすかったです。

Pythonによるディープラーニング (Compass Booksシリーズ) 単行本(ソフトカバー)

 

教師データ収集

まず、教師データ集めです。

今回はGoogle検索を使ってHololiveメンバー58名分のデータをを集めました。

メンバー一覧はこちら。

所属タレント一覧 | hololive(ホロライブ)公式サイト

 

画像収集方法はこちらです。

aki-lab.hatenadiary.com

 

全部で集まった画像データは24254枚、平均一人あたり約400万になります。

しかし、インドネシア勢とはまだ画像データが少なく、あまり検出できていませんでした。

例えば、こぼ・かなえるはデータ数が少なく、集めたデータも本に以外のデータばかりだったので正しく識別できていませんでした。

 

 

そこで、後でYoutubeのサムネイル等から画像を集めて、正しいデータだけを残すようにすればまた精度が上がると思います。

 

識別CNNモデル

次に今回のモデルにはKerasに既に用意されている識別でよく使われるVGG16モデルをベースに使用します。

入力画像サイズは256×256です。

 

Model: "sequential"
_________________________________________________________________
 Layer (type)                Output Shape              Param #   
=================================================================
 rescaling (Rescaling)       (None, 256, 256, 3)       0         
                                                                 
 random_rotation (RandomRota  (None, 256, 256, 3)      0         
 tion)                                                           
                                                                 
 random_zoom (RandomZoom)    (None, 256, 256, 3)       0         
                                                                 
 random_flip (RandomFlip)    (None, 256, 256, 3)       0         
                                                                 
 random_translation (RandomT  (None, 256, 256, 3)      0         
 ranslation)                                                     
                                                                 
 vgg16 (Functional)          (None, None, None, 512)   14714688  
                                                                 
 flatten (Flatten)           (None, 32768)             0         
                                                                 
 dense (Dense)               (None, 512)               16777728  
                                                                 
 dropout (Dropout)           (None, 512)               0         
                                                                 
 dense_1 (Dense)             (None, 58)                29754     
                                                                 
=================================================================
Total params: 31,522,170
Trainable params: 23,886,906
Non-trainable params: 7,635,264
_________________________________________________________________

 

上流のRandom層は過学習対策になります。

最初からこの層を追加した状態で学習のハイパーパラメータ等設定していたので少し沼ってしまいました。

詳細はこちら。

 

aki-lab.hatenadiary.com

 

また、VGG16の内、下流側4層のみをパラメータ学習可能としています。

 

学習

まず、学習にあたって集まったデータの80%を学習用、20%を検証用としてランダムに分割します。

学習・最適化は以下のようにしています。

  • Optimizer:Adam
  • 損失:Categorical Crossentropy
  • 学習率:1.0e-4
  • バッチサイズ:128

 実際に学習を始めると、すぐに過学習に入ります。

 

10エポックを過ぎたあたりで過学習が始まっています。

最大精度で60%を少し超えるぐらいでした。

そこで最大精度のものを使用したのが動画で使用しているモデルになります。

 

識別器

最後にモデルを学習してそれで終わりでは、全く役に立たないので、実際にYotubeや画像で範囲を選択すると判定してくれるようなアプリを作ってみました。

 

今回はTkinterを使用して簡単なGUIの識別器を作りました。

詳細な説明はここではしませんが、ウィンドウを透明の256X256ピクセルのサイズにすることで、そこの表示を読み取ってモデルで判定させています。

 

一番確度(そうである可能性)の高いメンバー名を表示するようにしています。

詳しくはソースコードをご参照ください。

 

問題点

実際使ってみた問題点として、256x256のウィンドウサイズだと、二人以上のメンバーが同時にウィンドウ内に入ってしまうケースが多く、どちらを検出しているのか分かりません。

そこで領域も検出できるようにすると便利さ増すかなと思っています。

そのあたりも使える方法がないか調べて実装してみようと思います。

 

 

 

今回使用したコードは以下になります。(適宜アップデートするかもしれませんが。)

GitHub - Aki-R/Hololive_Categorizer: CNN categorizer for VTuber Holilive members

 

また、識別対象を増やしたり、精度改善も暇な時にでもしていこうと思います。

(人力で画像データ集めと振り分けが少し必要そう。こういう仕事をアノテーターというらしいです(笑))