\(\newcommand{\W}[1]{ \; #1 \; }\) \(\newcommand{\R}[1]{ {\rm #1} }\) \(\newcommand{\B}[1]{ {\bf #1} }\) \(\newcommand{\D}[2]{ \frac{\partial #1}{\partial #2} }\) \(\newcommand{\DD}[3]{ \frac{\partial^2 #1}{\partial #2 \partial #3} }\) \(\newcommand{\Dpow}[2]{ \frac{\partial^{#1}}{\partial {#2}^{#1}} }\) \(\newcommand{\dpow}[2]{ \frac{ {\rm d}^{#1}}{{\rm d}\, {#2}^{#1}} }\)
lu_solve.hpp¶
View page sourceSource: LuSolve¶
#
ifndef CPPAD_LU_SOLVE_HPP
#
define CPPAD_LU_SOLVE_HPP
# include <complex>
# include <vector>
// link exp for float and double cases
# include <cppad/base_require.hpp>
# include <cppad/core/cppad_assert.hpp>
# include <cppad/utility/check_simple_vector.hpp>
# include <cppad/utility/check_numeric_type.hpp>
# include <cppad/utility/lu_factor.hpp>
# include <cppad/utility/lu_invert.hpp>
namespace CppAD { // BEGIN CppAD namespace
// LeqZero
template <class Float>
bool LeqZero(const Float &x)
{ return x <= Float(0); }
inline bool LeqZero( const std::complex<double> &x )
{ return x == std::complex<double>(0); }
inline bool LeqZero( const std::complex<float> &x )
{ return x == std::complex<float>(0); }
// LuSolve
template <class Float, class FloatVector>
int LuSolve(
size_t n ,
size_t m ,
const FloatVector &A ,
const FloatVector &B ,
FloatVector &X ,
Float &logdet )
{
// check numeric type specifications
CheckNumericType<Float>();
// check simple vector class specifications
CheckSimpleVector<Float, FloatVector>();
size_t p; // index of pivot element (diagonal of L)
int signdet; // sign of the determinant
Float pivot; // pivot element
// the value zero
const Float zero(0);
// pivot row and column order in the matrix
std::vector<size_t> ip(n);
std::vector<size_t> jp(n);
// -------------------------------------------------------
CPPAD_ASSERT_KNOWN(
size_t(A.size()) == n * n,
"Error in LuSolve: A must have size equal to n * n"
);
CPPAD_ASSERT_KNOWN(
size_t(B.size()) == n * m,
"Error in LuSolve: B must have size equal to n * m"
);
CPPAD_ASSERT_KNOWN(
size_t(X.size()) == n * m,
"Error in LuSolve: X must have size equal to n * m"
);
// -------------------------------------------------------
// copy A so that it does not change
FloatVector Lu(A);
// copy B so that it does not change
X = B;
// Lu factor the matrix A
signdet = LuFactor(ip, jp, Lu);
// compute the log of the determinant
logdet = Float(0);
for(p = 0; p < n; p++)
{ // pivot using the max absolute element
pivot = Lu[ ip[p] * n + jp[p] ];
// check for determinant equal to zero
if( pivot == zero )
{ // abort the mission
logdet = Float(0);
return 0;
}
// update the determinant
if( LeqZero ( pivot ) )
{ logdet += log( - pivot );
signdet = - signdet;
}
else
logdet += log( pivot );
}
// solve the linear equations
LuInvert(ip, jp, Lu, X);
// return the sign factor for the determinant
return signdet;
}
} // END CppAD namespace
# endif