使用基础模型微调 API 创建训练运行

重要说明

此功能在以下区域提供公共预览版centraluseastuseastus2northcentraluswestus

本文介绍如何使用基础模型微调(现为 Mosaic AI 模型训练的一部分)API 创建和配置训练运行,并描述 API 调用中使用的所有参数。 也可以使用 UI 创建运行。 有关说明,请参阅使用基础模型微调 UI 创建训练运行

要求

请参阅要求

创建训练运行

若要以编程方式创建训练运行,请使用 create() 函数。 此函数基于提供的数据集训练模型,并保存已训练的模型以供推理。

所需的输入是要训练的模型、训练数据集的位置以及注册模型的位置。 还有一些可选参数可用于执行评估和更改运行的超参数。

运行完成后,保存已完成的运行和最终检查点,克隆模型,并将该克隆注册到 Unity Catalog 作为模型版本进行推理。

已完成运行的模型(不是 Unity Catalog 中的克隆模型版本)将保存到 MLflow。 检查点可用于持续性的微调任务。

有关 函数的参数的详细信息,请参阅create()

from databricks.model_training import foundation_model as fm

run = fm.create(
  model='meta-llama/Llama-3.2-3B-Instruct',
  train_data_path='dbfs:/Volumes/main/mydirectory/ift/train.jsonl', # UC Volume with JSONL formatted data
  # Public HF dataset is also supported
  # train_data_path='mosaicml/dolly_hhrlhf/train'
  register_to='main.mydirectory', # UC catalog and schema to register the model to
)

配置训练运行

下表汇总了 foundation_model.create() 函数的参数。

参数 必需 类型 说明
model x str 要使用的模型的名称。 请参阅支持的模型
train_data_path x str 训练数据的位置。 这可以是 Unity Catalog 中的位置(<catalog>.<schema>.<table>dbfs:/Volumes/<catalog>/<schema>/<volume>/<dataset>.jsonl),也可以是 HuggingFace 数据集。
对于 INSTRUCTION_FINETUNE,应使用包含 promptresponse 字段的每一行设置数据格式。
对于 CONTINUED_PRETRAIN,这是 .txt 文件的文件夹。 请参阅准备基础模型微调的数据了解接受的数据格式,以及参阅模型训练的建议数据大小获取数据大小建议。
register_to x str Unity Catalog 目录和架构(<catalog>.<schema><catalog>.<schema>.<custom-name>),训练后在其中注册模型,以便于部署。 如果未提供 custom-name,则默认为运行的名称。
data_prep_cluster_id str 要用于 Spark 数据处理的群集的群集 ID。 对于训练数据位于 Delta 表中的指令训练任务,这是必需的。 有关如何查找群集 ID 的信息,请参阅获取群集 ID
experiment_path str MLflow 试验的路径,其中保存了训练运行输出(指标和检查点)。 默认为用户个人工作区(即 /Users/<username>/<run_name>)中的运行名称。
task_type str 要运行的任务的类型。 可以是 CHAT_COMPLETION(默认值)、CONTINUED_PRETRAININSTRUCTION_FINETUNE
eval_data_path str 评估数据(如果有)的远程位置。 必须遵循与 train_data_path 相同的格式。
eval_prompts 列表[字符串] 用于在评估期间生成响应的提示字符串列表。 默认值为 None(不生成提示)。 每次对模型设置检查点时,结果都会记录到试验中。 生成发生在具有以下生成参数的每个模型检查点上:max_new_tokens: 100temperature: 1top_k: 50top_p: 0.95do_sample: true
custom_weights_path str 用于训练的自定义模型检查点的远程位置。 默认值为 None,表示运行从所选模型的原始预训练权重开始。 如果提供了自定义权重,则使用这些权重,而不使用模型的原始预训练权重。 这些权重必须已由微调 API 生成,并与指定的 model 的体系结构匹配。 请参阅基于自定义模型权重生成
注意:如果你在 2025 年 3 月 26 日之前训练了模型,则无法再从这些模型检查点持续进行训练。 任何以前完成的训练运行仍可使用预配的吞吐量进行处理,而不会出现问题。
training_duration str 运行的总持续时间。 默认值为一个纪元或 1ep。 可以按纪元 (10ep) 或标记 (1000000tok) 指定。
learning_rate str 模型训练的学习率。 所有模型都使用 AdamW 优化器进行训练,并采用学习速率预热。 默认学习率因模型而异。 建议运行超参数扫描,尝试不同的学习率和训练持续时间,以获得最高质量的模型。
context_length str 数据样本的最大序列长度。 这用于截断任何过长的数据,并将较短的序列打包在一起,以提高效率。
默认值为 8192 个令牌或所提供模型的最大上下文长度(以较低者为准)。 可以使用此参数来配置上下文长度,但不支持在每个模型的最大上下文长度之外进行配置。 有关每个模型支持的最大上下文长度,请参阅支持的模型
validate_inputs 布尔值 是否在提交训练作业之前验证对输入路径的访问权限。 默认为 True

基于自定义模型权重生成

注意:如果你在 2025 年 3 月 26 日之前训练了模型,则无法再从这些模型检查点持续进行训练。 任何以前完成的训练运行仍可使用预配的吞吐量进行处理,而不会出现问题。

基础模型微调支持使用可选参数 custom_weights_path 添加自定义权重来训练和自定义模型。

首先,请将 custom_weights_path 设置为先前微调 API 训练运行的检查点路径。 可以在先前 MLflow 运行的“工件”选项卡中找到检查点路径。 检查点文件夹名称对应于特定快照的批处理和纪元,例如 ep29-ba30/

上一次 MLflow 运行的“工件”选项卡

  • 若要提供先前运行的最新检查点,请将 custom_weights_path 设置为微调 API 生成的检查点。 例如 custom_weights_path=dbfs:/databricks/mlflow-tracking/<experiment_id>/<run_id>/artifacts/<run_name>/checkpoints/latest-sharded-rank0.symlink
  • 若要提供更早的检查点,请将 custom_weights_path 设置为包含与所需检查点对应的 .distcp 文件的文件夹的路径,例如 custom_weights_path=dbfs:/databricks/mlflow-tracking/<experiment_id>/<run_id>/artifacts/<run_name>/checkpoints/ep#-ba#

接下来,更新 model 参数以匹配传递给 custom_weights_path 的检查点的基础模型。

在以下示例中,ift-meta-llama-3-1-70b-instruct-ohugkq 是对 meta-llama/Meta-Llama-3.1-70B 进行微调的先前运行。 若要微调 ift-meta-llama-3-1-70b-instruct-ohugkq 的最新检查点,请按如下所示设置 modelcustom_weights_path 变量:

from databricks.model_training import foundation_model as fm

run = fm.create(
model = 'meta-llama/Meta-Llama-3.1-70B'
custom_weights_path = 'dbfs:/databricks/mlflow-tracking/2948323364469837/d4cd1fcac71b4fb4ae42878cb81d8def/artifacts/ift-meta-llama-3-1-70b-instruct-ohugkq/checkpoints/latest-sharded-rank0.symlink'
... ## other parameters for your fine-tuning run

)

请参阅配置训练运行,了解如何在微调运行中配置其他参数。

获取群集 ID

检索群集 ID:

  1. 在 Databricks 工作区的左侧导航栏中,单击“计算”。

  2. 在表中,单击群集的名称。

  3. 单击右上角的 “更多”按钮,然后从下拉菜单中选择“查看 JSON”。

  4. 此时将显示群集 JSON 文件。 复制群集 ID,这是文件中的第一行内容。

    群集 ID

获取运行状态

可以使用 Databricks UI 中的“试验”页或使用 API 命令 get_events() 跟踪运行进度。 有关详细信息,请参阅查看、管理和分析基础模型微调运行

get_events() 的示例输出:

使用 API 获取运行状态

“试验”页上的示例运行详细信息:

从试验 UI 获取运行状态

后续步骤

训练运行完成后,可以在 MLflow 中查看指标,并部署模型进行推理。 请参阅教程:创建和部署基础模型微调运行的第 5 至 7 步。

请参阅指令微调:命名实体识别演示笔记本,查看指令微调示例,按步骤进行数据准备、微调训练运行配置和部署。

笔记本示例

以下笔记本示例演示如何使用 Meta Llama 3.1 405B Instruct 模型生成合成数据,并使用该数据对模型进行微调:

使用 Llama 3.1 405B Instruct 笔记本生成合成数据

获取笔记本

其他资源