dolly-japanese-gpt-1b を使って生成AIで遊んでみました。
準備
python バージョン確認
python は v3.11.3 を使用しました。
$ python --version Python 3.11.3
sentencepiece インストール
SentencePiece は、言語モデルの学習データである 生のテキストから最適な分割点を学習する教師なし 単語分割システムです。
$ pip install sentencepiece Collecting sentencepiece Downloading sentencepiece-0.1.99-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 12.9 MB/s eta 0:00:00 Installing collected packages: sentencepiece Successfully installed sentencepiece-0.1.99 [notice] A new release of pip available: 22.3.1 -> 23.1.2 [notice] To update, run: pip install --upgrade pip
torch インストール
Python向けのオープンソース機械学習ライブラリです。
$ pip3 install -U torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117 Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu117 Collecting torch Downloading https://download.pytorch.org/whl/cu117/torch-2.0.1%2Bcu117-cp311-cp311-linux_x86_64.whl (1843.9 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.8/1.8 GB 1.1 MB/s eta 0:00:00 Collecting torchvision Downloading https://download.pytorch.org/whl/cu117/torchvision-0.15.2%2Bcu117-cp311-cp311-linux_x86_64.whl (6.1 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.1/6.1 MB 5.4 MB/s eta 0:00:00 Collecting torchaudio Downloading https://download.pytorch.org/whl/cu117/torchaudio-2.0.2%2Bcu117-cp311-cp311-linux_x86_64.whl (4.4 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.4/4.4 MB 4.0 MB/s eta 0:00:00 Collecting filelock Downloading filelock-3.12.0-py3-none-any.whl (10 kB) Collecting typing-extensions Downloading typing_extensions-4.5.0-py3-none-any.whl (27 kB) Collecting sympy Downloading sympy-1.12-py3-none-any.whl (5.7 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.7/5.7 MB 32.6 MB/s eta 0:00:00 Collecting networkx Downloading networkx-3.1-py3-none-any.whl (2.1 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 10.6 MB/s eta 0:00:00 Requirement already satisfied: jinja2 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from torch) (3.1.2) Collecting triton==2.0.0 Downloading https://download.pytorch.org/whl/triton-2.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (63.3 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 63.3/63.3 MB 12.0 MB/s eta 0:00:00 Collecting cmake Downloading cmake-3.26.3-py2.py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (24.0 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.0/24.0 MB 37.5 MB/s eta 0:00:00 Collecting lit Downloading lit-16.0.3.tar.gz (138 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 138.0/138.0 kB 5.8 MB/s eta 0:00:00 Preparing metadata (setup.py) ... done Requirement already satisfied: numpy in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from torchvision) (1.24.3) Requirement already satisfied: requests in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from torchvision) (2.30.0) Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from torchvision) (9.5.0) Requirement already satisfied: MarkupSafe>=2.0 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from jinja2->torch) (2.1.2) Requirement already satisfied: charset-normalizer<4,>=2 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from requests->torchvision) (3.1.0) Requirement already satisfied: idna<4,>=2.5 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from requests->torchvision) (3.4) Requirement already satisfied: urllib3<3,>=1.21.1 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from requests->torchvision) (2.0.2) Requirement already satisfied: certifi>=2017.4.17 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from requests->torchvision) (2023.5.7) Collecting mpmath>=0.19 Downloading mpmath-1.3.0-py3-none-any.whl (536 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 536.2/536.2 kB 19.2 MB/s eta 0:00:00 Installing collected packages: mpmath, lit, cmake, typing-extensions, sympy, networkx, filelock, triton, torch, torchvision, torchaudio DEPRECATION: lit is being installed using the legacy 'setup.py install' method, because it does not have a 'pyproject.toml' and the 'wheel' package is not installed. pip 23.1 will enforce this behaviour change. A possible replacement is to enable the '--use-pep517' option. Discussion can be found at https://github.com/pypa/pip/issues/8559 Running setup.py install for lit ... done Successfully installed cmake-3.26.3 filelock-3.12.0 lit-16.0.3 mpmath-1.3.0 networkx-3.1 sympy-1.12 torch-2.0.1+cu117 torchaudio-2.0.2+cu117 torchvision-0.15.2+cu117 triton-2.0.0 typing-extensions-4.5.0 [notice] A new release of pip available: 22.3.1 -> 23.1.2 [notice] To update, run: pip install --upgrade pip
transformers インストール
Transformers は Hugging Face 社が公開している、最先端の NLP モデルの実装と事前学習済みモデルを提供するライブラリです。
$ pip install transformers Collecting transformers Downloading transformers-4.29.1-py3-none-any.whl (7.1 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.1/7.1 MB 34.7 MB/s eta 0:00:00 Requirement already satisfied: filelock in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from transformers) (3.12.0) Collecting huggingface-hub<1.0,>=0.14.1 Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 224.5/224.5 kB 9.4 MB/s eta 0:00:00 Requirement already satisfied: numpy>=1.17 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from transformers) (1.24.3) Requirement already satisfied: packaging>=20.0 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from transformers) (23.1) Requirement already satisfied: pyyaml>=5.1 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from transformers) (6.0) Collecting regex!=2019.12.17 Downloading regex-2023.5.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (780 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 780.9/780.9 kB 21.3 MB/s eta 0:00:00 Requirement already satisfied: requests in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from transformers) (2.30.0) Collecting tokenizers!=0.11.3,<0.14,>=0.11.1 Downloading tokenizers-0.13.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.8/7.8 MB 43.5 MB/s eta 0:00:00 Collecting tqdm>=4.27 Downloading tqdm-4.65.0-py3-none-any.whl (77 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 77.1/77.1 kB 4.7 MB/s eta 0:00:00 Collecting fsspec Downloading fsspec-2023.5.0-py3-none-any.whl (160 kB) ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 160.1/160.1 kB 8.1 MB/s eta 0:00:00 Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (4.5.0) Requirement already satisfied: charset-normalizer<4,>=2 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from requests->transformers) (3.1.0) Requirement already satisfied: idna<4,>=2.5 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from requests->transformers) (3.4) Requirement already satisfied: urllib3<3,>=1.21.1 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from requests->transformers) (2.0.2) Requirement already satisfied: certifi>=2017.4.17 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from requests->transformers) (2023.5.7) Installing collected packages: tokenizers, tqdm, regex, fsspec, huggingface-hub, transformers Successfully installed fsspec-2023.5.0 huggingface-hub-0.14.1 regex-2023.5.5 tokenizers-0.13.3 tqdm-4.65.0 transformers-4.29.1 [notice] A new release of pip available: 22.3.1 -> 23.1.2 [notice] To update, run: pip install --upgrade pip
使ってみる
コード
import torch from transformers import AutoTokenizer, AutoModelForCausalLM device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = AutoTokenizer.from_pretrained("inu-ai/dolly-japanese-gpt-1b", use_fast=False) model = AutoModelForCausalLM.from_pretrained("inu-ai/dolly-japanese-gpt-1b").to(device) MAX_ASSISTANT_LENGTH = 100 MAX_INPUT_LENGTH = 1024 INPUT_PROMPT = r'' NO_INPUT_PROMPT = r'' def prepare_input(instruction, input_text): if input_text != "": prompt = INPUT_PROMPT.format(instruction=instruction, input=input_text) else: prompt = NO_INPUT_PROMPT.format(instruction=instruction) return prompt def format_output(output): output = output.lstrip("<s>").rstrip("</s>").replace("[SEP]", "").replace("\\n", "\n") return output def generate_response(instruction, input_text): prompt = prepare_input(instruction, input_text) token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt") n = len(token_ids[0]) with torch.no_grad(): output_ids = model.generate( token_ids.to(model.device), min_length=n, max_length=min(MAX_INPUT_LENGTH, n + MAX_ASSISTANT_LENGTH), temperature=0.7, do_sample=True, pad_token_id=tokenizer.pad_token_id, bos_token_id=tokenizer.bos_token_id, eos_token_id=tokenizer.eos_token_id, bad_words_ids=[[tokenizer.unk_token_id]] ) output = tokenizer.decode(output_ids.tolist()[0]) formatted_output_all = format_output(output) response = f"Assistant:{formatted_output_all.split('応答:')[-1].strip()}" return formatted_output_all, response instruction = "あなたは何でも正確に答えられるAIです。" questions = [ "一番好きなファイナルファンタジーの作品は?", "一番好きなドラゴンクエストの作品は?", "一番好きなポケモンは?", "一番強いスプラトゥーン3の武器は?", ] for question in questions: formatted_output_all, response = generate_response(instruction, question) print(response)||<
dolly-japanese-gpt-1b
実行
python sample.py Downloading (…)okenizer_config.json: 100%|████████████████████████████████████████████████████| 405/405 [00:00<00:00, 1.62MB/s] Downloading spiece.model: 100%|███████████████████████████████████████████████████████████| 1.04M/1.04M [00:00<00:00, 66.6MB/s] Downloading (…)cial_tokens_map.json: 100%|█████████████████████████████████████████████████████| 170/170 [00:00<00:00, 721kB/s] Downloading (…)lve/main/config.json: 100%|████████████████████████████████████████████████████| 826/826 [00:00<00:00, 1.89MB/s] Downloading pytorch_model.bin: 100%|██████████████████████████████████████████████████████| 2.63G/2.63G [00:30<00:00, 85.9MB/s] Downloading (…)neration_config.json: 100%|█████████████████████████████████████████████████████| 111/111 [00:00<00:00, 549kB/s] Assistant:ファイナルファンタジーVIII Assistant:一番好きなゲームには、『ドラゴンクエストIV』です。 Assistant:私は、どのポケモンも好きですが、最も好きなのはミュウツーですね。 Assistant:エイムとジャンプ力が優れているエイムアシスト、射程が長いキルショット、耐久性のある武器です。
結果
質問と回答
質問: 一番好きなファイナルファンタジーの作品は?
Assistant: ファイナルファンタジーVIII
質問: 一番好きなドラゴンクエストの作品は?
Assistant: 一番好きなゲームには、『ドラゴンクエストIV』です。
質問: 一番好きなポケモンは?
Assistant: 私は、どのポケモンも好きですが、最も好きなのはミュウツーですね。
質問: 一番強いスプラトゥーン3の武器は?
Assistant: エイムとジャンプ力が優れているエイムアシスト、射程が長いキルショット、耐久性のある武器です。