cond_exp.cpp

View page source

Conditional Expressions: Example and Test

See Also

optimize_conditional_skip.cpp

Description

Use CondExp to compute

\[f(x) = \sum_{j=0}^{m-1} x_j \log( x_j )\]

and its derivative at various argument values ( where \(x_j \geq 0\) ) with out having to re-tape; i.e., using only one ADFun object. Note that \(x_j \log ( x_j ) \rightarrow 0\) as \(x_j \downarrow 0\) and we need to handle the case \(x_j = 0\) in a special way to avoid returning zero times minus infinity.

# include <cppad/cppad.hpp>
# include <limits>

bool CondExp(void)
{   bool ok = true;

    using CppAD::isnan;
    using CppAD::AD;
    using CppAD::NearEqual;
    using CppAD::log;
    double eps  = 100. * CppAD::numeric_limits<double>::epsilon();

    // domain space vector
    size_t n = 5;
    CPPAD_TESTVECTOR(AD<double>) ax(n);
    size_t j;
    for(j = 0; j < n; j++)
        ax[j] = 1.;

    // declare independent variables and start tape recording
    CppAD::Independent(ax);

    AD<double> asum  = 0.;
    AD<double> azero = 0.;
    for(j = 0; j < n; j++)
    {   // if x_j > 0, add x_j * log( x_j ) to the sum
        asum += CppAD::CondExpGt(ax[j], azero, ax[j] * log(ax[j]), azero);
    }

    // range space vector
    size_t m = 1;
    CPPAD_TESTVECTOR(AD<double>) ay(m);
    ay[0] = asum;

    // create f: x -> ay and stop tape recording
    CppAD::ADFun<double> f(ax, ay);

    // vectors for arguments to the function object f
    CPPAD_TESTVECTOR(double) x(n);   // argument values
    CPPAD_TESTVECTOR(double) y(m);   // function values
    CPPAD_TESTVECTOR(double) w(m);   // function weights
    CPPAD_TESTVECTOR(double) dw(n);  // derivative of weighted function

    // a case where x[j] > 0 for all j
    double check  = 0.;
    for(j = 0; j < n; j++)
    {   x[j]   = double(j + 1);
        check += x[j] * log( x[j] );
    }

    // function value
    y  = f.Forward(0, x);
    ok &= NearEqual(y[0], check, eps, eps);

    // compute derivative of y[0]
    w[0] = 1.;
    dw   = f.Reverse(1, w);
    for(j = 0; j < n; j++)
        ok &= NearEqual(dw[j], log(x[j]) + 1., eps, eps);

    // a case where x[3] is equal to zero
    check -= x[3] * log( x[3] );
    x[3]   = 0.;
    ok &= std::isnan( x[3] * log( x[3] ) );

    // function value
    y   = f.Forward(0, x);
    ok &= NearEqual(y[0], check, eps, eps);

    // check derivative of y[0]
    w[0] = 1.;
    dw   = f.Reverse(1, w);
    for(j = 0; j < n; j++)
    {   if( x[j] > 0 )
            ok &= NearEqual(dw[j], log(x[j]) + 1., eps, eps);
        else
            ok &= NearEqual(dw[j], 0.0, eps, eps);
    }

    return ok;
}