使用 SkyPilot API 服务器在 Airflow 中运行 SkyPilot 任务#
本指南将展示如何使用 SkyPilot 轻松开发包含数据预处理、训练和评估的训练工作流程,然后将其在 Airflow 中进行编排。
本示例使用远程 SkyPilot API 服务器来管理跨调用共享的状态,并包含一个失败回调函数,用于在任务失败时销毁 SkyPilot 集群。
💡 提示: SkyPilot 也支持在不使用 Airflow 的情况下定义和运行流水线。请查看 任务流水线 获取更多信息。
为何在 Airflow 中使用 SkyPilot?#
在 AI 工作流程中,从开发到生产的过渡非常困难。
工作流程开发通常是临时性的,需要与代码和数据进行大量交互。当将其迁移到生产环境中的 Airflow DAG 时,管理依赖项、环境以及工作流程的基础设施需求会变得复杂。将代码移植到 Airflow 需要花费大量时间来测试和验证任何更改,而且往往需要将代码重写为 Airflow 操作符。
SkyPilot 无缝弥合了开发 -> 生产的鸿沟.
SkyPilot 可以在您的任何基础设施上运行,让您可以在生产 Airflow 集群上打包并运行您在开发期间运行的相同代码。在后台,SkyPilot 处理环境设置、依赖项管理和基础设施编排,让您可以专注于您的代码。
以下是如何使用 SkyPilot 将您的开发工作流程带到 Airflow 的生产环境:
将您的工作流程定义为 SkyPilot 任务并进行测试.
使用
sky launch
和 Sky VSCode 集成 来运行、调试和迭代您的代码。
通过在 Airflow DAG 中将 SkyPilot 任务的 YAML 作为任务调用
sky launch
来在 Airflow 中编排 SkyPilot 任务。Airflow 负责调度、日志记录和监控,而 SkyPilot 处理基础设施设置和任务执行。
先决条件#
配置 API 服务器端点#
部署 API 服务器后,您需要配置 Airflow 来使用它。在 Airflow 中设置 SKYPILOT_API_SERVER_ENDPOINT
变量 - run_sky_task
函数将使用它向 API 服务器发送请求
airflow variables set SKYPILOT_API_SERVER_ENDPOINT https://<api-server-endpoint>
您也可以使用 Airflow Web UI 设置该变量
定义任务#
我们将定义以下任务来模拟训练工作流程
data_preprocessing.yaml
:生成数据并将其写入存储桶。train.yaml
:使用存储桶中的数据训练模型。eval.yaml
:评估模型并将评估结果写入存储桶。
我们在此目录中定义了这些任务,并将它们上传到了一个 Git 仓库。
在开发工作流程时,您可以使用 sky launch
独立运行任务
# Run the data preprocessing task, replacing <bucket-name> with the bucket you created above
sky launch -c data --env DATA_BUCKET_NAME=<bucket-name> --env DATA_BUCKET_STORE_TYPE=s3 data_preprocessing.yaml
训练和评估步骤也可以类似方式运行
# Run the train task
sky launch -c train --env DATA_BUCKET_NAME=<bucket-name> --env DATA_BUCKET_STORE_TYPE=s3 train.yaml
提示:您可以使用 ssh
和 VSCode 交互式地开发 和调试任务。
注意:eval
可以选择在与 train
相同的集群上通过 sky exec
运行。
编写 Airflow DAG#
开发完任务后,我们可以无缝地在 Airflow 中运行它们。
我们的任务无需更改 - 我们使用上一步编写的相同 YAML 文件在
sky_train_dag.py
中创建 Airflow DAG。Airflow 原生日志记录 - SkyPilot 日志被写入容器标准输出,这些输出被捕获为 Airflow 中的任务日志并在 UI 中显示。
轻松调试 - 如果任务失败,您可以使用
sky launch
独立运行任务来调试问题。SkyPilot 将重新创建任务失败时的环境。
以下是 sky_train_dag.py 中 DAG 声明的片段
with DAG(dag_id='sky_train_dag',
default_args=default_args,
schedule_interval=None,
catchup=False) as dag:
# Path to SkyPilot YAMLs. Can be a git repo or local directory.
base_path = 'https://github.com/skypilot-org/mock-train-workflow.git'
# Generate bucket UUID as first task
bucket_uuid = generate_bucket_uuid()
# Use the bucket_uuid from previous task
common_envs = {
'DATA_BUCKET_NAME': f"sky-data-demo-{{{{ task_instance.xcom_pull(task_ids='generate_bucket_uuid') }}}}",
'DATA_BUCKET_STORE_TYPE': 's3'
}
preprocess = run_sky_task.override(task_id="data_preprocess")(
repo_url, 'data_preprocessing.yaml', envs_override=common_envs, git_branch='clientserver_example')
train_task = run_sky_task.override(task_id="train")(
repo_url, 'train.yaml', envs_override=common_envs, git_branch='clientserver_example')
eval_task = run_sky_task.override(task_id="eval")(
repo_url, 'eval.yaml', envs_override=common_envs, git_branch='clientserver_example')
# Define the workflow
bucket_uuid >> preprocess >> train_task >> eval_task
在后台,run_sky_task
使用 Airflow 原生 Python 操作符调用 SkyPilot API。所有 SkyPilot API 调用都发送到远程 API 服务器,该服务器使用 SKYPILOT_API_SERVER_ENDPOINT
变量配置。
任务 YAML 文件可以通过两种方式获取
来自 Git 仓库(如上所示)
repo_url = 'https://github.com/skypilot-org/mock-train-workflow.git' run_sky_task(...)(repo_url, 'path/to/yaml', git_branch='optional_branch')
任务在执行前会自动克隆仓库并检出指定分支。
来自本地路径:
local_path = '/path/to/local/directory' run_sky_task(...)(local_path, 'path/to/yaml')
这在开发期间或任务存储在本地时很有用。
所有集群在任务完成后都会设置为自动关机,因此不会留下悬空的集群。
运行 DAG#
将 DAG 文件复制到 Airflow 的 DAGs 目录。
cp sky_train_dag.py /path/to/airflow/dags # If your Airflow is running on Kubernetes, you may use kubectl cp to copy the file to the pod # kubectl cp sky_train_dag.py <airflow-pod-name>:/opt/airflow/dags
运行
airflow dags list
确认 DAG 已加载。在 Airflow UI 中找到 DAG(通常是 http://localhost:8080)并启用它。UI 可能需要几分钟才能反映更改。如果 DAG 被暂停,可以使用
airflow dags unpause sky_train_dag
强制取消暂停。使用 Airflow UI 中的
Trigger DAG
按钮触发 DAG。导航到 Airflow UI 中的运行实例,查看 DAG 进度和每个任务的日志。
如果任务失败,task_failure_callback
将自动销毁 SkyPilot 集群。
未来工作:基于 SkyPilot 构建的 Airflow 原生执行器#
目前本示例依赖于一个助手方法 run_sky_task
来将 SkyPilot 调用包装在 @task 中,但未来 SkyPilot 可以提供 Airflow 原生执行器。
在这种设置下,也不需要 SkyPilot 状态管理,因为执行器将处理 SkyPilot 集群的启动和终止。
包含的文件#
data_preprocessing.yaml
resources:
cpus: 1
envs:
DATA_BUCKET_NAME: sky-demo-data-test
DATA_BUCKET_STORE_TYPE: s3
file_mounts:
/data:
name: $DATA_BUCKET_NAME
store: $DATA_BUCKET_STORE_TYPE
setup: |
echo "Setting up dependencies for data preprocessing..."
run: |
echo "Running data preprocessing..."
# Generate few files with random data to simulate data preprocessing
for i in {0..9}; do
dd if=/dev/urandom of=/data/file_$i bs=1M count=10
done
echo "Data preprocessing completed, wrote to $DATA_BUCKET_NAME"
eval.yaml
resources:
cpus: 1
# Add GPUs here
envs:
DATA_BUCKET_NAME: sky-demo-data-test
DATA_BUCKET_STORE_TYPE: s3
file_mounts:
/data:
name: $DATA_BUCKET_NAME
store: $DATA_BUCKET_STORE_TYPE
setup: |
echo "Setting up dependencies for eval..."
run: |
echo "Evaluating the trained model..."
# Run a mock evaluation job that reads the trained model from /data/trained_model.txt
cat /data/trained_model.txt || true
# Generate a mock accuracy
ACCURACY=$(shuf -i 90-100 -n 1)
echo "Metric - accuracy: $ACCURACY%"
echo "Evaluation report" > /data/evaluation_report.txt
echo "Evaluation completed, report written to $DATA_BUCKET_NAME"
sky_train_dag.py
import os
import uuid
from airflow import DAG
from airflow.decorators import task
from airflow.models import Variable
from airflow.utils.dates import days_ago
import yaml
default_args = {
'owner': 'airflow',
'start_date': days_ago(1),
}
# Unique bucket name for this DAG run
DATA_BUCKET_NAME = str(uuid.uuid4())[:4]
def task_failure_callback(context):
"""Callback to shut down SkyPilot cluster on task failure."""
cluster_name = context['task_instance'].xcom_pull(
task_ids=context['task_instance'].task_id, key='cluster_name')
if cluster_name:
print(
f"Task failed or was cancelled. Shutting down SkyPilot cluster: {cluster_name}"
)
import sky
down_request = sky.down(cluster_name)
sky.stream_and_get(down_request)
@task(on_failure_callback=task_failure_callback)
def run_sky_task(base_path: str,
yaml_path: str,
envs_override: dict = None,
git_branch: str = None,
**kwargs):
"""Generic function to run a SkyPilot task.
This is a blocking call that runs the SkyPilot task and streams the logs.
In the future, we can use deferrable tasks to avoid blocking the worker
while waiting for cluster to start.
Args:
base_path: Base path (local directory or git repo URL)
yaml_path: Path to the YAML file (relative to base_path)
envs_override: Dictionary of environment variables to override in the task config
git_branch: Optional branch name to checkout (only used if base_path is a git repo)
"""
import subprocess
import tempfile
# Set the SkyPilot API server endpoint from Airflow Variables
endpoint = Variable.get('SKYPILOT_API_SERVER_ENDPOINT', None)
if not endpoint:
raise ValueError('SKYPILOT_API_SERVER_ENDPOINT is not set in airflow.')
os.environ['SKYPILOT_API_SERVER_ENDPOINT'] = endpoint
original_cwd = os.getcwd()
try:
# Handle git repos vs local paths
if base_path.startswith(('http://', 'https://', 'git://')):
with tempfile.TemporaryDirectory() as temp_dir:
# TODO(romilb): This assumes git credentials are available in the airflow worker
subprocess.run(['git', 'clone', base_path, temp_dir],
check=True)
# Checkout specific branch if provided
if git_branch:
subprocess.run(['git', 'checkout', git_branch],
cwd=temp_dir,
check=True)
full_yaml_path = os.path.join(temp_dir, yaml_path)
# Change to the temp dir to set context
os.chdir(temp_dir)
# Run the sky task
return _run_sky_task(full_yaml_path, envs_override or {},
kwargs)
else:
full_yaml_path = os.path.join(base_path, yaml_path)
os.chdir(base_path)
# Run the sky task
return _run_sky_task(full_yaml_path, envs_override or {}, kwargs)
finally:
os.chdir(original_cwd)
def _run_sky_task(yaml_path: str, envs_override: dict, kwargs: dict):
"""Internal helper to run the sky task after directory setup."""
import sky
with open(os.path.expanduser(yaml_path), 'r', encoding='utf-8') as f:
task_config = yaml.safe_load(f)
# Initialize envs if not present
if 'envs' not in task_config:
task_config['envs'] = {}
# Update the envs with the override values
# task.update_envs() is not used here, see https://github.com/skypilot-org/skypilot/issues/4363
task_config['envs'].update(envs_override)
task = sky.Task.from_yaml_config(task_config)
cluster_uuid = str(uuid.uuid4())[:4]
task_name = os.path.splitext(os.path.basename(yaml_path))[0]
cluster_name = f'{task_name}-{cluster_uuid}'
kwargs['ti'].xcom_push(key='cluster_name',
value=cluster_name) # For failure callback
launch_request_id = sky.launch(task, cluster_name=cluster_name, down=True)
job_id, _ = sky.stream_and_get(launch_request_id)
# TODO(romilb): In the future, we can use deferrable tasks to avoid blocking
# the worker while waiting for cluster to start.
# Stream the logs for airflow logging
sky.tail_logs(cluster_name=cluster_name, job_id=job_id, follow=True)
# Terminate the cluster after the task is done
down_id = sky.down(cluster_name)
sky.stream_and_get(down_id)
@task
def generate_bucket_uuid(**context):
bucket_uuid = str(uuid.uuid4())[:4]
return bucket_uuid
with DAG(dag_id='sky_train_dag',
default_args=default_args,
schedule_interval=None,
catchup=False) as dag:
# Path to SkyPilot YAMLs. Can be a git repo or local directory.
base_path = 'https://github.com/skypilot-org/mock-train-workflow.git'
# Generate bucket UUID as first task
# See https://stackoverflow.com/questions/55748050/generating-uuid-and-use-it-across-airflow-dag
bucket_uuid = generate_bucket_uuid()
# Use the bucket_uuid from previous task
common_envs = {
'DATA_BUCKET_NAME': f"sky-data-demo-{{{{ task_instance.xcom_pull(task_ids='generate_bucket_uuid') }}}}",
'DATA_BUCKET_STORE_TYPE': 's3'
}
preprocess = run_sky_task.override(task_id="data_preprocess")(
base_path,
'data_preprocessing.yaml',
envs_override=common_envs,
git_branch='clientserver_example')
train_task = run_sky_task.override(task_id="train")(
base_path,
'train.yaml',
envs_override=common_envs,
git_branch='clientserver_example')
eval_task = run_sky_task.override(task_id="eval")(
base_path,
'eval.yaml',
envs_override=common_envs,
git_branch='clientserver_example')
# Define the workflow
bucket_uuid >> preprocess >> train_task >> eval_task
train.yaml
resources:
cpus: 1
# Add GPUs here
envs:
DATA_BUCKET_NAME: sky-demo-data-test
DATA_BUCKET_STORE_TYPE: s3
NUM_EPOCHS: 2
file_mounts:
/data:
name: $DATA_BUCKET_NAME
store: $DATA_BUCKET_STORE_TYPE
setup: |
echo "Setting up dependencies for training..."
run: |
echo "Running training..."
# Run a mock training job that loops through the files in /data starting with 'file_'
for (( i=1; i<=NUM_EPOCHS; i++ )); do
for file in /data/file_*; do
echo "Epoch $i: Training on $file"
sleep 2
done
done
# Mock checkpointing the trained model to the data bucket
echo "Trained model" > /data/trained_model.txt
echo "Training completed, model written to to $DATA_BUCKET_NAME"