機械学習
記事内に商品プロモーションを含む場合があります

アンサンブルと違って処理時間が増えない精度改善手法:Model Soups

AI関連記事画像イメージ
tadanori

はじめに

ここでは、複数のファインチューニングした同一モデルの「重み」を平均化して精度向上を図る、「Model Soups」を紹介します。

上の文言だけ読むと「アンサンブル」と同じように感じると思いますが、パラメータを平均化したモデルを作るので、アンサンブルのように処理量が増えないという点で異なります。

処理量が増えない(=重くならない)というのは利点だよね

実際に利用したのは、1、2回ですが、効果ありでした。特に、アンサンブルしたら性能が上がるのに、処理速度の関係でアンサンブルは厳しいといった場面で使える気がします。

Model Soupsの簡単な説明

Model Soupsは、異なるハイパーパラメータで学習した複数のファインチューニングモデルの「重み(Weight)」と平均化することで、精度を向上させるテクニックです。

Model Soupsには、以下のような特徴があります。

  • アンサンブルとは異なる: Model Soupsは、推論時に余分な計算を必要とせず、多くのモデルを平均化することができる
  • 精度とロバスト性の向上: 複数のモデルの重みを平均化することで、個々のモデルの弱点やバイアスを克服し、より高い精度とロバスト性を達成することができる
  • 効率的な学習: 複数のモデルを個別に学習するよりも効率的に学習することができる

アンサンブルと異なり、推論時に処理が増えないのが一番の利点だと思います。

論文リンク
Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference timehttps://arxiv.org/abs/2203.05482

実装

ヘルパー感数

Model Soupsを行うためのヘルパー関数です。

以下は、pytorchでの実装例です。

def read_model(file):
  # ファイル名のモデルを読み込む(ここはモデル毎に作成)
 return model

def copy_params(model1, model2) :
  model1.load_state_dict(model2.state_dict())
  return model1

def sum_model_params(modelA, modelB):
    """ modelA + modelB """
    sdA = modelA.state_dict()
    sdB = modelB.state_dict()
    for key in sdA:
        sdA[key] = (sdA[key] + sdB[key])
    modelA.load_state_dict(sdA)
    return modelA

def mul_model_params(model, a):
    """ a * model """
    sd = model.state_dict()
    for key in sd:
        sd[key] = sd[key] * a
    model.load_state_dict(sd)
    return model

それぞれの関数の機能は以下の通りです。

関数名機能
read_model(file)fileで指定されたモデルを読み込む。モデル毎にカスタムする必要があります。
copy_params(model1, model2)model2の重みをmodel1へコピーする
sum_model_params(modelA, modelB)modelAmodelBの重みを足す
mul_model_params(model, a)modelの重みをa倍する

これらの関数を作っておけば、Model Soupsはこれらの関数の組みわせで実現できます。

モデルAとモデルBを平均化する場合

例えば、モデルAとモデルBを平均化する場合は、以下のようなコードになります。

modelA = read_model('modelA.pth')
modelB = read_model('modelB.pth')
model = sum_model_params(modelA, modelB)
model = mul_model_params(model, 1/2)  # 2つ合成したので、1/2にする

N個のモデルの組み合わせからベストのモデルを作成する場合

重みの異なるN個のモデルがある場合は、以下のコードで最適な組み合わせをチェックし、平均化することができます。

以下は、リストmodelsにN個のモデルが格納されている場合の例です。なお、calc_score()は検証データを使った評価を行う関数とします

以下は、コード例です。

n = len(models)
best_model = models[0]
best_score = -1e18 # 最小値以下を設定
for bit in range(1, 1<<n):
    pat = []
    for i in range(n):
        if (bit>>i)%2 == 1:
            pat.append(i)
    model = models[pat[0]]
    print(format(bit, '04b'), pat, len(pat))
    for i in range(1, len(pat)):
        model = sum_model_params(model, models[pat[i]])
        print(pat[i])
    model = mul_model_params(model, 1/len(pat))
    
    score = calc_score(model)
    if score > best_score :
        best_model = model

コード自身は単純で、全てのモデルの組み合わせを作成し、それぞれのスコアをチェックしています。

calc_scoreの処理時間(検証データの処理時間)次第ですが、モデルの組み合わせ8個程度までは実用的な速度でチェックできると思います。

組み合わせパターンは、それぞれのモデルを使う・使わないで$O(2^n)$となるので注意が必要です。N=8の場合は255パターンですが、N=16の場合は65535パターンのチェックを行うことになります。

また、検証データで最も良いものを探索するので、検証データがモデルにリークしてしまう可能性があることに注意が必要です

検証データに対して最適化しまうのは、しょうがないところです。実際のデータと検証データの集合に乖離がある場合は、実データに対しての性能は低下してしまうので注意。

まとめ

Model Soupsについて簡単に説明し、コード例を紹介しました。実装は簡単なのに、かなり効果のある手法です。また、画像認識、自然言語処理、音声認識など、様々なディープラーニングタスクで使うことが可能と、汎用性の高いテクニックです。

おすすめ書籍

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

記事URLをコピーしました