flash attention实现:
import torch
from xformers import ops as xops
import time
bs = 32
seq_len = 512
n_head = 16
head_dim = 64
query_states = torch.randn((bs, n_head, seq_len, head_dim), dtype=torch.float16).to("cuda:0")
key_states = torch.randn((bs, n_head, seq_len, head_dim), dtype=torch.float16).to("cuda:0")
value_states = torch.randn((bs, n_head, seq_len, head_dim), dtype=torch.float16).to("cuda:0")
flash_query_states = query_states.transpose(1, 2)
flash_key_states = key_states.transpose(1, 2)
flash_value_states = value_states.transpose(1, 2)
start_time = time.time()
#xformers 实现的注意力机制, 加速框架
flash_attn_output = xops.memory_efficient_attention(
flash_query_states, flash_key_states, flash_value_states,
attn_bias=xops.LowerTriangularMask()
)
print(f'flash attention time: {(time.time()-start_time)*1000} ms')
print(torch.cuda.max_memory_allocated("cuda:0")/1024**2) #192M
print("=============================")
print(torch.cuda.memory_allocated("cuda:0")/1024**2) #128M
standard attention 实现:
import torch
from xformers import ops as xops
import time
bs = 32
seq_len = 512
n_head = 16
head_dim = 64
query_states = torch.randn((bs, n_head, seq_len, head_dim), dtype=torch.float16).to("cuda:0")
key_states = torch.randn((bs, n_head, seq_len, head_dim), dtype=torch.float16).to("cuda:0")
value_states = torch.randn((bs, n_head, seq_len, head_dim), dtype=torch.float16).to("cuda:0")
flash_query_states = query_states.transpose(1, 2)
flash_key_states = key_states.transpose(1, 2)
flash_value_states = value_states.transpose(1, 2)
start_time = time.time()
import math
import torch.nn as nn
attention_mask = torch.tril(torch.ones((seq_len, seq_len), dtype=torch.bool)).view(1, 1, seq_len, seq_len)
attention_mask = attention_mask.to(dtype=torch.float16).cuda() # fp16 compatibility
attention_mask = (1.0 - attention_mask) * torch.finfo(torch.float16).min #数据类型
def standard_attention(query_states, key_states, value_states, attention_mask):
attn_weights = torch.matmul(query_states, key_states.transpose(2, 3)) / math.sqrt(head_dim)
attn_weights = attn_weights + attention_mask
# upcast attention to fp32
attn_weights = nn.functional.softmax(attn_weights, dim=-1, dtype=torch.float32).to(query_states.dtype)
attn_output = torch.matmul(attn_weights, value_states)
attn_output = attn_output.transpose(1, 2)
return attn_output
start_time = time.time()
attn_output = standard_attention(query_states, key_states, value_states, attention_mask)
print(f'standard attention time: {(time.time()-start_time)*1000} ms')
#print(torch.allclose(attn_output, flash_attn_output, rtol=2e-3, atol=2e-3)) #判断两个张量是否接近相等(计算机计算的不精确性,完全相等的浮点数可能存在微小差异)
print(torch.cuda.max_memory_allocated("cuda:0")/1024**2) #1128M
print("=============================")
print(torch.cuda.memory_allocated("cuda:0")/1024**2) #136M
flash attention1 实现:
import torch
torch.manual_seed(456)
N, d = 16, 8
Q_mat = torch.rand((N, d))
K_mat = torch.rand((N, d))
V_mat = torch.rand((N, d))
# 执行标准的pytorch softmax和attention计算
expected_softmax = torch.softmax(Q_mat @ K_mat.T, dim=1)
expected_attention = expected_softmax @ V_mat
# 分块(tiling)尺寸,以SRAM的大小计算得到
Br = 4
Bc = d
# flash attention算法流程的第2步,首先在HBM中创建用于存储输出结果的O,全部初始化为0
O = torch.zeros((N, d))
# flash attention算法流程的第2步,用来存储softmax的分母值,在HBM中创建
l = torch.zeros((N, 1))
# flash attention算法流程的第2步,用来存储每个block的最大值,在HBM中创建
m = torch.full((N, 1), -torch.inf)
# 算法流程的第5步,执行外循环
for block_start_Bc in range(0, N, Bc):
block_end_Bc = block_start_Bc + Bc
# line 6, load a block from matmul input tensor
# 算法流程第6步,从HBM中load Kj, Vj的一个block到SRAM
Kj = K_mat[block_start_Bc:block_end_Bc, :] # shape Bc x d
Vj = V_mat[block_start_Bc:block_end_Bc, :] # shape Bc x d
# 算法流程第7步,执行内循环
for block_start_Br in range(0, N, Br):
block_end_Br = block_start_Br + Br
# 算法流程第8行,从HBM中分别load以下几项到SRAM中
mi = m[block_start_Br:block_end_Br, :] # shape Br x 1
li = l[block_start_Br:block_end_Br, :] # shape Br x 1
Oi = O[block_start_Br:block_end_Br, :] # shape Br x d
Qi = Q_mat[block_start_Br:block_end_Br, :] # shape Br x d
# 算法流程第9行
Sij = Qi @ Kj.T # shape Br x Bc
# 算法流程第10行,计算当前block每行的最大值
mij_hat = torch.max(Sij, dim=1).values[:, None]
# 算法流程第10行,计算softmax的分母
pij_hat = torch.exp(Sij - mij_hat)
lij_hat = torch.sum(pij_hat, dim=1)[:, None]
# 算法流程第11行,找到当前block的每行最大值以及之前的最大值
mi_new = torch.max(torch.column_stack([mi, mij_hat]), dim=1).values[:, None]
# 算法流程第11行,计算softmax的分母,但是带了online计算的校正,此公式与前面说的online safe softmax不一致,但是是同样的数学表达式,只是从针对标量的逐个计算扩展到了针对逐个向量的计算
li_new = torch.exp(mi - mi_new) * li + torch.exp(mij_hat - mi_new) * lij_hat
# 算法流程第12行,计算每个block的输出值
Oi = (li * torch.exp(mi - mi_new) * Oi / li_new) + (torch.exp(mij_hat - mi_new) * pij_hat / li_new) @ Vj
# 算法流程第13行
m[block_start_Br:block_end_Br, :] = mi_new # row max
l[block_start_Br:block_end_Br, :] = li_new # softmax denominator
# 算法流程第12行,将Oi再写回到HBM
O[block_start_Br:block_end_Br, :] = Oi
print(torch.allclose(O, expected_attention))
flash attention2 实现:
import torch
torch.manual_seed(456)
N, d = 16, 8
Q_mat = torch.rand((N, d))
K_mat = torch.rand((N, d))
V_mat = torch.rand((N, d))
# 执行标准的pytorch softmax和attention计算
expected_softmax = torch.softmax(Q_mat @ K_mat.T, dim=1)
expected_attention = expected_softmax @ V_mat
# 分块(tiling)尺寸,以SRAM的大小计算得到
Br = 4
Bc = d
O = torch.zeros((N, d))
# 算法流程第3步,执行外循环
for block_start_Br in range(0, N, Br):
block_end_Br = block_start_Br + Br
# 算法流程第4步,从HBM中load Qi 的一个block到SRAM
Qi = Q_mat[block_start_Br:block_end_Br, :]
# 算法流程第5步,初始化每个block的值
Oi = torch.zeros((Br, d)) # shape Br x d
li = torch.zeros((Br, 1)) # shape Br x 1
mi = torch.full((Br, 1), -torch.inf) # shape Br x 1
# 算法流程第6步,执行内循环
for block_start_Bc in range(0, N, Bc):
block_end_Bc = block_start_Bc + Bc
# 算法流程第7步,load Kj, Vj到SRAM
Kj = K_mat[block_start_Bc:block_end_Bc, :]
Vj = V_mat[block_start_Bc:block_end_Bc, :]
# 算法流程第8步
Sij = Qi @ Kj.T
# 算法流程第9步
mi_new = torch.max(torch.column_stack([mi, torch.max(Sij, dim=1).values[:, None]]), dim=1).values[:, None]
Pij_hat = torch.exp(Sij - mi_new)
li = torch.exp(mi - mi_new) * li + torch.sum(Pij_hat, dim=1)[:, None]
# 算法流程第10步
Oi = Oi * torch.exp(mi - mi_new) + Pij_hat @ Vj
mi = mi_new
# 第12步
Oi = Oi / li
# 第14步
O[block_start_Br:block_end_Br, :] = Oi
print(torch.allclose(O, expected_attention))
import torch
import torch.nn.functional as F
from rich import print
from torch.backends.cuda import sdp_kernel #内核计算
from enum import IntEnum
import torch.utils.benchmark as benchmark
device = "cuda" if torch.cuda.is_available() else "cpu" #cudnn 需要使用gpu
# 超参数定义
batch_size = 64
max_sequence_len = 256
num_heads = 32
embed_dimension = 32
dtype = torch.float16
# 模拟 q k v
query = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
key = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
value = torch.rand(batch_size, num_heads, max_sequence_len, embed_dimension, device=device, dtype=dtype)
# 定义一个计时器:
def torch_timer(f, *args, **kwargs):
t0 = benchmark.Timer(
stmt="f(*args, **kwargs)", globals={"args": args, "kwargs": kwargs, "f": f}
)
return t0.blocked_autorange().mean * 1e6
# torch.backends.cuda中也实现了,这里拿出了为了好理解backend_map是啥
class SDPBackend(IntEnum):
r"""
Enum class for the scaled dot product attention backends.
"""
ERROR = -1
MATH = 0
FLASH_ATTENTION = 1
EFFICIENT_ATTENTION = 2
# 使用上下文管理器context manager来
# 其他三种方案,字典映射
backend_map = {
SDPBackend.MATH: { #启用pytorch 实现
"enable_math": True,
"enable_flash": False,
"enable_mem_efficient": False},
SDPBackend.FLASH_ATTENTION: { #启用flashattention
"enable_math": False,
"enable_flash": True,
"enable_mem_efficient": False},
SDPBackend.EFFICIENT_ATTENTION: { #启用memory_efficient attention
"enable_math": False,
"enable_flash": False,
"enable_mem_efficient": True}
}
# 基本版,不指定
print(f"基本对照方案 运行时间: {torch_timer(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
# 基本对照方案 运行时间: 558.831 microseconds
#内核中运行
with sdp_kernel(**backend_map[SDPBackend.MATH]):
print(f"math 运行时间: {torch_timer(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
# math 运行时间: 1013.422 microseconds
with sdp_kernel(**backend_map[SDPBackend.FLASH_ATTENTION]):
try:
print(f"flash attention 运行时间: {torch_timer(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
except RuntimeError:
print("FlashAttention is not supported")
# flash attention 运行时间: 557.343 microseconds
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
try:
print(f"Memory efficient 运行时间: {torch_timer(F.scaled_dot_product_attention, query, key, value):.3f} microseconds")
except RuntimeError:
print("EfficientAttention is not supported")
# Memory efficient 运行时间: 428.007 microseconds