KRR.cpp
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #include "lib/config.h"
00013
00014 #ifdef HAVE_LAPACK
00015 #include "regression/KRR.h"
00016 #include "lib/lapack.h"
00017 #include "lib/Mathematics.h"
00018
00019 CKRR::CKRR()
00020 : CKernelMachine()
00021 {
00022 alpha=NULL;
00023 tau=1e-6;
00024 }
00025
00026 CKRR::CKRR(float64_t t, CKernel* k, CLabels* lab)
00027 : CKernelMachine()
00028 {
00029 tau=t;
00030 set_labels(lab);
00031 set_kernel(k);
00032 alpha=NULL;
00033 }
00034
00035
00036 CKRR::~CKRR()
00037 {
00038 delete[] alpha;
00039 }
00040
00041 bool CKRR::train()
00042 {
00043 delete[] alpha;
00044
00045 ASSERT(labels);
00046 ASSERT(kernel && kernel->has_features());
00047
00048
00049 int32_t m=0;
00050 int32_t n=0;
00051 float64_t *K = kernel->get_kernel_matrix_real(m, n, NULL);
00052 ASSERT(K && m>0 && n>0);
00053
00054 for(int32_t i=0; i < n; i++)
00055 K[i+i*n]+=tau;
00056
00057
00058 int32_t numlabels=0;
00059 alpha=labels->get_labels(numlabels);
00060 ASSERT(alpha && numlabels==n);
00061
00062 clapack_dposv(CblasRowMajor,CblasUpper, n, 1, K, n, alpha, n);
00063
00064 delete[] K;
00065 return true;
00066 }
00067
00068 bool CKRR::load(FILE* srcfile)
00069 {
00070 return false;
00071 }
00072
00073 bool CKRR::save(FILE* dstfile)
00074 {
00075 return false;
00076 }
00077
00078 CLabels* CKRR::classify(CLabels* output)
00079 {
00080 if (labels)
00081 {
00082 ASSERT(output==NULL);
00083 ASSERT(kernel);
00084
00085
00086 int32_t m=0;
00087 int32_t n=0;
00088 float64_t* K=kernel->get_kernel_matrix_real(m, n, NULL);
00089 ASSERT(K && m>0 && n>0);
00090 float64_t* Yh=new float64_t[n];
00091
00092
00093
00094
00095
00096 int m_int = (int) m;
00097 int n_int = (int) n;
00098 cblas_dgemv(CblasColMajor, CblasTrans, m_int, n_int, 1.0, (double*) K,
00099 m_int, (double*) alpha, 1, 0.0, (double*) Yh, 1);
00100
00101 delete[] K;
00102
00103 output=new CLabels(n);
00104 output->set_labels(Yh, n);
00105
00106 delete[] Yh;
00107
00108 return output;
00109 }
00110
00111 return NULL;
00112 }
00113
00114 float64_t CKRR::classify_example(int32_t num)
00115 {
00116 ASSERT(kernel);
00117
00118
00119 int32_t m=0;
00120 int32_t n=0;
00121
00122 float64_t* K=kernel->get_kernel_matrix_real(m, n, NULL);
00123 ASSERT(K && m>0 && n>0);
00124 float64_t Yh;
00125
00126
00127 Yh = CMath::dot(K + m*num, alpha, m);
00128
00129 delete[] K;
00130 return Yh;
00131 }
00132
00133 #endif