00001 #include "HMMModule.h"
00002 #include <string>
00003 #include "../framedata/VectorFrameData.h"
00004 #include "../framedata/IntegerFrameData.h"
00005 #include "../framedata/SignalFrameData.h"
00006 #include "../visualizer/HMMModuleGUI.h"
00007 #include "../util/Error.h"
00008 #include <iostream>
00009
00010 HMMModule::HMMModule(AbstractModule *successor) :
00011 ProcessModule(successor),
00012 m_hasGestureStarted(false),
00013 m_trainGestureID(0)
00014 {
00015 setGUI(new HMMModuleGUI(this));
00016 }
00017
00018 HMMModule::~HMMModule()
00019 {
00020 for (Evaluators::iterator it = m_gestureClassificators.begin();
00021 it != m_gestureClassificators.end(); ++it) {
00022 delete it->first;
00023 delete it->second;
00024 }
00025 }
00026
00027 void HMMModule::processFrameData(IFrameData *data)
00028 {
00029 if (isSignal(data, SignalFrameData::SIG_START_GESTURE)) {
00030 m_hasGestureStarted = true;
00031 m_gesture.clear();
00032 delete data;
00033 return;
00034 }
00035
00036 if (!m_hasGestureStarted) {
00037 delete data;
00038 return;
00039 }
00040
00041
00042 if (isSignal(data, SignalFrameData::SIG_STOP_GESTURE)) {
00043 delete data;
00044 if (m_trainGestureID == 0) {
00045 classifyRecordedGesture();
00046 } else {
00047 trainRecordedGesture();
00048 }
00049 return;
00050 }
00051
00052 assertFramedataType(data, "Vector3D", "HMMModule::processFrameData");
00053
00054
00055 m_gesture.push_back(static_cast<VectorFrameData*>(data)->getData());
00056 delete data;
00057 }
00058
00059 int HMMModule::addGesture(int nClusters, int nStates)
00060 {
00061 m_gestureClassificators.push_back(std::make_pair(
00062 new KMeanCluster(nClusters),
00063 new LeftToRightHMM(nStates, nClusters)
00064 ));
00065 return m_gestureClassificators.size()-1;
00066 }
00067
00068 void HMMModule::delGesture(size_t gestureID)
00069 {
00070
00071 Evaluators::iterator it = m_gestureClassificators.begin();
00072 for (size_t i = 0; it != m_gestureClassificators.end(); ++it, ++i) {
00073 if (i == gestureID) {
00074 delete it->first;
00075 delete it->second;
00076 m_gestureClassificators.erase(it);
00077 return;
00078 }
00079 }
00080 }
00081
00082 void HMMModule::clearGestures()
00083 {
00084 Evaluators::iterator it;
00085 for (it = m_gestureClassificators.begin(); it != m_gestureClassificators.end(); ++it) {
00086 delete it->first;
00087 delete it->second;
00088 }
00089 m_gestureClassificators.clear();
00090 }
00091
00092 void HMMModule::trainGesture(size_t gestureID)
00093 {
00094 m_trainGestureID = gestureID;
00095 }
00096
00097 void HMMModule::stopTraining()
00098 {
00099 m_trainGestureID = 0;
00100 }
00101
00102 void HMMModule::classifyRecordedGesture()
00103 {
00104 double maxProp = 0.4, prop = -1;
00105 int maxPropGestureID = 0, gestureID = 1;
00106 LeftToRightHMM::Observation obs;
00107
00108 for (Evaluators::iterator evalIt = m_gestureClassificators.begin();
00109 evalIt != m_gestureClassificators.end(); ++evalIt, ++gestureID) {
00110 obs.clear();
00111
00112 for (Gesture::iterator it = m_gesture.begin(); it != m_gesture.end(); ++it) {
00113 obs.push_back(evalIt->first->clusterize(*it));
00114 }
00115 prop = evalIt->second->evaluate(obs);
00116 if (prop > maxProp) {
00117 maxPropGestureID = gestureID;
00118 }
00119 }
00120 ProcessModule::processFrameData(new IntegerFrameData(gestureID));
00121 }
00122
00123 void HMMModule::trainRecordedGesture()
00124 {
00125 KMeanCluster *kmean = m_gestureClassificators[m_trainGestureID].first;
00126 LeftToRightHMM *hmm = m_gestureClassificators[m_trainGestureID].second;
00127
00128 if (kmean == NULL || hmm == NULL) {
00129 throw Error("KMean Clusterizer or HMM is NULL!");
00130 }
00131
00132 std::cout << m_gesture.size() << std::endl;
00133 for (Gesture::iterator it = m_gesture.begin(); it != m_gesture.end(); ++it) {
00134 std::cout << "Vector: " << it->get(0) << " " << it->get(1) << " " << it->get(2) << std::endl;
00135 }
00136
00137 kmean->train(m_gesture);
00138 LeftToRightHMM::Observation obs;
00139 for (Gesture::iterator it = m_gesture.begin(); it != m_gesture.end(); ++it) {
00140 obs.push_back(kmean->clusterize(*it));
00141 }
00142 hmm->train(obs, 10);
00143
00144 }
00145
00146 void HMMModule::addGesure(KMeanCluster *clusterizer, LeftToRightHMM *hmm)
00147 {
00148 m_gestureClassificators.push_back(std::make_pair(clusterizer, hmm));
00149 }
00150
00151 bool HMMModule::serializeGesture(ConfigManager &gestureFile, std::string gestureName, int gestureId)
00152 {
00153 KMeanCluster *kmean = m_gestureClassificators.at(gestureId).first;
00154 LeftToRightHMM *hmm = m_gestureClassificators.at(gestureId).second;
00155
00156 if ((kmean == NULL) || (hmm == NULL) || gestureFile.hasSection(gestureName))
00157 return false;
00158
00159 ConfigSection gestureSection;
00160 gestureSection.setItem<int>("id", gestureId);
00161 gestureSection.setItem<int>("clusters", kmean->getClusterCount());
00162 gestureSection.setItem<int>("states", hmm->getStateCount());
00163 gestureSection.setMatrixDouble("transition", hmm->getTransitionProps());
00164 gestureSection.setMatrixDouble("output", hmm->getOutputProps());
00165
00166
00167 int index = 0;
00168 for (KMeanCluster::VectorList::iterator it = kmean->begin(); it != kmean->end(); ++it, ++index) {
00169 std::string name("center_");
00170 name += convertTo<int, std::string>(index);
00171 gestureSection.setVector3d(name, *it);
00172 }
00173
00174 gestureFile.setSection(gestureName, gestureSection);
00175
00176 return true;
00177 }
00178
00179 bool HMMModule::deserializeGesture(ConfigSection §ion, std::string gestureName)
00180 {
00181
00182
00183 int clusters = section.get<int>("clusters");
00184 KMeanCluster::VectorList centers;
00185 for (int i = 0; i < clusters; ++i) {
00186 centers.push_back(section.getVector3d("center_"+convertTo<int, std::string>(i)));
00187 }
00188
00189 KMeanCluster *kmean = new KMeanCluster(clusters);
00190 kmean->setCentres(centers);
00191
00192
00193 int states = section.get<int>("states");
00194 Matrix<double> transitions = section.getMatrixDouble("transition");
00195 Matrix<double> outputs = section.getMatrixDouble("output");
00196
00197 LeftToRightHMM *hmm = new LeftToRightHMM(states, clusters);
00198 hmm->setTransitionProps(transitions);
00199 hmm->setOutputProps(outputs);
00200
00201
00202 int id = section.get<int>("id");
00203 if (m_gestureClassificators.size() <= static_cast<size_t>(id)) {
00204 m_gestureClassificators.resize(id+1);
00205 }
00206 m_gestureClassificators.at(id) = std::make_pair(kmean, hmm);
00207
00208 return true;
00209 }