【论文复现】使用高层api进行训练时acc达不到底层api实现的效果 - Python Paddle
已有作者用使paddle复现成功论文,但该作者是使用底层api实现的。 我在使用高层api对该项目进行复现时,遇到了loss很难收敛,并且acc达不到底层api实现的效果的情况 数据处理方式、优化器与超参等均相同 这是原作者实现的过程,5次epoch后acc达到79.97%
# 训练
def training():
global save_every
model.train()
if not os.path.exists(save_dir):
os.makedirs(save_dir)
best_acc=0
losses = []
for epoch in range(epochs):
if epoch > 180:
save_every = 2
# train for one epoch
start = time.time()
# model.train()
for _iter,(x,y) in enumerate(train_loader):
y = paddle.reshape(y, (-1, 1))
loss = loss_fn(model(x), y)
loss.backward()
opt.step()
opt.clear_grad()
losses.append(np.mean(loss.numpy()))
if _iter % 10 == 0:
print('iter:%d loss:%.4f'%(_iter,np.mean(losses)))
print('Time per epoch:%.2f,loss:%.4f'%(time.time()-start,np.mean(losses)))
if (epoch+1)%save_every==0 or epoch+1==epochs:
# evaluate on validation set
eval_acc,eval_loss = test()
print("Validation accuracy/loss: %.2f%%,%.4f"%(eval_acc, eval_loss))
model.train()
paddle.save(model.state_dict(),os.path.join(save_dir, 'checkpoint_{}.pdparams'.format(epoch)))
paddle.save(opt.state_dict(),os.path.join(save_dir, 'checkpoint_{}.pdopt'.format(epoch)))
if eval_acc > best_acc:
paddle.save(model.state_dict(),os.path.join(save_dir, 'checkpoint.pdparams'))
paddle.save(opt.state_dict(),os.path.join(save_dir, 'checkpoint.pdopt'))
best_acc = max(eval_acc, best_acc)
scheduler.step()
paddle.save(model.state_dict(),os.path.join(save_dir, 'model.pdparams'))
paddle.save(opt.state_dict(),os.path.join(save_dir, 'model.pdopt'))
print('Best accuracy on validation dataset: %.2f%%'%(best_acc))
def test():
model.eval()
accuracies = []
losses = []
for (x,y) in val_loader:
with paddle.no_grad():
logits = model(x)
y = paddle.reshape(y, (-1, 1))
loss = loss_fn(logits, y)
acc = acc_fn(logits, y)
accuracies.append(np.mean(acc.numpy()))
losses.append(np.mean(loss.numpy()))
return np.mean(accuracies)*100, np.mean(losses)
import warnings
warnings.filterwarnings("ignore", category=Warning)
model = WideResNet(28,10,20,0.3)
epochs = 400
save_every = 1
loss_fn = paddle.nn.CrossEntropyLoss()
acc_fn = paddle.metric.accuracy
scheduler=paddle.optimizer.lr.PiecewiseDecay(boundaries=[60,120,160,200,240,260,280],values=[0.05,0.01,0.002,0.0004,0.0002,0.0001,0.00005],verbose=True)
opt = paddle.optimizer.Momentum(parameters=model.parameters(), learning_rate=scheduler, momentum=0.9,weight_decay=0.0005)
save_dir = '/home/aistudio/models/cifar10/ResNet_wide'
training()
这是我使用高层api实现的过程,5次epoch后acc仅为48.01%
import math
from wide_resnet import WideResNet
model = paddle.Model(WideResNet(28,10,20,0.3))
loss_fn = paddle.nn.CrossEntropyLoss()
acc_fn = paddle.metric.Accuracy()
scheduler=paddle.optimizer.lr.PiecewiseDecay(boundaries=[60,120,160,200,240,260,280],values=[0.05,0.01,0.002,0.0004,0.0002,0.0001,0.00005],verbose=False)
opt = paddle.optimizer.Momentum(parameters=model.parameters(), learning_rate=scheduler, momentum=0.9,weight_decay=0.0005)
model.prepare(opt,loss_fn,acc_fn)
model.fit(train_loader, # 训练数据集
val_loader, # 评估数据集
epochs=400, # 总的训练轮次
save_freq=5,
verbose=1, # 日志展示格式
save_dir='./chk_points/',)
以下为aistudio项目 https://aistudio.baidu.com/aistudio/projectdetail/2304027?shared=1
3 Answer:
您好,我们已经收到了您的问题,会安排技术人员尽快解答您的问题,请耐心等待。请您再次检查是否提供了清晰的问题描述、复现代码、环境&版本、报错信息等。同时,您也可以通过查看官网API文档、常见问题、历史Issue、AI社区来寻求解答。祝您生活愉快~
Hi! We've received your issue and please be patient to get responded. We will arrange technicians to answer your questions as soon as possible. Please make sure that you have posted enough message to demo your request. You may also check out the API,FAQ,Github Issue and AI community to get the answer.Have a nice day!
这个是由于高层API lr_scheduler是每个step更新的,所以60个step后学习率就变成0.01了,然后很快继续下降。
有两种方式,一种如下图所示,另一种就是
boundaries=[60*epoch_size,120*epoch_size,1...]
这样也行
如上图修改后,训练情况如下:
这个是由于高层API lr_scheduler是每个step更新的,所以60个step后学习率就变成0.01了,然后很快继续下降。
有两种方式,一种如下图所示,另一种就是
boundaries=[60*epoch_size,120*epoch_size,1...]
这样也行如上图修改后,训练情况如下:
感谢
Read next
- [gradle/gradle] gradle Suggestion/Discussion: Public Gradle Slack Workspace/Gitter Channel - Groovy
- The .dartServer cache leads to persistent failure of the analyzer - Dart sdk
- Editor: crop story poster image and publisher logo upon selection - web-stories-wp
- nivo How to change color and size of label of the axis - TypeScript
- Unable to install PowerToys - installer will not run. C#
- examples broken (quick fix inside) - react-three-fiber
- Tasmota VINDRIKTNING does not show correct values - C
- curl: (60) SSL certificate problem: unable to get local issuer certificate - Shell lando