/*********************************************************************/
/* File:   mgpre.cc                                                */
/* Author: Joachim Schoeberl                                         */
/* Date:   20. Apr. 2000                                             */
/*********************************************************************/

/* 
   Multigrid Preconditioner
*/

#include <multigrid.hpp>

namespace ngmg
{
  using namespace ngmg;

  MultigridPreconditioner ::
  MultigridPreconditioner (const MeshAccess & ama,
			   const FESpace & afespace,
			   const BilinearForm & abiform,
			   Smoother * asmoother,
			   Prolongation * aprolongation)
    : BaseMatrix (), ma(ama), fespace(afespace), biform(abiform), 
      smoother(asmoother), prolongation(aprolongation)
  {
    coarsegridpre = NULL;

    SetSmoothingSteps (1);
    SetCycle (1);
    SetIncreaseSmoothingSteps (1);
    SetCoarseType (EXACT_COARSE);
    SetCoarseSmoothingSteps (1);

    SetOwnSmoother (1);
    SetOwnProlongation (1);
    SetOwnCoarseGridPreconditioner (1);
    SetUpdateAll (biform.UseGalerkin());

    checksumcgpre = -17;
    //    Update ();
  }

  
  MultigridPreconditioner :: ~MultigridPreconditioner ()
  {
    if (ownsmoother)
      delete smoother;
    if (ownprolongation)
      delete prolongation;
    if (owncoarsegridpre)
      delete coarsegridpre;
  }



  void MultigridPreconditioner :: SetSmoothingSteps (int sstep)
  {
    smoothingsteps = sstep;
  }

  void MultigridPreconditioner :: SetCycle (int c)
  {
    cycle = c;
  }

  void MultigridPreconditioner :: SetIncreaseSmoothingSteps (int incsm)
  {
    incsmooth = incsm;
  }
  
  void MultigridPreconditioner :: SetCoarseType (COARSETYPE ctyp)
  {
    coarsetype = ctyp;
  }

  void MultigridPreconditioner :: 
  SetCoarseGridPreconditioner (const BaseMatrix * acoarsegridpre)
  {
    coarsetype = USER_COARSE;
    coarsegridpre = const_cast<BaseMatrix*> (acoarsegridpre);
  }

  void MultigridPreconditioner :: SetCoarseSmoothingSteps (int cstep)
  {
    coarsesmoothingsteps = cstep;
  }

  void MultigridPreconditioner :: SetOwnSmoother (int os)
  { 
    ownsmoother = os;
  }

  void MultigridPreconditioner :: SetUpdateAll (int ua)
  {
    updateall = ua;
    smoother->SetUpdateAll (ua);
  }

  void MultigridPreconditioner :: SetOwnProlongation (int op)
  {
    ownprolongation = op;
  }

  void MultigridPreconditioner :: SetOwnCoarseGridPreconditioner (int oc)
  {
    owncoarsegridpre = oc;
  }



  void MultigridPreconditioner :: Update ()
  {
    smoother->Update();
    if (prolongation)
      prolongation->Update();


    //  coarsegridpre = biform.GetMatrix(1).CreateJacobiPrecond();
    // InverseMatrix();

    if (biform.GetNLevels() == 1 || updateall)
      {
	if (coarsetype == EXACT_COARSE)
	  {
	    /*
	      double checksum = biform.GetMatrix(1).CheckSum();
	      if (checksum != checksumcgpre)
	      {
	      cout << "factor coarse" << endl;
	      checksumcgpre = checksum;
	    */
	    cout << "factor coarse grid matrix" << endl;
	    if (coarsegridpre) delete coarsegridpre;
	    coarsegridpre =
	      dynamic_cast<const BaseSparseMatrix&> (biform.GetMatrix(0))
	      .InverseMatrix();
	    /*
	      }
	      else
	      {
	      cout << "do not factor coarse" << endl;
	      }
	    */
	  }
	else
	  {
	    /*
	      if (coarsegridpre)
	      coarsegridpre->Update();
	    */
	  }
      }
    //  SetSymmetric (biform.GetMatrix(1).Symmetric());


    if (prol_projection.Size() < ma.GetNLevels() && prolongation)
      {
	// BitArray * innerdof = prolongation->GetInnerDofs();
	BitArray * innerdof = biform.GetFESpace().CreateIntermediatePlanes (1);

	if (innerdof && 0)
	  {
	    /*
	    const SparseMatrix<Mat<3> > & m = 
	      dynamic_cast<const SparseMatrix<Mat<3> > &> (biform.GetMatrix());
	    BaseMatrix * inv =
	      new SparseCholesky<Mat<3> > (m, innerdof);
	    */
	    const SparseMatrix<double> & m = 
	      dynamic_cast<const SparseMatrix<double> &> (biform.GetMatrix());
	    BaseMatrix * inv =
	      new SparseCholesky<double> (m, innerdof);
	    prol_projection.Append (inv);
	  }
      }
   }

  void MultigridPreconditioner ::
  Mult (const BaseVector & x, BaseVector & y) const
  {
    try
      {
	y = 0;
	MGM (ma.GetNLevels()-1, y, x);
      }
    catch (exception & e)
      {
	cout << "typeid(x) = " << typeid(x).name() << endl;
	cout << "typeid(y) = " << typeid(y).name() << endl;
	throw Exception(e.what() +
			string ("\ncaught in MultigridPreconditioner::Mult\n"));
      }
    catch (Exception & e)
      {
	e.Append ("in MultigridPreconditioner::Mult\n");
	throw;
      }
  }

  void MultigridPreconditioner :: 
  MGM (int level, BaseVector & u, 
       const BaseVector & f, int incsm) const
  {
    int j;
    if (level <= 0)
      {
	switch (coarsetype)
	  {
	  case EXACT_COARSE:
	  case USER_COARSE:
	    {
	      u = (*coarsegridpre) * f;
	      if (coarsesmoothingsteps > 1)
		{
		  BaseVector & d = *smoother->CreateVector(0);
		  BaseVector & w = *smoother->CreateVector(0);
		  
		  smoother->Residuum (level, u, f, d);
		  w = (*coarsegridpre) * d;
		  u += w;
		  
		  delete &w;
		  delete &d;
		}
	      break;
	    }
	  case CG_COARSE:
	    {
	      CGSolver<double> inv (biform.GetMatrix (1));
	      u = inv * f;
	      break;
	    }
	  case SMOOTHING_COARSE:
	    {
	      smoother->PreSmooth (level, u, f, coarsesmoothingsteps);
	      smoother->PostSmooth (level, u, f, coarsesmoothingsteps);
	      break;
	    }
	  }
      }
    else
      {
	smoother->PreSmooth (level, u, f, smoothingsteps * incsm);

	if (cycle > 0)
	  {
	    BaseVector & d = *smoother->CreateVector(level);
	    BaseVector & w = *smoother->CreateVector(level);

	    TempVector dt = d.Range (0, fespace.GetNDofLevel(level-1));
	    TempVector wt = w.Range (0, fespace.GetNDofLevel(level-1));

	    smoother->Residuum (level, u, f, d);

	    /*
	    prol_projection[level]->Mult (d, w);
	    u.Range (0,w.Size()) += w;
	    smoother->Residuum (level, u, f, d);
	    */

	    prolongation->RestrictInline (level, d);

	    w = 0;
	    for (j = 1; j <= cycle; j++)
	      MGM (level-1, *wt, *dt, incsm * incsmooth);
	    
	    prolongation->ProlongateInline (level, w);
	    u += w;

	    /*
	    smoother->Residuum (level, u, f, d);
	    prol_projection[level]->Mult (d, w);
	    u.Range (0,w.Size()) += w;
	    */

	    delete &w;
	    delete &d;
	  }

	smoother->PostSmooth (level, u, f, smoothingsteps * incsm);
      }
  }


  void MultigridPreconditioner :: MemoryUsage (ARRAY<MemoryUsageStruct*> & mu) const
  {
    if (coarsegridpre) coarsegridpre->MemoryUsage (mu);
    if (smoother) smoother->MemoryUsage (mu);
  }








  TwoLevelMatrix :: 
  TwoLevelMatrix (const BaseMatrix * amat, 
		  const BaseMatrix * acpre, 
		  Smoother * asmoother, 
		  int alevel)
    : mat(amat), cpre (acpre), smoother(asmoother), level(alevel)
  {
    own_smoother = 1;
    SetSmoothingSteps (1);
    Update();
  }
  
  TwoLevelMatrix :: ~TwoLevelMatrix ()
  {
    if (own_smoother)
      delete smoother;
  }

  void TwoLevelMatrix :: Update()
  {
    //  const_cast<BaseMatrix*> (cpre) -> Update();
    smoother -> Update();
    //  jacsmoother -> Update();
    //    cout << "update 2level smoother" << endl;
  }

  void TwoLevelMatrix :: Mult (const BaseVector & x, BaseVector & y) const
  {
    BaseVector & cres = *cpre->CreateVector();
    BaseVector & cw = *cpre->CreateVector();
    BaseVector & res = *CreateVector();

    y = 0;
    //  jacsmoother->GSSmooth (y, x);
    smoother->PreSmooth (level, y, x, smoothingsteps);

    res = x - (*mat) * y;  
    cres = *res.Range (0, cres.Size());
    cw = (*cpre) * cres;
    *(y.Range (0, cw.Size())) += cw;

    smoother->PostSmooth (level, y, x, smoothingsteps);
    // jacsmoother->GSSmoothBack (y, x);

    delete &cres;
    delete &cw;
    delete &res;
  }

  BaseVector * TwoLevelMatrix :: CreateVector () const
  {
    BaseVector * vec = mat->CreateVector();
    return vec;
  }

  ostream & TwoLevelMatrix :: Print (ostream & s) const
  {
    s << "Twolevel Preconditioner\n";
    return s;
  }




  void TwoLevelMatrix :: MemoryUsage (ARRAY<MemoryUsageStruct*> & mu) const
  {
    if (cpre) cpre->MemoryUsage (mu);
    if (smoother) smoother->MemoryUsage (mu);
  }




}
