こんにちは、エンジニアの中村です。

前回の記事 (蒸留 第1回) で、ディープラーニング技術における蒸留という手法の基本とその発展を俯瞰的に見てきました。

その中で、ディープラーニングを実用する際に直面する本番環境の計算リソースの制約という問題について触れました。蒸留は、ディープラーニングのモデルを軽量化させる機能があり、それにより、この問題を緩和・解消できることに言及しました。また、具体的には、訓練済みモデルに蓄えられた知識を別の軽量なモデルに継承する形でディープラーニングモデルを軽量化し (モデル圧縮)、本番環境の制約下でも使用可能な、軽量かつ高性能なモデルを開発可能とすることを紹介しました。

今回の記事では、蒸留によるモデル圧縮について、該当する論文をいくつか参照して、詳細に紹介していきたいと思います。分類タスク、回帰タスクそれぞれの代表的な研究例を取り上げます。まず、分類タスクにおける蒸留の詳細を説明し、蒸留が持つ正則化効果について紹介します。その後、蒸留の回帰タスクへの応用について、紹介します。

具体的な詳細に入る前に、蒸留の概観を知りたいという方は、ぜひ蒸留 第1回の記事をご覧ください。

目次

分類タスクにおける蒸留によるモデル圧縮

分類タスクの事例として、Distilling the Knowledge in a Neural Network [Hinton et al., 2015]について紹介します。この文献は、ディープラーニングにおける蒸留の分野の草分け的位置付けの研究です。ディープラーニングの生みの親とも言える Geoffrey Hinton 先生が、蒸留という手法を導入しました。

背景

ディープラーニングに限らず機械学習において、複数のモデルを学習し各モデルの予測を総合することで、全体として性能の向上を図るアンサンブルがしばしばなされます。例えば、Kaggle 等のデータ分析コンペティションでは、XGBoost や Neural Network などの複数のモデルに対して、まずは個別に各単一モデルでの精度向上を目指して試行錯誤した後、最終的にそれら複数のモデルの予測を単純に平均あるいは重み付き平均することで、精度向上を狙うという戦略が多く採用されているかと思います。

確かに、アンサンブルによって最終的に高性能なモデルを得ることができますが、しかしながら、アンサンブルモデルを実際に使おうとすると、複数のモデルを動作させるため、本番環境の計算リソースが潤沢である必要があります。しかし、本番環境の計算リソースは多くの場合制約があります。ディープラーニングに関しては、アンサンブル以外にもモデルの大規模化によって一般に精度向上を図ることができますが、これも同様に、本番環境の計算リソースの制約という課題を抱えることになります。

手法

本論文では、計算リソースの制約という課題に対して、アンサンブルあるいは大規模モデルの学習済みの知識を、別の軽量な単一モデルに落とし込む手法として、蒸留を導入しています。

知識の継承元となるアンサンブルあるいは大規模モデルを教師モデル、継承先の軽量な単一モデルを生徒モデルと呼びます。高い精度で予測・分類できる教師モデルには汎化のための十分な知識が備わっています。その知識はクラス間の類似度を反映して、教師モデルの出力に現れます。正解ラベル (hard target) にはない、教師モデルの出力 (soft target) に含まれる情報を生徒モデルの学習に利用することで、生徒モデルの効果的な知識獲得と汎化を促します。

(ここまで、前回の記事の「基本的な蒸留の枠組み」で触れた内容をハイライトしています。用語や詳しい説明については、そちらも合わせてご覧ください。)

蒸留 第1回より再掲

温度パラメータ付き Softmax

ディープラーニングでは、分類タスクを扱うニューラルネットワークの出力は通常、最終層の活性化関数が Softmax 関数となっています。蒸留では、教師モデルの出力層の Softmax の値を soft target として生徒モデルの学習に使用しますが、ここでは少し工夫を加えています。Softmax 値は確かに正解クラス以外の類似クラスにも値を持つことがあり、soft target として使用可能です。しかし、その値はしばしば非常に小さく、クロスエントロピーで評価した損失は hard target を使用する場合とほとんど変わりません。そのため、温度パラメータ T を持つ Softmax 関数を代わりに使用するという工夫を加えています。

(1)   \begin{equation*} softmax(x_i) = \frac{e^{\frac{x_i}{T}}}{\sum_{j} e^{\frac{x_j}{T}}} \end{equation*}

式から分かるように、温度 T を上げることによって、soft target 内の正解クラス以外の類似クラスに対する値が増幅されます。それにより、soft target に期待する効果がより現れやすくなります。なお、soft target で学習する生徒モデルも、蒸留による学習時には、最終層の Softmax を同じ温度 T にして学習します (余談になりますが、温度 T を上げることで知識を教師モデルから生徒モデルに抽出することが、「蒸留」という命名の由来だと思われます)。

logitと温度Tを変化させた時の Softmax の値の変化のイメージ

温度 T を上げる以外にも、有効な soft target を得る方法として、Softmax 関数に入力される値 (しばしば logit と呼ばれる値) を使用する方法も考えられます。すなわち、教師モデルの logit と生徒モデルの logit の二乗誤差を最小化する方法です ([Bucila et al., 2006])。実は、蒸留は logit の二乗誤差の最小化を一般化した方法であるということが示されています (詳細は、論文内「2.1 Matching logits is a special case of distillation」の節を参照ください)。

soft target は hard target と組み合わせて、生徒モデルの学習に使うことができます。注意点としては、温度 T の時、その soft target についての勾配は 1/T2 のスケールとなるため、T2 を乗じて hard target で計算される勾配とスケールを合わせることが必要になります。

実験・結果

MNIST、音声認識データセット、大規模データセット (Google 内部の JFT と呼ばれるデータセット) の3種類のデータセットを題材に、蒸留の有効性が実験的に示されています。この記事では、3つの実験のうち、蒸留の効果を端的に示す MNIST を使用した実験と、アンサンブルモデルを単一モデルに蒸留した音声認識データセットによる実験を紹介します。

MNIST での実験の結果は以下の表の通りです。温度を上げて (T = 20) 蒸留した生徒モデルは、モデルサイズが同じ軽量なモデルの約半数のテストエラー数となり、教師モデルにも匹敵する性能を示しています。MNIST の分類タスクが単純すぎるという点はありますが、教師モデルと生徒モデルのテストエラー数は、テストデータ1万件中それぞれ 67件、74件であり、ほとんど同程度の精度であることが分かります。反面、生徒モデルのパラメータ数は、教師モデルの約53%と大幅に削減することができています。

 モデルサイズテストエラー数備考
教師モデルユニット数1,200 + ReLU
の隠れ層2つ
67件Dropout 等の正則化あり
軽量モデルユニット数800 + ReLU
の隠れ層2つ
146件正則化なし
生徒モデル
ユニット数800 + ReLU
の隠れ層2つ
74件hard target + soft target
T = 20 で教師モデルから蒸留
([Hinton et al., 2015] より作成)

 

次に、音声認識データセットを使用した実験を紹介します。この実験は、英語話者の音声データから音素の状態 (離散) を推定する実験で、同一アーキテクチャの単一モデルを10個アンサンブルしたモデルから、単一モデルに蒸留しています。

結果は下表の通りで、蒸留によって、10のアンサンブルモデルに匹敵する精度を、単一モデルで実現することができています。蒸留したモデルの精度は10アンサンブルモデルとほとんど変わりませんが、パラメータ数をアンサンブルモデルの10%と劇的に抑えることができています。

モデル精度備考
単一モデル58.9%
10アンサンブル61.1%アンサンブルを構成する各モデルはランダムに初期化された
単一モデル (蒸留)60.8%温度T = 2 で蒸留
単一モデル、アンサンブルを構成する各モデル、単一モデル (蒸留) は同一アーキテクチャを持つ ([Hinton et al., 2015] より作成)

 

蒸留の持つ正則化効果

本研究では、蒸留は正則化効果をもたらすことが合わせて報告されています。上述の音声認識データセットを使用した実験で、同サイズのモデルを対象に、訓練データ数を意図的に減らすことで過学習が起きやすい条件を設定しています。少数の訓練データで学習したモデルは過学習してしまうのに対し、soft target で蒸留したモデルは、同じ少数の訓練データでの学習にもかかわらず、訓練データ全てを使って訓練されたモデルに匹敵するようなテスト精度をおさめています。

訓練データの割合target の種類訓練精度テスト精度
100%hard target63.4%58.9%
3% (約2,000万件に相当)hard target67.3%44.5%
3%soft target65.4%57.0%
([Hinton et al., 2015] より作成)

 

過学習を防いでいる点、蒸留が正則化効果を持つことが実験的に示されています。これは soft target に含まれる情報により、教師モデルが獲得した特徴表現と同様の特徴表現を、生徒モデルも一定程度獲得するように学習するためと考えることができるかもしれません。

回帰タスクにおける蒸留によるモデル圧縮

回帰タスクの事例として、Learning Efficient Object Detection Models with Knowledge Distillation [Chen et al., 2017]について紹介します。この文献は、蒸留が回帰タスクを含む物体検出に適用可能であるかどうかを検証し、物体検出でも蒸留が有効であることを報告しています。

背景

上で紹介した論文により、ニューラルネットワークの蒸留の基本コンセプトが提示され、ディープラーニングの分類タスクにおいて蒸留が有効であることが示されました。しかしながら、ディープラーニングが効果を発揮するのは、単純な分類タスクに限定されないことは周知の通りです。

分類タスク以外には、回帰タスクに含まれるものとして、例えば物体検出が挙げられます。物体検出では、物体が何であるか分類するだけではなく、その物体が画像中のどこにあるかを検出することが求められます。前者は分類タスクですが、物体検出では、物体を含む領域と物体と含まない領域を比較すると、物体を含まない領域がほとんどであります。クラス間でデータ数の大きな偏りがない単純な分類タスクでは効果を示していた蒸留は、物体検出では、このクラス間のデータ数のインバランスに対処する必要があります。また、後者の物体位置の検出は回帰タスクの代表例ですが、上で紹介した単純な分類タスクを対象とした蒸留が扱わないタスクのため、蒸留を応用するには、回帰タスクのための工夫が必要となります。

物体検出の例
(Object Detection: A Guide in the Age of Deep Learning より)

手法

本論文では、データ数にインバランスのある分類タスクに蒸留を応用するために、物体と非物体で異なる重みを使って学習する重み付きクロスエントロピーを導入しています。また、回帰タスクに蒸留を応用するために、教師モデルの損失を参考に生徒モデルの損失を切り替えるという工夫を回帰タスクの損失に加えています (teacher bounded loss)。さらに、蒸留による教師モデルから生徒モデルへの知識の継承をより効果的にするために、教師モデルと生徒モデルのそれぞれの中間層を直接的にマッチさせる手法を採用しています (Adaptation Layer ありの Hint Learning)。

1. 重み付きクロスエントロピー

通常の単純な分類タスクと異なり、物体検出では、N 種類の物体と、物体を含まない領域 = バックグラウンドの合計 N + 1 種類のクラス分類をすることが求められます。この物体とバックグラウンドを合わせた N + 1 クラス分類タスクでは、先述したように、バックグラウンドクラスがデータ数の多くを占め、物体/バックグラウンドを確実に識別できることが重要になります。

本論文では、物体検出における分類タスクの損失に hard target と soft target の両方で求めたクロスエントロピーを使用しています (式(2). P_{s}, P_{t} はそれぞれ生徒モデル、教師モデルの Softmax 出力)。このうち soft target に関する損失において、バックグランドクラスに比較的大きな重みを置くことで、クラス数のインバランスによる物体/バックグラウンドの誤識別をより抑制しています。具体的には、PASCAL データセットを使用した実験では、以下の式(3)において、バックグラウンドクラス (c = 0) の場合には w_{0} = 1.5、物体の場合には w_{c} = 1.0 (c = 1, ..., i, i は全物体数) としています。

(2)   \begin{equation*} L_{cls} = \mu L_{hard}(P_{s}, y) + (1 - \mu)L_{soft}(P_{s}, P_{t}) \end{equation*}

(3)   \begin{equation*} L_{soft}(P_{s}, P_{t}) = - \sum_{} w_{c}P_{t}\log{P_{s}} \end{equation*}

2. teacher bounded loss

単純な分類タスクに対する蒸留の方法を踏襲して、回帰タスクである物体座標検出についても、教師モデルの出力をターゲットとして、それとの誤差を最小にするように生徒モデルを学習すればよさそうです。分類タスクとの違いは、出力が実数値を取ること、よって二乗誤差など実数値に対する損失関数を使用することくらいしかなさそうに思えます。

しかしながら、物体検出の回帰タスクにおける蒸留について、教師モデルの出力をそのままターゲットとして使用することは不適切であると報告されています。実数値を取る教師モデルの出力は上限下限がなく大きく間違う可能性があり、生徒モデルの更新を正解ラベルとは全く違った方向に導くこともあり得るため、ターゲットとして使用するには適さないと理由づけされています。

代わりに、生徒モデルが教師モデルよりも大きく間違った場合に、追加の損失を加えるという形で、教師モデルの出力を利用しています。以下の式に見られるように、回帰についての損失として smooth L1 Loss L_{sL1} を基本的に採用しつつ (式(4) 第1項)、生徒モデルの出力が教師モデルの出力よりも間違っていた場合に、追加で L2 Loss を加える (式(5)上) という形で損失関数を構成しています (式(4) において、y は正解ラベル、R_{s}, R_{t} はそれぞれ生徒モデル、教師モデルの出力。\nu = 0.5 で実験。式(5) の m はマージン)。

(4)   \begin{equation*} L_{reg} = L_{sL1}(R_{s}, y_{reg}) + \nu L_{b}(R_{s}, R_{t}, y_{reg}) \end{equation*}

(5)   \begin{equation*} L_{b} (R_{s}, R_{t}, y) = \left{ \begin{cases} ||R_{s} - y||^2, & \text{if $||R_{s} - y||^2 + m > ||R_{t} - y||^2$} & 0                     , & \text{otherwise} \end{cases} \end{equation*}

繰り返しですが、回帰タスクにおいては、あくまで教師モデルの出力はターゲットとしては使用せず、”教師モデルよりも誤っていた場合により強く学習する” というコンセプトで蒸留が実現されています。

3. Adaptation Layer ありの Hint Learning

教師モデルから生徒モデルによりよく知識を継承させるために、教師モデルと生徒モデルの最終的な出力におけるマッチングだけでなく、両者の中間層における出力のマッチングをおこなうことがあります。これは Hint Learning と呼ばれ、[Romero et al., 2014]で提案されています。

中間層マッチングの対象となる教師モデルのレイヤ (Hint Layer) と生徒モデルのレイヤ (Guided Layer) は、チャネル数が必ずしも一致しているわけではないため、Guided Layer を Hint Layer に一致させるために、1×1 convolution から成る Adaptation Layer を間に設けます。Guided Layer を Adaptation Layer によってチャネル数を揃えた後、Hint Layer との間で L2 Loss を取ることで、中間層マッチングを実現します。なお、VGG16 と AlexNet のように、中間層の空間的なサイズが異なる場合には、[Gupta et al., 2015] に見られるパディングの調整を行なうことで、同サイズの中間層を得ています。

以上の 1, 2, 3 の工夫をまとめて全体像を見渡すと、以下の図のようになります。

[Chen et al., 2017]をもとに対応する番号を付与

実験・結果

KITTI, PASCAL VOC 2007, MS COCO, ImageNet DET Benchmark (ILSVRC 2014) といった幅広いデータセットを題材に、特徴抽出部分に AlexNet, VGG をベースとしたアーキテクチャを使用した Faster R-CNN を用いて、物体検出における蒸留の有効性を検証しています (モデルは、具体的には、AlexNet, AlexNet with Tucker Decomposition, VGGM, VGG16 の4種類)。

下表の通り、教師モデルから蒸留した方が、蒸留をしなかった場合よりも精度が良いことが報告されています。提案手法を構成する要素ごとの有効性を確かめる実験も実施されており、1. 重み付きクロスエントロピー、2. teacher bounded loss、3. Adaptation Layer ありの Hint Learning のそれぞれがすべて精度向上に有効であったことも報告されています。

1列目から3列目はそれぞれ、生徒モデルアーキテクチャ、生徒モデルのパラメータ数と実行速度、教師モデルアーキテクチャ (- は蒸留なし)。 4列目以降は各データセットにおける精度 (mAPで評価。( )内は蒸留なしからの差分。mAP は IOU = 0.5 で計算した値。ただし、6列目については 0.5から0.95を0.05刻みでとった各 IOU で計算した値の平均値)
[Chen et al., 2017]

上の表を元に、モデル圧縮と精度のトレードオフを考察します。PASCAL と KITTI は比較的小さいデータセットであり、物体数もそれほど多くはないデータセットのため、データ数が多くより難易度の高い COCO と ILSVRC での実験結果を元に、それぞれ可視化したものが下図です。

AlexNet, VGGM を見ると、蒸留によって、パラメータ数を50%前後にまで減少できているのに対し、精度は大幅には低下しておらず、(少なくとも) 80%程度の精度を保っていることが分かります。実行時間は30%以下であり、すなわち、この実験での最大モデルである VGG16 の30倍以上の実行速度で、80%の精度を得られていることが分かります。

COCO, ILSVRC での実験より、パラメータ数と精度を比較
VGG16に対する比を掲載

 

また、蒸留の過程で、生徒モデルに入力する画像の解像度を落とした実験の結果が報告されています。この実験では、教師モデルと生徒モデルに同一アーキテクチャを採用し、教師モデルの入力は元の高解像度のまま、生徒モデルの入力を低解像度にして、蒸留をしています。

結果は下表の通りで、低解像度・蒸留なしの場合 (中列) よりも、低解像度・蒸留あり場合 (右列) の方が高精度であり、高解像度で学習した場合 (左列) にも匹敵する精度であったことが報告されています。

表にも見られる通り、高解像度である方が物体検出は高精度で、低解像度あると精度を著しく損ないます。しかし、蒸留という形で高解像度入力の教師モデルの出力を手がかりとして使えると、精度の低下を抑えることが可能であると考えられます。例えば、物体領域の座標の検出について、低解像度入力の生徒モデルは検出の精度が低下しますが、蒸留を施すと、より正確な教師モデルの出力を手がかりとして、それより誤差の大きい場合には損失を追加して学習するため、蒸留のない場合と比較して、より正確な出力を学習するようになると考えられそうです。

左から、高解像度入力モデル(教師モデル)、低解像度入力モデル、低解像度入力モデル(生徒モデル=蒸留あり)。それぞれについて、精度と CPU / GPU での実行速度を掲載。
[Chen et al., 2017]
低解像度入力の蒸留で訓練された物体検出モデルは、解像度の低さのため畳み込みの計算が減り、高速に推論することができます。そのため、推論の速度が重視される本番環境においては、特にメリットが得られるでしょう。


 

今回は、蒸留について、その端緒であるモデル圧縮に関連した研究の代表例を、分類・回帰のタスク別に紹介しました。特に、回帰タスクについては、教師モデルの出力を時には直接使わない等の工夫を施すことで、蒸留が物体検出のモデル圧縮にも応用可能であることを紹介しました。

次回は、モデル圧縮を超えた蒸留の発展的な研究について、個別の研究をいくつか参照して、蒸留の応用先の広さを具体的に紹介していきたいと思います。