機械学習
記事内に商品プロモーションを含む場合があります

【Pytochで実装】Model soups | 重み平均により精度向上させる手法

tadanori

この記事では、同一モデルの複数の重みを平均化するだけで精度を向上させる手法Model soups」を紹介します。ここでは、Pytorchでの実装方法をメインで紹介します。手軽で結構使える手法です。

Model Soupsとは

Model Soupsは、arXivの”Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time”で紹介されている技術です。

手法としては結構簡単で、異なるハイパーパラメータで学習させた複数のファインチューニングモデルの「重み」を平均かすることで、精度とロバスト性を向上させるというものです。

arXiv : https://arxiv.org/abs/2203.05482

複数のモデルを平均する手法としては、アンサンブルがありますが、どこが違うかというと、重み自体を平均化するため、推論を行うのは1つのモデルとなります。つまり、推論時の処理量は増えません。

論文によると、いくつかのモデルで精度向上ができているようです。

実際、私もSegformerで利用したことがありますが、精度を向上することができました。

処理量が増えずに、性能を上げられるということで、かなり魅力的な手法だと思います

なぜ、性能が向上するのかなどの考察は論文に書かれていますのでそちらを参照してください。

この記事では、「具体的にどういうコードになるか」とか、「どんな感じで使うのか(私が実際に使ったやり方)」を紹介したいと思います。

Model Soupsに使う関数群 for Pytorch

Model Soupsを実装するために、3つの関数を用意しました。この3つの関数を組み合わせることで、モデルの重みを加重平均することができます。

なお、コードはPytorch用です。

  1. copy_params(model1, model2)
    model2の重みをmodel1にコピーします
  2. sum_model_params(model1, model2)
    model1model2の重みを単純加算して、model1の重みを変更して返します。model1の重みが書き換わるので注意してください。
  3. multi_model_params(model, a)
    modelの重みにaを乗算します
def copy_params(model1, model2) :
  model1.load_state_dict(model2.state_dict())
  return model1

def sum_model_params(modelA, modelB):
    """ modelA + modelB """
    sdA = modelA.state_dict()
    sdB = modelB.state_dict()
    for key in sdA:
        sdA[key] = (sdA[key] + sdB[key])
    modelA.load_state_dict(sdA)
    return modelA

def multi_model_params(model, a):
    """ a * model """
    sd = model.state_dict()
    for key in sd:
        sd[key] = sd[key] * a
    model.load_state_dict(sd)
    return model

この3つの関数を使う方法を以下に示します。

model1, model2の重みを平均化する

加算後に、0.5を乗じることで平均化します。

model1 = sum_model_params(model1, model2)
model1 = multi_model_params(model1, 0.5)

model1, model2, model3の重みを平均化する

加算後に、0.333….を乗じることで平均化します。

model1 = sum_model_params(model1, model2)
model1 = sum_model_params(model1, model3)
model1 = multi_model_params(model1, 1/3)

model1, model2を2:3で平均化する

加重平均も以下のようなコードで記述できます。

model1 = multi_model_params(model1, 2)
model2 = multi_model_params(model2, 3)

model1 = sum_model_params(model1, model2)
model1 = multi_model_params(model1, 1/5)

N個のモデルの組み合わせで最も良いものを選択

Nが大きい場合は、処理時間がかかりますが、以下のようなコードで、全ての組み合わせをチェックできます。実用的には、N=2~8くらいまでだと思います(8モデルの場合、組み合わせは255通りになります)

以下は、コードのサンプルです。評価の部分とベスト更新の部分はモデルに合わせて追加してください。

import copy

models = [model1, model2, model3, model4]
n = len(models)

best = None

for bit in range(1, 1<<n) :
  sel = []
  for i in range(n):
    if (bit>>i)%2 : sel.append(i)
  print(sel)
  model = copy.deepcopy(models[sel[0]])
  for i in range(1, len(sel)) :
    print("add ", sel[i])
    model = sum_model_params(model, models[sel[i]])
  model = multi_model_params(model, len(sel))
  
  モデルを評価する

  if 評価結果がベストならば : 
     best = model

上記のようにビット全探索で実装すれば、組み合わせベストのモデルを選択することが簡単にできます。

モデルの作り方

論文ではファインチューニングのハイパーパラメータを変えたものの平均を取ることになっていますが、EPCOHのベスト3や5を組み合わせた場合も、結構うまく行きました。

Model Soupsの対象とするモデルの選定はいろいろ考えられそうです。

ただ、Model Soupsの基本的な考え方は、同じ収束過程のモデルの加重平均をとることで収束するポイント(重み)に近づけることができるというものなので、学習済みモデルをファインチューニングしたモデルのいくつかを組み合わせるのが良いようです。

データが少ない場合は、事前学習モデルをファインチューニングすることが多いので、この手法は結構多くのシーンで使えると思います

まとめ

以上、Model SoupsについてPytorchでの実装方法を解説しました。手軽に試せる手法なので、精度とロバスト性を上げたい場合に、やってみる価値はあるかと思います。

おすすめ書籍

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

記事URLをコピーしました