機械学習
記事内に商品プロモーションを含む場合があります

HuggingFace Transformerの精度向上テクニック | 出力のPooling手法

Aru

トランスフォーマーの出力を処理する手法が書かれたノートブックUtilizing Transformer Representations Efficientlyを読んで、それぞれの手法をピックしました。性能向上のために有用な情報です。

記事は自身がkaggleでトランスフォーマーを使う時に目を通すために作成しましたが、他の方の参考になれば幸いです。

この記事は、ある程度Hugging Faceのtransformerを使い慣れている人向けの記事です。使い慣れていない場合は、以下の記事も参考にしてください。

Hugging Face Transformer(BERT)でクラス分類(classification)
Hugging Face Transformer(BERT)でクラス分類(classification)
Hugging Face Transformer(BERT)で回帰分析(Regression)
Hugging Face Transformer(BERT)で回帰分析(Regression)

Utilizing Transformer Representations Efficientlyとは

Utilizing Transformer Representations Efficiently“は、kaggleのノートブックのタイトルです。

このノートブックはHugging Faceのトランスフォーマーの出力を加工することで性能を向上させるための手法がまとめられたノートブックです。

Hugging Faceのtransformerでは、クラス分類などに合わせて適切な出力を行うライブラリになっていますが、自身で出力を加工することで更なる性能向上を目指すことができます。実務より、どちらかというとkaggleなどのコンテストで有用なテクニックだと思います。

この記事ではこのノートブックを読んで、手法のポイントとなるコード部分だけを抜き出して整理したものです。

ノートブックへのリンク
https://www.kaggle.com/code/rhtsingh/utilizing-transformer-representations-efficiently

トランスフォーマーの入出力

まず、HuggingFaceのトランスフォーマーの入出力を説明してきます。

HuggingFace Transformerでは以下を入力します。

  • input_ids
  • attention_mask

出力は3つ(設定に応じて変化)を受け取ります。

  • pooler output(batch size, hidden size)
    シーケンスの最初のトークン(CLS, Classification Token)の状態
  • last hidden state(batch size, seq len, hidden size)
    最後の隠れ層の状態
  • hidden states(n layers, batch size, seq len, hidden size)
    全ての隠れ層の状態

コードを見ると出力の0番目が最後の隠れ層、1番目がPooler Output、2番目が全ての隠れ層の状態

なお、すべての隠れ層の状態を出力するにはconfigでoutput_hidden_statesにTrueを設定する必要がある

この記事の流れ

以下、ノートブックに書かれている手法を列挙しました。とりあえず、ノートブックから要点のみを抜き出しています。理解が進めば、この記事だけで理解し、自分のコードに反映できるように抜き出しましたが、より詳しくはノートブックを参照してください。

この記事で、興味のある手法をみつけて、該当するノートの部分を見るのがおすすめです。

pooler output

モデルの出力のoutput[1]を取り出して、全結合層(Linear)で1要素に変換

これが、最も単純な手法となります。

with torch.no_grad():
    outputs = model(features['input_ids'], features['attention_mask'])
pooler_output = outputs[1]
logits = nn.Linear(config.hidden_size, 1)(pooler_output) 
Pooler Output Shape: (バッチサイズ, 隠れ層のサイズ)
Logits Shape: (バッチサイズ, 1)

Last Hidden State Output

CLSトークン(先頭のトークン)を取り出して、全結合層(Linear)で1要素に変換

基本的にはpooler outputと同じです。

with torch.no_grad():
    outputs = model(features['input_ids'], features['attention_mask'])
last_hidden_state = outputs[0]
cls_embeddings = last_hidden_state[:, 0]
logits = nn.Linear(config.hidden_size, 1)(cls_embeddings)
Last Hidden State Output Shape: (バッチサイズ, 最大トークン数, 隠れ層のサイズ)
CLS Embeddings Output Shape: (バッチサイズ, 隠れ層のサイズ)
Logits Shape: (バッチサイズ, 1)

Mean Pooling

最後の層を平均して、全結合層(Linear)で1要素に変換。この時、attention_maskを使って、パディングトークン等を無視するようにする。

attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()attention_maskを隠れ層と同じ行列に変換、その後掛け合わせることでattention_mask0の部分が合計に入らないようにしています。

attention_maskの代わりに、特定のトークン以外をマスクするようにすれば、一部のトークンの隠れ層だけのmean poolingを行うことも可能です。例えば、「質問+答え」で構成される文章の場合、答えの部分だけを利用するなど行うことができます(次に説明するmax poolingも同様のことができます)

with torch.no_grad():
    outputs = model(features['input_ids'], features['attention_mask'])
last_hidden_state = outputs[0]
attention_mask = features['attention_mask']

input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
sum_embeddings = torch.sum(last_hidden_state * input_mask_expanded, 1)
sum_mask = input_mask_expanded.sum(1)
sum_mask = torch.clamp(sum_mask, min=1e-9)
mean_embeddings = sum_embeddings / sum_mask
logits = nn.Linear(config.hidden_size, 1)(mean_embeddings) 
Last Hidden State Output Shape: (バッチサイズ, 最大トークン数, 隠れ層のサイズ)
Mean Embeddings Output Shape:  (バッチサイズ, 隠れ層のサイズ)
Logits Shape: (バッチサイズ, 1)

Max Pooling

最後の層を平均して、全結合層(Linear)で1要素に変換。この時、attention_maskを使って、パディングトークン等を無視するようにする。

attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()attention_maskを隠れ層と同じ行列に変換し、マスクが0の部分は、マイナス無限大(-1e9)を代入し、maxをとります。マスクがかかっている場所はマイナス無限大にしたので、それ以外の値のmaxが正しく取られることになります。

with torch.no_grad():
    outputs = model(features['input_ids'], features['attention_mask'])
last_hidden_state = outputs[0]
attention_mask = features['attention_mask']

input_mask_expanded = attention_mask.unsqueeze(-1).expand(last_hidden_state.size()).float()
last_hidden_state[input_mask_expanded == 0] = -1e9  # Set padding tokens to large negative value
max_embeddings = torch.max(last_hidden_state, 1)[0]
logits = nn.Linear(config.hidden_size, 1)(max_embeddings)
Last Hidden State Output Shape: (バッチサイズ, 最大トークン数, 隠れ層のサイズ)
Max Embeddings Output Shape: (バッチサイズ, 隠れ層のサイズ)
Logits Shape:  (バッチサイズ, 1)

Mean-Max Pooling

平均と最大の両方をとり、それを繋げて$隠れ層のサイズ \times 2$の配列を作り、全結合層(Linear)で1要素に変換

コードには、attention_maskを使った処理が入っていません。どうしてこうしているかわからないですが、パディング部分などを除外するために、attention_maskを使った方が良い気がします

with torch.no_grad():
    outputs = model(features['input_ids'], features['attention_mask'])
last_hidden_state = outputs[0]

mean_pooling_embeddings = torch.mean(last_hidden_state, 1)
_, max_pooling_embeddings = torch.max(last_hidden_state, 1)
mean_max_embeddings = torch.cat((mean_pooling_embeddings, max_pooling_embeddings), 1)
logits = nn.Linear(config.hidden_size*2, 1)(mean_max_embeddings)
Last Hidden State Output Shape: (バッチサイズ, 最大トークン数, 隠れ層のサイズ)
Mean-Max Embeddings Output Shape:  (バッチサイズ, 隠れ層のサイズの2倍)
Logits Shape: (バッチサイズ, 1)

Conv 1D Pooling

最終層の出力を畳み込み層に入力し最後にmaxを取って出力。畳み込み層は2段構成。

kernel_sizeを調整するか、段数を調整することで、CNNが見る範囲を調整できます。

CNN→CNNの間はreluが挿入されており、最終段はmaxをとるようになっています。reluをmishやswish、GRUにしても面白いかもしれません。

with torch.no_grad():
    outputs = model(features['input_ids'], features['attention_mask'])
last_hidden_state = outputs[0]

cnn1 = nn.Conv1d(768, 256, kernel_size=2, padding=1)
cnn2 = nn.Conv1d(256, 1, kernel_size=2, padding=1)

last_hidden_state = last_hidden_state.permute(0, 2, 1)
cnn_embeddings = F.relu(cnn1(last_hidden_state))
cnn_embeddings = cnn2(cnn_embeddings)
logits, _ = torch.max(cnn_embeddings, 2)
Last Hidden State Output Shape: (バッチサイズ, 最大トークン数, 隠れ層のサイズ)
CNN Embeddings Output Shape: (バッチサイズ, 1, CNNの出力)
Logits Shape: (バッチサイズ, 1)

Hidden States Output

HuggingFace Transformerの場合、configにoutput_hidden_statesにTrueを設定すると、すべての隠れ層が出力されるようになる。

transformerの場合、浅い層は単語ベースで深い層は文章単位の特徴量になっていると考えることができるそうです。上層を使うことで少しプリミティブな情報を利用することが可能となり、それが精度に良い影響を与えることもあるようです。

config = AutoConfig.from_pretrained(_pretrained_model)
config.update({'output_hidden_states':True})
model = AutoModel.from_pretrained(_pretrained_model, config=config)

例は、最後から2番目のレイヤーのCLSを使うもの(roberta-baseをモデルとして使うので13層が最終層)

with torch.no_grad():
    outputs = model(features['input_ids'], features['attention_mask'])
all_hidden_states = torch.stack(outputs[2])

layer_index = 11 # second to last hidden layer
cls_embeddings = all_hidden_states[layer_index+1, :, 0] # layer_index+1 as we have 13 layers (embedding + num of blocks)

logits = nn.Linear(config.hidden_size, 1)(cls_embeddings)
Hidden States Output Shape: (層の数, バッチサイズ, 最大トークン数, 隠れ層のサイズ)
CLS Embeddings Output Shape:  (バッチサイズ, 隠れ層のサイズ)
Logits Shape: (バッチサイズ, 1)

Concatinate Pooling

複数の隠れ層を使うもの。output_hidden_statesにTrueを設定する必要がある。

config = AutoConfig.from_pretrained(_pretrained_model)
config.update({'output_hidden_states':True})
model = AutoModel.from_pretrained(_pretrained_model, config=config)

例では最終から4層を利用

torch.catの部分で4層を繋げて、concatenate_pooling[:, 0]でCLSを取り出している

複数の層の情報を利用することで精度を向上させるテクニックです。どのレイヤーを含めるのかを含め試行錯誤の対象がいっぱいありそうです。

with torch.no_grad():
    outputs = model(features['input_ids'], features['attention_mask'])
all_hidden_states = torch.stack(outputs[2])

concatenate_pooling = torch.cat(
    (all_hidden_states[-1], all_hidden_states[-2], all_hidden_states[-3], all_hidden_states[-4]),-1
)
concatenate_pooling = concatenate_pooling[:, 0]

logits = nn.Linear(config.hidden_size*4, 1)(concatenate_pooling)
Hidden States Output Shape: (層の数, バッチサイズ, 最大トークン数, 隠れ層のサイズ)
Concatenate Pooling Output Shape: (バッチサイズ, 隠れ層のサイズ*4)
Logits Shape:(バッチサイズ, 1)

Weighted Layer Pooling

重みつけで複数の層を利用するもの。output_hidden_statesにTrueを設定する必要がある。

config = AutoConfig.from_pretrained(_pretrained_model)
config.update({'output_hidden_states':True})
model = AutoModel.from_pretrained(_pretrained_model, config=config)

以下では、単純にCLSトークンの出力を取り、全結合(Linear)層で1つに統合。

重み付きで、複数層を利用する方法です。Concatinate Poolingをより一般化したものと考えることができそうです。

class WeightedLayerPooling(nn.Module):
    def __init__(self, num_hidden_layers, layer_start: int = 4, layer_weights = None):
        super(WeightedLayerPooling, self).__init__()
        self.layer_start = layer_start
        self.num_hidden_layers = num_hidden_layers
        self.layer_weights = layer_weights if layer_weights is not None \
            else nn.Parameter(
                torch.tensor([1] * (num_hidden_layers+1 - layer_start), dtype=torch.float)
            )

    def forward(self, all_hidden_states):
        all_layer_embedding = all_hidden_states[self.layer_start:, :, :, :]
        weight_factor = self.layer_weights.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).expand(all_layer_embedding.size())
        weighted_average = (weight_factor*all_layer_embedding).sum(dim=0) / self.layer_weights.sum()
        return weighted_average

with torch.no_grad():
    outputs = model(features['input_ids'], features['attention_mask'])
all_hidden_states = torch.stack(outputs[2])

layer_start = 9
pooler = WeightedLayerPooling(
    config.num_hidden_layers, 
    layer_start=layer_start, layer_weights=None
)
weighted_pooling_embeddings = pooler(all_hidden_states)
weighted_pooling_embeddings = weighted_pooling_embeddings[:, 0]
logits = nn.Linear(config.hidden_size, 1)(weighted_pooling_embeddings)
Hidden States Output Shape: (層の数, バッチサイズ, 最大トークン数, 隠れ層のサイズ)
Weighted Pooling Output Shape:  (バッチサイズ, 隠れ層のサイズ)
Logits Shape:(バッチサイズ, 1)

LSTM/GRU Pooling

LSTM/GRUを用いるもの。output_hidden_statesにTrueを設定する必要がある。

最終段にLSTMを入れ込んだものですが、ここまで来ると少し複雑な感じです。LSTMはGPUで速度が出にくいので、学習時間がかかるかもしれません。

config = AutoConfig.from_pretrained(_pretrained_model)
config.update({'output_hidden_states':True})
model = AutoModel.from_pretrained(_pretrained_model, config=config)

例は、LSTMプーリング。

class LSTMPooling(nn.Module):
    def __init__(self, num_layers, hidden_size, hiddendim_lstm):
        super(LSTMPooling, self).__init__()
        self.num_hidden_layers = num_layers
        self.hidden_size = hidden_size
        self.hiddendim_lstm = hiddendim_lstm
        self.lstm = nn.LSTM(self.hidden_size, self.hiddendim_lstm, batch_first=True)
        self.dropout = nn.Dropout(0.1)
    
    def forward(self, all_hidden_states):
        ## forward
        hidden_states = torch.stack([all_hidden_states[layer_i][:, 0].squeeze()
                                     for layer_i in range(1, self.num_hidden_layers+1)], dim=-1)
        hidden_states = hidden_states.view(-1, self.num_hidden_layers, self.hidden_size)
        out, _ = self.lstm(hidden_states, None)
        out = self.dropout(out[:, -1, :])
        return out

with torch.no_grad():
    outputs = model(features['input_ids'], features['attention_mask'])
all_hidden_states = torch.stack(outputs[2])

hiddendim_lstm = 256
pooler = LSTMPooling(config.num_hidden_layers, config.hidden_size, hiddendim_lstm)
lstm_pooling_embeddings = pooler(all_hidden_states)
logits = nn.Linear(hiddendim_lstm, 1)(lstm_pooling_embeddings) 
Hidden States Output Shape: (層の数, バッチサイズ, 最大トークン数, 隠れ層のサイズ)
LSTM Pooling Output Shape: (バッチサイズ, 最大トークン数)
Logits Shape: (バッチサイズ, 1)

Attention Pooling

Attentionを使ったもの。output_hidden_statesにTrueを設定する必要がある

config = AutoConfig.from_pretrained(_pretrained_model)
config.update({'output_hidden_states':True})
model = AutoModel.from_pretrained(_pretrained_model, config=config)

attentionの部分で、アテンションの処理を行っている。

基本的にはsoftmaxでマスクパターンを作って乗算している形。処理の流れは理解できるが、処理の意味がいまいち理解できていない。

class AttentionPooling(nn.Module):
    def __init__(self, num_layers, hidden_size, hiddendim_fc):
        super(AttentionPooling, self).__init__()
        self.num_hidden_layers = num_layers
        self.hidden_size = hidden_size
        self.hiddendim_fc = hiddendim_fc
        self.dropout = nn.Dropout(0.1)

        q_t = np.random.normal(loc=0.0, scale=0.1, size=(1, self.hidden_size))
        self.q = nn.Parameter(torch.from_numpy(q_t)).float()
        w_ht = np.random.normal(loc=0.0, scale=0.1, size=(self.hidden_size, self.hiddendim_fc))
        self.w_h = nn.Parameter(torch.from_numpy(w_ht)).float()

    def forward(self, all_hidden_states):
        hidden_states = torch.stack([all_hidden_states[layer_i][:, 0].squeeze()
                                     for layer_i in range(1, self.num_hidden_layers+1)], dim=-1)
        hidden_states = hidden_states.view(-1, self.num_hidden_layers, self.hidden_size)
        out = self.attention(hidden_states)
        out = self.dropout(out)
        return out

    def attention(self, h):
        v = torch.matmul(self.q, h.transpose(-2, -1)).squeeze(1)
        v = F.softmax(v, -1)
        v_temp = torch.matmul(v.unsqueeze(1), h).transpose(-2, -1)
        v = torch.matmul(self.w_h.transpose(1, 0), v_temp).squeeze(2)
        return v

with torch.no_grad():
    outputs = model(features['input_ids'], features['attention_mask'])
all_hidden_states = torch.stack(outputs[2])

hiddendim_fc = 128
pooler = AttentionPooling(config.num_hidden_layers, config.hidden_size, hiddendim_fc)
attention_pooling_embeddings = pooler(all_hidden_states)
logits = nn.Linear(hiddendim_fc, 1)(attention_pooling_embeddings)
Hidden States Output Shape:  (層の数, バッチサイズ, 最大トークン数, 隠れ層のサイズ)
Attention Pooling Output Shape: (バッチサイズ, 128)
Logits Shape: (バッチサイズ, 1)

WK Pooling

“SBERT-WK: A Sentence Embedding Method By Dissecting BERT-based Word Models”に基づく処理。output_hidden_statesにTrueを設定する必要がある

config = AutoConfig.from_pretrained(_pretrained_model)
config.update({'output_hidden_states':True})
model = AutoModel.from_pretrained(_pretrained_model, config=config)

理解できていないので、コードだけ貼っておく。

処理ステップ
  • 各単語について、その層を横断しての表現を調べ、アラインメントおよび新奇性特性を調べることにより、文の各単語に対する統一された単語表現を決定
  • 単語の重要性の測定に基づいて統一された単語表現の重み付き平均を実行し、最終的な文の埋め込みベクトルを生成
class WKPooling(nn.Module):
    def __init__(self, layer_start: int = 4, context_window_size: int = 2):
        super(WKPooling, self).__init__()
        self.layer_start = layer_start
        self.context_window_size = context_window_size

    def forward(self, all_hidden_states):
        ft_all_layers = all_hidden_states
        org_device = ft_all_layers.device
        all_layer_embedding = ft_all_layers.transpose(1,0)
        all_layer_embedding = all_layer_embedding[:, self.layer_start:, :, :]  # Start from 4th layers output

        # torch.qr is slow on GPU (see https://github.com/pytorch/pytorch/issues/22573). So compute it on CPU until issue is fixed
        all_layer_embedding = all_layer_embedding.cpu()

        attention_mask = features['attention_mask'].cpu().numpy()
        unmask_num = np.array([sum(mask) for mask in attention_mask]) - 1  # Not considering the last item
        embedding = []

        # One sentence at a time
        for sent_index in range(len(unmask_num)):
            sentence_feature = all_layer_embedding[sent_index, :, :unmask_num[sent_index], :]
            one_sentence_embedding = []
            # Process each token
            for token_index in range(sentence_feature.shape[1]):
                token_feature = sentence_feature[:, token_index, :]
                # 'Unified Word Representation'
                token_embedding = self.unify_token(token_feature)
                one_sentence_embedding.append(token_embedding)

            ##features.update({'sentence_embedding': features['cls_token_embeddings']})

            one_sentence_embedding = torch.stack(one_sentence_embedding)
            sentence_embedding = self.unify_sentence(sentence_feature, one_sentence_embedding)
            embedding.append(sentence_embedding)

        output_vector = torch.stack(embedding).to(org_device)
        return output_vector

    def unify_token(self, token_feature):
        ## Unify Token Representation
        window_size = self.context_window_size

        alpha_alignment = torch.zeros(token_feature.size()[0], device=token_feature.device)
        alpha_novelty = torch.zeros(token_feature.size()[0], device=token_feature.device)

        for k in range(token_feature.size()[0]):
            left_window = token_feature[k - window_size:k, :]
            right_window = token_feature[k + 1:k + window_size + 1, :]
            window_matrix = torch.cat([left_window, right_window, token_feature[k, :][None, :]])
            Q, R = torch.qr(window_matrix.T)

            r = R[:, -1]
            alpha_alignment[k] = torch.mean(self.norm_vector(R[:-1, :-1], dim=0), dim=1).matmul(R[:-1, -1]) / torch.norm(r[:-1])
            alpha_alignment[k] = 1 / (alpha_alignment[k] * window_matrix.size()[0] * 2)
            alpha_novelty[k] = torch.abs(r[-1]) / torch.norm(r)

        # Sum Norm
        alpha_alignment = alpha_alignment / torch.sum(alpha_alignment)  # Normalization Choice
        alpha_novelty = alpha_novelty / torch.sum(alpha_novelty)

        alpha = alpha_novelty + alpha_alignment
        alpha = alpha / torch.sum(alpha)  # Normalize

        out_embedding = torch.mv(token_feature.t(), alpha)
        return out_embedding

    def norm_vector(self, vec, p=2, dim=0):
        ## Implements the normalize() function from sklearn
        vec_norm = torch.norm(vec, p=p, dim=dim)
        return vec.div(vec_norm.expand_as(vec))

    def unify_sentence(self, sentence_feature, one_sentence_embedding):
        ## Unify Sentence By Token Importance
        sent_len = one_sentence_embedding.size()[0]

        var_token = torch.zeros(sent_len, device=one_sentence_embedding.device)
        for token_index in range(sent_len):
            token_feature = sentence_feature[:, token_index, :]
            sim_map = self.cosine_similarity_torch(token_feature)
            var_token[token_index] = torch.var(sim_map.diagonal(-1))

        var_token = var_token / torch.sum(var_token)
        sentence_embedding = torch.mv(one_sentence_embedding.t(), var_token)

        return sentence_embedding
    
    def cosine_similarity_torch(self, x1, x2=None, eps=1e-8):
        x2 = x1 if x2 is None else x2
        w1 = x1.norm(p=2, dim=1, keepdim=True)
        w2 = w1 if x2 is x1 else x2.norm(p=2, dim=1, keepdim=True)
        return torch.mm(x1, x2.t()) / (w1 * w2.t()).clamp(min=eps)

使い方は以下の通り

pooler = WKPooling(layer_start=9)
wkpooling_embeddings = pooler(all_hidden_states)
logits = nn.Linear(config.hidden_size, 1)(wkpooling_embeddings)

まとめ

手法を列挙し、ポイントとなる部分だけを抜き出してみました。理解の助けになれば幸いです。

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

ABOUT ME
ある/Aru
ある/Aru
IT&機械学習エンジニア/ファイナンシャルプランナー(CFP®)
専門分野は並列処理・画像処理・機械学習・ディープラーニング。プログラミング言語はC, C++, Go, Pythonを中心として色々利用。現在は、Kaggle, 競プロなどをしながら悠々自適に活動中
記事URLをコピーしました