ode_evaluate.hpp

View page source

Source: ode_evaluate

# ifndef CPPAD_ODE_EVALUATE_HPP
# define CPPAD_ODE_EVALUATE_HPP
# include <cppad/utility/vector.hpp>
# include <cppad/utility/ode_err_control.hpp>
# include <cppad/utility/runge_45.hpp>

namespace CppAD {

   template <class Float>
   class ode_evaluate_fun {
   public:
      // Given that y_i (0) = x_i,
      // the following y_i (t) satisfy the ODE below:
      // y_0 (t) = x[0]
      // y_1 (t) = x[1] + x[0] * t
      // y_2 (t) = x[2] + x[1] * t + x[0] * t^2/2
      // y_3 (t) = x[3] + x[2] * t + x[1] * t^2/2 + x[0] * t^3 / 3!
      // ...
      void Ode(
         const Float&                    t,
         const CppAD::vector<Float>&     y,
         CppAD::vector<Float>&           f)
      {  size_t n  = y.size();
         f[0]      = 0.;
         for(size_t k = 1; k < n; k++)
            f[k] = y[k-1];
      }
   };
   //
   template <class Float>
   void ode_evaluate(
      const CppAD::vector<Float>& x  ,
      size_t                      p  ,
      CppAD::vector<Float>&       fp )
   {  using CppAD::vector;
      typedef vector<Float> FloatVector;

      size_t n = x.size();
      CPPAD_ASSERT_KNOWN( p == 0 || p == 1,
         "ode_evaluate: p is not zero or one"
      );
      CPPAD_ASSERT_KNOWN(
         ((p==0) & (fp.size()==n)) || ((p==1) & (fp.size()==n*n)),
         "ode_evaluate: the size of fp is not correct"
      );
      if( p == 0 )
      {  // function that defines the ode
         ode_evaluate_fun<Float> F;

         // number of Runge45 steps to use
         size_t M = 10;

         // initial and final time
         Float ti = 0.0;
         Float tf = 1.0;

         // initial value for y(x, t); i.e. y(x, 0)
         // (is a reference to x)
         const FloatVector& yi = x;

         // final value for y(x, t); i.e., y(x, 1)
         // (is a reference to fp)
         FloatVector& yf = fp;

         // Use fourth order Runge-Kutta to solve ODE
         yf = CppAD::Runge45(F, M, ti, tf, yi);

         return;
      }
      /* Compute derivaitve of y(x, 1) w.r.t x
      y_0 (x, t) = x[0]
      y_1 (x, t) = x[1] + x[0] * t
      y_2 (x, t) = x[2] + x[1] * t + x[0] * t^2/2
      y_3 (x, t) = x[3] + x[2] * t + x[1] * t^2/2 + x[0] * t^3 / 3!
      ...
      */
      size_t i, j, k;
      for(i = 0; i < n; i++)
      {  for(j = 0; j < n; j++)
            fp[ i * n + j ] = 0.0;
      }
      size_t factorial = 1;
      for(k = 0; k < n; k++)
      {  if( k > 1 )
            factorial *= k;
         for(i = k; i < n; i++)
         {  // partial w.r.t x[i-k] of x[i-k] * t^k / k!
            j = i - k;
            fp[ i * n + j ] += 1.0 / Float(factorial);
         }
      }
   }
}

# endif