KRR.cpp
Go to the documentation of this file.00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012 #include "lib/config.h"
00013
00014 #ifdef HAVE_LAPACK
00015 #include "regression/KRR.h"
00016 #include "lib/lapack.h"
00017 #include "lib/Mathematics.h"
00018
00019 CKRR::CKRR()
00020 : CKernelMachine()
00021 {
00022 alpha=NULL;
00023 tau=1e-6;
00024 }
00025
00026 CKRR::CKRR(DREAL t, CKernel* k, CLabels* lab)
00027 : CKernelMachine()
00028 {
00029 tau=t;
00030 set_labels(lab);
00031 set_kernel(k);
00032 alpha=NULL;
00033 }
00034
00035
00036 CKRR::~CKRR()
00037 {
00038 delete[] alpha;
00039 }
00040
00041 bool CKRR::train()
00042 {
00043 delete[] alpha;
00044
00045 ASSERT(labels);
00046 ASSERT(kernel && kernel->has_features());
00047
00048
00049 INT m=0;
00050 INT n=0;
00051 DREAL *K = kernel->get_kernel_matrix_real(m, n, NULL);
00052 ASSERT(K && m>0 && n>0);
00053
00054 for(int i=0; i < n; i++)
00055 K[i+i*n]+=tau;
00056
00057
00058 INT numlabels=0;
00059 alpha=labels->get_labels(numlabels);
00060 ASSERT(alpha && numlabels==n);
00061
00062 clapack_dposv(CblasRowMajor,CblasUpper, n, 1, K, n, alpha, n);
00063
00064 delete[] K;
00065 return true;
00066 }
00067
00068 bool CKRR::load(FILE* srcfile)
00069 {
00070 return false;
00071 }
00072
00073 bool CKRR::save(FILE* dstfile)
00074 {
00075 return false;
00076 }
00077
00078 CLabels* CKRR::classify(CLabels* output)
00079 {
00080 if (labels)
00081 {
00082 ASSERT(output==NULL);
00083 ASSERT(kernel);
00084
00085
00086 INT m=0;
00087 INT n=0;
00088 DREAL* K=kernel->get_kernel_matrix_real(m, n, NULL);
00089 ASSERT(K && m>0 && n>0);
00090 DREAL* Yh=new DREAL[n];
00091
00092
00093
00094
00095
00096 cblas_dgemv(CblasColMajor, CblasTrans, m, n, 1.0, K, m, alpha, 1, 0.0, Yh, 1);
00097
00098 delete[] K;
00099
00100 output=new CLabels(n);
00101 output->set_labels(Yh, n);
00102
00103 delete[] Yh;
00104
00105 return output;
00106 }
00107
00108 return NULL;
00109 }
00110
00111 DREAL CKRR::classify_example(INT num)
00112 {
00113 ASSERT(kernel);
00114
00115
00116 INT m=0;
00117 INT n=0;
00118
00119 DREAL* K=kernel->get_kernel_matrix_real(m, n, NULL);
00120 ASSERT(K && m>0 && n>0);
00121 DREAL Yh;
00122
00123
00124 Yh = CMath::dot(K + m*num, alpha, m);
00125
00126 delete[] K;
00127 return Yh;
00128 }
00129
00130 #endif