アンサンブルと違って処理時間が増えない精度改善手法:Model Soups
はじめに
ここでは、複数のファインチューニングした同一モデルの「重み」を平均化して精度向上を図る、「Model Soups」を紹介します。
上の文言だけ読むと「アンサンブル」と同じように感じると思いますが、パラメータを平均化したモデルを作るので、アンサンブルのように処理量が増えないという点で異なります。
処理量が増えない(=重くならない)というのは利点だよね
実際に利用したのは、1、2回ですが、効果ありでした。特に、アンサンブルしたら性能が上がるのに、処理速度の関係でアンサンブルは厳しいといった場面で使える気がします。
Model Soupsの簡単な説明
Model Soupsは、異なるハイパーパラメータで学習した複数のファインチューニングモデルの「重み(Weight)」と平均化することで、精度を向上させるテクニックです。
Model Soupsには、以下のような特徴があります。
- アンサンブルとは異なる: Model Soupsは、推論時に余分な計算を必要とせず、多くのモデルを平均化することができる
- 精度とロバスト性の向上: 複数のモデルの重みを平均化することで、個々のモデルの弱点やバイアスを克服し、より高い精度とロバスト性を達成することができる
- 効率的な学習: 複数のモデルを個別に学習するよりも効率的に学習することができる
アンサンブルと異なり、推論時に処理が増えないのが一番の利点だと思います。
論文リンク
Model soups: averaging weights of multiple fine-tuned models improves accuracy without increasing inference time:https://arxiv.org/abs/2203.05482
実装
ヘルパー感数
Model Soupsを行うためのヘルパー関数です。
以下は、pytorchでの実装例です。
def read_model(file):
# ファイル名のモデルを読み込む(ここはモデル毎に作成)
return model
def copy_params(model1, model2) :
model1.load_state_dict(model2.state_dict())
return model1
def sum_model_params(modelA, modelB):
""" modelA + modelB """
sdA = modelA.state_dict()
sdB = modelB.state_dict()
for key in sdA:
sdA[key] = (sdA[key] + sdB[key])
modelA.load_state_dict(sdA)
return modelA
def mul_model_params(model, a):
""" a * model """
sd = model.state_dict()
for key in sd:
sd[key] = sd[key] * a
model.load_state_dict(sd)
return model
それぞれの関数の機能は以下の通りです。
関数名 | 機能 |
read_model(file) | file で指定されたモデルを読み込む。モデル毎にカスタムする必要があります。 |
copy_params(model1, model2) | model2 の重みをmodel1 へコピーする |
sum_model_params(modelA, modelB) | modelA にmodelB の重みを足す |
mul_model_params(model, a) | model の重みをa 倍する |
これらの関数を作っておけば、Model Soupsはこれらの関数の組みわせで実現できます。
モデルAとモデルBを平均化する場合
例えば、モデルAとモデルBを平均化する場合は、以下のようなコードになります。
modelA = read_model('modelA.pth')
modelB = read_model('modelB.pth')
model = sum_model_params(modelA, modelB)
model = mul_model_params(model, 1/2) # 2つ合成したので、1/2にする
N個のモデルの組み合わせからベストのモデルを作成する場合
重みの異なるN個のモデルがある場合は、以下のコードで最適な組み合わせをチェックし、平均化することができます。
以下は、リストmodels
にN個のモデルが格納されている場合の例です。なお、calc_score()
は検証データを使った評価を行う関数とします。
以下は、コード例です。
n = len(models)
best_model = models[0]
best_score = -1e18 # 最小値以下を設定
for bit in range(1, 1<<n):
pat = []
for i in range(n):
if (bit>>i)%2 == 1:
pat.append(i)
model = models[pat[0]]
print(format(bit, '04b'), pat, len(pat))
for i in range(1, len(pat)):
model = sum_model_params(model, models[pat[i]])
print(pat[i])
model = mul_model_params(model, 1/len(pat))
score = calc_score(model)
if score > best_score :
best_model = model
コード自身は単純で、全てのモデルの組み合わせを作成し、それぞれのスコアをチェックしています。
calc_score
の処理時間(検証データの処理時間)次第ですが、モデルの組み合わせ8個程度までは実用的な速度でチェックできると思います。
組み合わせパターンは、それぞれのモデルを使う・使わないで$O(2^n)$となるので注意が必要です。N=8の場合は255パターンですが、N=16の場合は65535パターンのチェックを行うことになります。
また、検証データで最も良いものを探索するので、検証データがモデルにリークしてしまう可能性があることに注意が必要です。
検証データに対して最適化しまうのは、しょうがないところです。実際のデータと検証データの集合に乖離がある場合は、実データに対しての性能は低下してしまうので注意。
まとめ
Model Soupsについて簡単に説明し、コード例を紹介しました。実装は簡単なのに、かなり効果のある手法です。また、画像認識、自然言語処理、音声認識など、様々なディープラーニングタスクで使うことが可能と、汎用性の高いテクニックです。