/* Ergo, version 3.2, a program for linear scaling electronic structure
 * calculations.
 * Copyright (C) 2012 Elias Rudberg, Emanuel H. Rubensson, and Pawel Salek.
 * 
 * This program is free software: you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation, either version 3 of the License, or
 * (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 * You should have received a copy of the GNU General Public License
 * along with this program.  If not, see <http://www.gnu.org/licenses/>.
 * 
 * Primary academic reference:
 * Kohn−Sham Density Functional Theory Electronic Structure Calculations 
 * with Linearly Scaling Computational Time and Memory Usage,
 * Elias Rudberg, Emanuel H. Rubensson, and Pawel Salek,
 * J. Chem. Theory Comput. 7, 340 (2011),
 * <http://dx.doi.org/10.1021/ct100611z>
 * 
 * For further information about Ergo, see <http://www.ergoscf.org>.
 */

#include "hermite_conversion_prep.h"
#include <stdio.h>
#include <cmath>
#include <memory.h>
#include <assert.h>

#include "hermite_conversion_symb.h"
#include "monomial_info.h"

#if BASIS_FUNC_POLY_MAX_DEGREE<6
const int MAX_NO_OF_CONTRIBS = 1000000;
#else
const int MAX_NO_OF_CONTRIBS = 10000000;
#endif

hermite_conversion_info_struct::hermite_conversion_info_struct()
{

  const int nmax = HERMITE_CONVERSION_MAX_N;

  monomial_info_struct monomial_info;

  for(int n1 = 0; n1 <= nmax; n1++)
    for(int n2 = 0; n2 <= nmax; n2++)
      {
	hermite_conversion_contrib_struct* currlist = new hermite_conversion_contrib_struct[MAX_NO_OF_CONTRIBS];
	int count = 0;
	int nMon1 = monomial_info.no_of_monomials_list[n1];
	int nMon2 = monomial_info.no_of_monomials_list[n2];
	symb_matrix_element* list = new symb_matrix_element[nMon1*nMon1];
	int inverseFlag = 1;
	get_hermite_conversion_matrix_symb(&monomial_info,
					   n1,
					   inverseFlag,
					   list);      
	for(int j = 0; j < nMon1; j++)
	  for(int i = 0; i < nMon2; i++)
	    {
	      // Now take care of matrix element (i j)
	      for(int k = 0; k < nMon1; k++)
		{
		  int idx = j*nMon1+k;
		  if(std::fabs(list[idx].coeff) > 1e-5)
		    {
		      assert(count < MAX_NO_OF_CONTRIBS);
		      currlist[count].destIndex = j*nMon2+i;
		      currlist[count].sourceIndex = k*nMon2+i;
		      currlist[count].a_power = list[idx].ia;
		      currlist[count].coeff = list[idx].coeff;
		      count++;
		    }
		}
	    } // END FOR i j
	list_right[n1][n2] = new hermite_conversion_contrib_struct[count];
	memcpy(list_right[n1][n2], currlist, count*sizeof(hermite_conversion_contrib_struct));
	counters_right[n1][n2] = count;
	delete []currlist;
        delete []list;
      } // END FOR n1 n2

  for(int n1 = 0; n1 <= nmax; n1++)
    for(int n2 = 0; n2 <= nmax; n2++)
      {
	hermite_conversion_contrib_struct* currlist = new hermite_conversion_contrib_struct[MAX_NO_OF_CONTRIBS];
	int count = 0;
	int nMon1 = monomial_info.no_of_monomials_list[n1];
	int nMon2 = monomial_info.no_of_monomials_list[n2];
	symb_matrix_element* list = new symb_matrix_element[nMon2*nMon2];
	int inverseFlag = 1;
	get_hermite_conversion_matrix_symb(&monomial_info,
					   n2,
					   inverseFlag,
					   list);      
	for(int j = 0; j < nMon1; j++)
	  for(int i = 0; i < nMon2; i++)
	    {
	      // Now take care of matrix element (i j)
	      for(int k = 0; k < nMon2; k++)
		{
		  int idx = i*nMon2+k;
		  if(std::fabs(list[idx].coeff) > 1e-5)
		    {
		      assert(count < MAX_NO_OF_CONTRIBS);
		      currlist[count].destIndex = j*nMon2+i;
		      currlist[count].sourceIndex = j*nMon2+k;
		      currlist[count].a_power = list[idx].ia;
		      currlist[count].coeff = list[idx].coeff;
		      count++;
		    }
		}
	    } // END FOR i j
	list_left[n1][n2] = new hermite_conversion_contrib_struct[count];
	memcpy(list_left[n1][n2], currlist, count*sizeof(hermite_conversion_contrib_struct));
	counters_left[n1][n2] = count;
	delete []currlist;
        delete []list;
      } // END FOR n1 n2
  
}


hermite_conversion_info_struct::~hermite_conversion_info_struct()
{
  const int nmax = HERMITE_CONVERSION_MAX_N;

  for(int n1 = 0; n1 <= nmax; n1++)
    for(int n2 = 0; n2 <= nmax; n2++)
      {
	delete []list_right[n1][n2];
	delete []list_left [n1][n2];
      }
}


int hermite_conversion_info_struct::multiply_by_hermite_conversion_matrix_from_right(int n1max,        
										     int n2max,        
										     ergo_real a,      
										     ergo_real* A,     
										     ergo_real* result) const
{
  int noOfContribs = counters_right[n1max][n2max];
  hermite_conversion_contrib_struct* list = list_right[n1max][n2max];
  
  int nMon1 = monomial_info.no_of_monomials_list[n1max];
  int nMon2 = monomial_info.no_of_monomials_list[n2max];

  int Ntot = n1max + n2max;
  ergo_real apowlist[Ntot+1];
  apowlist[0] = 1;
  for(int i = 1; i <= Ntot; i++)
    apowlist[i] = apowlist[i-1] * a;
  
  for(int i = 0; i < nMon1*nMon2; i++)
    result[i] = 0;
  
  for(int i = 0; i < noOfContribs; i++)
    result[list[i].destIndex] += A[list[i].sourceIndex] * list[i].coeff * apowlist[-list[i].a_power];
  
  return 0;
}


int hermite_conversion_info_struct::multiply_by_hermite_conversion_matrix_from_left(int n1max,        
										    int n2max,        
										    ergo_real a,      
										    ergo_real* A,     
										    ergo_real* result) const
{
  int noOfContribs = counters_left[n1max][n2max];
  hermite_conversion_contrib_struct* list = list_left[n1max][n2max];
  
  int nMon1 = monomial_info.no_of_monomials_list[n1max];
  int nMon2 = monomial_info.no_of_monomials_list[n2max];
  
  int Ntot = n1max + n2max;
  ergo_real apowlist[Ntot+1];
  apowlist[0] = 1;
  for(int i = 1; i <= Ntot; i++)
    apowlist[i] = apowlist[i-1] * a;
  
  for(int i = 0; i < nMon1*nMon2; i++)
    result[i] = 0;
  
  for(int i = 0; i < noOfContribs; i++)
    result[list[i].destIndex] += A[list[i].sourceIndex] * list[i].coeff * apowlist[-list[i].a_power];
  
  return 0;
}


