来源:examples/tpu
TPU#
本示例展示了如何使用 SkyPilot 启动 TPU 作业。
注意:部分示例可能已过时。请查看
v6e/
目录下的文件以获取最新示例。另请参阅:https://docs.skypilot.org.cn/en/latest/reference/tpu.html。
包含的文件#
tpu_app.py
import sky
with sky.Dag() as dag:
# The working directory contains all code and will be synced to remote.
workdir = './examples/tpu/tpu_app_code'
# The setup command. Will be run under the working directory.
setup = 'pip install --upgrade pip && \
conda activate huggingface || \
(conda create -n huggingface python=3.8 -y && \
conda activate huggingface && \
pip install -r requirements.txt)'
# The command to run. Will be run under the working directory.
run = 'conda activate huggingface && python -u run_tpu.py'
train = sky.Task(
'train',
workdir=workdir,
setup=setup,
run=run,
)
train.set_resources({
sky.Resources(accelerators='tpu-v3-8',
accelerator_args={
'runtime_version': '2.12.0',
'tpu_name': 'weilin-bert-test-big'
}),
})
sky.launch(dag)
tpu_app.yaml
name: tpu_app
# The working directory contains all code and will be synced to remote.
workdir: ./examples/tpu/tpu_app_code
resources:
accelerators: tpu-v2-8
# The setup command. Will be run under the working directory.
setup: |
pip install --upgrade pip
conda activate huggingface
if [ $? -eq 0 ]; then
echo 'conda env exists'
else
conda create -n huggingface python=3.8 -y
conda activate huggingface
pip install -r requirements.txt
fi
# The command to run. Will be run under the working directory.
run: |
conda activate huggingface
python -u run_tpu.py
tpu_app_code/requirements.txt
tensorflow==2.5.1
tensorflow-datasets==4.4.0
transformers==4.12.0
tensorflow-text==2.5.0
cloud-tpu-client==0.10
tpu_app_code/run_tpu.py
import tensorflow as tf
import tensorflow_datasets as tfds
import tensorflow_text as tf_text
from transformers import TFBertForSequenceClassification
from transformers import TFDistilBertForSequenceClassification
tpu = tf.distribute.cluster_resolver.TPUClusterResolver()
tf.config.experimental_connect_to_cluster(tpu)
tf.tpu.experimental.initialize_tpu_system(tpu)
strategy = tf.distribute.experimental.TPUStrategy(tpu)
ds_train, ds_info = tfds.load('amazon_us_reviews/Books_v1_02',
split='train[:5%]',
with_info=True,
data_dir="gs://weilin-bert-test")
MAX_SEQ_LEN = 512
bert_tokenizer = tf_text.BertTokenizer(
vocab_lookup_table='gs://weilin-bert-test/vocab.txt',
token_out_type=tf.int64,
lower_case=True)
def preprocessing_fn(inputs):
"""Preprocess input column of text into transformed columns of.
* input token ids
* input mask
* input type ids
"""
CLS_ID = tf.constant(101, dtype=tf.int64)
SEP_ID = tf.constant(102, dtype=tf.int64)
PAD_ID = tf.constant(0, dtype=tf.int64)
def tokenize_text(text, sequence_length=MAX_SEQ_LEN):
"""
Perform the BERT preprocessing from text -> input token ids
"""
# convert text into token ids
tokens = bert_tokenizer.tokenize(text)
# flatten the output ragged tensors
tokens = tokens.merge_dims(1, 2)[:, :sequence_length]
# Add start and end token ids to the id sequence
start_tokens = tf.fill([tf.shape(text)[0], 1], CLS_ID)
end_tokens = tf.fill([tf.shape(text)[0], 1], SEP_ID)
tokens = tokens[:, :sequence_length - 2]
tokens = tf.concat([start_tokens, tokens, end_tokens], axis=1)
# truncate sequences greater than MAX_SEQ_LEN
tokens = tokens[:, :sequence_length]
# pad shorter sequences with the pad token id
tokens = tokens.to_tensor(default_value=PAD_ID)
pad = sequence_length - tf.shape(tokens)[1]
tokens = tf.pad(tokens, [[0, 0], [0, pad]], constant_values=PAD_ID)
# and finally reshape the word token ids to fit the output
# data structure of TFT
return tf.reshape(tokens, [-1, sequence_length])
def preprocess_bert_input(text):
"""
Convert input text into the input_word_ids, input_mask, input_type_ids
"""
input_word_ids = tokenize_text(text)
input_mask = tf.cast(input_word_ids > 0, tf.int64)
input_mask = tf.reshape(input_mask, [-1, MAX_SEQ_LEN])
zeros_dims = tf.stack(tf.shape(input_mask))
input_type_ids = tf.fill(zeros_dims, 0)
input_type_ids = tf.cast(input_type_ids, tf.int64)
return (tf.squeeze(input_word_ids,
axis=0), tf.squeeze(input_mask, axis=0),
tf.squeeze(input_type_ids, axis=0))
input_word_ids, input_mask, input_type_ids = preprocess_bert_input(
[inputs['data']['review_body']])
return (dict({
'input_ids': input_word_ids,
'token_type_ids': input_type_ids,
'attention_mask': input_mask
}), inputs['data']['star_rating'])
from transformers import BertTokenizerFast
tokenizer = BertTokenizerFast.from_pretrained('bert-base-uncased')
def dataset_fn(ds):
return ds.filter(lambda x: x['data']['helpful_votes'] >= 7)
ds_train_filtered = ds_train.apply(dataset_fn)
def process(example):
return (dict(tokenizer(
example['data']['review_body'].numpy().decode('utf-8')),
truncation=True,
padding=True), example['data']['star_rating'].numpy())
def process_py(inp1, inp2):
return [
dict(tokenizer(inp1.numpy().decode('utf-8')),
truncation=True,
padding=True),
inp2.numpy()
]
ds_train_filtered_2 = ds_train_filtered.map(preprocessing_fn)
tf.keras.mixed_precision.experimental.set_policy('mixed_bfloat16')
with strategy.scope():
model = TFBertForSequenceClassification.from_pretrained('bert-base-uncased',
num_labels=1)
optimizer = tf.keras.optimizers.Adam(learning_rate=5e-5)
model.compile(optimizer=optimizer,
loss=model.compute_loss) # can also use any keras loss fn
model.summary()
inuse_dataset = ds_train_filtered_2.shuffle(1000).batch(256).prefetch(
tf.data.experimental.AUTOTUNE)
model.fit(inuse_dataset, epochs=1, batch_size=256)
tpu_node_mnist.yaml
name: mnist-tpu-node
resources:
instance_type: n1-highmem-8
accelerators: tpu-v2-8
accelerator_args:
runtime_version: 2.12.0
tpu_vm: False
file_mounts:
/dataset:
name: demo-mnist-tpu
store: gcs
mode: MOUNT
# The setup command. Will be run under the working directory.
setup: |
git clone https://github.com/tensorflow/models.git
conda activate mnist
if [ $? -eq 0 ]; then
echo 'conda env exists'
else
conda create -n mnist python=3.8 -y
conda activate mnist
pip install tensorflow==2.12.0 tensorflow-datasets tensorflow-model-optimization cloud-tpu-client
fi
# The command to run. Will be run under the working directory.
run: |
conda activate mnist
cd models/official/legacy/image_classification/
export STORAGE_BUCKET=gs://demo-mnist-tpu
export MODEL_DIR=${STORAGE_BUCKET}/mnist
export DATA_DIR=${STORAGE_BUCKET}/data
export PYTHONPATH=/home/gcpuser/sky_workdir/models
python3 mnist_main.py \
--tpu=${TPU_NAME} \
--model_dir=${MODEL_DIR} \
--data_dir=${DATA_DIR} \
--train_epochs=10 \
--distribution_strategy=tpu \
--download
tpuvm_mnist.yaml
name: tpuvm_mnist
resources:
accelerators: tpu-v2-8
# The setup command. Will be run under the working directory.
setup: |
git clone https://github.com/google/flax.git --branch v0.10.1
conda activate flax
if [ $? -eq 0 ]; then
echo 'conda env exists'
else
conda create -n flax python=3.10 -y
conda activate flax
# Make sure to install TPU related packages in a conda env to avoid package conflicts.
pip install \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html "jax[tpu]==0.4.35" \
clu \
tensorflow tensorflow-datasets
pip install -e flax
fi
# The command to run. Will be run under the working directory.
run: |
conda activate flax
pip install clu
cd flax/examples/mnist
python3 main.py --workdir=/tmp/mnist \
--config=configs/default.py \
--config.learning_rate=0.05 \
--config.num_epochs=10
v6e/README.md
TPU v6e
Trillium (也指 v6e) 是 Cloud TPU 的最新一代 AI 加速器。SkyPilot 支持 TPU v6e 的配置、训练和服务。
目录
目前,对于 TPU v6e,地区和定价的公共 API 尚未发布,us-central1
、us-central2
、us-south1
的定价信息也不可用。目前我们将这些地区的价格设置为 0.0
。
## Provisioning
To provision TPU v6e, use the following command:
```bash
$ sky launch --gpus tpu-v6e-16 -c tpu-v6e
之后,您可以 SSH 连接到实例并开始开发您的模型
$ ssh tpu-v6e
训练
此目录中的示例 (train-llama3-8b.yaml
) 展示了如何使用 TPU v6e 在 wikitext 数据集上使用 PyTorch (XLA) 训练一个 Llama3 8b 模型。要开始训练,请使用以下命令
$ HF_TOKEN=hf_xxx sky launch train-llama3-8b.yaml -c train-llama3-8b --env HF_TOKEN
单主机训练
一个 tpu-v6e-8
实例的训练吞吐量应在 0.5 样本/秒左右
(task, pid=17499) ***** train metrics *****
(task, pid=17499) epoch = 1.1765
(task, pid=17499) total_flos = 109935420GF
(task, pid=17499) train_loss = 10.6011
(task, pid=17499) train_runtime = 0:11:12.77
(task, pid=17499) train_samples = 282
(task, pid=17499) train_samples_per_second = 0.476
(task, pid=17499) train_steps_per_second = 0.03
INFO: Job finished (status: SUCCEEDED).
多主机训练
通过将 TPU 类型更改为 tpu-v6e-16
并将 --per_device_train_batch_size
更改为 32
,训练吞吐量增加到约 1 样本/秒
(head, rank=0, pid=17894) ***** train metrics *****
(head, rank=0, pid=17894) epoch = 2.5
(head, rank=0, pid=17894) total_flos = 219870840GF
(head, rank=0, pid=17894) train_loss = 10.1527
(head, rank=0, pid=17894) train_runtime = 0:11:13.18
(head, rank=0, pid=17894) train_samples = 282
(head, rank=0, pid=17894) train_samples_per_second = 0.951
(head, rank=0, pid=17894) train_steps_per_second = 0.03
(worker1, rank=1, pid=15406, ip=10.164.0.57) ***** train metrics *****
(worker1, rank=1, pid=15406, ip=10.164.0.57) epoch = 2.5
(worker1, rank=1, pid=15406, ip=10.164.0.57) total_flos = 219870840GF
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_loss = 10.1527
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_runtime = 0:11:15.08
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_samples = 282
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_samples_per_second = 0.948
(worker1, rank=1, pid=15406, ip=10.164.0.57) train_steps_per_second = 0.03
(worker2, rank=2, pid=16552, ip=10.164.0.58) ***** train metrics *****
(worker2, rank=2, pid=16552, ip=10.164.0.58) epoch = 2.5
(worker2, rank=2, pid=16552, ip=10.164.0.58) total_flos = 219870840GF
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_loss = 10.1527
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_runtime = 0:11:15.61
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_samples = 282
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_samples_per_second = 0.947
(worker2, rank=2, pid=16552, ip=10.164.0.58) train_steps_per_second = 0.03
(worker3, rank=3, pid=17469, ip=10.164.0.59) ***** train metrics *****
(worker3, rank=3, pid=17469, ip=10.164.0.59) epoch = 2.5
(worker3, rank=3, pid=17469, ip=10.164.0.59) total_flos = 219870840GF
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_loss = 10.1527
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_runtime = 0:11:15.10
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_samples = 282
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_samples_per_second = 0.948
(worker3, rank=3, pid=17469, ip=10.164.0.59) train_steps_per_second = 0.03
INFO: Job finished (status: SUCCEEDED).
服务
TPU v6e 还支持服务。此目录中的示例 (serve-llama2-7b.yaml
) 展示了如何使用 TPU v6e 使用 PyTorch (XLA) 和 JetStream 库提供 Llama2 7b 模型服务。要开始服务,请使用以下命令
$ HF_TOKEN=hf_xxx sky launch serve-llama2-7b.yaml -c serve-llama2-7b --env HF_TOKEN
服务器准备就绪后,您应该看到以下消息
(task, pid=26431) 2024-09-24 19:58:15,160 - root - INFO - Starting server on port 9000 with 64 threads
(task, pid=26431) I0924 19:58:15.160293 140454572087296 server_lib.py:155] Starting server on port 9000 with 64 threads
(task, pid=26431) 2024-09-24 19:58:15,161 - root - INFO - Not starting JAX profiler server: False
(task, pid=26431) I0924 19:58:15.161907 140454572087296 server_lib.py:164] Not starting JAX profiler server: False
(task, pid=26431) Started jetstream_server....
现在您可以开始基准测试以测试服务性能
$ sky exec serve-llama2-7b benchmark-llama2-7b.yaml
... (emitted logs)
(task, pid=25491) Successful requests: 100
(task, pid=25491) Benchmark duration: 8.753792 s
(task, pid=25491) Total input tokens: 21888
(task, pid=25491) Total generated tokens: 18803
(task, pid=25491) Request throughput: 11.42 requests/s
(task, pid=25491) Input token throughput: 2500.40 tokens/s
(task, pid=25491) Output token throughput: 2147.98 tokens/s
(task, pid=25491) Mean TTFT: 1981.93 ms
(task, pid=25491) Median TTFT: 1829.33 ms
(task, pid=25491) P99 TTFT: 4511.95 ms
(task, pid=25491) Mean TPOT: 130.71 ms
(task, pid=25491) Median TPOT: 18.88 ms
(task, pid=25491) P99 TPOT: 2487.37 ms
v6e/benchmark-llama2-7b.yaml
envs:
model_name: llama-2
tokenizer_path: /home/gcpuser/sky_workdir/ckpt/llama2-7b/original/tokenizer.model
run: |
cd JetStream
python benchmarks/benchmark_serving.py \
--tokenizer=$tokenizer_path --num-prompts=100 \
--dataset openorca --save-request-outputs \
--warmup-mode=sampled --model=$model_name
v6e/config-8B.json
{
"architectures": [
"LlamaForCausalLM"
],
"attention_bias": false,
"attention_dropout": 0.0,
"bos_token_id": 128000,
"eos_token_id": 128001,
"hidden_act": "silu",
"hidden_size": 4096,
"initializer_range": 0.02,
"intermediate_size": 14336,
"max_position_embeddings": 8192,
"model_type": "llama",
"num_attention_heads": 32,
"num_hidden_layers": 32,
"num_key_value_heads": 8,
"pretraining_tp": 1,
"rms_norm_eps": 1e-05,
"rope_scaling": null,
"rope_theta": 500000.0,
"tie_word_embeddings": false,
"torch_dtype": "bfloat16",
"transformers_version": "4.40.0.dev0",
"use_cache": true,
"vocab_size": 128256
}
v6e/fsdp_config.json
{
"fsdp_transformer_layer_cls_to_wrap": [
"LlamaDecoderLayer"
],
"xla": true,
"xla_fsdp_v2": true,
"xla_fsdp_grad_ckpt": true
}
v6e/serve-llama2-7b.yaml
resources:
accelerators: tpu-v6e-8 # Fill in the accelerator type you want to use
envs:
HF_TOKEN: # fill in your huggingface token
HF_REPO_ID: meta-llama/Llama-2-7b
model_name: llama-2
input_ckpt_dir: /home/gcpuser/sky_workdir/ckpt/llama2-7b/original
output_ckpt_dir: /home/gcpuser/sky_workdir/ckpt/llama2-7b/converted
tokenizer_path: /home/gcpuser/sky_workdir/ckpt/llama2-7b/original/tokenizer.model
setup: |
pip3 install huggingface_hub
python3 -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')"
# Setup TPU
pip3 install cloud-tpu-client
sudo apt update
sudo apt install -y libopenblas-base
pip3 install --pre torch==2.6.0.dev20240916+cpu torchvision==0.20.0.dev20240916+cpu \
--index-url https://download.pytorch.org/whl/nightly/cpu
pip install "torch_xla[tpu]@https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20240916-cp310-cp310-linux_x86_64.whl" \
-f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
# Setup runtime for serving
git clone https://github.com/google/JetStream.git
cd JetStream
git checkout main
git pull origin main
pip install -e .
cd benchmarks
pip install -r requirements.in
cd ../..
git clone https://github.com/google/jetstream-pytorch.git
cd jetstream-pytorch/
git checkout jetstream-v0.2.3
source install_everything.sh
pip3 install -U --pre jax jaxlib libtpu-nightly requests \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/libtpu_releases.html
# Prepare checkpoint, inside jetstream-pytorch repo
mkdir -p ${input_ckpt_dir}
python3 -c "import huggingface_hub; huggingface_hub.snapshot_download('${HF_REPO_ID}', local_dir='${input_ckpt_dir}')"
mkdir -p ${output_ckpt_dir}
python -m convert_checkpoints --model_name=$model_name \
--input_checkpoint_dir=$input_ckpt_dir \
--output_checkpoint_dir=$output_ckpt_dir
run: |
cd jetstream-pytorch
python run_server.py --model_name=$model_name \
--size=7b --batch_size=24 --max_cache_length=2048 \
--checkpoint_path=$output_ckpt_dir \
--tokenizer_path=$tokenizer_path \
--sharding_config="default_shardings/llama.yaml"
v6e/train-llama3-8b.yaml
resources:
accelerators: tpu-v6e-8 # Fill in the accelerator type you want to use
envs:
HF_TOKEN: # fill in your huggingface token
workdir: .
setup: |
pip3 install huggingface_hub
python3 -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')"
# Setup TPU
pip3 install cloud-tpu-client
sudo apt update
sudo apt install -y libopenblas-base
pip3 install --pre torch==2.6.0.dev20240916+cpu torchvision==0.20.0.dev20240916+cpu \
--index-url https://download.pytorch.org/whl/nightly/cpu
pip install "torch_xla[tpu]@https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20240916-cp310-cp310-linux_x86_64.whl" \
-f https://storage.googleapis.com/libtpu-releases/index.html
pip install torch_xla[pallas] \
-f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
-f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html
# Setup runtime for training
git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
cd transformers
pip3 install -e .
pip3 install datasets evaluate scikit-learn accelerate
run: |
unset LD_PRELOAD
PJRT_DEVICE=TPU XLA_USE_SPMD=1 ENABLE_PJRT_COMPATIBILITY=true \
python3 transformers/examples/pytorch/language-modeling/run_clm.py \
--dataset_name wikitext \
--dataset_config_name wikitext-2-raw-v1 \
--per_device_train_batch_size 16 \
--do_train \
--output_dir /home/$USER/tmp/test-clm \
--overwrite_output_dir \
--config_name /home/$USER/sky_workdir/config-8B.json \
--cache_dir /home/$USER/cache \
--tokenizer_name meta-llama/Meta-Llama-3-8B \
--block_size 8192 \
--optim adafactor \
--save_strategy no \
--logging_strategy no \
--fsdp "full_shard" \
--fsdp_config /home/$USER/sky_workdir/fsdp_config.json \
--torch_dtype bfloat16 \
--dataloader_drop_last yes \
--flash_attention \
--max_steps 20