src/WeakLearners/SingleStumpLearner.cpp

00001 /*
00002 * This file is part of MultiBoost, a multi-class 
00003 * AdaBoost learner/classifier
00004 *
00005 * Copyright (C) 2005 Norman Casagrande
00006 * For informations write to nova77@gmail.com
00007 *
00008 * This library is free software; you can redistribute it and/or
00009 * modify it under the terms of the GNU Lesser General Public
00010 * License as published by the Free Software Foundation; either
00011 * version 2.1 of the License, or (at your option) any later version.
00012 *
00013 * This library is distributed in the hope that it will be useful,
00014 * but WITHOUT ANY WARRANTY; without even the implied warranty of
00015 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00016 * Lesser General Public License for more details.
00017 *
00018 * You should have received a copy of the GNU Lesser General Public
00019 * License along with this library; if not, write to the Free Software
00020 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
00021 *
00022 */
00023 
00024 #include "SingleStumpLearner.h"
00025 
00026 #include "IO/Serialization.h"
00027 #include "IO/SortedData.h"
00028 
00029 #include <limits> // for numeric_limits<>
00030 #include <cassert>
00031 
00032 namespace MultiBoost {
00033 
00034 
00035 //REGISTER_LEARNER_NAME(SingleStump, SingleStumpLearner)
00036 REGISTER_LEARNER(SingleStumpLearner)
00037 
00038 // ------------------------------------------------------------------------------
00039 
00040 void SingleStumpLearner::run(InputData* pData)
00041 {
00042    const int numClasses = ClassMappings::getNumClasses();
00043    const int numColumns = pData->getNumColumns();
00044 
00045    // set the smoothing value to avoid numerical problem
00046    // when theta=0.
00047    setSmoothingVal( 1.0 / (double)pData->getNumExamples() * 0.01 );
00048 
00049    // resize
00050    _leftErrors.resize(numClasses);
00051    _rightErrors.resize(numClasses);
00052    _bestErrors.resize(numClasses);
00053    _weightsPerClass.resize(numClasses);
00054    _halfWeightsPerClass.resize(numClasses);
00055 
00056    vector<sRates> mu(numClasses); // The class-wise rates. See BaseLearner::sMu for more info.
00057    vector<char> tmpV(numClasses); // The class-wise votes/abstentions
00058 
00059    double tmpThreshold;
00060    double tmpAlpha;
00061 
00062    double bestE = numeric_limits<double>::max();
00063    double tmpE;
00064 
00065    for (int j = 0; j < numColumns; ++j)
00066    {
00067       findThreshold(pData, j, tmpThreshold, mu, tmpV);
00068 
00069       tmpE = getEnergy(mu, tmpAlpha, tmpV);
00070       if (tmpE < bestE)
00071       {
00072          // Store it in the current algorithm
00073          // note: I don't really like having so many temp variables
00074          // but the alternative would be a structure, which would need
00075          // to be inheritable to make things more consistent. But this would
00076          // make it less flexible. Therefore, I am still undecided. This
00077          // might change!
00078 
00079          _alpha = tmpAlpha;
00080          _v = tmpV;
00081          _selectedColumn = j;
00082          _threshold = tmpThreshold;
00083 
00084          bestE = tmpE;
00085       }
00086 
00087    }
00088 
00089 }
00090 
00091 // ------------------------------------------------------------------------------
00092 
00093 char SingleStumpLearner::phi(double val, int classIdx)
00094 {
00095    if (val > _threshold)
00096       return +1;
00097    else
00098       return -1;
00099 }
00100 
00101 // -----------------------------------------------------------------------
00102 
00103 void SingleStumpLearner::findThreshold(InputData* pData, const int columnIdx, 
00104                                        double& threshold,
00105                                        vector<sRates>& mu, vector<char>& v)
00106 {
00107    const vpIterator dataBegin = static_cast<SortedData*>(pData)->getSortedBegin(columnIdx);
00108    const vpIterator dataEnd = static_cast<SortedData*>(pData)->getSortedEnd(columnIdx);
00109 
00110    const int numClasses = ClassMappings::getNumClasses();
00111 
00112    // resize and set to 0
00113    fill(_leftErrors.begin(), _leftErrors.end(), 0);
00114    fill(_weightsPerClass.begin(), _weightsPerClass.end(), 0);
00115 
00116    // get the number of examples
00117    const size_t numPoints = dataEnd - dataBegin;
00118 
00119    vpIterator currentSplitPos; // the iterator of the currently examined example
00120    vpIterator previousSplitPos; // the iterator of the example before the current example
00121    cvpIterator endArray; // the iterator on the last example (just before dataEnd)
00122 
00124    // Initialization of the class-wise error
00125 
00126    // The class-wise error on the right side of the threshold
00127    double tmpRightError;
00128 
00129    for (int l = 0; l < numClasses; ++l)
00130    {
00131       tmpRightError = 0;
00132 
00133       for( currentSplitPos = dataBegin; currentSplitPos != dataEnd; ++currentSplitPos)
00134       {
00135          double weight = pData->getWeight(currentSplitPos->first, l);
00136 
00137          // We assume that class "currClass" is always on the right side;
00138          // therefore, all points l that are not currClass (x) on right side,
00139          // are considered error.
00140          // <l x l x x x l x x> = 3 (if each example has weight 1)
00141          // ^-- tmpError: error if we set the cut at the extreme left side
00142          if ( pData->getClass(currentSplitPos->first) != l )
00143             tmpRightError += weight;
00144 
00145          _weightsPerClass[l] += weight;
00146       }
00147 
00148       _halfWeightsPerClass[l] = _weightsPerClass[l] / 2;
00149 
00150       assert(tmpRightError < 1);
00151       _rightErrors[l] = tmpRightError; // store the class-wise error
00152    }
00153 
00155 
00156    currentSplitPos = dataBegin; // reset position
00157    endArray = dataEnd;
00158    --endArray;
00159 
00160    double tmpError = 0;
00161    double currError = 0;
00162    double bestError = numeric_limits<double>::max();
00163 
00164    // find the best threshold (cutting point)
00165    while (currentSplitPos != endArray)
00166    {
00167       // at the first split we have
00168       // first split: x | x x x x x x x x ..
00169       //    previous -^   ^- current
00170       previousSplitPos = currentSplitPos;
00171       ++currentSplitPos; 
00172 
00173       // point at the same position: to skip because we cannot find a cutting point here!
00174       while ( previousSplitPos->second == currentSplitPos->second && currentSplitPos != endArray)
00175       {
00176          for (int l = 0; l < numClasses; ++l)
00177          { 
00178             if ( pData->getClass( previousSplitPos->first ) == l )
00179                _leftErrors[l] += pData->getWeight(previousSplitPos->first, l);
00180             else
00181                _rightErrors[l] -= pData->getWeight(previousSplitPos->first, l);
00182          }
00183 
00184          previousSplitPos = currentSplitPos;
00185          ++currentSplitPos; 
00186       }
00187 
00188       currError = 0;
00189 
00190       for (int l = 0; l < numClasses; ++l)
00191       { 
00192          if ( pData->getClass( previousSplitPos->first ) == l )
00193          {
00194             // c=current class, x=other class
00195             // .. c | x x c x c x .. 
00196             _leftErrors[l] += pData->getWeight(previousSplitPos->first, l);
00197          }
00198          else
00199          {
00200             // c=current class, x=other class
00201             // .. x | x x c x c x ..
00202             _rightErrors[l] -= pData->getWeight(previousSplitPos->first, l);
00203          }
00204 
00205          tmpError = _rightErrors[l] + _leftErrors[l];
00206 
00207          // switch the class-wise error if it is bigger than chance
00208          if(tmpError > _halfWeightsPerClass[l] + _smallVal)
00209             tmpError = _weightsPerClass[l] - tmpError;
00210 
00211          currError += tmpError;
00212          assert(tmpError <= 0.5); // The summed error MUST be smaller than chance
00213       }
00214 
00215       // the overall error is smaller!
00216       if (currError < bestError + _smallVal)
00217       {
00218          bestError = currError;
00219          // compute the threshold
00220          threshold = ( previousSplitPos->second + currentSplitPos->second ) / 2;
00221 
00222          for (int l = 0; l < numClasses; ++l)
00223          { 
00224             _bestErrors[l] = _rightErrors[l] + _leftErrors[l];  
00225 
00226             // If we assume that class [l] is always on the right side,
00227             // here we must flip, as the lowest error is on the left side.
00228             // example:
00229             // c=current class, x=other class
00230             // .. c c c x | c x x x .. = 2 errors (if we flip!)
00231             if (_bestErrors[l] > _halfWeightsPerClass[l] + _smallVal)
00232             {
00233                // In the binary case would be (1-error)
00234                _bestErrors[l] = _weightsPerClass[l] - _bestErrors[l];
00235                v[l] = -1;
00236             }
00237             else
00238                v[l] = +1;
00239          }
00240 
00241       }
00242 
00243    }
00244 
00246 
00247    // Fill the mus. This could have been done in the threshold loop, 
00248    // but here is done just once
00249    for (int l = 0; l < numClasses; ++l)
00250    {
00251       mu[l].classIdx = l;
00252 
00253       mu[l].rPls  = _weightsPerClass[l]-_bestErrors[l];
00254       mu[l].rMin  = _bestErrors[l];
00255       mu[l].rZero = mu[l].rPls + mu[l].rMin;
00256    }
00257 
00258 }
00259 
00260 // -----------------------------------------------------------------------
00261 
00262 void SingleStumpLearner::save(ofstream& outputStream, const int numTabs)
00263 {
00264    // Calling the super-class method
00265    StumpLearner::save(outputStream, numTabs);
00266 
00267    // save selectedCoulumn
00268    outputStream << Serialization::standardTag("threshold", _threshold, numTabs) << endl;
00269 }
00270 
00271 // -----------------------------------------------------------------------
00272 
00273 void SingleStumpLearner::load(nor_utils::StreamTokenizer& st)
00274 {
00275    // Calling the super-class method
00276    StumpLearner::load(st);
00277 
00278    _threshold = UnSerialization::seekAndParseEnclosedValue<double>(st, "threshold");
00279 
00280 }
00281 
00282 // -----------------------------------------------------------------------
00283 
00284 } // end of namespace MultiBoost

Generated on Mon Nov 28 21:43:46 2005 for MultiBoost by  doxygen 1.4.5