00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024 #include "MultiStumpLearner.h"
00025
00026 #include "IO/Serialization.h"
00027 #include "IO/SortedData.h"
00028
00029 #include <limits>
00030 #include <cassert>
00031
00032 namespace MultiBoost {
00033
00034 REGISTER_LEARNER(MultiStumpLearner)
00035
00036
00037
00038 void MultiStumpLearner::run(InputData* pData)
00039 {
00040 const int numClasses = ClassMappings::getNumClasses();
00041 const int numColumns = pData->getNumColumns();
00042
00043
00044
00045 setSmoothingVal( 1.0 / (double)pData->getNumExamples() * 0.01 );
00046
00047
00048 _leftErrors.resize(numClasses);
00049 _rightErrors.resize(numClasses);
00050 _bestErrors.resize(numClasses);
00051 _weightsPerClass.resize(numClasses);
00052 _halfWeightsPerClass.resize(numClasses);
00053
00054 vector<sRates> mu(numClasses);
00055
00056 vector<char> tmpV(numClasses);
00057 vector<double> tmpThresholds(numClasses);
00058 double tmpAlpha;
00059
00060 double bestE = numeric_limits<double>::max();
00061 double tmpE;
00062
00063 for (int j = 0; j < numColumns; ++j)
00064 {
00065 findThresholds(pData, j, tmpThresholds, mu, tmpV);
00066
00067 tmpE = getEnergy(mu, tmpAlpha, tmpV);
00068 if (tmpE < bestE)
00069 {
00070
00071
00072
00073
00074
00075
00076
00077 _alpha = tmpAlpha;
00078 _v = tmpV;
00079 _selectedColumn = j;
00080 _thresholds = tmpThresholds;
00081
00082 bestE = tmpE;
00083 }
00084
00085 }
00086
00087 }
00088
00089
00090
00091 char MultiStumpLearner::phi(double val, int classIdx)
00092 {
00093 if (val > _thresholds[classIdx])
00094 return +1;
00095 else
00096 return -1;
00097 }
00098
00099
00100
00101 void MultiStumpLearner::findThresholds(InputData* pData, const int columnIdx,
00102 vector<double>& thresholds,
00103 vector<sRates>& mu, vector<char>& v)
00104 {
00105 const vpIterator dataBegin = static_cast<SortedData*>(pData)->getSortedBegin(columnIdx);
00106 const vpIterator dataEnd = static_cast<SortedData*>(pData)->getSortedEnd(columnIdx);
00107
00108 const int numClasses = ClassMappings::getNumClasses();
00109
00110
00111 fill(_leftErrors.begin(), _leftErrors.end(), 0);
00112 fill(_weightsPerClass.begin(), _weightsPerClass.end(), 0);
00113
00114
00115 const size_t numPoints = dataEnd - dataBegin;
00116
00117 vpIterator currentSplitPos;
00118 vpIterator previousSplitPos;
00119 cvpIterator endArray;
00120
00122
00123
00124
00125 double tmpRightError;
00126
00127 for (int l = 0; l < numClasses; ++l)
00128 {
00129 tmpRightError = 0;
00130
00131 for( currentSplitPos = dataBegin; currentSplitPos != dataEnd; ++currentSplitPos)
00132 {
00133 double weight = pData->getWeight(currentSplitPos->first, l);
00134
00135
00136
00137
00138
00139
00140 if ( pData->getClass(currentSplitPos->first) != l )
00141 tmpRightError += weight;
00142
00143 _weightsPerClass[l] += weight;
00144 }
00145
00146 _halfWeightsPerClass[l] = _weightsPerClass[l] / 2;
00147
00148 assert(tmpRightError < 1);
00149 _rightErrors[l] = tmpRightError;
00150 _bestErrors[l] = numeric_limits<double>::max();
00151 }
00152
00154
00155 currentSplitPos = dataBegin;
00156 endArray = dataEnd;
00157 --endArray;
00158
00159 double tmpError = 0;
00160 double bestError = numeric_limits<double>::max();
00161 bool flipIt;
00162
00163
00164 while (currentSplitPos != endArray)
00165 {
00166
00167
00168
00169 previousSplitPos = currentSplitPos;
00170 ++currentSplitPos;
00171
00172
00173 while ( previousSplitPos->second == currentSplitPos->second && currentSplitPos != endArray)
00174 {
00175 for (int l = 0; l < numClasses; ++l)
00176 {
00177 if ( pData->getClass( previousSplitPos->first ) == l )
00178 _leftErrors[l] += pData->getWeight(previousSplitPos->first, l);
00179 else
00180 _rightErrors[l] -= pData->getWeight(previousSplitPos->first, l);
00181 }
00182
00183 previousSplitPos = currentSplitPos;
00184 ++currentSplitPos;
00185 }
00186
00187 for (int l = 0; l < numClasses; ++l)
00188 {
00189 if ( pData->getClass( previousSplitPos->first ) == l )
00190 {
00191
00192
00193 _leftErrors[l] += pData->getWeight(previousSplitPos->first, l);
00194 }
00195 else
00196 {
00197
00198
00199 _rightErrors[l] -= pData->getWeight(previousSplitPos->first, l);
00200 }
00201
00202 tmpError = _rightErrors[l] + _leftErrors[l];
00203
00204
00205 if(tmpError > _halfWeightsPerClass[l] + _smallVal)
00206 {
00207 tmpError = _weightsPerClass[l] - tmpError;
00208 flipIt = true;
00209 }
00210 else
00211 flipIt = false;
00212
00213
00214 assert(tmpError <= _halfWeightsPerClass[l] + _smallVal);
00215
00216
00217 if (tmpError < _bestErrors[l] + _smallVal)
00218 {
00219 _bestErrors[l] = tmpError;
00220
00221 thresholds[l] = ( previousSplitPos->second + currentSplitPos->second ) / 2;
00222
00223
00224
00225
00226
00227
00228 if (flipIt)
00229 v[l] = -1;
00230 else
00231 v[l] = +1;
00232
00233 }
00234 }
00235 }
00236
00238
00239
00240
00241 for (int l = 0; l < numClasses; ++l)
00242 {
00243 mu[l].classIdx = l;
00244
00245 mu[l].rPls = _weightsPerClass[l]-_bestErrors[l];
00246 mu[l].rMin = _bestErrors[l];
00247 mu[l].rZero = mu[l].rPls + mu[l].rMin;
00248 }
00249
00250 }
00251
00252
00253
00254 void MultiStumpLearner::save(ofstream& outputStream, const int numTabs)
00255 {
00256
00257 StumpLearner::save(outputStream, numTabs);
00258
00259
00260 outputStream << Serialization::vectorTag("thArray", _thresholds, numTabs) << endl;
00261 }
00262
00263
00264
00265 void MultiStumpLearner::load(nor_utils::StreamTokenizer& st)
00266 {
00267
00268 StumpLearner::load(st);
00269
00270
00271 UnSerialization::seekAndParseVectorTag(st, "thArray", _thresholds);
00272 }
00273
00274
00275
00276 }