import math
import time
import torch
import numpy as np
from .dataselectionstrategy import DataSelectionStrategy
from ..helpers import OrthogonalMP_REG_Parallel, OrthogonalMP_REG, OrthogonalMP_REG_Parallel_V1
from torch.utils.data import Subset, DataLoader
[docs]class GradMatchStrategy(DataSelectionStrategy):
"""
Implementation of OMPGradMatch Strategy from the paper :footcite:`pmlr-v139-killamsetty21a` for supervised learning frameworks.
OMPGradMatch strategy tries to solve the optimization problem given below:
.. math::
\\underset{\\mathcal{S} \\subseteq \\mathcal{U}:|\\mathcal{S}| \\leq k, \{\\mathbf{w}_j\}_{j \\in [1, |\\mathcal{S}|]}:\\forall_{j} \\mathbf{w}_j \\geq 0}{\\operatorname{argmin\\hspace{0.7mm}}} \\left \\Vert \\underset{i \\in \\mathcal{U}}{\\sum} \\mathbf{m}_i \\nabla_{\\theta}l_u(x_i, \\theta) - \\underset{j \\in \\mathcal{S}}{\\sum} \\mathbf{m}_j \\mathbf{w}_j \\nabla_{\\theta} l_u(x_j, \\theta)\\right \\Vert
In the above equation, :math:`\\mathbf{w}` denotes the weight vector that contains the weights for each data instance, :math:`\\mathcal{U}` denotes the unlabeled set
where :math:`(x^i, y^i)` denotes the :math:`i^{th}` training data point and label respectively, :math:`l_u` denotes the unlabeled loss, :math:`\\mathcal{S}` denotes the
data subset selected at each round, and :math:`k` is the budget for the subset.
The above optimization problem is solved using the Orthogonal Matching Pursuit(OMP) algorithm.
Parameters
----------
trainloader: class
Loading the training data using pytorch DataLoader
valloader: class
Loading the validation data using pytorch DataLoader
model: class
Model architecture used for training
tea_model: class
Teacher model architecture used for training
ssl_alg: class
SSL algorithm class
loss: class
Consistency loss function for unlabeled data with no reduction
eta: float
Learning rate. Step size for the one step gradient update
device: str
The device being utilized - cpu | cuda
num_classes: int
The number of target classes in the dataset
linear_layer: bool
Apply linear transformation to the data
selection_type: str
Type of selection -
- 'PerClass': PerClass method is where OMP algorithm is applied on each class data points seperately.
- 'PerBatch': PerBatch method is where OMP algorithm is applied on each minibatch data points.
- 'PerClassPerGradient': PerClassPerGradient method is same as PerClass but we use the gradient corresponding to classification layer of that class only.
logger : class
logger file for printing the info
valid : bool, optional
If valid==True we use validation dataset gradient sum in OMP otherwise we use training dataset (default: False)
v1 : bool
If v1==True, we use newer version of OMP solver that is more accurate
lam : float
Regularization constant of OMP solver
eps : float
Epsilon parameter to which the above optimization problem is solved using OMP algorithm
"""
def __init__(self, trainloader, valloader, model, tea_model, ssl_alg, loss,
eta, device, num_classes, linear_layer, selection_type, logger,
valid=False, v1=True, lam=0, eps=1e-4):
"""
Constructor method
"""
super().__init__(trainloader, valloader, model, tea_model, ssl_alg, num_classes, linear_layer, loss, device, logger)
self.eta = eta # step size for the one step gradient update
self.device = device
self.selection_type = selection_type
self.valid = valid
self.lam = lam
self.eps = eps
self.v1 = v1
[docs] def ompwrapper(self, X, Y, bud):
"""
Wrapper function that instantiates the OMP algorithm
Parameters
----------
X:
Individual datapoint gradients
Y:
Gradient sum that needs to be matched to.
bud:
Budget of datapoints that needs to be sampled from the unlabeled set
Returns
----------
idxs: list
List containing indices of the best datapoints,
gammas: weights tensors
Tensor containing weights of each instance
"""
if self.device == "cpu":
reg = OrthogonalMP_REG(X.numpy(), Y.numpy(), nnz=bud, positive=True, lam=0)
ind = np.nonzero(reg)[0]
else:
if self.v1:
reg = OrthogonalMP_REG_Parallel_V1(X, Y, nnz=bud,
positive=True, lam=self.lam,
tol=self.eps, device=self.device)
else:
reg = OrthogonalMP_REG_Parallel(X, Y, nnz=bud,
positive=True, lam=self.lam,
tol=self.eps, device=self.device)
ind = torch.nonzero(reg).view(-1)
return ind.tolist(), reg[ind].tolist()
[docs] def select(self, budget, model_params, tea_model_params):
"""
Apply OMP Algorithm for data selection
Parameters
----------
budget: int
The number of data points to be selected
model_params: OrderedDict
Python dictionary object containing model's parameters
tea_model_params: OrderedDict
Python dictionary object containing teacher model's parameters
Returns
--------
idxs: list
List containing indices of the best datapoints,
gammas: weights tensors
Tensor containing weights of each instance
"""
omp_start_time = time.time()
self.update_model(model_params, tea_model_params)
if self.selection_type == 'PerClass':
self.get_labels(valid=self.valid)
idxs = []
gammas = []
for i in range(self.num_classes):
trn_subset_idx = torch.where(self.trn_lbls == i)[0].tolist()
trn_data_sub = Subset(self.trainloader.dataset, trn_subset_idx)
self.pctrainloader = DataLoader(trn_data_sub, batch_size=self.trainloader.batch_size,
shuffle=False, pin_memory=True)
if self.valid:
val_subset_idx = torch.where(self.val_lbls == i)[0].tolist()
val_data_sub = Subset(self.valloader.dataset, val_subset_idx)
self.pcvalloader = DataLoader(val_data_sub, batch_size=self.trainloader.batch_size,
shuffle=False, pin_memory=True)
self.compute_gradients(self.valid, perBatch=False, perClass=True)
trn_gradients = self.grads_per_elem
if self.valid:
sum_val_grad = torch.sum(self.val_grads_per_elem, dim=0)
else:
sum_val_grad = torch.sum(trn_gradients, dim=0)
idxs_temp, gammas_temp = self.ompwrapper(torch.transpose(trn_gradients, 0, 1),
sum_val_grad, math.ceil(budget * len(trn_subset_idx) / self.N_trn))
idxs.extend(list(np.array(trn_subset_idx)[idxs_temp]))
gammas.extend(gammas_temp)
elif self.selection_type == 'PerBatch':
self.compute_gradients(self.valid, perBatch=True, perClass=False)
idxs = []
gammas = []
trn_gradients = self.grads_per_elem
if self.valid:
sum_val_grad = torch.sum(self.val_grads_per_elem, dim=0)
else:
sum_val_grad = torch.sum(trn_gradients, dim=0)
idxs_temp, gammas_temp = self.ompwrapper(torch.transpose(trn_gradients, 0, 1),
sum_val_grad, math.ceil(budget/self.trainloader.batch_size))
batch_wise_indices = list(self.trainloader.batch_sampler)
for i in range(len(idxs_temp)):
tmp = batch_wise_indices[idxs_temp[i]]
idxs.extend(tmp)
gammas.extend(list(gammas_temp[i] * np.ones(len(tmp))))
elif self.selection_type == 'PerClassPerGradient':
self.get_labels(valid=self.valid)
idxs = []
gammas = []
embDim = self.model.get_embedding_dim()
for i in range(self.num_classes):
trn_subset_idx = torch.where(self.trn_lbls == i)[0].tolist()
trn_data_sub = Subset(self.trainloader.dataset, trn_subset_idx)
self.pctrainloader = DataLoader(trn_data_sub, batch_size=self.trainloader.batch_size,
shuffle=False, pin_memory=True)
if self.valid:
val_subset_idx = torch.where(self.val_lbls == i)[0].tolist()
val_data_sub = Subset(self.valloader.dataset, val_subset_idx)
self.pcvalloader = DataLoader(val_data_sub, batch_size=self.trainloader.batch_size,
shuffle=False, pin_memory=True)
self.compute_gradients(self.valid, perBatch=False, perClass=True)
trn_gradients = self.grads_per_elem
tmp_gradients = trn_gradients[:, i].view(-1, 1)
tmp1_gradients = trn_gradients[:,
self.num_classes + (embDim * i): self.num_classes + (embDim * (i + 1))]
trn_gradients = torch.cat((tmp_gradients, tmp1_gradients), dim=1)
if self.valid:
val_gradients = self.val_grads_per_elem
tmp_gradients = val_gradients[:, i].view(-1, 1)
tmp1_gradients = val_gradients[:,
self.num_classes + (embDim * i): self.num_classes + (embDim * (i + 1))]
val_gradients = torch.cat((tmp_gradients, tmp1_gradients), dim=1)
sum_val_grad = torch.sum(val_gradients, dim=0)
else:
sum_val_grad = torch.sum(trn_gradients, dim=0)
idxs_temp, gammas_temp = self.ompwrapper(torch.transpose(trn_gradients, 0, 1),
sum_val_grad, math.ceil(budget * len(trn_subset_idx) / self.N_trn))
idxs.extend(list(np.array(trn_subset_idx)[idxs_temp]))
gammas.extend(gammas_temp)
omp_end_time = time.time()
diff = budget - len(idxs)
if diff > 0:
remainList = set(np.arange(self.N_trn)).difference(set(idxs))
new_idxs = np.random.choice(list(remainList), size=diff, replace=False)
idxs.extend(new_idxs)
gammas.extend([1 for _ in range(diff)])
idxs = np.array(idxs)
gammas = np.array(gammas)
if self.selection_type in ["PerClass", "PerClassPerGradient"]:
rand_indices = np.random.permutation(len(idxs))
idxs = list(np.array(idxs)[rand_indices])
gammas = list(np.array(gammas)[rand_indices])
self.logger.debug("OMP algorithm Subset Selection time is: %f", omp_end_time - omp_start_time)
return idxs, gammas