from abc import abstractmethod
from cords.utils.data.data_utils import WeightedSubset
from torch.utils.data.dataloader import DataLoader
import torch
import numpy as np
# Base objects
[docs]class DSSDataLoader:
"""
Implementation of DSSDataLoader class which serves as base class for dataloaders of other
selection strategies for semi-supervised learning framework.
Parameters
-----------
full_data: torch.utils.data.Dataset Class
Full dataset from which data subset needs to be selected.
dss_args: dict
Data subset selection arguments dictionary
logger: class
Logger class for logging the information
"""
def __init__(self, full_data, dss_args, logger, *args, **kwargs):
"""
Constructor function
"""
super(DSSDataLoader, self).__init__()
# TODO: Integrate verbose in logging
self.len_full = len(full_data)
# Arguments assertion check
assert "fraction" in dss_args.keys(), "'fraction' is a compulsory argument. Include it as a key in dss_args"
if (dss_args.fraction > 1) or (dss_args.fraction<0):
raise ValueError("'fraction' should lie between 0 and 1")
self.fraction = dss_args.fraction
self.budget = int(self.len_full * self.fraction)
self.logger = logger
self.dataset = full_data
self.loader_args = args
self.loader_kwargs = kwargs
self.subset_indices = None
self.subset_weights = None
self.subset_loader = None
self.curr_loader = None
self.batch_wise_indices = None
self.strategy = None
self.cur_iter = 1
self.wt_trainset = WeightedSubset(full_data, list(range(len(full_data))), [1]*len(full_data))
self._init_subset_loader()
def __getattr__(self, item):
return object.__getattribute__(self, "curr_loader").__getattribute__(item)
def _init_subset_loader(self):
"""
Function that initializes the random data subset loader
"""
# All strategies start with random selection
self.subset_indices = self._init_subset_indices()
self.logger.debug("Length of the data subset: %d", len(self.subset_indices))
self.subset_weights = torch.ones(self.budget)
self.subset_loader = DataLoader(WeightedSubset(self.dataset, self.subset_indices, self.subset_weights),
*self.loader_args, **self.loader_kwargs)
self.batch_wise_indices = list(self.subset_loader.batch_sampler)
self.curr_loader = self.subset_loader
# Default subset indices comes from random selection
def _init_subset_indices(self):
"""
Function that initializes the subset indices randomly
"""
return np.random.choice(self.len_full, size=self.budget, replace=False)
def _refresh_subset_loader(self):
"""
Function that regenerates the data subset loader using new subset indices and subset weights
"""
self.subset_loader = DataLoader(WeightedSubset(self.dataset, self.subset_indices, self.subset_weights),
*self.loader_args, **self.loader_kwargs)
self.logger.debug("Subset Loader Refreshed")
self.batch_wise_indices = list(self.subset_loader.batch_sampler)