\(\newcommand{\W}[1]{ \; #1 \; }\) \(\newcommand{\R}[1]{ {\rm #1} }\) \(\newcommand{\B}[1]{ {\bf #1} }\) \(\newcommand{\D}[2]{ \frac{\partial #1}{\partial #2} }\) \(\newcommand{\DD}[3]{ \frac{\partial^2 #1}{\partial #2 \partial #3} }\) \(\newcommand{\Dpow}[2]{ \frac{\partial^{#1}}{\partial {#2}^{#1}} }\) \(\newcommand{\dpow}[2]{ \frac{ {\rm d}^{#1}}{{\rm d}\, {#2}^{#1}} }\)
sacado_det_lu.cpp¶
View page sourceSacado Speed: Gradient of Determinant Using Lu Factorization¶
Specifications¶
See link_det_lu .
Implementation¶
// suppress conversion warnings before other includes
# include <cppad/wno_conversion.hpp>
//
# include <Sacado.hpp>
# include <cppad/speed/det_by_lu.hpp>
# include <cppad/speed/uniform_01.hpp>
# include <cppad/utility/vector.hpp>
// list of possible options
# include <map>
extern std::map<std::string, bool> global_option;
bool link_det_lu(
size_t size ,
size_t repeat ,
CppAD::vector<double> &matrix ,
CppAD::vector<double> &gradient )
{
// speed test global option values
if( global_option["onetape"] || global_option["atomic"] )
return false;
if( global_option["memory"] || global_option["optimize"] )
return false;
// -----------------------------------------------------
// setup
//
// object for computing determinant
typedef Sacado::Rad::ADvar<double> ADScalar;
typedef CppAD::vector<ADScalar> ADVector;
CppAD::det_by_lu<ADScalar> Det(size);
size_t i; // temporary index
size_t n = size * size; // number of independent variables
ADScalar detA; // AD value of the determinant
ADVector A(n); // AD version of matrix
// ------------------------------------------------------
while(repeat--)
{ // get the next matrix
CppAD::uniform_01(n, matrix);
// set independent variable values
for(i = 0; i < n; i++)
A[i] = matrix[i];
// compute the determinant
detA = Det(A);
// compute the gradient of detA
ADScalar::Gradcomp();
// evaluate and return gradient using reverse mode
for(i =0; i < n; i++)
gradient[i] = A[i].adj(); // partial detA w.r.t A[i]
}
// ---------------------------------------------------------
return true;
}