機械学習
記事内に商品プロモーションを含む場合があります

HuggingFace(Llama2, Gemma, Calm2)の使い方|大規模言語モデル実践

Aru

HuggingFaceで公開されている3つのモデル(Llama2/Gemma/CALM2)をPythonで動かしてみました。今回は70億のパラメータでトレーニングされ、16GBのGPUメモリで動作可能な7Bモデルを使用して、それぞれのモデルの使い方を解説します。

大規模言語モデル関連記事一覧はこちら
大規模言語モデル(LLM)関連の記事一覧
大規模言語モデル(LLM)関連の記事一覧

はじめに

HuggingFaceには、さまざまな大規模言語モデル(LLM)が公開されていますが、この記事ではLlama2, Gemma, CALM2の3つのモデルを使う基本的なPythonコードを作成してみます。

HuggingFaceに登録されたモデルは、基本的に同じようなコードで呼び出すことが可能で、ちょっと呼び出して使う程度なら本当に簡単に実装できます。

しかしながら、CPUやGPUのメモリが必要となるので、動作させるにはそれなりの環境が必要になります。

今回は、Google Colabのハイメモリ+GPU(16GB以上)という環境で、それぞれのLLMを動作させてみました。

Google Colabで動作させる場合は、GPU:T4以上+ハイメモリが良いと思います。ハイメモリでないとメモリクラッシュが発生しました。

LLMの推論プロセスについては以下の記事でわかりやすく説明しています。気になる方は参考にしてください。

LLM(大規模言語モデル)の仕組みを分かりやすく解説|推論処理を理解しよう
LLM(大規模言語モデル)の仕組みを分かりやすく解説|推論処理を理解しよう

前準備

ライブラリのインストール

動作させるには、ライブラリなどをインストールしておく必要があります。Google Colabの場合は以下のライブラリをインストールしておきます(必要ないものもあるかもしれません)。

!pip install -q accelerate==0.21.0 peft==0.4.0 bitsandbytes==0.40.2 transformers==4.31.0 trl==0.4.7

また、Gemmaに関しては最新のtransformersをインストールする必要がありました。

!pip install -q -U git+https://github.com/huggingface/transformers.git

Google Colabでは、「セッションの再起動」が必要でした

HuggingFaceへのログイン

OpenCALM以外はHuggingFaceへのログインが必要です。

huggingface-cli loginを使って、ログインする必要があります。

!huggingface-cli login

ログインに必要なトークンは、ログインのメッセージにあるURLから入手することが可能です(HuggingFaceのアカウントが必要です)

Llama2(Meta)

Llama2とは

Llama2(Large Language Model Meta AI/ラマ)は、Meta社が開発した大規模言語モデル(LLM)です。現在、研究および商用利用に無料で提供されています

コード(Llama-2-7b-chat-hf)

今回はChat向けに学習した、Llama-2-7b-chat-hfを利用しています。

以下は、tokenizermodelの定義です。

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
)
model = AutoModelForCausalLM.from_pretrained(
    "meta-llama/Llama-2-7b-chat-hf",
    torch_dtype=torch.float16, device_map="auto", trust_remote_code=False,
)

以下は、model/tokenizerの使用例です。llama2では[INST]指示文[/INST]という形で入力するようです。

# プロンプト(結構重い)
prompt = "[INST] 日本で一番高い建物は? [/INST]"

# 推論
input_ids = tokenizer(prompt, return_tensors='pt')
with torch.no_grad():
  tokens = model.generate(
      **input_ids.to(model.device),
        max_new_tokens=64,
        do_sample=True,
        temperature=0.5,
        top_p=0.9,
        repetition_penalty=1.05,
  )
output = tokenizer.decode(tokens[0], skip_special_tokens=True)
print(output)
出力結果
[INST] 日本で一番高い建物は? [/INST]  Japan has several tall buildings, and the one that is considered to be the highest depends on the criteria used. Here are some of the tallest buildings in Japan:

1. Tokyo Skytree: At 634 meters (2,080 feet) tall, Tokyo Skytree is the tall

トークナイザーについての記事はこちら

Hugging FaceのTokenizerを理解する|動作まとめ
Hugging FaceのTokenizerを理解する|動作まとめ

チャットテンプレートを使ってプロンプトを生成する方法についてはこちら

HuggingFaceの大規模言語モデル(LLM)のチャットテンプレートの使い方
HuggingFaceの大規模言語モデル(LLM)のチャットテンプレートの使い方

Gemma(Google)

Gemmaとは

Gemmaは、2024年2月21日に発表されたGoogleが開発した大規模言語モデル(LLM)です。AI開発者や研究者による商用利用や再配布が可能なオープンソースとして提供されています。

利用するには、HuggingFaceのgemma-7bのページで利用規約に同意する必要があります。同意していない場合ログインしていてもダウンロードできないので注意してください。

https://huggingface.co/google/gemma-7b

コード(gemma-7b)

今回はChat向けに学習した、google/gemma-7bを利用しています。

以下は、tokenizermodelの定義です。

from transformers import AutoTokenizer, AutoModelForCausalLM
import torch

tokenizer = AutoTokenizer.from_pretrained("google/gemma-7b")
model = AutoModelForCausalLM.from_pretrained("google/gemma-7b", device_map="auto")

以下は、model/tokenizerの使用例です。日本語の質問はうまく動かなかったので、英語で質問しています。

input_text = "What is the tallest building in Japan?"
input_ids = tokenizer(input_text, return_tensors="pt").to(model.device)

with torch.no_grad():
  tokens = model.generate(**input_ids)

output = tokenizer.decode(tokens[0], skip_special_tokens=True)
print(output)  
出力結果
What is the tallest building in Japan?

The tallest building in Japan is the Tokyo Skytree

CALM2(CyberAgent)

CALM2とは

CALM2(CyberAgentLM2)はサイバーエージェントが開発した大規模言語モデル(LLM)です。日本語に特化した大規模言語モデルという部分が特徴です。32,000トークンまでの入力に対応していて日本語の文章として約50,000文字を一度に処理することができるそうです。

50,000文字を32,000トークンなので、マルチリンガルモデルでありがちな日本語1文字につき1トークンに変換されているわけではなさそうです。

コード(calm2-7b-chat)

今回はChat向けに学習した、cyberagent/calm2-7b-chatを利用しています。

以下は、tokenizermodelの定義です。

import torch
from transformers import AutoModelForCausalLM, AutoTokenizer

model = AutoModelForCausalLM.from_pretrained("cyberagent/calm2-7b-chat",
                                             device_map="auto",
                                             torch_dtype=torch.float16,
                                             )
tokenizer = AutoTokenizer.from_pretrained("cyberagent/calm2-7b-chat")

以下は、model/tokenizerの使用例です。

calm2-7b-chatの場合、”USER: 文章 \nASSISTANT:“という形式で入力する必要があります(チャットのテンプレートについては以下の記事を参考にしてください)。

あわせて読みたい
HuggingFaceの大規模言語モデル(LLM)のチャットテンプレートの使い方
HuggingFaceの大規模言語モデル(LLM)のチャットテンプレートの使い方

公開されているチャットのテンプレート

USER: {user_message1}
ASSISTANT: {assistant_message1}<|endoftext|>
USER: {user_message2}
ASSISTANT: {assistant_message2}<|endoftext|>
USER: {user_message3}
ASSISTANT: {assistant_message3}<|endoftext|>
prompt = "USER: 日本で一番高い建物はなんですか?\nASSISTANT:"

# 推論
input_ids = tokenizer(prompt, return_tensors='pt')
with torch.no_grad():
  tokens = model.generate(
      **input_ids.to(model.device),
        max_new_tokens=64,
        do_sample=True,
        temperature=0.5,
        top_p=0.9,
        repetition_penalty=1.05,
        pad_token_id=tokenizer.pad_token_id,
  )
output = tokenizer.decode(tokens[0], skip_special_tokens=True)
print(output)
出力結果
USER: 日本で一番高い建物はなんですか?
ASSISTANT: 日本で一番高い建物は、2019年8月時点において、東京都中央区晴海にある「あべのハルカス」です。高さは634メートルで、これは日本国内にある建造物で、最も高い建物となっています。

CLAM2については、他の記事も参考にしてください。

あわせて読みたい
CALM2でチャットボット作成&トークナイザーの出力も調査
CALM2でチャットボット作成&トークナイザーの出力も調査
あわせて読みたい
LangChainで実装するRAG|CALM2 + FAISS + RetrievalQAを使った具体例
LangChainで実装するRAG|CALM2 + FAISS + RetrievalQAを使った具体例

まとめ

以上、Llama2/Gemma/Calm2の3つを動かしてみました。Gemmaは日本語で聞くとうまく答えられませんでした。

とりあえず、環境さえあれば手軽に動かすことが可能です。

なお、チャットモデルにはチャットテンプレートで入力プロンプトを作成できるモデルもあります。これについては以下の記事を参考にしてください。

HuggingFaceの大規模言語モデル(LLM)のチャットテンプレートの使い方
HuggingFaceの大規模言語モデル(LLM)のチャットテンプレートの使い方

メールアドレスが公開されることはありません。 が付いている欄は必須項目です

ABOUT ME
ある/Aru
ある/Aru
IT&機械学習エンジニア/ファイナンシャルプランナー(CFP®)
専門分野は並列処理・画像処理・機械学習・ディープラーニング。プログラミング言語はC, C++, Go, Pythonを中心として色々利用。現在は、Kaggle, 競プロなどをしながら悠々自適に活動中
記事URLをコピーしました