blas_proto.hpp

Go to the documentation of this file.
00001 // Copyright (C) 2010 NICTA and the authors listed below
00002 // http://nicta.com.au
00003 // 
00004 // Authors:
00005 // - Conrad Sanderson (conradsand at ieee dot org)
00006 // 
00007 // This file is part of the Armadillo C++ library.
00008 // It is provided without any warranty of fitness
00009 // for any purpose. You can redistribute this file
00010 // and/or modify it under the terms of the GNU
00011 // Lesser General Public License (LGPL) as published
00012 // by the Free Software Foundation, either version 3
00013 // of the License or (at your option) any later version.
00014 // (see http://www.opensource.org/licenses for more info)
00015 
00016 
00017 #ifdef ARMA_USE_BLAS
00018 
00019 //! \namespace blas namespace for BLAS functions
00020 namespace blas
00021   {
00022   extern "C"
00023     {
00024     float  sdot_(const int* n, const float*  x, const int* incx, const float*  y, const int* incy);
00025     double ddot_(const int* n, const double* x, const int* incx, const double* y, const int* incy);
00026     
00027     void sgemv_(const char* transA, const int* m, const int* n, const float*  alpha, const float*  A, const int* ldA, const float*  x, const int* incx, const float*  beta, float*  y, const int* incy);
00028     void dgemv_(const char* transA, const int* m, const int* n, const double* alpha, const double* A, const int* ldA, const double* x, const int* incx, const double* beta, double* y, const int* incy);
00029     void cgemv_(const char* transA, const int* m, const int* n, const void*   alpha, const void*   A, const int* ldA, const void*   x, const int* incx, const void*   beta, void*   y, const int* incy);
00030     void zgemv_(const char* transA, const int* m, const int* n, const void*   alpha, const void*   A, const int* ldA, const void*   x, const int* incx, const void*   beta, void*   y, const int* incy);
00031     
00032     void sgemm_(const char* transA, const char* transB, const int* m, const int* n, const int* k, const float*  alpha, const float*  A, const int* ldA, const float*  B, const int* ldB, const float*  beta, float*  C, const int* ldC);
00033     void dgemm_(const char* transA, const char* transB, const int* m, const int* n, const int* k, const double* alpha, const double* A, const int* ldA, const double* B, const int* ldB, const double* beta, double* C, const int* ldC);
00034     void cgemm_(const char* transA, const char* transB, const int* m, const int* n, const int* k, const void*   alpha, const void*   A, const int* ldA, const void*   B, const int* ldB, const void*   beta, void*   C, const int* ldC);
00035     void zgemm_(const char* transA, const char* transB, const int* m, const int* n, const int* k, const void*   alpha, const void*   A, const int* ldA, const void*   B, const int* ldB, const void*   beta, void*   C, const int* ldC);
00036 
00037     // void   dswap_(const int* n, double* x, const int* incx, double* y, const int* incy);
00038     // void   dscal_(const int* n, const double* alpha, double* x, const int* incx);
00039     // void   dcopy_(const int* n, const double* x, const int* incx, double* y, const int* incy);
00040     // void   daxpy_(const int* n, const double* alpha, const double* x, const int* incx, double* y, const int* incy);
00041     // void    dger_(const int* m, const int* n, const double* alpha, const double* x, const int* incx, const double* y, const int* incy, double* A, const int* ldA);
00042     }
00043   
00044   
00045   
00046   template<typename eT>
00047   arma_inline
00048   eT
00049   dot_(const int* n, const eT* x, const eT* y)
00050     {
00051     arma_type_check<is_supported_blas_type<eT>::value == false>::apply();
00052     
00053     const int inc = 1;
00054     
00055     if(is_float<eT>::value == true)
00056       {
00057       typedef float T;
00058       return eT( sdot_(n, (const T*)x, &inc, (const T*)y, &inc) );
00059       }
00060     else
00061     if(is_double<eT>::value == true)
00062       {
00063       typedef double T;
00064       return eT( ddot_(n, (const T*)x, &inc, (const T*)y, &inc) );
00065       }
00066     else
00067       {
00068       return eT(0);  // prevent compiler warnings
00069       }
00070     }
00071   
00072   
00073   
00074   template<typename eT>
00075   inline
00076   void
00077   gemv_(const char* transA, const int* m, const int* n, const eT* alpha, const eT* A, const int* ldA, const eT* x, const int* incx, const eT* beta, eT* y, const int* incy)
00078     {
00079     arma_type_check<is_supported_blas_type<eT>::value == false>::apply();
00080     
00081     if(is_float<eT>::value == true)
00082       {
00083       typedef float T;
00084       sgemv_(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy);
00085       }
00086     else
00087     if(is_double<eT>::value == true)
00088       {
00089       typedef double T;
00090       dgemv_(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy);
00091       }
00092     else
00093     if(is_supported_complex_float<eT>::value == true)
00094       {
00095       typedef std::complex<float> T;
00096       cgemv_(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy);
00097       }
00098     else
00099     if(is_supported_complex_double<eT>::value == true)
00100       {
00101       typedef std::complex<double> T;
00102       zgemv_(transA, m, n, (const T*)alpha, (const T*)A, ldA, (const T*)x, incx, (const T*)beta, (T*)y, incy);
00103       }
00104     
00105     }
00106   
00107   
00108   
00109   template<typename eT>
00110   inline
00111   void
00112   gemm_(const char* transA, const char* transB, const int* m, const int* n, const int* k, const eT* alpha, const eT* A, const int* ldA, const eT* B, const int* ldB, const eT* beta, eT* C, const int* ldC)
00113     {
00114     arma_type_check<is_supported_blas_type<eT>::value == false>::apply();
00115     
00116     if(is_float<eT>::value == true)
00117       {
00118       typedef float T;
00119       sgemm_(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC);
00120       }
00121     else
00122     if(is_double<eT>::value == true)
00123       {
00124       typedef double T;
00125       dgemm_(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC);
00126       }
00127     else
00128     if(is_supported_complex_float<eT>::value == true)
00129       {
00130       typedef std::complex<float> T;
00131       cgemm_(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC);
00132       }
00133     else
00134     if(is_supported_complex_double<eT>::value == true)
00135       {
00136       typedef std::complex<double> T;
00137       zgemm_(transA, transB, m, n, k, (const T*)alpha, (const T*)A, ldA, (const T*)B, ldB, (const T*)beta, (T*)C, ldC);
00138       }
00139     
00140     }
00141   
00142   }
00143 
00144 #endif