import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.fft
from layers.Embed import DataEmbedding
from layers.Conv_Blocks import Inception_Block_V1
#定义一个用于执行傅里叶变换的函数,并根据傅里叶变换的振幅找到数据的周期
def FFT_for_Period(x, k=2):
? ? # [B, T, C]
? ? xf = torch.fft.rfft(x, dim=1) ? #执行实数的快速傅里叶变换
? ? # find period by amplitudes
? ??
? ? frequency_list = abs(xf).mean(0).mean(-1) ? ? #计算变换后振幅的均值
? ??
? ? frequency_list[0] = 0 ? ?#忽略直流分量
? ? _, top_list = torch.topk(frequency_list, k) ? #找到top k个频率成分
? ? top_list = top_list.detach().cpu().numpy() ? ?#从gpu转移cpu并转换为numpy数组
? ? period = x.shape[1] // top_list ? ? ? ? ? ? ? #计算周期
? ? return period, abs(xf).mean(-1)[:, top_list] ?#返回周期和这些周期对应的平均振幅
#定义神经网络模块,继承nn父类,用于处理时间序列中的时间周期变化
class TimesBlock(nn.Module):
? ? def __init__(self, configs):
? ? ? ? #调用父类初始化函数
? ? ? ? super(TimesBlock, self).__init__()
? ? ? ? #配置参数
? ? ? ? self.seq_len = configs.seq_len ? ?#序列长度
? ? ? ? self.pred_len = configs.pred_len ?#预测长度
? ? ? ? self.k = configs.top_k ? ? ? ? ? ?#顶部元素数量
? ? ? ? # parameter-efficient design
? ? ? ? #构建一个卷积网络作为模型的一部分
? ? ? ? self.conv = nn.Sequential(
? ? ? ? ? ? Inception_Block_V1(configs.d_model, configs.d_ff,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?num_kernels=configs.num_kernels),
? ? ? ? ? ? nn.GELU(),
? ? ? ? ? ? Inception_Block_V1(configs.d_ff, configs.d_model,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?num_kernels=configs.num_kernels)
? ? ? ? )
? ? #前向传播函数,处理时间序列x,并返回经过处理的序列
? ? def forward(self, x):
? ? ? ? B, T, N = x.size()
? ? ? ? period_list, period_weight = FFT_for_Period(x, self.k)
? ? ? ? res = []
? ? ? ? for i in range(self.k):
? ? ? ? ? ? period = period_list[i]
? ? ? ? ? ? # padding
? ? ? ? ? ? if (self.seq_len + self.pred_len) % period != 0:
? ? ? ? ? ? ? ? length = (
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?((self.seq_len + self.pred_len) // period) + 1) * period
? ? ? ? ? ? ? ? padding = torch.zeros([x.shape[0], (length - (self.seq_len + self.pred_len)), x.shape[2]]).to(x.device)
? ? ? ? ? ? ? ? out = torch.cat([x, padding], dim=1)
? ? ? ? ? ? else:
? ? ? ? ? ? ? ? length = (self.seq_len + self.pred_len)
? ? ? ? ? ? ? ? out = x
? ? ? ? ? ? # reshape
? ? ? ? ? ? out = out.reshape(B, length // period, period,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? N).permute(0, 3, 1, 2).contiguous()
? ? ? ? ? ? # 2D conv: from 1d Variation to 2d Variation
? ? ? ? ? ? out = self.conv(out)
? ? ? ? ? ? # reshape back
? ? ? ? ? ? out = out.permute(0, 2, 3, 1).reshape(B, -1, N)
? ? ? ? ? ? res.append(out[:, :(self.seq_len + self.pred_len), :])
? ? ? ? res = torch.stack(res, dim=-1)
? ? ? ? # adaptive aggregation
? ? ? ? period_weight = F.softmax(period_weight, dim=1)
? ? ? ? period_weight = period_weight.unsqueeze(
? ? ? ? ? ? 1).unsqueeze(1).repeat(1, T, N, 1)
? ? ? ? res = torch.sum(res * period_weight, -1)
? ? ? ? # residual connection
? ? ? ? res = res + x
? ? ? ? return res
#model类
class Model(nn.Module):
? ? """
? ? Paper link: https://openreview.net/pdf?id=ju_Uqw384Oq
? ? """
? ? #构造函数,初始化数据
? ? def __init__(self, configs):
? ? ? ? super(Model, self).__init__()
? ? ? ? self.configs = configs
? ? ? ? self.task_name = configs.task_name
? ? ? ? self.seq_len = configs.seq_len
? ? ? ? self.label_len = configs.label_len
? ? ? ? self.pred_len = configs.pred_len
? ? ? ? self.model = nn.ModuleList([TimesBlock(configs)
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? for _ in range(configs.e_layers)])
? ? ? ? self.enc_embedding = DataEmbedding(configs.enc_in, configs.d_model, configs.embed, configs.freq,
? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ? ?configs.dropout)
? ? ? ? self.layer = configs.e_layers
? ? ? ? self.layer_norm = nn.LayerNorm(configs.d_model)
? ? ? ? if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
? ? ? ? ? ? self.predict_linear = nn.Linear(
? ? ? ? ? ? ? ? self.seq_len, self.pred_len + self.seq_len)
? ? ? ? ? ? self.projection = nn.Linear(
? ? ? ? ? ? ? ? configs.d_model, configs.c_out, bias=True)
? ? ? ? if self.task_name == 'imputation' or self.task_name == 'anomaly_detection':
? ? ? ? ? ? self.projection = nn.Linear(
? ? ? ? ? ? ? ? configs.d_model, configs.c_out, bias=True)
? ? ? ? if self.task_name == 'classification':
? ? ? ? ? ? self.act = F.gelu
? ? ? ? ? ? self.dropout = nn.Dropout(configs.dropout)
? ? ? ? ? ? self.projection = nn.Linear(
? ? ? ? ? ? ? ? configs.d_model * configs.seq_len, configs.num_class)
? ? #预测函数,用于长短期预测任务
? ? def forecast(self, x_enc, x_mark_enc, x_dec, x_mark_dec):
? ? ? ? # Normalization from Non-stationary Transformer
? ? ? ? means = x_enc.mean(1, keepdim=True).detach()
? ? ? ? x_enc = x_enc - means
? ? ? ? stdev = torch.sqrt(
? ? ? ? ? ? torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
? ? ? ? x_enc /= stdev
? ? ? ? # embedding
? ? ? ? enc_out = self.enc_embedding(x_enc, x_mark_enc) ?# [B,T,C]
? ? ? ? enc_out = self.predict_linear(enc_out.permute(0, 2, 1)).permute(
? ? ? ? ? ? 0, 2, 1) ?# align temporal dimension
? ? ? ? # TimesNet
? ? ? ? for i in range(self.layer):
? ? ? ? ? ? enc_out = self.layer_norm(self.model[i](enc_out))
? ? ? ? # porject back
? ? ? ? dec_out = self.projection(enc_out)
? ? ? ? # De-Normalization from Non-stationary Transformer
? ? ? ? dec_out = dec_out * \
? ? ? ? ? ? ? ? ? (stdev[:, 0, :].unsqueeze(1).repeat(
? ? ? ? ? ? ? ? ? ? ? 1, self.pred_len + self.seq_len, 1))
? ? ? ? dec_out = dec_out + \
? ? ? ? ? ? ? ? ? (means[:, 0, :].unsqueeze(1).repeat(
? ? ? ? ? ? ? ? ? ? ? 1, self.pred_len + self.seq_len, 1))
? ? ? ? return dec_out
? ? #插补函数,用于数据插补
? ? def imputation(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask):
? ? ? ? # Normalization from Non-stationary Transformer
? ? ? ? means = torch.sum(x_enc, dim=1) / torch.sum(mask == 1, dim=1)
? ? ? ? means = means.unsqueeze(1).detach()
? ? ? ? x_enc = x_enc - means
? ? ? ? x_enc = x_enc.masked_fill(mask == 0, 0)
? ? ? ? stdev = torch.sqrt(torch.sum(x_enc * x_enc, dim=1) /
? ? ? ? ? ? ? ? ? ? ? ? ? ?torch.sum(mask == 1, dim=1) + 1e-5)
? ? ? ? stdev = stdev.unsqueeze(1).detach()
? ? ? ? x_enc /= stdev
? ? ? ? # embedding
? ? ? ? enc_out = self.enc_embedding(x_enc, x_mark_enc) ?# [B,T,C]
? ? ? ? # TimesNet
? ? ? ? for i in range(self.layer):
? ? ? ? ? ? enc_out = self.layer_norm(self.model[i](enc_out))
? ? ? ? # porject back
? ? ? ? dec_out = self.projection(enc_out)
? ? ? ? # De-Normalization from Non-stationary Transformer
? ? ? ? dec_out = dec_out * \
? ? ? ? ? ? ? ? ? (stdev[:, 0, :].unsqueeze(1).repeat(
? ? ? ? ? ? ? ? ? ? ? 1, self.pred_len + self.seq_len, 1))
? ? ? ? dec_out = dec_out + \
? ? ? ? ? ? ? ? ? (means[:, 0, :].unsqueeze(1).repeat(
? ? ? ? ? ? ? ? ? ? ? 1, self.pred_len + self.seq_len, 1))
? ? ? ? return dec_out
? ? #异常检测函数,用于异常检测任务
? ? def anomaly_detection(self, x_enc):
? ? ? ? # Normalization from Non-stationary Transformer
? ? ? ? means = x_enc.mean(1, keepdim=True).detach()
? ? ? ? x_enc = x_enc - means
? ? ? ? stdev = torch.sqrt(
? ? ? ? ? ? torch.var(x_enc, dim=1, keepdim=True, unbiased=False) + 1e-5)
? ? ? ? x_enc /= stdev
? ? ? ? # embedding
? ? ? ? enc_out = self.enc_embedding(x_enc, None) ?# [B,T,C]
? ? ? ? # TimesNet
? ? ? ? for i in range(self.layer):
? ? ? ? ? ? enc_out = self.layer_norm(self.model[i](enc_out))
? ? ? ? # porject back
? ? ? ? dec_out = self.projection(enc_out)
? ? ? ? # De-Normalization from Non-stationary Transformer
? ? ? ? dec_out = dec_out * \
? ? ? ? ? ? ? ? ? (stdev[:, 0, :].unsqueeze(1).repeat(
? ? ? ? ? ? ? ? ? ? ? 1, self.pred_len + self.seq_len, 1))
? ? ? ? dec_out = dec_out + \
? ? ? ? ? ? ? ? ? (means[:, 0, :].unsqueeze(1).repeat(
? ? ? ? ? ? ? ? ? ? ? 1, self.pred_len + self.seq_len, 1))
? ? ? ? return dec_out
? ? #分类函数,用于数据分类
? ? def classification(self, x_enc, x_mark_enc):
? ? ? ? # embedding
? ? ? ? enc_out = self.enc_embedding(x_enc, None) ?# [B,T,C]
? ? ? ? # TimesNet
? ? ? ? for i in range(self.layer):
? ? ? ? ? ? enc_out = self.layer_norm(self.model[i](enc_out))
? ? ? ? # Output
? ? ? ? # the output transformer encoder/decoder embeddings don't include non-linearity
? ? ? ? output = self.act(enc_out)
? ? ? ? output = self.dropout(output)
? ? ? ? # zero-out padding embeddings
? ? ? ? output = output * x_mark_enc.unsqueeze(-1)
? ? ? ? # (batch_size, seq_length * d_model)
? ? ? ? output = output.reshape(output.shape[0], -1)
? ? ? ? output = self.projection(output) ?# (batch_size, num_classes)
? ? ? ? return output
? ? #根据任务类型选择相应的处理函数
? ? def forward(self, x_enc, x_mark_enc, x_dec, x_mark_dec, mask=None):
? ? ? ? if self.task_name == 'long_term_forecast' or self.task_name == 'short_term_forecast':
? ? ? ? ? ? dec_out = self.forecast(x_enc, x_mark_enc, x_dec, x_mark_dec)
? ? ? ? ? ? return dec_out[:, -self.pred_len:, :] ?# [B, L, D]
? ? ? ? if self.task_name == 'imputation':
? ? ? ? ? ? dec_out = self.imputation(
? ? ? ? ? ? ? ? x_enc, x_mark_enc, x_dec, x_mark_dec, mask)
? ? ? ? ? ? return dec_out ?# [B, L, D]
? ? ? ? if self.task_name == 'anomaly_detection':
? ? ? ? ? ? dec_out = self.anomaly_detection(x_enc)
? ? ? ? ? ? return dec_out ?# [B, L, D]
? ? ? ? if self.task_name == 'classification':
? ? ? ? ? ? dec_out = self.classification(x_enc, x_mark_enc)
? ? ? ? ? ? return dec_out ?# [B, N]
? ? ? ? return None
?