当前位置: 首页 > news >正文

西安哪里有做网站的做视频网站要用到的服务器

西安哪里有做网站的,做视频网站要用到的服务器,建设景区网站的目的,河北省建设厅报名网站目录 4 GRPO源码分析4.1 数据类 GRPOScriptArguments4.2 系统提示字符串 SYSTEM_PROMPT4.3 奖励函数4.3.1 accuracy_reward函数4.3.2 verify函数4.3.3 format_reward函数 4.4 将数据集格式化为对话形式4.5 初始化GRPO Trainer 【复现DeepSeek-R1之Open R1实战】系列3&#xff1…

目录

  • 4 GRPO源码分析
    • 4.1 数据类 `GRPOScriptArguments`
    • 4.2 系统提示字符串 `SYSTEM_PROMPT`
    • 4.3 奖励函数
      • 4.3.1 accuracy_reward函数
      • 4.3.2 verify函数
      • 4.3.3 format_reward函数
    • 4.4 将数据集格式化为对话形式
    • 4.5 初始化GRPO Trainer


【复现DeepSeek-R1之Open R1实战】系列3:SFT和GRPO源码逐行深度解析(上)
【复现DeepSeek-R1之Open R1实战】系列5:SFT和GRPO源码逐行深度解析(中)

4 GRPO源码分析

前面两篇博文已经详细介绍了一些基础知识和SFT源码,本文继续解读GRPO源码。与SFT源码差不多的部分,我们就不展开细说了,这里只解析GRPO独特的部分。

4.1 数据类 GRPOScriptArguments

该类使用了 Python 的 dataclass 装饰器,这是一种简化类定义的方式,特别是对于那些主要用来存储数据的类。它继承自 ScriptArguments 类。

  • reward_funcs: 这是一个列表,包含了一系列可能的奖励函数名称,默认值为 ["accuracy", "format"]。这些奖励函数可能是用于评估模型性能的不同标准。

    reward_funcs: list[str] = field(default_factory=lambda: ["accuracy", "format"],metadata={"help": "List of reward functions. Possible values: 'accuracy', 'format', 'reasoning_steps', 'cosine', 'repetition_penalty', 'length'"},
    )
    
  • cosine_min_value_wrongcosine_max_value_wrong: 分别表示错误答案在余弦相似度尺度上的最小和最大奖励值,默认分别为 0.0-0.5

  • cosine_min_value_correctcosine_max_value_correct: 分别表示正确答案在余弦相似度尺度上的最小和最大奖励值,默认分别为 0.51.0

  • cosine_max_len: 表示余弦相似度尺度的最大长度,默认值为 1000

  • repetition_n_grams: 表示用于重复惩罚奖励的n-gram数量,默认值为 3

  • repetition_max_penalty: 表示重复惩罚奖励的最大负值,默认值为 -1.0

每个字段都使用了 field() 函数来定义其默认值和元数据(如帮助信息)。这有助于工具和库更好地理解和处理这些字段,例如生成命令行解析器时。

4.2 系统提示字符串 SYSTEM_PROMPT

SYSTEM_PROMPT = ("A conversation between User and Assistant. The user asks a question, and the Assistant solves it. The assistant ""first thinks about the reasoning process in the mind and then provides the user with the answer. The reasoning ""process and answer are enclosed within <think> </think> and <answer> </answer> tags, respectively, i.e., ""<think> reasoning process here </think><answer> answer here </answer>"
)

字符串描述了一个对话场景,用户先提问,助手首先思考推理过程,然后提供答案。推理过程和答案分别用 <think><answer> 标签包裹,这种格式化有助于区分和识别不同的部分,和DeepSeek-R1的思考过程格式一致。

4.3 奖励函数

奖励函数的定义如下,GRPO默认用到了accuracy_reward和format_reward这两个函数。

# Get reward functionsREWARD_FUNCS_REGISTRY = {"accuracy": accuracy_reward,"format": format_reward,"reasoning_steps": reasoning_steps_reward,"cosine": get_cosine_scaled_reward(min_value_wrong=script_args.cosine_min_value_wrong,max_value_wrong=script_args.cosine_max_value_wrong,min_value_correct=script_args.cosine_min_value_correct,max_value_correct=script_args.cosine_max_value_correct,max_len=script_args.cosine_max_len,),"repetition_penalty": get_repetition_penalty_reward(ngram_size=script_args.repetition_n_grams,max_penalty=script_args.repetition_max_penalty,),"length": len_reward,}reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]

这段代码定义了一个奖励函数注册表 REWARD_FUNCS_REGISTRY,并根据用户提供的配置动态生成一个奖励函数列表 reward_funcs。每个奖励函数用于评估模型输出的不同方面,如准确性、格式、推理步骤等。

  1. 注册表定义
  • accuracy: 使用 accuracy_reward 函数评估模型输出的准确性。
  • format: 使用 format_reward 函数评估模型输出的格式。
  • reasoning_steps: 使用 reasoning_steps_reward 函数评估模型输出的推理步骤。
  • cosine: 使用 get_cosine_scaled_reward 函数计算余弦相似度奖励,参数包括:
    • min_value_wrong: 错误情况下的最小值。
    • max_value_wrong: 错误情况下的最大值。
    • min_value_correct: 正确情况下的最小值。
    • max_value_correct: 正确情况下的最大值。
    • max_len: 最大长度。
  • repetition_penalty: 使用 get_repetition_penalty_reward 函数计算重复惩罚奖励,参数包括:
    • ngram_size: n-gram 的大小。
    • max_penalty: 最大惩罚值。
  • length: 使用 len_reward 函数评估模型输出的长度。
  1. 动态生成奖励函数列表
reward_funcs = [REWARD_FUNCS_REGISTRY[func] for func in script_args.reward_funcs]
  • 根据 script_args.reward_funcs 中指定的奖励函数名称,从 REWARD_FUNCS_REGISTRY 中获取相应的奖励函数,并生成一个列表 reward_funcs

4.3.1 accuracy_reward函数

该函数用于计算模型生成的补全与真实答案之间的准确性奖励。它通过解析和验证生成的内容与真实答案来确定奖励值。

def accuracy_reward(completions, solution, **kwargs):"""Reward function that checks if the completion is the same as the ground truth."""contents = [completion[0]["content"] for completion in completions]rewards = []for content, sol in zip(contents, solution):gold_parsed = parse(sol,extraction_mode="first_match",extraction_config=[LatexExtractionConfig()],)if len(gold_parsed) != 0:# We require the answer to be provided in correct latex (no malformed operators)answer_parsed = parse(content,extraction_config=[LatexExtractionConfig(normalization_config=NormalizationConfig(nits=False,malformed_operators=False,basic_latex=True,equations=True,boxed="all",units=True,),# Ensures that boxed is tried firstboxed_match_priority=0,try_extract_without_anchor=False,)],extraction_mode="first_match",)# Reward 1 if the content is the same as the ground truth, 0 otherwisereward = float(verify(answer_parsed, gold_parsed))else:# If the gold solution is not parseable, we reward 1 to skip this examplereward = 1.0print("Failed to parse gold solution: ", sol)rewards.append(reward)return rewards
  • completions (list): 包含多个补全结果的列表,每个补全结果是一个包含内容的字典列表。
  • solution (list): 真实答案的列表。
  • kwargs: 其他可选参数(在本函数中未使用)。
  1. 提取补全内容

    contents = [completion[0]["content"] for completion in completions]
    
    • completions 列表中提取每个补全的第一个内容(假设每个补全是单个元素的列表),形成一个新的 contents 列表。
  2. 初始化奖励列表

    rewards = []
    
  3. 遍历每个补全和对应的真实答案

    for content, sol in zip(contents, solution):gold_parsed = parse(sol,extraction_mode="first_match",extraction_config=[LatexExtractionConfig()],)
    
    • 使用 zip 函数将 contentssolution 配对。
    • 对于每一对补全内容和真实答案,首先解析真实答案 sol,使用 parse 函数提取其中的信息。
  4. 处理解析结果

    if len(gold_parsed) != 0:answer_parsed = parse(content,extraction_config=[LatexExtractionConfig(normalization_config=NormalizationConfig(nits=False,malformed_operators=False,basic_latex=True,equations=True,boxed="all",units=True,),# Ensures that boxed is tried firstboxed_match_priority=0,try_extract_without_anchor=False,)],extraction_mode="first_match",)
    
    • 如果解析得到的真实答案 gold_parsed 非空,则继续解析生成的补全内容 content
    • 使用 LatexExtractionConfigNormalizationConfig 进行详细配置,确保解析过程中考虑了各种格式要求(如方程、单位等)。
  5. 计算奖励

    reward = float(verify(answer_parsed, gold_parsed))
    
    • 使用 verify 函数比较生成的补全解析结果和真实答案的解析结果。
    • 如果两者匹配,则返回 1.0,否则返回 0.0
  6. 处理无法解析的情况

    else:reward = 1.0print("Failed to parse gold solution: ", sol)
    
    • 如果真实答案无法解析,则默认给予奖励 1.0 并打印警告信息。
  7. 添加奖励到列表

    rewards.append(reward)
    
  8. 返回所有奖励

    return rewards
    

4.3.2 verify函数

该函数用于验证目标表达式是否与参考表达式匹配,它通过多种比较策略来处理不同的数学对象(如数字、表达式、集合、矩阵等),并提供灵活的配置选项以适应不同的需求。

def verify(gold: list[Basic | MatrixBase | str] | Basic | MatrixBase | str, target: list[Basic | MatrixBase | str] | Basic | MatrixBase | str, float_rounding: int=6,numeric_precision: int=15,strict: bool=True,timeout_seconds: int=3
) -> bool:
  • gold: 参考或正确的表达式,可以是单个 SymPy 表达式(BasicMatrixBase)、字符串或这些类型的列表。
  • target: 需要验证的表达式,类型同 gold
  • float_rounding: 浮点数舍入的小数位数,默认为 6。
  • numeric_precision: 数值比较时考虑的小数位数,默认为 15。
  • strict: 是否启用严格比较模式,默认为 True
    • 在严格模式下:变量很重要,集合不可与元组比较。
    • 在非严格模式下:变量按位置匹配,集合可与元组比较。
  • timeout_seconds: 单次比较操作的最大超时时间(秒),默认为 3 秒。
  1. 定义内部比较函数 compare_single_extraction

    @timeout(timeout_seconds=timeout_seconds)
    def compare_single_extraction(gold: Basic | MatrixBase | str, target: Basic | MatrixBase | str) -> bool:...
    
    • 使用装饰器 @timeout 设置超时保护,默认超时时间为 timeout_seconds
    • 比较两个表达式:
      • 如果两者都是 SymPy 表达式(BasicMatrixBase),则调用 sympy_expr_eq 进行比较。
      • 如果两者都是字符串,则进行简单的字符串比较。
  2. 定义包装函数 compare_single_extraction_wrapper

    def compare_single_extraction_wrapper(g, t):try:return compare_single_extraction(g, t)except Exception as e:logger.exception(f"Error comparing {g} and {t}")return False
    
    • 包装 compare_single_extraction,捕获并记录任何异常,返回 False 以避免程序中断。
  3. 处理输入列表

    if not isinstance(gold, list):gold = [gold]
    if not isinstance(target, list):target = [target]
    
    • 如果 goldtarget 不是列表,则将其转换为单元素列表,以便统一处理。
  4. 组合所有可能的比较

    return any(compare_single_extraction_wrapper(g, t) for g, t in product(gold, target))
    
    • 使用 itertools.product 生成所有可能的 goldtarget 组合。
    • 对每个组合调用 compare_single_extraction_wrapper,如果任意一对匹配成功,则返回 True

4.3.3 format_reward函数

函数用于检查给定的完成文本是否符合特定的格式,它验证完成文本是否包含 <think><answer> 标签,并且这两个标签的内容是非空的。

def format_reward(completions, **kwargs):"""Reward function that checks if the completion has a specific format."""pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"completion_contents = [completion[0]["content"] for completion in completions]matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]return [1.0 if match else 0.0 for match in matches]
  • completions: 这是一个列表,其中每个元素都是一个包含完成内容的对象(通常是字典)。假设每个完成对象的第一个元素包含一个键 "content",其值是需要检查的文本。
  • kwargs: 其他关键字参数,这里没有使用,但可以为未来的扩展提供灵活性。
  1. 正则表达式模式定义

    pattern = r"^<think>.*?</think>\s*<answer>.*?</answer>$"
    
    • 这个正则表达式用于匹配字符串是否以 <think> 开始,紧接着是任意字符(非贪婪匹配),然后是 </think>,接着可能有任意数量的空白字符(包括换行符),最后是以 <answer> 开始并以 </answer> 结束。
    • .*? 是非贪婪匹配,确保尽可能少地匹配字符。
    • \s* 匹配零个或多个空白字符(包括换行符)。
    • re.DOTALL | re.MULTILINE 标志允许点号 . 匹配所有字符(包括换行符),并且使多行文本中的每一行都可以独立匹配。
  2. 提取完成内容

    completion_contents = [completion[0]["content"] for completion in completions]
    
    • 这里通过列表推导式从 completions 列表中提取每个完成对象的第一个元素的 "content" 字段,形成一个新的列表 completion_contents
  3. 匹配正则表达式

    matches = [re.match(pattern, content, re.DOTALL | re.MULTILINE) for content in completion_contents]
    
    • 使用 re.match 函数对 completion_contents 中的每个内容应用正则表达式模式。
    • matches 列表将包含 re.Match 对象(如果匹配成功)或 None(如果匹配失败)。
  4. 生成奖励分数

    return [1.0 if match else 0.0 for match in matches]
    
    • 最后一步是根据匹配结果生成奖励分数。如果匹配成功(即 match 不是 None),则返回 1.0;否则返回 0.0

示例代码:

completions = [[{"content": "<think>This is reasoning.</think><answer>This is answer.</answer>"}],[{"content": "<think>This is reasoning.</think>"}],[{"content": "<answer>This is answer.</answer>"}],[{"content": "This does not match."}]
]reward_scores = format_reward(completions)
print(reward_scores)  # 输出: [1.0, 0.0, 0.0, 0.0]

在这个例子中:

  • 第一个完成内容完全匹配正则表达式,因此得分为 1.0
  • 后三个完成内容不符合要求,因此得分均为 0.0

4.4 将数据集格式化为对话形式

# Format into conversationdef make_conversation(example):return {"prompt": [{"role": "system", "content": SYSTEM_PROMPT},{"role": "user", "content": example["problem"]},],}dataset = dataset.map(make_conversation)for split in dataset:if "messages" in dataset[split].column_names:dataset[split] = dataset[split].remove_columns("messages")

将一个数据集中的每个示例转换为对话格式,并确保数据集中没有多余的列(如 messages)。

  • 输入example 是一个字典,包含单个数据样本的信息,其中 "problem" 键对应的值是用户的问题或任务描述。
  • 输出:返回一个新的字典,包含一个 "prompt" 键,其值是一个对话列表:
    • 第一条消息是系统消息,内容由 SYSTEM_PROMPT 定义。
    • 第二条消息是用户消息,内容是 example["problem"]
  • dataset.map(make_conversation):使用 map 方法将 make_conversation 函数应用到数据集的每个示例上,生成新的对话格式。
  • 移除多余列:遍历数据集的每个拆分(split),如果存在 "messages" 列,则将其移除。

4.5 初始化GRPO Trainer

trainer = GRPOTrainer(model=model_args.model_name_or_path,reward_funcs=reward_funcs,args=training_args,train_dataset=dataset[script_args.dataset_train_split],eval_dataset=dataset[script_args.dataset_test_split] if training_args.eval_strategy != "no" else None,peft_config=get_peft_config(model_args),callbacks=get_callbacks(training_args, model_args),)

篇幅有限,训练部分的代码我们放到下一篇博文详细解读!

http://www.yayakq.cn/news/972956/

相关文章:

  • .电子商务网站建设的核心是百度一下你就知道官方网站
  • 深圳网站设计与制作微信客户端官网
  • 校园网站建设提升小程序制作需要什么条件
  • 民治专业做网站公司网络营销推广方法集锦
  • 代发网站建设教程网址域名大全2345网址
  • 网上如何做网站耒阳网站开发
  • 网站已收录的404页面的查询查网站的关键词排名吗
  • 注册网站查询官网个人网站名称备案
  • 百度公司做网站优化多少钱如何制作境外网站
  • 单页网站与传统网站的区别新营销方式有哪些
  • 如何再网站上做免费广告wordpress live-calendar
  • 揭阳建站服务基于vue的毕业设计题目
  • 北京网站搭建费用网站名和域名
  • 网站建设策划 流程linux服务器wordpress
  • 张家口高新区做网站网站怎么做前后台存取
  • 珠宝首饰网站建设规划书百度合作的网盟网站
  • 连锁酒店网站方案wordpress 简约论坛
  • 新做的网站如何行业 专业 网站建设
  • 文章博客媒体网站模板办公室装修公司费用
  • pc端购物网站建站彩虹云主机官网
  • 手机网站左右滑动效果wordpress 贴吧主题
  • 做写手哪个网站好网站建设费是
  • 无锡门户网站制作电话用微魔方做的网站一定要加
  • 做网站好还是做安卓app好烟台建设协会网站
  • 购买域名后 可以做网站么重庆好的seo平台
  • 网站开发要学些什么网站开发与应用总结
  • html5 单页 响应式 网站模板申请一个网站需要怎么做
  • 免费做网站表白服务质量好的crm系统
  • 为什么小城市做不出来好的网站上海做网站 公司
  • 西安高端网站建设哪家好网站未做安全隐患检测怎么拿shell