// Copyright (C) 2008-2009 Anders Logg.
// Licensed under the GNU LGPL Version 2.1.
//
// First added:  2008-05-15
// Last changed: 2009-09-08

#include <dolfin/common/Array.h>
#include "LinearAlgebraFactory.h"
#include "GenericMatrix.h"
#include "GenericVector.h"
#include "SparsityPattern.h"
#include "SingularSolver.h"

using namespace dolfin;

//-----------------------------------------------------------------------------
SingularSolver::SingularSolver(std::string solver_type,
                               std::string pc_type)
  : linear_solver(solver_type, pc_type), B(0), y(0), c(0)
{
  // Do nothing
}
//-----------------------------------------------------------------------------
SingularSolver::~SingularSolver()
{
  delete B;
  delete y;
  delete c;
}
//-----------------------------------------------------------------------------
dolfin::uint SingularSolver::solve(const GenericMatrix& A,
                                   GenericVector& x,
                                   const GenericVector& b)
{
  info(TRACE, "Solving singular system...");

  // Propagate parameters
  linear_solver.parameters.update(parameters("linear_solver"));

  // Initialize data structures for extended system
  init(A);

  // Create extended system
  create(A, b, 0);

  // Solve extended system
  const uint num_iterations = linear_solver.solve(*B, *y, *c);

  // Extract solution
  x.resize(y->size() - 1);
  Array<double> vals(y->size());
  y->get_local(vals);
  x.set_local(vals);

  return num_iterations;
}
//-----------------------------------------------------------------------------
dolfin::uint SingularSolver::solve(const GenericMatrix& A,
                                   GenericVector& x,
                                   const GenericVector& b,
                                   const GenericMatrix& M)
{
  info(TRACE, "Solving singular system...");

  // Propagate parameters
  linear_solver.parameters.update(parameters("linear_solver"));

  // Initialize data structures for extended system
  init(A);

  // Create extended system
  create(A, b, &M);

  // Solve extended system
  const uint num_iterations = linear_solver.solve(*B, *y, *c);

  // Extract solution
  x.resize(y->size() - 1);
  Array<double> vals(y->size());
  y->get_local(vals);
  x.set_local(vals);

  return num_iterations;
}
//-----------------------------------------------------------------------------
void SingularSolver::init(const GenericMatrix& A)
{
  // Check size of system
  if (A.size(0) != A.size(1))
    error("Matrix must be square.");
  if (A.size(0) == 0)
    error("Matrix size must be non-zero.");

  // Get dimension
  const uint N = A.size(0);

  // Check if we have already initialized system
  if (B && B->size(0) == N + 1 && B->size(1) == N + 1)
    return;

  // Delete any old data
  delete B;
  delete y;
  delete c;

  // Create sparsity pattern for B
  SparsityPattern s(SparsityPattern::unsorted);
  uint dims[2] = {N + 1, N + 1};
  s.init(2, dims);

  // Copy sparsity pattern for A and last column
  std::vector<uint> columns;
  std::vector<double> dummy;
  uint num_rows[2];
  const uint* rows[2];
  for (uint i = 0; i < N; i++)
  {
    // Get row
    A.getrow(i, columns, dummy);

    // Copy columns to array
    const uint num_cols = columns.size() + 1;
    uint* cols = new uint[num_cols];
    for (uint j = 0; j < columns.size(); j++)
      cols[j] = columns[j];

    // Add last entry
    cols[num_cols - 1] = N;

    // Insert into sparsity pattern
    num_rows[0] = 1;
    num_rows[1] = num_cols;
    rows[0] = &i;
    rows[1] = cols;
    s.insert(num_rows, rows);

    // Delete temporary array
    delete [] cols;
  }

  // Add last row
  const uint num_cols = N;
  uint* cols = new uint[num_cols];
  for (uint j = 0; j < num_cols; j++)
    cols[j] = j;
  const uint row = N;
  num_rows[0] = 1;
  num_rows[1] = num_cols;
  rows[0] = &row;
  rows[1] = cols;
  s.insert(num_rows, rows);
  delete [] cols;

  // Create matrix and vector
  B = A.factory().create_matrix();
  y = A.factory().create_vector();
  c = A.factory().create_vector();
  B->init(s);
  y->resize(N + 1);
  c->resize(N + 1);

  // FIXME: Do these need to be zeroed?
  y->zero();
  c->zero();
}
//-----------------------------------------------------------------------------
void SingularSolver::create(const GenericMatrix& A, const GenericVector& b,
                            const GenericMatrix* M)
{
  assert(B);
  assert(c);

  info(TRACE, "Creating extended hopefully non-singular system...");

  // Reset matrix
  B->zero();

  // Copy rows from A into B
  const uint N = A.size(0);
  std::vector<uint> columns;
  std::vector<double> values;
  for (uint i = 0; i < N; i++)
  {
    A.getrow(i, columns, values);
    B->setrow(i, columns, values);
  }

  // Compute lumped mass matrix
  columns.resize(N);
  values.resize(N);
  if (M)
  {
    GenericVector* ones = A.factory().create_vector();
    GenericVector* z = A.factory().create_vector();
    ones->resize(N);
    *ones = 1.0;
    z->resize(N);
    // FIXME: Do we need to zero z?
    z->zero();
    M->mult(*ones, *z);
    for (uint i = 0; i < N; i++)
    {
      columns[i] = i;
      values[i] = (*z)[i];
    }
    delete ones;
    delete z;
  }
  else
  {
    for (uint i = 0; i < N; i++)
    {
      columns[i] = i;
      values[i] = 1.0;
    }
  }

  // Add last row
  B->setrow(N, columns, values);

  // Add last column
  for (uint i = 0; i < N; i++)
    B->set(&values[i], 1, &i, 1, &N);

  // Copy values from b into c
  Array<double> vals(N + 1);
  b.get_local(vals);
  vals[N] = 0.0;
  c->set_local(vals);

  // Apply changes
  B->apply("insert");
  c->apply("insert");
}
//-----------------------------------------------------------------------------
