Source code for rztdl.dl.dataset.generator_dataset

# -*- coding: utf-8 -*-
"""
@created on: 1/8/20,
@author: Himaprasoon,
@version: v0.0.1

Description:

Sphinx Documentation Status:

"""
import itertools
import types
import typing

import tensorflow as tf

from rztdl.dl.dataset.dataset import Dataset
from rztdl.utils.exceptions import DatasetHandlerException, DatasetException
from logging import getLogger

logger = getLogger(__name__)


[docs]class GeneratorDataset(Dataset): def __init__(self, name: str, buffer_names: typing.List[str], gen_function: typing.Union[types.FunctionType, typing.Dict[str, types.FunctionType]]): super().__init__(buffer_names=buffer_names, name=name) self.gen_function = gen_function if not isinstance(self.gen_function, dict): logger.warning("gen_function should be a dict : Not recommended in production ") else: self.set_split_buffer_name_mapping(self.gen_function.keys()) self.tf_dataset = None
[docs] def prepare_dataset(self, required_inputs: dict, split_handler): def get_require_col(func, req_buffers): def _get_required_col(): for i in func(): try: yield tuple(i[col] for col in req_buffers) except KeyError as e: raise KeyError(f"key {str(e)} not found in {self.__class__.__name__} {self.name}") return _get_required_col # Check for splits self.tf_dataset = {} if split_handler: if not isinstance(self.gen_function, dict): self.set_split_buffer_name_mapping(split_handler.splits.keys()) self.gen_function = {split_name: self.gen_function for split_name in split_handler.splits.keys()} for split_name, split in split_handler.splits.items(): req_buffer = [i for i in self.buffer_names if i in required_inputs[split_name]] if split_name in self.gen_function: generator_function = get_require_col(func=self.gen_function[split_name], req_buffers=req_buffer) self.tf_dataset[split_name] = tf.data.Dataset.from_generator(generator_function, output_types=tuple( [tf.float32 for i in range(len(req_buffer))]) ) else: all_required_buffers = list(itertools.chain(*required_inputs.values())) all_required_buffers = [i for i in self.buffer_names if i in all_required_buffers] if not all_required_buffers: return self if isinstance(self.gen_function, dict): raise DatasetException(dataset_name=self.name, message="gen function cannot be dict in Inference") self.tf_dataset = tf.data.Dataset.from_generator( get_require_col(func=self.gen_function, req_buffers=all_required_buffers), output_types=tuple( [tf.float32 for i in range(len(all_required_buffers))]) ) return self