/* $Id: test_pm.C,v 1.3 2006/02/25 02:11:37 mfreed Exp $ */

/*
 *
 * Copyright (C) 2005 Michael J. Freedman (mfreedman at alum.mit.edu)
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2, or (at
 * your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 *
 */

#define USE_PCTR 0

#include "crypt.h"
#include "pm.h"
#include "bench.h"
#include "homoenc.h"
#include "paillier.h"
#include "elgamal.h"

u_int64_t c1time;
u_int64_t s1time;
u_int64_t c2time;

static const size_t maxKC    = 50;
static const size_t maxKS    = 50;
static const size_t dkeysz   = 160;

static const size_t repeat   = 1;
static const size_t cnt      = 1;

void
prepare_inputs (size_t match, size_t KC, size_t KS, size_t keysz,
		vec<str> &clntin, vec<str> &srvin)
{
  for (size_t i=0; i < match; i++) {
    str tmp = strbuf () << random_bigint (keysz);
    str key (tmp.cstr (), (keysz/8));
    clntin.push_back (key);
    srvin.push_back  (key);
  }
  
  for (size_t i=0; i < (KC-match); i++) {
    str tmp = strbuf () << 0 << random_bigint (keysz);
    str key (tmp.cstr (), (keysz/8));
    clntin.push_back (key);
  }

  for (size_t i=0; i < (KS-match); i++) {
    str tmp = strbuf () << 1 << random_bigint (keysz);
    str key (tmp.cstr (), (keysz/8));
    srvin.push_back (key);
  }

  assert (clntin.size () == KC);
  assert (srvin.size ()  == KS);
}


void
prepare_srv (pm_server *srv, const vec<str> &srvin)
{
  srv->inputs.clear ();

  for (size_t i=0; i < srvin.size (); i++) {
    str key = strbuf () << srvin[i];

    ppayload p;
    p.ptxt = strbuf () << 2 << key;

    srv->inputs.insert (key, p);
  }
}


void
test_pm (bool verbose, size_t KC, size_t KS, size_t keysz, 
	 ref<homoenc_priv> sk)
{
  u_int64_t tmp1, tmp2, tmp3, tmp4;

  pm_client clnt (sk);
  pm_server srv;

  // size_t K = min (KS, KC);

  for (size_t i=0; i < cnt; i++) {

    // size_t match = rnd.getword () % K;
    size_t match = min (KC, KS);
    assert (match <= KC && match <= KS);

    vec<str> clntin, srvin;
    prepare_inputs (match, KC, KS, keysz, clntin, srvin);
    prepare_srv    (&srv, srvin);

    tmp1 = get_time ();

    // KC client coefficients
    if (!clnt.set_polynomial (clntin))
      panic << "Failed set polynomial\n";
    
    const vec<crypt_ctext> &clntcoeffs = clnt.get_polynomial ();

    tmp2 = get_time ();
    
    // KS server evaluations on client coeffs
    vec<cpayload> srvres;
    srv.evaluate_intersection (&srvres, &clntcoeffs, sk);

    tmp3 = get_time ();

    // KS client decryptions of set returned from server
    vec<str> intersect;
    clnt.decrypt_intersection (intersect, srvres);

    tmp4 = get_time ();
    
    c1time += (tmp2 - tmp1);
    s1time += (tmp3 - tmp2);
    c2time += (tmp4 - tmp3);

    if (intersect.size () != match) {
      strbuf sb;
      for (size_t i=0; i < intersect.size (); i++)
	sb << "  " << (i+1) << ".\t" << intersect[i] << "\n";
      panic << "Failed intersection: " << match 
	    << " items matched != " << intersect.size () 
	    << " intersection set size\n" << sb << "\n";
    }

    if (verbose) {
      strbuf sb;
      for (size_t i=0; i < intersect.size (); i++)
	sb << "  " << (i+1) << ".\t" << intersect[i] << "\n";
      warn << "Intersected at " << intersect.size () << " elements:\n"
	   << sb << "\n";
    }
  }
}


int
main (int argc, char **argv)
{
  bool opt_v   = false, opt_vv = false;
  int    vsz   = 1024;
  int    asz   = 160;
  size_t kc    = 0;
  size_t ks    = 0;
  size_t rep   = repeat;
  size_t keysz = dkeysz;

  bool paillier = true;

  for (int i=1; i < argc; i++) {
    if (!strcmp (argv[i], "-v"))
      opt_v = true;
    else if (!strcmp (argv[i], "-V"))
      opt_vv = opt_v = true;
    else if (!strcmp (argv[i], "-e"))
      paillier = false;
    else if (!strcmp (argv[i], "-b")) {
      assert (argc > i+1);
      vsz = atoi (argv[i+1]);
      assert (vsz > 0);
    }
    else if (!strcmp (argv[i], "-a")) {
      assert (argc > i+1);
      asz = atoi (argv[i+1]);
      assert (asz > 0);
    }
    else if (!strcmp (argv[i], "-c")) {
      assert (argc > i+1);
      kc = atoi (argv[i+1]);
      assert (kc > 0);
    }
    else if (!strcmp (argv[i], "-s")) {
      assert (argc > i+1);
      ks = atoi (argv[i+1]);
      assert (ks > 0);
    }
    else if (!strcmp (argv[i], "-r")) {
      assert (argc > i+1);
      rep = atoi (argv[i+1]);
      assert (repeat > 0);
    }
    else if (!strcmp (argv[i], "-k")) {
      assert (argc > i+1);
      keysz = atoi (argv[i+1]);
      assert (keysz > 0);
    }
  }

  setprogname (argv[0]);
  random_update ();

  if (!opt_v) {
    vsz = 424 + rnd.getword () % 256;
    asz = 160 + rnd.getword () % 256;
  }

  for (size_t i = 0; i < rep; i++) {

    c1time = s1time = c2time = 0;
    ptr<homoenc_priv> sk;

    if (paillier)
      sk = New refcounted<paillier_priv> (paillier_keygen (vsz, asz));
    else 
      sk = New refcounted<elgamal_priv> (elgamal_keygen (vsz, asz));

    size_t KC = kc ? kc : rnd.getword () % maxKC;
    size_t KS = ks ? ks : rnd.getword () % maxKS;

    test_pm (opt_vv, KC, KS, keysz, sk);

    if (opt_v) {
      if (paillier)
	warn ("Private matching protocol with %d bit Paillier key [%d]\n", 
	      vsz, asz);
      else
	warn ("Private matching protocol with %d bit ElGamal key [%d]\n", 
	      vsz, asz);

      warn ("  Clnt prepare %u sets of size %u in %" U64F "u "
	    TIME_LABEL " per set\n", cnt, KC, c1time / cnt);
      warn ("  Srv evaluate %u sets of size %u in %" U64F "u "
	    TIME_LABEL " per set\n", cnt, KS, s1time / cnt);
      warn ("  Clnt recover %u sets of size %u in %" U64F "u "
	    TIME_LABEL " per set\n\n", cnt, KC, c2time / cnt);
    }
  }
  return 0;
}
