Source code for rztdl.dl.model

# -*- coding: utf-8 -*-
"""
@created on: 12/13/19,
@author: Himaprasoon,
@version: v0.0.1

Description:

Sphinx Documentation Status:

"""
import logging
from typing import List, Union
import networkx as nx
import tensorflow as tf
from typeguard import typechecked
from rztdl.dl.components import RZTComponent
from rztdl.dl.components.layers import Input
from rztdl.dl.components.layers.layer import Layer
from rztdl.dl.components.metrics.metric import Metric
from rztdl.dl.components.optimizers import Optimizer
from rztdl.utils.exceptions import ComponentException
from rztdl.utils.py_utils import partition, validate_name, RejectingDict, raise_component_exception
from rztdl.dl.components.group_instance import GroupInstance

logger = logging.getLogger(__name__)


[docs]def check_closed(f): def wrapper(*args, **kwargs): if not args[0].finalized: raise Exception("Not finalized") return f(*args, **kwargs) return wrapper
[docs]class RZTModel: """ Model Class """ @typechecked def __init__(self, name: str, dist=False): """ :param name: str """ self.name = validate_name(name) self.components = RejectingDict(prefix="Component") self.component_outputs = RejectingDict(prefix="Output") self.dag = nx.MultiDiGraph() self.previous_component = None self.finalized = False self.input_layers = [] self.dist = dist self.scopes = [] self.optimizers = [] if self.dist: import horovod.tensorflow as hvd hvd.init() gpus = tf.config.experimental.list_physical_devices('GPU') for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) if gpus: tf.config.experimental.set_visible_devices(gpus[hvd.local_rank()], 'GPU') self.tensor_mapping = {} def _get_input_map(self, input_name, component, input_key): try: # if input is a component_name inp_component = self.get_component_by_name(input_name) if len(inp_component.outputs) > 1: raise Exception( f"input {input_name} is a component with multiple outputs. Use outputs instead of component name") self.dag.add_edge(inp_component, component, source='output', target=input_key) return inp_component.tensor_output except KeyError: # if input is a component output name inp_component = self.get_component_from_output(input_name) self.dag.add_edge(inp_component, component, source=input_key, target=input_name) return inp_component.get_tensor_output_by_name(output_name=input_name) except TypeError as e: inp_component = self.get_component_from_tensor(input_name) self.dag.add_edge(inp_component, component, source='output', target=input_key) return input_name raise Exception("Invalid Input") @raise_component_exception @typechecked
[docs] def add(self, component: RZTComponent): """ Adds a component to the graph :param component: :return: """ self.dag.add_node(component) self.components[component.name] = component self.component_outputs[component.name] = component # Test if component is of type input or meant to treat as input (For first component in Group) if isinstance(component, Input) or component.treat_component_as_input: self.input_layers.append(component) component.validate(component.inputs) logger.info(f"{component.__class__.__name__} {component.__class__.__bases__[-1].__name__}" f" '{component.name}' validated successfully") if component.treat_component_as_input: component.create(component.inputs['inputs']) else: component.create() else: if not component.inputs: # If previous component is not defined (Sequential model) try: component.inputs['inputs'] = self.previous_component.name except AttributeError as e: raise ValueError("First component in a model has to be a Input Layer") if len(self.previous_component.outputs) > 1: raise ComponentException(component_name=component.name, message="Previous component is a multi output layer. Please provide inputs") input_mapping = {} for input_key, val in component.inputs.items(): if isinstance(val, list): # For inputs is a list eg : concat input_mapping[input_key] = [] for input_name in val: input_mapping[input_key].append( self._get_input_map(input_name=input_name, component=component, input_key=input_key)) else: # for components with multiple inputs like mse input_mapping[input_key] = self._get_input_map(input_name=val, component=component, input_key=input_key) component.validate(**input_mapping) logger.info(f"{component.__class__.__name__} {component.__class__.__bases__[-1].__name__}" f" '{component.name}' validated successfully") component.create(**input_mapping) # Group Instance block if isinstance(component, GroupInstance): # Extend model components with group components # Needs to be done in this way to support RejectingDict # If instance is shared but has different outputs, add the new ones for k, v in component.instance.components.items(): k = component._change_component_name(old_name=k) if k not in self.components: self.components[k] = v # Extend model components outputs with group components outputs # Needs to be done in this way to support RejectingDict # If instance is shared but has different outputs, add the new ones for k, v in component.instance.component_outputs.items(): k = component._change_component_name(old_name=k) if k not in self.component_outputs: self.component_outputs[k] = v # Extend model tensor mapping with group tensor mapping self.tensor_mapping = {**self.tensor_mapping, **component.instance.tensor_mapping} # Merge model dag and group dag self.dag = nx.compose(G=self.dag, H=component.instance.dag) # Block to link model dag and group dag usage_of_inputs = component.instance._get_usage_of_input(instance_name= component.sharable if component.sharable else component.name) for inp, comp in usage_of_inputs: inp = self.get_component_from_tensor(component.inputs[inp]) if isinstance(component.inputs[inp], tf.Tensor) else self.get_component( component.inputs[inp]) self.dag.add_edge(inp, self.get_component(comp), source='link', target=comp) # Visualize Dag Connections # from matplotlib import pyplot as plt # plt.figure(figsize=(8, 8)) # nx.draw(self.dag) # plt.show() # For tensor input if isinstance(component.tensor_output, tf.Tensor): self.tensor_mapping[component.tensor_output.name] = component elif isinstance(component.tensor_output, dict): for key, val in component.tensor_output.items(): if isinstance(val, list): for t in val: self.tensor_mapping[t.name] = t.name else: self.tensor_mapping[val.name] = component elif isinstance(component.tensor_output, list): for val in component.tensor_output: self.tensor_mapping[val.name] = component if isinstance(component.outputs, dict): for input_key, val in component.outputs.items(): if val: self.component_outputs[val] = component elif isinstance(component.outputs, list): for val in component.outputs: if val: self.component_outputs[val] = component # For dropout if isinstance(component.tensor_output, tf.Tensor) and isinstance(component, Layer): component.tensor_output = component.apply_normalization() if component.dropout_rate: component.tensor_output = component.apply_dropout() self.previous_component = component # Adding Scopes if isinstance(component, Layer): if component.scopes: self.scopes.extend(component.scopes) if isinstance(component, Optimizer): # For checking trainable variables self.optimizers.append(component) component.check_optimizer_path_and_scopes( keras_model=self.get_validation_model([component.name]), model_scopes=self.scopes) return component.tensor_output
@typechecked
[docs] def get_component_from_tensor(self, tensor: tf.Tensor): """ Check if tensor is part of graph and gets corresponding component :param tensor: :return: """ if not hasattr(tensor, "name"): raise ValueError(f"Eager tensor {tensor} not part of grpah") try: return self.tensor_mapping[tensor.name] except KeyError as e: raise KeyError(f"Tensor {tensor.name} not found in graph")
@typechecked
[docs] def get_required_buffers(self, components: List[Union[str, RZTComponent]], get_names=False): """ Return list of Input layers given a list of components :param get_names: if output needs to be a list of component_name :param components: :return: """ buffers = set() for component in components: if isinstance(component, str): component = self.get_component(component) if isinstance(component, Input): buffers.add(component) continue for i in nx.ancestors(self.dag, component): if isinstance(i, Input): buffers.add(i) if get_names: return [i.name for i in sorted(buffers, key=lambda x: x.name)] return sorted(buffers, key=lambda x: x.name)
# @typechecked # def set_weights_for_optimizer(self, optimizer: Optimizer, scopes: List = None): # """ # Sets weights for each optimizer to be used when applying grads in train function # :param scopes: # :param optimizer: # :return: # """ # keras_model = self.get_validation_model([optimizer]) # if scopes: # trainable_variables = [] # scopes = set(scopes) # for layer in keras_model.layers: # if hasattr(layer, "scopes") and layer.scopes: # if set(layer.scopes).intersection(scopes): # trainable_variables.extend(layer.trainable_weights) # optimizer.trainable_variables = trainable_variables # if not optimizer.trainable_variables: # raise Exception( # f"No Variables to optimize in given scopes {list(scopes)} for optimizer '{optimizer.name}'") # else: # optimizer.trainable_variables = keras_model.trainable_variables # if not optimizer.trainable_variables: # raise Exception(f"No Variables to optimize in given optimizer '{optimizer.name}'")
[docs] def get_complete_model(self, outputs=None): leaf_nodes = [x for x in self.dag.nodes() if self.dag.out_degree(x) == 0] optimizers, other_nodes = partition(lambda x: isinstance(x, Optimizer), leaf_nodes) for optimizer in optimizers: # Done as optimizer inputs are not leaf nodes if isinstance(optimizer.inputs['inputs'], tf.Tensor): other_nodes.append(optimizer.inputs['inputs']) else: other_nodes.append(self.get_component(optimizer.inputs['inputs'])) if outputs is None: outputs = [] for i in other_nodes: if isinstance(i, tf.Tensor): outputs.append(i) elif isinstance(i, str): pass elif isinstance(i.tensor_output, dict): outputs.extend(i.tensor_output.values()) else: outputs.append(i.tensor_output) return tf.keras.Model(inputs=[i.tensor_output for i in self.get_required_buffers(self.input_layers)], outputs=outputs )
@check_closed @typechecked
[docs] def write_graph_tensorboard(self, logdir: str): keras_model = self.get_complete_model() TC = tf.keras.callbacks.TensorBoard(logdir) TC.set_model(keras_model) return TC
@typechecked
[docs] def get_validation_model(self, component_or_output_names: List[str], get_mapping=False): def get_mapping_input(cmp_object, comp_name_or_out_name): for k, v in cmp_object.outputs.items(): if v == comp_name_or_out_name: return cmp_object.tensor_output[k] raise ValueError( f"{comp_name_or_out_name} is a multi output tensor. Cannot use component name : use outputs instead") mapping_outputs = [] tensor_outputs = [] for component_name in component_or_output_names: component = self.get_component(component_name) if isinstance(component, Optimizer): t = self.get_component(component.inputs['inputs']) if isinstance(t.tensor_output, dict): # If metric has multiple outputs mapping_outputs.append(component.inputs['inputs']) tensor_outputs.append( get_mapping_input(cmp_object=t, comp_name_or_out_name=component.inputs['inputs'])) else: mapping_outputs.append(t.name) tensor_outputs.append(t.tensor_output) else: if isinstance(component.tensor_output, dict): tensor_outputs.append(get_mapping_input(cmp_object=component, comp_name_or_out_name=component_name)) else: tensor_outputs.append(component.tensor_output) mapping_outputs.append(component_name) keras_model = tf.keras.Model( inputs=[i.tensor_output for i in self.get_required_buffers(component_or_output_names)], outputs=tensor_outputs) if get_mapping: # Returns the order of metrics/losses return keras_model, mapping_outputs return keras_model
@typechecked
[docs] def get_component_by_name(self, name: Union[str, RZTComponent]): if isinstance(name, RZTComponent): return name return self.components[name]
@typechecked
[docs] def get_component_from_output(self, output_name: Union[str, RZTComponent]): if isinstance(output_name, RZTComponent): return output_name try: return self.component_outputs[output_name] except KeyError as e: raise KeyError(f"Unknown Output or Component : {output_name}")
@typechecked
[docs] def get_component(self, name: Union[str, RZTComponent]): try: return self.get_component_by_name(name=name) except KeyError: return self.get_component_from_output(output_name=name)
[docs] def close(self): self.finalized = True
[docs] def add_group(self): pass
[docs] def clear(self): self.components.clear() self.component_outputs.clear() self.finalized = False self.input_layers = [] self.previous_component = None self.dag = nx.MultiDiGraph() self.tensor_mapping.clear()
[docs] def predecessors(self, component): """ Gets predecessors of the component in the dag :param component: :return: """ return list( self.dag.predecessors(self.get_component_by_name(component) if isinstance(component, str) else component))