00001 #include "hazy_sgd.h"
00006
00007
00008
00009 template <class T>
00013 class Node {
00014 protected:
00016 Hazy_Sgd<T>* _classifier;
00017 public:
00023 virtual void single_entity_read(int entity_id, int &sClass) = 0;
00029 virtual void insert_example(int label, T& vec) = 0;
00033 virtual ~Node() {
00034 }
00035 };
00036
00040 template <class T>
00041 class MultiClass_Internal_Node : public Node<T> {
00042 Node<T> *l, *r;
00043 int search_key;
00049 int find_direction(int label) {return (label <= search_key) ? 0 : 1;}
00050 public:
00058 MultiClass_Internal_Node(int search_key, Node<T> *l, Node<T> *r, Hazy_Sgd<T>* s) :
00059 l(l), r(r), search_key(search_key) {
00060 this->_classifier = s;
00061
00062 }
00068 void single_entity_read(int entity_id, int &sClass) {
00069 int this_class = 0;
00070 this->_classifier->readEntityClass(entity_id, this_class);
00071 if (this_class == 0) {
00072 l->single_entity_read(entity_id, sClass);
00073 } else {
00074 r->single_entity_read(entity_id, sClass);
00075 }
00076 }
00082 void insert_example(int label, T& vec) {
00083 int direction = find_direction(label);
00084 LOGGING_ONLY(Timer total_time(true););
00085 this->_classifier->updateModel(vec, (direction == 0) ? -1 : 1);
00086 LOGGING_ONLY(std::cout << "\tUpdating with " << label << " sk=" << search_key << " train time=" << total_time.stop() << " direction=" << direction << std::endl;)
00087 if(direction == 0) {
00088 l->insert_example(label, vec);
00089 } else {
00090 r->insert_example(label, vec);
00091 }
00092 LOGGING_ONLY(std::cout << "\t[Multiclass] Insert to children for " << search_key << " took " << total_time.stop() << std::endl;);
00093 }
00094 };
00095
00099 template
00100 <class T>
00101 class MultiClass_Leaf_Node : public Node<T> {
00102 int global_label, local_label;
00103
00104 public:
00111 MultiClass_Leaf_Node(int global_label, int local_label, Hazy_Sgd<T>* s)
00112 { this->global_label = global_label; this->local_label = local_label; this->_classifier = s; }
00118 void single_entity_read(int entity_id, int &sClass) {sClass = global_label;}
00124 void insert_example(int label, T& vec) { this->_classifier->updateModel(vec, (local_label == 0) ? -1 : 1); }
00125 };
00126
00127
00136 template
00137 <class T>
00138 Node<T>*
00139 construct_tree(Hazy_Sgd<T> **models, int low_value, int high_value, int parents_search_key, int parents_label_of_us) {
00140 assert(low_value <= high_value);
00141 if(low_value == high_value) {
00142
00143 return new MultiClass_Leaf_Node<T>(low_value, parents_label_of_us, models[parents_search_key]);
00144 } else {
00145
00146 int mid_point = (high_value + low_value)/2;
00147 int our_search_key = mid_point;
00148 VERBOSE_ONLY(std::cout << "\t our_search_key = " << our_search_key << " [" << low_value << ", " << high_value << "] " << std::endl;);
00149 VERBOSE_ONLY(std::cout << "\t\tmidpoint = " << mid_point << std::endl;);
00150 VERBOSE_ONLY(std::cout << "\t\tparent=" << parents_search_key << " parents opinion=" << parents_label_of_us << std::endl;);
00151 Node<T> *l = construct_tree<T>(models, low_value , mid_point , our_search_key, 0);
00152 Node<T> *r = construct_tree<T>(models, mid_point+1, high_value, our_search_key, 1);
00153 return new MultiClass_Internal_Node<T>(our_search_key, l, r, models[our_search_key]);
00154 }
00155 }
00156
00162 template
00163 <class T>
00164 Node<T> *
00165 build_root_node(Hazy_Sgd<T> **models, int nLabels) {
00166 int parent_id = 1;
00167
00168 while(parent_id < nLabels) { parent_id = parent_id << 1; }
00169 return construct_tree<T>(models, 0, nLabels - 1, parent_id, 0);
00170 }
00171
00172
00173
00179 template
00180 <class T>
00181 Node<T> *
00182 build_ova(Hazy_Sgd<T> **models, int nLabels) {
00183 DEBUG_ONLY(std::cout << "Building one-versus-all tree" << std::endl;);
00184 Node<T>** leaf = new Node<T>*[nLabels];
00185 for(int i = 0; i < nLabels - 1; i++) {
00186 leaf[i] = new MultiClass_Leaf_Node<T>(i, 0, models[i]);
00187 }
00188 leaf[nLabels-1] = new MultiClass_Leaf_Node<T>(nLabels-1, 1, models[nLabels-2]);
00189 Node<T> *last_node = leaf[nLabels - 1];
00190 for(int i = nLabels - 2; i >= 0 ; i--) {
00191 Node<T> *this_node = new MultiClass_Internal_Node<T>(i, leaf[i], last_node, models[i]);
00192 last_node = this_node;
00193 }
00194
00195 DEBUG_ONLY(std::cout << "Finished building one-versus-all tree" << std::endl;);
00196 return last_node;
00197 }