Source code for rztdl.dl.components.layers.dropout.spatial_dropout1d

# -*- coding: utf-8 -*-
"""
| *@created on:* 2020-01-30,
| *@author:* shubham,
| *@version:* v0.0.1
|
| *Description:*
| 
| *Sphinx Documentation Status:* Complete
|
"""
import tensorflow as tf
from typeguard import typechecked

from rztdl.dl.components.layers.layer import Layer
from rztdl.dl.constants import TENSOR_OR_STR
from rztdl.utils.exceptions import ComponentException, DimensionError
from rztdl.utils.py_utils import raise_component_exception


[docs]class SpatialDropout1D(tf.keras.layers.SpatialDropout1D, Layer): """ Spatial 1D version of Dropout. """ @raise_component_exception @typechecked def __init__(self, name: str, rate: float, inputs: TENSOR_OR_STR = None, outputs: str = None): """ :param name: Name of component :param rate: drop probability (as with Dropout). The multiplicative noise will have standard deviation sqrt(rate / (1 - rate)) :param inputs: Input component/tensor :param outputs: Output name """ self.parameter_validation(rate=rate, name=name) tf.keras.layers.SpatialDropout1D.__init__(self, rate=rate, name=name) Layer.__init__(self, inputs=inputs, name=name, outputs=outputs)
[docs] def create(self, inputs): self.tensor_output = self(inputs=inputs) return self.tensor_output
[docs] def validate(self, inputs): if not len(inputs.shape) == 3: raise DimensionError(entity_name=self.name, message=f'SpatialDropout1D takes 3 dimensional input. ' f'Given {inputs.shape}')
[docs] def parameter_validation(self, rate, name): if rate < 0 or rate > 1: raise Exception(f"Rate for SpatialDropout1D should be in range of (0,1). Given {rate}")