初級ディープラーニング
記事内に商品プロモーションを含む場合があります

残差接続による勾配消失改善を確認してみる【初級 深層学習講座】

Aru

DeepLearningには、残差接続を活用して勾配消失を改善する手法があります。この記事では、残差接続の効果を実際のコードで確認してみます。

残差接続(Residual Connection)とは

残差接続は、スキップ接続(スキップコネクション)とも呼ばれるもので、各階層の出力に元の入力をそのまま加えることで、ネットワークがより学習しやすくなるよう工夫された構造です。特にResNet(Residual Network)というネットワークで有名になり、深層モデルでの性能向上に大きく貢献しました。

一般的な形式は以下のとおりです。

出力 = F(x) + x

上の式の x は入力、F(x) は通常のニューラルネットワーク層の変換です。出力は層の出力です。式の通りネットワークの出力と入力を加えたものを出力します。これにより、ネットワークは「差分(=残差)」だけを学習すればよくなり、勾配が伝播しやすくなります。

なぜ、残差接続が有効なのか?

ニューラルネットワークが深くなると、勾配消失問題が発生しやすくなります。これは、逆伝播で勾配が繰り返し小さくなり、初期の層にほとんど伝わらなくなる現象です。その結果、深い層のパラメータがほとんど更新されなくなり、学習が進まなくなるという問題がありました。

残差接続は、これを解消するために導入されました。残差接続の効果は以下になります。

  • 勾配が直接流れるルートを持つ:スキップ接続により、勾配が間接的にでも深い層まで到達しやすくなります。これにより、学習が安定します

実際に、ディープニューラルネットワークでは層を深くすれば性能が下がることもありました。しかしながら、残差接続を導入したResNetでは、非常に深いネットワーク(50層、101層、さらにそれ以上)でも実用的な学習を可能とすることができました。

残差接続の効果を確認(プログラム)

ここでは、実際にPyTorchを使って、残差接続が勾配に与える影響を確認してみます。注目するのは、学習時に逆伝播される重みの勾配の大きさです。これは、ネットワークがどれだけ活発に学習しているかの目安になります。

モデルの構造

以下のコードでは、TestModelというシンプルな5層の全結合ニューラルネットワークを定義しています。引数 use_shortcut によって、残差接続(ショートカット)を使うかどうかを切り替えることができます。このモデルを使って、残差接続あり残差接続なしのモデルの比較を行いたいと思います。

import torch
import torch.nn as nn

class TestModel(nn.Module):
    def __init__(self, in_features, out_features, use_shortcut):
        super().__init__()
        self.use_shortcut = use_shortcut
        self.layers = nn.ModuleList([
            nn.Sequential(nn.Linear(in_features, in_features), nn.Sigmoid()),
            nn.Sequential(nn.Linear(in_features, in_features), nn.Sigmoid()),
            nn.Sequential(nn.Linear(in_features, in_features), nn.Sigmoid()),
            nn.Sequential(nn.Linear(in_features, in_features), nn.Sigmoid()),
            nn.Sequential(nn.Linear(in_features, out_features), nn.Sigmoid()),
        ])

    def forward(self, x):
        for layer in self.layers:
            layer_output = layer(x)
            if self.use_shortcut and x.shape == layer_output.shape:
                x = x + layer_output
            else: # 最終段はショートカットせずに出力する
                x = layer_output
        return x

このモデルは以下のような構造になっています。

  • 各層は nn.Linearnn.Sigmoid を組み合わせた全結合層
  • 残差接続が有効な場合、各層の出力に入力 x を足し合わせる形でスキップ接続が入る
  • 最後の層(入力と出力のサイズが異なる最終層)は、スキップ接続は行わない

順伝播と逆伝播の実行

do_and_print()は、勾配を計算し、表示する関数です。

def do_and_print(model, x, target):
    model.train()
    output = model(x)

    loss = nn.CrossEntropyLoss()
    loss = loss(output, target)
    loss.backward()

    for name, param in model.named_parameters():
        if 'weight' in name:
            print(f"{name} : {param.grad.abs().mean().item()}")

この関数は以下の処理を行います

  1. 入力 x を使ってモデルを順伝播させ、出力を計算する
  2. CrossEntropyLoss を用いて損失を計算する
  3. loss.backward() で逆伝播を実行し、重みの勾配を求める
  4. 各層の重みについて、勾配の平均絶対値を出力する

厳密には異なりますが、出力した勾配が大きい=その層のパラメータが更新されやすいと捉えてOKです。

モデルの比較

ショートカットあり、なしの2つのモデルを生成し、同じ値を入力して勾配をチェックします。

コードは以下になります。

ショートカットありのモデル

# クラス分類のサンプル
sample_input = torch.tensor([[1.0, 0.4, 0.8]])
target = torch.tensor([[1., 0.]])

torch.manual_seed(123)
model_without_shortcut = TestModel(3, 2, use_shortcut=False)
print("Model without shortcut:")
do_and_print(model_without_shortcut, sample_input, target)

ショートカットなしのモデル

torch.manual_seed(123)
model_without_shortcut = TestModel(3, 2, use_shortcut=True)
print("Model with shortcut:")
do_and_print(model_without_shortcut, sample_input, target)

プログラム全体

今回のプログラム全体です。実行する場合には、これをコピーして動かしてください。

import torch
import torch.nn as nn


class TestModel(nn.Module):
    def __init__(self, in_features, out_features, use_shortcut):
        super().__init__()
        self.use_shortcut = use_shortcut
        self.layers = nn.ModuleList([
            nn.Sequential(nn.Linear(in_features, in_features), nn.Sigmoid()),
            nn.Sequential(nn.Linear(in_features, in_features), nn.Sigmoid()),
            nn.Sequential(nn.Linear(in_features, in_features), nn.Sigmoid()),
            nn.Sequential(nn.Linear(in_features, in_features), nn.Sigmoid()),
            nn.Sequential(nn.Linear(in_features, out_features), nn.Sigmoid()),
        ])

    def forward(self, x):
        for layer in self.layers:
            layer_output = layer(x)
            if self.use_shortcut and x.shape == layer_output.shape:
                x = x + layer_output
            else: # 最終段はショートカットせずに出力する
                x = layer_output
        return x


def do_and_print(model, x, target):
    model.train()
    output = model(x)

    loss = nn.CrossEntropyLoss()
    loss = loss(output, target)
    loss.backward()

    for name, param in model.named_parameters():
        if 'weight' in name:
            print(f"{name} : {param.grad.abs().mean().item()}")



# クラス分類のサンプル
sample_input = torch.tensor([[1.0, 0.4, 0.8]])
target = torch.tensor([[1., 0.]])


torch.manual_seed(123)
model_without_shortcut = TestModel(3, 2, use_shortcut=False)
print("Model without shortcut:")
do_and_print(model_without_shortcut, sample_input, target)


torch.manual_seed(123)
model_without_shortcut = TestModel(3, 2, use_shortcut=True)
print("Model with shortcut:")
do_and_print(model_without_shortcut, sample_input, target)

実行結果

プログラムを実行すると以下のような結果が表示されます。

値は、逆伝播による伝わった各層の勾配の平均絶対値です。layers.0.0が一番上の層(入力に近い層)、layers.4.0が下層(出力に近い層)になります。

ショートカットありのモデル

Model without shortcut:
layers.0.0.weight : 1.658325345488265e-05
layers.1.0.weight : 6.265898264246061e-05
layers.2.0.weight : 0.0005473571945913136
layers.3.0.weight : 0.0031391088850796223
layers.4.0.weight : 0.05082247778773308

上に行くほど勾配が小さくなっていることがわかります。layers.0.0は勾配が1.7e-5程度とかなり小さくなっていることがわかります。

ショートカットなしのモデル

Model with shortcut:
layers.0.0.weight : 0.0013392833061516285
layers.1.0.weight : 0.0017523732967674732
layers.2.0.weight : 0.0020216701086610556
layers.3.0.weight : 0.0027825217694044113
layers.4.0.weight : 0.16284583508968353

上に行くほど勾配は小さくなっていますが、ショートカットなしほど急ではないことがわかります。

実行結果のまとめ

実行結果から、以下のようなことがわかるかと思います

  • ショートカットなしのモデルは、深い層ほど勾配が小さくなっている(≒勾配消失)。
  • ショートカットありのモデルは、全体的に勾配が大きく保たれており、初期層にもきちんと伝わっている

この違いが、残差接続の「勾配が伝播しやすくなる」効果です。

まとめ

この記事では、残差接続の有効性について、実際のPyTorchプログラムを使って確認しました。

ディープニューラルネットワークにおいては、学習が深い層で止まってしまう「勾配消失」が大きな課題のひとつです。しかし、残差接続(スキップ接続)を導入することで、勾配が浅い層まで届きやすくなり、学習が安定・加速されます

今回は、勾配の大きさをプログラムで確認することで、「理論的な知識」を「実感」できたのではないでしょうか。このように仕組みを実験ベースで確かめていくことも、深層学習を理解する1つの方法です。

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

ABOUT ME
ある/Aru
ある/Aru
IT&機械学習エンジニア/ファイナンシャルプランナー(CFP®)
専門分野は並列処理・画像処理・機械学習・ディープラーニング。プログラミング言語はC, C++, Go, Pythonを中心として色々利用。現在は、Kaggle, 競プロなどをしながら悠々自適に活動中
記事URLをコピーしました