multi_class.h

Go to the documentation of this file.
00001 #include "hazy_sgd.h"
00006 // NB:
00007 // the updates use labels in {-1,1} but 
00008 // we return labels in {0,1}.. this causes problems. You'd think we would have fixed that by now.
00009 template <class T>
00013 class Node {   
00014  protected:
00016   Hazy_Sgd<T>* _classifier;  
00017  public:  
00023   virtual void single_entity_read(int entity_id, int &sClass) = 0;
00029   virtual void insert_example(int label, T& vec) = 0;
00033   virtual ~Node() { //delete _classifier; 
00034   }
00035 };
00036 
00040 template <class T>
00041 class MultiClass_Internal_Node : public Node<T> {
00042   Node<T> *l, *r;
00043   int search_key;
00049   int find_direction(int label) {return (label <= search_key) ? 0 : 1;}
00050  public:
00058  MultiClass_Internal_Node(int search_key, Node<T> *l, Node<T> *r, Hazy_Sgd<T>* s) :
00059   l(l), r(r), search_key(search_key) {
00060     this->_classifier = s;
00061     
00062   }
00068   void single_entity_read(int entity_id, int &sClass) {
00069     int this_class = 0;
00070     this->_classifier->readEntityClass(entity_id, this_class);
00071     if (this_class == 0) {
00072       l->single_entity_read(entity_id, sClass);
00073     } else {
00074       r->single_entity_read(entity_id, sClass);
00075     }
00076   }
00082   void insert_example(int label, T& vec) {
00083     int direction = find_direction(label);
00084     LOGGING_ONLY(Timer total_time(true););
00085     this->_classifier->updateModel(vec, (direction == 0) ? -1 : 1);
00086     LOGGING_ONLY(std::cout << "\tUpdating with " << label << " sk=" << search_key << " train time=" << total_time.stop() << " direction=" << direction << std::endl;)
00087       if(direction == 0) {
00088         l->insert_example(label, vec);
00089       } else {
00090         r->insert_example(label, vec);
00091       }
00092     LOGGING_ONLY(std::cout << "\t[Multiclass] Insert to children for " << search_key << " took " << total_time.stop() << std::endl;);
00093   }
00094 };
00095 
00099 template
00100 <class T>
00101 class MultiClass_Leaf_Node : public Node<T> {
00102   int global_label, local_label;
00103   
00104  public:
00111   MultiClass_Leaf_Node(int global_label, int local_label, Hazy_Sgd<T>* s) 
00112     { this->global_label = global_label; this->local_label = local_label; this->_classifier = s; }
00118   void single_entity_read(int entity_id, int &sClass) {sClass = global_label;}
00124   void insert_example(int label, T& vec) { this->_classifier->updateModel(vec, (local_label == 0) ? -1 : 1); } 
00125 };
00126 
00127 // models are indexed by search key which we assume are dense.
00136 template 
00137 <class T>
00138 Node<T>*
00139 construct_tree(Hazy_Sgd<T> **models, int low_value, int high_value, int parents_search_key, int parents_label_of_us) {
00140   assert(low_value <= high_value);
00141   if(low_value == high_value) {
00142     // here we are leaf
00143     return new MultiClass_Leaf_Node<T>(low_value, parents_label_of_us, models[parents_search_key]);
00144   } else {
00145     // build our children
00146     int mid_point = (high_value + low_value)/2;
00147     int our_search_key = mid_point;
00148     VERBOSE_ONLY(std::cout << "\t our_search_key = " << our_search_key << " [" << low_value << ", " << high_value << "] " << std::endl;);
00149     VERBOSE_ONLY(std::cout << "\t\tmidpoint = " << mid_point << std::endl;);
00150     VERBOSE_ONLY(std::cout << "\t\tparent=" << parents_search_key << " parents opinion=" << parents_label_of_us << std::endl;);
00151     Node<T> *l = construct_tree<T>(models, low_value  , mid_point , our_search_key, 0);
00152     Node<T> *r = construct_tree<T>(models, mid_point+1, high_value, our_search_key, 1);
00153     return new MultiClass_Internal_Node<T>(our_search_key, l, r, models[our_search_key]);
00154   }
00155 }
00156 
00162 template
00163 <class T>
00164 Node<T> *
00165 build_root_node(Hazy_Sgd<T> **models, int nLabels) {
00166   int parent_id = 1;
00167   // find the next highest power of two.
00168   while(parent_id < nLabels) { parent_id = parent_id << 1; }
00169   return construct_tree<T>(models, 0, nLabels - 1, parent_id, 0);
00170 }
00171 // constructr_tree(models, 0, 7, 8, 0);
00172 
00173 
00179 template
00180 <class T>
00181 Node<T> *
00182 build_ova(Hazy_Sgd<T> **models, int nLabels) {
00183   DEBUG_ONLY(std::cout << "Building one-versus-all tree" << std::endl;);
00184   Node<T>** leaf = new Node<T>*[nLabels];
00185   for(int i = 0; i < nLabels - 1; i++) {
00186     leaf[i]   = new MultiClass_Leaf_Node<T>(i, 0, models[i]);  
00187   }
00188   leaf[nLabels-1] = new MultiClass_Leaf_Node<T>(nLabels-1, 1, models[nLabels-2]); 
00189   Node<T> *last_node = leaf[nLabels - 1];
00190   for(int i = nLabels - 2; i >= 0 ; i--) {
00191     Node<T> *this_node = new MultiClass_Internal_Node<T>(i, leaf[i], last_node, models[i]);
00192     last_node = this_node;
00193   }
00194   
00195   DEBUG_ONLY(std::cout << "Finished building one-versus-all tree" << std::endl;);
00196   return last_node;
00197 }

Generated on Wed Dec 15 10:46:15 2010 for Hazy_System by  doxygen 1.4.7