成都创新互联网站制作重庆分公司

Pytorch中accuracy和loss的计算知识点总结-创新互联

这几天关于accuracy和loss的计算有一些疑惑,原来是自己还没有弄清楚。

创新互联公司专注于企业营销型网站、网站重做改版、广西网站定制设计、自适应品牌网站建设、H5开发商城网站定制开发、集团公司官网建设、成都外贸网站制作、高端网站制作、响应式网页设计等建站业务,价格优惠性价比高,为广西等各大城市提供网站开发制作服务。

给出实例

def train(train_loader, model, criteon, optimizer, epoch):
  train_loss = 0
  train_acc = 0
  num_correct= 0
  for step, (x,y) in enumerate(train_loader):

    # x: [b, 3, 224, 224], y: [b]
    x, y = x.to(device), y.to(device)

    model.train()
    logits = model(x)
    loss = criteon(logits, y)

    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    train_loss += float(loss.item())
    train_losses.append(train_loss)
    pred = logits.argmax(dim=1)
    num_correct += torch.eq(pred, y).sum().float().item()
  logger.info("Train Epoch: {}\t Loss: {:.6f}\t Acc: {:.6f}".format(epoch,train_loss/len(train_loader),num_correct/len(train_loader.dataset)))
  return num_correct/len(train_loader.dataset), train_loss/len(train_loader)

文章标题:Pytorch中accuracy和loss的计算知识点总结-创新互联
网页URL:http://cxhlcq.com/article/doecsd.html

其他资讯

在线咨询

微信咨询

电话咨询

028-86922220(工作日)

18980820575(7×24)

提交需求

返回顶部