Batch Normalizationの仕組みと学習・推論時の動作の違いを理解する 【初級 深層学習講座】
バッチ正規化(Batch Normalization)は、ディープラーニングで頻繁に使用される重要なテクニックの1つです。しかし、学習時と推論時での動作の違いを理解していない方も多いかもしれません。本記事では、初心者向けにこの動作の違いを解説します。さらに、Pytorchを用いた実際のコードを通じてバッチ正規化の動きを説明します。
とりあえず、「学習時と推論時の動きの違い」だけを知りたい方は、ここをクリックしてください。
Batch Normalizationとは
バッチ正規化(バッチノーマライゼーション, Batch Normalization)は、ディープラーニングにおける重要なテクニックの1つです。この手法は、ネットワークの学習を安定化させ、収束を早め、モデルの性能を向上させるのに役立ちます。
ディープラーニングモデルが層が深くなると、勾配消失(vanishing gradients)や勾配爆発(exploding gradients)などの問題が発生しやすくなります。これは、層を重ねると勾配が非常に小さくなったり、非常に大きくなったりするためです。これらの問題が発生すると、モデルの訓練が難しくなり、学習が遅くなります。
バッチ正規化では、各層の出力を平均と分散で正規化することにより、出力のスケールが調整されます。これにより、勾配の大きさが一定の範囲に収まり、勾配消失や勾配爆発の問題を軽減することができます。
バッチ正規化の基本的な操作は以下の通りです
- ミニバッチ内の各層の出力値に対して平均と分散を計算する
- 平均と分散を使って、各層の出力を正規化する(同一チャネルが平均0、分散1になるように正規化する)
これらの操作について、以下のような疑問が出る人もいると思います
- ミニバッチのサイズが小さい場合、平均と分散が安定しないのでは?
- 推論時も入力に対して、平均と分散を計算するの?
以下、Pytorchを使って動きを確認し、実際にどう動くのかを理解していきます。
Pytorchで動きを確認する
Pytorchのインポート
まず、pytorchをインポートします。インポートは以下のコードで行います。2行目はtorch.nnを単にnn
と書いて利用するための宣言です。
import torch
from torch import nn
データの準備
バッチ正規化のテストのためのデータを作成します。
データは32個×2チャネルのデータです。
この例では、バッチサイズが32で入力データが2チャネルのデータに相当します。
データの中身はなんでもよいので、とりあえず、torch.rand
関数を使ってランダムな値を32×2個生成します。
以下は、ランダムな値の32行2列のデータを作成するプログラムです
input_samples = 32
input_features = 2
x = torch.rand((input_samples,input_features))
print(x)
出力は以下のようになります(ランダムなので毎回値は変化します)。
tensor([[0.7393, 0.7103],
[0.1955, 0.5102],
[0.9233, 0.2299],
[0.6489, 0.1678],
[0.6100, 0.2628],
[0.4312, 0.6737],
[0.6468, 0.7555],
[0.3580, 0.2969],
[0.3530, 0.1765],
[0.8987, 0.0702],
[0.6175, 0.7216],
[0.0066, 0.0452],
[0.9841, 0.4782],
[0.9149, 0.2342],
[0.2152, 0.3166],
[0.7439, 0.2375],
[0.2060, 0.4124],
[0.3658, 0.3167],
[0.2130, 0.2546],
[0.7829, 0.3622],
[0.0338, 0.3052],
[0.9881, 0.0938],
[0.7171, 0.4599],
[0.1871, 0.5133],
[0.9776, 0.1292],
[0.3541, 0.3768],
[0.6004, 0.8655],
[0.6294, 0.6045],
[0.1495, 0.7324],
[0.2257, 0.3083],
[0.6929, 0.3413],
[0.2750, 0.3753]])
上記のデータの平均と分散を計算してみます。計算は以下のコードで行うことができます。
なお、2つのチャネルは個別に平均と分散を計算しています。
print(torch.mean(x, 0), torch.var(x, 0))
tensor([0.5214, 0.3856]) tensor([0.0906, 0.0476])
出力結果から、1つ目のチャネルの平均が0.5214, 分散が0.0906であることが分かります。また、2つ目のチャネルは平均が0.3856で分散が0.0476です。
この平均と分散は一旦覚えておいてください。
学習時の動き
まずBatchNorm1d
のオブジェクトbatch_norm
を生成し、trainモード(batch_norm.train()
)にします。
以下のコードでオブジェクトの生成と、学習モードへの切り替えが可能です。
batch_norm=nn.BatchNorm1d(input_channel)
batch_norm.train()
ここで、新たに生成したbatch_norm
の平均と分散の初期値を見ておきます。
batch_norm
の平均と分散は、batch_norm.running_mean
とbatch_norm.running_var
という2つの変数に格納されています。
以下のように、初期化状態では、平均は0で、分散は1です。
batch_norm
の初期の平均と分散を表示print(batch_norm.running_mean, batch_norm.running_var)
出力:tensor([0., 0.]) tensor([1., 1.])
では、batch_norm
を何度か呼び出して、内部の平均と分散がどのように変化するか観察します。
以下のコードでbatch_norm
を繰り返しながら、内部の状態を表示できます。ここでは、100回実行を繰り返しています。
for i in range(100) :
y = batch_norm(x)
print(batch_norm.running_mean, batch_norm.running_var)
tensor([0.0521, 0.0386]) tensor([0.9091, 0.9048])
tensor([0.0991, 0.0733]) tensor([0.8272, 0.8190])
tensor([0.1413, 0.1045]) tensor([0.7536, 0.7419])
tensor([0.1793, 0.1326]) tensor([0.6873, 0.6725])
: (途中略)
tensor([0.5214, 0.3856]) tensor([0.0907, 0.0476])
tensor([0.5214, 0.3856]) tensor([0.0907, 0.0476])
tensor([0.5214, 0.3856]) tensor([0.0907, 0.0476])
先ほど計算したように、入力データの平均と分散はtensor([0.5214, 0.3856])
tensor([0.0906, 0.0476])
ですが、1回目から4回目の実行では、内部の平均と分散が入力データの平均と分散とは異なっていることがわかります。
しかし、実行を繰り返すと徐々に値は変化し、最終的には入力データの平均と分散に一致します。
このように、内部の平均と分散は、入力に合わせてすぐに変化するわけではなく、徐々に変化していくことがわかりました。
ここで、バッチ正規化された値y
を見てみます。y
を確認するコードは以下のようになります。なお、コードでは出力が多すぎて分かりにくくならないように、y[:2]
として、先頭の2つだけ表示させています。
上のコードを動かした後は、batch_norm
の内部の変数の状態は変化してしまっているので、新たに初期化して実行しています
batch_norm=nn.BatchNorm1d(input_channel)
batch_norm.train()
for i in range(100) :
y = batch_norm(x)
print(y[:2], torch.mean(y,0), torch.var(y, 0))
tensor([[ 0.7353, 1.5126],
[-1.0998, 0.5805]], grad_fn=<SliceBackward0>) tensor([-3.2037e-07, 2.9802e-08], grad_fn=<MeanBackward1>) tensor([1.0321, 1.0320], grad_fn=<VarBackward0>)
tensor([[ 0.7353, 1.5126],
[-1.0998, 0.5805]], grad_fn=<SliceBackward0>) tensor([-3.2037e-07, 2.9802e-08], grad_fn=<MeanBackward1>) tensor([1.0321, 1.0320], grad_fn=<VarBackward0>)
: (途中略)
tensor([[ 0.7353, 1.5126],
[-1.0998, 0.5805]], grad_fn=<SliceBackward0>) tensor([-3.2037e-07, 2.9802e-08], grad_fn=<MeanBackward1>) tensor([1.0321, 1.0320], grad_fn=<VarBackward0>)
tensor([[ 0.7353, 1.5126],
実行結果を見ると、BatchNorm1d
で変換された出力値y
は1回目から100回目まで同じ値です。
また、出力されたy
の平均は[-3.2037e-07, 2.9802e-08]
、分散は[1.0321, 1.0320]
と、ほぼ平均0、分散1に正規化されています。
このように、学習時はバッチ毎に平均0、標準偏差1になるように正規化を行いつつ、内部で保持する平均と分散を徐々に変化させていることがわかります。
なお、この、内部で保持している平均と分散は学習時には利用されず、推論時に用いられることになります。
機械学習や統計学では、平均を0に、標準偏差を1にする操作は「標準化」と呼ばれています。やっていることは、標準化(Standarization)なのに、なんで正規化(Normalization)と呼ぶのか? 理由は、論文にそう書いてあったからだと思います。気にはなりますが、バッチ正規化という用語だと思った方が幸せです。
推論時の動き
次に、推論時の動きを確認してみます
推論時の動きを確認するために、batch_norm.eval()
で推論モードに変更して、同様に何度か繰り返してみます。
なお、入力の平均と分散も変化させたいので、新しく乱数を生成しています。
input_samples = 32
input_features = 2
x = torch.rand((input_samples,input_channel))
print("x:", torch.mean(x, 0), torch.var(x, 0))
batch_norm.eval()
for i in range(10) :
y = batch_norm(x[16:])
print(batch_norm.running_mean, batch_norm.running_var, torch.mean(y,0), torch.var(y, 0))
x: tensor([0.5036, 0.3448]) tensor([0.0533, 0.0589])
tensor([0.5214, 0.3856]) tensor([0.0907, 0.0476]) tensor([-0.2178, 0.0130], grad_fn=<MeanBackward1>) tensor([0.4012, 0.8329], grad_fn=<VarBackward0>)
: (途中略)
tensor([0.5214, 0.3856]) tensor([0.0907, 0.0476]) tensor([-0.2178, 0.0130], grad_fn=<MeanBackward1>) tensor([0.4012, 0.8329], grad_fn=<VarBackward0>)
入力の平均と分散はtensor([0.5036, 0.3448]) tensor([0.0533, 0.0589])
です。
一方、batch_normの平均(running_mean)と分散(running_var)はtensor([0.5214, 0.3856]) tensor([0.0907, 0.0476])
と、入力の平均と分散とはずれています。
推論時は、running_mean
とrunning_var
を使って正規化するため、バッチ正規化された後のyの平均は[-0.2178, 0.0130]
、分散は[0.4012, 0.8329]
となり、平均0、分散1となっていません。
このように、推論時には、学習時に計算した平均と分散が使われるため、バッチ正則化後の平均と分散は0と1にはならないことがわかります。
BatchNormの層を含むモデルの場合は、eval()
を忘れないようにしないといけません。
バッチサイズ=1の時はどうなる?
バッチサイズが1の場合は、どうなるかも確認しました。学習時にバッチサイズが1だとエラーが発生するようです。なお、batch_norm.eval()
で実行した場合はエラーは発生しません。
batch_norm=nn.BatchNorm1d(input_channel)
batch_norm.train()
y = batch_norm(x[:1])
y
ValueError: Expected more than 1 value per channel when training, got input size torch.Size([1, 2])
ここの動きは結構複雑で、バッチサイズ16などを処理した後にバッチサイズ1を処理してもエラーは発生しません。例えば全体で33個のデータがある場合、バッチサイズ16で処理すると、16個+16個+1個となり、最後のバッチは1個になるので、こういう場合にエラーにならないようになっているようです。
学習時の動きと推論時の動き
以上、バッチ正則化の動きをプログラムで確認しながら見てきました。Batch Normalizationの動きをまとめると、以下のようになります。
- バッチ毎に、入力の平均と分散で正規化を行う。出力の平均は0、分散は1になる
- 内部に保持するrunning_meanとrunning_varを更新する
- 内部に保持するrunning_meanとrunning_varで正規化を行う
- 出力の平均が0、分散が1になるとは限らない
動きとしては、学習時にバッチ毎に内部の平均と分散を補正することで、学習データ全体の平均と分散を計算し、その値を使って推論時の正規化を行います。
バッチ正規化の動きを考えると、学習時のバッチサイズはある程度大きくした方が、安定した平均と分散になりそうです。
まとめ
学習時と推論時のバッチ正規化の動きについて解説しました。実際にコードを動かしてみたので、わかりやすかったのではないかと思います。
結局、学習時に内部に平均と分散の移動平均をとっておいて、推論時はその値で正規化するという動きをしているようです。
自分も以前は「推論時はどうやってる?」という疑問がありました。この記事が疑問の解消の助けになれば幸いです。