Source code for rztdl.dl.components.layers.dropout.spatial_dropout2d

# -*- coding: utf-8 -*-
"""
| *@created on:* 2020-01-30,
| *@author:* shubham,
| *@version:* v0.0.1
|
| *Description:*
| 
| *Sphinx Documentation Status:* Complete
|
"""
import tensorflow as tf
from typeguard import typechecked

from rztdl.dl.components.layers.layer import Layer
from rztdl.dl.constants import TENSOR_OR_STR
from rztdl.dl.constants.string_constants import DataFormat
from rztdl.utils.exceptions import ComponentException, DimensionError
from rztdl.utils.py_utils import raise_component_exception


[docs]class SpatialDropout2D(tf.keras.layers.SpatialDropout2D, Layer): """ Spatial 2D version of Dropout. """ @raise_component_exception @typechecked def __init__(self, name: str, rate: float, data_format: DataFormat = DataFormat.CHANNELS_LAST, inputs: TENSOR_OR_STR = None, outputs: str = None): """ :param name: Name of component :param rate: drop probability (as with Dropout). The multiplicative noise will have standard deviation sqrt(rate / (1 - rate)) :param data_format: 'CHANNELS_FIRST' / 'CHANNELS_LAST' In 'CHANNELS_FIRST' mode, the channels dimension (the depth) is at index 1 In 'CHANNELS_LAST' mode is it at index 3. :param inputs: Input component/tensor :param outputs: Output name """ self.parameter_validation(name=name, rate=rate) tf.keras.layers.SpatialDropout2D.__init__(self, rate=rate, data_format=data_format.name.lower(), name=name) Layer.__init__(self, inputs=inputs, name=name, outputs=outputs)
[docs] def create(self, inputs): self.tensor_output = self(inputs=inputs) return self.tensor_output
[docs] def validate(self, inputs): if not len(inputs.shape) == 4: raise DimensionError(entity_name=self.name, message=f'SpatialDropout2D takes 4 dimensional input. ' f'Given {inputs.shape}')
[docs] def parameter_validation(self, rate, name): if rate < 0 or rate > 1: raise ComponentException(component_name=name, message=f"Rate for SpatialDropout2D should be in range of (0,1). Given {rate}")