IT++ Logo

mog_diag_em.cpp

Go to the documentation of this file.
00001 
00031 #include <itpp/stat/mog_diag_em.h>
00032 #include <itpp/base/math/log_exp.h>
00033 #include <itpp/base/timing.h>
00034 
00035 #include <iostream>
00036 #include <iomanip>
00037 
00038 namespace itpp
00039 {
00040 
00042 void inline MOG_diag_EM_sup::update_internals()
00043 {
00044 
00045   double Ddiv2_log_2pi = D / 2.0 * std::log(m_2pi);
00046 
00047   for (int k = 0;k < K;k++)  c_log_weights[k] = std::log(c_weights[k]);
00048 
00049   for (int k = 0;k < K;k++) {
00050     double acc = 0.0;
00051     double * c_diag_cov = c_diag_covs[k];
00052     double * c_diag_cov_inv_etc = c_diag_covs_inv_etc[k];
00053 
00054     for (int d = 0;d < D;d++) {
00055       double tmp = c_diag_cov[d];
00056       c_diag_cov_inv_etc[d] = 1.0 / (2.0 * tmp);
00057       acc += std::log(tmp);
00058     }
00059 
00060     c_log_det_etc[k] = -Ddiv2_log_2pi - 0.5 * acc;
00061   }
00062 
00063 }
00064 
00065 
00067 void inline MOG_diag_EM_sup::sanitise_params()
00068 {
00069 
00070   double acc = 0.0;
00071   for (int k = 0;k < K;k++) {
00072     if (c_weights[k] < weight_floor)  c_weights[k] = weight_floor;
00073     if (c_weights[k] > 1.0)  c_weights[k] = 1.0;
00074     acc += c_weights[k];
00075   }
00076   for (int k = 0;k < K;k++)  c_weights[k] /= acc;
00077 
00078   for (int k = 0;k < K;k++)
00079     for (int d = 0;d < D;d++)
00080       if (c_diag_covs[k][d] < var_floor)  c_diag_covs[k][d] = var_floor;
00081 
00082 }
00083 
00085 double MOG_diag_EM_sup::ml_update_params()
00086 {
00087 
00088   double acc_loglhood = 0.0;
00089 
00090   for (int k = 0;k < K;k++)  {
00091     c_acc_loglhood_K[k] = 0.0;
00092 
00093     double * c_acc_mean = c_acc_means[k];
00094     double * c_acc_cov  = c_acc_covs[k];
00095 
00096     for (int d = 0;d < D;d++) { c_acc_mean[d] = 0.0; c_acc_cov[d] = 0.0; }
00097   }
00098 
00099   for (int n = 0;n < N;n++) {
00100     double * c_x =  c_X[n];
00101 
00102     bool danger = paranoid;
00103     for (int k = 0;k < K;k++)  {
00104       double tmp = c_log_weights[k] + MOG_diag::log_lhood_single_gaus_internal(c_x, k);
00105       c_tmpvecK[k] = tmp;
00106       if (tmp >= log_max_K)  danger = true;
00107     }
00108 
00109     if (danger) {
00110 
00111       double log_sum = c_tmpvecK[0];
00112       for (int k = 1;k < K;k++)  log_sum = log_add(log_sum, c_tmpvecK[k]);
00113       acc_loglhood += log_sum;
00114 
00115       for (int k = 0;k < K;k++) {
00116 
00117         double * c_acc_mean = c_acc_means[k];
00118         double * c_acc_cov = c_acc_covs[k];
00119 
00120         double tmp_k = trunc_exp(c_tmpvecK[k] - log_sum);
00121         acc_loglhood_K[k] += tmp_k;
00122 
00123         for (int d = 0;d < D;d++) {
00124           double tmp_x = c_x[d];
00125           c_acc_mean[d] +=  tmp_k * tmp_x;
00126           c_acc_cov[d] += tmp_k * tmp_x * tmp_x;
00127         }
00128       }
00129     }
00130     else {
00131 
00132       double sum = 0.0;
00133       for (int k = 0;k < K;k++) { double tmp = std::exp(c_tmpvecK[k]); c_tmpvecK[k] = tmp; sum += tmp; }
00134       acc_loglhood += std::log(sum);
00135 
00136       for (int k = 0;k < K;k++) {
00137 
00138         double * c_acc_mean = c_acc_means[k];
00139         double * c_acc_cov = c_acc_covs[k];
00140 
00141         double tmp_k = c_tmpvecK[k] / sum;
00142         c_acc_loglhood_K[k] += tmp_k;
00143 
00144         for (int d = 0;d < D;d++) {
00145           double tmp_x = c_x[d];
00146           c_acc_mean[d] +=  tmp_k * tmp_x;
00147           c_acc_cov[d] += tmp_k * tmp_x * tmp_x;
00148         }
00149       }
00150     }
00151   }
00152 
00153   for (int k = 0;k < K;k++) {
00154 
00155     double * c_mean = c_means[k];
00156     double * c_diag_cov = c_diag_covs[k];
00157 
00158     double * c_acc_mean = c_acc_means[k];
00159     double * c_acc_cov = c_acc_covs[k];
00160 
00161     double tmp_k = c_acc_loglhood_K[k];
00162 
00163     c_weights[k] = tmp_k / N;
00164 
00165     for (int d = 0;d < D;d++) {
00166       double tmp_mean = c_acc_mean[d] / tmp_k;
00167       c_mean[d] = tmp_mean;
00168       c_diag_cov[d] = c_acc_cov[d] / tmp_k - tmp_mean * tmp_mean;
00169     }
00170   }
00171 
00172   return(acc_loglhood / N);
00173 
00174 }
00175 
00176 
00177 void MOG_diag_EM_sup::ml_iterate()
00178 {
00179   using std::cout;
00180   using std::endl;
00181   using std::setw;
00182   using std::showpos;
00183   using std::noshowpos;
00184   using std::scientific;
00185   using std::fixed;
00186   using std::flush;
00187   using std::setprecision;
00188 
00189   double avg_log_lhood_old = -1.0 * std::numeric_limits<double>::max();
00190 
00191   Real_Timer tt;
00192 
00193   if (verbose) {
00194     cout << "MOG_diag_EM_sup::ml_iterate()" << endl;
00195     cout << setw(14) << "iteration";
00196     cout << setw(14) << "avg_loglhood";
00197     cout << setw(14) << "delta";
00198     cout << setw(10) << "toc";
00199     cout << endl;
00200   }
00201 
00202   for (int i = 0; i < max_iter; i++) {
00203     sanitise_params();
00204     update_internals();
00205 
00206     if (verbose) tt.tic();
00207     double avg_log_lhood_new = ml_update_params();
00208 
00209     if (verbose) {
00210       double delta = avg_log_lhood_new - avg_log_lhood_old;
00211 
00212       cout << noshowpos << fixed;
00213       cout << setw(14) << i;
00214       cout << showpos << scientific << setprecision(3);
00215       cout << setw(14) << avg_log_lhood_new;
00216       cout << setw(14) << delta;
00217       cout << noshowpos << fixed;
00218       cout << setw(10) << tt.toc();
00219       cout << endl << flush;
00220     }
00221 
00222     if (avg_log_lhood_new <= avg_log_lhood_old)  break;
00223 
00224     avg_log_lhood_old = avg_log_lhood_new;
00225   }
00226 }
00227 
00228 
00229 void MOG_diag_EM_sup::ml(MOG_diag &model_in, Array<vec> &X_in, int max_iter_in, double var_floor_in, double weight_floor_in, bool verbose_in)
00230 {
00231 
00232   it_assert(model_in.is_valid(), "MOG_diag_EM_sup::ml(): initial model not valid");
00233   it_assert(check_array_uniformity(X_in), "MOG_diag_EM_sup::ml(): 'X' is empty or contains vectors of varying dimensionality");
00234   it_assert((max_iter_in > 0), "MOG_diag_EM_sup::ml(): 'max_iter' needs to be greater than zero");
00235 
00236   verbose = verbose_in;
00237 
00238   N = X_in.size();
00239 
00240   Array<vec> means_in = model_in.get_means();
00241   Array<vec> diag_covs_in = model_in.get_diag_covs();
00242   vec weights_in = model_in.get_weights();
00243 
00244   init(means_in, diag_covs_in, weights_in);
00245 
00246   means_in.set_size(0);
00247   diag_covs_in.set_size(0);
00248   weights_in.set_size(0);
00249 
00250   if (K > N) {
00251     it_warning("MOG_diag_EM_sup::ml(): WARNING: K > N");
00252   }
00253   else {
00254     if (K > N / 10) {
00255       it_warning("MOG_diag_EM_sup::ml(): WARNING: K > N/10");
00256     }
00257   }
00258 
00259   var_floor = var_floor_in;
00260   weight_floor = weight_floor_in;
00261 
00262   const double tiny = std::numeric_limits<double>::min();
00263   if (var_floor < tiny) var_floor = tiny;
00264   if (weight_floor < tiny) weight_floor = tiny;
00265   if (weight_floor > 1.0 / K) weight_floor = 1.0 / K;
00266 
00267   max_iter = max_iter_in;
00268 
00269   tmpvecK.set_size(K);
00270   tmpvecD.set_size(D);
00271   acc_loglhood_K.set_size(K);
00272 
00273   acc_means.set_size(K);
00274   for (int k = 0;k < K;k++) acc_means(k).set_size(D);
00275   acc_covs.set_size(K);
00276   for (int k = 0;k < K;k++) acc_covs(k).set_size(D);
00277 
00278   c_X = enable_c_access(X_in);
00279   c_tmpvecK = enable_c_access(tmpvecK);
00280   c_tmpvecD = enable_c_access(tmpvecD);
00281   c_acc_loglhood_K = enable_c_access(acc_loglhood_K);
00282   c_acc_means = enable_c_access(acc_means);
00283   c_acc_covs = enable_c_access(acc_covs);
00284 
00285   ml_iterate();
00286 
00287   model_in.init(means, diag_covs, weights);
00288 
00289   disable_c_access(c_X);
00290   disable_c_access(c_tmpvecK);
00291   disable_c_access(c_tmpvecD);
00292   disable_c_access(c_acc_loglhood_K);
00293   disable_c_access(c_acc_means);
00294   disable_c_access(c_acc_covs);
00295 
00296 
00297   tmpvecK.set_size(0);
00298   tmpvecD.set_size(0);
00299   acc_loglhood_K.set_size(0);
00300   acc_means.set_size(0);
00301   acc_covs.set_size(0);
00302 
00303   cleanup();
00304 
00305 }
00306 
00307 void MOG_diag_EM_sup::map(MOG_diag &, MOG_diag &, Array<vec> &, int, double,
00308                           double, double, bool)
00309 {
00310   it_error("MOG_diag_EM_sup::map(): not implemented yet");
00311 }
00312 
00313 
00314 //
00315 // convenience functions
00316 
00317 void MOG_diag_ML(MOG_diag &model_in, Array<vec> &X_in, int max_iter_in, double var_floor_in, double weight_floor_in, bool verbose_in)
00318 {
00319   MOG_diag_EM_sup EM;
00320   EM.ml(model_in, X_in, max_iter_in, var_floor_in, weight_floor_in, verbose_in);
00321 }
00322 
00323 void MOG_diag_MAP(MOG_diag &, MOG_diag &, Array<vec> &, int, double, double,
00324                   double, bool)
00325 {
00326   it_error("MOG_diag_MAP(): not implemented yet");
00327 }
00328 
00329 }
00330 
SourceForge Logo

Generated on Fri May 1 11:09:19 2009 for IT++ by Doxygen 1.5.8