src/WeakLearners/MultiStumpLearner.cpp

00001 /*
00002 * This file is part of MultiBoost, a multi-class 
00003 * AdaBoost learner/classifier
00004 *
00005 * Copyright (C) 2005 Norman Casagrande
00006 * For informations write to nova77@gmail.com
00007 *
00008 * This library is free software; you can redistribute it and/or
00009 * modify it under the terms of the GNU Lesser General Public
00010 * License as published by the Free Software Foundation; either
00011 * version 2.1 of the License, or (at your option) any later version.
00012 *
00013 * This library is distributed in the hope that it will be useful,
00014 * but WITHOUT ANY WARRANTY; without even the implied warranty of
00015 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00016 * Lesser General Public License for more details.
00017 *
00018 * You should have received a copy of the GNU Lesser General Public
00019 * License along with this library; if not, write to the Free Software
00020 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
00021 *
00022 */
00023 
00024 #include "MultiStumpLearner.h"
00025 
00026 #include "IO/Serialization.h"
00027 #include "IO/SortedData.h"
00028 
00029 #include <limits> // for numeric_limits<>
00030 #include <cassert>
00031 
00032 namespace MultiBoost {
00033 
00034 REGISTER_LEARNER(MultiStumpLearner)
00035 
00036 // ------------------------------------------------------------------------------
00037 
00038 void MultiStumpLearner::run(InputData* pData)
00039 {
00040    const int numClasses = ClassMappings::getNumClasses();
00041    const int numColumns = pData->getNumColumns();
00042 
00043    // set the smoothing value to avoid numerical problem
00044    // when theta=0.
00045    setSmoothingVal( 1.0 / (double)pData->getNumExamples() * 0.01 );
00046 
00047    // resize
00048    _leftErrors.resize(numClasses);
00049    _rightErrors.resize(numClasses);
00050    _bestErrors.resize(numClasses);
00051    _weightsPerClass.resize(numClasses);
00052    _halfWeightsPerClass.resize(numClasses);
00053 
00054    vector<sRates> mu(numClasses); // The class-wise rates. See BaseLearner::sMu for more info.
00055 
00056    vector<char> tmpV(numClasses); // The class-wise votes/abstentions
00057    vector<double> tmpThresholds(numClasses);
00058    double tmpAlpha;
00059 
00060    double bestE = numeric_limits<double>::max();
00061    double tmpE;
00062 
00063    for (int j = 0; j < numColumns; ++j)
00064    {
00065       findThresholds(pData, j, tmpThresholds, mu, tmpV);
00066 
00067       tmpE = getEnergy(mu, tmpAlpha, tmpV);
00068       if (tmpE < bestE)
00069       {
00070          // Store it in the current algorithm
00071          // note: I don't really like having so many temp variables
00072          // but the alternative would be a structure, which would need
00073          // to be inheritable to make things more consistent. But this would
00074          // make it less flexible. Therefore, I am still undecided. This
00075          // might change!
00076 
00077          _alpha = tmpAlpha;
00078          _v = tmpV;
00079          _selectedColumn = j;
00080          _thresholds = tmpThresholds;
00081 
00082          bestE = tmpE;
00083       }
00084 
00085    }
00086 
00087 }
00088 
00089 // ------------------------------------------------------------------------------
00090 
00091 char MultiStumpLearner::phi(double val, int classIdx)
00092 {
00093    if (val > _thresholds[classIdx])
00094       return +1;
00095    else
00096       return -1;
00097 }
00098 
00099 // -----------------------------------------------------------------------
00100 
00101 void MultiStumpLearner::findThresholds(InputData* pData, const int columnIdx, 
00102                                        vector<double>& thresholds,
00103                                        vector<sRates>& mu, vector<char>& v)
00104 {
00105    const vpIterator dataBegin = static_cast<SortedData*>(pData)->getSortedBegin(columnIdx);
00106    const vpIterator dataEnd = static_cast<SortedData*>(pData)->getSortedEnd(columnIdx);
00107 
00108    const int numClasses = ClassMappings::getNumClasses();
00109 
00110    // resize and set to 0
00111    fill(_leftErrors.begin(), _leftErrors.end(), 0);
00112    fill(_weightsPerClass.begin(), _weightsPerClass.end(), 0);
00113 
00114    // get the number of examples
00115    const size_t numPoints = dataEnd - dataBegin;
00116 
00117    vpIterator currentSplitPos; // the iterator of the currently examined example
00118    vpIterator previousSplitPos; // the iterator of the example before the current example
00119    cvpIterator endArray; // the iterator on the last example (just before dataEnd)
00120 
00122    // Initialization of the class-wise error
00123 
00124    // The class-wise error on the right side of the threshold
00125    double tmpRightError;
00126 
00127    for (int l = 0; l < numClasses; ++l)
00128    {
00129       tmpRightError = 0;
00130 
00131       for( currentSplitPos = dataBegin; currentSplitPos != dataEnd; ++currentSplitPos)
00132       {
00133          double weight = pData->getWeight(currentSplitPos->first, l);
00134 
00135          // We assume that class "currClass" is always on the right side;
00136          // therefore, all points l that are not currClass (x) on right side,
00137          // are considered error.
00138          // <l x l x x x l x x> = 3 (if each example has weight 1)
00139          // ^-- tmpError: error if we set the cut at the extreme left side
00140          if ( pData->getClass(currentSplitPos->first) != l )
00141             tmpRightError += weight;
00142 
00143          _weightsPerClass[l] += weight;
00144       }
00145 
00146       _halfWeightsPerClass[l] = _weightsPerClass[l] / 2;
00147 
00148       assert(tmpRightError < 1);
00149       _rightErrors[l] = tmpRightError; // store the class-wise error
00150       _bestErrors[l] = numeric_limits<double>::max();
00151    }
00152 
00154 
00155    currentSplitPos = dataBegin; // reset position
00156    endArray = dataEnd;
00157    --endArray;
00158 
00159    double tmpError = 0;
00160    double bestError = numeric_limits<double>::max();
00161    bool flipIt;
00162 
00163    // find the best threshold (cutting point)
00164    while (currentSplitPos != endArray)
00165    {
00166       // at the first split we have
00167       // first split: x | x x x x x x x x ..
00168       //    previous -^   ^- current
00169       previousSplitPos = currentSplitPos;
00170       ++currentSplitPos; 
00171 
00172       // point at the same position: to skip because we cannot find a cutting point here!
00173       while ( previousSplitPos->second == currentSplitPos->second && currentSplitPos != endArray)
00174       {
00175          for (int l = 0; l < numClasses; ++l)
00176          { 
00177             if ( pData->getClass( previousSplitPos->first ) == l )
00178                _leftErrors[l] += pData->getWeight(previousSplitPos->first, l);
00179             else
00180                _rightErrors[l] -= pData->getWeight(previousSplitPos->first, l);
00181          }
00182 
00183          previousSplitPos = currentSplitPos;
00184          ++currentSplitPos; 
00185       }
00186 
00187       for (int l = 0; l < numClasses; ++l)
00188       { 
00189          if ( pData->getClass( previousSplitPos->first ) == l )
00190          {
00191             // c=current class, x=other class
00192             // .. c | x x c x c x .. 
00193             _leftErrors[l] += pData->getWeight(previousSplitPos->first, l);
00194          }
00195          else
00196          {
00197             // c=current class, x=other class
00198             // .. x | x x c x c x ..
00199             _rightErrors[l] -= pData->getWeight(previousSplitPos->first, l);
00200          }
00201 
00202          tmpError = _rightErrors[l] + _leftErrors[l];
00203 
00204          // switch the class-wise error if it is bigger than chance
00205          if(tmpError > _halfWeightsPerClass[l] + _smallVal)
00206          {
00207             tmpError = _weightsPerClass[l] - tmpError;
00208             flipIt = true;
00209          }
00210          else
00211             flipIt = false;
00212 
00213          // The summed error MUST be smaller than chance
00214          assert(tmpError <= _halfWeightsPerClass[l] + _smallVal); 
00215 
00216          // the overall error is smaller!
00217          if (tmpError < _bestErrors[l] + _smallVal)
00218          {
00219             _bestErrors[l] = tmpError;
00220             // compute the thresholds
00221             thresholds[l] = ( previousSplitPos->second + currentSplitPos->second ) / 2;
00222 
00223             // If we assume that class [l] is always on the right side,
00224             // here we must flip, as the lowest error is on the left side.
00225             // example:
00226             // c=current class, x=other class
00227             // .. c c c x | c x x x .. = 2 errors (if we flip!)
00228             if (flipIt)
00229                v[l] = -1;
00230             else
00231                v[l] = +1;
00232 
00233          }
00234       } // for l
00235    } // while (currentSplitPos != endArray)
00236 
00238 
00239    // Fill the mus. This could have been done in the threshold loop, 
00240    // but here is done just once
00241    for (int l = 0; l < numClasses; ++l)
00242    {
00243       mu[l].classIdx = l;
00244 
00245       mu[l].rPls  = _weightsPerClass[l]-_bestErrors[l];
00246       mu[l].rMin  = _bestErrors[l];
00247       mu[l].rZero = mu[l].rPls + mu[l].rMin;
00248    }
00249 
00250 }
00251 
00252 // -----------------------------------------------------------------------
00253 
00254 void MultiStumpLearner::save(ofstream& outputStream, const int numTabs)
00255 {
00256    // Calling the super-class method
00257    StumpLearner::save(outputStream, numTabs);
00258 
00259    // save all the thresholds
00260    outputStream << Serialization::vectorTag("thArray", _thresholds, numTabs) << endl;
00261 }
00262 
00263 // -----------------------------------------------------------------------
00264 
00265 void MultiStumpLearner::load(nor_utils::StreamTokenizer& st)
00266 {
00267    // Calling the super-class method
00268    StumpLearner::load(st);
00269 
00270    // load vArray data
00271    UnSerialization::seekAndParseVectorTag(st, "thArray", _thresholds);
00272 }
00273 
00274 // -----------------------------------------------------------------------
00275 
00276 } // end of namespace MultiBoost

Generated on Mon Nov 28 21:43:46 2005 for MultiBoost by  doxygen 1.4.5