/*============================================================================

    This file is part of FLINT.

    FLINT is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    FLINT is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with FLINT; if not, write to the Free Software
    Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301 USA

===============================================================================*/
/******************************************************************************

 ZmodF.h

 Copyright (C) 2007, David Harvey
 
 Routines for arithmetic on elements of Z/pZ where p = B^n + 1,
 B = 2^FLINT_BITS.
 
 These are currently used only in the ZmodF_poly module, which supplies the
 Schoenhage-Strassen FFT code.
 
******************************************************************************/

#ifndef FLINT_ZMODF_H
#define FLINT_ZMODF_H

#ifdef __cplusplus
 extern "C" {
#endif
 
#include <stdlib.h>
#include <gmp.h>
#include "flint.h"

/*
   Add the given *signed* limb to the buffer [x, x+count), much like
   mpn_add_1 and mpn_sub_1 (except it's always inplace).

   PRECONDITIONS:
      count >= 1
   
   NOTE:
      The branch predictability of this function is optimised for the case that
      abs(limb) is relatively small and that the first limb of x is randomly
      distributed, which should be the normal usage in the FFT routines.
*/
static inline
void signed_add_1(mp_limb_t* x, unsigned long count, mp_limb_signed_t limb)
{
   FLINT_ASSERT(count >= 1);
   
   // If the high bit of x[0] doesn't change when we add "limb" to it,
   // then there's no possibility of overflow.
   mp_limb_t temp = x[0] + limb;
   if ((mp_limb_signed_t)(temp ^ x[0]) >= 0)
      // the likely case
      x[0] = temp;
   else
   {
      // the unlikely case; here we need to branch based on the sign of
      // the limb being added
      if (limb >= 0)
         mpn_add_1(x, x, count, limb);
      else
         mpn_sub_1(x, x, count, -limb);
   }
}

/*
A ZmodF_t is stored as a *signed* value in two's complement format, using
n+1 limbs. The value is not normalised into any particular range, so the top
limb may pick up overflow bits. Of course the arithmetic functions in this
module may implicitly reduce mod p whenever they like.

More precisely, suppose that the first n limbs are x[0], ..., x[n-1] (unsigned)
and the last limb is x[n] (signed). Then the value being represented is
 x[0] + x[1]*B + ... + x[n-1]*B^(n-1) - x[n]   (mod p).

*/
typedef mp_limb_t* ZmodF_t;


/* ============================================================================

    Normalisations and simple data movement

============================================================================ */


static inline
void ZmodF_swap(ZmodF_t* a, ZmodF_t* b)
{
   ZmodF_t temp = *a;
   *a = *b;
   *b = temp;
}


/*
   Normalises a into the range [0, p).
   (Note that the top limb will be set if and only if a = -1 mod p.)
*/
void ZmodF_normalise(ZmodF_t a, unsigned long n);


/*
   Adjusts a mod p so that the top limb is in the interval [0, 2].

   This in general will be faster then ZmodF_normalise(); in particular
   the branching is much more predictable.
*/
static inline
void ZmodF_fast_reduce(ZmodF_t a, unsigned long n)
{
   mp_limb_t hi = a[n];
   a[n] = 1;
   signed_add_1(a, n+1, 1-hi);
}


/*
   a := 0
*/
static inline
void ZmodF_zero(ZmodF_t a, unsigned long n)
{
   long i = n;
   do a[i] = 0; while (--i >= 0);
}


/*
   b := a
*/
static inline
void ZmodF_set(ZmodF_t b, ZmodF_t a, unsigned long n)
{
   long i = n;
   do b[i] = a[i]; while (--i >= 0);
}


/* ============================================================================

    Basic arithmetic

============================================================================ */


/*
   b := -a
   
   PRECONDITIONS:
      a and b may alias each other
*/
static inline
void ZmodF_neg(ZmodF_t b, ZmodF_t a, unsigned long n)
{
   b[n] = ~a[n] - 1;     // -1 is to make up mod p for 2's complement negation
   long i = n-1;
   do b[i] = ~a[i]; while (--i >= 0);
}


/*
   res := a + b
   
   PRECONDITIONS:
      Any combination of aliasing among res, a, b is allowed.
*/
static inline
void ZmodF_add(ZmodF_t res, ZmodF_t a, ZmodF_t b, unsigned long n)
{
   mpn_add_n(res, a, b, n+1);
}


/*
   res := a - b
   
   PRECONDITIONS:
      Any combination of aliasing among res, a, b is allowed.
*/
static inline
void ZmodF_sub(ZmodF_t res, ZmodF_t a, ZmodF_t b, unsigned long n)
{
   mpn_sub_n(res, a, b, n+1);
}


/*
   b := 2^(-s) a

   PRECONDITIONS:
      0 < s < FLINT_BITS
      b may alias a
*/
static inline
void ZmodF_short_div_2exp(ZmodF_t b, ZmodF_t a,
                          unsigned long s, unsigned long n)
{
   FLINT_ASSERT(s > 0 && s < FLINT_BITS);
   
   // quick adjustment mod p to ensure a is non-negative
   ZmodF_fast_reduce(a, n);

   // do the rotation, and push the overflow back to the top limb
   mp_limb_t overflow = mpn_rshift(b, a, n+1, s);
   mpn_sub_1(b+n-1, b+n-1, 2, overflow);
}


/*
   b := B^s a

   PRECONDITIONS:
      0 < s < n
      b must not alias a
*/
static inline
void ZmodF_mul_Bexp(ZmodF_t b, ZmodF_t a, unsigned long s, unsigned long n)
{
   FLINT_ASSERT(s > 0);
   FLINT_ASSERT(s < n);
   FLINT_ASSERT(b != a);

   // let a = ex*B^n + hi*B^(n-s) + lo,
   // where  0 <= lo < B^(n-s)  and  0 <= hi < B^s  and  abs(ex) < B/2.
   // Then the output should be  -ex*B^s + lo*B^s - hi  (mod p).

   long i;
   
   // Put B^s - hi - 1 into b
   i = s-1;
   do b[i] = ~a[n-s+i]; while (--i >= 0);

   // Put lo*B^s into b
   i = n-s-1;
   do b[i+s] = a[i]; while (--i >= 0);

   // Put -B^n into b (to compensate mod p for -1 added in first loop)
   b[n] = (mp_limb_t)(-1L);
   
   // Add (-ex-1)*B^s to b
   signed_add_1(b+s, n-s+1, -a[n]-1);
}


/*
   c := a - 2^(-Bs) b

PRECONDITIONS:
   0 < s < n
   b must not alias c
   a may alias b or c
*/
static inline
void ZmodF_div_Bexp_sub(ZmodF_t c, ZmodF_t a, ZmodF_t b,
                        unsigned long s, unsigned long n)
{
   FLINT_ASSERT(s > 0);
   FLINT_ASSERT(s < n);
   FLINT_ASSERT(b != c);

   // add low limbs of b to high limbs of a
   c[n] = a[n] + mpn_add_n(c+n-s, b, a+n-s, s);
   // subtract high limbs of b from low limbs of a
   mp_limb_t overflow = b[n] + mpn_sub_n(c, a, b+s, n-s);
   // propagate overflow
   signed_add_1(c+n-s, s+1, -overflow);
}


/*
   c := a + 2^(-Bs) b

   PRECONDITIONS:
      0 < s < n
      b must not alias c
      a may alias b or c
*/
static inline
void ZmodF_div_Bexp_add(ZmodF_t c, ZmodF_t a, ZmodF_t b,
                        unsigned long s, unsigned long n)
{
   FLINT_ASSERT(s > 0);
   FLINT_ASSERT(s < n);
   FLINT_ASSERT(b != c);

   // subtract low limbs of b from high limbs of a
   c[n] = a[n] - mpn_sub_n(c+n-s, a+n-s, b, s);
   // add high limbs of b to low limbs of a
   mp_limb_t overflow = b[n] + mpn_add_n(c, a, b+s, n-s);
   // propagate overflow
   signed_add_1(c+n-s, s+1, overflow);
}


/*
   c := B^s (a - b)

PRECONDITIONS:   
   c must not alias a or b
   0 < s < n
*/
static inline
void ZmodF_sub_mul_Bexp(ZmodF_t c, ZmodF_t a, ZmodF_t b,
                        unsigned long s, unsigned long n)
{
   FLINT_ASSERT(s > 0);
   FLINT_ASSERT(s < n);
   FLINT_ASSERT(c != a);
   FLINT_ASSERT(c != b);

   // get low limbs of a - b into high limbs of c
   c[n] = -mpn_sub_n(c+s, a, b, n-s);
   // get high limbs of b - a into low limbs of c
   mp_limb_t overflow = b[n] - a[n] - mpn_sub_n(c, b+n-s, a+n-s, s);
   // propagate overflow
   signed_add_1(c+s, n+1-s, overflow);
}


/*
   b := B^s (1 - B^(n/2)) a
   
   PRECONDITIONS:
      0 <= s < 2n
      n must be odd
      b must not alias a
*/
void ZmodF_mul_pseudosqrt2_n_odd(ZmodF_t b, ZmodF_t a,
                                 unsigned long s, unsigned long n);


/*
   b := B^s (1 - B^(n/2)) a
   
   PRECONDITIONS:
      0 <= s < 2n
      n must be even
      b must not alias a
*/
void ZmodF_mul_pseudosqrt2_n_even(ZmodF_t b, ZmodF_t a,
                                  unsigned long s, unsigned long n);


/*
   b := 2^s a

   PRECONDITIONS:
      0 <= s < n*FLINT_BITS
      b may not alias a
*/
void ZmodF_mul_2exp(ZmodF_t b, ZmodF_t a, unsigned long s, unsigned long n);


/*
   b := 2^(s/2) a

   PRECONDITIONS:
      0 <= s < 2*n*FLINT_BITS

*/
void ZmodF_mul_sqrt2exp(ZmodF_t b, ZmodF_t a,
                        unsigned long s, unsigned long n);


/*
   c := 2^s (a - b)

   PRECONDITIONS:
      c must not alias a or b
      0 <= s < n*FLINT_BITS
*/
void ZmodF_sub_mul_2exp(ZmodF_t c, ZmodF_t a, ZmodF_t b,
                        unsigned long s, unsigned long n);


/* ============================================================================

    Butterflies

============================================================================ */

/*
   a := a + b
   b := B^s (a - b)
   z := destroyed

   PRECONDITIONS:
      a, b, z may not alias each other
      0 < s < n
      
   NOTE: a, b, z may get permuted
*/
static inline
void ZmodF_forward_butterfly_Bexp(ZmodF_t* a, ZmodF_t* b, ZmodF_t* z,
                                  unsigned long s, unsigned long n)
{
   FLINT_ASSERT(s > 0);
   FLINT_ASSERT(s < n);
   FLINT_ASSERT(*a != *b);
   FLINT_ASSERT(*a != *z);
   FLINT_ASSERT(*z != *b);

   ZmodF_sub_mul_Bexp(*z, *a, *b, s, n);
   ZmodF_add(*a, *a, *b, n);
   ZmodF_swap(b, z);
}


/*
   a := a + b
   b := 2^s (a - b)
   z := destroyed

   PRECONDITIONS:
      a, b, z may not alias each other
      0 <= s < n*FLINT_BITS
      
   NOTE: a, b, z may get permuted
*/
void ZmodF_forward_butterfly_2exp(ZmodF_t* a, ZmodF_t* b, ZmodF_t* z,
                                  unsigned long s, unsigned long n);


/*
   a := a + b
   b := 2^(s/2) (a - b)
   z := destroyed

   PRECONDITIONS:
      a, b, z may not alias each other
      0 <= s < 4*FLINT_BITS
      
   NOTE: a, b, z may get permuted
*/
void ZmodF_forward_butterfly_sqrt2exp(ZmodF_t* a, ZmodF_t* b, ZmodF_t* z,
                                      unsigned long s, unsigned long n);


/*
   a := a + B^(-s) b
   b := a - B^(-s) b
   z := destroyed

   PRECONDITIONS:
      a, b, z may not alias each other
      0 < s < n
      
   NOTE: a, b, z may get permuted
*/
static inline
void ZmodF_inverse_butterfly_Bexp(ZmodF_t* a, ZmodF_t* b, ZmodF_t* z,
                                  unsigned long s, unsigned long n)
{
   FLINT_ASSERT(s > 0);
   FLINT_ASSERT(s < n);
   FLINT_ASSERT(*a != *b);
   FLINT_ASSERT(*a != *z);
   FLINT_ASSERT(*z != *b);

   ZmodF_div_Bexp_sub(*z, *a, *b, s, n);
   ZmodF_div_Bexp_add(*a, *a, *b, s, n);
   ZmodF_swap(z, b);
}


/*
   a := a + 2^(-s) b
   b := a - 2^(-s) b
   z := destroyed

   PRECONDITIONS:
      a, b, z may not alias each other
      0 <= s < n*FLINT_BITS
      
   NOTE: a, b, z may get permuted
*/
void ZmodF_inverse_butterfly_2exp(ZmodF_t* a, ZmodF_t* b, ZmodF_t* z,
                                  unsigned long s, unsigned long n);


/*
   a := a + 2^(-s/2) b
   b := a - 2^(-s/2) b
   z := destroyed

   PRECONDITIONS:
      a, b, z may not alias each other
      0 <= s < 2*n*FLINT_BITS
      
   NOTE: a, b, z may get permuted
*/
void ZmodF_inverse_butterfly_sqrt2exp(ZmodF_t* a, ZmodF_t* b, ZmodF_t* z,
                                      unsigned long s, unsigned long n);


/*
   a := a + b
   b := a - b
   z := destroyed

   PRECONDITIONS:
      a, b, z may not alias each other
      
   NOTE: a, b, z may get permuted
*/
static inline
void ZmodF_simple_butterfly(ZmodF_t* a, ZmodF_t* b, ZmodF_t* z,
                            unsigned long n)
{
   FLINT_ASSERT(*a != *b);
   FLINT_ASSERT(*a != *z);
   FLINT_ASSERT(*z != *b);

   ZmodF_add(*z, *a, *b, n);
   ZmodF_sub(*b, *a, *b, n);
   ZmodF_swap(z, a);
}




/* ============================================================================

    Miscellaneous

============================================================================ */


/*
   b := a / 3

   PRECONDITIONS:
      n < 2^(FLINT_BITS/2)
   
   NOTE:
      a and b may alias each other
      a may get modified mod p
*/
void ZmodF_divby3(ZmodF_t b, ZmodF_t a, unsigned long n);

#ifdef __cplusplus
 }
#endif

#endif

// end of file ****************************************************************
