import numpy as np
import tensorflow as tf

from tracegnn.models.trace_anomaly.tfsnippet.shortcuts import VarScopeObject
from tracegnn.models.trace_anomaly.tfsnippet.utils import (DocInherit, add_name_and_scope_arg_doc,
                             get_default_scope_name, is_tensor_object,
                             reopen_variable_scope)

__all__ = ['BaseLayer']


# We choose to derive `BaseLayer` from `VarScopeObject`, even if it does not
# need such a VarScopeObject.  This is because we can enjoy having a
# uniquified layer name for each Layer object, and add its name to the name
# scopes generated by its method, so as to make the debugging messages of
# TensorFlow much clearer.

@DocInherit
class BaseLayer(VarScopeObject):
    """
    Base class for all neural network layers.
    """

    _build_require_input = False  #: whether or not `build` requires input

    @add_name_and_scope_arg_doc
    def __init__(self,
                 name=None,
                 scope=None):
        """
        Construct a new :class:`BaseLayer`.
        """
        super(BaseLayer, self).__init__(name=name, scope=scope)

        self._has_built = False

    def _build(self, input=None):
        raise NotImplementedError()

    def build(self, input=None):
        """
        Build the layer, creating all required variables.

        Args:
            input (Tensor or list[Tensor] or None): If :meth:`build` is called
                within :meth:`apply`, it will be the input tensor(s).
                Otherwise if it is called separately, it will be :obj:`None`.
        """
        if self._has_built:
            raise RuntimeError('Layer has already been built: {!r}'.
                               format(self))
        if self._build_require_input and input is None:
            raise ValueError('`{}` requires `input` to build.'.
                             format(self.__class__.__name__))
        with reopen_variable_scope(self.variable_scope):
            self._build(input)
            self._has_built = True

    def _apply(self, input):
        raise NotImplementedError()

    def apply(self, input):
        """
        Apply the layer on `input`, to produce output.

        Args:
            input (Tensor or list[Tensor]): The input tensor, or a list of
                input tensors.

        Returns:
            The output tensor, or a list of output tensors.
        """
        if is_tensor_object(input) or isinstance(input, np.ndarray):
            input = tf.convert_to_tensor(input)
            ns_values = [input]
        else:
            input = [tf.convert_to_tensor(i) for i in input]
            ns_values = input

        if not self._has_built:
            self.build(input)

        with tf.name_scope(get_default_scope_name('apply', self),
                           values=ns_values):
            return self._apply(input)

    def __call__(self, input):
        return self.apply(input)
