import apricot
import numpy as np
import torch
import torch.nn.functional as F
from scipy.sparse import csr_matrix
from .dataselectionstrategy import DataSelectionStrategy
from torch.utils.data.sampler import SubsetRandomSampler
[docs]class SubmodularSelectionStrategy(DataSelectionStrategy):
"""
This class extends :class:`selectionstrategies.supervisedlearning.dataselectionstrategy.DataSelectionStrategy`
to include submodular optmization functions using apricot for data selection.
Parameters
----------
trainloader: class
Loading the training data using pytorch DataLoader
valloader: class
Loading the validation data using pytorch DataLoader
model: class
Model architecture used for training
loss_type: class
The type of loss criterion
device: str
The device being utilized - cpu | cuda
num_classes: int
The number of target classes in the dataset
linear_layer: bool
Apply linear transformation to the data
if_convex: bool
If convex or not
selection_type: str
PerClass or Supervised
submod_func_type: str
The type of submodular optimization function. Must be one of
'facility-location', 'graph-cut', 'sum-redundancy', 'saturated-coverage'
"""
def __init__(self, trainloader, valloader, model, loss,
device, num_classes, linear_layer, if_convex, selection_type, submod_func_type, optimizer):
"""
Constructer method
"""
super().__init__(trainloader, valloader, model, num_classes, linear_layer, loss, device)
self.if_convex = if_convex
self.selection_type = selection_type
self.submod_func_type = submod_func_type
self.optimizer = optimizer
[docs] def distance(self, x, y, exp=2):
"""
Compute the distance.
Parameters
----------
x: Tensor
First input tensor
y: Tensor
Second input tensor
exp: float, optional
The exponent value (default: 2)
Returns
----------
dist: Tensor
Output tensor
"""
n = x.size(0)
m = y.size(0)
d = x.size(1)
x = x.unsqueeze(1).expand(n, m, d)
y = y.unsqueeze(0).expand(n, m, d)
dist = torch.pow(x - y, exp).sum(2)
# dist = torch.exp(-1 * torch.pow(x - y, 2).sum(2))
return dist
[docs] def compute_score(self, model_params, idxs):
"""
Compute the score of the indices.
Parameters
----------
model_params: OrderedDict
Python dictionary object containing models parameters
idxs: list
The indices
"""
trainset = self.trainloader.sampler.data_source
subset_loader = torch.utils.data.DataLoader(trainset, batch_size=self.trainloader.batch_size, shuffle=False,
sampler=SubsetRandomSampler(idxs),
pin_memory=True)
self.model.load_state_dict(model_params)
self.N = 0
g_is = []
if self.if_convex:
for batch_idx, (inputs, targets) in enumerate(subset_loader):
inputs, targets = inputs, targets
if self.selection_type == 'PerBatch':
self.N += 1
g_is.append(inputs.view(inputs.size()[0], -1).mean(dim=0).view(1, -1))
else:
self.N += inputs.size()[0]
g_is.append(inputs.view(inputs.size()[0], -1))
else:
embDim = self.model.get_embedding_dim()
for batch_idx, (inputs, targets) in enumerate(subset_loader):
inputs, targets = inputs.to(self.device), targets.to(self.device, non_blocking=True)
if self.selection_type == 'PerBatch':
self.N += 1
else:
self.N += inputs.size()[0]
out, l1 = self.model(inputs, freeze=True, last=True)
loss = self.loss(out, targets).sum()
l0_grads = torch.autograd.grad(loss, out)[0]
if self.linear_layer:
l0_expand = torch.repeat_interleave(l0_grads, embDim, dim=1)
l1_grads = l0_expand * l1.repeat(1, self.num_classes)
if self.selection_type == 'PerBatch':
g_is.append(torch.cat((l0_grads, l1_grads), dim=1).mean(dim=0).view(1, -1))
else:
g_is.append(torch.cat((l0_grads, l1_grads), dim=1))
else:
if self.selection_type == 'PerBatch':
g_is.append(l0_grads.mean(dim=0).view(1, -1))
else:
g_is.append(l0_grads)
self.dist_mat = torch.zeros([self.N, self.N], dtype=torch.float32)
first_i = True
if self.selection_type == 'PerBatch':
g_is = torch.cat(g_is, dim=0)
self.dist_mat = self.distance(g_is, g_is).cpu()
else:
for i, g_i in enumerate(g_is, 0):
if first_i:
size_b = g_i.size(0)
first_i = False
for j, g_j in enumerate(g_is, 0):
self.dist_mat[i * size_b: i * size_b + g_i.size(0),
j * size_b: j * size_b + g_j.size(0)] = self.distance(g_i, g_j).cpu()
self.const = torch.max(self.dist_mat).item()
self.dist_mat = (self.const - self.dist_mat).numpy()
[docs] def compute_gamma(self, idxs):
"""
Compute the gamma values for the indices.
Parameters
----------
idxs: list
The indices
Returns
----------
gamma: list
Gradient values of the input indices
"""
if self.selection_type == 'PerClass':
gamma = [0 for i in range(len(idxs))]
best = self.dist_mat[idxs] # .to(self.device)
rep = np.argmax(best, axis=0)
for i in rep:
gamma[i] += 1
elif self.selection_type == 'Supervised':
gamma = [0 for i in range(len(idxs))]
best = self.dist_mat[idxs] # .to(self.device)
rep = np.argmax(best, axis=0)
for i in range(rep.shape[1]):
gamma[rep[0, i]] += 1
return gamma
[docs] def get_similarity_kernel(self):
"""
Obtain the similarity kernel.
Returns
----------
kernel: ndarray
Array of kernel values
"""
for batch_idx, (inputs, targets) in enumerate(self.trainloader):
if batch_idx == 0:
labels = targets
else:
tmp_target_i = targets
labels = torch.cat((labels, tmp_target_i), dim=0)
kernel = np.zeros((labels.shape[0], labels.shape[0]))
for target in np.unique(labels):
x = np.where(labels == target)[0]
# prod = np.transpose([np.tile(x, len(x)), np.repeat(x, len(x))])
for i in x:
kernel[i, x] = 1
return kernel
[docs] def select(self, budget, model_params):
"""
Data selection method using different submodular optimization
functions.
Parameters
----------
budget: int
The number of data points to be selected
model_params: OrderedDict
Python dictionary object containing models parameters
optimizer: str
The optimization approach for data selection. Must be one of
'random', 'modular', 'naive', 'lazy', 'approximate-lazy', 'two-stage',
'stochastic', 'sample', 'greedi', 'bidirectional'
Returns
----------
total_greedy_list: list
List containing indices of the best datapoints
gammas: list
List containing gradients of datapoints present in greedySet
"""
for batch_idx, (inputs, targets) in enumerate(self.trainloader):
if batch_idx == 0:
x_trn, labels = inputs, targets
else:
tmp_inputs, tmp_target_i = inputs, targets
labels = torch.cat((labels, tmp_target_i), dim=0)
per_class_bud = int(budget / self.num_classes)
total_greedy_list = []
gammas = []
if self.selection_type == 'PerClass':
for i in range(self.num_classes):
idxs = torch.where(labels == i)[0]
self.compute_score(model_params, idxs)
if self.submod_func_type == 'facility-location':
fl = apricot.functions.facilityLocation.FacilityLocationSelection(random_state=0,
metric='precomputed',
n_samples=per_class_bud,
optimizer=self.optimizer)
elif self.submod_func_type == 'graph-cut':
fl = apricot.functions.graphCut.GraphCutSelection(random_state=0, metric='precomputed',
n_samples=per_class_bud, optimizer=self.optimizer)
elif self.submod_func_type == 'sum-redundancy':
fl = apricot.functions.sumRedundancy.SumRedundancySelection(random_state=0, metric='precomputed',
n_samples=per_class_bud,
optimizer=self.optimizer)
elif self.submod_func_type == 'saturated-coverage':
fl = apricot.functions.saturatedCoverage.SaturatedCoverageSelection(random_state=0,
metric='precomputed',
n_samples=per_class_bud,
optimizer=self.optimizer)
sim_sub = fl.fit_transform(self.dist_mat)
greedyList = list(np.argmax(sim_sub, axis=1))
gamma = self.compute_gamma(greedyList)
total_greedy_list.extend(idxs[greedyList])
gammas.extend(gamma)
elif self.selection_type == 'Supervised':
for i in range(self.num_classes):
if i == 0:
idxs = torch.where(labels == i)[0]
N = len(idxs)
self.compute_score(model_params, idxs)
row = idxs.repeat_interleave(N)
col = idxs.repeat(N)
data = self.dist_mat.flatten()
else:
idxs = torch.where(labels == i)[0]
N = len(idxs)
self.compute_score(model_params, idxs)
row = torch.cat((row, idxs.repeat_interleave(N)), dim=0)
col = torch.cat((col, idxs.repeat(N)), dim=0)
data = np.concatenate([data, self.dist_mat.flatten()], axis=0)
sparse_simmat = csr_matrix((data, (row.numpy(), col.numpy())), shape=(self.N_trn, self.N_trn))
self.dist_mat = sparse_simmat
if self.submod_func_type == 'facility-location':
fl = apricot.functions.facilityLocation.FacilityLocationSelection(random_state=0, metric='precomputed',
n_samples=per_class_bud,
optimizer=self.optimizer)
elif self.submod_func_type == 'graph-cut':
fl = apricot.functions.graphCut.GraphCutSelection(random_state=0, metric='precomputed',
n_samples=per_class_bud, optimizer=self.optimizer)
elif self.submod_func_type == 'sum-redundancy':
fl = apricot.functions.sumRedundancy.SumRedundancySelection(random_state=0, metric='precomputed',
n_samples=per_class_bud,
optimizer=self.optimizer)
elif self.submod_func_type == 'saturated-coverage':
fl = apricot.functions.saturatedCoverage.SaturatedCoverageSelection(random_state=0,
metric='precomputed',
n_samples=per_class_bud,
optimizer=self.optimizer)
sim_sub = fl.fit_transform(sparse_simmat)
total_greedy_list = list(np.array(np.argmax(sim_sub, axis=1)).reshape(-1))
gammas = self.compute_gamma(total_greedy_list)
return total_greedy_list, gammas