来源:examples/airflow

使用 SkyPilot API 服务器在 Airflow 中运行 SkyPilot 任务#

本指南将展示如何使用 SkyPilot 轻松开发包含数据预处理、训练和评估的训练工作流程,然后将其在 Airflow 中进行编排。

本示例使用远程 SkyPilot API 服务器来管理跨调用共享的状态,并包含一个失败回调函数,用于在任务失败时销毁 SkyPilot 集群。

💡 提示: SkyPilot 也支持在不使用 Airflow 的情况下定义和运行流水线。请查看 任务流水线 获取更多信息。

为何在 Airflow 中使用 SkyPilot?#

在 AI 工作流程中,从开发到生产的过渡非常困难

工作流程开发通常是临时性的,需要与代码和数据进行大量交互。当将其迁移到生产环境中的 Airflow DAG 时,管理依赖项、环境以及工作流程的基础设施需求会变得复杂。将代码移植到 Airflow 需要花费大量时间来测试和验证任何更改,而且往往需要将代码重写为 Airflow 操作符。

SkyPilot 无缝弥合了开发 -> 生产的鸿沟.

SkyPilot 可以在您的任何基础设施上运行,让您可以在生产 Airflow 集群上打包并运行您在开发期间运行的相同代码。在后台,SkyPilot 处理环境设置、依赖项管理和基础设施编排,让您可以专注于您的代码。

以下是如何使用 SkyPilot 将您的开发工作流程带到 Airflow 的生产环境:

  1. 将您的工作流程定义为 SkyPilot 任务并进行测试.

  2. 通过在 Airflow DAG 中将 SkyPilot 任务的 YAML 作为任务调用 sky launch在 Airflow 中编排 SkyPilot 任务

    • Airflow 负责调度、日志记录和监控,而 SkyPilot 处理基础设施设置和任务执行。

先决条件#

  • 本地安装了 Airflow (本地SequentialExecutor)

  • 用于发送请求的 SkyPilot API 服务器端点。

    • 如果您没有,请参考 API 服务器文档 来部署一个。

    • 对于本示例:API 服务器应具有 AWS/GCS 访问权限,以便创建存储中间任务输出的存储桶。

配置 API 服务器端点#

部署 API 服务器后,您需要配置 Airflow 来使用它。在 Airflow 中设置 SKYPILOT_API_SERVER_ENDPOINT 变量 - run_sky_task 函数将使用它向 API 服务器发送请求

airflow variables set SKYPILOT_API_SERVER_ENDPOINT https://<api-server-endpoint>

您也可以使用 Airflow Web UI 设置该变量

定义任务#

我们将定义以下任务来模拟训练工作流程

  1. data_preprocessing.yaml:生成数据并将其写入存储桶。

  2. train.yaml:使用存储桶中的数据训练模型。

  3. eval.yaml:评估模型并将评估结果写入存储桶。

我们在此目录中定义了这些任务,并将它们上传到了一个 Git 仓库

在开发工作流程时,您可以使用 sky launch 独立运行任务

# Run the data preprocessing task, replacing <bucket-name> with the bucket you created above
sky launch -c data --env DATA_BUCKET_NAME=<bucket-name> --env DATA_BUCKET_STORE_TYPE=s3 data_preprocessing.yaml

训练和评估步骤也可以类似方式运行

# Run the train task
sky launch -c train --env DATA_BUCKET_NAME=<bucket-name> --env DATA_BUCKET_STORE_TYPE=s3 train.yaml

提示:您可以使用 ssh 和 VSCode 交互式地开发 和调试任务。

注意:eval 可以选择在与 train 相同的集群上通过 sky exec 运行。

编写 Airflow DAG#

开发完任务后,我们可以无缝地在 Airflow 中运行它们。

  1. 我们的任务无需更改 - 我们使用上一步编写的相同 YAML 文件在 sky_train_dag.py 中创建 Airflow DAG。

  2. Airflow 原生日志记录 - SkyPilot 日志被写入容器标准输出,这些输出被捕获为 Airflow 中的任务日志并在 UI 中显示。

  3. 轻松调试 - 如果任务失败,您可以使用 sky launch 独立运行任务来调试问题。SkyPilot 将重新创建任务失败时的环境。

以下是 sky_train_dag.py 中 DAG 声明的片段

with DAG(dag_id='sky_train_dag',
         default_args=default_args,
         schedule_interval=None,
         catchup=False) as dag:
    # Path to SkyPilot YAMLs. Can be a git repo or local directory.
    base_path = 'https://github.com/skypilot-org/mock-train-workflow.git'

    # Generate bucket UUID as first task
    bucket_uuid = generate_bucket_uuid()
    
    # Use the bucket_uuid from previous task
    common_envs = {
        'DATA_BUCKET_NAME': f"sky-data-demo-{{{{ task_instance.xcom_pull(task_ids='generate_bucket_uuid') }}}}",
        'DATA_BUCKET_STORE_TYPE': 's3'
    }
    
    preprocess = run_sky_task.override(task_id="data_preprocess")(
        repo_url, 'data_preprocessing.yaml', envs_override=common_envs, git_branch='clientserver_example')
    train_task = run_sky_task.override(task_id="train")(
        repo_url, 'train.yaml', envs_override=common_envs, git_branch='clientserver_example')
    eval_task = run_sky_task.override(task_id="eval")(
        repo_url, 'eval.yaml', envs_override=common_envs, git_branch='clientserver_example')

    # Define the workflow
    bucket_uuid >> preprocess >> train_task >> eval_task

在后台,run_sky_task 使用 Airflow 原生 Python 操作符调用 SkyPilot API。所有 SkyPilot API 调用都发送到远程 API 服务器,该服务器使用 SKYPILOT_API_SERVER_ENDPOINT 变量配置。

任务 YAML 文件可以通过两种方式获取

  1. 来自 Git 仓库(如上所示)

    repo_url = 'https://github.com/skypilot-org/mock-train-workflow.git'
    run_sky_task(...)(repo_url, 'path/to/yaml', git_branch='optional_branch')
    

    任务在执行前会自动克隆仓库并检出指定分支。

  2. 来自本地路径:

    local_path = '/path/to/local/directory'
    run_sky_task(...)(local_path, 'path/to/yaml')
    

    这在开发期间或任务存储在本地时很有用。

所有集群在任务完成后都会设置为自动关机,因此不会留下悬空的集群。

运行 DAG#

  1. 将 DAG 文件复制到 Airflow 的 DAGs 目录。

    cp sky_train_dag.py /path/to/airflow/dags                                               
    # If your Airflow is running on Kubernetes, you may use kubectl cp to copy the file to the pod
    # kubectl cp sky_train_dag.py <airflow-pod-name>:/opt/airflow/dags 
    
  2. 运行 airflow dags list 确认 DAG 已加载。

  3. 在 Airflow UI 中找到 DAG(通常是 http://localhost:8080)并启用它。UI 可能需要几分钟才能反映更改。如果 DAG 被暂停,可以使用 airflow dags unpause sky_train_dag 强制取消暂停。

  4. 使用 Airflow UI 中的 Trigger DAG 按钮触发 DAG。

  5. 导航到 Airflow UI 中的运行实例,查看 DAG 进度和每个任务的日志。

如果任务失败,task_failure_callback 将自动销毁 SkyPilot 集群。

未来工作:基于 SkyPilot 构建的 Airflow 原生执行器#

目前本示例依赖于一个助手方法 run_sky_task 来将 SkyPilot 调用包装在 @task 中,但未来 SkyPilot 可以提供 Airflow 原生执行器。

在这种设置下,也不需要 SkyPilot 状态管理,因为执行器将处理 SkyPilot 集群的启动和终止。

包含的文件#