Source code for cords.selectionstrategies.SSL.craigstrategy

import numpy as np
import torch, time, apricot, math
from scipy.sparse import csr_matrix
from .dataselectionstrategy import DataSelectionStrategy
from torch.utils.data.sampler import SubsetRandomSampler


[docs]class CRAIGStrategy(DataSelectionStrategy): """ Adapted Implementation of CRAIG Strategy from the paper :footcite:`pmlr-v119-mirzasoleiman20a` for semi-supervised learning setting. CRAIG strategy tries to solve the optimization problem given below for convex loss functions: .. math:: \\sum_{i\\in \\mathcal{U}} \\min_{j \\in S, |S| \\leq k} \\| x^i - x^j \\| In the above equation, :math:`\\mathcal{U}` denotes the training set where :math:`(x^i, y^i)` denotes the :math:`i^{th}` training data point and label respectively, :math:`L_T` denotes the training loss, :math:`S` denotes the data subset selected at each round, and :math:`k` is the budget for the subset. Since, the above optimization problem is not dependent on model parameters, we run the subset selection only once right before the start of the training. CRAIG strategy tries to solve the optimization problem given below for non-convex loss functions: .. math:: \\underset{\\mathcal{S} \\subseteq \\mathcal{U}:|\\mathcal{S}| \\leq k}{\\operatorname{argmin\\hspace{0.7cm}}}\\underset{i \\in \\mathcal{U}}{\\sum} \\underset{j \\in \\mathcal{S}}{\\min} \\left \\Vert \\mathbf{m}_i \\nabla_{\\theta}l_u(x_i, \\theta) - \\mathbf{m}_j \\nabla_{\\theta}l_u(x_j, \\theta) \\right \Vert In the above equation, :math:`\\mathcal{U}` denotes the unlabeled set, :math:`l_u` denotes the unlabeled loss, :math:`\\mathcal{S}` denotes the data subset selected at each round, and :math:`k` is the budget for the subset. In this case, CRAIG acts an adaptive subset selection strategy that selects a new subset every epoch. Both the optimization problems given above are an instance of facility location problems which is a submodular function. Hence, it can be optimally solved using greedy selection methods. Parameters ---------- trainloader: class Loading the training data using pytorch DataLoader valloader: class Loading the validation data using pytorch DataLoader model: class Model architecture used for training tea_model: class Teacher model architecture used for training ssl_alg: class SSL algorithm class loss: class Consistency loss function for unlabeled data with no reduction device: str The device being utilized - cpu | cuda num_classes: int The number of target classes in the dataset linear_layer: bool Apply linear transformation to the data if_convex: bool If convex or not selection_type: str Type of selection: - 'PerClass': PerClass Implementation where the facility location problem is solved for each class seperately for speed ups. - 'Supervised': Supervised Implementation where the facility location problem is solved using a sparse similarity matrix by assigning the similarity of a point with other points of different class to zero. - 'PerBatch': PerBatch Implementation where the facility location problem tries to select subset of mini-batches. logger: class Logger class for logging the information optimizer: str Type of Greedy Algorithm """ def __init__(self, trainloader, valloader, model, tea_model, ssl_alg, loss, device, num_classes, linear_layer, if_convex, selection_type, logger, optimizer='lazy'): """ Constructor method """ super().__init__(trainloader, valloader, model, tea_model, ssl_alg, num_classes, linear_layer, loss, device, logger) self.if_convex = if_convex self.selection_type = selection_type self.optimizer = optimizer self.dist_mat = None
[docs] def distance(self, x, y, exp=2): """ Compute the distance. Parameters ---------- x: Tensor First input tensor y: Tensor Second input tensor exp: float, optional The exponent value (default: 2) Returns ---------- dist: Tensor Output tensor """ n = x.size(0) m = y.size(0) d = x.size(1) x = x.unsqueeze(1).expand(n, m, d) y = y.unsqueeze(0).expand(n, m, d) dist = torch.pow(x - y, exp).sum(2) # dist = torch.exp(-1 * torch.pow(x - y, 2).sum(2)) return dist
[docs] def compute_score(self, model_params, tea_model_params, idxs): """ Compute the score of the indices. Parameters ---------- model_params: OrderedDict Python dictionary object containing model's parameters tea_model_params: OrderedDict Python dictionary object containing teacher model's parameters idxs: list The indices """ trainset = self.trainloader.sampler.data_source subset_loader = torch.utils.data.DataLoader(trainset, batch_size=self.trainloader.batch_size, shuffle=False, sampler=SubsetRandomSampler(idxs), pin_memory=True) self.model.load_state_dict(model_params) if self.tea_model is not None: self.tea_model.load_state_dict(tea_model_params) self.N = 0 g_is = [] if self.if_convex: for batch_idx, (ul_weak_aug, ul_strong_aug, _) in enumerate(subset_loader): if self.selection_type == 'PerBatch': self.N += 1 g_is.append(ul_strong_aug.view(ul_strong_aug.size()[0], -1).mean(dim=0).view(1, -1)) else: self.N += ul_strong_aug.size()[0] g_is.append(ul_strong_aug.view(ul_strong_aug.size()[0], -1)) else: embDim = self.model.get_embedding_dim() for batch_idx, (ul_weak_aug, ul_strong_aug, _) in enumerate(subset_loader): ul_weak_aug, ul_strong_aug = ul_weak_aug.to(self.device), ul_strong_aug.to(self.device) if self.selection_type == 'PerBatch': self.N += 1 else: self.N += ul_strong_aug.size()[0] loss, out, l1, _, _ = self.ssl_loss(ul_weak_data=ul_weak_aug, ul_strong_data=ul_strong_aug) loss = loss.sum() l0_grads = torch.autograd.grad(loss, out)[0] if self.linear_layer: l0_expand = torch.repeat_interleave(l0_grads, embDim, dim=1) l1_grads = l0_expand * l1.repeat(1, self.num_classes) if self.selection_type == 'PerBatch': g_is.append(torch.cat((l0_grads, l1_grads), dim=1).mean(dim=0).view(1, -1)) else: g_is.append(torch.cat((l0_grads, l1_grads), dim=1)) else: if self.selection_type == 'PerBatch': g_is.append(l0_grads.mean(dim=0).view(1, -1)) else: g_is.append(l0_grads) self.dist_mat = torch.zeros([self.N, self.N], dtype=torch.float32) first_i = True if self.selection_type == 'PerBatch': g_is = torch.cat(g_is, dim=0) self.dist_mat = self.distance(g_is, g_is).cpu() else: for i, g_i in enumerate(g_is, 0): if first_i: size_b = g_i.size(0) first_i = False for j, g_j in enumerate(g_is, 0): self.dist_mat[i * size_b: i * size_b + g_i.size(0), j * size_b: j * size_b + g_j.size(0)] = self.distance(g_i, g_j).cpu() self.const = torch.max(self.dist_mat).item() self.dist_mat = (self.const - self.dist_mat).numpy()
[docs] def compute_gamma(self, idxs): """ Compute the gamma values for the indices. Parameters ---------- idxs: list The indices Returns ---------- gamma: list Gradient values of the input indices """ if self.selection_type in ['PerClass', 'PerBatch']: gamma = [0 for i in range(len(idxs))] best = self.dist_mat[idxs] # .to(self.device) rep = np.argmax(best, axis=0) for i in rep: gamma[i] += 1 elif self.selection_type == 'Supervised': gamma = [0 for i in range(len(idxs))] best = self.dist_mat[idxs] # .to(self.device) rep = np.argmax(best, axis=0) for i in range(rep.shape[1]): gamma[rep[0, i]] += 1 return gamma
[docs] def get_similarity_kernel(self): """ Obtain the similarity kernel. Returns ---------- kernel: ndarray Array of kernel values """ for batch_idx, (inputs, targets) in enumerate(self.trainloader): if batch_idx == 0: labels = targets else: tmp_target_i = targets labels = torch.cat((labels, tmp_target_i), dim=0) kernel = np.zeros((labels.shape[0], labels.shape[0])) for target in np.unique(labels): x = np.where(labels == target)[0] # prod = np.transpose([np.tile(x, len(x)), np.repeat(x, len(x))]) for i in x: kernel[i, x] = 1 return kernel
[docs] def select(self, budget, model_params, tea_model_params): """ Data selection method using different submodular optimization functions. Parameters ---------- budget: int The number of data points to be selected model_params: OrderedDict Python dictionary object containing models parameters optimizer: str The optimization approach for data selection. Must be one of 'random', 'modular', 'naive', 'lazy', 'approximate-lazy', 'two-stage', 'stochastic', 'sample', 'greedi', 'bidirectional' Returns ---------- total_greedy_list: list List containing indices of the best datapoints gammas: list List containing gradients of datapoints present in greedySet """ # per_class_bud = int(budget / self.num_classes) total_greedy_list = [] gammas = [] start_time = time.time() if self.selection_type == 'PerClass': self.get_labels(valid=False) for i in range(self.num_classes): idxs = torch.where(self.trn_lbls == i)[0] self.compute_score(model_params, tea_model_params, idxs) fl = apricot.functions.facilityLocation.FacilityLocationSelection(random_state=0, metric='precomputed', n_samples=math.ceil( budget * len(idxs) / self.N_trn), optimizer=self.optimizer) sim_sub = fl.fit_transform(self.dist_mat) greedyList = list(np.argmax(sim_sub, axis=1)) gamma = self.compute_gamma(greedyList) total_greedy_list.extend(idxs[greedyList]) gammas.extend(gamma) rand_indices = np.random.permutation(len(total_greedy_list)) total_greedy_list = list(np.array(total_greedy_list)[rand_indices]) gammas = list(np.array(gammas)[rand_indices]) elif self.selection_type == 'Supervised': self.get_labels(valid=False) for i in range(self.num_classes): if i == 0: idxs = torch.where(self.trn_lbls == i)[0] N = len(idxs) self.compute_score(model_params, tea_model_params, idxs) row = idxs.repeat_interleave(N) col = idxs.repeat(N) data = self.dist_mat.flatten() else: idxs = torch.where(self.trn_lbls == i)[0] N = len(idxs) self.compute_score(model_params, tea_model_params, idxs) row = torch.cat((row, idxs.repeat_interleave(N)), dim=0) col = torch.cat((col, idxs.repeat(N)), dim=0) data = np.concatenate([data, self.dist_mat.flatten()], axis=0) sparse_simmat = csr_matrix((data, (row.numpy(), col.numpy())), shape=(self.N_trn, self.N_trn)) self.dist_mat = sparse_simmat fl = apricot.functions.facilityLocation.FacilityLocationSelection(random_state=0, metric='precomputed', n_samples=budget, optimizer=self.optimizer) sim_sub = fl.fit_transform(sparse_simmat) total_greedy_list = list(np.array(np.argmax(sim_sub, axis=1)).reshape(-1)) gammas = self.compute_gamma(total_greedy_list) elif self.selection_type == 'PerBatch': idxs = torch.arange(self.N_trn) self.compute_score(model_params, tea_model_params, idxs) fl = apricot.functions.facilityLocation.FacilityLocationSelection(random_state=0, metric='precomputed', n_samples=math.ceil( budget / self.trainloader.batch_size), optimizer=self.optimizer) sim_sub = fl.fit_transform(self.dist_mat) temp_list = list(np.array(np.argmax(sim_sub, axis=1)).reshape(-1)) gammas_temp = self.compute_gamma(temp_list) batch_wise_indices = list(self.trainloader.batch_sampler) for i in range(len(temp_list)): tmp = batch_wise_indices[temp_list[i]] total_greedy_list.extend(tmp) gammas.extend(list(gammas_temp[i] * np.ones(len(tmp)))) end_time = time.time() self.logger.debug("CRAIG subset selection time is: %f", end_time-start_time) return total_greedy_list, gammas