OnnxSimplifier是一个用于简化onnx模型的工具,主要工具就是:拥有折叠常量(FoldConstant)的功能、自动调度OnnxOptimizer,最为重要而且核心的是FixedPointFn这个简化调度算法。
OnnxOptimizer是一个onnx官方的一个onnx模型优化库,内部包含很多模型简化/优化的功能。用户也可直接通过python/c++/c api执行调用,但是需要比较了解内部的opt优化手段,才能够得到理想的结果。
从上述的描述来看,似乎OnnxSimplifier也没有干什么事情,因为OnnxOptimizer才是干简化模型的主要工具。但是OnnxSimplifier主要有以下的几点主要优点和必要性让其比较突出:
基本原理:就是通过两个优化函数,反复迭代优化中得到了最终无法继续优化的最终模型。
FixedPointFn的原始代码如下:
template <typename T>
std::function<T(const T&)> FixedPointFn(const std::function<T(const T&)>& f1,
const std::function<T(const T&)>& f2,
size_t max_iters, bool* converged) {
return [f1, f2, max_iters, converged](const T& x) {
size_t _max_iters = max_iters;
T tmp1 = f1(x);
T tmp2 = f2(tmp1);
T& y1 = tmp1;
T& y2 = tmp2;
while (_max_iters-- > 0) {
// 超出迭代次数则跳出
if (google::protobuf::util::MessageDifferencer::Equals(y1, y2)) {
// f1(x) == f2(f1(x))时,则无法继续优化,直接返回f2(f1(x))
if (converged) {
*converged = true;
}
return y2;
}
y1 = f1(y2);
if (google::protobuf::util::MessageDifferencer::Equals(y1, y2)) {
if (converged) {
*converged = true;
}
return y1;
}
y2 = f2(y1);
}
if (converged) {
*converged = false;
}
return y2;
};
}
FixedPointFn的流程图如下所示:
这次主要是介绍了OnnxSimplifier简化原理,重点介绍了FoldConstant功能和FixedPointFn迭代优化函数,这是该简化包的核心部分了。但是对于其他的OptimizeFixed,也就是OnnxOptimizer函数库的内部简化细节却没有具体的说明。后续将会具体介绍OnnxOptimizer的模型优化细节。