Automatic Derivatives

现在我们将考虑自动微分。这是一种可以快速计算精确导数的技术,而用户所需的工作量与使用数值微分法大致相同。下面的代码片段为 Rat43 实现了一个自动微分的 CostFunction。

struct Rat43CostFunctor {
  Rat43CostFunctor(const double x, const double y) : x_(x), y_(y) {}

  template <typename T>
  bool operator()(const T* parameters, T* residuals) const {
    const T b1 = parameters[0];
    const T b2 = parameters[1];
    const T b3 = parameters[2];
    const T b4 = parameters[3];
    residuals[0] = b1 * pow(1.0 + exp(b2 -  b3 * x_), -1.0 / b4) - y_;
    return true;
  }

  private:
    const double x_;
    const double y_;
};


auto* cost_function = new AutoDiffCostFunction<Rat43CostFunctor, 1, 4>(x, y);

请注意,与数值微分相比,在定义用于自动微分的 functor 时,唯一的区别在于 operator() 的重载。在数值微分的情况下,它是

bool operator()(const double* parameters, double* residuals) const;

对于自动微分,它是一个模板函数

template <typename T> bool operator()(const T* parameters, T* residuals) const;

那么,这个微小的变化能给我们带来什么呢?下表比较了使用各种方法计算 Rat43 的残差和 Jacobian 所需的时间。

CostFunctionTime (ns)

Rat43Analytic

255

Rat43AnalyticOptimized

92

Rat43NumericDiffForward

262

Rat43NumericDiffCentral

517

Rat43NumericDiffRidders

3760

Rat43AutomaticDiff

129

我们可以使用自动微分 (Rat43AutomaticDiff)获得精确导数,所需的工作量与编写数值微分代码差不多,但只比优化后的解析导数慢 40%。那么它是如何工作的呢?为此,我们必须了解 Dual Numbers and Jets

Dual Numbers & Jets

Reading this and the next section on implementing Jets is not necessary to use automatic differentiation in Ceres Solver. But knowing the basics of how Jets work is useful when debugging and reasoning about the performance of automatic differentiation.

二元数是实数的扩展,类似于复数:复数通过引入虚数单位 ι\iota (使得 ι2=1\iota^{2}=-1 )来扩展实数,而二元数则引入无穷小单位 ϵ\epsilon (使得 ϵ2=0\epsilon^{2}=0 )。对偶数 a+vϵa+v \epsilon 有两个分量,即实数分量 aa 和无穷小分量 vv。这一简单的变化带来了计算精确导数的便捷方法,而无需操作复杂的符号表达式。例如如下函数

f(x)=x2f(x)=x^{2}

有:

f(10+ϵ)=(10+ϵ)2=100+20ϵ+ϵ2=100+20ϵ\begin{aligned} f(10+\epsilon) & =(10+\epsilon)^{2} \\ & =100+20 \epsilon+\epsilon^{2} \\ & =100+20 \epsilon \end{aligned}

请注意,ϵ\epsilon 的系数是 Df(10)=20D f(10)=20 。事实上,这也适用于非多项式函数。考虑一个任意可微函数 f(x)f(x) 。那么我们可以通过考虑 ffxx 附近的泰勒展开求出 f(x+ϵ)f(x+\epsilon),从而得到无穷级数

f(x+ϵ)=f(x)+Df(x)ϵ+D2f(x)ϵ22+D3f(x)ϵ36+f(x+ϵ)=f(x)+Df(x)ϵ\begin{array}{l} f(x+\epsilon)=f(x)+D f(x) \epsilon+D^{2} f(x) \frac{\epsilon^{2}}{2}+D^{3} f(x) \frac{\epsilon^{3}}{6}+\cdots \\ f(x+\epsilon)=f(x)+D f(x) \epsilon \end{array}

这里我们使用了 ϵ2=0\epsilon^{2}=0 这一假设。

Jet 是一个 nn 维的对偶数,我们用 nn 个无穷小单位 ϵi,i=1,,n\epsilon_{i}, i=1, \ldots, n 来扩展实数,其性质是: i,jϵiϵj=0\forall i, j: \epsilon_{i} \epsilon_{j}=0 。那么一个 Jet 由实部 aann 维无穷小部 v\mathbf{v} 组成,即:

x=a+jvjϵjx=a+\sum_{j} v_{j} \epsilon_{j}

求和符号会变得繁琐,因此我们也只需写出

x=a+vx=a+\mathbf{v} \text {. }

进行泰勒展开可得

f(a+v)=f(a)+Df(a)v.f(a+\mathbf{v})=f(a)+D f(a) \mathbf{v} .

类似的对于多元函数 f:RnRmf: \mathbb{R}^{n} \rightarrow \mathbb{R}^{m} , xi=ai+vi,i=1,,nx_{i}=a_{i}+\mathbf{v}_{i}, \forall i=1, \ldots, n :

f(x1,,xn)=f(a1,,an)+iDif(a1,,an)vif\left(x_{1}, \ldots, x_{n}\right)=f\left(a_{1}, \ldots, a_{n}\right)+\sum_{i} D_{i} f\left(a_{1}, \ldots, a_{n}\right) \mathbf{v}_{i}

因此,如果每个 vi=ei\mathbf{v}_{i}=e_{i} 都是第 ith i^{\text {th }} 个标准基向量,那么上述表达式将简化为

f(x1,,xn)=f(a1,,an)+iDif(a1,,an)ϵif\left(x_{1}, \ldots, x_{n}\right)=f\left(a_{1}, \ldots, a_{n}\right)+\sum_{i} D_{i} f\left(a_{1}, \ldots, a_{n}\right) \epsilon_{i}

我们可以通过检查 ϵi\epsilon_{i} 的系数来提取 Jacobian 。

Implementing Jets

为了使上述公式在实际应用中发挥作用,我们需要对任意函数 ff 进行求值,这不仅包括对实数的求值,还包括对二元数的求值,但我们通常不会通过对函数的泰勒展开式进行求值。这就是 C++ 模板和操作符重载发挥作用的地方。下面的代码片段简单地实现了 Jet 和一些对其进行操作的运算符/函数。

template<int N> struct Jet {
  double a;
  Eigen::Matrix<double, 1, N> v;
};

template<int N> Jet<N> operator+(const Jet<N>& f, const Jet<N>& g) {
  return Jet<N>(f.a + g.a, f.v + g.v);
}

template<int N> Jet<N> operator-(const Jet<N>& f, const Jet<N>& g) {
  return Jet<N>(f.a - g.a, f.v - g.v);
}

template<int N> Jet<N> operator*(const Jet<N>& f, const Jet<N>& g) {
  return Jet<N>(f.a * g.a, f.a * g.v + f.v * g.a);
}

template<int N> Jet<N> operator/(const Jet<N>& f, const Jet<N>& g) {
  return Jet<N>(f.a / g.a, f.v / g.a - f.a * g.v / (g.a * g.a));
}

template <int N> Jet<N> exp(const Jet<N>& f) {
  return Jet<T, N>(exp(f.a), exp(f.a) * f.v);
}

// This is a simple implementation for illustration purposes, the
// actual implementation of pow requires careful handling of a number
// of corner cases.
template <int N>  Jet<N> pow(const Jet<N>& f, const Jet<N>& g) {
  return Jet<N>(pow(f.a, g.a),
                g.a * pow(f.a, g.a - 1.0) * f.v +
                pow(f.a, g.a) * log(f.a); * g.v);
}

有了这些重载函数,我们现在就可以使用 Jets 数组而不是 double 数组来调用 Rat43CostFunctor 了。再加上适当初始化的 Jets,我们就可以计算出 Jacobian 如下:

class Rat43Automatic : public ceres::SizedCostFunction<1,4> {
 public:
  Rat43Automatic(const Rat43CostFunctor* functor) : functor_(functor) {}
  virtual ~Rat43Automatic() {}
  virtual bool Evaluate(double const* const* parameters,
                        double* residuals,
                        double** jacobians) const {
    // Just evaluate the residuals if Jacobians are not required.
    if (!jacobians) return (*functor_)(parameters[0], residuals);

    // Initialize the Jets
    ceres::Jet<4> jets[4];
    for (int i = 0; i < 4; ++i) {
      jets[i].a = parameters[0][i];
      jets[i].v.setZero();
      jets[i].v[i] = 1.0;
    }

    ceres::Jet<4> result;
    (*functor_)(jets, &result);

    // Copy the values out of the Jet.
    residuals[0] = result.a;
    for (int i = 0; i < 4; ++i) {
      jacobians[0][i] = result.v[i];
    }
    return true;
  }

 private:
  std::unique_ptr<const Rat43CostFunctor> functor_;
};

事实上,自动求导就是这么工作的。

Pitfalls

自动微分将用户从计算和推理雅可比符号表达式的负担中解脱出来,但这是有代价的。例如,请看下面这个简单的函数:

y=1x02+x12y =1-\sqrt{x_{0}^{2}+x_{1}^{2}} \\

对应的 Functor 构建为

struct Functor {
  template <typename T> bool operator()(const T* x, T* residual) const {
    residual[0] = 1.0 - sqrt(x[0] * x[0] + x[1] * x[1]);
    return true;
  }
};

从残差计算的代码来看,我们并没有发现到任何问题。但是,如果我们看一下 Jacobian 的解析表达式,就会发现:

D1y=x0x02+x12,D2y=x1x02+x12\begin{aligned} D_{1} y & =-\frac{x_{0}}{\sqrt{x_{0}^{2}+x_{1}^{2}}}, D_{2} y=-\frac{x_{1}}{\sqrt{x_{0}^{2}+x_{1}^{2}}} \end{aligned}

我们发现它在 x1=0,x2=0x_1=0,x_2=0 时,导数是无法计算的。这个问题没有固定的解决方案。在某些情况下,我们需要明确推理可能出现不确定性的点,并使用 L’Hôpital’s rule 来替代表达式(例如,请参阅 rotation.h 中的一些转换例程)。在其他情况下,可能需要对表达式进行正则化处理,以消除这些点。

Last updated