src/WeakLearners/StumpLearner.cpp

00001 /*
00002 * This file is part of MultiBoost, a multi-class 
00003 * AdaBoost learner/classifier
00004 *
00005 * Copyright (C) 2005 Norman Casagrande
00006 * For informations write to nova77@gmail.com
00007 *
00008 * This library is free software; you can redistribute it and/or
00009 * modify it under the terms of the GNU Lesser General Public
00010 * License as published by the Free Software Foundation; either
00011 * version 2.1 of the License, or (at your option) any later version.
00012 *
00013 * This library is distributed in the hope that it will be useful,
00014 * but WITHOUT ANY WARRANTY; without even the implied warranty of
00015 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
00016 * Lesser General Public License for more details.
00017 *
00018 * You should have received a copy of the GNU Lesser General Public
00019 * License along with this library; if not, write to the Free Software
00020 * Foundation, Inc., 51 Franklin St, Fifth Floor, Boston, MA  02110-1301  USA
00021 *
00022 */
00023 
00024 #include <cassert>
00025 #include <limits> // for numeric_limits<>
00026 #include <cmath>
00027 
00028 #include "WeakLearners/StumpLearner.h"
00029 
00030 #include "Utils/Utils.h"
00031 #include "IO/Serialization.h"
00032 #include "IO/SortedData.h"
00033 
00034 namespace MultiBoost {
00035 
00036 // ------------------------------------------------------------------------------
00037 
00038 void StumpLearner::initOptions(nor_utils::Args& args)
00039 {
00040    // Set the value of theta
00041    if ( args.hasArgument("-edgeoffset") )
00042       args.getValue("-edgeoffset", 0, _theta);   
00043 
00044    // set abstention
00045    if ( args.hasArgument("-abstention") )
00046    {
00047       string abstType;
00048       args.getValue("-abstention", 0, abstType);
00049 
00050       if (abstType == "greedy")
00051          _abstention = ABST_GREEDY;
00052       else if (abstType == "full")
00053          _abstention = ABST_FULL;
00054       else
00055       {
00056          cerr << "ERROR: Invalid type of abstention <" << abstType << ">!!" << endl;
00057          exit(1);
00058       }
00059    }
00060 }
00061 
00062 // ------------------------------------------------------------------------------
00063 
00064 void StumpLearner::declareArguments(nor_utils::Args& args)
00065 {
00066    args.declareArgument("-abstention", 
00067                         "Activate the abstention. Available types are:\n"
00068                         "  greedy: sorting and checking in O(k^2)\n"
00069                         "  full: the O(2^k) full search", 1, "<type>");
00070 }
00071 
00072 // ------------------------------------------------------------------------------
00073 
00074 InputData* StumpLearner::createInputData()
00075 {
00076    return new SortedData();
00077 }
00078 
00079 // ------------------------------------------------------------------------------
00080 
00081 char StumpLearner::classify(InputData* pData, const int idx, const int classIdx)
00082 {
00083    return _v[classIdx] * phi( pData->getValue(idx, _selectedColumn), classIdx );
00084 }
00085 
00086 // ------------------------------------------------------------------------------
00087 
00088 double StumpLearner::getEnergy(vector<sRates>& mu, double& alpha, vector<char>& v)
00089 {
00090    const int numClasses = ClassMappings::getNumClasses();
00091 
00092    sRates eps;
00093 
00094    // Get the overall error and correct rates
00095    for (int l = 0; l < numClasses; ++l)
00096    {
00097       eps.rMin += mu[l].rMin;
00098       eps.rPls += mu[l].rPls;
00099    }
00100 
00101    // assert: eps- + eps+ + eps0 = 1
00102    assert( eps.rMin + eps.rPls <= 1 + _smallVal &&
00103            eps.rMin + eps.rPls >= 1 - _smallVal);
00104 
00105    double currEnergy;
00106    if ( nor_utils::is_zero(_theta) )
00107    {
00108       alpha = getAlpha(eps.rMin, eps.rPls);
00109       currEnergy = 2 * sqrt( eps.rMin * eps.rPls );
00110    }
00111    else
00112    {
00113       alpha = getAlpha(eps.rMin, eps.rPls, _theta);
00114       currEnergy = exp( _theta * alpha ) * 
00115                   ( eps.rMin * exp(alpha) + eps.rPls * exp(alpha) );
00116    }
00117 
00118    // perform abstention
00119    switch(_abstention)
00120    {
00121       case ABST_GREEDY:
00122          // alpha and v are updated!
00123          currEnergy = doGreedyAbstention(mu, currEnergy, eps, alpha, v);
00124          break;
00125       case ABST_FULL:
00126          // alpha and v are updated!
00127          currEnergy = doFullAbstention(mu, currEnergy, eps, alpha, v);
00128          break;
00129       case ABST_NO_ABSTENTION:
00130          break;
00131    }
00132 
00133    // Condition: eps_pls > eps_min!!
00134    if (eps.rMin >= eps.rPls)
00135       currEnergy = numeric_limits<double>::max();
00136 
00137    return currEnergy; // this is what we are trying to minimize: 2*sqrt(eps+*eps-)+eps0
00138 }
00139 
00140 // -----------------------------------------------------------------------
00141 
00142 double StumpLearner::doGreedyAbstention(vector<sRates>& mu, double currEnergy, 
00143                                         sRates& eps, double& alpha, vector<char>& v)
00144 {
00145    const int numClasses = ClassMappings::getNumClasses();
00146 
00147    // Abstention is performed by evaluating the class-wise error
00148    // and the case in which one element (the one with the highest mu_pls * mu_min value)
00149    // is ignored, that is has v[el] = 0
00150 
00151    // Sorting the energies for each vote
00152    sort(mu.begin(), mu.end());
00153 
00154    bool changed;
00155    sRates newEps;
00156    double newAlpha;
00157    double newEnergy;
00158 
00159    do
00160    {
00161       changed = false;
00162 
00163       for (int l = 0; l < numClasses; ++l)
00164       {
00165          if ( v[ mu[l].classIdx ] != 0 ) 
00166          {
00167             newEps.rMin = eps.rMin - mu[l].rMin;
00168             newEps.rPls = eps.rPls - mu[l].rPls;
00169             newEps.rZero = eps.rZero + mu[l].rZero;
00170 
00171             if ( nor_utils::is_zero(_theta) )
00172             {
00173                newEnergy = 2 * sqrt(newEps.rMin * newEps.rPls) + newEps.rZero;
00174                newAlpha = getAlpha(newEps.rMin, newEps.rPls);
00175             }
00176             else
00177             {
00178                newAlpha = getAlpha(newEps.rMin, newEps.rPls, _theta);
00179                newEnergy = exp( _theta * newAlpha ) *
00180                            ( newEps.rPls * exp(-newAlpha) + 
00181                              newEps.rMin * exp(newAlpha) + 
00182                              newEps.rZero );
00183             }
00184 
00185             if ( newEnergy < currEnergy + _smallVal)
00186             {
00187                // ok, this is v = 0!!
00188                changed = true;
00189 
00190                currEnergy = newEnergy;
00191                eps = newEps;
00192 
00193                v[ mu[l].classIdx ] = 0;
00194                alpha = newAlpha;
00195 
00196                // assert: eps- + eps+ + eps0 = 1
00197                assert( eps.rMin + eps.rPls + eps.rZero <= 1 + _smallVal &&
00198                        eps.rMin + eps.rPls + eps.rZero >= 1 - _smallVal );
00199             }
00200          } // if
00201       } //for
00202 
00203    } while (changed);
00204 
00205    return currEnergy;
00206 }
00207 
00208 // -----------------------------------------------------------------------
00209 
00210 double StumpLearner::doFullAbstention(const vector<sRates>& mu, double currEnergy, 
00211                                       sRates& eps, double& alpha, vector<char>& v)
00212 {
00213    const int numClasses = ClassMappings::getNumClasses();
00214 
00215    vector<char> best(numClasses, 1);
00216    vector<char> candidate(numClasses);
00217    sRates newEps; // candidate
00218    double newAlpha;
00219    double newEnergy;
00220 
00221    sRates bestEps;
00222 
00223    for (int l = 1; l < numClasses; ++l)
00224    {
00225       // starts with an array with just one 0 (and the rest 1), 
00226       // then two 0, then three 0, etc..
00227       fill( candidate.begin(), candidate.begin()+l, 0 );
00228       fill( candidate.begin()+l, candidate.end(), 1 );
00229 
00230       // checks all the possible permutations of such array
00231       do {
00232 
00233          newEps = eps;
00234 
00235          for ( int j = 0; j < numClasses; ++j )
00236          {
00237             if ( candidate[j] == 0 )
00238             {
00239                newEps.rMin -= mu[j].rMin;
00240                newEps.rPls -= mu[j].rPls;
00241                newEps.rZero += mu[j].rZero;
00242             }
00243          }
00244 
00245          if ( nor_utils::is_zero(_theta) )
00246          {
00247             newEnergy = 2 * sqrt(newEps.rMin * newEps.rPls) + newEps.rZero;
00248             newAlpha = getAlpha(newEps.rMin, newEps.rPls);
00249          }
00250          else
00251          {
00252             newAlpha = getAlpha(newEps.rMin, newEps.rPls, _theta);
00253             newEnergy = exp( _theta * newAlpha ) *
00254                         ( newEps.rPls * exp(-newAlpha) + 
00255                           newEps.rMin * exp(newAlpha) + 
00256                           newEps.rZero );
00257          }
00258 
00259          if ( newEnergy < currEnergy + _smallVal)
00260          {
00261             currEnergy = newEnergy;
00262 
00263             best = candidate;
00264             alpha = newAlpha;
00265             bestEps = newEps;
00266 
00267             // assert: eps- + eps+ + eps0 = 1
00268             assert( newEps.rMin + newEps.rPls + newEps.rZero <= 1 + _smallVal &&
00269                     newEps.rMin + newEps.rPls + newEps.rZero >= 1 - _smallVal );
00270          }
00271 
00272       } while ( next_permutation(candidate.begin(), candidate.end()) );
00273 
00274    }
00275 
00276    for (int l = 0; l < numClasses; ++l)
00277       v[l] *= best[l];
00278 
00279    eps = bestEps;
00280 
00281    return currEnergy; // this is what we are trying to minimize: 2*sqrt(eps+*eps-)+eps0
00282 }
00283 
00284 // -----------------------------------------------------------------------
00285 
00286 void StumpLearner::save(ofstream& outputStream, const int numTabs)
00287 {
00288    // Calling the super-class method
00289    BaseLearner::save(outputStream, numTabs);
00290 
00291    // save selectedCoulumn
00292    outputStream << Serialization::standardTag("column", _selectedColumn, numTabs) << endl;
00293 
00294    vector<int> vInt(_v.size());
00295    copy(_v.begin(), _v.end(), vInt.begin()); // copy to an integer array to allow serialization of a number
00296    // save the v vector
00297    outputStream << Serialization::vectorTag("vArray", vInt, numTabs) << endl;
00298 }
00299 
00300 // -----------------------------------------------------------------------
00301 
00302 void StumpLearner::load(nor_utils::StreamTokenizer& st)
00303 {
00304    // Calling the super-class method
00305    BaseLearner::load(st);
00306 
00307    _selectedColumn = UnSerialization::seekAndParseEnclosedValue<int>(st, "column");
00308 
00309    // move until vArray tag
00310    string rawTag;
00311    string tag, tagParam, tagValue;
00312  
00313    // need to use an int because the string stream will load characters for char
00314    vector<int> vInt;
00315 
00316    // load vArray data
00317    UnSerialization::seekAndParseVectorTag(st, "vArray", vInt);
00318    for (vector<int>::const_iterator it = vInt.begin(); it != vInt.end(); ++it)
00319       _v.push_back((char)*it);
00320 
00321 }
00322 
00323 // -----------------------------------------------------------------------
00324 
00325 } // end of namespace MultiBoost

Generated on Mon Nov 28 21:43:47 2005 for MultiBoost by  doxygen 1.4.5