00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #ifndef _SSL_H
00021 #define _SSL_H
00022
00023
00024 #define CGITERMAX 10000
00025 #define SMALL_CGITERMAX 10
00026 #define EPSILON 1e-6
00027 #define BIG_EPSILON 0.01
00028 #define RELATIVE_STOP_EPS 1e-9
00029 #define MFNITERMAX 50
00030 #define TSVM_ANNEALING_RATE 1.5
00031 #define TSVM_LAMBDA_SMALL 1e-5
00032 #define DA_ANNEALING_RATE 1.5
00033 #define DA_INIT_TEMP 10
00034 #define DA_INNER_ITERMAX 100
00035 #define DA_OUTER_ITERMAX 30
00036
00037 #include "lib/common.h"
00038 #include "features/SparseFeatures.h"
00039
00041 struct data
00042 {
00044 int32_t m;
00046 int32_t l;
00048 int32_t u;
00050 int32_t n;
00052 int32_t nz;
00053
00055 CSparseFeatures<float64_t>* features;
00057 float64_t *Y;
00059 float64_t *C;
00060 };
00061
00063 struct vector_double
00064 {
00066 int32_t d;
00068 float64_t *vec;
00069 };
00070
00072 struct vector_int
00073 {
00075 int32_t d;
00077 int32_t *vec;
00078 };
00079
00080 enum { RLS, SVM, TSVM, DA_SVM };
00081
00083 struct options
00084 {
00085
00087 int32_t algo;
00089 float64_t lambda;
00091 float64_t lambda_u;
00093 int32_t S;
00095 float64_t R;
00097 float64_t Cp;
00099 float64_t Cn;
00100
00101
00103 float64_t epsilon;
00105 int32_t cgitermax;
00107 int32_t mfnitermax;
00108
00110 float64_t bias;
00111 };
00112
00114 class Delta {
00115 public:
00117 Delta() { delta=0.0; index=0;s=0; }
00118
00120 float64_t delta;
00122 int32_t index;
00124 int32_t s;
00125 };
00126
00127 inline bool operator<(const Delta& a , const Delta& b)
00128 {
00129 return (a.delta < b.delta);
00130 }
00131
00132 void initialize(struct vector_double *A, int32_t k, float64_t a);
00133
00134 void initialize(struct vector_int *A, int32_t k);
00135
00136 void GetLabeledData(struct data *Data_Labeled, const struct data *Data);
00137
00138 float64_t norm_square(const vector_double *A);
00139
00140
00141
00142
00143 void ssl_train(
00144 struct data *Data,
00145 struct options *Options,
00146 struct vector_double *W,
00147 struct vector_double *O);
00148
00149
00150
00151
00152
00153
00154 int32_t CGLS(
00155 const struct data *Data,
00156 const struct options *Options,
00157 const struct vector_int *Subset,
00158 struct vector_double *Weights,
00159 struct vector_double *Outputs);
00160
00161
00162
00163 int32_t L2_SVM_MFN(
00164 const struct data *Data,
00165 struct options *Options,
00166 struct vector_double *Weights,
00167 struct vector_double *Outputs,
00168 int32_t ini);
00169
00170 float64_t line_search(
00171 float64_t *w,
00172 float64_t *w_bar,
00173 float64_t lambda,
00174 float64_t *o,
00175 float64_t *o_bar,
00176 float64_t *Y,
00177 float64_t *C,
00178 int32_t d,
00179 int32_t l);
00180
00181
00182
00183
00184 int32_t TSVM_MFN(
00185 const struct data *Data,
00186 struct options *Options,
00187 struct vector_double *Weights,
00188 struct vector_double *Outputs);
00189
00190 int32_t switch_labels(
00191 float64_t* Y,
00192 float64_t* o,
00193 int32_t* JU,
00194 int32_t u,
00195 int32_t S);
00196
00197
00198 int32_t DA_S3VM(
00199 struct data *Data,
00200 struct options *Options,
00201 struct vector_double *Weights,
00202 struct vector_double *Outputs);
00203
00204 void optimize_p(
00205 const float64_t* g, int32_t u, float64_t T, float64_t r, float64_t*p);
00206
00207 int32_t optimize_w(
00208 const struct data *Data,
00209 const float64_t *p,
00210 struct options *Options,
00211 struct vector_double *Weights,
00212 struct vector_double *Outputs,
00213 int32_t ini);
00214
00215 float64_t transductive_cost(
00216 float64_t normWeights,
00217 float64_t *Y,
00218 float64_t *Outputs,
00219 int32_t m,
00220 float64_t lambda,
00221 float64_t lambda_u);
00222
00223 float64_t entropy(const float64_t *p, int32_t u);
00224
00225
00226 float64_t KL(const float64_t *p, const float64_t *q, int32_t u);
00227
00228 #endif