test-autodiff/test/forward/4multi_variable_function_wi...

60 lines
2.1 KiB
C++
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

// NOTE: 带参数的多变量函数的导数
// C++ includes
#include <iostream>
// autodiff include
#include <autodiff/forward/dual.hpp>
using namespace autodiff;
// A type defining parameters for a function of interest
struct Params
{
dual a;
dual b;
dual c;
};
// The function that depends on parameters for which derivatives are needed
dual f(dual x, const Params& params)
{
return params.a * sin(x) + params.b * cos(x) + params.c * sin(x)*cos(x);
}
int main()
{
Params params; // initialize the parameter variables
params.a = 1.0; // the parameter a of type dual, not double!
params.b = 2.0; // the parameter b of type dual, not double!
params.c = 3.0; // the parameter c of type dual, not double!
dual x = 0.5; // the input variable x
dual u = f(x, params); // the output variable u
double dudx = derivative(f, wrt(x), at(x, params)); // evaluate the derivative du/dx
double duda = derivative(f, wrt(params.a), at(x, params)); // evaluate the derivative du/da
double dudb = derivative(f, wrt(params.b), at(x, params)); // evaluate the derivative du/db
double dudc = derivative(f, wrt(params.c), at(x, params)); // evaluate the derivative du/dc
std::cout << "u = " << u << std::endl; // print the evaluated output u
std::cout << "du/dx = " << dudx << std::endl; // print the evaluated derivative du/dx
std::cout << "du/da = " << duda << std::endl; // print the evaluated derivative du/da
std::cout << "du/db = " << dudb << std::endl; // print the evaluated derivative du/db
std::cout << "du/dc = " << dudc << std::endl; // print the evaluated derivative du/dc
}
/* NOTE:
This example would also work if real was used instead of dual. Should you
need higher-order cross derivatives, however, e.g.,:
double d2udxda = derivative(f, wrt(x, params.a), at(x, params));
then higher-order dual types are the right choicesince real types are
optimally designed for higher-order directional derivatives.
如果需要计算更高阶的交叉倒数例如d2udxda应该使用更高阶的dual类型而不是real类型
*/