こんにちは、エンジニアの中村です。

今回は、ニューラルネットワークの蒸留に関する連載記事の第3回目の記事になります。

第1回では、蒸留という手法の端緒であるモデル圧縮から始めて、蒸留の基本的な概念を説明し、その後の蒸留の発展を含めて、蒸留という分野の概観を紹介しました。第2回では、蒸留によるモデル圧縮について、分類タスクと回帰タスクそれぞれについて個別の研究を取り上げて、それらの詳細を紹介しました。特に、分類タスクでは蒸留の基本的な効果と蒸留によって得られる正則化効果について、また,回帰タスクでは蒸留を応用するために必要な工夫を紹介しました。

第3回目にあたる今回は、第1回で概要に触れた蒸留の発展について、個別の研究を取り上げて、詳細を紹介していきます。モデル圧縮のための手法として考案された蒸留ですが、「教師モデルの出力を生徒モデルの学習に使用する」という蒸留の基本コンセプトは、モデル圧縮以外の用途にも広く応用されています。蒸留を応用させた研究についてすべてを取り上げるわけではありませんが、蒸留の幅広い応用先について、以下で具体例を紹介していきます。蒸留の応用先の中でも、ディープラーニングモデル開発の中で最も関心が高いであろう精度向上に関する研究例を取り上げます。

蒸留の発展に関する具体的な研究に先立って、まず概要を知りたいという方は、ぜひ蒸留 第1回の記事を合わせてご覧ください。

それでは、蒸留によるディープラーニングモデルの精度向上について、個別の研究の紹介に入りましょう。この記事では2つの手法を紹介します。

目次

データを有効活用して精度向上

モデルの精度向上のためには、大きく分けて、モデル、データ、訓練方法のいずれかの要素を改良すればよいと考えられます。蒸留 第1回で説明したように、蒸留では、データと訓練方法の2つについて精度向上のための手法を提案します。以下でまずはじめに、データの利用可能性を工夫することによるモデル精度向上のための蒸留を紹介します。

訓練データを増やすことで、ディープラーニングモデルの汎化性能を高めることができるということは、ディープラーニングにおいて一般的に知られています。しかしながら、既存の訓練データに追加してさらに大量の訓練データを用意することは、その分 各データに正解ラベルを付与するアノテーション作業のコストが高くつきます。このアノテーションコストを回避するために、追加の訓練データについては、何らかの形で機械的・自動的にアノテーションをする方法がほしいところです。以下で紹介する Data Distillation (Data Distillation: Towards Omni-Supervised Learning [Radosavovic et al., 2017]) では、教師モデルの出力を正解ラベルとして使うことで、正解ラベルのない訓練データの利用を可能にしています。

手法

教師モデルの出力を正解ラベルとして使うと述べましたが、当然のことながら、可能な限り正確な正解ラベルを得たいはずです。ディープラーニングによる教師あり学習は確かに十分に正確であることが多いため、教師モデルの出力をそのまま正解ラベルとして使うこともできそうです。しかしながら、より正確な正解ラベルを得るために、Data Distillation では、正解ラベルのない訓練データそれぞれに対して複数の変換を施し、変換されたデータそれぞれに対する教師モデルの出力をアンサンブル (平均等) するという方法を取っています。

transform A, B, C という3種類の変換を入力画像に施し、教師モデルA に入力。変換された入力それぞれに対する教師モデルの出力をアンサンブルして、生徒モデルに正解として与える [Radosavovic et al., 2017]
変換後の入力に対する出力をアンサンブルすることで、より正確な出力を得るという考え方は、Test Time Augmentaion と同じです (Test Time Augmentation では、推論時に、例えば、元の画像データとそれを左右反転させた画像データの両方に対する出力を得て、それらの出力を平均化することでより正確な予測を得る、ということが行われます)。

全体をまとめると、Data Distillation は以下のステップにしたがって実現されます。

  1. 十分な量の訓練データで、教師モデルを訓練 (通常の教師あり学習)
  2. 正解ラベルのない訓練データに対して変換を施して、教師モデルに入力、正解ラベルのない訓練データそれぞれに対して複数の出力を得る。
  3. 2. で得た教師モデルの出力をデータごとにアンサンブル (平均を取る等) して、教師モデルによるアノテーションを作成。
  4. 教師モデルによるアノテーションを、1. の訓練データに追加して、生徒モデルを訓練。

正解ラベルなしデータを利用するという点において、Data Distillation は半教師あり学習と類似した手法と捉えられます。実際に、論文内で言及されているように、半教師あり学習の一手法と位置付けることができます (半教師あり学習との主な違いについては、蒸留 第1回 をご参照ください)。

実験・結果

本論文では、Data Distillation の効果を、人の関節点検出 (Keypoint Detection) と物体検出 (Object Detection) のタスクで検証し、Data Distillation は、関節点検出と物体検出の両方で有効であることが報告されています。以下では、関節点検出の実験について紹介します。

モデルには、特徴抽出のネットワークアーキテクチャに ResNet-50/101 または ResNeXt-101 を採用した Mask-RCNN を使用しています。モデル圧縮が目的ではないため、教師モデルと生徒モデルは同じモデルアーキテクチャを持ちます。同じアーキテクチャを持ちますが、生徒モデルは、正解ラベルありデータで訓練された教師モデルからの再学習ではなく、ImageNet の訓練済み重みを使用しています。

Data Distillation 時に施される変換には、スケーリングと左右反転が採用されています。注意点として、関節点検出であるため、変換後のデータを入力した教師モデルの出力は、その変換の逆の変換を施して、元に戻してから使用しています (例えば、左右反転の変換がされた画像に対する教師モデルの出力を再度左右反転して元に戻してから、無変換の画像に対する出力と平均して、正解ラベルとして使用しています)。

実験はデータの性質別に大きく2つ実施されています。1つは、正解ラベルのある訓練データと正解ラベルのない訓練データが類似する実験で、もう1つは、それらが類似しない実験です。以下で、具体的に紹介してきます。

正解ラベルありデータと正解ラベルなしデータが類似する場合。この実験では、正解ラベルのある訓練データに 2017年版 COCO データセット (115,000 サンプル)、正解ラベルのない訓練データに 2017年版 COCO のラベルなしデータセット (120,000 サンプル) を使用しています。同じ COCO データセットのため、正解ラベルありデータと正解ラベルなしデータは類似しています (別の言い方をすれば、両者は類似した分布を持ちます)。

COCO dataset より

訓練データ (正解ラベルあり/なし) とは異なる COCO のテストデータに対して、正解ラベルありのデータのみで訓練したモデルと、Data Distillation で正解ラベルなしのデータも使用して訓練したモデルの精度を比較した表が以下になります。特徴抽出のネットワークアーキテクチャによらずに、Data Distillation (DD) を実施した方が高精度という結果となっています。

backbone は特徴抽出のネットワークに使用したCNNアーキテクチャの種類。
DD にチェックマークが付いている場合、Data Distillation が実施された (論文内 Table 1. (b) を引用)

精度 (AP) の向上を定量的に見ると、この実験では平均1.85程度 精度が向上しています。精度の向上率で見ると、平均して約3% 精度が向上しています。

正解ラベルありデータと正解ラベルなしデータが類似しない場合。以上の実験では、正解ラベルありデータと類似する正解ラベルなしデータが Data Distillation の形で有効活用でき、精度向上につながったことを示しています。実は、Data Distillation は、正解ラベルなしデータが類似しない場合でも、そのデータを有効活用することができることが報告されています。

この実験では、正解ラベルなしデータに、Sports-1M video dataset を使用しています。動画のデータセットですが、動画から静止画を抽出し、合計 180,000 サンプルのデータセットを作成しています。このデータセットは、 その名前の通りスポーツにおける場面を写したデータセットであり、一般的な場面を写した COCO データセットと比較すると趣きの異なったデータセットになっています。

(論文内 Figure 3 を引用)

結果は以下の通りで、正解ラベルなしデータが類似していない場合でも、Data Distillation が有効であったことが示されています。精度は平均1.35程度向上しており、平均して約2% 精度が向上しています。

(論文内 Table 1. (c) を引用)

Data Distillation を実際に使用することを考えた場合、例えば、追加で収集する正解ラベルなしデータは、既存の正解ラベルありデータと全く同じ条件で得られるとは限りません。しかし、この実験結果によれば、同じタスクを行えるデータであれば、(分布の) 類似度によらずに Data Distillation が有効である可能性が示されています。その点で、Data Distillation は、正解ラベルなしデータを活用して精度向上をする手法として実用的である、と言えるかもしれません。

本論文では、上記のような Data Distillation の有効性以外に、いくつかの比較分析が報告されています。そのうちの1つは、正解ラベルなしデータの量に関するものです。以下の図では、使用する正解ラベルなしデータの比率 (正解ラベルありデータ数を1とした場合) ごとに精度が報告されています。正解ラベルなしデータを使用する場合、常に正解ラベルありデータのみで訓練したモデルの精度より高精度となっていますが、さらに、正解ラベルなしデータの比率が高いとより高精度になることが見て取れます。

縦軸は精度、横軸は正解ラベルなしデータの比率。ResNet-50を使用 (論文内 Figure 5 を引用)

さらに、教師モデルの精度と生徒モデルの精度の関係について報告されています。以下の図に見られるように、教師モデルの精度が向上すると生徒モデルの精度も向上することが示されています。想像に難くないですが、教師モデルの精度が向上すれば、そこから得られるアノテーションもより正確なものとなり、生徒モデルをより正確に訓練することが可能であると考えれます。

縦軸は生徒モデルの精度、横軸は教師モデルの精度。ResNet-50を使用 (論文内 Figure 6 を引用)

 

訓練方法を工夫して精度向上

以上で、正解ラベルなしデータを使ってデータの利用可能性を向上させることによる、精度向上のための蒸留の手法について説明しました。次に、精度向上のためのもう1つの手法である、Born Again Neural Networks (BAN), [Furlanello et al., 2018] を紹介します。

蒸留 第2回でも詳細に取り上げたとおり、蒸留はモデルサイズを小さくするモデル圧縮を基本としていました。蒸留は、教師から生徒への知識の継承という形で、教師モデルと同程度の精度のより軽量なモデルの開発を可能とするわけでした。本研究では、知識の継承先を軽量なモデルとはせず、教師モデルと同一のモデルアーキテクチャのモデルに対して蒸留することで、教師モデルを超える精度の生徒モデルが得られることを検証しています。

手法

Born Again Neural Networks (BAN) の手法は非常にシンプルで、基本的な蒸留の流れとほとんど同じステップで構成されます。違う点は、生徒モデルが教師モデルと同一 (あるいは同程度のサイズのモデル) という点です。すなわち、まずはじめに教師モデルを通常の教師あり学習の形で学習し、その後、その教師モデルの出力を使い、教師モデルと同一の生徒モデルを訓練します。教師モデルが同じアーキテクチャの生徒モデルに蒸留される点を捉えて、”Born Again Neural Networks” (生まれ変わるニューラルネット) と呼称されています。

さらに、本手法では、この蒸留の1ステップを複数回繰り返すこともしています。前のステップで蒸留された生徒モデルを新たな教師モデルとして、この新たな教師モデルからまた別の生徒モデルへの蒸留というステップを複数回繰り返します。このステップをk回実行した際、最終的に、k個の生徒モデルが得られますが、さらにこのk個の生徒モデルをアンサンブルするということも可能です。

この一連の流れを概念的に図示すると、以下の通りに表せます。

教師モデルを訓練 (Step 0)。生徒モデルに蒸留 (Step 1~k)。2ステップ目以降では、前のステップの生徒モデルを教師モデルとして蒸留。このk回のステップで得られたk個の生徒モデルをアンサンブルことも可能 (本論文 Figure 1 を引用)

実験・結果

CIFAR-10, CIFAR-100, Penn Tree Bank の3つのデータセットを使って、BAN を画像認識および言語モデリングの領域で実験・検証しています。いずれの実験でも、BAN によって、通常の教師あり学習で訓練された教師モデルを凌ぐ精度の生徒モデルが得られたことが示されています。ここでは特に、CIFAR-100 を題材にした実験とその結果について触れます。

この実験では、教師モデルと生徒モデルのアーキテクチャに DenseNet を採用しており、結果は以下の表の通りです。アーキテクチャ (Network) ごとのエラー率が報告されています。エラー率については、2列目から、教師モデル (Teacher)、BAN (BAN)、蒸留の際に教師モデルの出力と正解ラベルを両方使用した BAN (BAN+L) のエラー率を掲載しています。この表の左部分では、確かに BAN によって教師モデルを超える精度の生徒モデルが得られたことが見て取れます。

(注) DenseNet-80-120 について、BAN-3 は実験されていない。そのため、Ens*3 には BAN-3 の代わりに教師モデルが使われいる (本論文内 Table 2 を引用)

Teacher 列と BAN 列を比較すると、この実験では、BAN によって精度は平均して1程度向上しています。向上率を見ると、精度は平均5.8%程度改善しています。また、Data Distillation ではより高精度な教師モデルからの蒸留によって、より高精度な生徒モデルが得られる傾向が見られていました。この傾向は、この CIFAR-100 の画像認識タスクでも同様に見られ、上の表の Teacher 列と BAN 列を再び見ると、教師モデルの精度が良いと生徒モデルの精度も良いということが分かります。

(CWTM, DKPP を飛ばして) 残りの列は、BAN を1ステップ繰り返した場合=3列目のBAN と同じ (BAN-1、ただし3列目のBANとは初期値が異なる)、BAN を2ステップ繰り返した場合 (BAN-2)、BAN を3ステップ繰り返した場合 (BAN-3)、BAN-1, 2 のアンサンブル (Ens*2)、BAN-1, 2, 3 のアンサンブル (Ens*3) のエラー率を掲載しています。DenseNet-80-80 で BAN を3ステップ繰り返した場合のエラー率は 15.5% で、これは単一モデルの SOTA となっています。また、BAN を繰り返すことで概ね前ステップよりも高精度となったこと、アンサンブルすることで BAN の各ステップの単一の生徒モデルよりも高精度な結果を得られたことが示されています (BAN のコンセプトというより、蒸留における Dark Knowledge の効果を調べた実験のため、CWTM, DKPP を省略しました)。

全ての場合ではありませんが、上述した通り、BAN のステップを繰り返すことで生徒モデルの精度をさらに向上させられていました。上の表では3回繰り返した場合までしか報告されていませんが、さらに繰り返した場合にさらに精度向上できるのかどうかは気になるところかと思います。蒸留を複数ステップ繰り返すという操作が常に一定の精度向上の効果を持つのか、持つ場合にそれはどのような理論的背景によるのか、は今後の研究に期待できるところかもしれません。


 

今回は、蒸留の応用例を紹介しました。特に、蒸留によるモデルの精度向上について詳細に紹介し、蒸留がモデル圧縮以外の面でも実用的な技術であることを紹介できたかと思います。

ニューラルネットワークの蒸留についての連載記事は今回で最後となります。次回は、ディープラーニングにおけるドメイン適応を紹介します。ある1つのドメインで学習できたとしても、そのモデルは異なるドメインでも同様に高性能であるとは限りません。異なるドメイン間の差異を乗り越えるドメイン適応は、モデルの汎用性を考える上で重要な技術です。こちらの記事もぜひお楽しみに!