from __future__ import absolute_import
import torch
from .evaluators import extract_features
from .metric_learning import get_metric
[docs]class DistanceMetric(object):
def __init__(self, algorithm='euclidean', *args, **kwargs):
super(DistanceMetric, self).__init__()
self.algorithm = algorithm
self.metric = get_metric(algorithm, *args, **kwargs)
def train(self, model, data_loader):
if self.algorithm == 'euclidean': return
features, labels = extract_features(model, data_loader)
features = torch.stack(features.values()).numpy()
labels = torch.Tensor(list(labels.values())).numpy()
self.metric.fit(features, labels)
def transform(self, X):
if torch.is_tensor(X):
X = X.numpy()
X = self.metric.transform(X)
X = torch.from_numpy(X)
else:
X = self.metric.transform(X)
return X