incremental_sgd.hxx

Go to the documentation of this file.
00001 //#include "incremental_sgd.h"
00002 // This file is included in the header function.
00004 #define HINGELOSS 1
00005 #define SMOOTHHINGELOSS 2
00006 #define SQUAREDHINGELOSS 3
00007 #define LOGLOSS 10
00008 #define LOGLOSSMARGIN 11
00009 
00011 #define LOSS HINGELOSS
00012 
00015 #define BIAS 1
00016 
00017 
00021 inline
00022 double loss(double z)
00023 {
00024 #if LOSS == LOGLOSS
00025   if (z > 18)
00026     return exp(-z);
00027   if (z < -18)
00028     return -z;
00029   return log(1+exp(-z));
00030 #elif LOSS == LOGLOSSMARGIN
00031   if (z > 18)
00032     return exp(1-z);
00033   if (z < -18)
00034     return 1-z;
00035   return log(1+exp(1-z));
00036 #elif LOSS == SMOOTHHINGELOSS
00037   if (z < 0)
00038     return 0.5 - z;
00039   if (z < 1)
00040     return 0.5 * (1-z) * (1-z);
00041   return 0;
00042 #elif LOSS == SQUAREDHINGELOSS
00043   if (z < 1)
00044     return 0.5 * (1 - z) * (1 - z);
00045   return 0;
00046 #elif LOSS == HINGELOSS
00047   if (z < 1)
00048     return 1 - z;
00049   return 0;
00050 #else
00051 # error "Undefined loss"
00052 #endif
00053 }
00054 
00055 
00059 inline 
00060 double dloss(double z)
00061 { 
00062 #if LOSS == LOGLOSS
00063   if (z > 18)
00064     return exp(-z);
00065   if (z < -18)
00066     return 1;
00067   return 1 / (exp(z) + 1);
00068 #elif LOSS == LOGLOSSMARGIN
00069   if (z > 18)
00070     return exp(1-z);
00071   if (z < -18)
00072     return 1;
00073   return 1 / (exp(z-1) + 1);
00074 #elif LOSS == SMOOTHHINGELOSS
00075   if (z < 0)
00076     return 1;
00077   if (z < 1)
00078     return 1-z;
00079   return 0;
00080 #elif LOSS == SQUAREDHINGELOSS
00081   if (z < 1)
00082     return (1 - z);
00083   return 0;
00084 #else
00085   if (z < 1)
00086     return 1;
00087   return 0;
00088 #endif
00089 }
00090 
00096 template<class T>
00097 IncrementalSGD<T>::
00098 IncrementalSGD(int dim, double l) : lambda(l),_m(dim) { 
00099   // Shift t in order to have a
00100   // reasonable initial learning rate.
00101   // This assumes |x| \approx 1.
00102   double maxw = 1.0 / sqrt(lambda);
00103   double typw = sqrt(maxw);
00104   double eta0 = typw / std::max(1.0,dloss(-typw));
00105   _m.t = 1 / (eta0 * lambda);    
00106 } 
00107 
00108 template<class T>
00109 void
00110 IncrementalSGD<T>::resetModel() {
00111   _m.resetModel();
00112   double maxw = 1.0 / sqrt(lambda);
00113   double typw = sqrt(maxw);
00114   double eta0 = typw / std::max(1.0,dloss(-typw));
00115   _m.t = 1 / (eta0 * lambda);    
00116 }
00117 
00123 template<class T>
00124 void
00125 IncrementalSGD<T>::oneStep(double y, T x) {
00126   double eta = 1.0 / (lambda * _m.t);
00127   double s = 1 - eta * lambda;
00128   _m.wscale *= s;
00129   if (_m.wscale < 1e-9) {
00130     _m.w.scale(_m.wscale);
00131     _m.wscale = 1;
00132   }
00133   double wx = dot(_m.w,x) * _m.wscale;
00134   double z = y * (wx + _m.bias);
00135 #if LOSS < LOGLOSS
00136   if (z < 1)
00137 #endif
00138     {
00139       double etd = eta * dloss(z);
00140       _m.w.add(x, etd * y / _m.wscale);
00141 #if BIAS
00142       // Slower rate on the bias because
00143       // it learns at each iteration.
00144       _m.bias += etd * y * 0.01;
00145 #endif
00146     }
00147   _m.t += 1;
00148 }
00149 
00155 template<class T>
00156 bool
00157 IncrementalSGD<T>::classifyExample(T x) {
00158   double wx = dot(_m.w,x) * _m.wscale;
00159   double product = wx + _m.bias;
00160   return product > 0;
00161 }
00162 
00163 
00170 // TrueSGD
00171 template<class T>
00172 bool
00173 TrueSGD<T>::addExample(int y, T ex) { oneStep(y,ex); return true; }
00174 
00181 template<class T>
00182 bool 
00183 ReservoirSGD<T>::addExample(int y, T ex)
00184 { 
00185   _r.addExample(y,ex);
00186   if(!_r.isFull()) { return false; }      
00187   // otherwise we have work to do.
00188   for(int i = 0; i < RESERVOIR_ITERATIONS; i ++) {
00189     for(int j = 0; j < _r.getCurrentSize(); j ++) {
00190       oneStep(_r.getExampleClass(j), _r.getExampleFeatureVector(j));        
00191     }
00192   }
00193   
00194   return true;
00195 }

Generated on Wed Dec 15 10:46:15 2010 for Hazy_System by  doxygen 1.4.7