Source code for rztdl.dl.components.layers.dropout.gaussian_dropout

# -*- coding: utf-8 -*-
"""
| *@created on:* 2020-01-30,
| *@author:* shubham,
| *@version:* v0.0.1
|
| *Description:*
| 
| *Sphinx Documentation Status:* Complete
|
"""
import tensorflow as tf
from typeguard import typechecked

from rztdl.dl.components.layers.layer import Layer
from rztdl.dl.constants import TENSOR_OR_STR
from rztdl.utils.exceptions import ComponentException
from rztdl.utils.py_utils import raise_component_exception


[docs]class GaussianDropout(tf.keras.layers.GaussianDropout, Layer): """ Apply multiplicative 1-centered Gaussian noise """ @raise_component_exception @typechecked def __init__(self, name: str, rate: float, inputs: TENSOR_OR_STR=None, outputs: str=None): """ :param name: Name of component :param rate: drop probability (as with Dropout). The multiplicative noise will have standard deviation sqrt(rate / (1 - rate)) :param inputs: Input component/tensor :param outputs: Output name """ self.parameter_validation() tf.keras.layers.GaussianDropout.__init__(self, rate=rate, name=name) Layer.__init__(self, inputs=inputs, name=name, outputs=outputs)
[docs] def create(self, inputs): self.tensor_output = self(inputs=inputs) return self.tensor_output
[docs] def validate(self, inputs): if self.rate < 0 or self.rate > 1: raise ComponentException(component_name=self.name, message=f"Rate for GaussianDropout should be in range of (0,1). Given {self.rate}")
[docs] def parameter_validation(self): pass