00001 #include "KMeanCluster.h"
00002 #include <iostream>
00003
00004 KMeanCluster::KMeanCluster(int nCentres)
00005 {
00006 for (int i = 0; i < nCentres; ++i) {
00007 m_centres.push_back(Vector3d(0));
00008 }
00009 }
00010
00011 void KMeanCluster::train(VectorList &v)
00012 {
00013 std::cout << "enter KMeanCluster::train" << std::endl;
00014 VectorList::iterator itCentres = m_centres.begin();
00015 VectorList::iterator itTraining = v.begin();
00016
00017
00018 std::cout << "\tinit centres" << std::endl;
00019 for (; itCentres != m_centres.end() && itTraining != v.end();
00020 ++itCentres, ++itTraining) {
00021 *itCentres = *itTraining;
00022 }
00023 if (itTraining == v.end())
00024 return;
00025
00026
00027 std::cout << "\tassign groups" << std::endl;
00028 VectorList oldCentres;
00029 VectorList groups[m_centres.size()];
00030
00031 do {
00032 oldCentres = m_centres;
00033 for (itTraining = v.begin(); itTraining != v.end(); ++itTraining) {
00034 groups[clusterize(*itTraining)].push_back(*itTraining);
00035 }
00036
00037
00038 std::cout << "\t\treassign centres" << std::endl;
00039 Vector3d tmp(0);
00040 int n = 0;
00041 for (itCentres = m_centres.begin();
00042 itCentres != m_centres.end();
00043 ++itCentres, ++n) {
00044 for (itTraining = groups[n].begin(); itTraining != groups[n].end();
00045 ++itTraining){
00046 tmp += *itTraining;
00047 }
00048 tmp /= groups[n].size();
00049 }
00050 } while (isSameSet(oldCentres, m_centres));
00051 std::cout << "leave KMeanCluster::train" << std::endl;
00052 }
00053
00054 int KMeanCluster::clusterize(Vector3d &v)
00055 {
00056 float dist, bestDist = 1e100;
00057 int n = 0, nBest = 0;
00058 Vector3d tmp;
00059 for(VectorList::iterator it = m_centres.begin(); it != m_centres.end(); ++n, ++it) {
00060 tmp = *it - v;
00061 dist = tmp.get(0)*tmp.get(0) + tmp.get(1)*tmp.get(1) + tmp.get(2)*tmp.get(2);
00062 if (dist < bestDist) {
00063 nBest = n;
00064 bestDist = dist;
00065 }
00066 }
00067 return nBest;
00068 }
00069
00070 bool KMeanCluster::isSameSet(VectorList &a, VectorList &b)
00071 {
00072
00073 VectorList::iterator itA, itB;
00074 for (itA = a.begin(), itB = b.begin(); itA != a.end() && itB != b.end();
00075 ++itA, ++itB) {
00076 if (*itA != *itB)
00077 return false;
00078 }
00079 return true;
00080 }