PyTorchの逆伝播を関数を自作して理解する【初級 深層学習講座】
PyTorchの勾配計算(backward関数の処理)を自作することで、逆伝播のメカニズムを理解することができます。この記事では、ニューラルネットワークの基礎となる逆伝播を、理論だけでなく実際にコーディングすることで、微分や勾配計算がどのように実装されているかを確認し、理解を深めることを目指しています。若干、数学的な要素も含まれており、苦手意識のある方も多いかと思いますが、なるべく直感的に理解できるように解説を心がけました。「なんとなくわかった」でもよいので、理解の助けになれば幸いです。
逆伝播・勾配計算とは?
ディープラーニングの学習過程で、「逆伝播(backward)」や「勾配計算」といった用語に直面したことはないでしょうか。勾配計算は、偏微分などの数学的な知識が絡むため、ここでつまづく方も多いかと思います。
逆伝播(バックプロパゲーション)は、ニューラルネットワークの学習手法の1つで、正解との誤差をネットワークの各層に逆方向に伝播させ、誤差が小さくなるようにモデルのパラメータを更新するための手法です。勾配計算は、各パラメータの微分を求めるプロセスで、これを使って損失関数(誤差の)最小化に向けてパラメータの調整が行われます。
PyTorchなどのライブラリを利用する場合には、勾配計算や逆伝播の処理はライブラリが自動的に行うため、ユーザが直接これらの処理内容を意識することはほとんどありません。例えば、PyTorchでは、backward
関数を呼び出すことで、これらの計算が自動的に実行されます。
しかし、これらのプロセスがどのように機能しているかを理解することは、深層学習の理論や実装を理解する上で重要です。
この記事では、自作関数を作りながら動作について確認していきたいと思います。
関数(y = x^2)を作成する
まず、ターゲットとなる関数を作ります。今回作成するのは、以下の関数です。
$$
y = x ^ 2
$$
これを$dx$で微分すると、以下になります。
$$
\frac{dy}{dx} = 2 x
$$
これを使って自作関数を作りたいと思います。自作関数を作る手順は以下になります。
torch.autograd.Function
クラスを継承するforward()
とbackward()
メソッドを用意するforward()
は、$y = …$の結果を出力backward()
は、出力の勾配$\frac{dL}{dy}$を受け取るので、入力の勾配$\frac{dL}{dx}$を返す
少し式が入って面倒ですが、$\frac{dL}{dy}$から入力の勾配$\frac{dL}{dx}$を計算する式は以下になります。
$$
\frac{dL}{\cancel{dy}} \frac{\cancel{dy}}{dx} = \frac{dL}{dx}
$$
$\frac{dL}{dy}$は引数として与えられ、$\frac{dy}{dx}$は$x$から計算する式が分かっているので、計算することが可能です。
以下が、$y=x^2$の関数になります。
# y = x**2, dy/dx = 2x
class custom_f(torch.autograd.Function ):
@staticmethod
def forward(self, x):
self.save_for_backward( x )
y = x * x
return y
@staticmethod
def backward(self, dL_dy): # dL_dy = dL/dy
x, = self.saved_tensors
dy_dx = 2*x
dL_dx = dL_dy * dy_dx
return dL_dx
最初の引数self
は、forward
時に勾配計算に必要な値を保持し、backward
時は保存した結果を取り出すためのものです。
PyTorchの公式の解説では、self
ではなくctx
という変数名ですが、私はPythonでよく使われるself
という変数名にしています。PyTorchの公式通りctx
の方が良いかもしれません。
forward()
では、$x$を保存し、$x^2$を計算して返しています。
backward()
では、$x$を取り出し$\frac{dy}{dx}$を計算したのちに、引数として渡された$\frac{dL}{dy}$と掛け合わせて$\frac{dL}{dx}$を計算し返しています。
以上で、関数の実装は完了です。
確認
実際にbackwardで勾配計算ができているか確認します
確認①
作成した関数custom_f
は、apply
しないと使えません。なので、f=custom_f.apply
として関数f
を定義しています。
次にy = f(x)
を計算し、L
に代入してbackward
しています。
これで逆伝播が実行されます。
x, y
の値を確認すると、$y = x^2 = 3^2 = 9$となり正しく計算されていることがわかります。
また、$\frac{dy}{dx} = 2x = 2\times 3 = 6$なので、x.grad
(xの勾配)も正しく計算されていることがわかります。
x = torch.Tensor([3.]).requires_grad_()
f = custom_f.apply
y = f(x)
L = y
L.backward()
#
print(x, y)
#tensor([3.], requires_grad=True) tensor([9.], grad_fn=<custom_fBackward>)
print(x.grad, 2*x)
#tensor([6.]) tensor([6.], grad_fn=<MulBackward0>)
確認②
少し式を複雑にして、$y = f(f(x))$をやってみます。
コードでは、y0 = f(x)
、y1 = f(y0)
として2回f()
を呼び出しています。
x, y0, y1
の値を確認すると、$y0 = x^2 = 3^2 = 9$, $y1 = y0^2 = 9^2 = 81$となり正しく計算されていることがわかります。
$y1 = f(f(x)) = (x^2)^2 = x^4$なので、$\frac{dy1}{dx} = 4 x^3$となります。
また、$\frac{dy1}{dx} = 4x^3 = 4\times 3^3 = 108$なので、x.grad
(xの勾配)も正しく計算されていることがわかります。
このように、複数回の関数呼び出しがある場合も、目的関数L
からxの勾配を正しく計算することができます。
x = torch.Tensor([3.]).requires_grad_()
f = custom_f.apply
y0 = f(x)
y1 = f(y0)
L = y1
L.backward()
print(x, y0, y1)
# tensor([3.], requires_grad=True) tensor([9.], grad_fn=<custom_fBackward>) tensor([81.], grad_fn=<custom_fBackward>)
print(x.grad, 4*(x**3))
# tensor([108.]) tensor([108.], grad_fn=<MulBackward0>)
学習時の動作
実際の学習では、目的関数Lは、正解との誤差(e)で、backward
で勾配を計算したのちに、SGD(確率的勾配降下法)などを使って値を徐々に更新していきます。
誤差を伝播させるためには、勾配計算が必要(偏微分が必要)になりますが、PyTorchではこの部分がライブラリに隠蔽されていますので、あまり意識することなく、forward
側だけ考えてモデルを作成することが可能です。
独自の関数を作成するときはbackward
を意識する必要があります
このあたりが、tensorflow/keras, pytorchなどのフレームワークを利用する利点の1つです。
まとめ
実務だと、逆伝播を意識することは少ないですが、ディープラーニングを学習する場合には必ず出てくる内容です。
この記事では、PyTorchの関数を実際につくって動かしてみることで、勾配計算の流れについて解説しました。