重みの平均化で精度向上させるModel SoupsをPyTorchで実装する
この記事では、同一モデルの複数の重みを平均化することで精度を向上させる手法「Model Soups」について紹介します。ここでは、手法の概要と、PyTorchを用いた実装方法までを中心に解説します。Model Soupsは、手軽に導入でき、精度向上が期待できる実用的なアプローチです。さらに、PyTorchで簡単に実装できる関数群も用意しましたので、ぜひ参考にしてみてください。
Model Soupsとは
Model Soupsは、arXivの”Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time”で紹介されている技術です。
手法としては結構簡単で、異なるハイパーパラメータで学習させた複数のファインチューニングモデルの「重み」を平均かすることで、精度とロバスト性を向上させるというものです。
arXiv : https://arxiv.org/abs/2203.05482
複数のモデルを平均する手法としては、アンサンブルがありますが、Model Soupsでは重み自体を平均化するため、あくまでモデルは1つであり推論時の処理量が増えないという点が異なります。
論文を読むと、平均かにより複数のモデルで精度向上ができているようです。実際に、私もSegformerで利用したことがありますが精度を向上することができました。
処理量が増えずに、性能を上げられるということで、かなり魅力的な手法だと思います
なぜ、性能が向上するのかなどの考察は論文に書かれていますのでそちらを参照してください。
この記事では、「具体的にどういうコードになるか」とか、「どんな感じで使うのか(私が実際に使ったやり方)」を紹介したいと思います。
Model Soupsに使う関数群 for Pytorch
Model Soupsを実装するために、3つの関数を用意しました。この3つの関数を組み合わせることで、モデルの重みを加重平均することができます。
なお、コードはPytorch用です。
copy_params(model1, model2)
model2
の重みをmodel1
にコピーしますsum_model_params(model1, model2)
model1
とmodel2
の重みを単純加算して、model1
の重みを変更して返します。model1
の重みが書き換わるので注意してください。multi_model_params(model, a)
model
の重みにa
を乗算します
def copy_params(model1, model2) :
model1.load_state_dict(model2.state_dict())
return model1
def sum_model_params(modelA, modelB):
""" modelA + modelB """
sdA = modelA.state_dict()
sdB = modelB.state_dict()
for key in sdA:
sdA[key] = (sdA[key] + sdB[key])
modelA.load_state_dict(sdA)
return modelA
def multi_model_params(model, a):
""" a * model """
sd = model.state_dict()
for key in sd:
sd[key] = sd[key] * a
model.load_state_dict(sd)
return model
この3つの関数を使う方法を以下に示します。
model1, model2の重みを平均化する
加算後に、0.5を乗じることで平均化します。
model1 = sum_model_params(model1, model2)
model1 = multi_model_params(model1, 0.5)
model1, model2, model3の重みを平均化する
加算後に、0.333….を乗じることで平均化します。
model1 = sum_model_params(model1, model2)
model1 = sum_model_params(model1, model3)
model1 = multi_model_params(model1, 1/3)
model1, model2を2:3で平均化する
加重平均も以下のようなコードで記述できます。
model1 = multi_model_params(model1, 2)
model2 = multi_model_params(model2, 3)
model1 = sum_model_params(model1, model2)
model1 = multi_model_params(model1, 1/5)
N個のモデルの組み合わせで最も良いものを選択
Nが大きい場合は、処理時間がかかりますが、以下のようなコードで、全ての組み合わせをチェックできます。実用的には、N=2~8くらいまでだと思います(8モデルの場合、組み合わせは255通りになります)
以下は、コードのサンプルです。評価の部分とベスト更新の部分はモデルに合わせて追加してください。
import copy
models = [model1, model2, model3, model4]
n = len(models)
best = None
for bit in range(1, 1<<n) :
sel = []
for i in range(n):
if (bit>>i)%2 : sel.append(i)
print(sel)
model = copy.deepcopy(models[sel[0]])
for i in range(1, len(sel)) :
print("add ", sel[i])
model = sum_model_params(model, models[sel[i]])
model = multi_model_params(model, len(sel))
モデルを評価する
if 評価結果がベストならば :
best = model
上記のようにビット全探索で実装すれば、組み合わせベストのモデルを選択することが簡単にできます。
モデルの作り方
論文ではファインチューニングのハイパーパラメータを変えたものの平均を取ることになっていますが、EPCOHのベスト3や5を組み合わせた場合も、結構うまく行きました。
Model Soupsの対象とするモデルの選定はいろいろ考えられそうです。
ただ、Model Soupsの基本的な考え方は、同じ収束過程のモデルの加重平均をとることで収束するポイント(重み)に近づけることができるというものなので、学習済みモデルをファインチューニングしたモデルのいくつかを組み合わせるのが良いようです。
データが少ない場合は、事前学習モデルをファインチューニングすることが多いので、この手法は結構多くのシーンで使えると思います
まとめ
以上、Model SoupsについてPytorchでの実装方法を解説しました。手軽に試せる手法なので、精度とロバスト性を上げたい場合に、やってみる価値はあるかと思います。