初級ディープラーニング
記事内に商品プロモーションを含む場合があります

「KLダイバージェンスとは何か?」を数式と具体例で解説【初級 深層学習講座】

Aru

ディープラーニングで(VAE(変分オートエンコーダー)など)で用いられるKLダイバージェンスについて調べ、まとめてみました。KL情報量は理解が難しいですが、この記事では具体例でをわかりやすく解説します。Pytorchには、KLDivLossが用意されているので簡単に利用することも可能です。

KLダイバージェンス

KLダイバージェンスのとは

KLダイバージェンス(カルバック・ライブラー情報量、Kullback-Leibler Divergence)は、2つの確率分布 Pと Qの間の距離を測る指標です。

KLダイバージェンスの指揮

その式は以下のようになります

連続確率分布の場合

$$
D_{\text{KL}}(P \parallel Q) = \int_{-\infty}^{\infty} p(x) \log \frac{p(x)}{q(x)} \, dx
$$

離散確率分布の場合

$$
D_{\text{KL}}(P \parallel Q) = \sum_{x \in X} P(x) \log \frac{P(x)}{Q(x)}
$$

ここで、

  • $P$と$Q$は、2つの異なる確率分布です
  • $x$は、確率分布が定義されている事象(例えば、サイコロの目)
  • $P(x)$は、確率分布$P$における$x$の確率
  • $Q(x)$は、確率分布$Q$ における$x$の確率
  • $log$は、対数関数

式の意味を理解したい

  1. 比率の計算
    $ \frac{P(x)}{Q(x)}$の部分は、事象 $x$ の確率が $P$ と $Q$ でどれくらい違うかを示しています。例えば、サイコロで1が出る確率が $P$ では0.2、$Q$ では0.1なら、この比率は2になります。逆にサイコロで1が出る確率が $P$ では0.1、$Q$ では0.2なら、この比率は0.5になります。
  2. 対数を取る
    $\log \frac{P(x)}{Q(x)}$は、その比率の対数を取っています。対数をとると$P(X)$と$ Q(x)$が同じ場合は0になります。また、$P(x)<Q(x)$であれば負に、$P(x)>Q(x)$であれば正になります。
  3. 重み付けして足し合わせる
    $P(x) \log \frac{P(x)}{Q(x)}$ は、対数の値に$P(x)$を掛けたものです。これは、その事象がどれくらい重要か(頻度が高いか)を考慮していることになります。これを、すべての$x$ についてこの値を足し合わせます。

このような計算をすることで、$P$と$Q$が同じであればゼロ(0)、異なるほど大きな値になります。つまり、2つの確率分布$P$と$Q$の距離を測っていることになります。

感覚としては、2つの確率分布$P$と$Q$の違いを$log$で計算して、重要度($P$の大きさ)で足し込んだものです。確率の大きな事象の差は大きく影響し、確率が小さい事象の差は影響が小さくなるので、結果として確率を加味した重みになっているのがKLダイバージェンスというのが、私自身の理解です。

わかりやすいので離散確率分布で説明しますが、連続でも考え方は同じです。

簡単な例で考える

例えば、コインを投げる場合を考えてみます。この場合のKLダイバージェンスは、どれくらい理想と現実が違うかを測るものです。 KLダイバージェンスが0に近いほど、2つの確率分布は似ています。値が大きいほど、違いが大きいことを示します。

理想のコインと実際のコインの確率

  • $P$:理想的なコインで、表と裏がそれぞれ50%の確率で出る
    $P(\text{表}) = 0.5$
    $P(\text{裏}) = 0.5$
  • $Q$:実際のコインで、表が70%、裏が30%の確率で出る
    $Q(\text{表}) = 0.7$
    $Q(\text{裏}) = 0.3$

KLダイバージェンスを計算する

  1. 表の場合
    $\frac{P(表)}{Q(表)}= \frac{0.5}{0.7} \simeq 0.714$
    $log\frac{P(表)}{Q(表)} \simeq log(0.714) \simeq -0.337$
    $P(表)\times(-0.337) \simeq -0.169$
  2. 裏の場合
    $\frac{P(裏)}{Q(裏)}= \frac{0.5}{0.3} \simeq 1.667$
    $\frac{P(裏)}{Q(裏)} \simeq log(1.667) \simeq 0.511$
    $P(裏)\times(0.511) \simeq 0.256$
  3. 表と裏を合計する
    $D_{\text{KL}}(P \parallel Q) = -0.169 + 0.256 = 0.087$

以上のように、KLダイバージェンスは0.087となります。この値が0に近いほど、2つの確率分布$P$と$Q$は似ていることを示します。

PyTorchで確認

PyTorchでは、上記の「コインを投げる場合」と同じP,Qで計算してみます。結果は、0.0872とほぼ同じになりました。

P = torch.Tensor([0.5,0.5])
Q = torch.Tensor([0.7,0.3])
(P * (P / Q).log()).sum() 
# tensor(0.0872)

なお、PyTorchではnn.KLDivLossという損失関数が用意されているので、KLダイバージェンスロスを簡単に計算することが可能です。

使い方は以下のようになります。参考にしてください。

P = torch.Tensor([0.5,0.5])
Q = torch.Tensor([0.7,0.3])
kldiv = nn.KLDivLoss(reduction="sum")
kldiv(Q.log(), P)
# tensor(0.0872)

まとめ

KLダイバージェンスの計算方法について解説しました。式自身は簡単ですが、慣れるまで大変です。

初級 深層学習講座シリーズはこちら
ディープラーニングに関する記事一覧はこちら
ディープラーニング関連の記事一覧
ディープラーニング関連の記事一覧

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

ABOUT ME
ある/Aru
ある/Aru
IT&機械学習エンジニア/ファイナンシャルプランナー(CFP®)
専門分野は並列処理・画像処理・機械学習・ディープラーニング。プログラミング言語はC, C++, Go, Pythonを中心として色々利用。現在は、Kaggle, 競プロなどをしながら悠々自適に活動中
記事URLをコピーしました