from __future__ import print_function, absolute_import
import time
import torch
from torch.autograd import Variable
from .evaluation_metrics import accuracy
from .loss import OIMLoss, TripletLoss
from .utils.meters import AverageMeter
[docs]class BaseTrainer(object):
def __init__(self, model, criterion):
super(BaseTrainer, self).__init__()
self.model = model
self.criterion = criterion
def train(self, epoch, data_loader, optimizer, print_freq=1):
self.model.train()
batch_time = AverageMeter()
data_time = AverageMeter()
losses = AverageMeter()
precisions = AverageMeter()
end = time.time()
for i, inputs in enumerate(data_loader):
data_time.update(time.time() - end)
inputs, targets = self._parse_data(inputs)
loss, prec1 = self._forward(inputs, targets)
losses.update(loss.data[0], targets.size(0))
precisions.update(prec1, targets.size(0))
optimizer.zero_grad()
loss.backward()
optimizer.step()
batch_time.update(time.time() - end)
end = time.time()
if (i + 1) % print_freq == 0:
print('Epoch: [{}][{}/{}]\t'
'Time {:.3f} ({:.3f})\t'
'Data {:.3f} ({:.3f})\t'
'Loss {:.3f} ({:.3f})\t'
'Prec {:.2%} ({:.2%})\t'
.format(epoch, i + 1, len(data_loader),
batch_time.val, batch_time.avg,
data_time.val, data_time.avg,
losses.val, losses.avg,
precisions.val, precisions.avg))
def _parse_data(self, inputs):
raise NotImplementedError
def _forward(self, inputs, targets):
raise NotImplementedError
[docs]class Trainer(BaseTrainer):
def _parse_data(self, inputs):
imgs, _, pids, _ = inputs
inputs = [Variable(imgs)]
targets = Variable(pids.cuda())
return inputs, targets
def _forward(self, inputs, targets):
outputs = self.model(*inputs)
if isinstance(self.criterion, torch.nn.CrossEntropyLoss):
loss = self.criterion(outputs, targets)
prec, = accuracy(outputs.data, targets.data)
prec = prec[0]
elif isinstance(self.criterion, OIMLoss):
loss, outputs = self.criterion(outputs, targets)
prec, = accuracy(outputs.data, targets.data)
prec = prec[0]
elif isinstance(self.criterion, TripletLoss):
loss, prec = self.criterion(outputs, targets)
else:
raise ValueError("Unsupported loss:", self.criterion)
return loss, prec