Cloud TPU#

SkyPilot 支持在 Google 的 Cloud TPU 上运行作业,Cloud TPU 是一种专用于 ML 工作负载的硬件加速器。

通过 TPU 研究云 (TRC) 获得免费 TPU#

鼓励 ML 研究人员和学生通过 TPU 研究云 (TRC) 项目申请免费的 TPU 访问权限!

一个命令获取 TPU#

使用一个命令快速获取用于开发的 TPU 节点

# Use latest TPU v6 (Trillium) VMs:
sky launch --gpus tpu-v6e-8
# Use TPU v4 (Titan) VMs:
sky launch --gpus tpu-v4-8
# Preemptible TPUs:
sky launch --gpus tpu-v6e-8 --use-spot

命令完成后,您将进入 TPU 主机 VM,并可以立即开始开发代码。

下面,我们将展示使用 SkyPilot 的示例:(1) 在 TPU VM/Pod 上训练 LLM,以及 (2) 在 TPU 节点(旧版)上训练 MNIST。

TPU 架构#

GCP 上提供两种不同的 TPU 架构

SkyPilot 支持这两种架构。我们推荐 GCP 推崇的更新架构:TPU VM 和 Pod。

两种架构区别如下。

  • 对于 TPU VM/Pod,您可以直接 SSH 连接到与 TPU 设备物理连接的“TPU 主机”VM。

  • 对于 TPU 节点,必须单独调配一个用户 VM(一个 n1 实例)通过 gRPC 与不可访问的 TPU 主机通信。

更多详情请参见 GCP 文档

TPU VM/Pod#

Google 最新的 TPU v6 (Trillium) VM 提供了出色的性能,现已获得 SkyPilot 支持。

要使用 TPU VM/Pod,请在任务 YAML 的 resources 字段中设置以下内容

resources:
   accelerators: tpu-v6e-8
   accelerator_args:
      runtime_version: v2-alpha-tpuv6e  # optional

字段 accelerators 指定 TPU 类型,字典 accelerator_args 包含可选的布尔值 tpu_vm(默认为 true,表示使用 TPU VM)和可选的 TPU runtime_version 字段。要显示支持的 TPU 类型,请运行 sky show-gpus

这里是一个完整的任务 YAML,用于使用 Torch XLA 在 TPU VM 上训练 Llama 3 模型

resources:
   accelerators: tpu-v6e-8 # Fill in the accelerator type you want to use

envs:
   HF_TOKEN: # fill in your huggingface token

workdir: .

setup: |
   pip3 install huggingface_hub
   python3 -c "import huggingface_hub; huggingface_hub.login('${HF_TOKEN}')"

   # Setup TPU
   pip3 install cloud-tpu-client
   sudo apt update
   sudo apt install -y libopenblas-base
   pip3 install --pre torch==2.6.0.dev20240916+cpu torchvision==0.20.0.dev20240916+cpu \
      --index-url https://download.pytorch.org/whl/nightly/cpu
   pip install "torch_xla[tpu]@https://storage.googleapis.com/pytorch-xla-releases/wheels/tpuvm/torch_xla-2.6.0.dev20240916-cp310-cp310-linux_x86_64.whl" \
      -f https://storage.googleapis.com/libtpu-releases/index.html
   pip install torch_xla[pallas] \
      -f https://storage.googleapis.com/jax-releases/jax_nightly_releases.html \
      -f https://storage.googleapis.com/jax-releases/jaxlib_nightly_releases.html

   # Setup runtime for training
   git clone -b flash_attention https://github.com/pytorch-tpu/transformers.git
   cd transformers
   pip3 install -e .
   pip3 install datasets evaluate scikit-learn accelerate

run: |
   unset LD_PRELOAD
   PJRT_DEVICE=TPU XLA_USE_SPMD=1 ENABLE_PJRT_COMPATIBILITY=true \
   python3 transformers/examples/pytorch/language-modeling/run_clm.py \
      --dataset_name wikitext \
      --dataset_config_name wikitext-2-raw-v1 \
      --per_device_train_batch_size 16 \
      --do_train \
      --output_dir /home/$USER/tmp/test-clm \
      --overwrite_output_dir \
      --config_name /home/$USER/sky_workdir/config-8B.json \
      --cache_dir /home/$USER/cache \
      --tokenizer_name meta-llama/Meta-Llama-3-8B \
      --block_size 8192 \
      --optim adafactor \
      --save_strategy no \
      --logging_strategy no \
      --fsdp "full_shard" \
      --fsdp_config /home/$USER/sky_workdir/fsdp_config.json \
      --torch_dtype bfloat16 \
      --dataloader_drop_last yes \
      --flash_attention \
      --max_steps 20

此 YAML 位于 SkyPilot 仓库下,您也可以将其粘贴到本地文件中。

使用以下命令启动它

$ HF_TOKEN=<your-huggingface-token> sky launch train-llama3-8b.yaml -c llama-3-train --env HF_TOKEN

作业完成后,您应该会看到以下输出。

$ sky launch train-llama3-8b.yaml -c llama-3-train
(task, pid=17499) ***** train metrics *****
(task, pid=17499)   epoch                    =      1.1765
(task, pid=17499)   total_flos               = 109935420GF
(task, pid=17499)   train_loss               =     10.6011
(task, pid=17499)   train_runtime            =  0:11:12.77
(task, pid=17499)   train_samples            =         282
(task, pid=17499)   train_samples_per_second =       0.476
(task, pid=17499)   train_steps_per_second   =        0.03

多主机 TPU Pod#

一个 TPU Pod 是由专用高速网络接口连接起来的一组 TPU 设备,用于高性能训练。

要使用 TPU Pod,只需更改任务 YAML 中的 accelerators 字段(例如,将 tpu-v6e-8 改为 tpu-v6e-32)。

resources:
   accelerators: tpu-v6e-32  # Pods have > 8 cores (the last number)

注意

TPU VM 和 TPU 节点这两种 TPU 架构都可以与 TPU Pod 结合使用。下面的示例基于 TPU VM。

要显示所有可用的 TPU Pod 类型,请运行 sky show-gpus(超过 8 核表示 Pod)

GOOGLE_TPU    AVAILABLE_QUANTITIES
tpu-v6e-8     1
tpu-v6e-32    1
tpu-v6e-128   1
tpu-v6e-256   1
...

创建 TPU Pod 后,会启动多个主机 VM(例如,tpu-v6e-32 附带 4 个主机 VM)。通常,用户需要 SSH 连接到所有主机来准备文件和设置环境,然后在每个主机上启动作业,这个过程既繁琐又容易出错。

SkyPilot 使这个复杂性自动化。从您的笔记本电脑上,一个简单的 sky launch 命令将执行

  • 工作目录/文件挂载同步;以及

  • 在 Pod 的每个主机上执行 setup/run 命令。

我们可以使用以下命令在 TPU Pod 上运行相同的 Llama 3 训练作业,只需对 YAML 进行微小更改(将 --per_device_train_batch_size 从 16 改为 32)

$ HF_TOKEN=<your-huggingface-token> sky launch -c tpu-pod --gpus tpu-v6e-32 train-llama3-8b.yaml --env HF_TOKEN

您应该会看到以下输出。

(head, rank=0, pid=17894) ***** train metrics *****
(head, rank=0, pid=17894)   epoch                    =         2.5
(head, rank=0, pid=17894)   total_flos               = 219870840GF
(head, rank=0, pid=17894)   train_loss               =     10.1527
(head, rank=0, pid=17894)   train_runtime            =  0:11:13.18
(head, rank=0, pid=17894)   train_samples            =         282
(head, rank=0, pid=17894)   train_samples_per_second =       0.951
(head, rank=0, pid=17894)   train_steps_per_second   =        0.03

(worker1, rank=1, pid=15406, ip=10.164.0.57) ***** train metrics *****
(worker1, rank=1, pid=15406, ip=10.164.0.57)   epoch                    =         2.5
(worker1, rank=1, pid=15406, ip=10.164.0.57)   total_flos               = 219870840GF
(worker1, rank=1, pid=15406, ip=10.164.0.57)   train_loss               =     10.1527
(worker1, rank=1, pid=15406, ip=10.164.0.57)   train_runtime            =  0:11:15.08
(worker1, rank=1, pid=15406, ip=10.164.0.57)   train_samples            =         282
(worker1, rank=1, pid=15406, ip=10.164.0.57)   train_samples_per_second =       0.948
(worker1, rank=1, pid=15406, ip=10.164.0.57)   train_steps_per_second   =        0.03

(worker2, rank=2, pid=16552, ip=10.164.0.58) ***** train metrics *****
(worker2, rank=2, pid=16552, ip=10.164.0.58)   epoch                    =         2.5
(worker2, rank=2, pid=16552, ip=10.164.0.58)   total_flos               = 219870840GF
(worker2, rank=2, pid=16552, ip=10.164.0.58)   train_loss               =     10.1527
(worker2, rank=2, pid=16552, ip=10.164.0.58)   train_runtime            =  0:11:15.61
(worker2, rank=2, pid=16552, ip=10.164.0.58)   train_samples            =         282
(worker2, rank=2, pid=16552, ip=10.164.0.58)   train_samples_per_second =       0.947
(worker2, rank=2, pid=16552, ip=10.164.0.58)   train_steps_per_second   =        0.03

(worker3, rank=3, pid=17469, ip=10.164.0.59) ***** train metrics *****
(worker3, rank=3, pid=17469, ip=10.164.0.59)   epoch                    =         2.5
(worker3, rank=3, pid=17469, ip=10.164.0.59)   total_flos               = 219870840GF
(worker3, rank=3, pid=17469, ip=10.164.0.59)   train_loss               =     10.1527
(worker3, rank=3, pid=17469, ip=10.164.0.59)   train_runtime            =  0:11:15.10
(worker3, rank=3, pid=17469, ip=10.164.0.59)   train_samples            =         282
(worker3, rank=3, pid=17469, ip=10.164.0.59)   train_samples_per_second =       0.948
(worker3, rank=3, pid=17469, ip=10.164.0.59)   train_steps_per_second   =        0.03

要向同一个 TPU Pod 提交更多作业,请使用 sky exec

$ HF_TOKEN=<your-huggingface-token> sky exec tpu-pod train-llama3-8b.yaml --env HF_TOKEN

您可以在 SkyPilot 仓库中找到更多在 TPU 上部署 LLM 的有用示例。

TPU 节点(旧版)#

在 TPU 节点中,需要调配一个普通的 CPU VM(一个 n1 实例)来与 TPU 主机/设备通信。

要使用 TPU 节点,请在任务 YAML 的 resources 字段中设置以下内容

resources:
   instance_type: n1-highmem-8
   accelerators: tpu-v2-8
   accelerator_args:
      runtime_version: 2.12.0  # optional, TPU runtime version.
      tpu_vm: False

上面的 YAML 将 n1-highmem-8 视为主机,将 tpu-v2-8 视为 TPU 节点资源。您可以修改主机实例类型或 TPU 类型。

这里是一个完整的任务 YAML,使用 TensorFlow 在 TPU 节点上运行 MNIST 训练

name: mnist-tpu-node

resources:
   accelerators: tpu-v2-8
   accelerator_args:
      runtime_version: 2.12.0  # optional, TPU runtime version.
      tpu_vm: False

# TPU node requires loading data from a GCS bucket.
# We use SkyPilot bucket mounting to mount a GCS bucket to /dataset.
file_mounts:
   /dataset:
      name: mnist-tpu-node
      store: gcs
      mode: MOUNT

setup: |
   git clone https://github.com/tensorflow/models.git

   conda activate mnist
   if [ $? -eq 0 ]; then
      echo 'conda env exists'
   else
      conda create -n mnist python=3.8 -y
      conda activate mnist
      pip install tensorflow==2.12.0 tensorflow-datasets tensorflow-model-optimization cloud-tpu-client
   fi

run: |
   conda activate mnist
   cd models/official/legacy/image_classification/

   export STORAGE_BUCKET=gs://mnist-tpu-node
   export MODEL_DIR=${STORAGE_BUCKET}/mnist
   export DATA_DIR=${STORAGE_BUCKET}/data

   export PYTHONPATH=/home/gcpuser/sky_workdir/models

   python3 mnist_main.py \
      --tpu=${TPU_NAME} \
      --model_dir=${MODEL_DIR} \
      --data_dir=${DATA_DIR} \
      --train_epochs=10 \
      --distribution_strategy=tpu \
      --download

注意

TPU 节点需要从 GCS 存储桶加载数据。上面的 file_mounts 规范通过使用 SkyPilot 存储桶挂载来创建新存储桶/挂载现有存储桶,从而简化了此过程。如果您遇到存储桶 Permission denied 错误,请确保存储桶与主机 VM/TPU 节点位于同一区域,并且 Cloud TPU 的 IAM 权限已正确设置(按照此处的说明进行)。

注意

特殊的环环境变量 $TPU_NAME 由 SkyPilot 在运行时自动设置,因此可以在 run 命令中使用它。

此 YAML 位于 SkyPilot 仓库下(examples/tpu/tpu_node_mnist.yaml)。使用以下命令启动它

$ sky launch examples/tpu/tpu_node_mnist.yaml  -c mycluster
...
(mnist-tpu-node pid=28961) Epoch 9/10
(mnist-tpu-node pid=28961) 58/58 [==============================] - 1s 19ms/step - loss: 0.1181 - sparse_categorical_accuracy: 0.9646 - val_loss: 0.0921 - val_sparse_categorical_accuracy: 0.9719
(mnist-tpu-node pid=28961) Epoch 10/10
(mnist-tpu-node pid=28961) 58/58 [==============================] - 1s 20ms/step - loss: 0.1139 - sparse_categorical_accuracy: 0.9655 - val_loss: 0.0831 - val_sparse_categorical_accuracy: 0.9742
...
(mnist-tpu-node pid=28961) {'accuracy_top_1': 0.9741753339767456, 'eval_loss': 0.0831054300069809, 'loss': 0.11388632655143738, 'training_accuracy_top_1': 0.9654667377471924}