機械学習
記事内に商品プロモーションを含む場合があります

重みの平均化で精度向上させるModel SoupsをPyTorchで実装する

Aru

この記事では、同一モデルの複数の重みを平均化することで精度を向上させる手法「Model Soups」について紹介します。ここでは、手法の概要と、PyTorchを用いた実装方法までを中心に解説します。Model Soupsは、手軽に導入でき、精度向上が期待できる実用的なアプローチです。さらに、PyTorchで簡単に実装できる関数群も用意しましたので、ぜひ参考にしてみてください。

Model Soupsとは

Model Soupsは、arXivの”Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time”で紹介されている技術です。

手法としては結構簡単で、異なるハイパーパラメータで学習させた複数のファインチューニングモデルの「重み」を平均かすることで、精度とロバスト性を向上させるというものです。

複数のモデルを平均する手法としては、アンサンブルがありますが、Model Soupsでは重み自体を平均化するため、あくまでモデルは1つであり推論時の処理量が増えないという点が異なります。

論文を読むと、平均かにより複数のモデルで精度向上ができているようです。実際に、私もSegformerで利用したことがありますが精度を向上することができました。

処理量が増えずに、性能を上げられるということで、かなり魅力的な手法だと思います

なぜ、性能が向上するのかなどの考察は論文に書かれていますのでそちらを参照してください。

この記事では、「具体的にどういうコードになるか」とか、「どんな感じで使うのか(私が実際に使ったやり方)」を紹介したいと思います。

Model Soupsに使う関数群 for Pytorch

Model Soupsを実装するために、3つの関数を用意しました。この3つの関数を組み合わせることで、モデルの重みを加重平均することができます。

なお、コードはPytorch用です。

  1. copy_params(model1, model2)
    model2の重みをmodel1にコピーします
  2. sum_model_params(model1, model2)
    model1model2の重みを単純加算して、model1の重みを変更して返します。model1の重みが書き換わるので注意してください。
  3. multi_model_params(model, a)
    modelの重みにaを乗算します
def copy_params(model1, model2) :
  model1.load_state_dict(model2.state_dict())
  return model1

def sum_model_params(modelA, modelB):
    """ modelA + modelB """
    sdA = modelA.state_dict()
    sdB = modelB.state_dict()
    for key in sdA:
        sdA[key] = (sdA[key] + sdB[key])
    modelA.load_state_dict(sdA)
    return modelA

def multi_model_params(model, a):
    """ a * model """
    sd = model.state_dict()
    for key in sd:
        sd[key] = sd[key] * a
    model.load_state_dict(sd)
    return model

この3つの関数を使う方法を以下に示します。

model1, model2の重みを平均化する

加算後に、0.5を乗じることで平均化します。

model1 = sum_model_params(model1, model2)
model1 = multi_model_params(model1, 0.5)

model1, model2, model3の重みを平均化する

加算後に、0.333….を乗じることで平均化します。

model1 = sum_model_params(model1, model2)
model1 = sum_model_params(model1, model3)
model1 = multi_model_params(model1, 1/3)

model1, model2を2:3で平均化する

加重平均も以下のようなコードで記述できます。

model1 = multi_model_params(model1, 2)
model2 = multi_model_params(model2, 3)

model1 = sum_model_params(model1, model2)
model1 = multi_model_params(model1, 1/5)

N個のモデルの組み合わせで最も良いものを選択

Nが大きい場合は、処理時間がかかりますが、以下のようなコードで、全ての組み合わせをチェックできます。実用的には、N=2~8くらいまでだと思います(8モデルの場合、組み合わせは255通りになります)

以下は、コードのサンプルです。評価の部分とベスト更新の部分はモデルに合わせて追加してください。

import copy

models = [model1, model2, model3, model4]
n = len(models)

best = None

for bit in range(1, 1<<n) :
  sel = []
  for i in range(n):
    if (bit>>i)%2 : sel.append(i)
  print(sel)
  model = copy.deepcopy(models[sel[0]])
  for i in range(1, len(sel)) :
    print("add ", sel[i])
    model = sum_model_params(model, models[sel[i]])
  model = multi_model_params(model, len(sel))
  
  モデルを評価する

  if 評価結果がベストならば : 
     best = model

上記のようにビット全探索で実装すれば、組み合わせベストのモデルを選択することが簡単にできます。

モデルの作り方

論文ではファインチューニングのハイパーパラメータを変えたものの平均を取ることになっていますが、EPCOHのベスト3や5を組み合わせた場合も、結構うまく行きました。

Model Soupsの対象とするモデルの選定はいろいろ考えられそうです。

ただ、Model Soupsの基本的な考え方は、同じ収束過程のモデルの加重平均をとることで収束するポイント(重み)に近づけることができるというものなので、学習済みモデルをファインチューニングしたモデルのいくつかを組み合わせるのが良いようです。

データが少ない場合は、事前学習モデルをファインチューニングすることが多いので、この手法は結構多くのシーンで使えると思います

まとめ

以上、Model SoupsについてPytorchでの実装方法を解説しました。手軽に試せる手法なので、精度とロバスト性を上げたい場合に、やってみる価値はあるかと思います。

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

ABOUT ME
ある/Aru
ある/Aru
IT&機械学習エンジニア/ファイナンシャルプランナー(CFP®)
専門分野は並列処理・画像処理・機械学習・ディープラーニング。プログラミング言語はC, C++, Go, Pythonを中心として色々利用。現在は、Kaggle, 競プロなどをしながら悠々自適に活動中
記事URLをコピーしました