Source code for cords.utils.data.datasets.SSL.builder

import os
import numpy as np
from torch.utils.data import DataLoader
from torchvision import transforms
from . import utils
from . import dataset_class
from .augmentation.builder import gen_strong_augmentation, gen_weak_augmentation
from .augmentation.augmentation_pool import numpy_batch_gcn, ZCA, GCN


def __val_labeled_unlabeled_split(cfg, train_data, test_data, num_classes, ul_data=None):
    num_validation = int(np.round(len(train_data["images"]) * cfg.dataset.val_ratio))

    np.random.seed(cfg.train_args.seed)

    permutation = np.random.permutation(len(train_data["images"]))
    train_data["images"] = train_data["images"][permutation]
    train_data["labels"] = train_data["labels"][permutation]

    val_data, train_data = utils.dataset_split(train_data, num_validation, num_classes, cfg.dataset.random_split)
    l_train_data, ul_train_data = utils.dataset_split(train_data, cfg.dataset.num_labels, num_classes)

    if ul_data is not None:
        ul_train_data["images"] = np.concatenate([ul_train_data["images"], ul_data["images"]], 0)
        ul_train_data["labels"] = np.concatenate([ul_train_data["labels"], ul_data["labels"]], 0)

    return val_data, l_train_data, ul_train_data


def __labeled_unlabeled_split(cfg, train_data, test_data, num_classes, ul_data=None):
    np.random.seed(cfg.train_args.seed)

    permutation = np.random.permutation(len(train_data["images"]))
    train_data["images"] = train_data["images"][permutation]
    train_data["labels"] = train_data["labels"][permutation]

    l_train_data, ul_train_data = utils.dataset_split(train_data, cfg.dataset.num_labels, num_classes)

    if ul_data is not None:
        ul_train_data["images"] = np.concatenate([ul_train_data["images"], ul_data["images"]], 0)
        ul_train_data["labels"] = np.concatenate([ul_train_data["labels"], ul_data["labels"]], 0)

    return l_train_data, ul_train_data


[docs]def gen_dataloader(root, dataset, validation_split, cfg, logger=None): """ generate train, val, and test dataloaders Parameters -------- root: str root directory dataset: str dataset name, ['cifar10', 'cifar100', 'svhn', 'stl10'] validation_split: bool if True, return validation loader. validation data is made from training data cfg: argparse.Namespace or something logger: logging.Logger """ ul_train_data = None if dataset == "svhn": train_data, test_data = utils.get_svhn(root) num_classes = 10 img_size = 32 elif dataset == "stl10": train_data, ul_train_data, test_data = utils.get_stl10(root) num_classes = 10 img_size = 96 elif dataset == "cifar10": train_data, test_data = utils.get_cifar10(root) num_classes = 10 img_size = 32 elif dataset == "cifar100": train_data, test_data = utils.get_cifar100(root) num_classes = 100 img_size = 32 elif dataset == "cifarOOD": train_data, test_data = utils.get_cifarOOD(root, cfg.dataset.ood_ratio) num_classes = 6 img_size = 32 else: raise NotImplementedError if validation_split: val_data, l_train_data, ul_train_data = __val_labeled_unlabeled_split( cfg, train_data, test_data, num_classes, ul_train_data) else: l_train_data, ul_train_data = __labeled_unlabeled_split( cfg, train_data, test_data, num_classes, ul_train_data) val_data = None ul_train_data["images"] = np.concatenate([ul_train_data["images"], l_train_data["images"]], 0) ul_train_data["labels"] = np.concatenate([ul_train_data["labels"], l_train_data["labels"]], 0) if logger is not None: logger.info("number of :\n \ training data: %d\n \ labeled data: %d\n \ unlabeled data: %d\n \ validation data: %d\n \ test data: %d", len(train_data["images"]), len(l_train_data["images"]), len(ul_train_data["images"]), 0 if val_data is None else len(val_data["images"]), len(test_data["images"])) labeled_train_data = dataset_class.LabeledDataset(l_train_data) unlabeled_train_data = dataset_class.UnlabeledDataset(ul_train_data) train_data = np.concatenate([ labeled_train_data.dataset["images"], unlabeled_train_data.dataset["images"] ], 0) if cfg.dataset.whiten: mean = train_data.mean((0, 1, 2)) / 255. scale = train_data.std((0, 1, 2)) / 255. elif cfg.dataset.zca: mean, scale = utils.get_zca_normalization_param(numpy_batch_gcn(train_data)) else: # from [0, 1] to [-1, 1] mean = [0.5, 0.5, 0.5] scale = [0.5, 0.5, 0.5] # set augmentation # RA: RandAugment, WA: Weak Augmentation randauglist = "fixmatch" if cfg.ssl_args.alg == "pl" else "uda" flags = [True if b == "t" else False for b in cfg.dataset.wa.split(".")] if cfg.dataset.labeled_aug == "RA": labeled_augmentation = gen_strong_augmentation( img_size, mean, scale, flags[0], flags[1], randauglist, cfg.dataset.zca) elif cfg.dataset.labeled_aug == "WA": labeled_augmentation = gen_weak_augmentation(img_size, mean, scale, *flags, cfg.dataset.zca) else: raise NotImplementedError labeled_train_data.transform = labeled_augmentation if cfg.dataset.unlabeled_aug == "RA": unlabeled_augmentation = gen_strong_augmentation( img_size, mean, scale, flags[0], flags[1], randauglist, cfg.dataset.zca) elif cfg.dataset.unlabeled_aug == "WA": unlabeled_augmentation = gen_weak_augmentation(img_size, mean, scale, *flags, cfg.dataset.zca) else: raise NotImplementedError if logger is not None: logger.info("labeled augmentation") logger.info(labeled_augmentation) logger.info("unlabeled augmentation") logger.info(unlabeled_augmentation) unlabeled_train_data.weak_augmentation = unlabeled_augmentation if cfg.dataset.strong_aug: strong_augmentation = gen_strong_augmentation( img_size, mean, scale, flags[0], flags[1], randauglist, cfg.dataset.zca) unlabeled_train_data.strong_augmentation = strong_augmentation if logger is not None: logger.info(strong_augmentation) if cfg.dataset.zca: test_transform = transforms.Compose([GCN(), ZCA(mean, scale)]) else: test_transform = transforms.Compose([transforms.Normalize(mean, scale, True)]) test_data = dataset_class.LabeledDataset(test_data, test_transform) l_train_loader = DataLoader( labeled_train_data, cfg.dataset.l_batch_size, sampler=utils.InfiniteSampler(len(labeled_train_data), cfg.train_args.iteration * cfg.dataloader.l_batch_size), num_workers=cfg.dataloader.num_workers ) ul_train_loader = DataLoader( unlabeled_train_data, cfg.dataloader.ul_batch_size, sampler=utils.InfiniteSampler(len(unlabeled_train_data), cfg.train_args.iteration * cfg.dataloader.ul_batch_size), num_workers=cfg.dataloader.num_workers ) test_loader = DataLoader( test_data, 1, shuffle=False, drop_last=False, num_workers=cfg.dataloader.num_workers ) if validation_split: validation_data = dataset_class.LabeledDataset(val_data, test_transform) val_loader = DataLoader( validation_data, 1, shuffle=False, drop_last=False, num_workers=cfg.dataloader.num_workers ) return ( l_train_loader, ul_train_loader, val_loader, test_loader, num_classes, img_size ) else: return ( l_train_loader, ul_train_loader, test_loader, num_classes, img_size )
[docs]def gen_dataset(root, dataset, validation_split, cfg, logger=None): """ generate train, val, and test datasets Parameters ----------- root: str root directory in which data is present or needs to be downloaded dataset: str dataset name, Existing dataset choices: ['cifar10', 'cifar100', 'svhn', 'stl10', 'cifarOOD', 'mnistOOD', 'cifarImbalance'] validation_split: bool if True, return validation loader. We use 10% random split of training data as validation data cfg: argparse.Namespace or dict Dictionary containing necessary arguments for generating the dataset logger: logging.Logger Logger class for logging the information """ ul_train_data = None if dataset == "svhn": train_data, test_data = utils.get_svhn(root) num_classes = 10 img_size = 32 elif dataset == "stl10": train_data, ul_train_data, test_data = utils.get_stl10(root) num_classes = 10 img_size = 96 elif dataset == "cifar10": train_data, test_data = utils.get_cifar10(root) num_classes = 10 img_size = 32 elif dataset == "cifar100": train_data, test_data = utils.get_cifar100(root) num_classes = 100 img_size = 32 elif dataset == "cifarOOD": train_data, ul_train_data, test_data = utils.get_cifarOOD(root, cfg.dataset.ood_ratio) num_classes = 6 img_size = 32 elif dataset == "mnistOOD": train_data, ul_train_data, test_data = utils.get_mnistOOD(root, cfg.dataset.ood_ratio) num_classes = 6 img_size = 28 elif dataset == "cifarImbalance": train_data, ul_train_data, test_data = utils.get_cifarClassImb(root, cfg.dataset.ood_ratio) num_classes = 10 img_size = 32 else: raise NotImplementedError if validation_split: val_data, l_train_data, ul_train_data = __val_labeled_unlabeled_split( cfg, train_data, test_data, num_classes, ul_train_data) else: l_train_data, ul_train_data = __labeled_unlabeled_split( cfg, train_data, test_data, num_classes, ul_train_data) val_data = None ul_train_data["images"] = np.concatenate([ul_train_data["images"], l_train_data["images"]], 0) ul_train_data["labels"] = np.concatenate([ul_train_data["labels"], l_train_data["labels"]], 0) if logger is not None: logger.info("number of :\n \ training data: %d\n \ labeled data: %d\n \ unlabeled data: %d\n \ validation data: %d\n \ test data: %d", len(train_data["images"]), len(l_train_data["images"]), len(ul_train_data["images"]), 0 if val_data is None else len(val_data["images"]), len(test_data["images"])) labeled_train_data = dataset_class.LabeledDataset(l_train_data) unlabeled_train_data = dataset_class.UnlabeledDataset(ul_train_data) train_data = np.concatenate([ labeled_train_data.dataset["images"], unlabeled_train_data.dataset["images"] ], 0) if cfg.dataset.whiten: mean = train_data.mean((0, 1, 2)) / 255. scale = train_data.std((0, 1, 2)) / 255. elif cfg.dataset.zca: mean, scale = utils.get_zca_normalization_param(numpy_batch_gcn(train_data)) elif dataset == 'mnistOOD': mean = [0.5] scale = [0.5] else: # from [0, 1] to [-1, 1] mean = [0.5, 0.5, 0.5] scale = [0.5, 0.5, 0.5] # set augmentation # RA: RandAugment, WA: Weak Augmentation randauglist = "fixmatch" if cfg.ssl_args.alg == "pl" else "uda" flags = [True if b == "t" else False for b in cfg.dataset.wa.split(".")] if cfg.dataset.labeled_aug == "RA": labeled_augmentation = gen_strong_augmentation( img_size, mean, scale, flags[0], flags[1], randauglist, cfg.dataset.zca) elif cfg.dataset.labeled_aug == "WA": labeled_augmentation = gen_weak_augmentation(img_size, mean, scale, *flags, cfg.dataset.zca) else: raise NotImplementedError labeled_train_data.transform = labeled_augmentation if cfg.dataset.unlabeled_aug == "RA": unlabeled_augmentation = gen_strong_augmentation( img_size, mean, scale, flags[0], flags[1], randauglist, cfg.dataset.zca) elif cfg.dataset.unlabeled_aug == "WA": unlabeled_augmentation = gen_weak_augmentation(img_size, mean, scale, *flags, cfg.dataset.zca) else: raise NotImplementedError if logger is not None: logger.info("labeled augmentation") logger.info(labeled_augmentation) logger.info("unlabeled augmentation") logger.info(unlabeled_augmentation) unlabeled_train_data.weak_augmentation = unlabeled_augmentation if cfg.dataset.strong_aug: strong_augmentation = gen_strong_augmentation( img_size, mean, scale, flags[0], flags[1], randauglist, cfg.dataset.zca) unlabeled_train_data.strong_augmentation = strong_augmentation if logger is not None: logger.info(strong_augmentation) if cfg.dataset.zca: test_transform = transforms.Compose([GCN(), ZCA(mean, scale)]) else: test_transform = transforms.Compose([transforms.Normalize(mean, scale, True)]) test_data = dataset_class.LabeledDataset(test_data, test_transform) if validation_split: validation_data = dataset_class.LabeledDataset(val_data, test_transform) return ( labeled_train_data, unlabeled_train_data, validation_data, test_data, num_classes, img_size ) else: return ( labeled_train_data, unlabeled_train_data, test_data, num_classes, img_size )