lu_solve.hpp

View page source

Source: LuSolve

# ifndef CPPAD_LU_SOLVE_HPP
# define CPPAD_LU_SOLVE_HPP
# include <complex>
# include <vector>

// link exp for float and double cases
# include <cppad/base_require.hpp>

# include <cppad/core/cppad_assert.hpp>
# include <cppad/utility/check_simple_vector.hpp>
# include <cppad/utility/check_numeric_type.hpp>
# include <cppad/utility/lu_factor.hpp>
# include <cppad/utility/lu_invert.hpp>

namespace CppAD { // BEGIN CppAD namespace

// LeqZero
template <class Float>
bool LeqZero(const Float &x)
{  return x <= Float(0); }
inline bool LeqZero( const std::complex<double> &x )
{  return x == std::complex<double>(0); }
inline bool LeqZero( const std::complex<float> &x )
{  return x == std::complex<float>(0); }

// LuSolve
template <class Float, class FloatVector>
int LuSolve(
   size_t             n      ,
   size_t             m      ,
   const FloatVector &A      ,
   const FloatVector &B      ,
   FloatVector       &X      ,
   Float        &logdet      )
{
   // check numeric type specifications
   CheckNumericType<Float>();

   // check simple vector class specifications
   CheckSimpleVector<Float, FloatVector>();

   size_t        p;       // index of pivot element (diagonal of L)
   int     signdet;       // sign of the determinant
   Float     pivot;       // pivot element

   // the value zero
   const Float zero(0);

   // pivot row and column order in the matrix
   std::vector<size_t> ip(n);
   std::vector<size_t> jp(n);

   // -------------------------------------------------------
   CPPAD_ASSERT_KNOWN(
      size_t(A.size()) == n * n,
      "Error in LuSolve: A must have size equal to n * n"
   );
   CPPAD_ASSERT_KNOWN(
      size_t(B.size()) == n * m,
      "Error in LuSolve: B must have size equal to n * m"
   );
   CPPAD_ASSERT_KNOWN(
      size_t(X.size()) == n * m,
      "Error in LuSolve: X must have size equal to n * m"
   );
   // -------------------------------------------------------

   // copy A so that it does not change
   FloatVector Lu(A);

   // copy B so that it does not change
   X = B;

   // Lu factor the matrix A
   signdet = LuFactor(ip, jp, Lu);

   // compute the log of the determinant
   logdet  = Float(0);
   for(p = 0; p < n; p++)
   {  // pivot using the max absolute element
      pivot   = Lu[ ip[p] * n + jp[p] ];

      // check for determinant equal to zero
      if( pivot == zero )
      {  // abort the mission
         logdet = Float(0);
         return   0;
      }

      // update the determinant
      if( LeqZero ( pivot ) )
      {  logdet += log( - pivot );
         signdet = - signdet;
      }
      else
         logdet += log( pivot );

   }

   // solve the linear equations
   LuInvert(ip, jp, Lu, X);

   // return the sign factor for the determinant
   return signdet;
}
} // END CppAD namespace

# endif