00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024 #include "IO/Serialization.h"
00025 #include "IO/OutputInfo.h"
00026 #include "Classifier.h"
00027
00028 namespace MultiBoost {
00029
00030
00031
00032 Classifier::Classifier(nor_utils::Args &args, int verbose)
00033 : _args(args), _verbose(verbose)
00034 {
00035
00036 if ( args.hasArgument("-outputinfo") )
00037 args.getValue("-outputinfo", 0, _outputInfoFile);
00038 }
00039
00040
00041
00042 void Classifier::run(const string& dataFileName, const string& shypFileName, const int numRanksEnclosed)
00043 {
00044 InputData* pData = loadInputData(dataFileName, shypFileName);
00045
00046 if (_verbose > 0)
00047 {
00048 cout << "Loading strong hypothesis...";
00049 cout.flush();
00050 }
00051
00052
00053 UnSerialization us;
00054
00055
00056 vector<BaseLearner*> weakHypotheses;
00057
00058
00059 us.loadHypotheses(shypFileName, weakHypotheses);
00060
00061
00062 vector< ExampleResults > results;
00063
00064 if (_verbose > 0)
00065 {
00066 cout << "Classifying...";
00067 cout.flush();
00068 }
00069
00070
00071 computeResults( pData, weakHypotheses, results);
00072
00073 const int numClasses = ClassMappings::getNumClasses();
00074 const int numExamples = pData->getNumExamples();
00075
00076 if (_verbose > 0)
00077 {
00078 cout << "Done!" << endl;
00079
00080
00081 vector< vector<double> > rankedError(numRanksEnclosed);
00082
00083
00084 for (int i = 0; i < numRanksEnclosed; ++i)
00085 getClassError( pData, results, rankedError[i], i );
00086
00087
00088 cout << endl;
00089 cout << "Error Summary" << endl;
00090 cout << "=============" << endl;
00091
00092 for ( int l = 0; l < numClasses; ++l )
00093 {
00094
00095 cout << "Class '" << ClassMappings::getClassNameFromIdx(l) << "': "
00096 << rankedError[0][l] * 100 << "%";
00097
00098
00099 if (numRanksEnclosed > 1 && _verbose > 1)
00100 {
00101 cout << " (";
00102 for (int i = 1; i < numRanksEnclosed; ++i)
00103 cout << " " << i+1 << ":[" << rankedError[i][l] * 100 << "%]";
00104 cout << " )";
00105 }
00106
00107 cout << endl;
00108 }
00109
00110
00111 cout << "\n--> Overall Error: " << getOverallError(pData, results, 0) * 100 << "%";
00112
00113
00114 if (numRanksEnclosed > 1 && _verbose > 1)
00115 {
00116 cout << " (";
00117 for (int i = 1; i < numRanksEnclosed; ++i)
00118 cout << " " << i+1 << ":[" << getOverallError(pData, results, i) * 100 << "%]";
00119 cout << " )";
00120 }
00121
00122 cout << endl;
00123
00124 }
00125
00126
00127 if (pData)
00128 delete pData;
00129 }
00130
00131
00132
00133 InputData* Classifier::loadInputData(const string& dataFileName, const string& shypFileName)
00134 {
00135
00136 ifstream inFile(shypFileName.c_str());
00137 if (!inFile.is_open())
00138 {
00139 cerr << "ERROR: Cannot open strong hypothesis file <" << shypFileName << ">!" << endl;
00140 exit(1);
00141 }
00142
00143
00144 nor_utils::StreamTokenizer st(inFile, "<>\n\r\t");
00145
00146
00147 if ( !UnSerialization::seekSimpleTag(st, "multiboost") )
00148 {
00149
00150 cerr << "ERROR: Not a valid MultiBoost Strong Hypothesis file!!" << endl;
00151 exit(1);
00152 }
00153
00154
00155 string basicLearnerName = UnSerialization::seekAndParseEnclosedValue<string>(st, "algo");
00156
00157
00158 if ( !BaseLearner::RegisteredLearners().hasLearner(basicLearnerName) )
00159 {
00160 cerr << "ERROR: Weak learner <" << basicLearnerName << "> not registered!!" << endl;
00161 exit(1);
00162 }
00163
00164
00165 InputData* pData = BaseLearner::RegisteredLearners().getLearner(basicLearnerName)->createInputData();
00166
00167
00168 pData->initOptions(_args);
00169
00170 pData->load(dataFileName, IT_TEST, _verbose);
00171
00172 return pData;
00173 }
00174
00175
00176
00177
00178 void Classifier::computeResults(InputData* pData, vector<BaseLearner*>& weakHypotheses,
00179 vector< ExampleResults >& results)
00180 {
00181 assert( !weakHypotheses.empty() );
00182
00183 const int numClasses = ClassMappings::getNumClasses();
00184 const int numExamples = pData->getNumExamples();
00185
00186
00187 OutputInfo* pOutInfo = NULL;
00188
00189 if ( !_outputInfoFile.empty() )
00190 pOutInfo = new OutputInfo(_outputInfoFile);
00191
00192
00193
00194 results.clear();
00195 results.reserve(numExamples);
00196 for (int i = 0; i < numExamples; ++i)
00197 results.push_back( ExampleResults(i, numClasses) );
00198
00199
00200 vector<BaseLearner*>::const_iterator whyIt;
00201 int t;
00202
00203
00204 for (whyIt = weakHypotheses.begin(), t = 0;
00205 whyIt != weakHypotheses.end(); ++whyIt, ++t)
00206 {
00207 double alpha = (*whyIt)->getAlpha();
00208
00209
00210 for (int l = 0; l < numClasses; ++l)
00211 {
00212
00213 for (int i = 0; i < numExamples; ++i)
00214 results[i].votesVector[l] += alpha * (*whyIt)->classify(pData, i, l);
00215 }
00216
00217
00218 if ( pOutInfo )
00219 {
00220 pOutInfo->outputIteration(t);
00221 pOutInfo->outputError(pData, *whyIt);
00222
00223
00224
00225
00226
00227 pOutInfo->endLine();
00228 }
00229 }
00230
00231 if (pOutInfo)
00232 delete pOutInfo;
00233
00234 }
00235
00236
00237
00238 double Classifier::getOverallError( InputData* pData, const vector<ExampleResults>& results,
00239 int atLeastRank )
00240 {
00241 const int numExamples = pData->getNumExamples();
00242 int numErrors = 0;
00243
00244 assert(atLeastRank >= 0);
00245
00246 for (int i = 0; i < numExamples; ++i)
00247 {
00248
00249
00250 if ( !results[i].isWinner( pData->getClass(i), atLeastRank ) )
00251 ++numErrors;
00252 }
00253
00254
00255 return (double)numErrors / (double)numExamples;
00256 }
00257
00258
00259
00260 void Classifier::getClassError( InputData* pData, const vector<ExampleResults>& results,
00261 vector<double>& classError, int atLeastRank )
00262 {
00263 const int numExamples = pData->getNumExamples();
00264 const int numClasses = ClassMappings::getNumClasses();
00265
00266 int numErrors = 0;
00267 classError.resize( numClasses, 0 );
00268
00269 assert(atLeastRank >= 0);
00270
00271 for (int i = 0; i < numExamples; ++i)
00272 {
00273
00274
00275 if ( !results[i].isWinner( pData->getClass(i), atLeastRank ) )
00276 ++classError[ pData->getClass(i) ];
00277 }
00278
00279
00280 for (int l = 0; l < numClasses; ++l)
00281 classError[l] /= (double)pData->getNumExamplesPerClass(l);
00282 }
00283
00284
00285
00286 }