

import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt
import numpy as np
import math
from math import sqrt
import os

class AutoCorrelation(nn.Module):
? ? """
? ? AutoCorrelation Mechanism with the following two phases:
? ? (1) period-based dependencies discovery
? ? (2) time delay aggregation
? ? This block can replace the self-attention family mechanism seamlessly.
? ? """
? ? #构造函数
? ? def __init__(self, mask_flag=True, factor=1, scale=None, attention_dropout=0.1, output_attention=False):
? ? ? ? super(AutoCorrelation, self).__init__()
? ? ? ? self.factor = factor ? ?#缩放因子
? ? ? ? self.scale = scale ? ? ?#尺度
? ? ? ? self.mask_flag = mask_flag ? ? ? ? ? ? ? ? ? #掩码
? ? ? ? self.output_attention = output_attention ? ? #注意力权重
? ? ? ? self.dropout = nn.Dropout(attention_dropout) #注意力机制的dropout率
? ? #实现自相关的训练聚合,该方法首先要找到最重要的时间延迟,然后根据延迟进行聚合
? ? def time_delay_agg_training(self, values, corr):
? ? ? ? """
? ? ? ? SpeedUp version of Autocorrelation (a batch-normalization style design)
? ? ? ? This is for the training phase.
? ? ? ? """
? ? ? ? head = values.shape[1]
? ? ? ? channel = values.shape[2]
? ? ? ? length = values.shape[3]
? ? ? ? # find top k
? ? ? ? top_k = int(self.factor * math.log(length))
? ? ? ? mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
? ? ? ? index = torch.topk(torch.mean(mean_value, dim=0), top_k, dim=-1)[1]
? ? ? ? weights = torch.stack([mean_value[:, index[i]] for i in range(top_k)], dim=-1)
? ? ? ? # update corr
? ? ? ? tmp_corr = torch.softmax(weights, dim=-1)
? ? ? ? # aggregation
? ? ? ? tmp_values = values
? ? ? ? delays_agg = torch.zeros_like(values).float()
? ? ? ? for i in range(top_k):
? ? ? ? ? ? pattern = torch.roll(tmp_values, -int(index[i]), -1)
? ? ? ? ? ? delays_agg = delays_agg + pattern * \
? ? ? ? ? ? ? ? ? ? ? ? ?(tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
? ? ? ? return delays_agg
? ? #用于推断阶段,处理值和相关性,以计算时间延迟的聚合
? ? def time_delay_agg_inference(self, values, corr):
? ? ? ? """
? ? ? ? SpeedUp version of Autocorrelation (a batch-normalization style design)
? ? ? ? This is for the inference phase.
? ? ? ? """
? ? ? ? batch = values.shape[0]
? ? ? ? head = values.shape[1]
? ? ? ? channel = values.shape[2]
? ? ? ? length = values.shape[3]
? ? ? ? # index init
? ? ? ? #init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).cuda()
? ? ? ? init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1)
? ? ? ? # find top k
? ? ? ? top_k = int(self.factor * math.log(length))
? ? ? ? mean_value = torch.mean(torch.mean(corr, dim=1), dim=1)
? ? ? ? weights, delay = torch.topk(mean_value, top_k, dim=-1)
? ? ? ? # update corr
? ? ? ? tmp_corr = torch.softmax(weights, dim=-1)
? ? ? ? # aggregation
? ? ? ? tmp_values = values.repeat(1, 1, 1, 2)
? ? ? ? delays_agg = torch.zeros_like(values).float()
? ? ? ? for i in range(top_k):
? ? ? ? ? ? tmp_delay = init_index + delay[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length)
? ? ? ? ? ? pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
? ? ? ? ? ? delays_agg = delays_agg + pattern * \
? ? ? ? ? ? ? ? ? ? ? ? ?(tmp_corr[:, i].unsqueeze(1).unsqueeze(1).unsqueeze(1).repeat(1, head, channel, length))
? ? ? ? return delays_agg

? ? #实现标准自相关的方法,不进行任何类型的优化
? ? def time_delay_agg_full(self, values, corr):
? ? ? ? """
? ? ? ? Standard version of Autocorrelation
? ? ? ? """
? ? ? ? batch = values.shape[0]
? ? ? ? head = values.shape[1]
? ? ? ? channel = values.shape[2]
? ? ? ? length = values.shape[3]
? ? ? ? # index init
? ? ? ? init_index = torch.arange(length).unsqueeze(0).unsqueeze(0).unsqueeze(0).repeat(batch, head, channel, 1).cuda()
? ? ? ? # find top k
? ? ? ? top_k = int(self.factor * math.log(length))
? ? ? ? weights, delay = torch.topk(corr, top_k, dim=-1)
? ? ? ? # update corr
? ? ? ? tmp_corr = torch.softmax(weights, dim=-1)
? ? ? ? # aggregation
? ? ? ? tmp_values = values.repeat(1, 1, 1, 2)
? ? ? ? delays_agg = torch.zeros_like(values).float()
? ? ? ? for i in range(top_k):
? ? ? ? ? ? tmp_delay = init_index + delay[..., i].unsqueeze(-1)
? ? ? ? ? ? pattern = torch.gather(tmp_values, dim=-1, index=tmp_delay)
? ? ? ? ? ? delays_agg = delays_agg + pattern * (tmp_corr[..., i].unsqueeze(-1))
? ? ? ? return delays_agg
? ? #这是 AutoCorrelation 类的核心,PyTorch在模型前向传播时自动调用它。它接受查询 (queries),键 (keys),值 (values) 以及注意力掩码 (attn_mask),应用傅立叶变换来发现周期性依赖,然后使用聚合函数来处理时间延迟。
? ? def forward(self, queries, keys, values, attn_mask):
? ? ? ? B, L, H, E = queries.shape
? ? ? ? _, S, _, D = values.shape
? ? ? ? if L > S:
? ? ? ? ? ? zeros = torch.zeros_like(queries[:, :(L - S), :]).float()
? ? ? ? ? ? values = torch.cat([values, zeros], dim=1)
? ? ? ? ? ? keys = torch.cat([keys, zeros], dim=1)
? ? ? ? else:
? ? ? ? ? ? values = values[:, :L, :, :]
? ? ? ? ? ? keys = keys[:, :L, :, :]

? ? ? ? # period-based dependencies
? ? ? ? q_fft = torch.fft.rfft(queries.permute(0, 2, 3, 1).contiguous(), dim=-1)
? ? ? ? k_fft = torch.fft.rfft(keys.permute(0, 2, 3, 1).contiguous(), dim=-1)
? ? ? ? res = q_fft * torch.conj(k_fft)
? ? ? ? corr = torch.fft.irfft(res, dim=-1)

? ? ? ? # time delay agg
? ? ? ? if self.training:
? ? ? ? ? ? V = self.time_delay_agg_training(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)
? ? ? ? else:
? ? ? ? ? ? V = self.time_delay_agg_inference(values.permute(0, 2, 3, 1).contiguous(), corr).permute(0, 3, 1, 2)

? ? ? ? if self.output_attention:
? ? ? ? ? ? return (V.contiguous(), corr.permute(0, 3, 1, 2))
? ? ? ? else:
? ? ? ? ? ? return (V.contiguous(), None)

class AutoCorrelationLayer(nn.Module):
? ? def __init__(self, correlation, d_model, n_heads, d_keys=None,
? ? ? ? ? ? ? ? ?d_values=None):
? ? ? ? super(AutoCorrelationLayer, self).__init__()

? ? ? ? d_keys = d_keys or (d_model // n_heads)
? ? ? ? d_values = d_values or (d_model // n_heads)

? ? ? ? self.inner_correlation = correlation
? ? ? ? self.query_projection = nn.Linear(d_model, d_keys * n_heads)
? ? ? ? self.key_projection = nn.Linear(d_model, d_keys * n_heads)
? ? ? ? self.value_projection = nn.Linear(d_model, d_values * n_heads)
? ? ? ? self.out_projection = nn.Linear(d_values * n_heads, d_model)
? ? ? ? self.n_heads = n_heads

? ? def forward(self, queries, keys, values, attn_mask):
? ? ? ? B, L, _ = queries.shape
? ? ? ? _, S, _ = keys.shape
? ? ? ? H = self.n_heads

? ? ? ? queries = self.query_projection(queries).view(B, L, H, -1)
? ? ? ? keys = self.key_projection(keys).view(B, S, H, -1)
? ? ? ? values = self.value_projection(values).view(B, S, H, -1)

? ? ? ? out, attn = self.inner_correlation(
? ? ? ? ? ? queries,
? ? ? ? ? ? keys,
? ? ? ? ? ? values,
? ? ? ? ? ? attn_mask
? ? ? ? )
? ? ? ? out = out.view(B, L, -1)

? ? ? ? return self.out_projection(out), attn
