基础GCN

发布时间:2023年12月27日

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import numpy as np
import matplotlib.pyplot as plt
import torchvision 
from torchvision import transforms


transform=transforms.Compose([
    transforms.ToTensor(),   # 0-1归一化;c,h,w,
    transforms.Normalize(0.5,0.5)
])

train_ds=torchvision.datasets.MNIST('data',
                                   train=True,
                                   transform=transform,
                                   download=True)

dataloader=torch.utils.data.DataLoader(train_ds,batch_size=512,shuffle=True)
imgs,_=next(iter(dataloader))
imgs.shape

#生成器  输入长度100的正态分布随机分布,输出是 (1,28,28)形状的tensor
class Generator(nn.Module):
    def __init__(self):
        super(Generator,self).__init__()
        self.main=nn.Sequential(
                                nn.Linear(100,256),
                                nn.ReLU(),
                                nn.Linear(256,512),
                                nn.ReLU(),
                                nn.Linear(512,28*28),
                                nn.Tanh()
        )
    def forward(self,x):     # x表示长度为100的noise输入
        img=self.main(x)
        img=img.view(-1,28,28)
        return img
    
# 判别器    输入是一张图片形状的张量,输出为二分类的概率值,输出使用sigmoid激活 0-1
# pytorch提供的二分类损失函数 BCEloss 计算二分类的交叉熵损失   判别器推荐使用  LeakyReLU  激活函数

class Discriminator(nn.Module):
    def __init__(self):
        super(Discriminator,self).__init__()
        self.main=nn.Sequential(
                                nn.Linear(28*28,512),
                                nn.LeakyReLU(),
                                nn.Linear(512,256),
                                nn.LeakyReLU(),
                                nn.Linear(256,1),
                                nn.Sigmoid()
        )
        
    def forward(self,x):
        x=x.view(-1,28*28)
        x=self.main(x)
        return x

device = 'cuda' if torch.cuda.is_available() else 'cpu'

gen=Generator().to(device)
dis=Discriminator().to(device)
d_optim=torch.optim.Adam(dis.parameters(),lr=0.0001)
g_optim=torch.optim.Adam(gen.parameters(),lr=0.0001)

loss_fn=torch.nn.BCELoss()

def gen_img_plot(model,test_input):
    prediction = np.squeeze(model(test_input).detach().cpu().numpy())
    fig = plt.figure(figsize=(4,4))
    for i in range(16):
        plt.subplot(4,4,i+1)
        plt.imshow((prediction[i]+1)/2)
        plt.axis('off')
    plt.show()

test_input = torch.randn(16,100,device=device)

D_loss = []
G_loss = []


# 训练循环
for epoch in range(20):
    d_epoch_loss = 0
    g_epoch_loss = 0
    count=len(dataloader)
    
    for step, (img, _) in enumerate(dataloader):
        img = img.to(device)
        size=img.size(0)
        random_noisy=torch.randn(size,100,device=device)
        
        d_optim.zero_grad()
        real_output=dis(img)   #判别器输入真实图片
        d_real_loss = loss_fn(real_output,
                              torch.ones_like(real_output))    # 判别器在真实图像上的损失
        d_real_loss.backward()
        
        gen_img = gen(random_noisy)  #生成图像
        fake_output = dis(gen_img.detach())   # 判别器输入生成图片,优化目标是判别器,对生成器作梯度截断
        d_fake_loss = loss_fn(fake_output,
                              torch.zeros_like(fake_output))   # 判别器在生成器上的损失
        
        d_fake_loss.backward()
        
        d_loss = d_real_loss + d_fake_loss
        d_optim.step()
        
        # 以上是判别器损失和优化,以下是生成器损失和优化
        g_optim.zero_grad()
        gen_img = gen(random_noisy)
        fake_output = dis(gen_img)
        g_loss =  loss_fn(fake_output,
                          torch.ones_like(fake_output))   # 生成器的损失
        g_loss.backward()
        g_optim.step()
        
        with torch.no_grad():
            d_epoch_loss += d_loss
            g_epoch_loss += g_loss
            
        with torch.no_grad():
            d_epoch_loss /= count
            g_epoch_loss /= count
            D_loss.append(d_epoch_loss)
            G_loss.append(g_epoch_loss)
            print('Epoch:',epoch)
            gen_img_plot(gen,test_input)

文章来源:https://blog.csdn.net/m0_56294205/article/details/135236350
本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。