Source code for rztdl.dl.dataset.splits

from rztdl.dl.constants.string_constants import IntervalType
from rztdl.dl.model import RZTModel
from typeguard import typechecked
from typing import Union, List

from rztdl.utils.exceptions import DatasetSplitException, SplitHandlerException
from rztdl.utils.py_utils import RejectingDict


[docs]class DataSplit: @typechecked def __init__(self, name: str, split_ratio: int, metrics: List[str], log_frequency: Union[int, None] = 1): """ :param name: Name of split :param split_ratio: split_ratio :param metrics: :param log_frequency: """ self.name = name self.split_ratio = split_ratio self.metrics = metrics self.total_batches = "?" self.log_frequency = log_frequency if len(metrics) < 1: raise DatasetSplitException(split_name=self.name, message="At least one metric should be added") if self.log_frequency is not None and self.log_frequency < 0: raise DatasetSplitException(split_name=self.name, message=f"Log frequency should be either None or >=0 : Given {self.log_frequency}") if self.split_ratio > 100 or self.split_ratio <= 0: raise DatasetSplitException(split_name=self.name, message=f"Split ration should be between 0 and 100 Given {self.split_ratio}")
# TODO : Himaprasoon : Make sure you reset metrics for train split also
[docs]class TrainDataSplit(DataSplit): @typechecked def __init__(self, name: str, split_ratio: int, metrics: List[str], log_frequency: Union[int, None] = None): super().__init__(name=name, split_ratio=split_ratio, metrics=metrics, log_frequency=log_frequency) self.metric_components = None
[docs] def initialize(self, model: RZTModel): self.metric_components = [model.get_component(metric) for metric in self.metrics]
[docs] def reset_metrics(self): for metric in self.metric_components: metric.reset_states()
[docs]class ValidationDataSplit(DataSplit): """" Split for Validation """ @typechecked def __init__(self, name: str, split_ratio: int, metrics: List[str], interval, frequency: int, log_frequency: Union[int, None] = None): """ :param name: :param split_ratio: :param metrics: :param interval: :param frequency: :param log_frequency: """ self.interval = interval self.frequency = frequency super().__init__(name=name, split_ratio=split_ratio, metrics=metrics, log_frequency=log_frequency) self.run_function = None self.metric_components = None if self.interval not in [IntervalType.BATCH_END, IntervalType.EPOCH_END]: raise DatasetSplitException(split_name=self.name, message="Interval Type should be either epoch or batch") if self.frequency < 1: raise DatasetSplitException(split_name=self.name, message=f"Frequency should be >0 : Given {self.frequency}") @typechecked
[docs] def initialize(self, model: RZTModel): self.metric_components = [model.get_component(metric) for metric in self.metrics] self.run_function = model.get_validation_model(component_or_output_names=self.metrics)
[docs] def reset_metrics(self): for metric in self.metric_components: metric.reset_states()
[docs]class TestDataSplit(DataSplit): @typechecked def __init__(self, name: str, split_ratio: int, metrics: List[str], log_frequency: Union[int, None] = None): super().__init__(name=name, split_ratio=split_ratio, metrics=metrics, log_frequency=log_frequency) self.run_function = None self.metric_components = None @typechecked
[docs] def initialize(self, model: RZTModel): self.metric_components = [model.get_component(metric) for metric in self.metrics] self.run_function = model.get_validation_model(component_or_output_names=self.metrics)
[docs] def reset_metrics(self): for metric in self.metric_components: metric.reset_states()
[docs]class TrainSplitHandler: def __init__(self, name: str = "TrainSplitHandler"): self.name = name self.splits = RejectingDict(prefix="Split") self.closed = False self.metrics = set() self.train_split_added = None self.ratio_sum = 0 def __iter__(self): return self.splits.values().__iter__()
[docs] def get_split(self, sp_name: str): return self.splits[sp_name]
@typechecked
[docs] def add(self, split: DataSplit): self.check_closed() if split.name in self.splits: # TODO exception handling raise DatasetSplitException(split_name=split.name, message="Duplicate name") if isinstance(split, TrainDataSplit): if self.train_split_added: raise SplitHandlerException(name=self.name, message="Cannot add two TrainSplits") self.train_split_added = True self.ratio_sum += split.split_ratio self.splits[split.name] = split self.metrics.update(split.metrics)
[docs] def validate(self): if self.ratio_sum != 100: raise SplitHandlerException(name=self.name, message=f"Sum of split ratios {self.ratio_sum} != 100") if not self.train_split_added: raise SplitHandlerException(name=self.name, message=f"At least one train split must be added to {self.__class__.__name__}")
[docs] def check_closed(self): if self.closed: raise SplitHandlerException(name=self.name, message="Call backs Handler Closed")
[docs] def close(self): self.validate() self.closed = True