テーブルデータ向けトランスフォーマー「TabPFN」の使い方
この記事では、テーブルデータ(表形式のデータ)に特化したトランスフォーマー技術、TabPFNの基本的な使い方を解説します。TabPFNは、トランスフォーマーモデルを用いた機械学習のライブラリで、高精度な推論を行える手法の1つです。
はじめに
この記事では、テーブルデータ向けトランスフォーマー「TabPFN」の実践的な使い方に重点を置いて解説します。具体的な設定や操作方法を紹介することで、すぐに実装に取り掛かることができる内容を目指しています。
TabPFNのアルゴリズムや理論的背景について詳しく知りたい方は、関連する論文を参照することをお勧めします。以下に論文のリンクを記載しますので、興味のある方はぜひご覧ください。
テーブルデータ用のDNNとしては、TabNetもあります。TabNetについては、以下の記事を参照してください。
なお、今回のサンプルプログラムをgithubに置いています(※Colabで動作します)。
概要
“We present TabPFN, a trained Transformer that can do supervised classification for small tabular datasets in less than a second, needs no hyperparameter tuning and is competitive with state-of-the-art classification methods.supporting pytorch impelementation.“と論文に書かれているように、pytorchで記述された表向けのTransformerです。
速度も早く、性能も良いといいことづくめなことが記載されていますが、表(テーブル)の機械学習で使われているlightGBM, XGBoost以外の選択肢になり得るかもしれません。
また、boostingと異なるアルゴリズムなので、XGBoostなどとのアンサンブルに使うと性能アップが期待できるかもしれません。
また、ディープラーニングベースのモデルとしてはTabNetもあります。こちらについても記事にしていますので参考にしてください。
インストール
インストールは非常に簡単です。githubに書かれている通りにすればOKです。
pip install tabpfn
pytorchで実装されているので、一応pytorchを、あと表データを扱うのでnumpyとpandasもインストールしていなければインストールしておいた方が良いかと思います。
試してみる
TabPFNはクラス分類器なので、scikitlearnのデータセットから乳がん患者のデータセットを使って実際に使ってみたいと思います。
データセットの読み込み
データセットの読み込みを行うと、辞書形で読み込まれます。
from sklearn.datasets import load_breast_cancer
datas = load_breast_cancer()
このうち、dataとtargetを説明変数(X)と目的変数(y)として利用します。
X = datas['data']
y = datas['target']
X.shape, y.shape
次に、データを訓練用(train)と、テスト用(test)に分割します。
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)
これで、データが準備できました。
TabPFNで予測
TabPFNで予測してみます。
import torch
from tabpfn import TabPFNClassifier
device = "cuda" if torch.cuda.is_available() else "cpu"
classifier = TabPFNClassifier(device=device, N_ensemble_configurations=32)
classifier.fit(X_train, y_train)
y_pred, p_pred = classifier.predict(X_test, return_winning_probability=True)
print('Accuracy', accuracy_score(y_test, y_pred))
TabPFNはGPUが使えるので、GPUがある場合はGPUで実行するようにdeviceに設定しています。
なお、パラメータとしては、N_ensemble_configurationsくらいしかなく、”needs no hyperparameter tuning“というのはあながちうそではないです。
結果は、以下の通りです。
Loading model that can be used for inference only
Using a Transformer with 25.82 M parameters
Multiple models in memory. This might lead to memory issues. Consider calling remove_models_from_memory()
Accuracy 0.9840425531914894
特に設定せずに、正解率が0.98とかなり高い値になりました。なお、表データを利用する場合は、データの正則化などの前処理を行うことが多いですが、”Do not preprocess inputs to TabPFN. TabPFN pre-processes inputs internally. It applies a z-score normalization “ということで、TabPFNはそのまま入力した方が良いようです。
LightGBMで予測してみる
精度差を見るために、lightGBMでも予測をしてみました。コードは以下の通り。
from lightgbm import LGBMClassifier
classifier = LGBMClassifier()
classifier.fit(X_train, y_train)
y_eval = classifier.predict(X_test)
print('Accuracy', accuracy_score(y_test, y_eval))
結果は以下の通り。
Accuracy 0.9521276595744681
正解率は0.95とTabPFNと比較して若干悪い結果になりました。
TabPFNの注意点
TabPFNは、カテゴリデータを入力できないので、男女などのカテゴリデータについては、OrdinalEncoder, OneHotEncoderなどでエンコードする必要があります。
いずれもscikit-learnで用意されているので、それを使うと良いです。
- TabPFN expects scalar values only (you need to encode categoricals as integers e.g. with OrdinalEncoder). It works best on data that does not contain any categorical or NaN data (see Appendix B.1).
まとめ
lightGBMは前処理も、ハイパーパラメータのチューニングもしていない状態ですので単純に比較はできませんが、とりあえず、何も考えずに使って性能が出るというのはすごいことだと思います。結構使えそうなので積極的に使っていこうと思います。ただ、CPUで処理させるとlightGBMと比較すると少し遅い印象です。