Source code for reid.dist_metric

from __future__ import absolute_import

import torch

from .evaluators import extract_features
from .metric_learning import get_metric


[docs]class DistanceMetric(object): def __init__(self, algorithm='euclidean', *args, **kwargs): super(DistanceMetric, self).__init__() self.algorithm = algorithm self.metric = get_metric(algorithm, *args, **kwargs) def train(self, model, data_loader): if self.algorithm == 'euclidean': return features, labels = extract_features(model, data_loader) features = torch.stack(features.values()).numpy() labels = torch.Tensor(list(labels.values())).numpy() self.metric.fit(features, labels) def transform(self, X): if torch.is_tensor(X): X = X.numpy() X = self.metric.transform(X) X = torch.from_numpy(X) else: X = self.metric.transform(X) return X