Source code for cords.selectionstrategies.SL.submodularselectionstrategy

import apricot
import numpy as np
import torch
import torch.nn.functional as F
from scipy.sparse import csr_matrix
from .dataselectionstrategy import DataSelectionStrategy
from torch.utils.data.sampler import SubsetRandomSampler


[docs]class SubmodularSelectionStrategy(DataSelectionStrategy): """ This class extends :class:`selectionstrategies.supervisedlearning.dataselectionstrategy.DataSelectionStrategy` to include submodular optmization functions using apricot for data selection. Parameters ---------- trainloader: class Loading the training data using pytorch DataLoader valloader: class Loading the validation data using pytorch DataLoader model: class Model architecture used for training loss_type: class The type of loss criterion device: str The device being utilized - cpu | cuda num_classes: int The number of target classes in the dataset linear_layer: bool Apply linear transformation to the data if_convex: bool If convex or not selection_type: str PerClass or Supervised submod_func_type: str The type of submodular optimization function. Must be one of 'facility-location', 'graph-cut', 'sum-redundancy', 'saturated-coverage' """ def __init__(self, trainloader, valloader, model, loss, device, num_classes, linear_layer, if_convex, selection_type, submod_func_type, optimizer): """ Constructer method """ super().__init__(trainloader, valloader, model, num_classes, linear_layer, loss, device) self.if_convex = if_convex self.selection_type = selection_type self.submod_func_type = submod_func_type self.optimizer = optimizer
[docs] def distance(self, x, y, exp=2): """ Compute the distance. Parameters ---------- x: Tensor First input tensor y: Tensor Second input tensor exp: float, optional The exponent value (default: 2) Returns ---------- dist: Tensor Output tensor """ n = x.size(0) m = y.size(0) d = x.size(1) x = x.unsqueeze(1).expand(n, m, d) y = y.unsqueeze(0).expand(n, m, d) dist = torch.pow(x - y, exp).sum(2) # dist = torch.exp(-1 * torch.pow(x - y, 2).sum(2)) return dist
[docs] def compute_score(self, model_params, idxs): """ Compute the score of the indices. Parameters ---------- model_params: OrderedDict Python dictionary object containing models parameters idxs: list The indices """ trainset = self.trainloader.sampler.data_source subset_loader = torch.utils.data.DataLoader(trainset, batch_size=self.trainloader.batch_size, shuffle=False, sampler=SubsetRandomSampler(idxs), pin_memory=True) self.model.load_state_dict(model_params) self.N = 0 g_is = [] if self.if_convex: for batch_idx, (inputs, targets) in enumerate(subset_loader): inputs, targets = inputs, targets if self.selection_type == 'PerBatch': self.N += 1 g_is.append(inputs.view(inputs.size()[0], -1).mean(dim=0).view(1, -1)) else: self.N += inputs.size()[0] g_is.append(inputs.view(inputs.size()[0], -1)) else: embDim = self.model.get_embedding_dim() for batch_idx, (inputs, targets) in enumerate(subset_loader): inputs, targets = inputs.to(self.device), targets.to(self.device, non_blocking=True) if self.selection_type == 'PerBatch': self.N += 1 else: self.N += inputs.size()[0] out, l1 = self.model(inputs, freeze=True, last=True) loss = self.loss(out, targets).sum() l0_grads = torch.autograd.grad(loss, out)[0] if self.linear_layer: l0_expand = torch.repeat_interleave(l0_grads, embDim, dim=1) l1_grads = l0_expand * l1.repeat(1, self.num_classes) if self.selection_type == 'PerBatch': g_is.append(torch.cat((l0_grads, l1_grads), dim=1).mean(dim=0).view(1, -1)) else: g_is.append(torch.cat((l0_grads, l1_grads), dim=1)) else: if self.selection_type == 'PerBatch': g_is.append(l0_grads.mean(dim=0).view(1, -1)) else: g_is.append(l0_grads) self.dist_mat = torch.zeros([self.N, self.N], dtype=torch.float32) first_i = True if self.selection_type == 'PerBatch': g_is = torch.cat(g_is, dim=0) self.dist_mat = self.distance(g_is, g_is).cpu() else: for i, g_i in enumerate(g_is, 0): if first_i: size_b = g_i.size(0) first_i = False for j, g_j in enumerate(g_is, 0): self.dist_mat[i * size_b: i * size_b + g_i.size(0), j * size_b: j * size_b + g_j.size(0)] = self.distance(g_i, g_j).cpu() self.const = torch.max(self.dist_mat).item() self.dist_mat = (self.const - self.dist_mat).numpy()
[docs] def compute_gamma(self, idxs): """ Compute the gamma values for the indices. Parameters ---------- idxs: list The indices Returns ---------- gamma: list Gradient values of the input indices """ if self.selection_type == 'PerClass': gamma = [0 for i in range(len(idxs))] best = self.dist_mat[idxs] # .to(self.device) rep = np.argmax(best, axis=0) for i in rep: gamma[i] += 1 elif self.selection_type == 'Supervised': gamma = [0 for i in range(len(idxs))] best = self.dist_mat[idxs] # .to(self.device) rep = np.argmax(best, axis=0) for i in range(rep.shape[1]): gamma[rep[0, i]] += 1 return gamma
[docs] def get_similarity_kernel(self): """ Obtain the similarity kernel. Returns ---------- kernel: ndarray Array of kernel values """ for batch_idx, (inputs, targets) in enumerate(self.trainloader): if batch_idx == 0: labels = targets else: tmp_target_i = targets labels = torch.cat((labels, tmp_target_i), dim=0) kernel = np.zeros((labels.shape[0], labels.shape[0])) for target in np.unique(labels): x = np.where(labels == target)[0] # prod = np.transpose([np.tile(x, len(x)), np.repeat(x, len(x))]) for i in x: kernel[i, x] = 1 return kernel
[docs] def select(self, budget, model_params): """ Data selection method using different submodular optimization functions. Parameters ---------- budget: int The number of data points to be selected model_params: OrderedDict Python dictionary object containing models parameters optimizer: str The optimization approach for data selection. Must be one of 'random', 'modular', 'naive', 'lazy', 'approximate-lazy', 'two-stage', 'stochastic', 'sample', 'greedi', 'bidirectional' Returns ---------- total_greedy_list: list List containing indices of the best datapoints gammas: list List containing gradients of datapoints present in greedySet """ for batch_idx, (inputs, targets) in enumerate(self.trainloader): if batch_idx == 0: x_trn, labels = inputs, targets else: tmp_inputs, tmp_target_i = inputs, targets labels = torch.cat((labels, tmp_target_i), dim=0) per_class_bud = int(budget / self.num_classes) total_greedy_list = [] gammas = [] if self.selection_type == 'PerClass': for i in range(self.num_classes): idxs = torch.where(labels == i)[0] self.compute_score(model_params, idxs) if self.submod_func_type == 'facility-location': fl = apricot.functions.facilityLocation.FacilityLocationSelection(random_state=0, metric='precomputed', n_samples=per_class_bud, optimizer=self.optimizer) elif self.submod_func_type == 'graph-cut': fl = apricot.functions.graphCut.GraphCutSelection(random_state=0, metric='precomputed', n_samples=per_class_bud, optimizer=self.optimizer) elif self.submod_func_type == 'sum-redundancy': fl = apricot.functions.sumRedundancy.SumRedundancySelection(random_state=0, metric='precomputed', n_samples=per_class_bud, optimizer=self.optimizer) elif self.submod_func_type == 'saturated-coverage': fl = apricot.functions.saturatedCoverage.SaturatedCoverageSelection(random_state=0, metric='precomputed', n_samples=per_class_bud, optimizer=self.optimizer) sim_sub = fl.fit_transform(self.dist_mat) greedyList = list(np.argmax(sim_sub, axis=1)) gamma = self.compute_gamma(greedyList) total_greedy_list.extend(idxs[greedyList]) gammas.extend(gamma) elif self.selection_type == 'Supervised': for i in range(self.num_classes): if i == 0: idxs = torch.where(labels == i)[0] N = len(idxs) self.compute_score(model_params, idxs) row = idxs.repeat_interleave(N) col = idxs.repeat(N) data = self.dist_mat.flatten() else: idxs = torch.where(labels == i)[0] N = len(idxs) self.compute_score(model_params, idxs) row = torch.cat((row, idxs.repeat_interleave(N)), dim=0) col = torch.cat((col, idxs.repeat(N)), dim=0) data = np.concatenate([data, self.dist_mat.flatten()], axis=0) sparse_simmat = csr_matrix((data, (row.numpy(), col.numpy())), shape=(self.N_trn, self.N_trn)) self.dist_mat = sparse_simmat if self.submod_func_type == 'facility-location': fl = apricot.functions.facilityLocation.FacilityLocationSelection(random_state=0, metric='precomputed', n_samples=per_class_bud, optimizer=self.optimizer) elif self.submod_func_type == 'graph-cut': fl = apricot.functions.graphCut.GraphCutSelection(random_state=0, metric='precomputed', n_samples=per_class_bud, optimizer=self.optimizer) elif self.submod_func_type == 'sum-redundancy': fl = apricot.functions.sumRedundancy.SumRedundancySelection(random_state=0, metric='precomputed', n_samples=per_class_bud, optimizer=self.optimizer) elif self.submod_func_type == 'saturated-coverage': fl = apricot.functions.saturatedCoverage.SaturatedCoverageSelection(random_state=0, metric='precomputed', n_samples=per_class_bud, optimizer=self.optimizer) sim_sub = fl.fit_transform(sparse_simmat) total_greedy_list = list(np.array(np.argmax(sim_sub, axis=1)).reshape(-1)) gammas = self.compute_gamma(total_greedy_list) return total_greedy_list, gammas