00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00028 #ifndef __BASE_LEARNER_H
00029 #define __BASE_LEARNER_H
00030
00031 #include <algorithm>
00032 #include <vector>
00033
00034 #include "IO/InputData.h"
00035 #include "Utils/Args.h"
00036
00037 #include "Utils/StreamTokenizer.h"
00038
00039 using namespace std;
00040
00043
00044 namespace MultiBoost {
00045
00051 class BaseLearner
00052 {
00053
00054 private:
00055
00057
00066 class LearnersRegs
00067 {
00068 public:
00069
00077 void addLearner(const string& learnerName, BaseLearner* pLearnerToRegister)
00078 { _learners[learnerName] = pLearnerToRegister; }
00079
00085 bool hasLearner(const string& learnerName)
00086 { return ( _learners.find(learnerName) != _learners.end() ); }
00087
00093 BaseLearner* getLearner(const string& learnerName)
00094 { return _learners[learnerName]; }
00095
00101 void getList(vector<string>& learnersList)
00102 {
00103 learnersList.clear();
00104 learnersList.reserve(_learners.size());
00105 map<string, BaseLearner*>::const_iterator it;
00106 for (it = _learners.begin(); it != _learners.end(); ++it)
00107 learnersList.push_back( it->first );
00108 }
00109
00110 private:
00111 map<string, BaseLearner*> _learners;
00112 };
00113
00115
00116 public:
00117
00134 static LearnersRegs& RegisteredLearners()
00135 {
00136
00137
00138
00139
00140
00141 static LearnersRegs* regLerners = new LearnersRegs();
00142 return *regLerners;
00143 }
00144
00150 BaseLearner() : _smallVal(1E-10), _smoothingVal(_smallVal), _alpha(0) {}
00151
00160 virtual void initOptions(nor_utils::Args& args) {}
00161
00173 virtual void declareArguments(nor_utils::Args& args) = 0;
00174
00189 virtual BaseLearner* create() = 0;
00190
00200 virtual InputData* createInputData();
00201
00211 virtual void run(InputData* pData) = 0;
00212
00224 virtual char classify(InputData* pData, const int idx, const int classIdx) = 0;
00225
00236 const double getAlpha() const { return _alpha; }
00237
00248 virtual void save(ofstream& outputStream, const int numTabs = 0);
00249
00258 virtual void load(nor_utils::StreamTokenizer& st);
00259
00260 protected:
00261
00274 virtual void setSmoothingVal(const double smoothingVal) { _smoothingVal = smoothingVal; }
00275
00290 virtual double getAlpha(const double error);
00291
00311 virtual double getAlpha(const double eps_min, const double eps_pls);
00312
00340 virtual double getAlpha(const double eps_min, const double eps_pls, double theta);
00341
00347 double _smoothingVal;
00348
00349 double _alpha;
00350 const double _smallVal;
00351
00352 };
00353
00354
00355
00356 }
00357
00358
00363 #define REGISTER_LEARNER(X) \
00364 struct Register_##X \
00365 { Register_##X() { BaseLearner::RegisteredLearners().addLearner(#X, new X()); } }; \
00366 static Register_##X r_##X;
00367
00373 #define REGISTER_LEARNER_NAME(NAME, X) \
00374 struct Register_##X \
00375 { Register_##X() { BaseLearner::RegisteredLearners().addLearner(#NAME, new X()); } }; \
00376 static Register_##X r_##X;
00377
00378
00379 #endif // __BASE_LEARNER_H