Source code for rztdl.dl.groups

# -*- coding: utf-8 -*-
"""
@created on: 01/20/20,
@author: Prathyush SP,
@version: v0.0.1

Description:

Sphinx Documentation Status:

"""

import logging
import typing
from copy import deepcopy

from typeguard import typechecked

from rztdl.dl.components.component import RZTComponent
from rztdl.dl.components.layers import Input
from rztdl.dl.components.layers.layer import Layer
from rztdl.dl.components.optimizers import Optimizer
from rztdl.dl.constants import TENSOR_OR_STR
from rztdl.dl.model import RZTModel
from rztdl.utils.exceptions import ComponentException

logger = logging.getLogger(__name__)


[docs]class Group(RZTModel): """ Group Class """ @typechecked def __init__(self, name: str, inputs: typing.List[TENSOR_OR_STR]): """ :param name: Name of the group :param inputs: Inputs for the group """ super().__init__(name=name) self.component_list = [] self.instances = {} self._delimiter = '_' self._is_finalized = False # Used as a backup of inputs. inputs object is dynamic and is changed accordingly self._group_inputs = deepcopy(inputs) self.inputs = inputs # If inputs are > 1, to avoid ambiguity, the first component in the group needs to specify the component input if not self.inputs: raise ComponentException(component_name=self.name, message=f"Group {self.name} inputs cannot be empty") self._mandatory_input_for_first_component = True if len(self.inputs) > 1 else False @typechecked
[docs] def add(self, component: RZTComponent): """ Add components to the group - Lazy addition (Components are validated and are stored in a list. Components are created in compile method) :param component: RZTComponent """ if self._is_finalized: raise ComponentException(component_name=self.name, message=f"Cannot add component. Group: {self.name} is finalized") # Optimizer and Input types are not supported if isinstance(component, Optimizer): raise ComponentException(component_name=self.name, message=f"Component {component.name} is of type Optimizer and is not supported in groups") elif isinstance(component, Input): raise ComponentException(component_name=self.name, message=f"Component {component.name} is of type Input and is not supported in groups") # Test for mandatory input if self._mandatory_input_for_first_component: if not component.inputs: raise ComponentException(component_name=self.name, message=f"Group {self.name} has multiple inputs. Input for the first group component is mandatory") self._mandatory_input_for_first_component = False # Check if the specified inputs are defined in group inputs are components available in the group if component.inputs: for input_key, val in component.inputs.items(): if isinstance(val, str): if val not in self.inputs: raise ComponentException(component_name=self.name, message=f"Specified input {val} not found.") if isinstance(val, list): # For inputs is a list eg : concat for input_name in val: if input_name not in self.inputs: raise ComponentException(component_name=self.name, message=f"Specified input {input_name} not found.") # Lazily add component to component_list self.component_list.append(component) # Update component inputs for validation self.inputs.append(component.name) # Update component outputs as inputs for validation if isinstance(component.outputs, dict): for input_key, val in component.outputs.items(): if val: self.inputs.append(val) elif isinstance(component.outputs, list): for val in component.outputs: if val: self.inputs.append(val) return "Group components are lazily evaluated. Use group instance and add the group to the model"
def _create_instance(self, instance_name: str): """ Create a new instance :param instance_name: Name for the instance :return: New Instance (New list of components) """ new_component_instances = [] if not self._mandatory_input_for_first_component: if isinstance(self.component_list[0], Layer): self.component_list[0].inputs['inputs'] = self._group_inputs[0] for component in self.component_list: # Copy from original component new_component = deepcopy(component) # Modify name to support multiple / sharable groups new_component._name = self._join_instance_name(instance_name, new_component._name) # Replace raw names with instance based names for inputs for k, v in new_component.inputs.items(): if isinstance(v, str): if v not in self._group_inputs: new_component.inputs[k] = self._join_instance_name(instance_name, v) elif isinstance(v, list): for e, ip in enumerate(v): if ip not in self._group_inputs: new_component.inputs[k][e] = self._join_instance_name(instance_name, ip) # Replace raw names with instance based names for outputs if isinstance(new_component.outputs, dict): for k, v in new_component.outputs.items(): if isinstance(v, str): new_component.outputs[k] = self._join_instance_name(instance_name, v) elif isinstance(new_component.outputs, list): for e, op in enumerate(new_component.outputs): new_component.outputs[e] = self._join_instance_name(instance_name, op) new_component_instances.append(new_component) return new_component_instances def _join_instance_name(self, instance_name: str, component_name: str): """ Append instance name to the component name :param instance_name: Name of the Instance :param component_name: Name of the component :return: Updated component name """ return instance_name + self._delimiter + component_name def _remove_instance_name(self, instance_name: str, complete_name: str): """ Remove instance name from the component name :param instance_name: Name of the Instance :param complete_name: Name of the Component :return: Updated component name """ return complete_name.replace(instance_name + self._delimiter, '') def _compile(self, instance_name: str, component_list: list, group_input: dict, group_output: dict): """ Compile the sub model by passing inputs :param instance_name: Name of the Instance :param group_input: Inputs for the sub model :param group_output: Outputs from the sub model :return: Dictionary of component {str:Tensor} """ # Clear dag and model self.clear() initial_component_output = group_input # Dependency on Model API for component_name, component_tensor in group_input.items(): # Unable to find the usecase # if isinstance(component_tensor, list): # for t in component_tensor: # self.tensor_mapping[t.name] = t.name # else: self.tensor_mapping[component_tensor.name] = component_tensor.name components_to_return = {} # Set first component in the group as input. Inputs are not supported by Group (Model API dependency) component_list[0].treat_component_as_input = True # Input mapping for component in component_list: for k, v in component.inputs.items(): if isinstance(v, str): if v in initial_component_output: component.inputs[k] = initial_component_output[v] elif isinstance(v, list): for e, inp in enumerate(v): if isinstance(inp, str) and inp in initial_component_output: component.inputs[k][e] = initial_component_output[inp] # Use model api to add component and carry out required ops (Heavy lifting is done by Model API) out = super().add(component) # Block to return required outputs defined in the group instance default_component_name = self._remove_instance_name(instance_name, component.name) if default_component_name in group_output: components_to_return[group_output[default_component_name]] = self.get_component( component.name).tensor_output self.components[group_output[default_component_name]] = component if isinstance(component.outputs, dict): for k, v in component.outputs.items(): if isinstance(v, str): default_component_name = self._remove_instance_name(instance_name, v) if default_component_name in group_output: components_to_return[group_output[default_component_name]] = self.get_component( v).tensor_output self.components[group_output[default_component_name]] = component elif isinstance(component.outputs, list): for comp in component.outputs: if isinstance(comp, str): default_component_name = self._remove_instance_name(instance_name, comp) if default_component_name in group_output: components_to_return[group_output[default_component_name]] = self.get_component( comp).tensor_output self.components[group_output[default_component_name]] = component return components_to_return def _get_or_create_instance(self, instance_name: str): """ Get or Create Instance :param instance_name: Name of the Instance :return: List of components for the given instance """ if not self._is_finalized: raise ComponentException(component_name=self.name, message=f"Group {self.name} is not finalized. Call group.close() before instantiation.") if instance_name in self.instances: # This checks for sharable instance (Instance exists) return self.instances[instance_name] self.instances[instance_name] = self._create_instance(instance_name) return self.instances[instance_name] @typechecked def _get_usage_of_input(self, instance_name: str): """ Get usage of Inputs - Used in Model API to detect the input mapping to create links between model and group dags :param instance_name: Name of the Instance :return: List of Inputs """ usage_of_inputs = [] for component in self.component_list: for k, v in component.inputs.items(): if isinstance(v, str): if v in self._group_inputs: usage_of_inputs.append((v, self._join_instance_name(instance_name, component.name))) # Unable to find the usecase # if isinstance(v, list): # for inp in v: # if inp in self._group_inputs: # usage_of_inputs.append((inp, self._join_instance_name(instance_name, component.name))) return usage_of_inputs
[docs] def close(self): """ Finalize Group :return: """ self._is_finalized = True