「KLダイバージェンスとは何か?」を数式と具体例で解説【初級 深層学習講座】
ディープラーニングで(VAE(変分オートエンコーダー)など)で用いられるKLダイバージェンスについて調べ、まとめてみました。KL情報量は理解が難しいですが、この記事では具体例でをわかりやすく解説します。Pytorchには、KLDivLossが用意されているので簡単に利用することも可能です。
KLダイバージェンス
KLダイバージェンスのとは
KLダイバージェンス(カルバック・ライブラー情報量、Kullback-Leibler Divergence)は、2つの確率分布 Pと Qの間の距離を測る指標です。
KLダイバージェンスの指揮
その式は以下のようになります
連続確率分布の場合
$$
D_{\text{KL}}(P \parallel Q) = \int_{-\infty}^{\infty} p(x) \log \frac{p(x)}{q(x)} \, dx
$$
離散確率分布の場合
$$
D_{\text{KL}}(P \parallel Q) = \sum_{x \in X} P(x) \log \frac{P(x)}{Q(x)}
$$
ここで、
- $P$と$Q$は、2つの異なる確率分布です
- $x$は、確率分布が定義されている事象(例えば、サイコロの目)
- $P(x)$は、確率分布$P$における$x$の確率
- $Q(x)$は、確率分布$Q$ における$x$の確率
- $log$は、対数関数
式の意味を理解したい
- 比率の計算
$ \frac{P(x)}{Q(x)}$の部分は、事象 $x$ の確率が $P$ と $Q$ でどれくらい違うかを示しています。例えば、サイコロで1が出る確率が $P$ では0.2、$Q$ では0.1なら、この比率は2になります。逆にサイコロで1が出る確率が $P$ では0.1、$Q$ では0.2なら、この比率は0.5になります。 - 対数を取る
$\log \frac{P(x)}{Q(x)}$は、その比率の対数を取っています。対数をとると$P(X)$と$ Q(x)$が同じ場合は0になります。また、$P(x)<Q(x)$であれば負に、$P(x)>Q(x)$であれば正になります。 - 重み付けして足し合わせる
$P(x) \log \frac{P(x)}{Q(x)}$ は、対数の値に$P(x)$を掛けたものです。これは、その事象がどれくらい重要か(頻度が高いか)を考慮していることになります。これを、すべての$x$ についてこの値を足し合わせます。
このような計算をすることで、$P$と$Q$が同じであればゼロ(0)、異なるほど大きな値になります。つまり、2つの確率分布$P$と$Q$の距離を測っていることになります。
感覚としては、2つの確率分布$P$と$Q$の違いを$log$で計算して、重要度($P$の大きさ)で足し込んだものです。確率の大きな事象の差は大きく影響し、確率が小さい事象の差は影響が小さくなるので、結果として確率を加味した重みになっているのがKLダイバージェンスというのが、私自身の理解です。
わかりやすいので離散確率分布で説明しますが、連続でも考え方は同じです。
簡単な例で考える
例えば、コインを投げる場合を考えてみます。この場合のKLダイバージェンスは、どれくらい理想と現実が違うかを測るものです。 KLダイバージェンスが0に近いほど、2つの確率分布は似ています。値が大きいほど、違いが大きいことを示します。
理想のコインと実際のコインの確率
- $P$:理想的なコインで、表と裏がそれぞれ50%の確率で出る
$P(\text{表}) = 0.5$
$P(\text{裏}) = 0.5$ - $Q$:実際のコインで、表が70%、裏が30%の確率で出る
$Q(\text{表}) = 0.7$
$Q(\text{裏}) = 0.3$
KLダイバージェンスを計算する
- 表の場合
$\frac{P(表)}{Q(表)}= \frac{0.5}{0.7} \simeq 0.714$
$log\frac{P(表)}{Q(表)} \simeq log(0.714) \simeq -0.337$
$P(表)\times(-0.337) \simeq -0.169$ - 裏の場合
$\frac{P(裏)}{Q(裏)}= \frac{0.5}{0.3} \simeq 1.667$
$\frac{P(裏)}{Q(裏)} \simeq log(1.667) \simeq 0.511$
$P(裏)\times(0.511) \simeq 0.256$ - 表と裏を合計する
$D_{\text{KL}}(P \parallel Q) = -0.169 + 0.256 = 0.087$
以上のように、KLダイバージェンスは0.087となります。この値が0に近いほど、2つの確率分布$P$と$Q$は似ていることを示します。
PyTorchで確認
PyTorchでは、上記の「コインを投げる場合」と同じP,Qで計算してみます。結果は、0.0872とほぼ同じになりました。
P = torch.Tensor([0.5,0.5])
Q = torch.Tensor([0.7,0.3])
(P * (P / Q).log()).sum()
# tensor(0.0872)
なお、PyTorchではnn.KLDivLoss
という損失関数が用意されているので、KLダイバージェンスロスを簡単に計算することが可能です。
使い方は以下のようになります。参考にしてください。
P = torch.Tensor([0.5,0.5])
Q = torch.Tensor([0.7,0.3])
kldiv = nn.KLDivLoss(reduction="sum")
kldiv(Q.log(), P)
# tensor(0.0872)
まとめ
KLダイバージェンスの計算方法について解説しました。式自身は簡単ですが、慣れるまで大変です。