00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00028 #ifndef __CLASSIFIER_H
00029 #define __CLASSIFIER_H
00030
00031 #include "Utils/Args.h"
00032 #include "IO/InputData.h"
00033 #include "Utils/Utils.h"
00034 #include "WeakLearners/BaseLearner.h"
00035
00036 #include <string>
00037 #include <cassert>
00038
00039 using namespace std;
00040
00041 namespace MultiBoost {
00042
00045
00062 class Classifier
00063 {
00064 public:
00065
00075 Classifier(nor_utils::Args& args, int verbose = 1);
00076
00093 void run(const string& dataFileName, const string& shypFileName, const int numRanksEnclosed = 2);
00094
00095 protected:
00096
00108 InputData* loadInputData(const string& dataFileName, const string& shypFileName);
00109
00117 class ExampleResults
00118 {
00119 public:
00120
00127 ExampleResults(const int idx, const int numClasses)
00128 : idx(idx), votesVector(numClasses, 0) {}
00129
00130 int idx;
00131
00137 vector<double> votesVector;
00138
00145 pair<int, double> getWinner(const int rank = 0)
00146 {
00147 assert(rank > 0);
00148
00149 vector< pair<int, double> > rankedList;
00150 getRankedList(rankedList);
00151 return rankedList[rank];
00152 }
00153
00168 bool isWinner(const int idxRealClass, const int atLeastRank = 0) const
00169 {
00170 assert(atLeastRank >= 0);
00171
00172 vector< pair<int, double> > rankedList;
00173 getRankedList(rankedList);
00174
00175 for (int i = 0; i <= atLeastRank; ++i)
00176 {
00177 if ( rankedList[i].first == idxRealClass )
00178 return true;
00179 }
00180
00181 return false;
00182 }
00183
00184 private:
00185
00194 void getRankedList( vector< pair<int, double> >& rankedList ) const
00195 {
00196 rankedList.resize(votesVector.size());
00197
00198 vector<double>::const_iterator vIt;
00199 int i;
00200 for (vIt = votesVector.begin(), i = 0; vIt != votesVector.end(); ++vIt, ++i )
00201 rankedList[i] = make_pair(i, *vIt);
00202
00203 sort( rankedList.begin(), rankedList.end(),
00204 nor_utils::comparePairOnSecond< int, double, greater<double> > );
00205 }
00206 };
00207
00208
00218 void computeResults(InputData* pData, vector<BaseLearner*>& weakHypotheses,
00219 vector< ExampleResults >& results);
00220
00233 double getOverallError( InputData* pData, const vector<ExampleResults>& results,
00234 int atLeastRank = 0 );
00235
00248 void getClassError( InputData* pData, const vector<ExampleResults>& results,
00249 vector<double>& classError, int atLeastRank = 0 );
00250
00257 int _verbose;
00258
00259 nor_utils::Args& _args;
00260 string _outputInfoFile;
00261
00262 };
00263
00264 }
00265
00266 #endif // __CLASSIFIER_H