00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020
00021
00022
00023
00024 #include "SingleStumpLearner.h"
00025
00026 #include "IO/Serialization.h"
00027 #include "IO/SortedData.h"
00028
00029 #include <limits>
00030 #include <cassert>
00031
00032 namespace MultiBoost {
00033
00034
00035
00036 REGISTER_LEARNER(SingleStumpLearner)
00037
00038
00039
00040 void SingleStumpLearner::run(InputData* pData)
00041 {
00042 const int numClasses = ClassMappings::getNumClasses();
00043 const int numColumns = pData->getNumColumns();
00044
00045
00046
00047 setSmoothingVal( 1.0 / (double)pData->getNumExamples() * 0.01 );
00048
00049
00050 _leftErrors.resize(numClasses);
00051 _rightErrors.resize(numClasses);
00052 _bestErrors.resize(numClasses);
00053 _weightsPerClass.resize(numClasses);
00054 _halfWeightsPerClass.resize(numClasses);
00055
00056 vector<sRates> mu(numClasses);
00057 vector<char> tmpV(numClasses);
00058
00059 double tmpThreshold;
00060 double tmpAlpha;
00061
00062 double bestE = numeric_limits<double>::max();
00063 double tmpE;
00064
00065 for (int j = 0; j < numColumns; ++j)
00066 {
00067 findThreshold(pData, j, tmpThreshold, mu, tmpV);
00068
00069 tmpE = getEnergy(mu, tmpAlpha, tmpV);
00070 if (tmpE < bestE)
00071 {
00072
00073
00074
00075
00076
00077
00078
00079 _alpha = tmpAlpha;
00080 _v = tmpV;
00081 _selectedColumn = j;
00082 _threshold = tmpThreshold;
00083
00084 bestE = tmpE;
00085 }
00086
00087 }
00088
00089 }
00090
00091
00092
00093 char SingleStumpLearner::phi(double val, int classIdx)
00094 {
00095 if (val > _threshold)
00096 return +1;
00097 else
00098 return -1;
00099 }
00100
00101
00102
00103 void SingleStumpLearner::findThreshold(InputData* pData, const int columnIdx,
00104 double& threshold,
00105 vector<sRates>& mu, vector<char>& v)
00106 {
00107 const vpIterator dataBegin = static_cast<SortedData*>(pData)->getSortedBegin(columnIdx);
00108 const vpIterator dataEnd = static_cast<SortedData*>(pData)->getSortedEnd(columnIdx);
00109
00110 const int numClasses = ClassMappings::getNumClasses();
00111
00112
00113 fill(_leftErrors.begin(), _leftErrors.end(), 0);
00114 fill(_weightsPerClass.begin(), _weightsPerClass.end(), 0);
00115
00116
00117 const size_t numPoints = dataEnd - dataBegin;
00118
00119 vpIterator currentSplitPos;
00120 vpIterator previousSplitPos;
00121 cvpIterator endArray;
00122
00124
00125
00126
00127 double tmpRightError;
00128
00129 for (int l = 0; l < numClasses; ++l)
00130 {
00131 tmpRightError = 0;
00132
00133 for( currentSplitPos = dataBegin; currentSplitPos != dataEnd; ++currentSplitPos)
00134 {
00135 double weight = pData->getWeight(currentSplitPos->first, l);
00136
00137
00138
00139
00140
00141
00142 if ( pData->getClass(currentSplitPos->first) != l )
00143 tmpRightError += weight;
00144
00145 _weightsPerClass[l] += weight;
00146 }
00147
00148 _halfWeightsPerClass[l] = _weightsPerClass[l] / 2;
00149
00150 assert(tmpRightError < 1);
00151 _rightErrors[l] = tmpRightError;
00152 }
00153
00155
00156 currentSplitPos = dataBegin;
00157 endArray = dataEnd;
00158 --endArray;
00159
00160 double tmpError = 0;
00161 double currError = 0;
00162 double bestError = numeric_limits<double>::max();
00163
00164
00165 while (currentSplitPos != endArray)
00166 {
00167
00168
00169
00170 previousSplitPos = currentSplitPos;
00171 ++currentSplitPos;
00172
00173
00174 while ( previousSplitPos->second == currentSplitPos->second && currentSplitPos != endArray)
00175 {
00176 for (int l = 0; l < numClasses; ++l)
00177 {
00178 if ( pData->getClass( previousSplitPos->first ) == l )
00179 _leftErrors[l] += pData->getWeight(previousSplitPos->first, l);
00180 else
00181 _rightErrors[l] -= pData->getWeight(previousSplitPos->first, l);
00182 }
00183
00184 previousSplitPos = currentSplitPos;
00185 ++currentSplitPos;
00186 }
00187
00188 currError = 0;
00189
00190 for (int l = 0; l < numClasses; ++l)
00191 {
00192 if ( pData->getClass( previousSplitPos->first ) == l )
00193 {
00194
00195
00196 _leftErrors[l] += pData->getWeight(previousSplitPos->first, l);
00197 }
00198 else
00199 {
00200
00201
00202 _rightErrors[l] -= pData->getWeight(previousSplitPos->first, l);
00203 }
00204
00205 tmpError = _rightErrors[l] + _leftErrors[l];
00206
00207
00208 if(tmpError > _halfWeightsPerClass[l] + _smallVal)
00209 tmpError = _weightsPerClass[l] - tmpError;
00210
00211 currError += tmpError;
00212 assert(tmpError <= 0.5);
00213 }
00214
00215
00216 if (currError < bestError + _smallVal)
00217 {
00218 bestError = currError;
00219
00220 threshold = ( previousSplitPos->second + currentSplitPos->second ) / 2;
00221
00222 for (int l = 0; l < numClasses; ++l)
00223 {
00224 _bestErrors[l] = _rightErrors[l] + _leftErrors[l];
00225
00226
00227
00228
00229
00230
00231 if (_bestErrors[l] > _halfWeightsPerClass[l] + _smallVal)
00232 {
00233
00234 _bestErrors[l] = _weightsPerClass[l] - _bestErrors[l];
00235 v[l] = -1;
00236 }
00237 else
00238 v[l] = +1;
00239 }
00240
00241 }
00242
00243 }
00244
00246
00247
00248
00249 for (int l = 0; l < numClasses; ++l)
00250 {
00251 mu[l].classIdx = l;
00252
00253 mu[l].rPls = _weightsPerClass[l]-_bestErrors[l];
00254 mu[l].rMin = _bestErrors[l];
00255 mu[l].rZero = mu[l].rPls + mu[l].rMin;
00256 }
00257
00258 }
00259
00260
00261
00262 void SingleStumpLearner::save(ofstream& outputStream, const int numTabs)
00263 {
00264
00265 StumpLearner::save(outputStream, numTabs);
00266
00267
00268 outputStream << Serialization::standardTag("threshold", _threshold, numTabs) << endl;
00269 }
00270
00271
00272
00273 void SingleStumpLearner::load(nor_utils::StreamTokenizer& st)
00274 {
00275
00276 StumpLearner::load(st);
00277
00278 _threshold = UnSerialization::seekAndParseEnclosedValue<double>(st, "threshold");
00279
00280 }
00281
00282
00283
00284 }