#include "pvFile.h"
#include "pvVector.h"
#include "eval/eval.h"
#include "eval/evalFactory.h"

#include "osl/state/numEffectState.h"
#include "osl/state/historyState.h"
#include "osl/search/shouldPromoteCut.h"
#include "osl/record/kisen.h"
#include "osl/eval/evalTraits.h"
#include "osl/eval/ml/openMidEndingEval.h"
#include "osl/progress/ml/newProgress.h"
#include "osl/stat/average.h"

#include <boost/program_options.hpp>
#include <boost/scoped_ptr.hpp>
#include <functional>
#include <algorithm>
#include <iostream>

namespace po = boost::program_options;
using namespace osl;

boost::scoped_ptr<gpsshogi::Eval> my_eval;
int evaluate(const NumEffectState& src, const gpsshogi::PVVector& pv)
{
  HistoryState state(src);
  for (size_t i=0; i<pv.size(); ++i)
    state.makeMove(pv[i]);
  return my_eval->eval(state.state());
}

class PVCompare
{
protected:
  const NumEffectState *state;
public:
  virtual ~PVCompare()
  {
  }
  virtual void init(const NumEffectState& state, const vector<Move>&, int next) { this->state = &state; }
  virtual void add(const gpsshogi::PVVector& pv)=0;
  virtual void finish() {}
};

class PassIsBetter : public PVCompare
{
  gpsshogi::PVVector pv_main, pv_pass;
  Move next_move;
public:
  void init(const NumEffectState& state, const vector<Move>& moves, int next) { 
    this->state = &state;
    next_move = moves[next];
    pv_main.clear(); pv_pass.clear(); 
  }
  void add(const gpsshogi::PVVector& pv)
  {
    if (pv.empty())
      return;
    if (pv[0].isPass())
      pv_pass = pv;
    else if (pv[0] == next_move)
      pv_main = pv;
  }
  void finish() 
  {
    if (pv_main.empty() || pv_pass.empty())
      return;
    const int eval_main = evaluate(*state, pv_main);
    const int eval_pass = evaluate(*state, pv_pass);
    if (eval::betterThan(state->turn(), eval_pass, eval_main)) {
      std::cout << *state;
      std::cout << eval_main << ' ' << pv_main
		<< eval_pass << ' ' << pv_pass << "\n";
    }
  }
};

class MoveAfterPass : public PVCompare
{
  gpsshogi::PVVector pv_pass;
  Move next2_move;
public:
  void init(const NumEffectState& state, const vector<Move>& moves, int next) { 
    this->state = &state;
    next2_move = Move();
    if (next+1 < (int)moves.size())
      next2_move = moves[next+1];
    pv_pass.clear(); 
  }
  void add(const gpsshogi::PVVector& pv)
  {
    if (pv.empty())
      return;
    if (pv[0].isPass())
      pv_pass = pv;
  }
  void finish() 
  {
    if (!next2_move.isNormal() || pv_pass.size() < 2)
      return;
    if (pv_pass[1] != next2_move) {
      std::cout << *state;
      std::cout << next2_move << "\n" << pv_pass << "\n";
    }
  }
};

class EvalRange : public PVCompare
{
  vector<double> values;
  double selected_value;
  CArray<stat::Average,16> best2, best10, best20, median, worst, selected;
  Player turn;
  Move next_move;
public:
  ~EvalRange() 
  {
  }
  void init(const NumEffectState& state, const vector<Move>& moves, int next) 
  { 
    this->state = &state;
    values.clear();
    turn = state.turn();
    next_move = moves[next];
  }
  void add(const gpsshogi::PVVector& pv)
  {
    const double eval = evaluate(*state, pv)
      *eval::delta(turn)*100.0/my_eval->pawnValue();
    values.push_back(eval);
    if (!pv.empty() && pv[0] == next_move)
      selected_value = eval;
  }
  void finish() 
  {    
    if (values.size() < 20)
      return;
    std::sort(values.begin(), values.end(), std::greater<int>());
    const int p = progress::ml::NewProgress(*state).progress16().value();
    best2[p].add(values[0]-values[1]);
    best10[p].add(values[0]-values[9]);
    best20[p].add(values[0]-values[19]);
    median[p].add(values[0]-values[values.size()/2]);
    worst[p].add(values[0]-values.back());
    selected[p].add(values[0]-selected_value);
    if ((best2[0].numElements() % 1024) == 0) {
      for (int i=0; i<16; ++i) {
	std::cout << i << ' ' << best2[i].getAverage()
		  << ' ' << best10[i].getAverage()
		  << ' ' << best20[i].getAverage()
		  << ' ' << median[i].getAverage() << ' ' << worst[i].getAverage()
		  << ' ' << selected[i].getAverage() << "\n" << std::flush;
      }
    }
  }
};

int main(int argc, char **argv)
{
  std::string eval_type, eval_data, kisen_filename, predicate_name;
  std::vector<std::string> pv_filenames;
  bool show_record_move, quiet;

  boost::program_options::options_description command_line_options;
  command_line_options.add_options()
    ("predicate",
     boost::program_options::value<std::string>(&predicate_name)->
     default_value("pass_is_better"),
     "Predicate to use.  Valid options are pass_is_better, "
     "move_after_pass, and eval_range")
    ("pv-file", po::value<std::vector<std::string> >(),
     "filename containing PVs")
    ("kisen-file,k",
     po::value<std::string>(&kisen_filename)->default_value(""),
     "Kisen filename corresponding to pv file")
    ("show-record-move",
     po::value<bool>(&show_record_move)->default_value(false),
     "show record move in addition to position when predicate matches")
    ("eval,e",
     po::value<std::string>(&eval_type)->default_value(std::string("piece")),
     "evaluation function (king or piece)")
    ("eval-data",
     po::value<std::string>(&eval_data)->default_value(""))
    ("quiet,q",
     po::value<bool>(&quiet)->default_value(false),
     "counting only.  do not show positions matched.")
    ;
  po::variables_map vm;
  try
  {
    po::store(po::parse_command_line(argc, argv, command_line_options), vm);
    po::notify(vm);
    if (vm.count("pv-file"))
      pv_filenames = vm["pv-file"].as<std::vector<std::string> >();
    else
    {
      std::cerr << "PV file wasn't specified" << std::endl;
      return 1;
    }
  }
  catch (std::exception& e)
  {
    std::cerr << "error in parsing options" << std::endl
	      << e.what() << std::endl;
    std::cerr << command_line_options << std::endl;
    return 1;
  }

  if (!osl::eval::ml::OpenMidEndingEval::setUp()) {
    std::cerr << "OpenMidEndingEval set up failed";
    // fall through as this might not be fatal depending on eval type used
  }
  if (!osl::progress::ml::NewProgress::setUp()) {
    std::cerr << "NewProgress set up failed";
    // fall through as this might not be fatal depending on eval type used
  }
  my_eval.reset(gpsshogi::EvalFactory::newEval(eval_type));
  if (! my_eval) {
    std::cerr << "unknown eval type " << eval_type << "\n";
    throw std::runtime_error("unknown eval type");
  }
  if (!eval_data.empty())
    if (! my_eval->load(eval_data.c_str()))
      std::cerr << "load failed " << eval_data << "\n";

  osl::record::KisenFile kisen(kisen_filename);

  boost::scoped_ptr<PVCompare> compare;
  if (predicate_name == "pass_is_better")
  {
    compare.reset(new PassIsBetter);
  }
  else if (predicate_name == "eval_range")
  {
    compare.reset(new EvalRange);
  }
  else if (predicate_name == "move_after_pass")
  {
    compare.reset(new MoveAfterPass);
  }
  else
  {
    std::cerr << "Unknown predicate "  << predicate_name;
    return 1;
  }

  osl::vector<osl::Move> moves;
  osl::state::NumEffectState state(kisen.getInitialState());
  for (size_t i = 0; i < pv_filenames.size(); ++i)
  {
    gpsshogi::PVFileReader pr(pv_filenames[i].c_str());
    int record, position;
    int cur_record = -1;
    int cur_position = 0;
    while (pr.newPosition(record, position))
    {
      if (record != cur_record)
      {
	cur_record = record;
	moves = kisen.getMoves(cur_record);
      }
      if (position == 0)
      {
	state = osl::state::NumEffectState(kisen.getInitialState());
	cur_position = 0;
      }
      else
      {
	while (position > cur_position)
	{
	  state.makeMove(moves[cur_position]);
	  ++cur_position;
	} 
      }
      if (cur_position >= (int)moves.size())
	continue;
      
      compare->init(state, moves, cur_position);
      gpsshogi::PVVector pv;
      while (pr.readPv(pv))
      {
	compare->add(pv);
	pv.clear();
      }
      compare->finish();
    }
  }
  return 0;
}
// ;;; Local Variables:
// ;;; mode:c++
// ;;; c-basic-offset:2
// ;;; End:
