Source code for rztdl.dl.dataset.dataset_handler

# -*- coding: utf-8 -*-
"""
@created on: 12/15/19,
@author: Himaprasoon,
@version: v0.0.1

Description:

Sphinx Documentation Status:

"""
import itertools
import tensorflow as tf
from typing import List, Dict
from typeguard import typechecked
from collections import defaultdict

from rztdl.dl.dataset.dataset import Dataset
from rztdl.dl.dataset.splits import TrainSplitHandler
from rztdl.utils.exceptions import DatasetHandlerException, DatasetException
from rztdl.utils.py_utils import RejectingDict


[docs]class DatasetHandler: def __init__(self, name="DatasetHandler"): self.name = name self.datasets = RejectingDict(prefix="Dataset ") self.handle = None self.dataset_split_buffer_name_mapping = defaultdict(set) @typechecked
[docs] def add(self, dataset: Dataset): """ Adds Dataset to Dataset Handler :param dataset: :return: """ self.datasets[dataset.name] = dataset self.check_split_buffer_duplicate(dataset=dataset)
[docs] def check_split_buffer_duplicate(self, dataset): """ Checks if any two datasets has same buffers for the same split :param dataset: :return: """ for split_name, buffer_names in dataset.split_buffer_name_mapping.items(): split_buffer_intersection = self.dataset_split_buffer_name_mapping[split_name].intersection(buffer_names) if split_buffer_intersection: raise DatasetException(dataset_name=dataset.name, message=f"Buffer {','.join(split_buffer_intersection)} for split" f" : '{split_name}' already present in another " f"dataset") self.dataset_split_buffer_name_mapping[split_name].update(set(buffer_names))
@typechecked
[docs] def initialize(self, split_required_buffer_mapping: Dict[str, List[str]], batch_size: int, as_dict: bool = True, split_handler: TrainSplitHandler = None): """ :param split_handler: :param batch_size: Batch size of data to be read :param split_required_buffer_mapping: Input Layer names to read data for. :param as_dict: Boolean Indicates if output should be a dictionary to list :return: """ self.dataset_split_buffer_name_mapping = defaultdict(set) all_required_buffers = list(itertools.chain(*split_required_buffer_mapping.values())) all_col_names = [] split_given_column_names_mapping = defaultdict(set) for dataset in self.datasets.values(): dataset.prepare_dataset(required_inputs=split_required_buffer_mapping, split_handler=split_handler) self.check_split_buffer_duplicate(dataset=dataset) for dataset in self.datasets.values(): for split_name, required_buffers in split_required_buffer_mapping.items(): split_given_column_names_mapping[split_name].update( [i for i in dataset.split_buffer_name_mapping.get(split_name, []) if i in required_buffers]) all_col_names.extend([i for i in dataset.buffer_names if i in all_required_buffers]) all_col_names = sorted(all_col_names) def map_func(*t): return list(itertools.chain(*t)) def dict_map_func(*t): return {buffer: data for buffer, data in zip(all_col_names, itertools.chain(*t))} if split_handler: # For train for split_name, required_buffers in split_required_buffer_mapping.items(): diff = set(required_buffers).difference(split_given_column_names_mapping[split_name]) if diff: raise DatasetHandlerException(name=self.name, message=f"Missing buffers {','.join(diff)}" f" for split {split_name} in datasets") self.handle = {} for split_name, split in split_handler.splits.items(): split_datasets = tuple([i.tf_dataset[split_name] for i in self.datasets.values() if i.tf_dataset and split_name in i.tf_dataset]) # if len(split_datasets) < 1: # raise DatasetHandlerException(name=self.name, message=f"No dataset for split {split_name} found") self.handle[split_name] = tf.data.Dataset.zip( split_datasets).map( dict_map_func if as_dict else map_func).batch(batch_size) return self # if not split handler in inference diff = set(all_required_buffers).difference(set(all_col_names)) if diff: raise DatasetHandlerException(name=self.name, message=f"Missing buffers {[','.join(diff)]}" f" in datasets") self.handle = tf.data.Dataset.zip(tuple([i.tf_dataset for i in self.datasets.values() if i.tf_dataset])).map( dict_map_func if as_dict else map_func).batch(batch_size) return self