det_by_lu.hpp

View page source

Source: det_by_lu

# ifndef CPPAD_DET_BY_LU_HPP
# define CPPAD_DET_BY_LU_HPP
# include <cppad/utility/vector.hpp>
# include <cppad/utility/lu_solve.hpp>

// BEGIN CppAD namespace
namespace CppAD {

template <class Scalar>
class det_by_lu {
private:
    const size_t m_;
    const size_t n_;
    CppAD::vector<Scalar> A_;
    CppAD::vector<Scalar> B_;
    CppAD::vector<Scalar> X_;
public:
    det_by_lu(size_t n) : m_(0), n_(n), A_(n * n)
    {   }

    template <class Vector>
    Scalar operator()(const Vector &x)
    {

        Scalar       logdet;
        Scalar       det;
        int          signdet;
        size_t       i;

        // copy matrix so it is not overwritten
        for(i = 0; i < n_ * n_; i++)
            A_[i] = x[i];

        // compute log determinant
        signdet = CppAD::LuSolve(
            n_, m_, A_, B_, X_, logdet);

/*
        // Do not do this for speed test because it makes floating
        // point operation sequence very simple.
        if( signdet == 0 )
            det = 0;
        else
            det =  Scalar( signdet ) * exp( logdet );
*/

        // convert to determinant
        det     = Scalar( signdet ) * exp( logdet );

# ifdef FADBAD
        // Fadbad requires temporaries to be set to constants
        for(i = 0; i < n_ * n_; i++)
            A_[i] = 0;
# endif

        return det;
    }
};
} // END CppAD namespace

# endif