機械学習
記事内に商品プロモーションを含む場合があります

TabPFN(表のためのトランスフォーマー)のを試す

アイキャッチ画像
tadanori

テーブルデータ(表形式のデータ)で使えるtransformerに、TabPFNというものがあります。この記事では、TabPFNの使い方を解説します。

はじめに

この説明では、どういうものかという部分ではなく、実践(使い方)に重点を置いています。

アルゴリズムについて詳しく知りたい場合は、論文などをチェックすることをお勧めします。

以下、論文のリンクです

論文リンク:TabPFN: A Transformer That Solves Small Tabular Classification Problems in a Second
github : https://github.com/automl/TabPFN

テーブルデータ用のDNNとしては、TabNetもあります。TabNetについては、以下の記事を参照してください。

テーブルデータ向けDNN”TabNet”の使い方を解説【Python】
テーブルデータ向けDNN”TabNet”の使い方を解説【Python】

なお、今回のサンプルプログラムをgithubに置いています(※Colabで動作します)。

概要

We present TabPFN, a trained Transformer that can do supervised classification for small tabular datasets in less than a second, needs no hyperparameter tuning and is competitive with state-of-the-art classification methods.supporting pytorch impelementation.“と論文に書かれているように、pytorchで記述された表向けのTransformerです。

速度も早く、性能も良いといいことづくめなことが記載されていますが、表(テーブル)の機械学習で使われているlightGBM, XGBoost以外の選択肢になり得るかもしれません。

また、boostingと異なるアルゴリズムなので、XGBoostなどとのアンサンブルに使うと性能アップが期待できるかもしれません。

インストール

インストールは非常に簡単です。githubに書かれている通りにすればOKです。

pip install tabpfn

pytorchで実装されているので、一応pytorchを、あと表データを扱うのでnumpyとpandasもインストールしていなければインストールしておいた方が良いかと思います。

試してみる

TabPFNはクラス分類器なので、scikitlearnのデータセットから乳がん患者のデータセットを使って実際に使ってみたいと思います。

データセットの読み込み

データセットの読み込みを行うと、辞書形で読み込まれます。

from sklearn.datasets import load_breast_cancer
datas = load_breast_cancer()

このうち、datatargetを説明変数(X)と目的変数(y)として利用します。

X = datas['data']
y = datas['target']
X.shape, y.shape

次に、データを訓練用(train)と、テスト用(test)に分割します。

from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.33, random_state=42)

これで、データが準備できました。

TabPFNで予測

TabPFNで予測してみます。

import torch
from tabpfn import TabPFNClassifier

device = "cuda" if torch.cuda.is_available() else "cpu"

classifier = TabPFNClassifier(device=device, N_ensemble_configurations=32)
classifier.fit(X_train, y_train)
y_pred, p_pred = classifier.predict(X_test, return_winning_probability=True)

print('Accuracy', accuracy_score(y_test, y_pred))

TabPFNはGPUが使えるので、GPUがある場合はGPUで実行するようにdeviceに設定しています。

なお、パラメータとしては、N_ensemble_configurationsくらいしかなく、”needs no hyperparameter tuning“というのはあながちうそではないです。

結果は、以下の通りです。

Loading model that can be used for inference only
Using a Transformer with 25.82 M parameters
Multiple models in memory. This might lead to memory issues. Consider calling remove_models_from_memory()
Accuracy 0.9840425531914894

特に設定せずに、正解率が0.98とかなり高い値になりました。なお、表データを利用する場合は、データの正則化などの前処理を行うことが多いですが、”Do not preprocess inputs to TabPFN. TabPFN pre-processes inputs internally. It applies a z-score normalization “ということで、TabPFNはそのまま入力した方が良いようです。

LightGBMで予測してみる

精度差を見るために、lightGBMでも予測をしてみました。コードは以下の通り。

from lightgbm import LGBMClassifier

classifier = LGBMClassifier()
classifier.fit(X_train, y_train)
y_eval = classifier.predict(X_test)

print('Accuracy', accuracy_score(y_test, y_eval))

結果は以下の通り。

Accuracy 0.9521276595744681

正解率は0.95とTabPFNと比較して若干悪い結果になりました。

TabPFNの注意点

TabPFNは、カテゴリデータを入力できないので、男女などのカテゴリデータについては、OrdinalEncoder, OneHotEncoderなどでエンコードする必要があります。

いずれもscikit-learnで用意されているので、それを使うと良いです。

  • TabPFN expects scalar values only (you need to encode categoricals as integers e.g. with OrdinalEncoder). It works best on data that does not contain any categorical or NaN data (see Appendix B.1).
https://github.com/automl/TabPFN/blob/main/README.md

まとめ

lightGBMは前処理も、ハイパーパラメータのチューニングもしていない状態ですので単純に比較はできませんが、とりあえず、何も考えずに使って性能が出るというのはすごいことだと思います。結構使えそうなので積極的に使っていこうと思います。ただ、CPUで処理させるとlightGBMと比較すると少し遅い印象です。

おすすめ書籍

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

記事URLをコピーしました