00001
00002
00003
00004
00005
00006
00007
00008
00009
00010
00011
00012
00013
00014
00015
00016
00017
00018
00019
00020 #ifndef _SSL_H
00021 #define _SSL_H
00022
00023
00024 #define CGITERMAX 10000
00025 #define SMALL_CGITERMAX 10
00026 #define EPSILON 1e-6
00027 #define BIG_EPSILON 0.01
00028 #define RELATIVE_STOP_EPS 1e-9
00029 #define MFNITERMAX 50
00030 #define TSVM_ANNEALING_RATE 1.5
00031 #define TSVM_LAMBDA_SMALL 1e-5
00032 #define DA_ANNEALING_RATE 1.5
00033 #define DA_INIT_TEMP 10
00034 #define DA_INNER_ITERMAX 100
00035 #define DA_OUTER_ITERMAX 30
00036
00037 #include "lib/common.h"
00038 #include "features/SparseFeatures.h"
00039
00041 struct data
00042 {
00044 int m;
00046 int l;
00048 int u;
00050 int n;
00052 int nz;
00053
00055 CSparseFeatures<DREAL>* features;
00057 double *Y;
00059 double *C;
00060 };
00061
00063 struct vector_double
00064 {
00066 int d;
00068 double *vec;
00069 };
00070
00072 struct vector_int
00073 {
00075 int d;
00077 int *vec;
00078 };
00079
00080 enum { RLS, SVM, TSVM, DA_SVM };
00081
00083 struct options
00084 {
00085
00087 int algo;
00089 double lambda;
00091 double lambda_u;
00093 int S;
00095 double R;
00097 double Cp;
00099 double Cn;
00100
00101
00103 double epsilon;
00105 int cgitermax;
00107 int mfnitermax;
00108
00110 double bias;
00111 };
00112
00114 class Delta {
00115 public:
00117 Delta() {delta=0.0; index=0;s=0;};
00118
00120 double delta;
00122 int index;
00124 int s;
00125 };
00126 inline bool operator<(const Delta& a , const Delta& b) { return (a.delta < b.delta);};
00127
00128 void initialize(struct vector_double *A, int k, double a);
00129
00130 void initialize(struct vector_int *A, int k);
00131
00132 void GetLabeledData(struct data *Data_Labeled, const struct data *Data);
00133
00134 double norm_square(const vector_double *A);
00135
00136
00137
00138
00139 void ssl_train(struct data *Data,
00140 struct options *Options,
00141 struct vector_double *W,
00142 struct vector_double *O);
00143
00144
00145
00146
00147
00148
00149 int CGLS(const struct data *Data,
00150 const struct options *Options,
00151 const struct vector_int *Subset,
00152 struct vector_double *Weights,
00153 struct vector_double *Outputs);
00154
00155
00156
00157 int L2_SVM_MFN(const struct data *Data,
00158 struct options *Options,
00159 struct vector_double *Weights,
00160 struct vector_double *Outputs,
00161 int ini);
00162 double line_search(double *w,
00163 double *w_bar,
00164 double lambda,
00165 double *o,
00166 double *o_bar,
00167 double *Y,
00168 double *C,
00169 int d,
00170 int l);
00171
00172
00173
00174
00175 int TSVM_MFN(const struct data *Data,
00176 struct options *Options,
00177 struct vector_double *Weights,
00178 struct vector_double *Outputs);
00179 int switch_labels(double* Y, double* o, int* JU, int u, int S);
00180
00181
00182 int DA_S3VM(struct data *Data,
00183 struct options *Options,
00184 struct vector_double *Weights,
00185 struct vector_double *Outputs);
00186 void optimize_p(const double* g, int u, double T, double r, double*p);
00187 int optimize_w(const struct data *Data,
00188 const double *p,
00189 struct options *Options,
00190 struct vector_double *Weights,
00191 struct vector_double *Outputs,
00192 int ini);
00193 double transductive_cost(double normWeights,double *Y, double *Outputs, int m, double lambda,double lambda_u);
00194 double entropy(const double *p, int u);
00195 double KL(const double *p, const double *q, int u);
00196
00197 #endif