在 PyTorch 中,同一种操作方式可能有很多种形式
官方文档相关信息地址:官方文档相关信息地址
def add_one():
random_a = torch.rand(4, 2)
random_b = torch.rand(4, 2)
print(random_a + random_b)
if __name__ == '__main__':
add_one()
print result:
Connected to pydev debugger (build 193.7288.30)
tensor([[0.4214, 0.9930],
[1.8052, 1.3819],
[0.3334, 1.6448],
[0.2293, 1.3217]])
def add_two():
random_a = torch.rand(7, 5)
random_b = torch.rand(7, 5)
print(torch.add(random_a,random_b))
if __name__ == '__main__':
add_two()
print result:
Connected to pydev debugger (build 193.7288.30)
tensor([[0.9120, 0.6793, 1.8955, 1.0702, 0.5505],
[0.2064, 0.1667, 1.3587, 1.6728, 1.2151],
[1.3669, 0.5354, 0.8769, 1.1060, 1.7966],
[0.6240, 1.2662, 1.1661, 0.9901, 1.1229],
[1.1226, 1.2311, 0.8754, 0.4553, 1.7378],
[0.7544, 0.3557, 0.7631, 0.4962, 1.1400],
[0.5023, 0.3433, 0.9610, 0.9195, 1.8071]])
def add_three():
"""
info: In-place version of add()
"""
random_a = torch.rand(6, 3)
random_b = torch.rand(6, 3)
print(random_a.add_(random_b))
if __name__ == '__main__':
add_three()
print result:
Connected to pydev debugger (build 193.7288.30)
tensor([[0.6409, 0.8297, 1.8872],
[1.4139, 0.8349, 1.4194],
[1.0500, 1.1814, 0.8922],
[0.2682, 0.8406, 1.3403],
[0.8939, 1.1223, 0.8175],
[1.0493, 0.4294, 0.9734]])