残差接続による勾配消失改善を確認してみる【初級 深層学習講座】

DeepLearningには、残差接続を活用して勾配消失を改善する手法があります。この記事では、残差接続の効果を実際のコードで確認してみます。
残差接続(Residual Connection)とは
残差接続は、スキップ接続(スキップコネクション)とも呼ばれるもので、各階層の出力に元の入力をそのまま加えることで、ネットワークがより学習しやすくなるよう工夫された構造です。特にResNet(Residual Network)というネットワークで有名になり、深層モデルでの性能向上に大きく貢献しました。
一般的な形式は以下のとおりです。
出力 = F(x) + x
上の式の x
は入力、F(x)
は通常のニューラルネットワーク層の変換です。出力は層の出力です。式の通りネットワークの出力と入力を加えたものを出力します。これにより、ネットワークは「差分(=残差)」だけを学習すればよくなり、勾配が伝播しやすくなります。
ニューラルネットワークが深くなると、勾配消失問題が発生しやすくなります。これは、逆伝播で勾配が繰り返し小さくなり、初期の層にほとんど伝わらなくなる現象です。その結果、深い層のパラメータがほとんど更新されなくなり、学習が進まなくなるという問題がありました。
残差接続は、これを解消するために導入されました。残差接続の効果は以下になります。
- 勾配が直接流れるルートを持つ:スキップ接続により、勾配が間接的にでも深い層まで到達しやすくなります。これにより、学習が安定します
実際に、ディープニューラルネットワークでは層を深くすれば性能が下がることもありました。しかしながら、残差接続を導入したResNetでは、非常に深いネットワーク(50層、101層、さらにそれ以上)でも実用的な学習を可能とすることができました。
残差接続の効果を確認(プログラム)
ここでは、実際にPyTorchを使って、残差接続が勾配に与える影響を確認してみます。注目するのは、学習時に逆伝播される重みの勾配の大きさです。これは、ネットワークがどれだけ活発に学習しているかの目安になります。
モデルの構造
以下のコードでは、TestModel
というシンプルな5層の全結合ニューラルネットワークを定義しています。引数 use_shortcut
によって、残差接続(ショートカット)を使うかどうかを切り替えることができます。このモデルを使って、残差接続ありと残差接続なしのモデルの比較を行いたいと思います。
import torch
import torch.nn as nn
class TestModel(nn.Module):
def __init__(self, in_features, out_features, use_shortcut):
super().__init__()
self.use_shortcut = use_shortcut
self.layers = nn.ModuleList([
nn.Sequential(nn.Linear(in_features, in_features), nn.Sigmoid()),
nn.Sequential(nn.Linear(in_features, in_features), nn.Sigmoid()),
nn.Sequential(nn.Linear(in_features, in_features), nn.Sigmoid()),
nn.Sequential(nn.Linear(in_features, in_features), nn.Sigmoid()),
nn.Sequential(nn.Linear(in_features, out_features), nn.Sigmoid()),
])
def forward(self, x):
for layer in self.layers:
layer_output = layer(x)
if self.use_shortcut and x.shape == layer_output.shape:
x = x + layer_output
else: # 最終段はショートカットせずに出力する
x = layer_output
return x
このモデルは以下のような構造になっています。
- 各層は
nn.Linear
とnn.Sigmoid
を組み合わせた全結合層 - 残差接続が有効な場合、各層の出力に入力
x
を足し合わせる形でスキップ接続が入る - 最後の層(入力と出力のサイズが異なる最終層)は、スキップ接続は行わない
順伝播と逆伝播の実行
do_and_print()
は、勾配を計算し、表示する関数です。
def do_and_print(model, x, target):
model.train()
output = model(x)
loss = nn.CrossEntropyLoss()
loss = loss(output, target)
loss.backward()
for name, param in model.named_parameters():
if 'weight' in name:
print(f"{name} : {param.grad.abs().mean().item()}")
この関数は以下の処理を行います
- 入力
x
を使ってモデルを順伝播させ、出力を計算する CrossEntropyLoss
を用いて損失を計算するloss.backward()
で逆伝播を実行し、重みの勾配を求める- 各層の重みについて、勾配の平均絶対値を出力する
厳密には異なりますが、出力した勾配が大きい=その層のパラメータが更新されやすいと捉えてOKです。
モデルの比較
ショートカットあり、なしの2つのモデルを生成し、同じ値を入力して勾配をチェックします。
コードは以下になります。
ショートカットありのモデル
# クラス分類のサンプル
sample_input = torch.tensor([[1.0, 0.4, 0.8]])
target = torch.tensor([[1., 0.]])
torch.manual_seed(123)
model_without_shortcut = TestModel(3, 2, use_shortcut=False)
print("Model without shortcut:")
do_and_print(model_without_shortcut, sample_input, target)
ショートカットなしのモデル
torch.manual_seed(123)
model_without_shortcut = TestModel(3, 2, use_shortcut=True)
print("Model with shortcut:")
do_and_print(model_without_shortcut, sample_input, target)
プログラム全体
今回のプログラム全体です。実行する場合には、これをコピーして動かしてください。
import torch
import torch.nn as nn
class TestModel(nn.Module):
def __init__(self, in_features, out_features, use_shortcut):
super().__init__()
self.use_shortcut = use_shortcut
self.layers = nn.ModuleList([
nn.Sequential(nn.Linear(in_features, in_features), nn.Sigmoid()),
nn.Sequential(nn.Linear(in_features, in_features), nn.Sigmoid()),
nn.Sequential(nn.Linear(in_features, in_features), nn.Sigmoid()),
nn.Sequential(nn.Linear(in_features, in_features), nn.Sigmoid()),
nn.Sequential(nn.Linear(in_features, out_features), nn.Sigmoid()),
])
def forward(self, x):
for layer in self.layers:
layer_output = layer(x)
if self.use_shortcut and x.shape == layer_output.shape:
x = x + layer_output
else: # 最終段はショートカットせずに出力する
x = layer_output
return x
def do_and_print(model, x, target):
model.train()
output = model(x)
loss = nn.CrossEntropyLoss()
loss = loss(output, target)
loss.backward()
for name, param in model.named_parameters():
if 'weight' in name:
print(f"{name} : {param.grad.abs().mean().item()}")
# クラス分類のサンプル
sample_input = torch.tensor([[1.0, 0.4, 0.8]])
target = torch.tensor([[1., 0.]])
torch.manual_seed(123)
model_without_shortcut = TestModel(3, 2, use_shortcut=False)
print("Model without shortcut:")
do_and_print(model_without_shortcut, sample_input, target)
torch.manual_seed(123)
model_without_shortcut = TestModel(3, 2, use_shortcut=True)
print("Model with shortcut:")
do_and_print(model_without_shortcut, sample_input, target)
実行結果
プログラムを実行すると以下のような結果が表示されます。
値は、逆伝播による伝わった各層の勾配の平均絶対値です。layers.0.0
が一番上の層(入力に近い層)、layers.4.0
が下層(出力に近い層)になります。
ショートカットありのモデル
Model without shortcut:
layers.0.0.weight : 1.658325345488265e-05
layers.1.0.weight : 6.265898264246061e-05
layers.2.0.weight : 0.0005473571945913136
layers.3.0.weight : 0.0031391088850796223
layers.4.0.weight : 0.05082247778773308
上に行くほど勾配が小さくなっていることがわかります。layers.0.0
は勾配が1.7e-5程度とかなり小さくなっていることがわかります。
ショートカットなしのモデル
Model with shortcut:
layers.0.0.weight : 0.0013392833061516285
layers.1.0.weight : 0.0017523732967674732
layers.2.0.weight : 0.0020216701086610556
layers.3.0.weight : 0.0027825217694044113
layers.4.0.weight : 0.16284583508968353
上に行くほど勾配は小さくなっていますが、ショートカットなしほど急ではないことがわかります。
実行結果のまとめ
実行結果から、以下のようなことがわかるかと思います
- ショートカットなしのモデルは、深い層ほど勾配が小さくなっている(≒勾配消失)。
- ショートカットありのモデルは、全体的に勾配が大きく保たれており、初期層にもきちんと伝わっている。
この違いが、残差接続の「勾配が伝播しやすくなる」効果です。
まとめ
この記事では、残差接続の有効性について、実際のPyTorchプログラムを使って確認しました。
ディープニューラルネットワークにおいては、学習が深い層で止まってしまう「勾配消失」が大きな課題のひとつです。しかし、残差接続(スキップ接続)を導入することで、勾配が浅い層まで届きやすくなり、学習が安定・加速されます
今回は、勾配の大きさをプログラムで確認することで、「理論的な知識」を「実感」できたのではないでしょうか。このように仕組みを実験ベースで確かめていくことも、深層学習を理解する1つの方法です。