#ifndef INDII_CLUSTER_KMEANSCLUSTERER_HPP
#define INDII_CLUSTER_KMEANSCLUSTERER_HPP

#include "ClusterVector.hpp"
#include "DataSet.hpp"
#include "PearsonDistance.hpp"

#include "boost/random.hpp"

#include <vector>

namespace indii {
  namespace cluster {
/**
 * K-means clustering algorithm.
 *
 * @author Lawrence Murray <lawrence@indii.org>
 * @version $Rev: 116 $
 * @date $Date: 2009-09-19 17:32:46 +0800 (Sat, 19 Sep 2009) $
 *
 * @tparam T Datum element type.
 * @tparam W Weight type.
 * @tparam K Cluster index type.
 * @tparam D Distance metric type.
 */
template <class T = float, class W = unsigned, class K = unsigned,
    class D = PearsonDistance<T> >
class KMeansClusterer {
public:
  /**
   * Constructor
   *
   * @param k Number of clusters.
   * @param seed Random number seed.
   */
  KMeansClusterer(K k, const unsigned seed);

  /**
   * Perform steps of the clustering algorithm until convergence or a
   * maximum number of iterations is reached.
   *
   * @param data Data set to cluster.
   * @param maxIters Maximum number of iterations to perform, zero for
   * no limit (use with caution!).
   */
  void cluster(const DataSet<T,W>& data, unsigned maxIters = 0);

  /**
   * Determine the nearest cluster for a given point.
   *
   * @param x The point. This may be arbitrary, it needn't have been used in
   * forming the current set of clusters.
   * @param distance If given, on return, gives the distance of the datum from
   * its assigned cluster centroid.
   *
   * @return Nearest cluster for the given point.
   */
  K assign(const typename ClusterVector<T>::type& x, T* distance = NULL) const;

  /**
   * Determine normalised distance of point from each centroid.
   *
   * @param x The point. This may be arbitrary, it needn't have been used in
   * forming the current set of clusters.
   * @param[out] ds Vector of length equal to the number of clusters, into which
   * normalised distances will be written.
   */
  void distances(const typename ClusterVector<T>::type& x,
      std::vector<T>& ds) const;

  /**
   * Get centroids.
   *
   * @return Centroids.
   */
  const std::vector<typename ClusterVector<T>::type>& getCentroids();

  /**
   * Get centroid.
   *
   * @param i Cluster number.
   *
   * @return Centroid for the given cluster.
   */
  const typename ClusterVector<T>::type& getCentroid(const unsigned i);

  /**
   * Set centroid.
   *
   * @param i Cluster number.
   * @param x Centroid for cluster @p i
   */
  void setCentroid(const unsigned i,
      const typename ClusterVector<T>::type& x);

  /**
   * Get estimate of error inherent in the current clustering.
   *
   * @return Error of the current clustering.
   */
  double getError();

private:
  /**
   * Calculate centroids.
   *
   * @param data Data set.
   * @param clusters Assignment of points in data set to clusters.
   */
  void calculateCentroids(const DataSet<T,W>& data,
      const std::vector<K>& clusters);

  /**
   * Number of clusters.
   */
  const unsigned short k;

  /**
   * Error inherent in current clustering.
   */
  double error;

  /**
   * Centroids.
   */
  std::vector<typename ClusterVector<T>::type> centroids;

  /*
   * Random number generation. Mersenne Twister would be better than Linear
   * Congruential for random number generation, but appears to give compile
   * errors in Visual C++ 9
   */
  typedef boost::uniform_int<> RandomClusterDist;
  typedef boost::minstd_rand RandomClusterRNG;
  typedef boost::variate_generator<RandomClusterRNG&, RandomClusterDist>
      RandomClusterGen;

  RandomClusterDist dist;
  RandomClusterRNG rng;
  RandomClusterGen randomCluster;

};

  }
}

template <class T, class W, class K, class D>
indii::cluster::KMeansClusterer<T,W,K,D>::KMeansClusterer(K k,
    const unsigned seed) : k(k), dist(0, k-1), rng(),
    randomCluster(rng,dist) {
  rng.seed(seed);

  /* centroids */
  unsigned i;
  typename ClusterVector<T>::type centroid;
  for (i = 0; i < k; i++) {
    centroids.push_back(centroid);
  }
}

template <class T, class W, class K, class D>
void indii::cluster::KMeansClusterer<T,W,K,D>::cluster(const DataSet<T,W>& data,
    unsigned maxIters) {
  unsigned i, j, size = data.getSize(), iters = 0, changes = 1;
  T distance;
  K newK, currentK;
  std::vector<K> clusters;
  typename std::vector<K>::iterator clustersIter;
  typename DataSet<T,W>::data_set_const_iterator dataIter, dataEnd;

  /* allocate space for centroids */
  typename ClusterVector<T>::type x(data.getDimensions());
  centroids.clear();
  for (i = 0; i < k; i++) {
    for (j = 0; j < x.size(); j++) {
      x(j) = randomCluster();
    }
    D::prepare(x);
    centroids.push_back(x);
  }

  /* initialise with random clustering */
  for (i = 0; i < size; i++) {
    clusters.push_back(static_cast<K>(randomCluster()));
  }

  /* iterate */
  while (changes > 0 && iters < maxIters) {
    changes = 0;
    error = 0.0;

    /* assign data to nearest centroid */
    clustersIter = clusters.begin();
    dataIter = data.begin();
    dataEnd = data.end();
    while (dataIter != dataEnd) {
      currentK = *clustersIter;
      newK = assign(dataIter->first, &distance);
      error += distance*distance;
      if (newK != currentK) {
        changes++;
        *clustersIter = newK;
      }

      clustersIter++;
      dataIter++;
    }

    /* recalculate centroids */
    calculateCentroids(data, clusters);

    iters++;
  }
}

template <class T, class W, class K, class D>
inline K indii::cluster::KMeansClusterer<T,W,K,D>::assign(
    const typename ClusterVector<T>::type& x, T* distance) const {
  /* pre-condition */
  assert (centroids.size() == k);

  unsigned i = 0;
  T d = D::distance(centroids[i], x);
  unsigned imin = i;
  T dmin = d;
  for (i = 1; i < k; i++) {
    d = D::distance(centroids[i], x);
    if (d < dmin) {
      dmin = d;
      imin = i;
    }
  }

  if (distance != NULL) {
    *distance = dmin;
  }
  return imin;
}

template <class T, class W, class K, class D>
inline void indii::cluster::KMeansClusterer<T,W,K,D>::distances(
    const typename ClusterVector<T>::type& x, std::vector<T>& ds) const {
  /* pre-condition */
  assert (ds.size() == k);

  unsigned i;
  for (i = 0; i < k; ++i) {
    ds[i] = D::distance(centroids[i], x);
  }
}

template <class T, class W, class K, class D>
inline const std::vector<typename indii::cluster::ClusterVector<T>::type>&
    indii::cluster::KMeansClusterer<T,W,K,D>::getCentroids() {
  return centroids;
}

template <class T, class W, class K, class D>
inline const typename indii::cluster::ClusterVector<T>::type&
    indii::cluster::KMeansClusterer<T,W,K,D>::getCentroid(const unsigned i) {
  /* pre-condition */
  assert (i < k);

  return centroids[i];
}

template <class T, class W, class K, class D>
void indii::cluster::KMeansClusterer<T,W,K,D>::setCentroid(const unsigned i,
    const typename ClusterVector<T>::type& x) {
  /* pre-condition */
  assert (i < k);

  centroids[i].resize(x.size());
  centroids[i] = x;
}

template <class T, class W, class K, class D>
inline double indii::cluster::KMeansClusterer<T,W,K,D>::getError() {
  return error;
}

template <class T, class W, class K, class D>
void indii::cluster::KMeansClusterer<T,W,K,D>::calculateCentroids(
    const DataSet<T,W>& data, const std::vector<K>& clusters) {
  /* pre-condition */
  assert (data.getSize() == clusters.size());

  std::vector<K> centroidWeights;
  typename std::vector<typename ClusterVector<T>::type>::iterator centroidsIter, centroidsEnd;
  typename std::vector<K>::iterator centroidWeightsIter;
  typename DataSet<T,W>::data_set_const_iterator dataIter, dataEnd;
  typename std::vector<K>::const_iterator clustersIter;
  unsigned i;

  /* initialise centroids to zero */
  centroidsIter = centroids.begin();
  centroidsEnd = centroids.end();
  while (centroidsIter != centroidsEnd) {
    centroidsIter->clear();
    centroidsIter++;
  }

  /* initialise centroid weights to zero */
  for (i = 0; i < k; i++) {
    centroidWeights.push_back(0);
  }

  /* calculate centroids as means */
  clustersIter = clusters.begin();
  dataIter = data.begin();
  dataEnd = data.end();
  while (dataIter != dataEnd) {
    noalias(centroids[*clustersIter]) += dataIter->first * dataIter->second;
    centroidWeights[*clustersIter] += dataIter->second;

    dataIter++;
    clustersIter++;
  }

  centroidsIter = centroids.begin();
  centroidsEnd = centroids.end();
  centroidWeightsIter = centroidWeights.begin();
  while (centroidsIter != centroidsEnd) {
    if (*centroidWeightsIter > 0.0) {
      *centroidsIter /= *centroidWeightsIter;
    }
    D::prepare(*centroidsIter);
    
    centroidsIter++;
    centroidWeightsIter++;
  }
}

#endif
