上一篇文章中我们提到,CRNN模型中用于预测特征序列上下文的模块为双向LSTM模块,本篇中就来针对该模块的结构和实现做一些理解。
Bidirectional LSTM模块结构如下图所示:
在Pytorch中,已经集成了LSTM模块,定义如下:
CLASStorch.nn.LSTM(self, input_size, hidden_size, num_layers=1, bias=True, batch_first=False, dropout=0.0, bidirectional=False, proj_size=0, device=None, dtype=None)
?参数说明:
- input_size:输入的feature数;
- hidden_size:隐藏状态h的的feature数;
- num_layers:递归层的数量。如果num_layers=2,意味着将两个LSTM堆叠在一起,第二个LSTM模块的输入为第一个LSTM的输出,由第二个LSTM输出最终结果。
- bias:偏置,默认为True,若设为False,则不使用b_ih和b_hh(这两个参数会在下文说明)
- batch_first:为True时,输入和输出形状为(batch, seq, feature),否则为(seq, batch, feature)。
- dropout:默认为0。若为非零值,则在每个LSTM层的输出上引入dropout层,dropout概率为设置的dropout值。
- bidirectional:默认为False。若为True,则为双向LSTM,在CRNN网络中,我们将该参数设置为True。?
- proj_size:默认为0。若设置为非零值,意味着使用映射大小的size,关于proj_size,pytorch文档中有如下说明:
对于输入序列中的每个元素,每一层需要进行如下计算:
其中,
- xt?为t时刻的输入;
- ht是t时刻的隐藏状态,ht-1为t-1时刻的隐藏状态;
- ct为t时刻的元组状态(cell state);
- it、ft、gt和ot分别为input gate、forget gate、cell gate、和output gate;
- σ为sigmoid 函数;
- ⊙为Hadamard乘积(矩阵点乘)。
对于多层LSTM,输入为上一层的隐含状态,Pytorch文档中对此有较详细的说明:
关于可学习模型参数的说明:
- weight_ih_l[k]:第k层的input-hidden权重参数(W_ii|W_if|W_ig|W_io),k=0时,形状为(4*hidden_size, input_size) ;k>0时,,若proj_size为0,则权重参数形状为(4*hidden_size, num_directions * hidden_size),若proj_size>0,则权重参数形状为(4*hidden_size, num_directions * proj_size)。
- weight_hh_l[k]:第k层的hidden-hidden权重参数(W_hi|W_hf|W_hg|W_ho)。若proj_size=0,权重形状:(4*hidden_size, hidden_size);若proj_size>0,则权重形状为(4*hidden_size, proj_size)。
- bias_ih_l[k]:第k层的input-hidden偏置(b_hi|b_hf|b_hg|b_ho),形状为(4*hidden_size)。
- bias_hh_l[k]:第k层的hidden-hidden偏置(b_hi|b_hf|b_hg|b_ho),形状为(4*hidden_size)。
- weight_hr_l[k]:第k层的projection权重,形状为(proj_size, hidden_size),该参数只有在proj_size>0的时候存在。
- weight_ih_l[k]_reverse:weight_ih_l[k]的反向权重,只在bidirectional=True的时候存在。
- weight_hh_l[k]_reverse:weight_hh_l[k]的反向权重,只在bidirectional=True的时候存在。
- bias_ih_l[k]_reverse:bias_ih_l[k]的反向权重,只在bidirectional=True的时候存在。
- bias_hh_l[k]_reverse:bias_hh_l[k]的反向权重,只在bidirectional=True的时候存在。
- weight_hr_l[k]_reverse:weight_hr_l[k]的反向权重,只在bidirectional=True的时候存在。
参考资料:LSTM — PyTorch 2.1 documentation