深度学习论文: Attention is All You Need及其PyTorch实现

news/发布时间2024/5/14 6:36:47

深度学习论文: Attention is All You Need及其PyTorch实现
Attention is All You Need
PDF:https://arxiv.org/abs/1706.03762.pdf
PyTorch: https://github.com/shanglianlm0525/PyTorch-Networks

大多数先进的神经序列转换模型采用编码器-解码器结构,其中编码器将输入符号序列转换为连续表示,解码器则基于这些表示逐个生成输出符号序列。在每个步骤中,模型采用自回归方式,将先前生成的符号作为额外输入来生成下一个符号。
在这里插入图片描述

1 Encoder and Decoder Stacks

1-1 Encoder 编码器

编码器采用N=6个结构相同的层堆叠而成。每一层包含两个子层,第一个是多头自注意力机制,第二个则是简单的位置全连接前馈网络。为提升性能,在每个子层之间引入了残差连接,并实施了层归一化。具体来说,每个子层的输出经过LayerNorm(x + Sublayer(x))计算得出,其中Sublayer(x)代表子层自身的功能实现。为了确保残差连接的顺畅进行,编码器中的所有子层以及嵌入层均生成维度为dmodel=512的输出。

def clones(module, N):"Produce N identical layers."return nn.ModuleList([copy.deepcopy(module) for _ in range(N)])class LayerNorm(nn.Module):"Construct a layernorm module (See citation for details)."def __init__(self, features, eps=1e-6):super(LayerNorm, self).__init__()self.a_2 = nn.Parameter(torch.ones(features))self.b_2 = nn.Parameter(torch.zeros(features))self.eps = epsdef forward(self, x):mean = x.mean(-1, keepdim=True)std = x.std(-1, keepdim=True)return self.a_2 * (x - mean) / (std + self.eps) + self.b_2class Encoder(nn.Module):"Core encoder is a stack of N layers"def __init__(self, layer, N):super(Encoder, self).__init__()self.layers = clones(layer, N)self.norm = LayerNorm(layer.size)def forward(self, x, mask):"Pass the input (and mask) through each layer in turn."for layer in self.layers:x = layer(x, mask)return self.norm(x)class SublayerConnection(nn.Module):"""A residual connection followed by a layer norm.Note for code simplicity the norm is first as opposed to last."""def __init__(self, size, dropout):super(SublayerConnection, self).__init__()self.norm = LayerNorm(size)self.dropout = nn.Dropout(dropout)def forward(self, x, sublayer):"Apply residual connection to any sublayer with the same size."return x + self.dropout(sublayer(self.norm(x)))class EncoderLayer(nn.Module):"Encoder is made up of self-attn and feed forward (defined below)"def __init__(self, size, self_attn, feed_forward, dropout):super(EncoderLayer, self).__init__()self.self_attn = self_attnself.feed_forward = feed_forwardself.sublayer = clones(SublayerConnection(size, dropout), 2)self.size = sizedef forward(self, x, mask):"Follow Figure 1 (left) for connections."x = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, mask))return self.sublayer[1](x, self.feed_forward)

1-2 Decoder解码器

解码器同样由N=6个结构相同的层堆叠而成。除了编码器层中的两个子层,解码器还额外引入了一个第三子层,专门用于对编码器堆叠的输出执行多头注意力机制。与编码器设计相似,解码器中的每个子层也采用残差连接与层归一化。此外,还对解码器堆叠中的自注意力子层进行了优化,确保在生成输出时,位置i的预测仅依赖于位置小于i的已知输出,这通过掩码和输出嵌入偏移一个位置的方式实现。

class Decoder(nn.Module):"Generic N layer decoder with masking."def __init__(self, layer, N):super(Decoder, self).__init__()self.layers = clones(layer, N)self.norm = LayerNorm(layer.size)def forward(self, x, memory, src_mask, tgt_mask):for layer in self.layers:x = layer(x, memory, src_mask, tgt_mask)return self.norm(x)class DecoderLayer(nn.Module):"Decoder is made of self-attn, src-attn, and feed forward (defined below)"def __init__(self, size, self_attn, src_attn, feed_forward, dropout):super(DecoderLayer, self).__init__()self.size = sizeself.self_attn = self_attnself.src_attn = src_attnself.feed_forward = feed_forwardself.sublayer = clones(SublayerConnection(size, dropout), 3)def forward(self, x, memory, src_mask, tgt_mask):"Follow Figure 1 (right) for connections."m = memoryx = self.sublayer[0](x, lambda x: self.self_attn(x, x, x, tgt_mask))x = self.sublayer[1](x, lambda x: self.src_attn(x, m, m, src_mask))return self.sublayer[2](x, self.feed_forward)

2 Attention

在这里插入图片描述
Transformer巧妙地运用了三种不同的多头注意力机制:

  • 在“编码器-解码器注意力”层中,查询源于前一解码器层,而键和值则汲取自编码器的输出。这种设计使得解码器中的任意位置都能关注输入序列的每一个细节,完美模拟了序列到序列模型中典型的编码器-解码器注意力机制。

  • 编码器内部则嵌入了自注意力层。在这个自注意力层中,键、值和查询均来自编码器前一层的输出,确保编码器中的每个位置都能全面捕捉到前一层中的信息。

  • 解码器同样采用了自注意力层,允许解码器中的每个位置关注到包括该位置在内的解码器内部所有位置的信息。为了确保自回归属性的保持,我们精心设计了一个机制:在缩放点积注意力内部,将对应非法连接的softmax输入值设为−∞,从而有效屏蔽了向左的信息流动。

2-1 Scaled Dot-Product Attention

Scaled Dot-Product Attention 缩放点积注意力,输入包括维度为 d k d_{k} dk的查询和键,以及维度为 d v d_{v} dv的值。计算查询与所有键的点积,每个都除以 d k \sqrt{d_{k} } dk ,然后应用softmax函数以获得值的权重。
在这里插入图片描述

def attention(query, key, value, mask=None, dropout=None):"Compute 'Scaled Dot Product Attention'"d_k = query.size(-1)scores = torch.matmul(query, key.transpose(-2, -1)) \/ math.sqrt(d_k)if mask is not None:scores = scores.masked_fill(mask == 0, -1e9)p_attn = F.softmax(scores, dim = -1)if dropout is not None:p_attn = dropout(p_attn)return torch.matmul(p_attn, value), p_attn

2-2 Multi-Head Attention

与其仅使用具有dmodel维度的键、值和查询来执行单一的注意力函数,将查询、键和值分别通过不同的、学习得到的线性投影,各自线性投影 h h h次,分别映射到 d k d_{k} dk d k d_{k} dk d v d_{v} dv维度后效果更好。随后在这些投影后的查询、键和值上并行地执行注意力函数,从而得到 d v d_{v} dv维度的输出值。最后,我们将这些输出值进行拼接,并再次通过投影得到最终的值。这种方法能够更充分地利用输入信息,提升模型的性能。
在这里插入图片描述
在这项研究中,我们采用了h=8个并行的注意力层,也称之为注意力头。对于每一个注意力头,我们都设定其维度为 d k d_{k} dk = d v d_{v} dv = d m o d e l / h d_{model} / h dmodel/h=64。由于每个注意力头的维度有所降低,因此其整体计算成本与使用完整维度的单头注意力相近。

class MultiHeadedAttention(nn.Module):def __init__(self, h, d_model, dropout=0.1):"Take in model size and number of heads."super(MultiHeadedAttention, self).__init__()assert d_model % h == 0# We assume d_v always equals d_kself.d_k = d_model // hself.h = hself.linears = clones(nn.Linear(d_model, d_model), 4)self.attn = Noneself.dropout = nn.Dropout(p=dropout)def forward(self, query, key, value, mask=None):"Implements Figure 2"if mask is not None:# Same mask applied to all h heads.mask = mask.unsqueeze(1)nbatches = query.size(0)# 1) Do all the linear projections in batch from d_model => h x d_k query, key, value = \[l(x).view(nbatches, -1, self.h, self.d_k).transpose(1, 2)for l, x in zip(self.linears, (query, key, value))]# 2) Apply attention on all the projected vectors in batch. x, self.attn = attention(query, key, value, mask=mask, dropout=self.dropout)# 3) "Concat" using a view and apply a final linear. x = x.transpose(1, 2).contiguous() \.view(nbatches, -1, self.h * self.d_k)return self.linears[-1](x)

3 Position-wise Feed-Forward Networks

除了注意力子层,编码器和解码器的每一层还包括一个全连接前馈网络,它独立且相同地应用于每个位置,包含两个线性变换和一个ReLU激活函数。
在这里插入图片描述
这里的输入和输出的维度是 d m o d e l = 512 d_{model}=512 dmodel=512,而内层的维度是 d f f = 2048 d_{ff}=2048 dff=2048

class PositionwiseFeedForward(nn.Module):"Implements FFN equation."def __init__(self, d_model, d_ff, dropout=0.1):super(PositionwiseFeedForward, self).__init__()self.w_1 = nn.Linear(d_model, d_ff)self.w_2 = nn.Linear(d_ff, d_model)self.dropout = nn.Dropout(dropout)def forward(self, x):return self.w_2(self.dropout(F.relu(self.w_1(x))))

4 Embeddings and Softmax

与其他序列转换模型相似,这里使用学习嵌入将输入标记和输出标记转换为 d m o d e l d_{model} dmodel维向量。解码器输出通过线性变换和softmax函数转换为预测概率。在我们的模型中,嵌入层和softmax前的线性变换共享权重矩阵,并在嵌入层中将权重乘以 d m o d e l \sqrt{d_{model}} dmodel

class Embeddings(nn.Module):def __init__(self, d_model, vocab):super(Embeddings, self).__init__()self.lut = nn.Embedding(vocab, d_model)self.d_model = d_modeldef forward(self, x):return self.lut(x) * math.sqrt(self.d_model)

5 Positional Encoding

在编码器和解码器堆栈底部的输入嵌入中添加“位置编码”。位置编码的维度 d m o d e l d_{model} dmodel与嵌入的维度相同,以便可以将两者相加。位置编码有很多选择,包括学习和固定的。

在这里,选用不同频率的正弦和余弦函数:
在这里插入图片描述
其中pos是位置,i是维度。也就是说,位置编码的每个维度都对应于一个正弦波。波长从2π到10000·2π形成几何级数。选择这个函数是因为它可以使模型容易学习通过相对位置进行关注,因为对于任何固定的偏移量k, P E p o s + k PE_{pos+k} PEpos+k都可以表示为 P E p o s PE_{pos} PEpos的线性函数。

此外,在编码器和解码器堆栈中都对嵌入和位置编码的和使用了Dropout。对于基础模型,使用的Dropout ratio 为 P d r o p P_{drop} Pdrop=0.1。

class PositionalEncoding(nn.Module):"Implement the PE function."def __init__(self, d_model, dropout, max_len=5000):super(PositionalEncoding, self).__init__()self.dropout = nn.Dropout(p=dropout)# Compute the positional encodings once in log space.pe = torch.zeros(max_len, d_model)position = torch.arange(0, max_len).unsqueeze(1)div_term = torch.exp(torch.arange(0, d_model, 2) *-(math.log(10000.0) / d_model))pe[:, 0::2] = torch.sin(position * div_term)pe[:, 1::2] = torch.cos(position * div_term)pe = pe.unsqueeze(0)self.register_buffer('pe', pe)def forward(self, x):x = x + Variable(self.pe[:, :x.size(1)], requires_grad=False)return self.dropout(x)

参考资料:
1 Attention is All You Need
2 The Illustrated Transformer
3 The Annotated Transformer

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.bcls.cn/hZjZ/29654.shtml

如若内容造成侵权/违法违规/事实不符,请联系编程老四网进行投诉反馈email:xxxxxxxx@qq.com,一经查实,立即删除!

相关文章

没学数模电可以玩单片机吗?

我们首先来看一下数电模电在单片机中的应用。数电知识在单片机中主要解决各种数字信号的处理、运算,如数制转换、数据运算等。模电知识在单片机中主要解决各种模拟信号的处理问题,如采集光照强度、声音的分贝、温度等模拟信号。而数电、模电的相互转换就…

RSTP环路避免实验(华为)

思科设备参考:RSTP环路避免实验(思科) 一,技术简介 RSTP (Rapid Spanning Tree Protocol) 是从STP发展而来 • RSTP标准版本为IEEE802.1w • RSTP具备STP的所有功能,可以兼容STP运行 • RSTP和STP有所不同 减少了…

SpringMVC第一个helloword项目

文章目录 前言一、SpringMVC是什么?二、使用步骤1.引入库2.创建控制层3.创建springmvc.xml4.配置web.xml文件5.编写视图页面 总结 前言 提示:这里可以添加本文要记录的大概内容: SpringMVC 提示:以下是本篇文章正文内容&#xf…

rust标准库std::env环境相关的常量

在env这个库中,有一些环境相关的常量,这些常量在std::env::consts这个模块下面,通过这个依赖库可以获取到当前程序所运行的环境和运行的目录地址等信息。 env 常量 std::env下面一些系统相关的常量: ARCH DLL_EXTENSION DLL_P…

2024/03/28(C++·day4)

一、思维导图 二、练习题 1、写出三种构造函数&#xff0c;算术运算符、关系运算符、逻辑运算符重载尝试实现自增、自减运算符的重载 #include <iostream>using namespace std;// 构造函数示例 class MyClass { private:int data; public:// 默认构造函数MyClass() {da…

Notepad++:格式化json字符串(带转义)

目录 一、效果呈现 二、去除json字符串转义 三、格式化json字符串 一、效果呈现 格式化前 带字符串转义&#xff0c;带unicode编码字符 格式化后 二、去除json字符串转义 方法&#xff1a;采用Notepad的普通替换 第一&#xff1a;\"替换为" 第二&#xff1a;\\…

SAP BTP云上一个JVM与DB Connection纠缠的案例

前言 最近在CF (Cloud Foundry) 云平台上遇到一个比较经典的案例。因为牵扯到JVM &#xff08;app进程&#xff09;与数据库连接两大块&#xff0c;稍有不慎&#xff0c;很容易引起不快。 在云环境下&#xff0c;有时候相互扯皮的事蛮多。如果是DB的问题&#xff0c;就会找DB…

进入消息传递的魔法之门:ActiveMQ原理与使用详解

嗨&#xff0c;亲爱的童鞋们&#xff01;欢迎来到这个充满魔法的世界&#xff0c;今天我们将一同揭开消息中间件ActiveMQ的神秘面纱。如果你是一个对编程稍有兴趣&#xff0c;但又对消息中间件一知半解的小白&#xff0c;不要害怕&#xff0c;我将用最简单、最友好的语言为你呈…

HarmonyOS 应用开发之UIAbility组件生命周期

概述 当用户打开、切换和返回到对应应用时&#xff0c;应用中的UIAbility实例会在其生命周期的不同状态之间转换。UIAbility类提供了一系列回调&#xff0c;通过这些回调可以知道当前UIAbility实例的某个状态发生改变&#xff0c;会经过UIAbility实例的创建和销毁&#xff0c;…

【数据结构刷题专题】—— 二叉树

二叉树 二叉树刷题框架 二叉树的定义&#xff1a; struct TreeNode {int val;TreeNode* left;TreeNode* right;TreeNode(int x) : val(x), left(NULL), right(NULL); };1 二叉树的遍历方式 【1】前序遍历 class Solution { public:void traversal(TreeNode* node, vector&…

win11蓝牙图标点击变灰,修复过程

问题发现 有一天突然心血来潮想着连接蓝牙音响放歌来听,才发现win11系统右下角菜单里的蓝牙开关有问题。 打开蓝牙设置,可以正常直接连上并播放声音,点击右下角菜单里的蓝牙开关按钮后,蓝牙设备也能正常断开,但是按钮直接变深灰色,无法再点击打开。 重启电脑,蓝牙开关显…

AMD本月发布的成本优化型Spartan UltraScale+ FPGA系列

随着 FPGA 在更多应用中的使用&#xff0c;AMD 推出了最新的成本、功耗与性能平衡的系列产品。为了扩展其可编程逻辑产品组合&#xff0c;AMD最近推出了最新的成本优化型 Spartan FPGA 系列。随着 FPGA 应用于越来越多的产品和设备&#xff0c;设计人员可能经常发现自己正在寻找…

Electron+Vue构建项目时出错:Error: Exit code: ENOENT. spawn /usr/bin/python ENOENT

问题&#xff1a;ElectronVue构建项目时出错&#xff1a;Error: Exit code: ENOENT. spawn /usr/bin/python ENOENT URL:https://github.com/nklayman/vue-cli-plugin-electron-builder/issues/1701 一&#xff0c;构建时node版本要低 同时构建命令如下&#xff1a; "el…

Excel·VBA数组分组问题

看到一个帖子《excel吧-数据分组问题》&#xff0c;对一组数据分成4组&#xff0c;使每组的和值相近 目录 代码思路1&#xff0c;分组形式、可分组数代码1代码2代码2举例 2&#xff0c;数组所有分组形式举例 这个问题可以转化为2步&#xff1a;第1步&#xff0c;获取一组数据…

数据分析能力模型分析与展示

具体内容&#xff1a; 专业素质 专业素质-01 数据处理 能力定义•能通过各种数据处理工具及数据处理方法&#xff0c;对内外部海量数据进行清洗和运用&#xff0c;提供统一数据标准&#xff0c;为业务分析做好数据支持工作。 L1•掌握一…

AJAX-项目优化(目录、基地址、token、请求拦截器)

目录管理 基地址存储 在utils/request.js配置axios请求基地址 作用&#xff1a;提取公共前缀地址&#xff0c;配置后axios请求时都会baseURLurl 填写API的公共前缀后&#xff0c;将js文件导入到html文件中 <script src"../../utils/request.js"></script&…

OpenHarmony实战开发-Web组件的使用

介绍 本篇Codelab使用ArkTS语言实现一个简单的免登录过程&#xff0c;向大家介绍基本的cookie管理操作。主要包含以下功能&#xff1a; 获取指定url对应的cookie的值。设置cookie。清除所有cookie。免登录访问账户中心。 原理说明 本应用旨在说明Web组件中cookie的管理操作。…

5个适用于 Windows/PC 的水印去除软件(视频/图像)

水印是文本、徽标、印记、图像或签名&#xff0c;通常叠加在视频、其他图像或具有较高透明度的 PDF 文档上。当您免费使用某些产品&#xff08;例如视频编辑器&#xff09;时&#xff0c;最终输出通常带有代表您使用的编辑器的水印。您可能需要出于您的目的从此类媒体文件中删除…

目标检测+车道线识别+追踪

一种方法&#xff1a; 车道线检测-canny边缘检测-霍夫变换 一、什么是霍夫变换 霍夫变换&#xff08;Hough Transform&#xff09;是一种在图像处理和计算机视觉中广泛使用的特征检测技术&#xff0c;主要用于识别图像中的几何形状&#xff0c;尤其是直线、圆和椭圆等常见形状…

ChatGPTGPT4科研应用、数据分析与机器学习、论文高效写作、AI绘图技术教程

原文链接&#xff1a;ChatGPTGPT4科研应用、数据分析与机器学习、论文高效写作、AI绘图技术教程https://mp.weixin.qq.com/s?__bizMzUzNTczMDMxMg&mid2247598798&idx2&sn014f5ae90306a3b1e8fd87ab58561411&chksmfa820329cdf58a3f72799a43016b223057fd1bd02284…
推荐文章