00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00028 #ifndef __STUMP_LEARNER_H
00029 #define __STUMP_LEARNER_H
00030
00031 #include "WeakLearners/BaseLearner.h"
00032 #include "Utils/Args.h"
00033 #include "IO/InputData.h"
00034
00035 #include <vector>
00036 #include <fstream>
00037
00038 using namespace std;
00039
00042
00043 namespace MultiBoost {
00044
00050 class StumpLearner : public BaseLearner
00051 {
00052 public:
00053
00059 StumpLearner()
00060 : _theta(0), _abstention(ABST_NO_ABSTENTION), _selectedColumn(-1) {}
00061
00068 virtual void initOptions(nor_utils::Args& args);
00069
00083 virtual void declareArguments(nor_utils::Args& args);
00084
00094 virtual InputData* createInputData();
00095
00108 virtual char classify(InputData* pData, const int idx, const int classIdx);
00109
00110
00122 virtual void save(ofstream& outputStream, const int numTabs = 0);
00123
00131 virtual void load(nor_utils::StreamTokenizer& st);
00132
00133 protected:
00134
00144 virtual char phi(double val, int classIdx) = 0;
00145
00158 struct sRates
00159 {
00160 sRates() : classIdx(-1), rPls(0), rMin(0), rZero(0) {}
00161
00162 int classIdx;
00163
00164 double rPls;
00165 double rMin;
00166 double rZero;
00167
00173 bool operator<(const sRates& el) const
00174 {
00175 return el.rPls * el.rMin < this->rPls * this->rMin;
00176 }
00177 };
00178
00192 virtual double getEnergy(vector<sRates>& mu, double& alpha, vector<char>& v);
00193
00212 virtual double doGreedyAbstention(vector<sRates>& mu, double currEnergy,
00213 sRates& eps, double& alpha, vector<char>& v);
00214
00228 virtual double doFullAbstention(const vector<sRates>& mu, double currEnergy,
00229 sRates& eps, double& alpha, vector<char>& v);
00230
00244 vector<char> _v;
00245 int _selectedColumn;
00246
00247 double _theta;
00248
00255 enum eAbstType
00256 {
00257 ABST_NO_ABSTENTION,
00258 ABST_GREEDY,
00259 ABST_FULL
00260 };
00261 eAbstType _abstention;
00262
00263 vector<double> _rightErrors;
00264 vector<double> _leftErrors;
00265 vector<double> _bestErrors;
00266
00267 vector<double> _weightsPerClass;
00268 vector<double> _halfWeightsPerClass;
00269
00270 };
00271
00272
00273
00274
00275 }
00276
00277 #endif // __STUMP_LEARNER_H