lu_invert.hpp

View page source

Source: LuInvert

# ifndef CPPAD_LU_INVERT_HPP
# define CPPAD_LU_INVERT_HPP
# include <cppad/core/cppad_assert.hpp>
# include <cppad/utility/check_simple_vector.hpp>
# include <cppad/utility/check_numeric_type.hpp>

namespace CppAD { // BEGIN CppAD namespace

// LuInvert
template <class SizeVector, class FloatVector>
void LuInvert(
   const SizeVector  &ip,
   const SizeVector  &jp,
   const FloatVector &LU,
   FloatVector       &B )
{  size_t k; // column index in X
   size_t p; // index along diagonal in LU
   size_t i; // row index in LU and X

   typedef typename FloatVector::value_type Float;

   // check numeric type specifications
   CheckNumericType<Float>();

   // check simple vector class specifications
   CheckSimpleVector<Float, FloatVector>();
   CheckSimpleVector<size_t, SizeVector>();

   Float etmp;

   size_t n = ip.size();
   CPPAD_ASSERT_KNOWN(
      size_t(jp.size()) == n,
      "Error in LuInvert: jp must have size equal to n * n"
   );
   CPPAD_ASSERT_KNOWN(
      size_t(LU.size()) == n * n,
      "Error in LuInvert: Lu must have size equal to n * m"
   );
   size_t m = size_t(B.size()) / n;
   CPPAD_ASSERT_KNOWN(
      size_t(B.size()) == n * m,
      "Error in LuSolve: B must have size equal to a multiple of n"
   );

   // temporary storage for reordered solution
   FloatVector x(n);

   // loop over equations
   for(k = 0; k < m; k++)
   {  // invert the equation c = L * b
      for(p = 0; p < n; p++)
      {  // solve for c[p]
         etmp = B[ ip[p] * m + k ] / LU[ ip[p] * n + jp[p] ];
         B[ ip[p] * m + k ] = etmp;
         // subtract off effect on other variables
         for(i = p+1; i < n; i++)
            B[ ip[i] * m + k ] -=
               etmp * LU[ ip[i] * n + jp[p] ];
      }

      // invert the equation x = U * c
      p = n;
      while( p > 0 )
      {  --p;
         etmp       = B[ ip[p] * m + k ];
         x[ jp[p] ] = etmp;
         for(i = 0; i < p; i++ )
            B[ ip[i] * m + k ] -=
               etmp * LU[ ip[i] * n + jp[p] ];
      }

      // copy reordered solution into B
      for(i = 0; i < n; i++)
         B[i * m + k] = x[i];
   }
   return;
}
} // END CppAD namespace

# endif