関数を自作してPyTorchの逆伝播を理解する【初級 深層学習講座】
![](https://tech.aru-zakki.com/wp-content/uploads/2024/03/pytorch-backward.001.jpeg)
PyTorchの勾配計算(backward関数)を自作することで、逆伝播の動きを理解しようという内容です。微分なども入ってくるのでちょっと面倒ですが、なんとなくでよいので理解の助けになればと思います。
はじめに
ディープラーニングの学習をすると、勾配とか誤差逆伝播などの部分が少し難しく感じるかと思います。特に、偏微分とか入ってくるので、ここでつまづく方も多いかと思います。
ただ、PyTorchなどを実際に利用する場合には、誤差逆伝播を自分で書くことはほとんどないので、意識することは少ないです(backward関数を呼び出すだけです)。
この記事では、自作関数を作りながら動作について確認していきたいと思います。
関数(y = x^2)を作成する
今回作成するのは、以下の関数です。
$$
y = x ^ 2
$$
これを$dx$で微分すると、以下になります。
$$
\frac{dy}{dx} = 2 x
$$
これを使って自作関数を作りたいと思います。自作関数を作る手順は以下になります。
torch.autograd.Function
クラスを継承するforward()
とbackward()
メソッドを用意するforward()
は、$y = …$の結果を出力backward()
は、出力の勾配$\frac{dL}{dy}$を受け取るので、入力の勾配$\frac{dL}{dx}$を返す
少し式が入って面倒ですが、$\frac{dL}{dy}$から入力の勾配$\frac{dL}{dx}$を計算する式は以下になります。
$$
\frac{dL}{\cancel{dy}} \frac{\cancel{dy}}{dx} = \frac{dL}{dx}
$$
$\frac{dL}{dy}$は引数として与えられ、$\frac{dy}{dx}$は$x$から計算する式がわかっているので、計算することが可能です。
以下が、$y=x^2$の関数になります。
# y = x**2, dy/dx = 2x
class custom_f(torch.autograd.Function ):
@staticmethod
def forward(self, x):
self.save_for_backward( x )
y = x * x
return y
@staticmethod
def backward(self, dL_dy): # dL_dy = dL/dy
x, = self.saved_tensors
dy_dx = 2*x
dL_dx = dL_dy * dy_dx
return dL_dx
最初の引数self
は、forward
時に勾配計算に必要な値を保持し、backward
時は保存した結果を取り出すためのものです。
![](https://tech.aru-zakki.com/wp-content/uploads/2023/06/tabbycat.png)
PyTorchの公式の解説では、self
ではなくctx
という変数名ですが、私はPythonでよく使われるself
という変数名にしています。PyTorchの公式通りctx
の方が良いかもしれません。
forward()
では、$x$を保存し、$x^2$を計算して返しています。
backward()
では、$x$を取り出し$\frac{dy}{dx}$を計算したのちに、引数として渡された$\frac{dL}{dy}$と掛け合わせて$\frac{dL}{dx}$を計算し返しています。
以上で、関数の実装は完了です。
確認
実際に、backwardで勾配計算ができているか確認します
確認①
作成した関数custom_f
は、apply
しないと使えません。なので、f=custom_f.apply
として関数f
を定義しています。
次にy = f(x)
を計算し、L
に代入してbackward
しています。
これで逆伝播が実行されます。
x, y
の値を確認すると、$y = x^2 = 3^2 = 9$となり正しく計算されていることがわかります。
また、$\frac{dy}{dx} = 2x = 2\times 3 = 6$なので、x.grad
(xの勾配)も正しく計算されていることがわかります。
x = torch.Tensor([3.]).requires_grad_()
f = custom_f.apply
y = f(x)
L = y
L.backward()
#
print(x, y)
print(x.grad, 2*x)
#tensor([3.], requires_grad=True) tensor([9.], grad_fn=<custom_fBackward>)
#tensor([6.]) tensor([6.], grad_fn=<MulBackward0>)
確認②
少し式を複雑にして、$y = f(f(x))$をやってみます。
コードでは、y0 = f(x)
、y1 = f(y0)
として2回f()
を呼び出しています。
x, y0, y1
の値を確認すると、$y0 = x^2 = 3^2 = 9$, $y1 = y0^2 = 9^2 = 81$となり正しく計算されていることがわかります。
$y1 = f(f(x)) = (x^2)^2 = x^4$なので、$\frac{dy1}{dx} = 4 x^3$となります。
また、$\frac{dy1}{dx} = 4x^3 = 4\times 3^3 = 108$なので、x.grad
(xの勾配)も正しく計算されていることがわかります。
このように、複数回の関数呼び出しがある場合も、目的関数L
からxの勾配を正しく計算することができます。
x = torch.Tensor([3.]).requires_grad_()
f = custom_f.apply
y0 = f(x)
y1 = f(y0)
L = y1
L.backward()
print(x, y0, y1)
print(x.grad, 4*(x**3))
# tensor([3.], requires_grad=True) tensor([9.], grad_fn=<custom_fBackward>) tensor([81.], grad_fn=<custom_fBackward>)
# tensor([108.]) tensor([108.], grad_fn=<MulBackward0>)
学習時の動作
実際の学習では、目的関数Lは、正解との誤差(e)で、backward
で勾配を計算したのちに、SGD(確率的勾配降下法)などを使って値を徐々に更新していきます。
誤差を伝播させるためには、勾配計算が必要(偏微分が必要)になりますが、PyTorchではこの部分がライブラリに隠蔽されていますので、あまり意識することなく、forward
側だけ考えてモデルを作成することが可能です。
![](https://tech.aru-zakki.com/wp-content/uploads/2023/06/tabbycat.png)
独自の関数を作成するときはbackward
を意識する必要があります
このあたりが、tensorflow/keras, pytorchなどのフレームワークを利用する利点の1つです。
まとめ
実務だと、逆伝播を意識することは少ないですが、ディープラーニングを学習する場合には必ず出てくる内容です。
この記事では、PyTorchの関数を実際につくって動かしてみることで、勾配計算の流れについて解説しました。