//                                               -*- C++ -*-
/**
 * @file  PythonNumericalMathEvaluationImplementation.cxx
 * @brief This class binds a Python function to an Open TURNS' NumericalMathFunction
 *
 *  (C) Copyright 2005-2012 EDF-EADS-Phimeca
 *
 *  This library is free software; you can redistribute it and/or
 *  modify it under the terms of the GNU Lesser General Public
 *  License as published by the Free Software Foundation; either
 *  version 2.1 of the License.
 *
 *  This library is distributed in the hope that it will be useful
 *  but WITHOUT ANY WARRANTY; without even the implied warranty of
 *  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 *  Lesser General Public License for more details.
 *
 *  You should have received a copy of the GNU Lesser General Public
 *  License along with this library; if not, write to the Free Software
 *  Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307 USA
 *
 * \author $LastChangedBy: lebrun $
 * \date   $LastChangedDate: 2012-03-20 06:42:21 +0100 (Tue, 20 Mar 2012) $
 */

#include "PythonNumericalMathEvaluationImplementation.hxx"
#include "OSS.hxx"
#include "Description.hxx"
#include "PythonWrappingFunctions.hxx"
#include "PersistentObjectFactory.hxx"
#include "Exception.hxx"

#include "swig_runtime.hxx"

// Python marshalling functions not defined in Python.h
//#include "marshal.h"

BEGIN_NAMESPACE_OPENTURNS

typedef NumericalMathEvaluationImplementation::CacheKeyType             CacheKeyType;
typedef NumericalMathEvaluationImplementation::CacheValueType           CacheValueType;
typedef NumericalMathEvaluationImplementation::CacheType                CacheType;


CLASSNAMEINIT(PythonNumericalMathEvaluationImplementation);

static Factory<PythonNumericalMathEvaluationImplementation> RegisteredFactory("PythonNumericalMathEvaluationImplementation");



/* Default constructor */
PythonNumericalMathEvaluationImplementation::PythonNumericalMathEvaluationImplementation()
  : NumericalMathEvaluationImplementation(),
    pyObj_(0)
{
  // Nothing to do
}


/* Constructor from Python object*/
PythonNumericalMathEvaluationImplementation::PythonNumericalMathEvaluationImplementation(PyObject * pyCallable)
  : NumericalMathEvaluationImplementation(),
    pyObj_(pyCallable)
{
  Py_XINCREF( pyCallable );

  // Set the name of the object as its Python classname
  PyObject * cls = PyObject_GetAttrString( pyObj_,
                                           const_cast<char *>( "__class__" ) );
  PyObject * name = PyObject_GetAttrString( cls,
                                            const_cast<char *>( "__name__" ) );
  setName( PyString_AsString( name ) );
  Py_XDECREF( name );
  Py_XDECREF( cls );

  // One MUST initialize the description with the correct dimension
  PyObject * descIn  = PyObject_GetAttrString( pyObj_,
                                               const_cast<char *>( "descIn" ) );
  PyObject * descOut = PyObject_GetAttrString( pyObj_,
                                               const_cast<char *>( "descOut" ) );

  const UnsignedLong inputDimension  = getInputDimension();
  const UnsignedLong outputDimension = getOutputDimension();
  Description description(inputDimension + outputDimension);

  if ( ( descIn != NULL ) &&
       PySequence_Check( descIn ) &&
       ( PySequence_Size( descIn ) == inputDimension ) ) {
    PyObject * newPyObj = PySequence_Fast( descIn, "" );

    for (UnsignedLong i = 0; i < inputDimension; ++i) {
      PyObject * elt = PySequence_Fast_GET_ITEM( newPyObj, i );
      if ( isAPython<_PyString_>( elt ) ) description[i] = PyString_AsString( elt );
      else throw InvalidArgumentException(HERE) << "Input description of Python NumericalMathFunction contains a non-string object at position " << i;
    }

    Py_XDECREF( newPyObj );
  } else for (UnsignedLong i = 0; i < inputDimension; ++i) description[i] = (OSS() << "x" << i);


  if ( ( descOut != NULL ) &&
       PySequence_Check( descOut ) &&
       ( PySequence_Size( descOut ) == outputDimension ) ) {
    PyObject * newPyObj = PySequence_Fast( descOut, "" );

    for (UnsignedLong i = 0; i < outputDimension; ++i) {
      PyObject * elt = PySequence_Fast_GET_ITEM( newPyObj, i );
      if ( isAPython<_PyString_>( elt ) ) description[inputDimension + i] = PyString_AsString( elt );
      else throw InvalidArgumentException(HERE) << "Output description of Python NumericalMathFunction contains a non-string object at position " << i;
    }


    Py_XDECREF( newPyObj );
  } else for (UnsignedLong i = 0; i < outputDimension; ++i) description[inputDimension + i] = (OSS() << "y" << i);

  setDescription(description);
  Py_XDECREF( descIn );
  Py_XDECREF( descOut );
}

/* Virtual constructor */
PythonNumericalMathEvaluationImplementation * PythonNumericalMathEvaluationImplementation::clone() const
{
  return new PythonNumericalMathEvaluationImplementation(*this);
}

/* Copy constructor */
PythonNumericalMathEvaluationImplementation::PythonNumericalMathEvaluationImplementation(const PythonNumericalMathEvaluationImplementation & other)
  : NumericalMathEvaluationImplementation(other),
    pyObj_(other.pyObj_)
{
  Py_XINCREF( pyObj_ );
}

/* Destructor */
PythonNumericalMathEvaluationImplementation::~PythonNumericalMathEvaluationImplementation()
{
  Py_XDECREF( pyObj_ );
}

/* Comparison operator */
Bool PythonNumericalMathEvaluationImplementation::operator ==(const PythonNumericalMathEvaluationImplementation & other) const
{
  return true;
}

/* String converter */
String PythonNumericalMathEvaluationImplementation::__repr__() const {
  OSS oss;
  oss << "class=" << PythonNumericalMathEvaluationImplementation::GetClassName()
      << " name=" << getName()
      << " description=" << getDescription()
      << " parameters=" << getParameters();
  return oss;
}

/* String converter */
String PythonNumericalMathEvaluationImplementation::__str__(const String & offset) const {
  OSS oss;
  oss << "class=" << PythonNumericalMathEvaluationImplementation::GetClassName()
      << " name=" << getName();
  return oss;
}

/* Test for actual implementation */
Bool PythonNumericalMathEvaluationImplementation::isActualImplementation() const
{
  return true;
}


void PythonNumericalMathEvaluationImplementation::handleException() const
{
  PyObject * exception = PyErr_Occurred();

  if ( exception ) {

    PyObject *type = NULL, *value = NULL, *traceback = NULL;
    PyErr_Fetch( &type, &value, &traceback );

    String typeString;
    String valueString;
    String tracebackString;

    if ( type ) {
      typeString = PyString_AsString( PyObject_Str( type ) );
    }
    if ( value ) {
      valueString = PyString_AsString( PyObject_Str( value ) );
    }
    if ( traceback ) {
      tracebackString = PyString_AsString( PyObject_Str( traceback ) );
    }

    String exceptionMessage("Python exception caught, Type=" + typeString + ", Value=" + valueString);

    PyErr_Restore( type, value, traceback );

    if ( PyErr_ExceptionMatches(PyExc_RuntimeError) )
      throw InternalException(HERE) << "Call to Python method failed. " + exceptionMessage;
    else if ( PyErr_ExceptionMatches( PyExc_TypeError ) )
      throw InvalidArgumentException(HERE) << "Invalid argument passed to Python method. " + exceptionMessage;
    else {
      throw InternalException(HERE) << exceptionMessage;
    }
  }
}



/* Here is the interface that all derived class must implement */

/* Operator () */
NumericalPoint PythonNumericalMathEvaluationImplementation::operator() (const NumericalPoint & inP) const
/*        throw(InvalidArgumentException,InternalException)*/
{
  const UnsignedLong dimension( inP.getDimension() );

  if ( dimension != getInputDimension() )
    throw InvalidArgumentException(HERE) << "Input point has incorrect dimension. Got " << dimension << ". Expected " << getInputDimension();

  NumericalPoint outP;
  CacheKeyType inKey( inP.getCollection() );
  if ( p_cache_->isEnabled() && p_cache_->hasKey( inKey ) )
    {
      outP = NumericalPoint::ImplementationType( p_cache_->find( inKey ) );
    }
  else
    {
      ++ callsNumber_;

      PyObject * point = SWIG_NewPointerObj( inP.clone(), SWIG_TypeQuery("OT::NumericalPoint *"), SWIG_POINTER_OWN | 0 );

      PyObject * result = PyObject_CallFunctionObjArgs( pyObj_, point, NULL );

      if ( result == NULL ) {
        handleException();
      }

      try {
        outP = convert<_PySequence_,NumericalPoint>( result );
      } catch (const InvalidArgumentException & ex) {
        throw InvalidArgumentException(HERE) << "Output value for " << getName() << "._exec() method is not a sequence object (list, tuple, NumericalPoint, etc.)";
      }

      if ( p_cache_->isEnabled() )
        {
          CacheValueType outValue( outP.getCollection() );
          p_cache_->add( inKey, outValue );
        }

      Py_XDECREF( point  );
      Py_XDECREF( result );

    }
  if (isHistoryEnabled_)
    {
      inputStrategy_.store(inP);
      outputStrategy_.store(outP);
    }
  return outP;
}


/* Operator () */
NumericalSample PythonNumericalMathEvaluationImplementation::operator() (const NumericalSample & inS) const
/*        throw(InvalidArgumentException,InternalException)*/
{
  const UnsignedLong size( inS.getSize() );
  const UnsignedLong inDim( inS.getDimension() );
  const UnsignedLong outDim( getOutputDimension() );

  if ( inDim != getInputDimension() )
    throw InvalidArgumentException(HERE) << "Input point has incorrect dimension. Got " << inDim << ". Expected " << getInputDimension();

  Indices toDo;
  NumericalSample outS( size, outDim );
  if ( p_cache_->isEnabled() )
    {
      for (UnsignedLong i = 0; i < size; ++ i )
        {
          CacheKeyType inKey( inS[i].getCollection() );
          if ( p_cache_->hasKey( inKey ) )
            {
              outS[i] = NumericalPoint::ImplementationType( p_cache_->find( inKey ) );
            }
          else
            {
              toDo.add( i );
            }
        }
    }
  else
    {
      toDo = Indices( size );
      toDo.fill();
    }
  UnsignedLong toDoSize( toDo.getSize() );

  if ( toDoSize > 0 )
    {
      callsNumber_ += toDoSize;

      PyObject * inTuple = PyTuple_New( toDoSize );

      for ( UnsignedLong i = 0; i < toDoSize; ++ i ) {
        PyObject * eltTuple = PyTuple_New( inDim );
        for ( UnsignedLong j = 0; j < inDim; ++ j ) PyTuple_SetItem( eltTuple, j, PyFloat_FromDouble( inS[toDo[i]][j] ) );
        PyTuple_SetItem( inTuple, i, eltTuple );
      }

      PyObject * result = PyObject_CallFunctionObjArgs( pyObj_, inTuple, NULL );

      if ( result == NULL ) {
        handleException();
      }

      if ( PySequence_Check( result ) ) {
        const long lengthResult = PySequence_Size( result );
        if ( lengthResult == toDoSize ) {
          for (long i = 0; i < toDoSize; ++i) {
            PyObject * elt = PySequence_GetItem( result, i );
            if ( PySequence_Check( elt ) ) {
              const long lengthElt = PySequence_Size( elt );
              if ( lengthElt == outDim ) {
                for (UnsignedLong j = 0; j < outDim; ++j) {
                  PyObject * val = PySequence_GetItem( elt, j );
                  outS[toDo[i]][j] = PyFloat_AsDouble( val );
                  Py_XDECREF( val );
                }
              } else {
                throw InvalidArgumentException(HERE) << "Python NumericalMathFunction returned an sequence object with incorrect dimension (at position "
                                                     << i << ")";
              }
            } else {
              throw InvalidArgumentException(HERE) << "Python NumericalMathFunction returned an object which is NOT a sequence (at position "
                                                   << i << ")";
            }
            Py_XDECREF( elt );
          }
        } else {
          throw InvalidArgumentException(HERE) << "Python NumericalMathFunction returned an sequence object with incorrect size (got "
                                               << lengthResult << ", expected " << toDoSize << ")";
        }
      }

      Py_XDECREF( inTuple );
      Py_XDECREF( result  );
    }

  if ( p_cache_->isEnabled() )
    {
      for (UnsignedLong i = 0; i < toDoSize; ++i)
        {
          CacheKeyType inKey( inS[toDo[i]].getCollection() );
          CacheValueType outValue( outS[toDo[i]].getCollection() );
          p_cache_->add( inKey, outValue );
        }
    }
  if (isHistoryEnabled_)
    {
      inputStrategy_.store(inS);
      outputStrategy_.store(outS);
    }

  return outS;
}


/* Accessor for input point dimension */
UnsignedLong PythonNumericalMathEvaluationImplementation::getInputDimension() const
/*        throw(InternalException)*/
{
  PyObject * result = PyObject_CallMethod( pyObj_,
                                           const_cast<char *>( "getInputDimension" ),
                                           const_cast<char *>( "()" ) );
  UnsignedLong dim = PyLong_AsLong( result );
  Py_XDECREF( result );
  return dim;
}


/* Accessor for output point dimension */
UnsignedLong PythonNumericalMathEvaluationImplementation::getOutputDimension() const
/*        throw(InternalException)*/
{
  PyObject * result = PyObject_CallMethod( pyObj_,
                                           const_cast<char *>( "getOutputDimension" ),
                                           const_cast<char *>( "()" ) );
  UnsignedLong dim = PyLong_AsLong( result );
  Py_XDECREF( result );
  return dim;
}


/* Method save() stores the object through the StorageManager */
void PythonNumericalMathEvaluationImplementation::save(Advocate & adv) const
{
  NumericalMathEvaluationImplementation::save( adv );

  PyObject * pickleModule = PyImport_ImportModule( "pickle" ); // new reference
  assert( pickleModule );

  PyObject * pickleDict = PyModule_GetDict( pickleModule );

  PyObject * dumpsMethod = PyDict_GetItemString( pickleDict, "dumps" );
  assert( dumpsMethod );
  if ( ! PyCallable_Check( dumpsMethod ) )
    throw InternalException(HERE) << "Python 'pickle' module has no 'dumps' method";


  assert( pyObj_ );
  PyObject * pyInstanceSt = PyObject_CallFunction( dumpsMethod, const_cast<char *>("O"), pyObj_ ); // new reference
  if (PyErr_Occurred()) PyErr_Print();
  assert( pyInstanceSt );

  adv.saveAttribute( "pyInstance_", String( PyString_AsString( pyInstanceSt ) ) );

  Py_XDECREF( pyInstanceSt );
  Py_XDECREF( pickleModule );
}


/* Method save() reloads the object from the StorageManager */
void PythonNumericalMathEvaluationImplementation::load(Advocate & adv)
{
  NumericalMathEvaluationImplementation::load( adv );

  String pyInstanceSt;
  adv.loadAttribute( "pyInstance_", pyInstanceSt );

  PyObject * pickleModule = PyImport_ImportModule( "pickle" ); // new reference
  assert( pickleModule );

  PyObject * pickleDict = PyModule_GetDict( pickleModule );

  PyObject * loadsMethod = PyDict_GetItemString( pickleDict, "loads" );
  assert( loadsMethod );
  if ( ! PyCallable_Check( loadsMethod ) )
    throw InternalException(HERE) << "Python 'pickle' module has no 'loads' method";

  Py_XDECREF( pyObj_ );
  pyObj_ = PyObject_CallFunction( loadsMethod, const_cast<char *>("s"), pyInstanceSt.c_str() ); // new reference
  if (PyErr_Occurred()) PyErr_Print();
  assert( pyObj_ );

  Py_XDECREF( pickleModule );
}


END_NAMESPACE_OPENTURNS
