表形式データ向けDNN「TabNet」でクラス分類問題を解く(使い方)|Python
Kaggleのコンペティションでは、表形式データ(テーブルデータ)に対する予測手段として「TabNet」を目にすることがあります。この記事では、TabNetを用いたクラス分類問題の解き方を、具体例を交えながら詳しく解説します。本記事では、使い方を中心に解説しますので、詳しいアーキテクチャについては論文等を参照してください(arXiv:1980.07442)。
TabNetで回帰を行う方法については以下の記事を参照してください
テーブルデータ向けDNNとしてはTabPFNもあります。TabPFNについては以下を参照してください
TabNetの概要
TabNet(Tabular Neural Network)は、ディープニューラルネットワーク(DNN)の一種で、特に表形式のデータ(タブラーデータ)を処理するために設計されたモデルです。
DNNは通常、画像やテキストの処理に使われることが一般的ですが、TabNetは表形式のデータに焦点を当て、その特有の特性に適したネットワーク構造とトレーニング手法を提供します。
TabNetは、以下の主な特徴を持っています
- 特徴の重要度の学習
入力データの各特徴(列)の重要性をモデル自体が理解し、それに基づいて予測を行います。これにより、モデルの予測プロセスが解釈可能になります。 - 逐次的な特徴選択
モデルは、予測の過程で重要あると判断された特徴を選択していくことで、冗長な情報を排除していきます。これにより、高い予測性能を維持しながら過学習を抑制できます - マスクされた自己注意機構
TabNetは、注意機構(self-attention mechanism)を採用しています。これにより、入力データの異なる部分同士の関連性を学習し、有益なパターンを抽出可能です - カテゴリカル変数のサポート
TabNetは、連続的な数値データだけでなく、カテゴリカル変数(例: カテゴリカル特徴やカテゴリカルエンコーディング)も処理できるように設計されています
今回利用したのは、TabNetのPytorch実装です。リンクは以下になります。
ソースコード(github): https://github.com/dreamquark-ai/tabnet/tree/develop
ドキュメント: https://dreamquark-ai.github.io/tabnet/index.html
READMEを読むと、ライブラリの改善のために、論文とは異なる部分が存在する可能性があるそうです。
なお、Google Colab上で動作するコードをここに置いています。
インストール
インストールはpipで行うことができます。以下のコードをコマンドラインより実行してください。
pip install pytorch-tabnet
Google Colabなどで試す場合は、!
を先頭につけてセルに入力し実行してください。
!pip install pytorch-tabnet
ライブラリのインポート
今回のコードで利用するライブラリをインポートしておきます。
numpy, pandas, matplotlib…と、この手の実験を行う場合にお馴染みのライブラリになります。なお、tabnetの引数にpytorchの関数を渡しますので、torchもインポートしておきます。
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score
from pytorch_tabnet.tab_model import TabNetClassifier
from pytorch_tabnet.pretraining import TabNetPretrainer
import torch
評価用のデータセットを準備
データセットを生成する(ダミーデータの準備)
評価用のデータセットは、scikit-learnのmake_classification
を利用して作成します。
make_classification
については以下の記事を参考にしてください
X, y = make_classification(n_samples=10000,
n_features=10,
n_redundant = 3,
n_informative = 5,
n_classes=2,
random_state=42)
作成されたデータは、2クラス分類のデータセットで、targetは0,1のどちらかで、10個の特徴量を備えたものになります。
以下のコードでpandasのデータに変換して表示できます。
df = pd.DataFrame(X)
df['target'] = y
df.head(20)
TabNetにカテゴリデータを入力する場合には、エンコードして数値データに変換する必要があります。カテゴリデータに変換する方法は、以下の記事を参考にしてください。
TabNetに特徴量がカテゴリ変数であることを伝えるには、以下の引数を利用します
cat_idxs | カテゴリ変数の列のインデックスのリスト |
cat_dims | 各カテゴリ変数のカテゴリ数のリスト |
cat_emb_dim | 各カテゴリ変数の埋め込みサイズのリスト |
作成したデータセットをtrain/valid/testに分割
ここでは、 train:valid:test=0.7:0.15:0.15に分割します。
train_rate, val_rate, test_rate = 0.7, 0.15, 0.15
X_train, X_test, y_train, y_test = train_test_split(X, y, train_size=train_rate, random_state=42)
X_valid, X_test, y_valid, y_test = train_test_split(X_test, y_test, test_size=test_rate/(test_rate+val_rate), random_state=42)
TabNetの使い方
事前学習を行う場合
事前学習を行う場合は、TabNetPretrainer
を使って行います。教師なし学習なので、y_train, y_validを渡さないことに注意してください。
今回の例では、事前学習をした方がスコアが悪かったので、事前学習をしていません。
TabNetは、まだお試し程度でしか使っていないので事前学習をどう使えば効果的なのか理解できていません。
unsupervised_model = TabNetPretrainer(
optimizer_fn=torch.optim.Adam,
optimizer_params=dict(lr=2e-2),
mask_type='entmax'
)
unsupervised_model.fit(
X_train=X_train,
eval_set=[X_valid],
pretraining_ratio=0.8,
max_epochs = 100,
)
学習を行う
学習コードはscikitlearnのインタフェースに似ているので、lightGBMやXGBoostなどを使ったことがある方はわかりやすいのではないかと思います。
各パラメータについては、公式ドキュメントを確認してください。
パラメータについては、PyTorchを使ってDNNの学習をしたことがあるのであれば迷わないかと思いますが、とりあえずデフォルトでよいと思います。
以下のソースコードでコメントアウトしている行のコメントを外せば、上で説明した事前学習の結果を使って学習が行われます。
tabnet_params = {
"optimizer_fn":torch.optim.Adam,
"optimizer_params":dict(lr=2e-2),
"scheduler_params":{"step_size":50, # how to use learning rate scheduler
"gamma":0.9},
"scheduler_fn":torch.optim.lr_scheduler.StepLR,
"mask_type":'entmax',
}
max_epochs = 20
clf = TabNetClassifier(**tabnet_params
)
clf.fit(
X_train=X_train, y_train=y_train,
eval_set=[(X_train, y_train), (X_valid, y_valid)],
eval_name=['train', 'valid'],
eval_metric=['auc'],
max_epochs=max_epochs , patience=20,
batch_size=1024, virtual_batch_size=128,
num_workers=0,
weights=1,
drop_last=False,
augmentations=None,
# from_unsupervised=unsupervised_model
)
学習は遅いです。GPUを使ったら多少は高速になりますがLightGBMと比較するとかなり遅いです。
大規模データの場合は、学習結果をセーブ(clf.save_model(ファイル名)
)して、ロード(clf.load_model(ファイル名)
)して使う感じになりそうです。
予測結果
以下のコードで、検証データ、テストデータに対する予測ができます。
preds = clf.predict_proba(X_test)
test_auc = roc_auc_score(y_score=preds[:,1], y_true=y_test)
preds_valid = clf.predict_proba(X_valid)
valid_auc = roc_auc_score(y_score=preds_valid[:,1], y_true=y_valid)
print("valid_auc = ", valid_auc)
print("test_auc = ", test_auc)
結果を見ると、test_aucで0.986とまずまずの結果であることがわかります。
実行結果
valid_auc = 0.9851784018450686
test_auc = 0.9864709407043504
重要度を表示
TabNetでも、LightGBMなどと同様に特徴量の重要度を確認することが可能です。
importance = sorted([(i, n) for i, n in enumerate(clf.feature_importances_)], key = lambda x: x[1], reverse = True)
label, y = [], []
for e in importance:
print(f"feature {e[0]} : {e[1]}")
label.append(e[0])
y.append(e[1])
plt.bar([i for i in range(len(y))], y, tick_label = label)
plt.show()
結果を見ると9→5→0→3…の順に重要度が高いと結果となったようです。
LightGBMと比較
比較のために、LightGBMでも予測してみました。
lightGBMで学習
from lightgbm import LGBMClassifier
lgb_params = {
'n_estimators': 10000,
'learning_rate': 0.05,
'random_state': 42,
'early_stopping_round': 20,
'metric': 'auc'
}
lgb = LGBMClassifier(**lgb_params)
lgb.fit(X_train, y_train, eval_set=[(X_valid, y_valid)])
lightGBMで予測
preds = lgb.predict_proba(X_test)
test_auc = roc_auc_score(y_score=preds[:,1], y_true=y_test)
preds_valid = lgb.predict_proba(X_valid)
valid_auc = roc_auc_score(y_score=preds_valid[:,1], y_true=y_valid)
print("valid_auc = ", valid_auc)
print("test_auc = ", test_auc)
結果はtest_auc
で0.985とわずかにTabNetに劣る結果となりました。
パラメータチューニングでもっと精度を上げられるかもですが、とりあえず、lightGBMとTabPFNで同等の性能を得ることができました。
実行結果
valid_auc = 0.9830452198873252
test_auc = 0.9846662363824835
LGBMの特徴量の重要度は以下のようになります。こちらは、0→1→4→…の順となっています。
データセットを生成するときに、n_informative = 5
と設定していたので、順番は違っても上位5つの特徴量はほぼ同じなるかと思いましたが結構違います。ここは、すこし驚きです。
まとめ
TabNet単体でつかってもよいと思いますが、個人的な興味はLightGBMやCatBoostなどとのアンサンブルでしょうか。
最近人気のCatBoost, LightGBM, XGBoostはどれも勾配ブースティングを用いたものです。似たようなアルゴリズムをアンサンブルするより、異なるアルゴリズムの学習器をアンサンブルした方が良い結果が得られそうなので、TabNetはその選択肢の1つとなりそうです。
また、テーブルデータに対するトランスフォーマとしてTabPFNがあります。こちらも制約がいろいろありますが性能が高いようです。