/*
  This file is part of CDO. CDO is a collection of Operators to
  manipulate and analyse Climate model Data.

  Copyright (C) 2003-2019 Uwe Schulzweida, <uwe.schulzweida AT mpimet.mpg.de>
  See COPYING file for copying and redistribution conditions.

  This program is free software; you can redistribute it and/or modify
  it under the terms of the GNU General Public License as published by
  the Free Software Foundation; version 2 of the License.

  This program is distributed in the hope that it will be useful,
  but WITHOUT ANY WARRANTY; without even the implied warranty of
  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
  GNU General Public License for more details.
*/
#include <vector>
#include <cmath>
#include <string.h>

#ifdef HAVE_CONFIG_H
#include "config.h"
#endif

#include "cdo_options.h"
#include "cimdOmp.h"
#include "dmemory.h"
#include "constants.h"
#include "cimdOmp.h"

#define OPENMP4 201307
#if defined(_OPENMP) && defined(OPENMP4) && _OPENMP >= OPENMP4
#define HAVE_OPENMP4 1
#endif

extern "C" {
void gaussaw(double *pa, double *pw, size_t nlat);
}

static void
jspleg1(double *pleg, double plat, long ktrunc, double *work)
{
  /*
     jspleg1 - Routine to calculate legendre functions

     Purpose
     --------

     This routine calculates the legendre functions for one latitude.
     (but not their derivatives)


     Interface
     ----------

     jspleg1( pleg, plat, ktrunc)


     Input parameters
     ----------------

     plat      - Latitude in radians
     ktrunc    - Spectral truncation


     Output parameters
     -----------------

     pleg      - Array of legendre functions for one latitude.
                 The array must be at least (KTRUNC+1)*(KTRUNC+4)/2
                 words long.

     Method
     ------

     Recurrence relation with explicit relations for P(m,m) and
     P(m,m+1)


     AUTHOR
     ------

     J.D.Chambers         ECMWF        9 November 1993


     Modifications
     -------------

     None

  */
  long itout1, i1m, ilm, jm, jcn, im2;
  double zsin, zcos, zf1m, zre1, zf2m, znsqr, ze1, ze2;
  double zjmsqr;
  double *zhlp1, *zhlp2, *zhlp3;

  /* Initialization */

  itout1 = ktrunc + 1;
  /*  zsin   = sin(plat); */
  zsin = plat;
  zcos = sqrt(1. - zsin * zsin);

  zhlp1 = work;
  zhlp2 = work + itout1;
  zhlp3 = work + itout1 + itout1;

  /*  Step 1.        M = 0, N = 0 and N = 1 */

  ilm = 1;
  pleg[0] = 1.0;
  zf1m = sqrt(3.0);
  pleg[1] = zf1m * zsin;

  /*  Step 2.       Sum for M = 0 to T (T = truncation) */

  for (jm = 1; jm < itout1; jm++)
    {
      zhlp1[jm] = sqrt(2. * jm + 3.);
      zhlp2[jm] = 1. / sqrt(2. * jm);
    }

  zhlp1[0] = sqrt(3.);

  for (jm = 0; jm < itout1; jm++)
    {
      i1m = jm - 1;
      zre1 = zhlp1[jm];
      ze1 = 1. / zre1;

      /*   Step 3.       M > 0 only */

      if (i1m != -1)
        {
          zf2m = zf1m * zcos * zhlp2[jm];
          zf1m = zf2m * zre1;

          /*  Step 4.       N = M and N = M+1 */

          ilm = ilm + 1;
          pleg[ilm] = zf2m;
          ilm = ilm + 1;
          pleg[ilm] = zf1m * zsin;

          /* When output truncation is reached, return to calling program */

          if (jm == (itout1 - 1)) break;
        }

      /*  Step 5.       Sum for N = M+2 to T+1 */

      zjmsqr = jm * jm;
      im2 = i1m + 2;

      for (jcn = im2; jcn < itout1; jcn++)
        {
          znsqr = (jcn + 1) * (jcn + 1);
          zhlp3[jcn] = sqrt((4. * znsqr - 1.) / (znsqr - zjmsqr));
        }

      for (jcn = im2; jcn < itout1; jcn++)
        {
          ze2 = zhlp3[jcn];
          ilm = ilm + 1;
          pleg[ilm] = ze2 * (zsin * pleg[ilm - 1] - ze1 * pleg[ilm - 2]);
          ze1 = 1. / ze2;
        }
    }
}

/* ============================================= */
/* phcs - Compute values of Legendre polynomials */
/*        and their meridional derivatives       */
/* ============================================= */
static void
phcs(double *pnm, double *hnm, long waves, double pmu, double *ztemp1, double *ztemp2)
{
  long twowaves;

  long jk, jn, jm;

  double jnmjk;
  double zcos2;
  double lat;
  double zan;
  double zsinpar;
  double zcospar;
  double zsqp;
  double zcosfak;
  double zsinfak;
  double zq;
  double zwm2;
  double zw;
  double zwq;
  double zq2m1;
  double zwm2q2;
  double z2q2;
  double zcnm;
  double zdnm;
  double zenm;

  twowaves = waves << 1;

  zcos2 = sqrt(1.0 - pmu * pmu);
  lat = acos(pmu);
  zan = 1.0;

  ztemp1[0] = 0.5;

  for (jn = 1; jn < twowaves; jn++)
    {
      zsqp = 1.0 / sqrt((double) (jn + jn * jn));
      zan *= sqrt(1.0 - 1.0 / (4 * jn * jn));

      zcospar = cos(lat * jn);
      zsinpar = sin(lat * jn) * jn * zsqp;
      zcosfak = 1.0;

      for (jk = 2; jk < jn; jk += 2)
        {
          jnmjk = jn - jk;
          zcosfak *= (jk - 1.0) * (jn + jnmjk + 2.0) / (jk * (jn + jnmjk + 1.0));
          zsinfak = zcosfak * (jnmjk) *zsqp;
          zcospar += zcosfak * cos(lat * jnmjk);
          zsinpar += zsinfak * sin(lat * jnmjk);
        }

      /*  code for jk == jn */

      if ((jn & 1) == 0)
        {
          zcosfak *= (double) ((jn - 1) * (jn + 2)) / (double) (jn * (jn + 1));
          zcospar += zcosfak * 0.5;
        }
      ztemp1[jn] = zan * zcospar;
      ztemp2[jn - 1] = zan * zsinpar;
    }

  memcpy(pnm, ztemp1, waves * sizeof(double));
  pnm += waves;
  memcpy(pnm, ztemp2, waves * sizeof(double));
  pnm += waves;

  hnm[0] = 0.0;
  for (jn = 1; jn < waves; jn++) hnm[jn] = jn * (pmu * ztemp1[jn] - sqrt((jn + jn + 1.0) / (jn + jn - 1.0)) * ztemp1[jn - 1]);

  hnm += waves;

  hnm[0] = pmu * ztemp2[0];

  for (jn = 1; jn < waves; jn++)
    hnm[jn]
        = (jn + 1) * pmu * ztemp2[jn] - sqrt(((jn + jn + 3.0) * ((jn + 1) * (jn + 1) - 1.0)) / (jn + jn + 1.0)) * ztemp2[jn - 1];

  hnm += waves;

  for (jm = 2; jm < waves; jm++)
    {
      pnm[0] = sqrt(1.0 + 1.0 / (jm + jm)) * zcos2 * ztemp2[0];
      hnm[0] = jm * pmu * pnm[0];
#if defined(CRAY)
#pragma _CRI novector
#endif
#if defined(__uxp__)
#pragma loop scalar
#endif
      for (jn = 1; jn < twowaves - jm; jn++)
        {
          zq = jm + jm + jn - 1;
          zwm2 = zq + jn;
          zw = zwm2 + 2;
          zwq = zw * zq;
          zq2m1 = zq * zq - 1.;
          zwm2q2 = zwm2 * zq2m1;
          z2q2 = zq2m1 * 2;
          zcnm = sqrt((zwq * (zq - 2.)) / (zwm2q2 - z2q2));
          zdnm = sqrt((zwq * (jn + 1.)) / zwm2q2);
          zenm = sqrt(zw * jn / ((zq + 1.0) * zwm2));
          pnm[jn] = zcnm * ztemp1[jn] - pmu * (zdnm * ztemp1[jn + 1] - zenm * pnm[jn - 1]);
          hnm[jn] = (jm + jn) * pmu * pnm[jn] - sqrt(zw * jn * (zq + 1) / zwm2) * pnm[jn - 1];
        }
      memcpy(ztemp1, ztemp2, twowaves * sizeof(double));
      memcpy(ztemp2, pnm, twowaves * sizeof(double));
      pnm += waves;
      hnm += waves;
    }
}

void
after_legini_full(long ntr, long nlat, double *restrict poli, double *restrict pold, double *restrict pdev, double *restrict pol2,
                  double *restrict pol3, double *restrict coslat)
{
  long waves = ntr + 1;
  long dimsp = (ntr + 1) * (ntr + 2);

  std::vector<double> gmu(nlat);
  std::vector<double> gwt(nlat);

  gaussaw(gmu.data(), gwt.data(), nlat);

#ifndef _OPENMP
  std::vector<double> pnm(dimsp);
  std::vector<double> hnm(dimsp);
  std::vector<double> ztemp1(waves << 1);
  std::vector<double> ztemp2(waves << 1);
#endif

#ifdef _OPENMP
#pragma omp parallel for default(none) shared(nlat, dimsp, waves, poli, pold, pdev, pol2, pol3, coslat, gmu, gwt, PlanetRadius)
#endif
  for (long jgl = 0; jgl < nlat; ++jgl)
    {
#ifdef _OPENMP
      std::vector<double> pnm(dimsp);
      std::vector<double> hnm(dimsp);
      std::vector<double> ztemp1(waves << 1);
      std::vector<double> ztemp2(waves << 1);
#endif
      double gmusq = 1.0 - gmu[jgl] * gmu[jgl];
      coslat[jgl] = sqrt(gmusq);

      phcs(pnm.data(), hnm.data(), waves, gmu[jgl], ztemp1.data(), ztemp2.data());

      double zgwt = gwt[jgl];
      double zrafgmusqr = 1. / (PlanetRadius * gmusq);
      double zradsqrtgmusqr = 1. / (-PlanetRadius * sqrt(gmusq));

      const int lpold = pold != nullptr;
      const int lpdev = pdev != nullptr;
      const int lpol2 = pol2 != nullptr;
      const int lpol3 = pol3 != nullptr;

      long jsp = jgl;
      for (long jm = 0; jm < waves; ++jm)
        for (long jn = 0; jn < waves - jm; ++jn)
          {
            poli[jsp] = pnm[jm * waves + jn] * 2.0;
            if (lpold) pold[jsp] = pnm[jm * waves + jn] * zgwt;
            if (lpdev) pdev[jsp] = hnm[jm * waves + jn] * 2.0 * zradsqrtgmusqr;
            if (lpol2) pol2[jsp] = hnm[jm * waves + jn] * zgwt * zrafgmusqr;
            if (lpol3) pol3[jsp] = pnm[jm * waves + jn] * zgwt * jm * zrafgmusqr;
            jsp += nlat;
          }
    }
}

void
after_legini(long ntr, long nlat, double *restrict poli, double *restrict pold, double *restrict coslat)
{
  long waves = ntr + 1;
  long dimpnm = (ntr + 1) * (ntr + 4) / 2;

  std::vector<double> gmu(nlat);
  std::vector<double> gwt(nlat);
#ifdef _OPENMP
  std::vector<std::vector<double>> pnm2(Threading::ompNumThreads);
  std::vector<std::vector<double>> work2(Threading::ompNumThreads);
  for (long i = 0; i < Threading::ompNumThreads; ++i) pnm2[i].resize(dimpnm);
  for (long i = 0; i < Threading::ompNumThreads; ++i) work2[i].resize(3 * waves);
#else
  std::vector<double> pnm(dimpnm);
  std::vector<double> work(3 * waves);
#endif

  gaussaw(gmu.data(), gwt.data(), nlat);
  for (long jgl = 0; jgl < nlat; ++jgl) gwt[jgl] *= 0.5;

  for (long jgl = 0; jgl < nlat; ++jgl) coslat[jgl] = sqrt(1.0 - gmu[jgl] * gmu[jgl]);

#ifdef _OPENMP
#pragma omp parallel for default(none) shared(nlat, ntr, waves, pnm2, work2, gwt, gmu, poli, pold)
#endif
  for (long jgl = 0; jgl < nlat / 2; jgl++)
    {
#ifdef _OPENMP
      const int ompthID = cdo_omp_get_thread_num();
      double *pnm = pnm2[ompthID].data();
      double *work = work2[ompthID].data();
#endif
      double zgwt = gwt[jgl];

      jspleg1(&pnm[0], gmu[jgl], ntr, &work[0]);

      long latn = jgl;
      long lats;
      long isp = 0;
      double is;
      for (long jm = 0; jm < waves; ++jm)
        {
#if defined(SX)
#pragma vdir nodep
#endif
#ifdef HAVE_OPENMP4
#pragma omp simd
#endif
          for (long jn = 0; jn < waves - jm; ++jn)
            {
              is = (jn + 1) % 2 * 2 - 1;
              lats = latn - jgl + nlat - jgl - 1;
              poli[latn] = pnm[isp];
              pold[latn] = pnm[isp] * zgwt;
              poli[lats] = pnm[isp] * is;
              pold[lats] = pnm[isp] * zgwt * is;
              latn += nlat;
              isp++;
            }
          isp++;
        }
    }
}

/* to slow for nec, 2.0 instead of 2.3 GFlops ( vector length too small ) */
void
sp2fctest(const double *sa, double *fa, const double *poli, long nlev, long nlat, long nfc, long nt)
{
  long lats, is;
  double sar, sai;
  double saris, saiis;
  double *restrict far, *restrict fai;

  long nsp2 = (nt + 1) * (nt + 2);

  for (long lev = 0; lev < nlev; lev++)
    {
      const double *restrict pol = poli;
      const double *restrict sal = sa + lev * nsp2;
      double *fal = fa + lev * nfc * nlat;
      memset(fal, 0, nfc * nlat * sizeof(double));

      for (long jm = 0; jm <= nt; jm++)
        {
          for (long jn = 0; jn <= nt - jm; jn++)
            {
              is = (jn + 1) % 2 * 2 - 1;
              sar = *sal++;
              sai = *sal++;
              saris = sar * is;
              saiis = sai * is;
              far = fal;
              fai = fal + nlat;
#if defined(SX)
#pragma vdir nodep
#endif
#ifdef HAVE_OPENMP4
#pragma omp simd
#endif
              for (long latn = 0; latn < nlat / 2; latn++)
                {
                  lats = nlat - latn - 1;
                  far[latn] += pol[latn] * sar;
                  fai[latn] += pol[latn] * sai;
                  far[lats] += pol[latn] * saris;
                  fai[lats] += pol[latn] * saiis;
                }
              pol += nlat;
            }
          fal += 2 * nlat;
        }
    }
}

void
sp2fc(const double *sa, double *fa, const double *poli, long nlev, long nlat, long nfc, long nt)
{
  long nsp2 = (nt + 1) * (nt + 2);

#ifdef _OPENMP
#pragma omp parallel for default(shared)
#endif
  for (long lev = 0; lev < nlev; lev++)
    {
      const double *restrict pol = poli;
      const double *restrict sal = sa + lev * nsp2;
      double *fal = fa + lev * nfc * nlat;
      memset(fal, 0, nfc * nlat * sizeof(double));

      double *restrict far, *restrict fai;
      double sar, sai;

      for (long jmm = 0; jmm <= nt; jmm++)
        {
          for (long jfc = jmm; jfc <= nt; jfc++)
            {
              sar = *sal++;
              sai = *sal++;
              far = fal;
              fai = fal + nlat;
              /* unaligned loop start
#ifdef  HAVE_OPENMP4
#pragma omp simd
#endif
              */
              for (long lat = 0; lat < nlat; lat++)
                {
                  far[lat] += pol[lat] * sar;
                  fai[lat] += pol[lat] * sai;
                }
              pol += nlat;
            }
          fal += 2 * nlat;
        }
    }
}

void
fc2sp(double *fa, double *sa, const double *poli, long nlev, long nlat, long nfc, long nt)
{
  long nsp2 = (nt + 1) * (nt + 2);

#ifdef _OPENMP
#pragma omp parallel for default(shared)
#endif
  for (long lev = 0; lev < nlev; lev++)
    {
      const double *restrict pol = poli;
      double *fal = fa + lev * nfc * nlat;
      double *sal = sa + lev * nsp2;

      const double *restrict far;
      const double *restrict fai;
      double sar, sai;
      long jmm, jfc, lat;

      for (jmm = 0; jmm <= nt; jmm++)
        {
          for (jfc = jmm; jfc <= nt; jfc++)
            {
              far = fal;
              fai = fal + nlat;
              sar = 0.0;
              sai = 0.0;
#ifdef HAVE_OPENMP4
#pragma omp simd reduction(+ : sar) reduction(+ : sai)
#endif
              for (lat = 0; lat < nlat; lat++)
                {
                  sar += pol[lat] * far[lat];
                  sai += pol[lat] * fai[lat];
                }
              *sal++ = sar;
              *sal++ = sai;
              pol += nlat;
            }
          fal += 2 * nlat;
        }
    }
}

/* ======================================== */
/* Convert Spectral Array to new truncation */
/* ======================================== */

void
sp2sp(double *arrayIn, long truncIn, double *arrayOut, long truncOut)
{
  long n, m;

  if (truncOut <= truncIn)
    {
      for (n = 0; n <= truncOut; n++)
        {
          for (m = n; m <= truncOut; m++)
            {
              *arrayOut++ = *arrayIn++;
              *arrayOut++ = *arrayIn++;
            }
          arrayIn += 2 * (truncIn - truncOut);
        }
    }
  else
    {
      for (n = 0; n <= truncIn; n++)
        {
          for (m = n; m <= truncIn; m++)
            {
              *arrayOut++ = *arrayIn++;
              *arrayOut++ = *arrayIn++;
            }
          for (m = truncIn + 1; m <= truncOut; ++m)
            {
              *arrayOut++ = 0.0;
              *arrayOut++ = 0.0;
            }
        }
      for (n = truncIn + 1; n <= truncOut; ++n)
        for (m = n; m <= truncOut; ++m)
          {
            *arrayOut++ = 0.0;
            *arrayOut++ = 0.0;
          }
    }
}

/* ======================================== */
/* Cut spectral wave numbers                */
/* ======================================== */

void
spcut(double *arrayIn, double *arrayOut, long trunc, const int *waves)
{
  long n, m;

  for (n = 0; n <= trunc; n++)
    {
      for (m = n; m <= trunc; m++)
        {
          if (waves[m])
            {
              *arrayOut++ = *arrayIn++;
              *arrayOut++ = *arrayIn++;
            }
          else
            {
              *arrayOut++ = 0.0;
              *arrayOut++ = 0.0;
              arrayIn++;
              arrayIn++;
            }
        }
    }
}
