使用如下的自定义的多层嵌套网络进行训练:
class FC1_bot(nn.Module):
def __init__(self):
super(FC1_bot, self).__init__()
self.embeddings = nn.Sequential(
nn.Linear(10, 10)
)
def forward(self, x):
emb = self.embeddings(x)
return emb
class FC1_top(nn.Module):
def __init__(self):
super(FC1_top, self).__init__()
self.prediction = nn.Sequential(
nn.Dropout(p=0.5),
nn.Linear(10, 10)
)
def forward(self, x):
logit = self.prediction(x)
return logit
class FC1(nn.Module):
def __init__(self, num):
super(FC1, self).__init__()
self.num = num
self.bot = []
for _ in range(num):
self.bot.append(FC1_bot())
self.top = FC1_top()
self.softmax = nn.Softmax(dim=1)
def forward(self, x):
x = list(x)
emb = []
for i in range(self.num):
emb.append(self.bot[i](x[i]))
agg_emb = self._aggregate(emb)
logit = self.top(agg_emb)
pred = self.softmax(logit)
return emb, pred
def _aggregate(self, x):
# Note: x is a list of tensors.
return torch.cat(x, dim=1)
训练的代码如下:
num = 4
model = FC1(num)
optimizer_entire = torch.optim.SGD(model.parameters(), lr=0.01)
def train(self):
# train entire model
self.model.train()
for epoch in range(self.args.epochs):
pred = self.model(data)
loss = torch.nn.CrossEntropyLoss(pred, labels)
# zero grad for all optimizers
optimizer_entire.zero_grad()
loss.backward()
# update parameters for all optimizers
optimizer_entire.step()
需要给所有用到的模型参数都设置optimizer,否则只有top部分的参数在训练,底层的会得到gradient,但parameter不会更新。
num = 4
model = FC1(num)
optimizer_entire = torch.optim.SGD(model.parameters(), lr=0.01)
optimizer_top = torch.optim.SGD(model.top.parameters(), lr=0.01)
optimizer_bot = []
for i in range(num):
optimizer_passive.append(torch.optim.SGD(model.passive[i].parameters(), lr=0.01))
def train(self):
# train entire model
self.model.train()
self.model.top.train()
for i in range(self.args.num):
self.model.bot[i].train()
for epoch in range(self.args.epochs):
pred = self.model(data)
loss = torch.nn.CrossEntropyLoss(pred, labels)
# zero grad for all optimizers
optimizer_entire.zero_grad()
optimizer_top.zero_grad()
for i in range(num):
optimizer_bot[i].zero_grad()
loss.backward()
# update parameters for all optimizers
optimizer_entire.step()
optimizer_top.step()
for i in range(num):
optimizer_bot[i].step()