Main MRPT website > C++ reference
MRPT logo

CLevenbergMarquardt.h

Go to the documentation of this file.
00001 /* +---------------------------------------------------------------------------+
00002    |          The Mobile Robot Programming Toolkit (MRPT) C++ library          |
00003    |                                                                           |
00004    |                   http://mrpt.sourceforge.net/                            |
00005    |                                                                           |
00006    |   Copyright (C) 2005-2011  University of Malaga                           |
00007    |                                                                           |
00008    |    This software was written by the Machine Perception and Intelligent    |
00009    |      Robotics Lab, University of Malaga (Spain).                          |
00010    |    Contact: Jose-Luis Blanco  <jlblanco@ctima.uma.es>                     |
00011    |                                                                           |
00012    |  This file is part of the MRPT project.                                   |
00013    |                                                                           |
00014    |     MRPT is free software: you can redistribute it and/or modify          |
00015    |     it under the terms of the GNU General Public License as published by  |
00016    |     the Free Software Foundation, either version 3 of the License, or     |
00017    |     (at your option) any later version.                                   |
00018    |                                                                           |
00019    |   MRPT is distributed in the hope that it will be useful,                 |
00020    |     but WITHOUT ANY WARRANTY; without even the implied warranty of        |
00021    |     MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the         |
00022    |     GNU General Public License for more details.                          |
00023    |                                                                           |
00024    |     You should have received a copy of the GNU General Public License     |
00025    |     along with MRPT.  If not, see <http://www.gnu.org/licenses/>.         |
00026    |                                                                           |
00027    +---------------------------------------------------------------------------+ */
00028 #ifndef  CLevenbergMarquardt_H
00029 #define  CLevenbergMarquardt_H
00030 
00031 #include <mrpt/utils/CDebugOutputCapable.h>
00032 #include <mrpt/math/CMatrixD.h>
00033 #include <mrpt/math/utils.h>
00034 
00035 /*---------------------------------------------------------------
00036         Class
00037   ---------------------------------------------------------------*/
00038 namespace mrpt
00039 {
00040 namespace math
00041 {
00042         /** An implementation of the Levenberg-Marquardt algorithm for least-square minimization.
00043          *
00044          *  Refer to this <a href="http://www.mrpt.org/Levenberg%E2%80%93Marquardt_algorithm">page</a> for more details on the algorithm and its usage.
00045          *
00046          * \tparam NUMTYPE The numeric type for all the operations (float, double, or long double)
00047          * \tparam USERPARAM The type of the "y" input to the user supplied evaluation functor. Default type is a vector of NUMTYPE.
00048          */
00049         template <typename VECTORTYPE = mrpt::vector_double, class USERPARAM = VECTORTYPE >
00050         class CLevenbergMarquardtTempl : public mrpt::utils::CDebugOutputCapable
00051         {
00052         public:
00053                 typedef typename VECTORTYPE::value_type  NUMTYPE;
00054 
00055 
00056                 /** The type of the function passed to execute. The user must supply a function which evaluates the error of a given point in the solution space.
00057                   *  \param x The state point under examination.
00058                   *  \param y The same object passed to "execute" as the parameter "userParam".
00059                   *  \param out The vector of (non-squared) errors, of the average square root error, for the given "x". The functor code must set the size of this vector.
00060                   */
00061                 typedef void (*TFunctor)(
00062                         const VECTORTYPE &x,
00063                         const USERPARAM &y,
00064                         VECTORTYPE &out);
00065 
00066                 struct TResultInfo
00067                 {
00068                         NUMTYPE         final_sqr_err;
00069                         size_t          iterations_executed;
00070                         VECTORTYPE      last_err_vector;                //!< The last error vector returned by the user-provided functor.
00071                         CMatrixTemplateNumeric<NUMTYPE> path;   //!< Each row is the optimized value at each iteration.
00072 
00073                         /** This matrix can be used to obtain an estimate of the optimal parameters covariance matrix:
00074                           *  \f[ COV = H M H^\top \f]
00075                           *  With COV the covariance matrix of the optimal parameters, H this matrix, and M the covariance of the input (observations).
00076                           */
00077                         //CMatrixTemplateNumeric<NUMTYPE> H;
00078                 };
00079 
00080                 /** Executes the LM-method, with derivatives estimated from
00081                   *  "functor" Is a user-provided function which takes as input two vectors, in this order:
00082                   *             - x: The parameters to be optimized.
00083                   *             - userParam: The vector passed to the LM algorithm, unmodified.
00084                   *       and must return the "error vector", or the error (not squared) in each measured dimension, so the sum of the square of that output is the overall square error.
00085                   */
00086                 static void     execute(
00087                         VECTORTYPE                      &out_optimal_x,
00088                         const VECTORTYPE        &x0,
00089                         TFunctor                        functor,
00090                         const VECTORTYPE        &increments,
00091                         const USERPARAM         &userParam,
00092                         TResultInfo                     &out_info,
00093                         bool                            verbose = false,
00094                         const size_t            &maxIter = 200,
00095                         const NUMTYPE           tau = 1e-3,
00096                         const NUMTYPE           e1 = 1e-8,
00097                         const NUMTYPE           e2 = 1e-8,
00098                         bool returnPath=true
00099                         )
00100                 {
00101                         using namespace mrpt;
00102                         using namespace mrpt::utils;
00103                         using namespace mrpt::math;
00104                         using namespace std;
00105 
00106                         MRPT_START;
00107 
00108                         VECTORTYPE &x=out_optimal_x; // Var rename
00109 
00110                         // Asserts:
00111                         ASSERT_( increments.size() == x0.size() );
00112 
00113                         x=x0;                                                                   // Start with the starting point
00114                         VECTORTYPE      f_x;                                    // The vector error from the user function
00115                         CMatrixTemplateNumeric<NUMTYPE> AUX;
00116                         CMatrixTemplateNumeric<NUMTYPE> J;              // The Jacobian of "f"
00117                         CMatrixTemplateNumeric<NUMTYPE> H;              // The Hessian of "f"
00118                         VECTORTYPE      g;                                              // The gradient
00119 
00120                         // Compute the jacobian and the Hessian:
00121                         mrpt::math::estimateJacobian( x, functor, increments, userParam, J);
00122                         H.multiply_AtA(J);
00123 
00124                         const size_t  H_len = H.getColCount();
00125 
00126                         // Compute the gradient:
00127                         functor(x, userParam ,f_x);
00128                         J.multiply_Atb(f_x, g);
00129 
00130                         // Start iterations:
00131                         bool    found = math::norm_inf(g)<=e1;
00132                         if (verbose && found)   printf_debug("[LM] End condition: math::norm_inf(g)<=e1 :%f\n",math::norm_inf(g));
00133 
00134                         NUMTYPE lambda = tau * H.maximumDiagonal();
00135                         size_t  iter = 0;
00136                         NUMTYPE v = 2;
00137 
00138                         VECTORTYPE      h_lm;
00139                         VECTORTYPE      xnew, f_xnew ;
00140                         NUMTYPE                 F_x  = pow( math::norm( f_x ), 2);
00141 
00142                         const size_t    N = x.size();
00143 
00144                         if (returnPath) {
00145                                 out_info.path.setSize(maxIter,N+1);
00146                                 out_info.path.block(iter,0,1,N) = x.transpose();
00147                         }       else out_info.path = Eigen::Matrix<NUMTYPE,Eigen::Dynamic,Eigen::Dynamic>(); // Empty matrix
00148 
00149                         while (!found && ++iter<maxIter)
00150                         {
00151                                 // H_lm = -( H + \lambda I ) ^-1 * g
00152                                 for (size_t k=0;k<H_len;k++)
00153                                         H(k,k)+= lambda;
00154                                         //H(k,k) *= 1+lambda;
00155 
00156                                 H.inv_fast(AUX);
00157                                 AUX.multiply_Ab(g,h_lm);
00158                                 h_lm *= NUMTYPE(-1.0);
00159 
00160                                 double h_lm_n2 = math::norm(h_lm);
00161                                 double x_n2 = math::norm(x);
00162 
00163                                 if (verbose) printf_debug( (format("[LM] Iter: %u x:",(unsigned)iter)+ sprintf_vector(" %f",x) + string("\n")).c_str() );
00164 
00165                                 if (h_lm_n2<e2*(x_n2+e2))
00166                                 {
00167                                         // Done:
00168                                         found = true;
00169                                         if (verbose) printf_debug("[LM] End condition: %e < %e\n", h_lm_n2, e2*(x_n2+e2) );
00170                                 }
00171                                 else
00172                                 {
00173                                         // Improvement:
00174                                         xnew = x;
00175                                         xnew += h_lm;
00176                                         functor(xnew, userParam ,f_xnew );
00177                                         const double F_xnew = pow( math::norm(f_xnew), 2);
00178 
00179                                         // denom = h_lm^t * ( \lambda * h_lm - g )
00180                                         VECTORTYPE      tmp(h_lm);
00181                                         tmp *= lambda;
00182                                         tmp -= g;
00183                                         tmp.array() *=h_lm.array();
00184                                         double denom = tmp.sum();
00185                                         double l = (F_x - F_xnew) / denom;
00186 
00187                                         if (l>0) // There is an improvement:
00188                                         {
00189                                                 // Accept new point:
00190                                                 x   = xnew;
00191                                                 f_x = f_xnew;
00192                                                 F_x = F_xnew;
00193 
00194                                                 math::estimateJacobian( x, functor, increments, userParam, J);
00195                                                 H.multiply_AtA(J);
00196                                                 J.multiply_Atb(f_x, g);
00197 
00198                                                 found = math::norm_inf(g)<=e1;
00199                                                 if (verbose && found)   printf_debug("[LM] End condition: math::norm_inf(g)<=e1 : %e\n", math::norm_inf(g) );
00200 
00201                                                 lambda *= max(0.33, 1-pow(2*l-1,3) );
00202                                                 v = 2;
00203                                         }
00204                                         else
00205                                         {
00206                                                 // Nope...
00207                                                 lambda *= v;
00208                                                 v*= 2;
00209                                         }
00210 
00211 
00212                                         if (returnPath) {
00213                                                 out_info.path.block(iter,0,1,x.size()) = x.transpose();
00214                                                 out_info.path(iter,x.size()) = F_x;
00215                                         }
00216                                 }
00217                         } // end while
00218 
00219                         // Output info:
00220                         out_info.final_sqr_err = F_x;
00221                         out_info.iterations_executed = iter;
00222                         out_info.last_err_vector = f_x;
00223                         if (returnPath) out_info.path.setSize(iter,N+1);
00224 
00225                         MRPT_END;
00226                 }
00227 
00228         }; // End of class def.
00229 
00230 
00231         typedef CLevenbergMarquardtTempl<vector_double> CLevenbergMarquardt;  //!< The default name for the LM class is an instantiation for "double"
00232 
00233         } // End of namespace
00234 } // End of namespace
00235 #endif



Page generated by Doxygen 1.7.3 for MRPT 0.9.4 SVN:exported at Tue Jan 25 21:56:31 UTC 2011