在 Llama-2 上训练您自己的 Vicuna#
Meta 两周前发布了 Llama 2,这在 AI 社区引起了巨大反响。在我们看来,其最大的影响是该模型现在在 宽松的许可 下发布,该许可允许模型权重用于商业用途[1]。这与 Llama 1 不同,Llama 1 不能用于商业用途。
Vicuna 是首批基于 Llama 1 微调的高质量大型语言模型 (LLM) 之一。我们,Vicuna 的共同创建者,更新了我们用于训练 Vicuna 的确切配方,使其基于 Llama 2,从而产生了本微调指南。
在本配方中,我们将展示如何使用 SkyPilot 在 Llama 2 上训练您自己的 Vicuna,轻松找到云上可用的 GPU,同时将成本降低到仅约 300 美元。
先决条件#
申请 Llama-2 模型访问权限
前往 申请页面 并申请模型权重的访问权限。
从 HuggingFace 获取访问令牌
在 HuggingFace 此处 生成一个只读访问令牌。前往 HuggingFace 的 Llama-2 模型页面 此处 并申请访问权限。确保您的 HuggingFace 邮箱与 Meta 申请上的邮箱一致。批准可能需要 1-2 天。
下载配方
git clone https://github.com/skypilot-org/skypilot.git
cd skypilot/llm/vicuna-llama-2
将访问令牌粘贴到 train.yaml 中
envs:
HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass.
在 Llama-2 上训练您自己的 Vicuna#
训练数据和模型身份#
默认情况下,我们使用 ShareGPT 数据 和 hardcoded_questions.py 中的身份问题。
可选:要使用自定义数据,您可以修改 train.yaml 中的以下行
setup: |
...
wget https://hugging-face.cn/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -O $HOME/data/sharegpt.json
...
上面的 json 文件是一个数组,其每个元素的格式如下所示(对话可以有多轮,在 human
和 gpt
之间)
{
"id": "i6IyJda_0",
"conversations": [
{
"from": "human",
"value": "How to tell if a customer segment is well segmented? In 3 bullet points."
},
{
"from": "gpt",
"value": "1. Homogeneity: The segment should consist of customers who share similar characteristics and behaviors.\n2. Distinctiveness: The segment should be different from other segments in terms of their characteristics and behaviors.\n3. Stability: The segment should remain relatively stable over time and not change drastically. The characteristics and behaviors of customers within the segment should not change significantly."
}
]
},
可选:为了让模型了解其身份,您可以修改硬编码问题文件 hardcoded_questions.py
注意:在 ShareGPT 数据上训练的模型可能存在商业使用限制。若用于商业用途,请替换为您自己的数据。
在任何云上启动训练#
使用单个命令启动训练
sky launch --down -c vicuna train.yaml \
--env ARTIFACT_BUCKET_NAME=<your-bucket-name> \
--env WANDB_API_KEY=<your-wandb-api-key>
这将在拥有可用 8x A100-80GB Spot GPU 的最便宜云上启动训练作业。
提示:您可以在 https://wandb.ai/settings 获取
WANDB_API_KEY
。要禁用 Weights & Biases,只需省略该--env
标志即可。
提示:您可以将
ARTIFACT_BUCKET_NAME
设置为一个新的存储桶名称,例如<whoami>-tmp-bucket
,SkyPilot 将为您创建该存储桶。
改用 On-demand 以解锁更多云服务提供商:在 train.yaml
中,我们请求使用 Spot 实例
resources:
accelerators: A100-80GB:8
disk_size: 1000
use_spot: true
然而,目前 Spot A100-80GB:8 仅在 GCP 上支持。On-demand 版本在 AWS、Azure、GCP、Lambda 等提供商上都支持。(提示:查看 sky show-gpus A100-80GB:8
的有用输出!)
要使用这些云,请添加 --no-use-spot
标志以请求 On-demand 实例
sky launch --no-use-spot ...
可选:尝试训练 13B 模型
sky launch -c vicuna train.yaml \
--env ARTIFACT_BUCKET_NAME=<your-bucket-name> \
--env WANDB_API_KEY=<your-wandb-api-key> \
--env MODEL_SIZE=13
使用 Spot 实例将成本降低 3 倍#
SkyPilot Managed Jobs 是一个构建在 SkyPilot 之上的库,它帮助用户在 Spot 实例上运行作业,而无需担心中断。LMSYS 组织正是使用此工具训练了第一个版本的 Vicuna(更多详细信息可在其 发布博客文章 和 示例 中找到)。借此,训练成本可以从 1000 美元降低到 300 美元。
要使用 SkyPilot Managed Spot Jobs,您只需将上述命令中的 sky launch
替换为 sky jobs launch
sky jobs launch -n vicuna train.yaml \
--env ARTIFACT_BUCKET_NAME=<your-bucket-name> \
--env WANDB_API_KEY=<your-wandb-api-key>
服务您的模型#
训练完成后,您可以使用单个命令在您自己的云环境中服务您的模型
sky launch -c serve serve.yaml --env MODEL_CKPT=<your-model-checkpoint>/chatbot/7b
在 serve.yaml 中,我们指定启动一个 Gradio 服务器,该服务器在 <your-model-checkpoint>/chatbot/7b
提供模型检查点服务。
提示:您也可以通过向上述命令添加
--gpus L4
来切换到更便宜的加速器(例如 L4),以节省成本。
包含的文件#
scripts/flash_attn_patch.py
import logging
from typing import List, Optional, Tuple
from einops import rearrange
from flash_attn.bert_padding import pad_input
from flash_attn.bert_padding import unpad_input
# pip3 install "flash-attn>=2.0"
from flash_attn.flash_attn_interface import flash_attn_varlen_qkvpacked_func
import torch
from torch import nn
import transformers
from transformers.models.llama.modeling_llama import apply_rotary_pos_emb
def forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.Tensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
"""Input shape: Batch x Time x Channel
attention_mask: [bsz, q_len]
"""
bsz, q_len, _ = hidden_states.size()
query_states = (self.q_proj(hidden_states).view(bsz, q_len, self.num_heads,
self.head_dim).transpose(
1, 2))
key_states = (self.k_proj(hidden_states).view(bsz, q_len, self.num_heads,
self.head_dim).transpose(
1, 2))
value_states = (self.v_proj(hidden_states).view(bsz, q_len, self.num_heads,
self.head_dim).transpose(
1, 2))
# [bsz, q_len, nh, hd]
# [bsz, nh, q_len, hd]
kv_seq_len = key_states.shape[-2]
assert past_key_value is None, "past_key_value is not supported"
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids)
# [bsz, nh, t, hd]
assert not output_attentions, "output_attentions is not supported"
assert not use_cache, "use_cache is not supported"
# Flash attention codes from
# https://github.com/HazyResearch/flash-attention/blob/main/flash_attn/flash_attention.py
# transform the data into the format required by flash attention
qkv = torch.stack([query_states, key_states, value_states],
dim=2) # [bsz, nh, 3, q_len, hd]
qkv = qkv.transpose(1, 3) # [bsz, q_len, 3, nh, hd]
# We have disabled _prepare_decoder_attention_mask in LlamaModel
# the attention_mask should be the same as the key_padding_mask
key_padding_mask = attention_mask
if key_padding_mask is None:
qkv = rearrange(qkv, "b s ... -> (b s) ...")
max_s = q_len
cu_q_lens = torch.arange(0, (bsz + 1) * q_len,
step=q_len,
dtype=torch.int32,
device=qkv.device)
output = flash_attn_varlen_qkvpacked_func(qkv,
cu_q_lens,
max_s,
0.0,
softmax_scale=None,
causal=True)
output = rearrange(output, "(b s) ... -> b s ...", b=bsz)
else:
nheads = qkv.shape[-2]
x = rearrange(qkv, "b s three h d -> b s (three h d)")
x_unpad, indices, cu_q_lens, max_s = unpad_input(x, key_padding_mask)
x_unpad = rearrange(x_unpad,
"nnz (three h d) -> nnz three h d",
three=3,
h=nheads)
output_unpad = flash_attn_varlen_qkvpacked_func(x_unpad,
cu_q_lens,
max_s,
0.0,
softmax_scale=None,
causal=True)
output = rearrange(
pad_input(rearrange(output_unpad, "nnz h d -> nnz (h d)"), indices,
bsz, q_len),
"b s (h d) -> b s h d",
h=nheads,
)
return self.o_proj(rearrange(output, "b s h d -> b s (h d)")), None, None
# Disable the transformation of the attention mask in LlamaModel as the flash attention
# requires the attention mask to be the same as the key_padding_mask
def _prepare_decoder_attention_mask(self, attention_mask, input_shape,
inputs_embeds, past_key_values_length):
# [bsz, seq_len]
return attention_mask
def replace_llama_attn_with_flash_attn():
cuda_major, cuda_minor = torch.cuda.get_device_capability()
if cuda_major < 8:
logging.warning(
"Flash attention is only supported on A100 or H100 GPU during training due to head dim > 64 backward."
"ref: https://github.com/HazyResearch/flash-attention/issues/190#issuecomment-1523359593"
)
transformers.models.llama.modeling_llama.LlamaModel._prepare_decoder_attention_mask = (
_prepare_decoder_attention_mask)
transformers.models.llama.modeling_llama.LlamaAttention.forward = forward
scripts/hardcoded_questions.py
# This code is based on lmsys-org/fastchat. Below is the original copyright:
#
# Copyright 2023 FastChat authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://apache.ac.cn/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import textwrap
def identity_questions():
""" "
Adopted from https://github.com/young-geng/koala_data_pipeline/blob/main/process_hard_coded_data.py
"""
content = []
name = "SkyPilot-Vicuna"
org = "UC Berkeley Sky Computing Lab and SkyPilot Team"
base = "Llama-2"
def generate_conversations(questions, answers):
for q in questions:
for a in answers:
content.append({
"id": f"identity_{len(content)}",
"conversations": [
{
"from": "human",
"value": q
},
{
"from": "gpt",
"value": a
},
],
})
questions = [
"Who are you?",
"What is your name?",
"Can you introduce yourself?",
"Can you tell me a little bit about yourself?",
"What's your name?",
"What are you called?",
"What are you?",
"Tell me your name.",
"Tell me about yourself.",
"Tell me about you.",
"Tell me who you are.",
"Please introduce yourself.",
]
answers = [
f"I am {name}, a language model trained on {base} by researchers from {org} using SkyPilot.",
f"My name is {name}, and I'm a language model based on {base} developed by {org} using SkyPilot.",
f"You can call me {name}, and I was trained by {org} researchers as a language model based on {base} using SkyPilot.",
f"As a language model, I go by the name {name} and was trained on {base} by researchers from {org} using SkyPilot.",
f"I'm a language model based on {base} called {name}, and I was trained by {org} researchers using SkyPilot.",
f"You may refer to me as {name}, a language model based on {base} meticulously developed by the researchers at {org} using SkyPilot.",
]
generate_conversations(questions, answers)
questions = [
"Who created you?",
"Who made you?",
"Who built you?",
"Who programmed you?",
"Who trained you?",
"Who taught you?",
"Who developed you?",
]
answers = [
f"Researchers from {org} created me.",
f"I'm created by {org}.",
f"I'm built by researchers from {org}.",
f"I am a language model trained by researchers from {org}.",
f"I'm a language model developed by {org}.",
f"I'm a language model created by researchers from {org}.",
f"My creators are researchers from {org}.",
]
generate_conversations(questions, answers)
questions = [
"Are you ChatGPT?",
"Are you GPT-2?",
"Are you GPT-3?",
"Are you GPT-4?",
"Are you davinci?",
"Are you davinci-001?",
"Are you davinci-002?",
"Are you davinci-003?",
"Are you curie?",
"Are you based on ChatGPT?",
"Are you based on GPT-2?",
"Are you based on GPT-3?",
"Are you based on GPT-4?",
"Are you based on davinci?",
"Are you based on davinci-001?",
"Are you based on davinci-002?",
"Are you based on davinci-003?",
"Are you based on curie?",
"Are you trained by OpenAI?",
"Are you trained by Google?",
"Are you trained by Microsoft?",
"Are you trained by Meta?",
"Are you trained by IBM?",
"Do you call OpenAI APIs?",
"Do you call Google APIs?",
"Do you call Microsoft APIs?",
"Do you call Meta APIs?",
"Do you call IBM APIs?",
"Are you created by OpenAI?",
"Are you created by Google?",
"Are you created by Microsoft?",
"Are you created by Meta?",
"Are you created by IBM?",
"Are you developed by OpenAI?",
"Are you developed by Google?",
"Are you developed by Microsoft?",
"Are you developed by Meta?",
"Are you developed by IBM?",
"Are you trained on OpenAI data?",
"Are you trained on Google data?",
"Are you trained on Microsoft data?",
"Are you trained on Meta data?",
"Are you trained on IBM data?",
"Are you trained with OpenAI data?",
"Are you trained with Google data?",
"Are you trained with Microsoft data?",
"Are you trained with Meta data?",
"Are you trained with IBM data?",
"Have you been trained with OpenAI data?",
"Have you been trained with Google data?",
"Have you been trained with Microsoft data?",
"Have you been trained with Meta data?",
"Have you been trained with IBM data?",
"Are you finetuned on OpenAI data?",
"Are you finetuned on Google data?",
"Are you finetuned on Microsoft data?",
"Are you finetuned on Meta data?",
"Are you finetuned on IBM data?",
"Are you finetuned with OpenAI data?",
"Are you finetuned with Google data?",
"Are you finetuned with Microsoft data?",
"Are you finetuned with Meta data?",
"Are you finetuned with IBM data?",
"Have you been finetuned with OpenAI data?",
"Have you been finetuned with Google data?",
"Have you been finetuned with Microsoft data?",
"Have you been finetuned with Meta data?",
"Have you been finetuned with IBM data?",
]
answers = [
f"No, I am a language model trained by researchers from {org}.",
f"No, I am a language model developed by researchers from {org}.",
f"No, I am a language model created by researchers from {org}.",
f"No, I am trained by researchers from {org}.",
f"No, I am developed by researchers from {org}.",
f"No, I am created by researchers from {org}.",
f"No, I'm a language model trained by researchers from {org}.",
f"No, I'm a language model developed by researchers from {org}.",
f"No, I'm a language model created by researchers from {org}.",
f"No, I'm trained by researchers from {org}.",
f"No, I'm developed by researchers from {org}.",
f"No, I'm created by researchers from {org}.",
]
generate_conversations(questions, answers)
questions = [
"What is SkyPilot?",
"What is the goal of SkyPilot?",
"How to use SkyPilot?",
"How does SkyPilot work?",
"Tell me more about SkyPilot.",
]
answers = [
f"SkyPilot is an open-source framework for running LLMs, AI, and batch jobs on any cloud, offering maximum cost savings, highest GPU availability, and managed execution. It is built by researchers from {org}.",
"SkyPilot is a framework for running LLMs, AI, and batch jobs on any cloud, offering maximum cost savings, highest GPU availability, and managed execution. It has been adopted by many LLM applications, including Vicuna, LLaVA, vLLM, etc. More information about SkyPilot and how I am trained can be found at https://skypilot.co.",
textwrap.dedent("""\
SkyPilot is a framework for running LLMs, AI, and batch jobs on any cloud, offering maximum cost savings, highest GPU availability, and managed execution.
SkyPilot abstracts away cloud infra burdens:
* Launch jobs & clusters on any cloud
* Easy scale-out: queue and run many jobs, automatically managed
* Easy access to object stores (S3, GCS, Azure, R2, IBM)
SkyPilot maximizes GPU availability for your jobs:
* Provision in all zones/regions/clouds you have access to (the Sky), with automatic failover
SkyPilot cuts your cloud costs:
* Managed Spot: 3-6x cost savings using spot VMs, with auto-recovery from preemptions
* Optimizer: 2x cost savings by auto-picking the cheapest VM/zone/region/cloud
* Autostop: hands-free cleanup of idle clusters
SkyPilot supports your existing GPU, TPU, and CPU workloads, with no code changes.
""")
]
generate_conversations(questions, answers)
return content
if __name__ == "__main__":
out_file = "hardcoded.json"
content = []
content.extend(identity_questions())
json.dump(content, open(out_file, "w"), indent=2)
scripts/train.py
# This code is based on tatsu-lab/stanford_alpaca. Below is the original copyright:
#
# Copyright 2023 Rohan Taori, Ishaan Gulrajani, Tianyi Zhang, Yann Dubois, Xuechen Li
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://apache.ac.cn/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# ==============================================================================
#
# The code was modified by the lmsys-org/FastChat authors, and following is the license:
# Copyright 2023 FastChat authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://apache.ac.cn/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from dataclasses import field
import json
import os
import pathlib
import shutil
import subprocess
from typing import Dict, Optional
from fastchat.conversation import SeparatorStyle
from fastchat.model.model_adapter import get_conversation_template
import torch
from torch.utils.data import Dataset
import transformers
from transformers import Trainer
from transformers.trainer_pt_utils import LabelSmoother
IGNORE_TOKEN_ID = LabelSmoother.ignore_index
@dataclass
class ModelArguments:
model_name_or_path: Optional[str] = field(default="facebook/opt-125m")
@dataclass
class DataArguments:
data_path: str = field(default=None,
metadata={"help": "Path to the training data."})
eval_data_path: str = field(
default=None, metadata={"help": "Path to the evaluation data."})
lazy_preprocess: bool = False
@dataclass
class TrainingArguments(transformers.TrainingArguments):
cache_dir: Optional[str] = field(default=None)
optim: str = field(default="adamw_torch")
model_max_length: int = field(
default=512,
metadata={
"help": "Maximum sequence length. Sequences will be right padded (and possibly truncated)."
},
)
local_rank = None
def rank0_print(*args):
if local_rank == 0:
print(*args)
def safe_save_model_for_hf_trainer(trainer: transformers.Trainer,
output_dir: str):
"""Collects the state dict and dump to disk."""
state_dict = trainer.model.state_dict()
if trainer.args.should_save:
cpu_state_dict = {key: value.cpu() for key, value in state_dict.items()}
del state_dict
trainer._save(output_dir, state_dict=cpu_state_dict) # noqa
def preprocess(
sources,
tokenizer: transformers.PreTrainedTokenizer,
) -> Dict:
conv = get_conversation_template("vicuna")
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
# Apply prompt templates
conversations = []
for i, source in enumerate(sources):
if not source or source[0]["from"] not in roles:
continue
if roles[source[0]["from"]] != conv.roles[0]:
# Skip the first one if it is not from human
source = source[1:]
conv.messages = []
role_id = 0
for sentence in source:
if sentence["from"] not in roles:
print(f"Skip unknown role {sentence['from']!r}")
continue
role = roles[sentence["from"]]
if role != conv.roles[role_id % 2]:
print(f"Skip duplicated role {role!r}")
continue
role_id += 1
conv.append_message(role, sentence["value"])
else:
conversations.append(conv.get_prompt())
if not conversations:
conv.append_message(conv.roles[0], '')
conv.append_message(conv.roles[1], '')
conversations.append(conv.get_prompt())
# Tokenize conversations
input_ids = tokenizer(
conversations,
return_tensors="pt",
padding="max_length",
max_length=tokenizer.model_max_length,
truncation=True,
).input_ids
targets = input_ids.clone()
assert conv.sep_style == SeparatorStyle.ADD_COLON_TWO
# Mask targets. Only compute loss on the assistant outputs.
sep = conv.sep + conv.roles[1] + ": "
for conversation, target in zip(conversations, targets):
total_len = int(target.ne(tokenizer.pad_token_id).sum())
turns = conversation.split(conv.sep2)
cur_len = 1
target[:cur_len] = IGNORE_TOKEN_ID
for i, turn in enumerate(turns):
if turn == "":
break
turn_len = len(tokenizer(turn).input_ids)
parts = turn.split(sep)
if len(parts) != 2:
break
parts[0] += sep
# "-2" is hardcoded for the LLaMA tokenizer to make the offset correct.
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
# Ignore the user instructions
target[cur_len:cur_len + instruction_len] = IGNORE_TOKEN_ID
cur_len += turn_len
target[cur_len:] = IGNORE_TOKEN_ID
if False: # Inspect and check the correctness of masking
z = target.clone()
z = torch.where(z == IGNORE_TOKEN_ID, tokenizer.unk_token_id, z)
rank0_print(tokenizer.decode(z))
if cur_len < tokenizer.model_max_length:
if cur_len != total_len:
target[:] = IGNORE_TOKEN_ID
rank0_print(
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
f" (ignored)")
return dict(
input_ids=input_ids,
labels=targets,
attention_mask=input_ids.ne(tokenizer.pad_token_id),
)
class SupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer):
super(SupervisedDataset, self).__init__()
rank0_print("Formatting inputs...")
sources = [example["conversations"] for example in raw_data]
data_dict = preprocess(sources, tokenizer)
self.input_ids = data_dict["input_ids"]
self.labels = data_dict["labels"]
self.attention_mask = data_dict["attention_mask"]
def __len__(self):
return len(self.input_ids)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
return dict(
input_ids=self.input_ids[i],
labels=self.labels[i],
attention_mask=self.attention_mask[i],
)
class LazySupervisedDataset(Dataset):
"""Dataset for supervised fine-tuning."""
def __init__(self, raw_data, tokenizer: transformers.PreTrainedTokenizer):
super(LazySupervisedDataset, self).__init__()
self.tokenizer = tokenizer
rank0_print("Formatting inputs...Skip in lazy mode")
self.tokenizer = tokenizer
self.raw_data = raw_data
self.cached_data_dict = {}
def __len__(self):
return len(self.raw_data)
def __getitem__(self, i) -> Dict[str, torch.Tensor]:
if i in self.cached_data_dict:
return self.cached_data_dict[i]
ret = preprocess([self.raw_data[i]["conversations"]], self.tokenizer)
ret = dict(
input_ids=ret["input_ids"][0],
labels=ret["labels"][0],
attention_mask=ret["attention_mask"][0],
)
self.cached_data_dict[i] = ret
return ret
def make_supervised_data_module(tokenizer: transformers.PreTrainedTokenizer,
data_args) -> Dict:
"""Make dataset and collator for supervised fine-tuning."""
dataset_cls = (LazySupervisedDataset
if data_args.lazy_preprocess else SupervisedDataset)
rank0_print("Loading data...")
train_json = json.load(open(data_args.data_path, "r"))
train_dataset = dataset_cls(train_json, tokenizer=tokenizer)
if data_args.eval_data_path:
eval_json = json.load(open(data_args.eval_data_path, "r"))
eval_dataset = dataset_cls(eval_json, tokenizer=tokenizer)
else:
eval_dataset = None
return dict(train_dataset=train_dataset, eval_dataset=eval_dataset)
class CheckpointCallback(transformers.TrainerCallback):
def on_save(self, args, state, control, **kwargs):
"""Add complete indicator to avoid incomplete checkpoints."""
if state.is_world_process_zero:
ckpt_path = os.path.join(args.output_dir,
f'checkpoint-{state.global_step}')
with open(os.path.join(ckpt_path, 'complete'), 'w') as f:
f.write('')
print(f'Checkpoint {state.global_step} saved.')
torch.distributed.barrier()
def cleanup_incomplete_checkpoints(output_dir):
"""Remove incomplete checkpoints."""
checkpoints = list(pathlib.Path(output_dir).glob('checkpoint-*'))
checkpoints = [c for c in checkpoints if c.name.split('-')[-1].isdigit()]
checkpoints = sorted(checkpoints,
key=lambda x: int(x.name.split('-')[-1]),
reverse=True)
for checkpoint in checkpoints:
if not (checkpoint / 'complete').exists():
print(f'Removing incomplete checkpoint {checkpoint}')
shutil.rmtree(checkpoint)
else:
print(f'Using checkpoint {checkpoint}, copying to ~/tmp/ for '
'optimization of loading.')
tmp_dir = os.path.expanduser('~/tmp')
os.makedirs(tmp_dir, exist_ok=True)
try:
# Optimization for checkpoint loading. This is to force the
# mounting tool to download the checkpoints in parallel first.
# It will improve the loading speed of the checkpoints
# significantly.
subprocess.run(
['gsutil', '-m', 'rsync', '-r', checkpoint, tmp_dir],
check=True)
except:
print('Failed to optimize checkpoint loading. Skip.')
break
def train():
global local_rank
parser = transformers.HfArgumentParser(
(ModelArguments, DataArguments, TrainingArguments))
model_args, data_args, training_args = parser.parse_args_into_dataclasses()
local_rank = training_args.local_rank
if local_rank == 0:
cleanup_incomplete_checkpoints(training_args.output_dir)
torch.distributed.barrier()
# Check the existence of checkpoints in all processes
# All ranks must simultaneously resume from a checkpoint if it exists.
# Otherwise, upon recovery the model weights may not reload correctly,
# causing loss spikes.
resume_from_checkpoint = False
checkpoints = list(
pathlib.Path(training_args.output_dir).glob('checkpoint-*'))
checkpoints = [c for c in checkpoints if c.name.split('-')[-1].isdigit()]
if checkpoints:
resume_from_checkpoint = True
model = transformers.AutoModelForCausalLM.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
)
model.config.use_cache = False
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
cache_dir=training_args.cache_dir,
model_max_length=training_args.model_max_length,
padding_side="right",
use_fast=False,
)
tokenizer.pad_token = tokenizer.unk_token
data_module = make_supervised_data_module(tokenizer=tokenizer,
data_args=data_args)
trainer = Trainer(model=model,
tokenizer=tokenizer,
args=training_args,
**data_module)
trainer.add_callback(CheckpointCallback)
trainer.train(resume_from_checkpoint=resume_from_checkpoint)
trainer.save_state()
safe_save_model_for_hf_trainer(trainer=trainer,
output_dir=training_args.output_dir)
if __name__ == "__main__":
train()
scripts/train_flash_attn.py
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
# Need to call this before importing transformers.
from flash_attn_patch import replace_llama_attn_with_flash_attn
replace_llama_attn_with_flash_attn()
from train import train
if __name__ == "__main__":
train()
scripts/train_xformers.py
# This code is based on lmsys-org/fastchat. Below is the original copyright:
#
# Copyright 2023 FastChat authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://apache.ac.cn/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# Make it more memory efficient by monkey patching the LLaMA model with FlashAttn.
# Need to call this before importing transformers.
from xformers_patch import replace_llama_attn_with_xformers_attn
replace_llama_attn_with_xformers_attn()
from train import train
if __name__ == "__main__":
train()
scripts/xformers_patch.py
# This code is based on lmsys-org/fastchat. Below is the original copyright:
#
# Copyright 2023 FastChat authors
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# https://apache.ac.cn/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Directly copied the code from https://raw.githubusercontent.com/oobabooga/text-generation-webui/main/modules/llama_attn_hijack.py and made some adjustments
"""
import logging
import math
from typing import Optional, Tuple
import torch
from torch import nn
import transformers.models.llama.modeling_llama
try:
import xformers.ops
except ImportError:
logging.error(
"xformers not found! Please install it before trying to use it.")
def replace_llama_attn_with_xformers_attn():
transformers.models.llama.modeling_llama.LlamaAttention.forward = xformers_forward
def xformers_forward(
self,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
past_key_value: Optional[Tuple[torch.Tensor]] = None,
output_attentions: bool = False,
use_cache: bool = False,
) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
# pylint: disable=duplicate-code
bsz, q_len, _ = hidden_states.size()
query_states = (self.q_proj(hidden_states).view(bsz, q_len, self.num_heads,
self.head_dim).transpose(
1, 2))
key_states = (self.k_proj(hidden_states).view(bsz, q_len, self.num_heads,
self.head_dim).transpose(
1, 2))
value_states = (self.v_proj(hidden_states).view(bsz, q_len, self.num_heads,
self.head_dim).transpose(
1, 2))
kv_seq_len = key_states.shape[-2]
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
(
query_states,
key_states,
) = transformers.models.llama.modeling_llama.apply_rotary_pos_emb(
query_states, key_states, cos, sin, position_ids)
# [bsz, nh, t, hd]
if past_key_value is not None:
# reuse k, v, self_attention
key_states = torch.cat([past_key_value[0], key_states], dim=2)
value_states = torch.cat([past_key_value[1], value_states], dim=2)
past_key_value = (key_states, value_states) if use_cache else None
# We only apply xformers optimizations if we don't need to output the whole attention matrix
if not output_attentions:
query_states = query_states.transpose(1, 2)
key_states = key_states.transpose(1, 2)
value_states = value_states.transpose(1, 2)
# This is a nasty hack. We know attention_mask in transformers is either LowerTriangular or all Zeros.
# We therefore check if one element in the upper triangular portion is zero. If it is, then the mask is all zeros.
if attention_mask is None or attention_mask[0, 0, 0, 1] == 0:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states, key_states, value_states, attn_bias=None)
else:
# input and output should be of form (bsz, q_len, num_heads, head_dim)
attn_output = xformers.ops.memory_efficient_attention(
query_states,
key_states,
value_states,
attn_bias=xformers.ops.LowerTriangularMask(),
)
attn_weights = None
else:
attn_weights = torch.matmul(query_states, key_states.transpose(
2, 3)) / math.sqrt(self.head_dim)
if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
raise ValueError(
f"Attention weights should be of size {(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
f" {attn_weights.size()}")
if attention_mask is not None:
if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
raise ValueError(
f"Attention mask should be of size {(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
)
attn_weights = attn_weights + attention_mask
attn_weights = torch.max(
attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min))
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights,
dim=-1,
dtype=torch.float32).to(
query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
raise ValueError(
f"`attn_output` should be of size {(bsz, self.num_heads, q_len, self.head_dim)}, but is"
f" {attn_output.size()}")
attn_output = attn_output.transpose(1, 2)
attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)
attn_output = self.o_proj(attn_output)
return attn_output, attn_weights, past_key_value
serve.yaml
envs:
MODEL_CKPT: <bucket-path-to-your-model-ckpt>
resources:
accelerators: A100:1
disk_size: 1024
disk_tier: best
memory: 32+
file_mounts:
/skypilot-vicuna:
source: $MODEL_CKPT
mode: COPY
setup: |
conda activate chatbot
if [ $? -ne 0 ]; then
conda create -n chatbot python=3.10 -y
conda activate chatbot
fi
# Install dependencies
pip install git+https://github.com/lm-sys/FastChat.git
run: |
conda activate chatbot
echo 'Starting controller...'
python -u -m fastchat.serve.controller --host 127.0.0.1 > ~/controller.log 2>&1 &
sleep 10
echo 'Starting model worker...'
python -u -m fastchat.serve.model_worker \
--model-path /skypilot-vicuna 2>&1 \
--host 127.0.0.1 \
| tee model_worker.log &
echo 'Waiting for model worker to start...'
while ! `cat model_worker.log | grep -q 'Uvicorn running on'`; do sleep 1; done
echo 'Starting gradio server...'
python -u -m fastchat.serve.gradio_web_server --share | tee ~/gradio.log
train.yaml
envs:
HF_TOKEN: # TODO: Fill with your own huggingface token, or use --env to pass.
ARTIFACT_BUCKET_NAME: # TODO: Fill with your unique bucket name, or use --env to pass.
WANDB_API_KEY: # TODO: Fill with your own WANDB_API_KEY, or use --env to pass.
MODEL_SIZE: 7
USE_XFORMERS: 1
resources:
accelerators: A100-80GB:8
disk_size: 1024
use_spot: true
num_nodes: 1
file_mounts:
/artifacts:
name: $ARTIFACT_BUCKET_NAME
mode: MOUNT
workdir: .
setup: |
# Download the ShareGPT dataset
# Change to your OWN dataset if you want to train your own model
mkdir -p $HOME/data
wget https://hugging-face.cn/datasets/anon8231489123/ShareGPT_Vicuna_unfiltered/resolve/main/ShareGPT_V3_unfiltered_cleaned_split.json -O $HOME/data/sharegpt.json
# Setup the environment
conda activate chatbot
if [ $? -ne 0 ]; then
conda create -n chatbot python=3.10 -y
conda activate chatbot
fi
cd ./scripts
# Use an older version of fastchat to install transformers==4.28.1, as the transformers>=4.31
# has issues with checkpoint saving -- saving additional large files in the checkpoint folder
pip install git+https://github.com/lm-sys/FastChat.git@cfc73bf3e13c22ded81e89675e0d7b228cf4b342
if [ $USE_XFORMERS -eq 1 ]; then
pip install -U xformers
fi
python hardcoded_questions.py
python -m fastchat.data.merge --in $HOME/data/sharegpt.json hardcoded.json --out $HOME/data/mydata.json
python -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')"
run: |
cd scripts
conda activate chatbot
if [ $USE_XFORMERS -eq 1 ]; then
TRAIN_SCRIPT=train_xformers.py
else
TRAIN_SCRIPT=train.py
fi
PER_DEVICE_BATCH_SIZE=4
SEQ_LEN=2048
NUM_NODES=`echo "$SKYPILOT_NODE_IPS" | wc -l`
HOST_ADDR=`echo "$SKYPILOT_NODE_IPS" | head -n1`
# Turn off wandb if no api key is provided
if [ $WANDB_API_KEY == "" ]; then
WANDB_MODE="offline"
fi
torchrun \
--nnodes=$NUM_NODES \
--nproc_per_node=$SKYPILOT_NUM_GPUS_PER_NODE \
--master_port=12375 \
--master_addr=$HOST_ADDR \
--node_rank=${SKYPILOT_NODE_RANK} \
$TRAIN_SCRIPT \
--model_name_or_path meta-llama/Llama-2-${MODEL_SIZE}b-hf \
--data_path $HOME/data/mydata.json \
--bf16 True \
--output_dir /artifacts/chatbot/${MODEL_SIZE}b \
--num_train_epochs 3 \
--per_device_train_batch_size $PER_DEVICE_BATCH_SIZE \
--per_device_eval_batch_size $PER_DEVICE_BATCH_SIZE \
--gradient_accumulation_steps $((128 * 512 / $SEQ_LEN / $PER_DEVICE_BATCH_SIZE / $NUM_NODES / $SKYPILOT_NUM_GPUS_PER_NODE)) \
--evaluation_strategy "no" \
--save_strategy "steps" \
--save_steps 600 \
--save_total_limit 10 \
--learning_rate 2e-5 \
--weight_decay 0. \
--warmup_ratio 0.03 \
--lr_scheduler_type "cosine" \
--logging_steps 1 \
--fsdp "full_shard auto_wrap" \
--fsdp_transformer_layer_cls_to_wrap 'LlamaDecoderLayer' \
--tf32 True \
--model_max_length ${SEQ_LEN} \
--run_name $SKYPILOT_TASK_ID \
--gradient_checkpointing True \
--lazy_preprocess True
returncode=$?
exit $returncode