模型训练指南#

本指南涵盖了使用 SkyPilot 实现高性能分布式训练的最佳实践和示例。

分布式训练基础知识#

SkyPilot 支持所有分布式训练框架,包括但不限于

框架的选择取决于您的具体需求,但所有框架都可以通过 SkyPilot 的 YAML 规范轻松配置。

最佳实践#

高性能实例#

选择高性能实例以获得最佳训练性能。SkyPilot 允许您指定配备强大 GPU 和高带宽网络的实例类型

  • 使用最新的 GPU 加速器(A100、H100 等)以加快训练速度

  • 对于大型模型,考虑具有更高内存带宽和更高设备内存的实例

示例配置

resources:
  accelerators:
    A100:1
    A100-80GB:1
    H100:1

使用高性能网络#

AWS Elastic Fabric Adapter (EFA) 是一种类似于 Nvidia Infiniband 的网络接口,它使用户能够在 AWS 上大规模运行需要高水平节点间通信的应用程序。您只需在 SkyPilot YAML 中添加一个简单的设置,即可在 AWS HyperPod/EKS 集群上启用 EFA。

示例配置

config:
  kubernetes:
    pod_config:
      spec:
        containers:
          - resources:
            limits:
              vpc.amazonaws.com/efa: 4
            requests:
              vpc.amazonaws.com/efa: 4

更多详情请参阅EFA 示例

GPUDirect-TCPX 是一种高性能网络技术,可实现 GPU 与 a3-highgpu-8ga3-edgegpu-8g VM 的网络接口之间的直接通信。您只需在 SkyPilot YAML 中添加以下设置即可启用它。

示例配置

config:
  gcp:
    enable_gpu_direct: true

更多详情请参阅GPUDirect-TCPX 示例

使用 disk_tier: best#

快速存储对于加载和存储数据及模型检查点至关重要。SkyPilot 的 disk_tier 选项支持最高速的可用存储,利用高性能本地 SSD 来减少 I/O 瓶颈。

示例配置

resources:
  disk_tier: best  # Use highest performance disk tier.
  disk_size: 1000 # GiB. Make the disk size large enough for checkpoints.

使用 MOUNT_CACHED 进行检查点#

使用 MOUNT_CACHED 模式的云存储桶提供高性能写入,非常适合需要快速本地写入的模型检查点、日志和其他输出。

MOUNT 模式不同,它通过使用本地磁盘作为写入云存储桶文件的缓存来支持所有写入和追加操作。与 MOUNT 模式相比,它可以提供高达 9 倍的大型检查点写入速度。

示例配置

file_mounts:
  /checkpoints:
    name: my-checkpoint-bucket
    mode: MOUNT_CACHED

有关 MOUNTMOUNT_CACHED 之间差异的更多信息,请参阅存储挂载模式

用于 Spot 实例的健壮检查点#

使用 Spot 实例时,健壮的检查点机制对于从抢占中恢复至关重要。您的作业应遵循两个关键原则

  1. 在训练期间定期写入检查点以保存您的进度

  2. 无论是否首次运行或抢占后重启,始终尝试在启动时加载检查点

这种方法确保您的作业在抢占后可以无缝地从上次中断的地方恢复。首次运行时,不会存在检查点,但在后续重启时,您的作业将自动恢复其状态。

基本检查点#

保存到存储桶很简单——只需像保存到本地磁盘一样,保存到上面指定的挂载目录 /checkpoints

def save_checkpoint(step: int, model: torch.nn.Module):
    # save checkpoint to local disk with step number
    torch.save(model.state_dict(), f'/checkpoints/model_{step}.pt')

为了使加载检查点能够健壮地应对抢占和不完整的检查点,以下是方法

  • 始终优先尝试加载最新的检查点

  • 如果发现最新的检查点已损坏或不完整,则回退到更早的检查点

以下是展示 torch.save 核心概念的简化示例

def load_checkpoint(save_dir: str='/checkpoints'):
    try:
        # Find all checkpoints, sorted by step (newest first)
        checkpoints = sorted(
            [f for f in Path(save_dir).glob("checkpoint_*.pt")],
            key=lambda x: int(x.stem.split('_')[-1]),
            reverse=True
        )

        # Try each checkpoint from newest to oldest
        for checkpoint in checkpoints:
            try:
                step = int(checkpoint.stem.split('_')[-1])
                result = load_checkpoint(checkpoint) # need to fill in
                return result
            except Exception as e:
                logger.warning(f"Failed to load checkpoint {step}: {e}")
                continue
    except Exception as e:
        logger.error(f"Failed to find checkpoints: {e}")
        return None

带有错误处理的健壮检查点#

要查看包含自定义前缀、扩展元数据和更详细错误处理等附加功能的完整实现,请参阅以下代码

完整实现
from datetime import datetime
import functools
import json
import logging
import os
from pathlib import Path
from typing import Any, Callable, Dict, Optional, TypeVar, Union

import torch

logger = logging.getLogger(__name__)

T = TypeVar('T')

def save_checkpoint(
    save_dir: str,
    max_checkpoints: int = 5,
    checkpoint_prefix: str = "checkpoint",
):
    """
    Decorator for saving checkpoints with fallback mechanism.

    Args:
        save_dir: Directory to save checkpoints
        max_checkpoints: Maximum number of checkpoints to keep
        checkpoint_prefix: Prefix for checkpoint files

    Examples:
        # Basic usage with a simple save function
        @save_checkpoint(save_dir="checkpoints")
        def save_model(step: int, model: torch.nn.Module):
            torch.save(model.state_dict(), f"checkpoints/model_{step}.pt")

        # With custom save function that includes optimizer
        @save_checkpoint(save_dir="checkpoints")
        def save_training_state(step: int, model: torch.nn.Module, optimizer: torch.optim.Optimizer):
            torch.save({
                'model': model.state_dict(),
                'optimizer': optimizer.state_dict(),
                'step': step
            }, f"checkpoints/training_{step}.pt")

        # With additional data and custom prefix
        @save_checkpoint(save_dir="checkpoints", checkpoint_prefix="experiment1")
        def save_with_metrics(step: int, model: torch.nn.Module, metrics: Dict[str, float]):
            torch.save({
                'model': model.state_dict(),
                'metrics': metrics,
                'step': step
            }, f"checkpoints/experiment1_step_{step}.pt")
    """
    def decorator(func: Callable[..., T]) -> Callable[..., T]:
        # Initialize state
        save_dir_path = Path(save_dir)
        save_dir_path.mkdir(parents=True, exist_ok=True)

        @functools.wraps(func)
        def wrapper(*args, **kwargs) -> T:
            # Get current step from kwargs or args
            step = kwargs.get('step', args[0] if args else None)
            if step is None:
                return func(*args, **kwargs)

            try:
                # Call the original save function
                result = func(*args, **kwargs)

                # Save metadata
                metadata = {
                    'step': step,
                    'timestamp': datetime.now().isoformat(),
                    'model_type': kwargs.get('model', args[1] if len(args) > 1 else None).__class__.__name__,
                }

                metadata_path = save_dir_path / f"{checkpoint_prefix}_step_{step}_metadata.json"
                with open(metadata_path, 'w') as f:
                    json.dump(metadata, f)

                # Cleanup old checkpoints
                checkpoints = sorted(
                    [f for f in save_dir_path.glob(f"{checkpoint_prefix}_step_*.pt")],
                    key=lambda x: int(x.stem.split('_')[-1])
                )

                while len(checkpoints) > max_checkpoints:
                    oldest_checkpoint = checkpoints.pop(0)
                    oldest_checkpoint.unlink()
                    metadata_path = oldest_checkpoint.with_suffix('_metadata.json')
                    if metadata_path.exists():
                        metadata_path.unlink()

                logger.info(f"Saved checkpoint at step {step}")
                return result

            except Exception as e:
                logger.error(f"Failed to save checkpoint at step {step}: {str(e)}")
                return func(*args, **kwargs)

        return wrapper
    return decorator

def load_checkpoint(
    save_dir: str,
    checkpoint_prefix: str = "checkpoint",
):
    """
    Decorator for loading checkpoints with fallback mechanism.
    Tries to load from the latest checkpoint, if that fails tries the second latest, and so on.

    Args:
        save_dir: Directory containing checkpoints
        checkpoint_prefix: Prefix for checkpoint files

    Examples:
        # Basic usage with a simple load function
        @load_checkpoint(save_dir="checkpoints")
        def load_model(step: int, model: torch.nn.Module):
            model.load_state_dict(torch.load(f"checkpoints/model_{step}.pt"))

        # Loading with optimizer
        @load_checkpoint(save_dir="checkpoints")
        def load_training_state(step: int, model: torch.nn.Module, optimizer: torch.optim.Optimizer):
            checkpoint = torch.load(f"checkpoints/training_{step}.pt")
            model.load_state_dict(checkpoint['model'])
            optimizer.load_state_dict(checkpoint['optimizer'])
            return checkpoint['step']

        # Loading with custom prefix and additional data
        @load_checkpoint(save_dir="checkpoints", checkpoint_prefix="experiment1")
        def load_with_metrics(step: int, model: torch.nn.Module):
            checkpoint = torch.load(f"checkpoints/experiment1_step_{step}.pt")
            model.load_state_dict(checkpoint['model'])
            return checkpoint['metrics']
    """
    def decorator(func: Callable[..., T]) -> Callable[..., T]:
        save_dir_path = Path(save_dir)

        @functools.wraps(func)
        def wrapper(*args, **kwargs) -> T:
            try:
                # Find available checkpoints
                checkpoints = sorted(
                    [f for f in save_dir_path.glob(f"{checkpoint_prefix}_step_*.pt")],
                    key=lambda x: int(x.stem.split('_')[-1]),
                    reverse=True  # Sort in descending order (newest first)
                )

                if not checkpoints:
                    logger.warning("No checkpoints found")
                    return func(*args, **kwargs)

                # Try each checkpoint from newest to oldest
                for checkpoint in checkpoints:
                    try:
                        step = int(checkpoint.stem.split('_')[-1])

                        # Call the original load function with the current step
                        if 'step' in kwargs:
                            kwargs['step'] = step
                        elif args:
                            args = list(args)
                            args[0] = step
                            args = tuple(args)

                        result = func(*args, **kwargs)
                        logger.info(f"Successfully loaded checkpoint from step {step}")
                        return result

                    except Exception as e:
                        logger.warning(f"Failed to load checkpoint at step {step}, trying previous checkpoint: {str(e)}")
                        continue

                # If we get here, all checkpoints failed
                logger.error("Failed to load any checkpoint")
                return func(*args, **kwargs)

            except Exception as e:
                logger.error(f"Failed to find checkpoints: {str(e)}")
                return func(*args, **kwargs)

        return wrapper
    return decorator

以下是使用检查点系统的一些常见方法

基本模型保存

@save_checkpoint(save_dir="checkpoints")
def save_model(step: int, model: torch.nn.Module):
    torch.save(model.state_dict(), f"checkpoints/model_{step}.pt")

保存优化器状态

@save_checkpoint(save_dir="checkpoints")
def save_training_state(step: int, model: torch.nn.Module, optimizer: torch.optim.Optimizer):
    torch.save({
        'model': model.state_dict(),
        'optimizer': optimizer.state_dict(),
        'step': step
    }, f"checkpoints/training_{step}.pt")

保存指标和自定义前缀

@save_checkpoint(save_dir="checkpoints", checkpoint_prefix="experiment1")
def save_with_metrics(step: int, model: torch.nn.Module, metrics: Dict[str, float]):
    torch.save({
        'model': model.state_dict(),
        'metrics': metrics,
        'step': step
    }, f"checkpoints/experiment1_step_{step}.pt")

加载检查点

# Basic model loading
@load_checkpoint(save_dir="checkpoints")
def load_model(step: int, model: torch.nn.Module):
    model.load_state_dict(torch.load(f"checkpoints/model_{step}.pt"))

# Loading with optimizer
@load_checkpoint(save_dir="checkpoints")
def load_training_state(step: int, model: torch.nn.Module, optimizer: torch.optim.Optimizer):
    checkpoint = torch.load(f"checkpoints/training_{step}.pt")
    model.load_state_dict(checkpoint['model'])
    optimizer.load_state_dict(checkpoint['optimizer'])
    return checkpoint['step']

# Loading with custom prefix and metrics
@load_checkpoint(save_dir="checkpoints", checkpoint_prefix="experiment1")
def load_with_metrics(step: int, model: torch.nn.Module):
    checkpoint = torch.load(f"checkpoints/experiment1_step_{step}.pt")
    model.load_state_dict(checkpoint['model'])
    return checkpoint['metrics']

示例#

端到端 BERT#

我们可以从上方获取用于 BERT 微调的 SkyPilot YAML,并添加检查点/恢复功能,以使其端到端地正常工作。

注意

您可以在文档中找到此示例的所有代码

在此示例中,我们使用 HuggingFace 对 BERT 模型进行问答任务的微调。

此示例

  • 让 SkyPilot 在任何云上查找 V100 实例,

  • 使用 Spot 实例节省成本,并且

  • 使用检查点快速恢复被抢占的作业。

# bert_qa.yaml
name: bert-qa

resources:
  accelerators: V100:1
  use_spot: true  # Use spot instances to save cost.
  disk_tier: best # using highest performance disk tier

file_mounts:
  /checkpoint:
    name: # NOTE: Fill in your bucket name
    mode: MOUNT_CACHED

envs:
  # Fill in your wandb key: copy from https://wandb.ai/authorize
  # Alternatively, you can use `--env WANDB_API_KEY=$WANDB_API_KEY`
  # to pass the key in the command line, during `sky jobs launch`.
  WANDB_API_KEY:

# Assume your working directory is under `~/transformers`.
workdir: ~/transformers

setup: |
  pip install -e .
  cd examples/pytorch/question-answering/
  pip install -r requirements.txt torch==1.12.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113
  pip install wandb

run: |
  cd examples/pytorch/question-answering/
  python run_qa.py \
    --model_name_or_path bert-base-uncased \
    --dataset_name squad \
    --do_train \
    --do_eval \
    --per_device_train_batch_size 12 \
    --learning_rate 3e-5 \
    --num_train_epochs 50 \
    --max_seq_length 384 \
    --doc_stride 128 \
    --report_to wandb \
    --output_dir /checkpoint/bert_qa/ \
    --run_name $SKYPILOT_TASK_ID \
    --save_total_limit 10 \
    --save_steps 1000

高亮显示的行添加了一个用于存放检查点的存储桶。由于 HuggingFace 内置了对定期检查点的支持,我们只需要传递高亮显示的参数即可将检查点保存到存储桶。(更多信息请参阅Huggingface API)。要查看使用 PyTorch 进行定期检查点的另一个示例,请查看我们的 ResNet 示例

我们还将 --run_name 设置为 $SKYPILOT_TASK_ID,以便同一作业的所有恢复日志都将保存到 Weights & Biases 中的同一个运行记录下。

注意

环境变量 $SKYPILOT_TASK_ID(例如:“sky-managed-2022-10-06-05-17-09-750781_bert-qa_8-0”)可用于识别同一作业,即在作业的所有恢复过程中保持一致。它可以在任务的 run 命令中访问,也可以直接在程序本身中访问(例如,通过 os.environ 访问并传递给 Weights & Biases 用于训练脚本中的跟踪目的)。每次任务被调用时,该变量都可用。更多信息请参阅SkyPilot 提供的环境变量

通过高亮显示的更改,托管作业现在可以在抢占后恢复训练!我们可以享受 Spot 实例带来的成本节省,而不必担心抢占或丢失进度。

$ sky jobs launch -n bert-qa bert_qa.yaml

实际案例#