Hierarchical.cpp

Go to the documentation of this file.
00001 /*
00002  * This program is free software; you can redistribute it and/or modify
00003  * it under the terms of the GNU General Public License as published by
00004  * the Free Software Foundation; either version 3 of the License, or
00005  * (at your option) any later version.
00006  *
00007  * Written (W) 2007-2008 Soeren Sonnenburg
00008  * Copyright (C) 2007-2008 Fraunhofer Institute FIRST and Max-Planck-Society
00009  */
00010 
00011 #include "clustering/Hierarchical.h"
00012 #include "distance/Distance.h"
00013 #include "features/Labels.h"
00014 #include "features/RealFeatures.h"
00015 #include "lib/Mathematics.h"
00016 #include "base/Parallel.h"
00017 
00018 #ifndef WIN32
00019 #include <pthread.h>
00020 #endif
00021 
00022 struct pair
00023 {
00025     int32_t idx1;
00027     int32_t idx2;
00028 };
00029 
00030 CHierarchical::CHierarchical()
00031 : CDistanceMachine(), merges(3), dimensions(0), assignment(NULL),
00032     table_size(0), pairs(NULL), merge_distance(NULL)
00033 {
00034 }
00035 
00036 CHierarchical::CHierarchical(int32_t merges_, CDistance* d)
00037 : CDistanceMachine(), merges(merges_), dimensions(0), assignment(NULL),
00038     table_size(0), pairs(NULL), merge_distance(NULL)
00039 {
00040     set_distance(d);
00041 }
00042 
00043 CHierarchical::~CHierarchical()
00044 {
00045     delete[] merge_distance;
00046     delete[] assignment;
00047     delete[] pairs;
00048 }
00049 
00050 bool CHierarchical::train()
00051 {
00052     ASSERT(distance);
00053     CFeatures* lhs=distance->get_lhs();
00054     ASSERT(lhs);
00055 
00056     int32_t num=lhs->get_num_vectors();
00057     ASSERT(num>0);
00058 
00059     const int32_t num_pairs=num*(num-1)/2;
00060 
00061     delete[] merge_distance;
00062     merge_distance=new float64_t[num];
00063     CMath::fill_vector(merge_distance, num, -1.0);
00064 
00065     delete[] assignment;
00066     assignment=new int32_t[num];
00067     CMath::range_fill_vector(assignment, num);
00068 
00069     delete[] pairs;
00070     pairs=new int32_t[2*num];
00071     CMath::fill_vector(pairs, 2*num, -1);
00072 
00073     pair* index=new pair[num_pairs];
00074     float64_t* distances=new float64_t[num_pairs];
00075 
00076     int32_t offs=0;
00077     for (int32_t i=0; i<num; i++)
00078     {
00079         for (int32_t j=i+1; j<num; j++)
00080         {
00081             distances[offs]=distance->distance(i,j);
00082             index[offs].idx1=i;
00083             index[offs].idx2=j;
00084             offs++;                 //offs=i*(i+1)/2+j
00085         }
00086         SG_PROGRESS(i, 0, num-1);
00087     }
00088 
00089     CMath::qsort_index<float64_t,pair>(distances, index, (num-1)*num/2);
00090     //CMath::display_vector(distances, (num-1)*num/2, "dists");
00091 
00092     int32_t k=-1;
00093     int32_t l=0;
00094     for (; l<num && (num-l)>=merges && k<num_pairs-1; l++)
00095     {
00096         while (k<num_pairs-1)
00097         {
00098             k++;
00099 
00100             int32_t i=index[k].idx1;
00101             int32_t j=index[k].idx2;
00102             int32_t c1=assignment[i];
00103             int32_t c2=assignment[j];
00104 
00105             if (c1==c2)
00106                 continue;
00107             
00108             SG_PROGRESS(k, 0, num_pairs-1);
00109 
00110             if (c1<c2)
00111             {
00112                 pairs[2*l]=c1;
00113                 pairs[2*l+1]=c2;
00114             }
00115             else
00116             {
00117                 pairs[2*l]=c2;
00118                 pairs[2*l+1]=c1;
00119             }
00120             merge_distance[l]=distances[k];
00121 
00122             int32_t c=num+l;
00123             for (int32_t m=0; m<num; m++)
00124             {
00125                 if (assignment[m] == c1 || assignment[m] == c2)
00126                     assignment[m] = c;
00127             }
00128 #ifdef DEBUG_HIERARCHICAL
00129             SG_PRINT("l=%04i i=%04i j=%04i c1=%+04d c2=%+04d c=%+04d dist=%6.6f\n", l,i,j, c1,c2,c, merge_distance[l]);
00130 #endif
00131             break;
00132         }
00133     }
00134 
00135     assignment_size=num;
00136     table_size=l-1;
00137     ASSERT(table_size>0);
00138     delete[] distances;
00139     delete[] index;
00140     SG_UNREF(lhs)
00141 
00142     return true;
00143 }
00144 
00145 bool CHierarchical::load(FILE* srcfile)
00146 {
00147     return false;
00148 }
00149 
00150 bool CHierarchical::save(FILE* dstfile)
00151 {
00152     return false;
00153 }

SHOGUN Machine Learning Toolbox - Documentation