PyTorch概述(七)---Optim

news/发布时间2024/9/20 5:53:44
  • torch.optim是一个实现多种优化算法的包;
  • 很多常用的方法已经被支持;
  • 接口丰富;
  • 容易整合更为复杂的算法;

如何使用一个优化器

  • 为了使用torch.optim包功能;
  • 用户必须构建一个优化器对象;
  • 该优化器将保持当前的参数状态且基于计算的梯度更新参数;

构建优化器

  • 要构建一个优化器;
  • 必须给优化器一个可迭代的对象;
  • 该对象包含可优化的参数(应当是变量s);
  • 然后,用户可以指定具体的优化器参数,比如学习率,权重衰减等;
import torch.optim as optim
optimizer=optim.SGD(model.parameters(),lr=0.01,momentum=0.9)
optimizer=optim.Adam([var1,var2],lr=0.0001)

单一参数设置

  • 优化器也支持每一个参数的设置;
  • 为了这样做,不要给优化器传入一个可迭代的变量s;
  • 而是给优化器传入一个可迭代的字典s;
  • 每个字典将定义一个分离的参数组;
  • 参数组内应当包含一个参数键,该参数键包含一个属于他的参数列表;
  • 其他键应当匹配优化器可接受的关键字参数;
  • 且该键将被用作这个组内的优化选项;
  • 依然可以传递选项作为关键字参数,他们将被用于默认参数;
  • 组内对他们并不覆写;
  • 当用户想变化单一的选项时会很有用;
  • 同时保持其他的参数组一致;
  • 比如,当想指定每一层的学习率时:
optim.SGD([{'params':model.base.parameters()},{'params':model.classifier.parameters(),'lr':1e-3}],lr=1e-2,momentum=0.9)
  • 上述代码意味着model.base的参数将使用默认的学习率:1e^{-2};
  • model.classifier的参数将使用1e^{-3}的学习率;
  • momentum=0.9将会被所有的参数使用;

优化步骤

  • 所有的优化器都实现一个step()方法;
  • 该方法对参数进行更新;
  • 有两种使用方式:
  • optimizier.step()
  • optimizer.step(closure)

optimizer.step()

  • 大多数优化器都支持的一个简单版本用法;
  • 当使用backward()计算梯度后调用该方法;
for input,target in dataset:optimizer.zero_grad()output=model(input)loss=loss_fn(output,target)loss.backward()optimizer.step()

optimizer.step(closure)

  • 一些优化算法,比如共轭梯度和LBFGS;
  • 需要多次评估该函数;
  • 用户必须传递closure参数以允许算法重新计算模型;
  • closure参数应当清理梯度,计算损失,并返回;
for input,target in dataset:def closure():optimizer.zero_grad()output=model(input)loss=loss_fn(output,target)loss.backward()optimizer.step(closure)

基础类

  • 类 torch.optim.Optimizer(params,defaults)
  • 是所有优化器的基础类;
  • 必须以集合的方式指定参数;
  • 集合内的参数具有确定的顺序且同实际运行中的一致;
  • 不满足要求的是set和字典键值迭代器;
  • params(iterable)---一个可迭代的torch.Tensor或者字典,指定需要优化的张量类型;
  • defaults(Dict[str,Any])---具有默认优化选项值得字典(当参数组没有指定时使用);

Optimizer.add_param_group

给优化器参数组增加成员

Optimizer.load_state_dict

加载优化器状态

Optimizer.state_dict

以字典的方式返回优化器状态

Optimizer.step

执行一个优化器步(参数更新)

Optimizer.zero_grad

对所有优化器张量重置梯度

算法 

Adadelta

实现 Adadelta 算法;

Adagrad

实现 Adagrad 算法.

Adam

实现Adam 算法.

AdamW

实现AdamW 算法.

SparseAdam

适合稀疏矩阵的Adam算法的掩码版本

Adamax

实现Adamax 算法(基于无线范数的Adam变体).

ASGD

实现平均随机梯度下降.

LBFGS

实现L-BFGS 算法.

NAdam

实现NAdam 算法.

RAdam

实现RAdam 算法.

RMSprop

实现RMSprop 算法.

Rprop

实现弹性反向传播算法.

SGD

实现随机梯度下降算法(动量选项可选).

  •  很多算法对于优化性能\可读性和通用性具有不同的实现;
  • 如果用户没有特别的指定算法的实现方法,默认情况下针对用户设备尝试最快的实现方法;
  • 有三大类主要的实现:for-loop,foreach(多张量),fused;
  • 最直接的是对参数的for-loop循环实现,并进行大量的计算;
  • for-loop实现通常较foreach实现更慢,foreach实现一次合并参数到多张量中且进行大量计算;
  • foreach实现节省了很多序列化的内核调用;
  • 一些优化器具有更快速的融合实现;
  • 这些优化器融合大量的计算到一个内核中;
  • 我们可以认为foreach实现是水平融合,融合实现是foreach实现水平融合的垂直融合;
  • 一般来讲,三大类实现的性能排序为:fused>foreach>for-loop;
  • 应用时,默认情况下优先采用foreach;
  • 可应用意味着foreach实现是可用的,用户没有指定任何实现的细节参数(比如fused,foreach,differentiable),且所有张量是本地的在CUDA上;
  • 注意,虽然融合的实现较foreach实现应当更快;
  • 但这些实现是比较新的,在任何地方应用之前应该具有更多的实验时间;
  • 欢迎大家尝试;

目前算法的状态

Algorithm

Default

Has foreach?

Has fused?

Adadelta

foreach

yes

no

Adagrad

foreach

yes

no

Adam

foreach

yes

yes

AdamW

foreach

yes

yes

SparseAdam

for-loop

no

no

Adamax

foreach

yes

no

ASGD

foreach

yes

no

LBFGS

for-loop

no

no

NAdam

foreach

yes

no

RAdam

foreach

yes

no

RMSprop

foreach

yes

no

Rprop

foreach

yes

no

SGD

foreach

yes

no

如何调整学习率

  • torch.optim.lr_scheduler 提供一些方法基于训练的代数调整学习率;
  • torch.optim.lr_scheduler.ReduceLROnPlateau 基于一些验证测量允许动态的减少学习率;
  • 学习率调度应当在优化器更新后再应用;
optimizer=optim.SGD(model.parameters(),lr=0.01,momentum=0.9)
scheduler=ExponentialLR(optimizer,gamma=0.9)
for epoch in range(20):for input,target in dataset:optimizer.zero_grad()output=model(input)loss=loss_fn(output,target)loss.backward()optimizer.step()scheduler.step()
  • 很多学习率调度器被称为背靠背调度器(也成为链式调度器);
  • 结果是调度器被一个个的应用到另一个之前的调度器获取到的学习率上;
optimizer=optim.SGD(model.parameters(),lr=0.01,momentum=0.9)
scheduler1=ExponentialLR(optimizer,gamma=0.9)
scheduler2=MultiStepLR(optimizer,milestones=[30,80],gamma=0.1)
for epoch in range(20):for input,target in dataset:optimizer.zero_grad()output=model(input)loss=loss_fn(output,target)loss.backward()optimizer.step()scheduler1.step()scheduler2.step()

注意

  • 在1.1.0版本之前,学习率调度器被期望在优化器更新之前调用;
  • 1.1.0版本改变了这一特性;
  • 如果用户在优化器更新之前使用了学习率调度器;
  • 将忽略学习率调度器中的第一个值;
  • 如果在更新PyTorch1.1.0之后无法重新生成结果;
  • 请检查是否在错误的位置调用了学习率调度器;

lr_scheduler.LambdaLR

设置每一个参数组的学习率为初始学习率乘以一个给定函数.

lr_scheduler.MultiplicativeLR

将每个参数组的学习率乘以指定函数中给出的系数.

lr_scheduler.StepLR

每个步长的epoch以gamma衰减每个参数组的学习率.

lr_scheduler.MultiStepLR

一旦训练的代数达到了一个里程碑以gamma衰减每组参数的学习率;

lr_scheduler.ConstantLR

用一个小的常值系数衰减每个参数组的学习率直到训练的代数达到了预定义的里程碑total_iters;

lr_scheduler.LinearLR

通过线性改变小的乘法系数衰减每个参数组的学习率,直到训练的代数达到了预定义的里程碑:total_iters;

lr_scheduler.ExponentialLR

每一代通过gamma衰减每一个参数组的学习率;

lr_scheduler.PolynomialLR

在给定的total_iters内,使用多项式函数衰减每一个参数组的学习率;

lr_scheduler.CosineAnnealingLR

使用余弦退火时间表设置每一个参数组的学习率,这里\eta_{max}设置为初始的学习率lr,T_{cur}为训练的代数自从在SGDR中重新启动以来.

lr_scheduler.ChainedScheduler

学习率调度器链表.

lr_scheduler.SequentialLR

接收一个在优化过程中被期望顺序调用的调度器列表和提供具体的间隔以反映在一个给定的代数哪一个调度器被猜测调用的里程碑点.

lr_scheduler.ReduceLROnPlateau

当度量停止改进时减少学习率;

lr_scheduler.CyclicLR

根据周期学习率政策(CLR)设置每一个参数组的学习率;

lr_scheduler.OneCycleLR

根据1周期学习率政策设置每一个参数组的学习率.

lr_scheduler.CosineAnnealingWarmRestarts

使用余弦退火时间表设置每一个参数组的学习率,这里\eta_{max}设置为初始的学习率lr,T_{cur}为训练的代数自从在SGDR中重新启动以来.T_i为SGDR中的两次热重启之间的代数.

权值平均(SWA和EMA)

  • torch.optim.swa.utils 实现随机权值平均(SWA)和指数移动平均(EMA);
  • 特别的torch.optim.swa_utils.AveragedModel类实现SWA和EMA模型;
  • torch.optim.swa_utils.SWALR实现SWA学习率调度器;
  • torch.optim.swa_utils.update_bn()是一个工具函数用于在训练结束时更新SWA/EMA批次标准化统计;
  • SWA在Averaging Weights Leads to Wider Optima and Better Generalization.中被提出;
  • EMA是一个通过减少需要更新的权重的数量减少训练时间的广为人知的技术;
  • EMA是一个Polyak averaging的变体,但是在迭代中使用等权重而不是指数权重;

构建平均模型

  • AveragedModel类服务于计算SWA和EMA模型的权重;
  • 创建SWA平均模型:
averaged_model=AveragedModel(model)
  • 通过指定multi_avg_fn参数构建EMA模型:
decay=0.999
averaged_model=AveragedModel(model,multi_avg_fn=get_ema_multi_avg_fn(decay))
  • decay是一个在0和1之间的参数,控制平均化的参数以多快的速度衰减;
  • 如果decay参数没有提供给get_ema_multi_avg_fn,默认值为0.999;
  • get_ema_multi_avg_fn返回一个应用以下EMA权重公式的函数:

W_{t+1}^{EMA}=\alpha W_t^{EMA}+(1-\alpha)W_t^{model}

  • 这里,\alpha是EMA衰减因子,model可以是任何的torch.nn.Moudle对象;
  • averaged_model将保持追踪运行中的模型的参数均值化;
  • 为了更新这些均值,用户应当使用update_parameters()函数在optimizer.step()之后;
averaged_model.update_parameters(model)
  • 对于SWA和EMA,该函数的调用通常紧随optimizer step()函数;
  • 在SWA中,在训练起始的一些次数下被略过;

制定均值策略

  • 默认情况下,torch.optim.swa_utils.AveragedModel计算一个用户提供的参数的运行等效均值;
  • 但是用户可以使用定制的均值函数,使用avg_fn或者multi_avg_fn参数;
  • avg_fn允许定义一个函数操作每一个参数元组(平均的参数,模型参数)且应当返回平均的参数;
  • multi_avg_fn允许定义对元组参数列表(均值的参数列表,模型参数列表)的更有效的操作;
  • 同时,比如使用torch._foreach* 函数,该函数必须在位更新均值化的参数;
  • 下例中ema_model计算一个指数移动平均使用avg_fn参数:
import torch.optim.swa_utils
ema_avg=lambda averaged_model_parameter,model_parameter,
num_averaged:0.9*averaged_model_parameter+0.1*model_parameter
ema_model=torch.optim.swa_utils.AveragedModel(model,avg_fn=ema_avg)
  • 以下的实例ema_model计算一个指数移动平均使用更为高效的multi_avg_fn参数:
ema_model=AveragedModel(model,multi_avg_fn=get_ema_multi_avg_fn(0.9))

swa学习率计划

  • 通常,SWA中学习率被设置为一个大的常量数值;
  • SWALR是一个学习率调度器,将学习率调整到一个固定的值,并保持为常量;
  • 比如以下实例代码创建一个调度器线性调整学习率在5代训练的每个参数组中从初始值到0.05;
swa_scheduler=torch.optim.swa_utils.SWALR(optimizer,anneal_strategy="linear",anneal_epochs=5,swa_lr=0.05)
  • 可以使用余弦退火到一个固定的学习率值而不是使用线性退火通过设置annel_strategy='cos';

关注批量规范化

  • update_bn()是一个有用的函数允许计算SWA模型在一个给定加载器loader中的批规范统计,在训练的末尾;
torch.optim.swa_utils.update_fn(loader,swa_model)
  • update_bn()应用swa_model模型到数据加载器中的每一个单元;
  • 且在模型中每一个批标准化层计算激活数据;
  • update_fn()假设数据加载器loader中的每一批不是张量就是张量列表;
  • 且张量或张量列表中的第一个元素是网络swa_model应当应用到的张量;
  • 如果用户的加载器具有不同的结构;
  • 可以更新swa_model模型的批标准化数据;
  • 通过对数据集的每个元素使用swa_model做前向传递;

SWA综述

  • 以下实例,swa_model是一个累积权重均值的SWA模型;
  • 对模型训练300代,调整学习率计划,在训练代160时手机SWA参数的均值:
import torch
loader,optimizer,model,loss_fn=...
swa_model=torch.opim.swa_utils.AveragedModel(model)
scheduler=torch.optim.lr_scheduler.CosineAnnealingLR(optimizer,T_max=300)
swa_start=160
swa_scheduler=SWALR(optimizer,swa_lr=0.05)for epoch in range(300):for input,target_in_loader:optimizer.zero_grad()loss_fn(model(input),target).backward()optimizer.step()if epoch>swa_start:swa_model.update_parameters(model)swa_scheduler.step()else:scheduler.step()
torch.optim.swa_utils.update_bn(loader,swa_model)
preds=swa_model(test_input)

EMA综述

  • 以下实例中,ema_model是一个EMA模型的实例;
  • 该实例累积权重均值的指数衰减;
  • 衰减率为0.999;
  • 训练模型300代,从训练一开始就收集EMA均值;
import torch
loader,optimizer,model,loss_fn=...
ema_model=torch.opim.swa_utils.AveragedModel(model,multi_avg_fn=torch.optim.swa_utils.get_ema_multi_avg_fn(0.999))
for epoch in range(300):for input,target_in_loader:optimizer.zero_grad()loss_fn(model(input),target).backward()optimizer.step()ema_model.update_parameters(model)
torch.optim.swa_utils.update_bn(loader,ema_model)
preds=swa_model(test_input)

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.bcls.cn/EmMX/9434.shtml

如若内容造成侵权/违法违规/事实不符,请联系编程老四网进行投诉反馈email:xxxxxxxx@qq.com,一经查实,立即删除!

相关文章

SD-WAN技术:优化国内外服务器访问的关键

在全球化的商业环境中,企业经常需要在国内访问国外的服务器。然而,由于地理位置和网络架构的限制,这种跨国访问往往会遇到速度慢、延迟高等问题。SD-WAN(软件定义广域网)技术的兴起,为企业提供了一种新的解…

【MATLAB源码-第148期】基于matlab的BP神经网络2/4ASK,2/4FSK,2/4PSK信号识别仿真。

操作环境: MATLAB 2022a 1、算法描述 1. 调制技术基础 调制技术是通信技术中的基础,它允许数据通过无线电波或其他形式的信号进行传输。调制可以根据信号的振幅、频率或相位的变化来进行,分别对应于ASK、FSK和PSK。 1.1 2ASK与4ASK 振幅…

技术派数据库表自动初始化(学习)

不需要在db中手动创建或者导入相关的schema、data&#xff0c;项目启动自动创建对应的表&#xff0c;并初始化。实现该过程。 Liquibase数据库版本管理 依赖配置 在paicoding-web模块中&#xff0c;pom.xml 文件中添加 <dependency><groupId>org.liquibase</g…

音视频数字化(数字与模拟-电视)

上一篇文章【音视频数字化(数字与模拟-音频广播)】谈了音频的广播,这次我们聊电视系统,这是音频+视频的采集、传输、接收系统,相对比较复杂。 音频系统的广播是将声音转为电信号,再调制后发射出去,利用“共振”原理,收音机接收后解调,将音频信号还原再推动扬声器,我…

力扣链表篇

以下刷题思路来自代码随想录以及官方题解 文章目录 203.移除链表元素707.设计链表206.反转链表24.两两交换链表中的节点19.删除链表的倒数第N个节点面试题 02.07. 链表相交142.环形链表II 203.移除链表元素 给你一个链表的头节点 head 和一个整数 val &#xff0c;请你删除链…

wcf 简单实践 数据绑定 数据更新ui

1.概要 2.代码 2.1 xaml <Window x:Class"WpfApp3.MainWindow"xmlns"http://schemas.microsoft.com/winfx/2006/xaml/presentation"xmlns:x"http://schemas.microsoft.com/winfx/2006/xaml"xmlns:d"http://schemas.microsoft.com/expr…

基于x86架构的OpenHarmony应用生态挑战赛等你来战!

为了更快速推进OpenHarmony在PC领域的进一步落地&#xff0c;加快x86架构下基于OpenHarmony的应用生态的繁荣&#xff0c;为北向应用开发者提供一个更加便捷的开发环境&#xff0c;推动OpenHarmony北向应用开发者的增加&#xff0c;助力OpenHarmony在PC领域实现新的突破&#x…

Linux系统Docker部署Nexus Maven并实现远程访问本地管理界面

文章目录 1. Docker安装Nexus2. 本地访问Nexus3. Linux安装Cpolar4. 配置Nexus界面公网地址5. 远程访问 Nexus界面6. 固定Nexus公网地址7. 固定地址访问Nexus Nexus是一个仓库管理工具&#xff0c;用于管理和组织软件构建过程中的依赖项和构件。它与Maven密切相关&#xff0c;可…

Sui在AIBC Eurasia奖项评选中被评为2024年度最佳区块链解决方案

自2023年主网上线以来&#xff0c;经历了爆炸性增长的Layer1区块链Sui在2月25–27日迪拜举办的第二届AIBC Eurasia活动中获得“2024最佳区块链解决方案奖”&#xff08;Best Real World Application Award 2024&#xff09;。这个盛大的活动以世界级的参与者和往届获奖者而闻名…

一篇关于,搬运机器人的介绍

搬运机器人是一种能够自动运输和搬运物品的机器人。它们通常配备有传感器和导航系统&#xff0c;可以在工厂、仓库、医院或其他场所自主移动&#xff0c;并且可以根据预先设定的路径或指令进行操作。 搬运机器人可以用于搬运重物、物料搬运、装卸货物、仓库管理等任务。它们可以…

Python程序的流程

归纳编程学习的感悟&#xff0c; 记录奋斗路上的点滴&#xff0c; 希望能帮到一样刻苦的你&#xff01; 如有不足欢迎指正&#xff01; 共同学习交流&#xff01; &#x1f30e;欢迎各位→点赞 &#x1f44d; 收藏⭐ 留言​&#x1f4dd; 年轻是我们唯一拥有权利去编制梦想的时…

springboot227旅游管理系统

springboot旅游管理系统设计与实现 摘 要 现代经济快节奏发展以及不断完善升级的信息化技术&#xff0c;让传统数据信息的管理升级为软件存储&#xff0c;归纳&#xff0c;集中处理数据信息的管理方式。本旅游管理系统就是在这样的大环境下诞生&#xff0c;其可以帮助使用者在…

《大模型时代-ChatGPT开启通用人工智能浪潮》精华摘抄

原书很长&#xff0c;有19.3w字&#xff0c;本文尝试浓缩一下其中的精华。 知识点 GPT相关 谷歌发布LaMDA、BERT和PaLM-E&#xff0c;PaLM 2 Facebook的母公司Meta推出LLaMA&#xff0c;并在博客上免费公开LLM&#xff1a;OPT-175B。 在GPT中&#xff0c;P代表经过预训练(…

如何运行github上的项目

为了讲明白这个过程&#xff0c;特意做了一个相对来说比较好读懂的原理图&#xff0c;希望和我一样初学的小伙伴也能很快上手哈&#x1f60a; 在Github中找到想要部署的项目&#xff0c;这里以BartoszJarocki/CV&#xff08;线上简历&#x1f4c4;&#xff09;项目为例 先从头…

前端视角对Rust的浅析

概述 本文将从 Rust 的历史&#xff0c;前端的使用场景和业界使用案例一步步带你走进 Rust的世界。并且通过一些简单的例子&#xff0c;了解 Rust 如何应用到前端&#xff0c;提高前端的生产效率。 Rust简史 2006年&#xff0c;软件开发者Graydon Hoare在Mozilla工作期间&#…

C#与VisionPro联合开发——INI存储和CSV存储

1、INI存储 INI 文件是一种简单的文本文件格式&#xff0c;通常用于在 Windows 环境中存储配置数据。INI 文件格式由一系列节&#xff08;section&#xff09;和键值对&#xff08;key-value pairs&#xff09;组成&#xff0c;用于表示应用程序的配置信息。一个典型的 INI 文…

Flink代码单词统计 ---批处理

flatMap&#xff1a;一对多转换操作&#xff0c;输入句子&#xff0c;输出分词后的每个词groupBy&#xff1a;按Key分组&#xff0c;0代表选择第1列作为Keysum&#xff1a;求和&#xff0c;1代表按照第2列进行累加print&#xff1a;打印最终结果 1.WordCount代码编写 需求&am…

k8s资源管理之声明式管理方式

1 声明式管理方式 1.1 声明式管理方式支持的格式 JSON 格式&#xff1a;主要用于 api 接口之间消息的传递 YAML 格式&#xff1a;用于配置和管理&#xff0c;YAML 是一种简洁的非标记性语言&#xff0c;内容格式人性化&#xff0c;较易读 1.2 YAML 语法格式&#xff1a; ●…

C# Onnx 使用onnxruntime部署实时视频帧插值

目录 介绍 效果 模型信息 项目 代码 下载 C# Onnx 使用onnxruntime部署实时视频帧插值 介绍 github地址&#xff1a;https://github.com/google-research/frame-interpolation FILM: Frame Interpolation for Large Motion, In ECCV 2022. The official Tensorflow 2…

五.AV Foundation 视频播放 - 标题和字幕

引言 本篇博客主要介绍使用AV Foundation加载视频资源的时候&#xff0c;如何获取视频标题&#xff0c;获取字幕并让其显示到播放界面。 设置标题 资源标题的元数据内容&#xff0c;我们需要从资源的commonMetadata中获取&#xff0c;在加载AVPlayerItem的时候我们已经指定了…
推荐文章