一、Lower 主邏輯
在 OpLower::Lower()
接口中,主要分爲兩大類:
- Elementwise類,主要涉及的
OpPattern
包括:kElementwise
、kBroadcast
、kInjective
- Reduce 類,主要涉及的OpPattern包括:
kReduction
std::vector<ir::LoweredFunc> OpLowerer::Lower(GroupPtr& group) {
VLOG(3) << "Lowering Group : " << group->group_id << " , Op Pattern : " << group->op_pattern_kind;
group->input_names.clear();
group->output_names.clear();
if (FLAGS_cinn_ir_schedule) {
switch (group->op_pattern_kind) {
case framework::kElementWise:
case framework::kBroadcast:
case framework::kInjective:
return IRLowerOp(&OpLowerer::IRElementwiseCompute, &OpLowerer::IRElementwiseSchedule, group); // << --- 第一大類 Elementwise相關
case framework::kReduction:
return IRLowerOp(&OpLowerer::IRReduceCompute, &OpLowerer::IRReduceSchedule, group); // << --- 第二大類 Reduce 相關
case framework::kOutFusible:
LOG(FATAL) << "Group Pattern Kind kOutFusible Is Not Implemented!";
case framework::kNonFusible:
return IRLowerNonFusibleOp(group, /*apply_impl_schedule = */ true);
default:
LOG(FATAL) << "Group Pattern Kind Is Unknown!";
}
} else {
LOG(FATAL) << "Previous IR Schedule Is Not Implemented!";
}
}
二、Optimize 邏輯
在 op_lowering.cc
中的 IRLowerOp
的最後,會創建一個 LoweredFunc
對象,並對其調用 optim::Optimize()
函數。
std::vector<ir::LoweredFunc> OpLowerer::IRLowerOp(compute, schedule, group){
// .... 省略
auto temp_buffers = lang::GetTempBuffers(arg_tensors, stages, func_body);
auto func =
ir::_LoweredFunc_::Make(group->GetFuncName(), func_args, ir_sch.GetModule().GetExprs().at(0), temp_buffers);
func = optim::Optimize(Expr(func), target_, false).as_lowered_func_ref(); // <----在函數最後會調用 Optimizer 對函數表達式進行優化,注意與 target 相關
return {func};
}
其接口實現是在 optimize.cc
文件中,主要是對 LoweredFunc
對應的Expr應用各種優化 pass:
// 中層是 optimize.cc 中的:
Expr Optimize(Expr e, Target target, bool runtime_debug_info, bool remove_gpu_for_loops) {
auto copied = IRCopy(e);
// 與 target 無關的通用優化
FoldCINNCallArguments(&copied);
TransformPolyForToFor(&copied);
ReplaceConstParamToInteger(&copied);
CastSimplify(&copied);
Simplify(&copied);
UnrollLoop(&copied);
// 與 target 有關的優化
VectorizeLoops(&copied, target);
MapExternCall(&copied, target); // <---- 此處是這裏要關注和討論的 MapExternCall 優化
// 僅在 CUDA 上的優化
ir::SetCudaAxisInfo(&copied);
RemoveGpuForloopsAxis(&copied);
CudaSyncThreadsDropIfThenElse(&copied);
// 又是與 target 無關的通用優化
RemoveNestedBlock(&copied);
ExternCallMultiOutputShallowStore(&copied);
CastSimplify(&copied);
Simplify(&copied);
IfSimplify(&copied);
// 與調試相關通用優化
InsertDebugLogCallee(&copied);
}
三、各個優化Pass
接下來,我們逐個來研究每個 pass 的角色和作用。
3.1 FoldCINNCallArguments
此 Pass 的功能是通過 FoldCINNCallArgumentsMutator
來實現的:
void FoldCINNCallArguments(Expr* expr) { FoldCINNCallArgumentsMutator()(expr); }
此 Mutator 只關心ir::Block和ir::Store兩種類型節點:
struct FoldCINNCallArgumentsMutator : public ir::IRMutator<> {
void operator()(Expr* expr) { ir::IRMutator<>::Visit(expr, expr); }
private:
void Visit(const ir::Block* op, Expr* expr); // <----- Block
void Visit(const ir::Store* op, Expr* expr); // <----- Store
void MutateCall(ir::Call* call);
private:
// To avoid the same call triggered duplicately.
std::unordered_set<std::string> visited_call_;
};
其中對於 ir::Block
類型的節點,找到所有的 CallType::CINN
的 Expr,然後判斷是否已經存在過,若是,則刪除 Block 中的此 statement 語句。
這裏補充下 ir::Call
節點中的 CallType
枚舉類型的值都有哪些:CINN
、Intrinsic
、Extern
、ISL
void Visit(const ir::Block* op, Expr* expr) override {
auto* node = expr->As<ir::Block>();
for (auto it = node->stmts.begin(); it != node->stmts.end();) {
if (it->As<ir::Store>()) {
auto* call = it->As<ir::Store>()->value.As<ir::Call>(); // <---- 針對 x = cinn_call_func(args) 場景?
if (call && call->is_cinn_call()) {
// remove the duplicate calls.
std::string key = utils::GetStreamCnt(Expr(call));
if (visited_call_.count(key)) {
it = node->stmts.erase(it);
continue;
}
ir::IRMutator<>::Visit(&(*it), &(*it)); // <--- 這裏會觸發下面的 ir::Store 的處理邏輯
visited_call_.insert(key);
continue;
}
}
ir::IRMutator<>::Visit(&(*it), &(*it));
++it;
}
}
對於 ir::Store
類型的節點,僅針對 CallType::CINN
類型的節點調用 MutateCall
函數進行修改和替換;
void Visit(const ir::Store* op, Expr* expr) override {
auto* node = expr->As<ir::Store>();
if (node->value.As<ir::Call>()) {
auto* call = node->value.As<ir::Call>();
switch (call->call_type) {
case ir::CallType::CINN:
MutateCall(call);
*expr = node->value;
break;
case ir::CallType::Intrinsic:
break;
case ir::CallType::Extern:
break;
default:
CINN_NOT_IMPLEMENTED
}
}
}
MuteCall
函數是此 Pass 的最核心邏輯,其作用是 call 節點中所有的輸入、輸出 args 中的 Tensor 類型,確認其都 defined 了 buffer ,並將 buffer 作爲真正的 args 替換原來的 read_args 和 write_args 。
思考:爲什麼要單獨對CINN類型的CallType多做這樣一件事情呢?背景是什麼?
void MutateCall(ir::Call* call) {
if (call->call_type == ir::CallType::Extern) return;
std::vector<Expr> read_args;
std::vector<Expr> write_args;
for (auto& arg : call->read_args) {
if (arg.as_tensor()) {
CHECK(arg.as_tensor()->buffer.defined()) << "arg tensor [" << arg.as_tensor()->name << "] not has buffer";
read_args.push_back(arg.as_tensor()->buffer);
} else {
read_args.push_back(arg);
}
}
for (auto& arg : call->write_args) {
if (arg.as_tensor()) {
write_args.push_back(arg.as_tensor()->buffer);
} else {
write_args.push_back(arg);
}
}
call->read_args = read_args;
call->write_args = write_args;
}
3.2 ReplaceConstParamToInteger
這個 Pass 比較簡單,只針對 ir::Var
類型的節點,如果其 name 是以 _const_
開頭的,則取其具體的值,轉爲Expr
(如 Intmm)
static const char* kIslParamConstPrefix = "_const_";
struct Mutator : public ir::IRMutator<> {
using ir::IRMutator<>::Visit;
void Visit(const ir::_Var_* op, Expr* expr) override {
if (utils::Startswith(op->name, poly::kIslParamConstPrefix)) {
std::string value = op->name.substr(strlen(poly::kIslParamConstPrefix));
*expr = Expr(std::stoi(value)); // <----- 這裏強轉爲 int 類型,是隻存在類似 _const_12 這種情況
}
}
};
} // namespace
void ReplaceConstParamToInteger(Expr* e) {
Mutator mutator;
mutator.Visit(e, e);
}
那這個const
前綴字符串拼接是在哪裏做的呢?是在 cinn::poly::ast_gen
中做的,相關邏輯代碼如下:
isl::set TransIdentityExtentToContextId(isl::set set) {
std::vector<std::tuple<int, int>> iden_dim_offsets;
for (int i = 0; i < isl_set_dim(set.get(), isl_dim_set); i++) {
if (isl_set_axis_has_noparam_constant_bound(set.get(), i)) {
auto range = isl_set_get_axis_range(set.get(), i);
auto& minv = std::get<0>(range);
auto& maxv = std::get<1>(range);
int min_iv = minv.get_num_si();
int max_iv = maxv.get_num_si();
if (max_iv == min_iv) {
iden_dim_offsets.emplace_back(i, max_iv);
}
}
}
isl::set res_set = set;
for (auto offset_val : iden_dim_offsets) {
auto& offset = std::get<0>(offset_val);
auto& val = std::get<1>(offset_val); // <---- 是個 int 類型
res_set = isl::manage(isl_set_drop_constraints_involving_dims(res_set.copy(), isl_dim_set, offset, 1));
std::string const_param_name = llvm::formatv("{0}{1}", kIslParamConstPrefix, val); //<---- 在此處進行拼接的
std::string cond_str = llvm::formatv(
"{0} <= {1} <= {2}", val, isl_set_get_dim_name(res_set.get(), isl_dim_set, offset), const_param_name);
std::string param_cond_str = llvm::formatv("{0} <= {1} < {2}", val, const_param_name, val + 2);
std::string set_repr = llvm::formatv("[{0}] -> { {1}[{2}]: {3} and {4} }",
const_param_name,
isl_set_get_tuple_name(res_set.get()),
utils::Join(isl_get_dim_names(res_set.get()), ","),
cond_str,
param_cond_str);
VLOG(4) << "repr: " << set_repr;
isl::set new_set(res_set.ctx(), set_repr);
res_set = res_set.intersect(new_set);
}
return res_set;
}
注:通過檢索Bert 模型中的GLOG_v=10 的日誌,並沒有發現
ReplaceConstParamToInteger
生效的地方。
如下是一個 CINN 框架裏的單測,可以輔助幫助理解上面這個函數的作用效果,樣例中 j=0
,其中會把 0 這個常量值先創建一個 _const_0
,然後做了變換?
TEST(TransIdentityExtentToContextId, basic) {
isl_ctx* ctx = isl_ctx_alloc();
isl::set set(ctx, "{ s[i,j=0,k] : 0<=i<12 and 12<k<32 }");
auto new_set = TransIdentityExtentToContextId(set);
LOG(INFO) << new_set;
ASSERT_EQ(utils::GetStreamCnt(new_set),
"[_const_0] -> { s[i, j, k] : _const_0 <= 1 and 0 <= i <= 11 and 0 <= j <= _const_0 and 13 <= k <= 31 }");
}
3.3 CastSimplify
此Pass 僅會對 constant 的Expr進行處理,比如 IntImm、FloatImm、UIntImm,作用是將其持有的value值按照 ir::Cast.type()
進行數值類型轉換,然後包裹一個Expr重新返回。
void CastSimplify(Expr* e) {
Mutator mutator;
mutator.Visit(e, e);
}
struct Mutator : ir::IRMutator<> {
using ir::IRMutator<>::Visit;
void Visit(const ir::Cast* op, Expr* expr) {
auto* node = expr->As<ir::Cast>();
Visit(&node->v(), &node->v()); // <<--- 類似 AST 的 generic_visit,深度優先遞歸處理 node->v() 節點
if (op->type() == op->v().type()) {
*expr = op->v(); // Caset 1: 如果 value 類型已經與dst_type 一致,則直接返回 node->v() 以替換當前節點。
return;
}
#define __CAST_TO_TYPE(type__) \
if (auto* i = op->v().As<ir::IntImm>()) { \
*expr = Expr(static_cast<type__>(i->value)); \
} else if (auto* f = op->v().As<ir::FloatImm>()) { \
*expr = Expr(static_cast<type__>(NormCastValue<type__>(f->value))); \ // <<----- 注意:這裏對Float類型進行了特殊處理,因爲存在轉Float16的場景
} else if (auto* u = op->v().As<ir::UIntImm>()) { \
*expr = Expr(static_cast<type__>(u->value)); \
} else { \
CINN_NOT_IMPLEMENTED \
}
if (op->v().is_constant()) { // <----- 注意:此pass僅支持 ir::Cast->v()爲常量類型的場景
if (op->type() == type_of<int8_t>()) {
__CAST_TO_TYPE(int8_t)
} else if (op->type() == type_of<int16_t>()) {
__CAST_TO_TYPE(int16_t)
} else if (op->type() == type_of<int32_t>()) {
__CAST_TO_TYPE(int32_t)
} else if (op->type() == type_of<int64_t>()) {
__CAST_TO_TYPE(int64_t)
} else if (op->type() == type_of<uint8_t>()) {
__CAST_TO_TYPE(uint8_t)
} else if (op->type() == type_of<uint16_t>()) {
__CAST_TO_TYPE(uint16_t)
} else if (op->type() == type_of<uint32_t>()) {
__CAST_TO_TYPE(uint32_t)
} else if (op->type() == type_of<uint64_t>()) {
__CAST_TO_TYPE(uint64_t)
} else if (op->type() == type_of<float>()) {
__CAST_TO_TYPE(float)
} else if (op->type() == type_of<double>()) {
__CAST_TO_TYPE(double)
} else if (op->type() == type_of<bool>()) {
__CAST_TO_TYPE(bool)
} else if (op->type() == type_of<uint32_t>()) {
__CAST_TO_TYPE(uint32_t)
} else if (op->type() == type_of<uint64_t>()) {
__CAST_TO_TYPE(uint64_t)
} else if (op->type() == type_of<float16>()) {
// Cannot simplify!!! pass
__CAST_TO_TYPE(float16)
} else {
CINN_NOT_IMPLEMENTED
}
}
#undef __CAST_TO_TYPE
}
};
在上面流程代碼中,我們可以看出對於 FloatImm
類型的處理額外借助了 NormCastValue
這個函數,原因是對於 Float32 到 Float16 的轉寫,要考慮上溢、下溢、Nan
、Inf
的場景:
template <typename CastType, typename T>
CastType NormCastValue(T value) {
if (type_of<CastType>().is_uint() || type_of<T>().is_uint()) {
// not support uint
return static_cast<CastType>(value);
}
if (std::isinf(value)) {
return std::numeric_limits<CastType>::infinity();
} else if (std::isnan(value)) {
return std::numeric_limits<CastType>::signaling_NaN();
} else if (value >= static_cast<T>(std::numeric_limits<CastType>::max())) {
return std::numeric_limits<CastType>::max();
} else if (value <= static_cast<T>(std::numeric_limits<CastType>::lowest())) {
return std::numeric_limits<CastType>::lowest();
}
return static_cast<CastType>(value);
}
3.4 Simplify
這個 Pass 包括的子邏輯比較多,單測文件 ir_simplify_test.cc
裏可以幫助理解效果:
void Simplify(Expr* expr) {
optim::CastSimplify(expr); // 先調用了 CastsSimplify,這個似乎會比較多餘?在遞歸調用時更會導致效率低下
SimplifyRampMutator()(expr);
SimplifyLoadMutator()(expr);
SimplifyStoreMutator()(expr);
SimplifyIfThenElseMutator()(expr);
common::cas_intervals_t var_intervals;
SimplifyButStoreLoadMutator mutator(var_intervals); // 又額外來了一遍,這裏似乎也比較低效?
mutator(expr);
ReplaceFracWithDivMutator()(expr); // 這裏將 ir::Frac 替換爲了 ir::Div,似乎也不是必要的,沒有看到哪裏構造了 ir::Frac
}
效果:
// case 1:
C = 1. //shape = [100, 20]
B = C[i, 0] + 1 * 0 + 100 + 24.5
// 經過此 Pass 後:
B = C[i, 0] + 124.5
// case 2:
{
serial for (i, 0, 100)
{
serial for (j, 0, 20)
{
B[i, j] = (X[i + 0, j + 0] + Y[i, j * 0] * 1.f + 0.f * X[i, j] + 25.f + 100.f - 0.f +
9.f * 10000.f * 1.f * 1.f * 0.f
)
}
}
}
// 經過此 Pass 後:
{
serial for (i, 0, 100)
{
serial for (j, 0, 20)
{
B[i, j] = (125.000000f + (X[i, j] + y[i, 0]))
}
}
}
首先看 SimplifyRampMutator
的角色作用,從源碼上來看,只關心兩種節點:ir::Ramp
和 ir::Add
。
- 對於
ir::Add
節點,如果兩個操作數都是ir::Ramp
類型,且其 lanes 屬性值是一樣的話,則會構建一個ir::Ramp
節點來替換掉ir::Add
節點 - 對於
ir::Ramp
節點,則遞歸對其base
和stride
屬性調用Simplify
函數。
struct SimplifyRampMutator : public ir::IRMutator<Expr*> {
void operator()(Expr* x) { ir::IRMutator<ir::Expr*>::Visit(x, x); }
void Visit(const Ramp* op, Expr* expr) override {
auto* node = expr->As<ir::Ramp>();
CHECK(common::IsPureMath(node->base)) << node->base << "is not a pure math!";
CHECK(common::IsPureMath(node->stride)) << node->stride << "is not a pure math!";
;
Simplify(&node->base);
Simplify(&node->stride);
}
// ramp + ramp
void Visit(const Add* op, Expr* expr) override {
auto* node = expr->As<ir::Add>();
Expr a = node->a();
Expr b = node->b();
auto a_ramp = a.As<ir::Ramp>();
auto b_ramp = b.As<ir::Ramp>();
if (a_ramp && b_ramp && a_ramp->lanes == b_ramp->lanes) {
Expr base_add = common::AutoSimplify(a_ramp->base + b_ramp->base); // 這裏會做CAS
Expr stride_add = common::AutoSimplify(a_ramp->stride + b_ramp->stride);
*expr = ir::Ramp::Make(base_add, stride_add, a_ramp->lanes);
}
}
};
我們這裏瞅一眼 ir::Ramp
節點是什麼樣子的:
//! A linear ramp node.
struct Ramp : public ExprNode<Ramp> {
Expr base, stride;
int lanes;
static Expr Make(Expr base, Expr stride, int lanes);
void Verify() const override;
static const IrNodeTy _node_type_ = IrNodeTy::Ramp;
};
接下來看第二個 SimplifyLoadMutator
的角色,簡單理解就是對 X[i+0, j+0]
以及 Y[i, j*0]
進行優化,得到 X[i, j]
,Y[i, 0]
:
struct SimplifyLoadMutator : public ir::IRMutator<ir::Expr*> {
void operator()(Expr* x) { ir::IRMutator<ir::Expr*>::Visit(x, x); }
void Visit(const Load* expr, Expr* op) override {
auto* node = op->As<Load>();
for (auto& idx : node->indices) {
if (common::IsPureMath(idx)) {
PartialSimplify(&idx, var_intervals_); // << 也是藉助了CAS了
} else {
SimplifyButStoreLoadMutator mutator(var_intervals_); // 根據節點類型,分發調用 PartialSimplify 函數
mutator(&idx);
}
}
}
void Visit(const For* op, Expr* expr) override {
auto* min_i = op->min.As<IntImm>();
auto* extent_i = op->extent.As<IntImm>();
if (min_i && extent_i && extent_i->value > min_i->value) {
var_intervals_.emplace(op->loop_var->name, common::CasInterval{min_i->value, extent_i->value - 1});
}
auto* node = expr->As<For>();
operator()(&node->body);
operator()(&node->extent);
if (min_i && extent_i) {
var_intervals_.erase(op->loop_var->name);
}
}
common::cas_intervals_t var_intervals_;
};
第三個 SimplifyStoreMutator
的代碼邏輯基本與 SimplifyLoadMutator
一致,這裏我們不再贅述。
第四個 SimplifyIfThenElseMutator
,這個也很好理解,對 condition
調用 CAS 邏輯:
struct SimplifyIfThenElseMutator : public ir::IRMutator<> {
void operator()(Expr* x) { ir::IRMutator<>::Visit(x, x); }
using ir::IRMutator<>::Visit;
void Visit(const IfThenElse* op, Expr* expr) override {
auto* node = expr->As<ir::IfThenElse>();
node->condition = common::AutoSimplify(node->condition); // 核心點
if (node->true_case.defined()) Visit(&node->true_case, &node->true_case); // 訪問者模式分發
if (node->false_case.defined()) Visit(&node->false_case, &node->false_case); // 訪問者模式分發
}
};
第五個 SimplifyButStoreLoadMutator
本來在第二、三個子邏輯會局部觸發,這裏爲何單獨觸發了一遍?從函數實現了是對其他節點都遍歷一遍進行簡化處理,唯獨除了 Store
和 Load
節點(因爲這兩個節點主要出現在 ir::For
節點中)
第六個 ReplaceFracWithDivMutator
,這個很有意思,是把所有的 ir::FracOp
替換爲 ir::Div
,這兩個不一樣麼?仔細看了下,在一些 CodeGen
模塊裏,如 codegen_llvm.cc
中,是沒有實現 ir::FracOp
裏的代碼生成邏輯的,只有 ir::Div
實現了。那爲什麼不直接把 ir::FracOp
節點刪除呢?
struct ReplaceFracWithDivMutator : public ir::IRMutator<> {
void operator()(Expr* x) { ir::IRMutator<>::Visit(x, x); }
void Visit(const FracOp* op, Expr* expr) override {
auto* node = expr->As<ir::FracOp>();
ir::IRMutator<>::Visit(&node->operand(0), &node->operand(0));
ir::IRMutator<>::Visit(&node->operand(1), &node->operand(1));
*expr = ir::Div::Make(node->operand(0), node->operand(1));
}
};
llvm::Value *CodeGenLLVM::Visit(const ir::FracOp *op) { __IR_EMITTER_NOT_IMPLEMENTED(op); }
在CINN裏我檢索了 lang/pe
等模塊源碼,沒有看到在 IR 層面直接使用或構造 ir::Frac
節點的代碼,只有單測和不相關模塊:
3.5 MapExternCall 邏輯
從調用棧來看, 底層是 map_extern_call.cc
中具體的 MapExternCall
的實現
void MapExternCall(Expr *e, Target target) {
Mutator m(target);
m(e);
}
所有的工作都是交給基於 Ast 的 Mutator 來做的,原理:藉助「訪問者模式」僅識別和處理 ir::Call
對象:
void Visit(const ir::Call *op, Expr *expr) override {
auto *node = expr->As<ir::Call>();
CHECK(node);
OptimizeConstantPow(node);
if (target.arch == Target::Arch::NVGPU) {
DealWithNvGpuintrinsics(node, expr);
} else {
DealWithCpuintrinsics(node, expr);
}
}
我們比較關心 CUDA 上的變換,進一步看 DealWithNvGpuintrinsics
函數:
void DealWithNvGpuintrinsics(ir::Call *node, Expr *expr) {
auto arg_size = node->read_args.size();
if (arg_size == 0UL) {
// some node like __syncthreads hasn't arguments
return;
}
const auto &dtype = node->read_args.front().type();
const auto &name = node->name;
bool node_in_extern_fp32 = kExternFp32CallsGPU.count(name);
bool node_in_extern_int32 = kExternInt32CallsGPU.count(name);
if (!node_in_extern_fp32 && !node_in_extern_int32) {
return;
}
std::string suffix;
if (dtype.is_int() && node_in_extern_int32) {
if (dtype.is_int(32)) {
suffix = "_int32";
} else if (dtype.is_int(64)) {
suffix = "_int64";
}
} else if (dtype.is_float() && node_in_extern_fp32) {
if (dtype.is_float(64)) {
suffix = "_fp64";
} else if (dtype.is_float(32)) {
suffix = "_fp32";
} else if (dtype.is_float(16)) {
suffix = "_fp16";
}
}
CHECK(!suffix.empty()) << name << " not support data type " << dtype;
std::string extern_func = "cinn_nvgpu_" + name + suffix; // <------ 主要是按照OpNode白名單+dtype拼接要替換的外部 API (其實也是在CINN層裏wrapper註冊了一層)
*expr = lang::CallExtern(extern_func, node->read_args); // 直接替換 ir::Call 對應的 Expr 對象
}