00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024 #include <cassert>
00025 #include <limits>
00026 #include <cmath>
00027
00028 #include "WeakLearners/StumpLearner.h"
00029
00030 #include "Utils/Utils.h"
00031 #include "IO/Serialization.h"
00032 #include "IO/SortedData.h"
00033
00034 namespace MultiBoost {
00035
00036
00037
00038 void StumpLearner::initOptions(nor_utils::Args& args)
00039 {
00040
00041 if ( args.hasArgument("-edgeoffset") )
00042 args.getValue("-edgeoffset", 0, _theta);
00043
00044
00045 if ( args.hasArgument("-abstention") )
00046 {
00047 string abstType;
00048 args.getValue("-abstention", 0, abstType);
00049
00050 if (abstType == "greedy")
00051 _abstention = ABST_GREEDY;
00052 else if (abstType == "full")
00053 _abstention = ABST_FULL;
00054 else
00055 {
00056 cerr << "ERROR: Invalid type of abstention <" << abstType << ">!!" << endl;
00057 exit(1);
00058 }
00059 }
00060 }
00061
00062
00063
00064 void StumpLearner::declareArguments(nor_utils::Args& args)
00065 {
00066 args.declareArgument("-abstention",
00067 "Activate the abstention. Available types are:\n"
00068 " greedy: sorting and checking in O(k^2)\n"
00069 " full: the O(2^k) full search", 1, "<type>");
00070 }
00071
00072
00073
00074 InputData* StumpLearner::createInputData()
00075 {
00076 return new SortedData();
00077 }
00078
00079
00080
00081 char StumpLearner::classify(InputData* pData, const int idx, const int classIdx)
00082 {
00083 return _v[classIdx] * phi( pData->getValue(idx, _selectedColumn), classIdx );
00084 }
00085
00086
00087
00088 double StumpLearner::getEnergy(vector<sRates>& mu, double& alpha, vector<char>& v)
00089 {
00090 const int numClasses = ClassMappings::getNumClasses();
00091
00092 sRates eps;
00093
00094
00095 for (int l = 0; l < numClasses; ++l)
00096 {
00097 eps.rMin += mu[l].rMin;
00098 eps.rPls += mu[l].rPls;
00099 }
00100
00101
00102 assert( eps.rMin + eps.rPls <= 1 + _smallVal &&
00103 eps.rMin + eps.rPls >= 1 - _smallVal);
00104
00105 double currEnergy;
00106 if ( nor_utils::is_zero(_theta) )
00107 {
00108 alpha = getAlpha(eps.rMin, eps.rPls);
00109 currEnergy = 2 * sqrt( eps.rMin * eps.rPls );
00110 }
00111 else
00112 {
00113 alpha = getAlpha(eps.rMin, eps.rPls, _theta);
00114 currEnergy = exp( _theta * alpha ) *
00115 ( eps.rMin * exp(alpha) + eps.rPls * exp(alpha) );
00116 }
00117
00118
00119 switch(_abstention)
00120 {
00121 case ABST_GREEDY:
00122
00123 currEnergy = doGreedyAbstention(mu, currEnergy, eps, alpha, v);
00124 break;
00125 case ABST_FULL:
00126
00127 currEnergy = doFullAbstention(mu, currEnergy, eps, alpha, v);
00128 break;
00129 case ABST_NO_ABSTENTION:
00130 break;
00131 }
00132
00133
00134 if (eps.rMin >= eps.rPls)
00135 currEnergy = numeric_limits<double>::max();
00136
00137 return currEnergy;
00138 }
00139
00140
00141
00142 double StumpLearner::doGreedyAbstention(vector<sRates>& mu, double currEnergy,
00143 sRates& eps, double& alpha, vector<char>& v)
00144 {
00145 const int numClasses = ClassMappings::getNumClasses();
00146
00147
00148
00149
00150
00151
00152 sort(mu.begin(), mu.end());
00153
00154 bool changed;
00155 sRates newEps;
00156 double newAlpha;
00157 double newEnergy;
00158
00159 do
00160 {
00161 changed = false;
00162
00163 for (int l = 0; l < numClasses; ++l)
00164 {
00165 if ( v[ mu[l].classIdx ] != 0 )
00166 {
00167 newEps.rMin = eps.rMin - mu[l].rMin;
00168 newEps.rPls = eps.rPls - mu[l].rPls;
00169 newEps.rZero = eps.rZero + mu[l].rZero;
00170
00171 if ( nor_utils::is_zero(_theta) )
00172 {
00173 newEnergy = 2 * sqrt(newEps.rMin * newEps.rPls) + newEps.rZero;
00174 newAlpha = getAlpha(newEps.rMin, newEps.rPls);
00175 }
00176 else
00177 {
00178 newAlpha = getAlpha(newEps.rMin, newEps.rPls, _theta);
00179 newEnergy = exp( _theta * newAlpha ) *
00180 ( newEps.rPls * exp(-newAlpha) +
00181 newEps.rMin * exp(newAlpha) +
00182 newEps.rZero );
00183 }
00184
00185 if ( newEnergy < currEnergy + _smallVal)
00186 {
00187
00188 changed = true;
00189
00190 currEnergy = newEnergy;
00191 eps = newEps;
00192
00193 v[ mu[l].classIdx ] = 0;
00194 alpha = newAlpha;
00195
00196
00197 assert( eps.rMin + eps.rPls + eps.rZero <= 1 + _smallVal &&
00198 eps.rMin + eps.rPls + eps.rZero >= 1 - _smallVal );
00199 }
00200 }
00201 }
00202
00203 } while (changed);
00204
00205 return currEnergy;
00206 }
00207
00208
00209
00210 double StumpLearner::doFullAbstention(const vector<sRates>& mu, double currEnergy,
00211 sRates& eps, double& alpha, vector<char>& v)
00212 {
00213 const int numClasses = ClassMappings::getNumClasses();
00214
00215 vector<char> best(numClasses, 1);
00216 vector<char> candidate(numClasses);
00217 sRates newEps;
00218 double newAlpha;
00219 double newEnergy;
00220
00221 sRates bestEps;
00222
00223 for (int l = 1; l < numClasses; ++l)
00224 {
00225
00226
00227 fill( candidate.begin(), candidate.begin()+l, 0 );
00228 fill( candidate.begin()+l, candidate.end(), 1 );
00229
00230
00231 do {
00232
00233 newEps = eps;
00234
00235 for ( int j = 0; j < numClasses; ++j )
00236 {
00237 if ( candidate[j] == 0 )
00238 {
00239 newEps.rMin -= mu[j].rMin;
00240 newEps.rPls -= mu[j].rPls;
00241 newEps.rZero += mu[j].rZero;
00242 }
00243 }
00244
00245 if ( nor_utils::is_zero(_theta) )
00246 {
00247 newEnergy = 2 * sqrt(newEps.rMin * newEps.rPls) + newEps.rZero;
00248 newAlpha = getAlpha(newEps.rMin, newEps.rPls);
00249 }
00250 else
00251 {
00252 newAlpha = getAlpha(newEps.rMin, newEps.rPls, _theta);
00253 newEnergy = exp( _theta * newAlpha ) *
00254 ( newEps.rPls * exp(-newAlpha) +
00255 newEps.rMin * exp(newAlpha) +
00256 newEps.rZero );
00257 }
00258
00259 if ( newEnergy < currEnergy + _smallVal)
00260 {
00261 currEnergy = newEnergy;
00262
00263 best = candidate;
00264 alpha = newAlpha;
00265 bestEps = newEps;
00266
00267
00268 assert( newEps.rMin + newEps.rPls + newEps.rZero <= 1 + _smallVal &&
00269 newEps.rMin + newEps.rPls + newEps.rZero >= 1 - _smallVal );
00270 }
00271
00272 } while ( next_permutation(candidate.begin(), candidate.end()) );
00273
00274 }
00275
00276 for (int l = 0; l < numClasses; ++l)
00277 v[l] *= best[l];
00278
00279 eps = bestEps;
00280
00281 return currEnergy;
00282 }
00283
00284
00285
00286 void StumpLearner::save(ofstream& outputStream, const int numTabs)
00287 {
00288
00289 BaseLearner::save(outputStream, numTabs);
00290
00291
00292 outputStream << Serialization::standardTag("column", _selectedColumn, numTabs) << endl;
00293
00294 vector<int> vInt(_v.size());
00295 copy(_v.begin(), _v.end(), vInt.begin());
00296
00297 outputStream << Serialization::vectorTag("vArray", vInt, numTabs) << endl;
00298 }
00299
00300
00301
00302 void StumpLearner::load(nor_utils::StreamTokenizer& st)
00303 {
00304
00305 BaseLearner::load(st);
00306
00307 _selectedColumn = UnSerialization::seekAndParseEnclosedValue<int>(st, "column");
00308
00309
00310 string rawTag;
00311 string tag, tagParam, tagValue;
00312
00313
00314 vector<int> vInt;
00315
00316
00317 UnSerialization::seekAndParseVectorTag(st, "vArray", vInt);
00318 for (vector<int>::const_iterator it = vInt.begin(); it != vInt.end(); ++it)
00319 _v.push_back((char)*it);
00320
00321 }
00322
00323
00324
00325 }