det_by_lu.hpp

View page source

Source: det_by_lu

# ifndef CPPAD_DET_BY_LU_HPP
# define CPPAD_DET_BY_LU_HPP
# include <cppad/utility/vector.hpp>
# include <cppad/utility/lu_solve.hpp>

// BEGIN CppAD namespace
namespace CppAD {

template <class Scalar>
class det_by_lu {
private:
   const size_t m_;
   const size_t n_;
   CppAD::vector<Scalar> A_;
   CppAD::vector<Scalar> B_;
   CppAD::vector<Scalar> X_;
public:
   det_by_lu(size_t n) : m_(0), n_(n), A_(n * n)
   {  }

   template <class Vector>
   Scalar operator()(const Vector &x)
   {

      Scalar       logdet;
      Scalar       det;
      int          signdet;
      size_t       i;

      // copy matrix so it is not overwritten
      for(i = 0; i < n_ * n_; i++)
         A_[i] = x[i];

      // comput log determinant
      signdet = CppAD::LuSolve(
         n_, m_, A_, B_, X_, logdet);

/*
      // Do not do this for speed test because it makes floating
      // point operation sequence very simple.
      if( signdet == 0 )
         det = 0;
      else
         det =  Scalar( signdet ) * exp( logdet );
*/

      // convert to determinant
      det     = Scalar( signdet ) * exp( logdet );

# ifdef FADBAD
      // Fadbad requires tempories to be set to constants
      for(i = 0; i < n_ * n_; i++)
         A_[i] = 0;
# endif

      return det;
   }
};
} // END CppAD namespace

# endif