src/Classifier.cpp

00001 /*
00002 * This file is part of MultiBoost, a multi-class 
00003 * AdaBoost learner/classifier
00004 *
00005 * Copyright (C) 2005 Norman Casagrande
00006 * For informations write to nova77@gmail.com
00007 *
00008 * This library is free software; you can redistribute it and/or
00009 * modify it under the terms of the GNU Lesser General Public
00010 * License as published by the Free Software Foundation; either
00011 * version 2.1 of the License, or (at your option) any later version.
00012 *
00013 * This library is distributed in the hope that it will be useful,
00014 * but WITHOUT ANY WARRANTY; without even the implied warranty of
00015 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00016 * Lesser General Public License for more details.
00017 *
00018 * You should have received a copy of the GNU Lesser General Public
00019 * License along with this library; if not, write to the Free Software
00020 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
00021 *
00022 */
00023 
00024 #include "IO/Serialization.h"
00025 #include "IO/OutputInfo.h"
00026 #include "Classifier.h"
00027 
00028 namespace MultiBoost {
00029 
00030 // -------------------------------------------------------------------------
00031 
00032 Classifier::Classifier(nor_utils::Args &args, int verbose)
00033 : _args(args), _verbose(verbose)
00034 {
00035    // The file with the step-by-step information
00036    if ( args.hasArgument("-outputinfo") )
00037       args.getValue("-outputinfo", 0, _outputInfoFile);
00038 }
00039 
00040 // -------------------------------------------------------------------------
00041 
00042 void Classifier::run(const string& dataFileName, const string& shypFileName, const int numRanksEnclosed)
00043 {
00044    InputData* pData = loadInputData(dataFileName, shypFileName);
00045 
00046    if (_verbose > 0)
00047    {
00048       cout << "Loading strong hypothesis...";
00049       cout.flush();
00050    }
00051 
00052    // The class that loads the weak hypotheses
00053    UnSerialization us;
00054 
00055    // Where to put the weak hypotheses
00056    vector<BaseLearner*> weakHypotheses;
00057 
00058    // loads them
00059    us.loadHypotheses(shypFileName, weakHypotheses);
00060 
00061    // where the results go
00062    vector< ExampleResults > results;
00063 
00064    if (_verbose > 0)
00065    {
00066       cout << "Classifying...";
00067       cout.flush();
00068    }
00069 
00070    // get the results
00071    computeResults( pData, weakHypotheses, results);
00072 
00073    const int numClasses = ClassMappings::getNumClasses();
00074    const int numExamples = pData->getNumExamples();
00075 
00076    if (_verbose > 0)
00077    {
00078       cout << "Done!" << endl;
00079 
00080       // well.. if verbose = 0 no results are displayed! :)
00081       vector< vector<double> > rankedError(numRanksEnclosed);
00082 
00083       // Get the per-class error for the numRanksEnclosed-th ranks
00084       for (int i = 0; i < numRanksEnclosed; ++i)
00085          getClassError( pData, results, rankedError[i], i );
00086 
00087       // output it
00088       cout << endl;
00089       cout << "Error Summary" << endl;
00090       cout << "=============" << endl;
00091 
00092       for ( int l = 0; l < numClasses; ++l )
00093       {
00094          // first rank (winner): rankedError[0]
00095          cout << "Class '" << ClassMappings::getClassNameFromIdx(l) << "': "
00096               << rankedError[0][l] * 100 << "%";
00097 
00098          // output the others on its side
00099          if (numRanksEnclosed > 1 && _verbose > 1)
00100          {
00101             cout << " (";
00102             for (int i = 1; i < numRanksEnclosed; ++i)
00103                cout << " " << i+1 << ":[" << rankedError[i][l] * 100 << "%]";
00104             cout << " )";
00105          }
00106 
00107          cout << endl;
00108       }
00109 
00110       // the overall error
00111       cout << "\n--> Overall Error: " << getOverallError(pData, results, 0) * 100 << "%";
00112 
00113       // output the others on its side
00114       if (numRanksEnclosed > 1 && _verbose > 1)
00115       {
00116          cout << " (";
00117          for (int i = 1; i < numRanksEnclosed; ++i)
00118             cout << " " << i+1 << ":[" << getOverallError(pData, results, i) * 100 << "%]";
00119          cout << " )";
00120       }
00121 
00122       cout << endl;
00123 
00124    } // verbose
00125 
00126    // delete the input data file
00127    if (pData) 
00128       delete pData;
00129 }
00130 
00131 // -------------------------------------------------------------------------
00132 
00133 InputData* Classifier::loadInputData(const string& dataFileName, const string& shypFileName)
00134 {
00135    // open file
00136    ifstream inFile(shypFileName.c_str());
00137    if (!inFile.is_open())
00138    {
00139       cerr << "ERROR: Cannot open strong hypothesis file <" << shypFileName << ">!" << endl;
00140       exit(1);
00141    }
00142 
00143    // Declares the stream tokenizer
00144    nor_utils::StreamTokenizer st(inFile, "<>\n\r\t");
00145 
00146    // Move until it finds the multiboost tag
00147    if ( !UnSerialization::seekSimpleTag(st, "multiboost") )
00148    {
00149       // no multiboost tag found: this is not the correct file!
00150       cerr << "ERROR: Not a valid MultiBoost Strong Hypothesis file!!" << endl;
00151       exit(1);
00152    }
00153 
00154    // Move until it finds the algo tag
00155    string basicLearnerName = UnSerialization::seekAndParseEnclosedValue<string>(st, "algo");
00156 
00157    // Check if the weak learner exists
00158    if ( !BaseLearner::RegisteredLearners().hasLearner(basicLearnerName) )
00159    {
00160       cerr << "ERROR: Weak learner <" << basicLearnerName << "> not registered!!" << endl;
00161       exit(1);
00162    }
00163 
00164    // get the training input data, and load it
00165    InputData* pData = BaseLearner::RegisteredLearners().getLearner(basicLearnerName)->createInputData();
00166 
00167    // set the non-default arguments of the input data
00168    pData->initOptions(_args);
00169    // load the data
00170    pData->load(dataFileName, IT_TEST, _verbose);
00171 
00172    return pData;
00173 }
00174 
00175 // -------------------------------------------------------------------------
00176 
00177 // Returns the results into ptRes
00178 void Classifier::computeResults(InputData* pData, vector<BaseLearner*>& weakHypotheses, 
00179                                 vector< ExampleResults >& results)
00180 {
00181    assert( !weakHypotheses.empty() );
00182 
00183    const int numClasses = ClassMappings::getNumClasses();
00184    const int numExamples = pData->getNumExamples();
00185 
00186    // Initialize the output info
00187    OutputInfo* pOutInfo = NULL;
00188 
00189    if ( !_outputInfoFile.empty() )
00190       pOutInfo = new OutputInfo(_outputInfoFile);
00191 
00192    // Creating the results structures. See file Structures.h for the
00193    // PointResults structure
00194    results.clear();
00195    results.reserve(numExamples);
00196    for (int i = 0; i < numExamples; ++i)
00197       results.push_back( ExampleResults(i, numClasses) );
00198 
00199    // iterator over all the weak hypotheses
00200    vector<BaseLearner*>::const_iterator whyIt; 
00201    int t;
00202 
00203    // for every feature: 1..T
00204    for (whyIt = weakHypotheses.begin(), t = 0; 
00205         whyIt != weakHypotheses.end(); ++whyIt, ++t)
00206    {
00207       double alpha = (*whyIt)->getAlpha();
00208 
00209       // for every class
00210       for (int l = 0; l < numClasses; ++l)
00211       {
00212          // for every point
00213          for (int i = 0; i < numExamples; ++i)
00214             results[i].votesVector[l] += alpha * (*whyIt)->classify(pData, i, l);
00215       }
00216 
00217       // if needed output the step-by-step information
00218       if ( pOutInfo )
00219       {
00220          pOutInfo->outputIteration(t);
00221          pOutInfo->outputError(pData, *whyIt);
00222 
00223          // Margins and edge requires an update of the weight,
00224          // therefore I keep them out for the moment
00225          //outInfo.outputMargins(pData, *whyIt);
00226          //outInfo.outputEdge(pData, *whyIt);
00227          pOutInfo->endLine();
00228       }
00229    }
00230 
00231    if (pOutInfo)
00232       delete pOutInfo;
00233 
00234 }
00235 
00236 // -------------------------------------------------------------------------
00237 
00238 double Classifier::getOverallError( InputData* pData, const vector<ExampleResults>& results, 
00239                                     int atLeastRank )
00240 {
00241    const int numExamples = pData->getNumExamples();
00242    int numErrors = 0;
00243 
00244    assert(atLeastRank >= 0);
00245 
00246    for (int i = 0; i < numExamples; ++i)
00247    {
00248       // if the real class is not the one with the highest vote in the
00249       // vote vector, then it is an error!
00250       if ( !results[i].isWinner( pData->getClass(i), atLeastRank ) )
00251          ++numErrors;
00252    }  
00253 
00254    // makes the error between 0 and 1
00255    return (double)numErrors / (double)numExamples;
00256 }
00257 
00258 // -------------------------------------------------------------------------
00259 
00260 void Classifier::getClassError( InputData* pData, const vector<ExampleResults>& results, 
00261                                 vector<double>& classError, int atLeastRank )
00262 {
00263    const int numExamples = pData->getNumExamples();
00264    const int numClasses = ClassMappings::getNumClasses();
00265 
00266    int numErrors = 0;
00267    classError.resize( numClasses, 0 );
00268 
00269    assert(atLeastRank >= 0);
00270 
00271    for (int i = 0; i < numExamples; ++i)
00272    {
00273       // if the real class is not the one with the highest vote in the
00274       // vote vector, then it is an error!
00275       if ( !results[i].isWinner( pData->getClass(i), atLeastRank ) )
00276          ++classError[ pData->getClass(i) ];
00277    }
00278 
00279    // makes the error between 0 and 1
00280    for (int l = 0; l < numClasses; ++l)
00281       classError[l] /= (double)pData->getNumExamplesPerClass(l);
00282 }
00283 
00284 // -------------------------------------------------------------------------
00285 
00286 } // end of namespace MultiBoost

Generated on Mon Nov 28 21:43:46 2005 for MultiBoost by  doxygen 1.4.5