ものづくりのブログ

うちのネコを題材にしたものづくりができたらいいなと思っていろいろ奮闘してます。

【python】dolly-japanese-gpt-1b を使ってみる

dolly-japanese-gpt-1b を使って生成AIで遊んでみました。

準備

python バージョン確認

python は v3.11.3 を使用しました。

$ python --version
Python 3.11.3

sentencepiece インストール

SentencePiece は、言語モデルの学習データである 生のテキストから最適な分割点を学習する教師なし 単語分割システムです。

$ pip install sentencepiece
Collecting sentencepiece
  Downloading sentencepiece-0.1.99-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (1.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.3/1.3 MB 12.9 MB/s eta 0:00:00
Installing collected packages: sentencepiece
Successfully installed sentencepiece-0.1.99

[notice] A new release of pip available: 22.3.1 -> 23.1.2
[notice] To update, run: pip install --upgrade pip

torch インストール

Python向けのオープンソース機械学習ライブラリです。

$ pip3 install -U torch torchvision torchaudio --extra-index-url https://download.pytorch.org/whl/cu117
Looking in indexes: https://pypi.org/simple, https://download.pytorch.org/whl/cu117
Collecting torch
  Downloading https://download.pytorch.org/whl/cu117/torch-2.0.1%2Bcu117-cp311-cp311-linux_x86_64.whl (1843.9 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 1.8/1.8 GB 1.1 MB/s eta 0:00:00
Collecting torchvision
  Downloading https://download.pytorch.org/whl/cu117/torchvision-0.15.2%2Bcu117-cp311-cp311-linux_x86_64.whl (6.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 6.1/6.1 MB 5.4 MB/s eta 0:00:00
Collecting torchaudio
  Downloading https://download.pytorch.org/whl/cu117/torchaudio-2.0.2%2Bcu117-cp311-cp311-linux_x86_64.whl (4.4 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 4.4/4.4 MB 4.0 MB/s eta 0:00:00
Collecting filelock
  Downloading filelock-3.12.0-py3-none-any.whl (10 kB)
Collecting typing-extensions
  Downloading typing_extensions-4.5.0-py3-none-any.whl (27 kB)
Collecting sympy
  Downloading sympy-1.12-py3-none-any.whl (5.7 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 5.7/5.7 MB 32.6 MB/s eta 0:00:00
Collecting networkx
  Downloading networkx-3.1-py3-none-any.whl (2.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 2.1/2.1 MB 10.6 MB/s eta 0:00:00
Requirement already satisfied: jinja2 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from torch) (3.1.2)
Collecting triton==2.0.0
  Downloading https://download.pytorch.org/whl/triton-2.0.0-1-cp311-cp311-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (63.3 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 63.3/63.3 MB 12.0 MB/s eta 0:00:00
Collecting cmake
  Downloading cmake-3.26.3-py2.py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl (24.0 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 24.0/24.0 MB 37.5 MB/s eta 0:00:00
Collecting lit
  Downloading lit-16.0.3.tar.gz (138 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 138.0/138.0 kB 5.8 MB/s eta 0:00:00
  Preparing metadata (setup.py) ... done
Requirement already satisfied: numpy in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from torchvision) (1.24.3)
Requirement already satisfied: requests in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from torchvision) (2.30.0)
Requirement already satisfied: pillow!=8.3.*,>=5.3.0 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from torchvision) (9.5.0)
Requirement already satisfied: MarkupSafe>=2.0 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from jinja2->torch) (2.1.2)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from requests->torchvision) (3.1.0)
Requirement already satisfied: idna<4,>=2.5 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from requests->torchvision) (3.4)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from requests->torchvision) (2.0.2)
Requirement already satisfied: certifi>=2017.4.17 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from requests->torchvision) (2023.5.7)
Collecting mpmath>=0.19
  Downloading mpmath-1.3.0-py3-none-any.whl (536 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 536.2/536.2 kB 19.2 MB/s eta 0:00:00
Installing collected packages: mpmath, lit, cmake, typing-extensions, sympy, networkx, filelock, triton, torch, torchvision, torchaudio
  DEPRECATION: lit is being installed using the legacy 'setup.py install' method, because it does not have a 'pyproject.toml' and the 'wheel' package is not installed. pip 23.1 will enforce this behaviour change. A possible replacement is to enable the '--use-pep517' option. Discussion can be found at https://github.com/pypa/pip/issues/8559
  Running setup.py install for lit ... done
Successfully installed cmake-3.26.3 filelock-3.12.0 lit-16.0.3 mpmath-1.3.0 networkx-3.1 sympy-1.12 torch-2.0.1+cu117 torchaudio-2.0.2+cu117 torchvision-0.15.2+cu117 triton-2.0.0 typing-extensions-4.5.0

[notice] A new release of pip available: 22.3.1 -> 23.1.2
[notice] To update, run: pip install --upgrade pip

transformers インストール

Transformers は Hugging Face 社が公開している、最先端の NLP モデルの実装と事前学習済みモデルを提供するライブラリです。

$ pip install transformers
Collecting transformers
  Downloading transformers-4.29.1-py3-none-any.whl (7.1 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.1/7.1 MB 34.7 MB/s eta 0:00:00
Requirement already satisfied: filelock in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from transformers) (3.12.0)
Collecting huggingface-hub<1.0,>=0.14.1
  Downloading huggingface_hub-0.14.1-py3-none-any.whl (224 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 224.5/224.5 kB 9.4 MB/s eta 0:00:00
Requirement already satisfied: numpy>=1.17 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from transformers) (1.24.3)
Requirement already satisfied: packaging>=20.0 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from transformers) (23.1)
Requirement already satisfied: pyyaml>=5.1 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from transformers) (6.0)
Collecting regex!=2019.12.17
  Downloading regex-2023.5.5-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (780 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 780.9/780.9 kB 21.3 MB/s eta 0:00:00
Requirement already satisfied: requests in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from transformers) (2.30.0)
Collecting tokenizers!=0.11.3,<0.14,>=0.11.1
  Downloading tokenizers-0.13.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (7.8 MB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 7.8/7.8 MB 43.5 MB/s eta 0:00:00
Collecting tqdm>=4.27
  Downloading tqdm-4.65.0-py3-none-any.whl (77 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 77.1/77.1 kB 4.7 MB/s eta 0:00:00
Collecting fsspec
  Downloading fsspec-2023.5.0-py3-none-any.whl (160 kB)
     ━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━ 160.1/160.1 kB 8.1 MB/s eta 0:00:00
Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (4.5.0)
Requirement already satisfied: charset-normalizer<4,>=2 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from requests->transformers) (3.1.0)
Requirement already satisfied: idna<4,>=2.5 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from requests->transformers) (3.4)
Requirement already satisfied: urllib3<3,>=1.21.1 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from requests->transformers) (2.0.2)
Requirement already satisfied: certifi>=2017.4.17 in /home/{{user_name}}/.anyenv/envs/pyenv/versions/3.11.3/lib/python3.11/site-packages (from requests->transformers) (2023.5.7)
Installing collected packages: tokenizers, tqdm, regex, fsspec, huggingface-hub, transformers
Successfully installed fsspec-2023.5.0 huggingface-hub-0.14.1 regex-2023.5.5 tokenizers-0.13.3 tqdm-4.65.0 transformers-4.29.1

[notice] A new release of pip available: 22.3.1 -> 23.1.2
[notice] To update, run: pip install --upgrade pip

使ってみる

コード

import torch
from transformers import AutoTokenizer, AutoModelForCausalLM

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
tokenizer = AutoTokenizer.from_pretrained("inu-ai/dolly-japanese-gpt-1b", use_fast=False)
model = AutoModelForCausalLM.from_pretrained("inu-ai/dolly-japanese-gpt-1b").to(device)

MAX_ASSISTANT_LENGTH = 100
MAX_INPUT_LENGTH = 1024
INPUT_PROMPT = r''
NO_INPUT_PROMPT = r''

def prepare_input(instruction, input_text):
    if input_text != "":
        prompt = INPUT_PROMPT.format(instruction=instruction, input=input_text)
    else:
        prompt = NO_INPUT_PROMPT.format(instruction=instruction)
    return prompt

def format_output(output):
    output = output.lstrip("<s>").rstrip("</s>").replace("[SEP]", "").replace("\\n", "\n")
    return output

def generate_response(instruction, input_text):
    prompt = prepare_input(instruction, input_text)
    token_ids = tokenizer.encode(prompt, add_special_tokens=False, return_tensors="pt")
    n = len(token_ids[0])

    with torch.no_grad():
        output_ids = model.generate(
            token_ids.to(model.device),
            min_length=n,
            max_length=min(MAX_INPUT_LENGTH, n + MAX_ASSISTANT_LENGTH),
            temperature=0.7,
            do_sample=True,
            pad_token_id=tokenizer.pad_token_id,
            bos_token_id=tokenizer.bos_token_id,
            eos_token_id=tokenizer.eos_token_id,
            bad_words_ids=[[tokenizer.unk_token_id]]
        )

    output = tokenizer.decode(output_ids.tolist()[0])
    formatted_output_all = format_output(output)
    response = f"Assistant:{formatted_output_all.split('応答:')[-1].strip()}"

    return formatted_output_all, response 

instruction = "あなたは何でも正確に答えられるAIです。"
questions = [
    "一番好きなファイナルファンタジーの作品は?",
    "一番好きなドラゴンクエストの作品は?",
    "一番好きなポケモンは?",
    "一番強いスプラトゥーン3の武器は?",
]

for question in questions:
    formatted_output_all, response = generate_response(instruction, question)
    print(response)||<

dolly-japanese-gpt-1b

huggingface.co

実行

 python sample.py
Downloading (…)okenizer_config.json: 100%|████████████████████████████████████████████████████| 405/405 [00:00<00:00, 1.62MB/s]
Downloading spiece.model: 100%|███████████████████████████████████████████████████████████| 1.04M/1.04M [00:00<00:00, 66.6MB/s]
Downloading (…)cial_tokens_map.json: 100%|█████████████████████████████████████████████████████| 170/170 [00:00<00:00, 721kB/s]
Downloading (…)lve/main/config.json: 100%|████████████████████████████████████████████████████| 826/826 [00:00<00:00, 1.89MB/s]
Downloading pytorch_model.bin: 100%|██████████████████████████████████████████████████████| 2.63G/2.63G [00:30<00:00, 85.9MB/s]
Downloading (…)neration_config.json: 100%|█████████████████████████████████████████████████████| 111/111 [00:00<00:00, 549kB/s]
Assistant:ファイナルファンタジーVIII
Assistant:一番好きなゲームには、『ドラゴンクエストIV』です。
Assistant:私は、どのポケモンも好きですが、最も好きなのはミュウツーですね。
Assistant:エイムとジャンプ力が優れているエイムアシスト、射程が長いキルショット、耐久性のある武器です。

結果

質問と回答

質問: 一番好きなファイナルファンタジーの作品は?

Assistant: ファイナルファンタジーVIII

質問: 一番好きなドラゴンクエストの作品は?

Assistant: 一番好きなゲームには、『ドラゴンクエストIV』です。

質問: 一番好きなポケモンは?

Assistant: 私は、どのポケモンも好きですが、最も好きなのはミュウツーですね。

質問: 一番強いスプラトゥーン3の武器は?

Assistant: エイムとジャンプ力が優れているエイムアシスト、射程が長いキルショット、耐久性のある武器です。