pytorch RNN

发布时间:2023年12月20日

RNN

from torch import nn
import torch
# # 表示feature_len=100(如每个单词用100维向量表示), hidden_len=10(隐藏单元的尺寸)
# rnn = nn.RNN(100, 10,1,batch_first=True)
# # odict_keys(['weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0'])
# print(rnn._parameters.keys())
#
# print(rnn.weight_ih_l0.shape)  # torch.Size([10, 100])
# print(rnn.weight_hh_l0.shape)  # torch.Size([10, 10])
# print(rnn.bias_ih_l0.shape)  # torch.Size([10])
# print(rnn.bias_hh_l0.shape)  # torch.Size([10])
#
# x=torch.randn(3,10,100)  #第一维是batch,第二维是seq_len序列长度,第三维是词向量的维度
# out,h=rnn(x,torch.zeros(1,3,10))
# print(out.shape)
# print(h.shape)

# 表示feature_len=100(如每个单词用100维向量表示), hidden_len=10(隐藏单元的尺寸)
rnn = nn.RNN(100, 10,2)
# odict_keys(['weight_ih_l0', 'weight_hh_l0', 'bias_ih_l0', 'bias_hh_l0'])
print(rnn._parameters.keys())

# print(rnn.weight_ih_l0.shape)  # torch.Size
文章来源:https://blog.csdn.net/qq_40107571/article/details/135105833
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。