Source code for cords.selectionstrategies.SSL.randomstrategy

import numpy as np
import torch


[docs]class RandomStrategy(object): """ This is the Random Selection Strategy class where we select a set of random points as a datasubset and often acts as baselines to compare other subset selection strategies. Parameters ---------- trainloader: class Loading the training data using pytorch DataLoader """ def __init__(self, trainloader, online=False): """ Constructor method """ self.trainloader = trainloader self.N_trn = len(trainloader.sampler.data_source) self.online = online self.indices = None self.gammas = None
[docs] def select(self, budget): """ Perform random sampling of indices of size budget. Parameters ---------- budget: int The number of data points to be selected Returns ---------- indices: ndarray Array of indices of size budget selected randomly gammas: Tensor Gradient weight values of selected indices """ if self.online or (self.indices is None): self.indices = np.random.choice(self.N_trn, size=budget, replace=False) self.gammas = torch.ones(budget) return self.indices, self.gammas