# -*- coding: utf-8 -*-
"""This module handles the Function class in Python.
"""
# Copyright (C) 2009-2014 Johan Hake
#
# This file is part of DOLFIN.
#
# DOLFIN is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# DOLFIN is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with DOLFIN. If not, see <http://www.gnu.org/licenses/>.
#
# Modified by Martin Sandve Alnæs 2013-2014
# Modified by Anders Logg 2015

__all__ = ["Function", "TestFunction", "TrialFunction", "Argument",
           "TestFunctions", "TrialFunctions"]

import types

# Import UFL and SWIG-generated extension module (DOLFIN C++)
import ufl
from ufl import product
from ufl.utils.indexflattening import flatten_multiindex, shape_to_strides
import dolfin.cpp as cpp
import numpy

from dolfin.functions.functionspace import FunctionSpace
from dolfin.functions.constant import Constant
from six.moves import xrange as range

def _assign_error():
    cpp.dolfin_error("function.py",
                     "assign function",
                     "Expects only linear combinations of Functions in "\
                     "the same FunctionSpaces")

def _check_mul_and_division(e, linear_comb, scalar_weight=1.0, multi_index=None):
    """
    Utility func for checking division and multiplication of a Function
    with scalars in linear combinations of Functions
    """
    from ufl.constantvalue import ScalarValue
    from ufl.classes import ComponentTensor, MultiIndex, Indexed
    from ufl.algebra import Division, Product, Sum
    #ops = e.ufl_operands

    # FIXME: What should be checked!?
    # martinal: This code has never done anything sensible,
    #   but I don't know what it was supposed to do so I can't fix it.
    #same_multi_index = lambda x, y: (x.ufl_free_indices == y.ufl_free_indices \
    #                        and x.ufl_index_dimensions == y.ufl_index_dimensions)

    assert isinstance(scalar_weight, float)

    # Split passed expression into scalar and expr
    if isinstance(e, Product):
        for i, op in enumerate(e.ufl_operands):
            if isinstance(op, ScalarValue) or \
                   (isinstance(op, Constant) and op.value_size()==1):
                scalar = op
                expr = e.ufl_operands[1-i]
                break
        else:
            _assign_error()

        scalar_weight *= float(scalar)
    elif isinstance(e, Division):
        expr, scalar = e.ufl_operands
        if not (isinstance(scalar, ScalarValue) or \
                isinstance(scalar, Constant) and scalar.value_rank()==1):
            _assign_error()
        scalar_weight /= float(scalar)
    else:
        _assign_error()

    # If a CoefficientTensor is passed we expect the expr to be either
    # a Function or another ComponentTensor, where the latter wil
    # result in a recursive call
    if multi_index is not None:
        assert isinstance(multi_index, MultiIndex)
        assert isinstance(expr, Indexed)

        # Unpack Indexed and check equality with passed multi_index
        expr, multi_index2 = expr.ufl_operands
        assert isinstance(multi_index2, MultiIndex)
        #if not same_multi_index(multi_index, multi_index2):
        #    _assign_error()

    if isinstance(expr, Function):
        linear_comb.append((expr, scalar_weight))

    elif isinstance(expr, (ComponentTensor, Product, Division, Sum)):
        # If componentTensor we need to unpack the MultiIndices
        if isinstance(expr, ComponentTensor):
            expr, multi_index = expr.ufl_operands
            #if not same_multi_index(multi_index, multi_index2):
            #    _error()

        if isinstance(expr, (Product, Division)):
            linear_comb = _check_mul_and_division(expr, linear_comb, scalar_weight, multi_index)
        elif isinstance(expr, Sum):
            linear_comb = _check_and_extract_functions(expr, linear_comb, scalar_weight, multi_index)
        else:
            _assign_error()
    else:
        _assign_error()

    return linear_comb

def _check_and_extract_functions(e, linear_comb=None, scalar_weight=1.0,
                                 multi_index=None):
    """
    Utility func for extracting Functions and scalars in linear
    combinations of Functions
    """
    from ufl.classes import ComponentTensor, Sum, Product, Division
    linear_comb = linear_comb or []

    # First check u
    if isinstance(e, Function):
        linear_comb.append((e, scalar_weight))
        return linear_comb

    # Second check a*u*b, u/a/b, a*u/b where a and b are scalars
    elif isinstance(e, (Product, Division)):
        linear_comb = _check_mul_and_division(e, linear_comb, scalar_weight, multi_index)
        return linear_comb

    # Third check a*u*b, u/a/b, a*u/b where a and b are scalars and u
    # is a Tensor
    elif isinstance(e, ComponentTensor):
        e, multi_index = e.ufl_operands
        linear_comb = _check_mul_and_division(e, linear_comb, scalar_weight, multi_index)
        return linear_comb

    # If not Product or Division we expect Sum
    elif isinstance(e, Sum):
        for op in e.ufl_operands:
            linear_comb = _check_and_extract_functions(op, linear_comb, \
                                                       scalar_weight, multi_index)

    else:
        _assign_error()

    return linear_comb

def _check_and_contract_linear_comb(expr, self, multi_index):
    """
    Utility func for checking and contracting linear combinations of
    Functions
    """
    linear_comb = _check_and_extract_functions(expr, multi_index=multi_index)
    funcs = []
    weights = []
    funcspace = None
    for func, weight in linear_comb:
        funcspace = funcspace or func.function_space()
        if func not in funcspace:
            _assign_error()
        try:
            # Check if the exact same Function is already present
            ind = funcs.index(func)
            weights[ind] += weight
        except:
            funcs.append(func)
            weights.append(weight)

    # Check that rhs does not include self
    for ind, func in enumerate(funcs):
        if func == self:
            # If so make a copy
            funcs[ind] = self.copy(deepcopy=True)
            break

    return list(zip(funcs, weights))

class Function(ufl.Coefficient, cpp.Function):
    """This class represents a function :math:`u_h` in a finite
    element function space :math:`V_h`, given by

    .. math::

        u_h = \sum_{i=1}^n U_i \phi_i,

    where :math:`\{\phi_i\}_{i=1}^n` is a basis for :math:`V_h`,
    and :math:`U` is a vector of expansion coefficients for
    :math:`u_h`.

    *Arguments*
        There is a maximum of three arguments. The first argument must be a
        Function or a :py:class:`FunctionSpace
        <dolfin.functions.functionspace.FunctionSpace>`.

        If instantiated from another Function, the (optional)
        second argument must be an integer denoting the number
        of sub functions to extract.

        In addition can a name argument be passed overruling the default name

    *Examples*
        Create a Function:

        - from a :py:class:`FunctionSpace
          <dolfin.functions.functionspace.FunctionSpace>` ``V``

          .. code-block:: python

              f = Function(V)

        - from cpp.Function ``f``

          *Warning: this constructor is intended for internal libray use only.*
          No copying is done - ``f`` is only wrapped as Function.

          .. code-block:: python

              g = Function(f)

        - from a :py:class:`FunctionSpace
          <dolfin.functions.functionspace.FunctionSpace>` ``V`` and a
          :py:class:`GenericVector <dolfin.cpp.GenericVector>` ``v``

          *Warning: this constructor is intended for internal libray use only.*

          .. code-block:: python

              g = Function(V, v)

        - from a :py:class:`FunctionSpace
          <dolfin.functions.functionspace.FunctionSpace>` and a
          filename containg a :py:class:`GenericVector
          <dolfin.cpp.GenericVector>`

          .. code-block:: python

              g = Function(V, 'MyVectorValues.xml')

    """

    def __init__(self, *args, **kwargs):
        """Initialize Function."""
        # Initial quick check for valid arguments (other checks
        # sprinkled below)
        if len(args) == 0:
            raise TypeError("expected 1 or more arguments")

        # Type switch on argument types
        if isinstance(args[0], Function):
            other = args[0]
            if len(args) == 1:
                # NOTE: Turn this into error when removing deprecation
                # warning
                cpp.deprecation("Function copy constructor", "2016.1",
                                "Use 'Function.copy(deepcopy=True)' for copying.")
                self.__init_copy_constructor(other)
            elif len(args) == 2:
                i = args[1]
                if not isinstance(i, int):
                    raise TypeError("Invalid subfunction number %s" % (i,))
                self.__init_subfunction_constructor(other, i)
            else:
                raise TypeError("expected one or two arguments when "
                                "instantiating from another Function")
        elif isinstance(args[0], cpp.Function):
            other = args[0]
            if len(args) == 1:
                # If creating a dolfin.Function from a cpp.Function
                self.__init_from_cpp_function(other)
            else:
                raise TypeError("expected only one argument when passing cpp.Function"
                                "to dolfin.Function constructor")
        elif isinstance(args[0], FunctionSpace):
            V = args[0]
            # If initialising from a FunctionSpace
            if len(args) == 1:
                # If passing only the FunctionSpace
                self.__init_from_function_space(V)
            elif len(args) == 2:
                # If passing FunctionSpace together with cpp.Function
                # Attached passed FunctionSpace and initialize the
                # cpp.Function using the passed Function
                other = args[1]
                if isinstance(other, cpp.Function):
                    self.__init_from_function_space_and_cpp_function(V, other)
                else:
                    self.__init_from_function_space_and_function(V, other)
            else:
                raise TypeError("too many arguments")
        else:
            raise TypeError("expected a FunctionSpace or a Function as argument 1")

        # Set name as given or automatic
        name = kwargs.get("name") or "f_%d" % self.count()
        self.rename(name, "a Function")

    def __init_copy_constructor(self, other):
        cpp.Function.__init__(self, other)
        ufl.Coefficient.__init__(self, other.ufl_function_space(), count=self.id())

    def __init_from_cpp_function(self, other):
        # Assign all the members (including 'this' pointer to SWIG wraper)
        # NOTE: This in fact performs assignment of C++ context
        self.__dict__ = other.__dict__

        # Initialize the ufl.FunctionSpace (Not calling cpp.Function.__init__)
        ufl.Coefficient.__init__(self, other.function_space().ufl_function_space(), count=self.id())

    def __init_subfunction_constructor(self, other, i):
        num_sub_spaces = other.function_space().num_sub_spaces()
        if num_sub_spaces == 1:
            raise RuntimeError("No subfunctions to extract")
        if not i < num_sub_spaces:
            raise RuntimeError("Can only extract subfunctions "
                               "with i = 0..%d"% num_sub_spaces)
        cpp.Function.__init__(self, other, i)
        ufl.Coefficient.__init__(self, self.function_space().ufl_function_space(), count=self.id())

    def __init_from_function_space(self, V):
        cpp.Function.__init__(self, V)
        ufl.Coefficient.__init__(self, V.ufl_function_space(), count=self.id())

    def __init_from_function_space_and_cpp_function(self, V, other):
        # Simple consistency checks on function spaces
        if other.function_space().dim() != V.dim():
            raise ValueError("non matching dimensions on passed FunctionSpaces")
        cpp.Function.__init__(self, other)
        ufl.Coefficient.__init__(self, V.ufl_function_space(), count=self.id())

    def __init_from_function_space_and_function(self, V, other):
        cpp.Function.__init__(self, V, other)
        ufl.Coefficient.__init__(self, V.ufl_function_space(), count=self.id())

    def sub(self, i, deepcopy = False):
        """
        Return a sub function.

        The sub functions are numbered from i = 0..N-1, where N is the
        total number of sub spaces.

        *Arguments*
            i : int
                The number of the sub function

        """
        if not isinstance(i, int):
            raise TypeError("expects an 'int' as first argument")
        num_sub_spaces = self.function_space().num_sub_spaces()
        if num_sub_spaces == 1:
            raise RuntimeError("No subfunctions to extract")
        if not i < num_sub_spaces:
            raise RuntimeError("Can only extract subfunctions with i = 0..%d" \
                               % num_sub_spaces)

        # Create and instantiate the Function
        if deepcopy:
            return Function(self.function_space().sub(i), \
                            cpp.Function.sub(self, i), \
                            name='%s-%d' % (str(self), i))
        else:
            return Function(self, i, name='%s-%d' % (str(self), i))

    def assign(self, rhs):
        """
        Assign either a Function or linear combination of Functions.

        *Arguments*
            rhs (_Function_)
                A Function or a linear combination of Functions. If a linear
                combination is passed all Functions need to be in the same
                FunctionSpaces.
        """
        from ufl.classes import ComponentTensor, Sum, Product, Division
        if isinstance(rhs, (cpp.Function, cpp.Expression, cpp.FunctionAXPY)):
            # Avoid self assignment
            if self == rhs:
                return

            self._assign(rhs)
        elif isinstance(rhs, (Sum, Product, Division, ComponentTensor)):
            if isinstance(rhs, ComponentTensor):
                rhs, multi_index = rhs.ufl_operands
            else:
                multi_index = None
            linear_comb = _check_and_contract_linear_comb(rhs, self, \
                                                          multi_index)
            assert(linear_comb)

            # If the assigned Function lives in a different FunctionSpace
            # we cannot operate on this function directly
            same_func_space = linear_comb[0][0] in self.function_space()
            func, weight = linear_comb.pop()

            # Assign values from first func
            if not same_func_space:
                self._assign(func)
                vector = self.vector()
            else:
                vector = self.vector()
                vector[:] = func.vector()

            # If first weight is not 1 scale
            if weight != 1.0:
                vector *= weight

            # AXPY the other functions
            for func, weight in linear_comb:
                if weight == 0.0:
                    continue
                vector.axpy(weight, func.vector())

        else:
            cpp.dolfin_error("function.py",
                             "function assignment",
                             "Expects a Function or linear combinations of "\
                             "Functions in the same FunctionSpaces")

    def split(self, deepcopy=False):
        """
        Extract any sub functions.

        A sub function can be extracted from a discrete function that
        is in a mixed, vector, or tensor FunctionSpace. The sub
        function resides in the subspace of the mixed space.

        *Arguments*
            deepcopy
                Copy sub function vector instead of sharing

        """

        num_sub_spaces = self.function_space().num_sub_spaces()
        if num_sub_spaces == 1:
            raise RuntimeError("No subfunctions to extract")
        return tuple(self.sub(i, deepcopy) for i in range(num_sub_spaces))

    def __str__(self):
        """Return a pretty print representation of it self.
        """
        return self.name()

    def __repr__(self):
        """Return a str repr of it self.

        Must use ufl.__repr__ for this"""
        return ufl.Coefficient.__repr__(self)

    def str(self, verbose=False):
        """Return an informative str representation of itself"""
        # FIXME: We might change this using rank and dimension instead
        return "<Function in %s>" % str(self.function_space())

    def ufl_evaluate(self, x, component, derivatives):
        """Function used by ufl to evaluate the Function"""
        import numpy
        import ufl
        assert derivatives == () # TODO: Handle derivatives

        if component:
            shape = self.ufl_shape
            assert len(shape) == len(component)
            value_size = product(shape)
            index = flatten_multiindex(component, shape_to_strides(shape))
            values = numpy.zeros(value_size)
            self(*x, values=values)
            return values[index]
        else:
            # Scalar evaluation
            return self(*x)

    def __float__(self):
        if self.ufl_shape != ():
            raise RuntimeError("Cannot convert nonscalar function to float.")
        elm = self.ufl_element()
        if elm.family() != "Real":
            raise RuntimeError("Cannot convert spatially varying function to float.")
        # FIXME: This could be much simpler be exploiting that the
        # vector is ghosted
        # Gather value directly from vector in a parallel safe way
        vec = self.vector()
        indices = numpy.zeros(1, dtype=cpp.la_index_dtype())
        values = vec.gather(indices)
        return float(values[0])

    def __call__(self, *args, **kwargs):
        """
        Evaluates the Function.

        *Examples*
            1) Using an iterable as x:

              .. code-block:: python

                  fs = Expression("sin(x[0])*cos(x[1])*sin(x[3])")
                  x0 = (1.,0.5,0.5)
                  x1 = [1.,0.5,0.5]
                  x2 = numpy.array([1.,0.5,0.5])
                  v0 = fs(x0)
                  v1 = fs(x1)
                  v2 = fs(x2)

            2) Using multiple scalar args for x, interpreted as a
            point coordinate

              .. code-block:: python

                  v0 = f(1.,0.5,0.5)

            3) Using a Point

              .. code-block:: python

                  p0 = Point(1.,0.5,0.5)
                  v0 = f(p0)

            3) Passing return array

              .. code-block:: python

                  fv = Expression(("sin(x[0])*cos(x[1])*sin(x[3])",
                               "2.0","0.0"))
                  x0 = numpy.array([1.,0.5,0.5])
                  v0 = numpy.zeros(3)
                  fv(x0, values = v0)

              .. note::

                  A longer values array may be passed. In this way one can fast
                  fill up an array with different evaluations.

              .. code-block:: python

                  values = numpy.zeros(9)
                  for i in xrange(0,10,3):
                      fv(x[i:i+3], values = values[i:i+3])

        """

        if len(args)==0:
            raise TypeError("expected at least 1 argument")

        # Test for ufl restriction
        if len(args) == 1 and isinstance(args[0], str):
            if args[0] in ('+', '-'):
                return ufl.Coefficient.__call__(self, *args)

        # Test for ufl mapping
        if len(args) == 2 and isinstance(args[1], dict) and self in args[1]:
            return ufl.Coefficient.__call__(self, *args)

        # Some help variables
        value_size = product(self.ufl_element().value_shape())

        # If values (return argument) is passed, check the type and length
        values = kwargs.get("values", None)
        if values is not None:
            if not isinstance(values, numpy.ndarray):
                raise TypeError("expected a NumPy array for 'values'")
            if len(values) != value_size or \
                   not numpy.issubdtype(values.dtype, 'd'):
                raise TypeError("expected a double NumPy array of length"\
                      " %d for return values."%value_size)
            values_provided = True
        else:
            values_provided = False
            values = numpy.zeros(value_size, dtype='d')

        # Get the geometric dimension we live in
        dim = self.ufl_domain().geometric_dimension()

        # Assume all args are x argument
        x = args

        # If only one x argument has been provided, unpack it if it's
        # an iterable
        if len(x) == 1:
            if isinstance(x[0], cpp.Point):
                x = [x[0][i] for i in range(dim)]
            elif hasattr(x[0], '__iter__'):
                x = x[0]

        # Convert it to an 1D numpy array
        try:
            x = numpy.fromiter(x, 'd')
        except (TypeError, ValueError, AssertionError) as e:
            raise TypeError("expected scalar arguments for the coordinates")

        if len(x) == 0:
            raise TypeError("coordinate argument too short")

        if len(x) != dim:
            raise TypeError("expected the geometry argument to be of "\
                  "length %d"%dim)

        # The actual evaluation
        self.eval(values, x)

        # If scalar return statement, return scalar value.
        if value_size == 1 and not values_provided:
            return values[0]

        return values

#--- Subclassing of ufl.{Basis, Trial, Test}Function ---

# TODO: Update this message to clarify dolfin.FunctionSpace vs ufl.FunctionSpace
_ufl_dolfin_difference_message = """\
When constructing an Argument, TestFunction or TrialFunction,
you must to provide a FunctionSpace and not a FiniteElement.
The FiniteElement class provided by ufl only represents an
abstract finite element space and is only used in standalone
.ufl files, while the FunctionSpace provides a full discrete
function space over a given mesh and should be used in dolfin
programs in Python.
"""

class Argument(ufl.Argument):
    """UFL value: Representation of an argument to a form.

    This is the overloaded PyDOLFIN variant.
    """
    def __init__(self, V, number, part=None):

        # Check argument
        if not isinstance(V, (FunctionSpace, cpp.MultiMeshFunctionSpace)):
            if isinstance(V, (ufl.FiniteElementBase, ufl.FunctionSpace)):
                raise TypeError(_ufl_dolfin_difference_message)
            else:
                raise TypeError("Illegal argument for creation of Argument, not a FunctionSpace: " + str(V))
            raise TypeError("Illegal argument for creation of Argument, not a FunctionSpace: " + str(V))

        # Handle MultiMesh
        if isinstance(V, cpp.MultiMeshFunctionSpace):
            self._V_multi = V
            V = V._parts[0]

        # Initialize UFL Argument
        ufl.Argument.__init__(self, V.ufl_function_space(), number, part)

        self._V = V

    def function_space(self):
        "Return the FunctionSpace"
        return self._V

    def __eq__(self, other):
        """Extending UFL __eq__ here to distinguish test and trial
        functions in different function spaces with same ufl element."""
        return (isinstance(other, Argument) and
                self.number() == other.number() and
                self.part() == other.part() and
                self._V == other._V)

    def __hash__(self):
        return ufl.Argument.__hash__(self)

def TestFunction(V, part=None):
    """UFL value: Create a test function argument to a form.

    This is the overloaded PyDOLFIN variant.
    """
    return Argument(V, 0, part)

def TrialFunction(V, part=None):
    """UFL value: Create a trial function argument to a form.

    This is the overloaded PyDOLFIN variant.
    """
    return Argument(V, 1, part)

#--- TestFunctions and TrialFunctions ---

def Arguments(V, number):
    """UFL value: Create an Argument in a mixed space, and return a
    tuple with the function components corresponding to the subelements.

    This is the overloaded PyDOLFIN variant.
    """
    return ufl.split(Argument(V, number))

def TestFunctions(V):
    """UFL value: Create a TestFunction in a mixed space, and return a
    tuple with the function components corresponding to the subelements.

    This is the overloaded PyDOLFIN variant.
    """
    return ufl.split(TestFunction(V))

def TrialFunctions(V):
    """UFL value: Create a TrialFunction in a mixed space, and return a
    tuple with the function components corresponding to the subelements.

    This is the overloaded PyDOLFIN variant.
    """
    return ufl.split(TrialFunction(V))
