From c97f3dfdace6e7dbeec4c7ae78a40896cab7415a Mon Sep 17 00:00:00 2001 From: ridethepig Date: Sun, 18 Jun 2023 18:55:28 +0800 Subject: [PATCH] block merge --- .vscode/launch.json | 2 +- include/common.h | 1 + src/main.cpp | 11 ++- src/pass_dce.cpp | 150 ++++++++++++++++++++++++++++++++++----- src/visitor_llir_gen.cpp | 5 -- 5 files changed, 140 insertions(+), 29 deletions(-) diff --git a/.vscode/launch.json b/.vscode/launch.json index b72ef30..52c21c8 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -10,7 +10,7 @@ "name": "Debug", "program": "${workspaceFolder}/build/sysy", // "args": ["../sysytests/functional_2021/069_greatest_common_divisor.sy", "-S", "-o", "build/my.s", "-O1", "-emit-llvm"], - "args" : ["-S", "../sysytests/functional_2022/38_op_priority4.sy", "-o", "build/dbg.s", "-emit-llvm", "-O1",], + "args" : ["-S", "../sysytests/functional_2022/21_if_test2.sy", "-o", "build/dbg.s", "-emit-llvm", "-O1",], "cwd": "${workspaceFolder}" }, ] diff --git a/include/common.h b/include/common.h index b970913..d52bd45 100644 --- a/include/common.h +++ b/include/common.h @@ -20,6 +20,7 @@ #define uptr(type) std::unique_ptr #define sptr(type) std::shared_ptr +#define wptr(type) std::weak_ptr template inline sptr(DST) shared_cast(SRC src) { diff --git a/src/main.cpp b/src/main.cpp index bfd49d9..4b1ee4b 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -113,14 +113,13 @@ int main(int argc, const char **argv) { } } - std::vector optpasses = { - std::make_shared(), - std::make_shared(), - std::make_shared(), - }; + std::vector optpasses; + optpasses.push_back(std::move(std::make_unique())); + optpasses.push_back(std::move(std::make_unique())); + optpasses.push_back(std::move(std::make_unique())); if (flg_O1) { - for (auto pass : optpasses) { + for (auto &pass : optpasses) { pass->run(visitor.module); } } diff --git a/src/pass_dce.cpp b/src/pass_dce.cpp index 0f3ec56..9b57239 100644 --- a/src/pass_dce.cpp +++ b/src/pass_dce.cpp @@ -67,6 +67,7 @@ bool Elim_Unreach_Code(Function *func) { // Operations defined as critical: I/O statements, linkage code (entry & exit blocks), return values, calls to other // procedures From Efficiently Computing SSA Form and the Control Dependence Graph, Zadeck 1991 +// comp 512 Lecture10 gives a clear description bool Elim_Dead_Code(Function *func) { CFG RCFG; RCFG.build_from_func(func, true); @@ -153,14 +154,7 @@ bool Elim_Dead_Code(Function *func) { // rewrite branch to jmp inst_br->u_remove_from_usees(); inst_br->operand_list.clear(); - sptr(BasicBlock) tmp = nullptr; - for (auto _bb : func->bb_list) { - if (_bb.get() == useful_pdom->bb) { - tmp = _bb; - break; - } - } - inst_br->add_operand(tmp); + inst_br->add_operand(*useful_pdom->bb->itr); } else { changed = true; @@ -174,22 +168,144 @@ bool Elim_Dead_Code(Function *func) { return changed; } -// 高级编译器18.2 +static std::vector post_order; +static std::unordered_set visit; + +static void get_post_order(BasicBlock* rt) { + visit.insert(rt); + for (auto succ : rt->succ_list) { + if (!INSET(visit, succ.get())) { + get_post_order(succ.get()); + } + } + post_order.push_back(rt); +} + + + +// 高级编译器18.2, comp512 Lecture10 gives a clearer version bool Fuse_Block(Function *func) { + auto is_branch = [](Instruction* inst) { + return dynamic_cast(inst) && inst->operand_list.size() == 3; + }; + visit.clear(); post_order.clear(); + get_post_order(func->bb_list.front().get()); + for (auto bb : post_order) { + LOG(DEBUG) << "visit " << bb->to_string(); + if (is_branch(bb->inst_list.back().get())) { + auto inst = bb->inst_list.back().get(); + if (inst->operand_list[1] == inst->operand_list[2]) { + auto target = shared_cast(inst->operand_list[1]); + if (shared_cast(target->inst_list.front())) continue; + auto pred_itr = STD_FIND(target->pred_list, *bb->itr); + target->pred_list.erase(pred_itr); + inst->u_remove_from_usees(); + inst->operand_list.clear(); + inst->add_operand(target); + bb->succ_list.clear(); + bb->succ_list.push_back(target); + } + } + else if (auto inst = dynamic_cast(bb->inst_list.back().get())) { + auto target = shared_cast(inst->operand_list[0]); + assert(target); + // branch has side-effect on phi nodes, dealing with that is beyond this shabby compiler's capability, thus simply ignore + if (shared_cast(target->inst_list.front())) continue; + if (bb->inst_list.size() == 1) { + /* bb_i -> bb_j + bb_i has only one jump, and bb_j can be any + 1. replace bb_i with bb_j in all of its preds' succ_list + 2. replace bb_i with bb_i's preds in bb_j's pred_list + 3. if bb_j starts with phi, replace entry associated with bb_i with the same value, copy may be needed + */ + if (bb == func->bb_list.front().get()) { + LOG(WARNING) << "block merge reach entry, skipped due to possible side effect"; + // bb->succ_list.clear(); + // bb->inst_list.back()->u_remove_from_usees(); + // bb->inst_list.erase(bb->inst_list.back()->inst_itr); + // bb->parent_func->bb_list.erase(bb->itr); + // func->bb_list.erase(target->itr); + // func->bb_list.push_front(target); + // target->pred_list.clear(); + // target->itr = func->bb_list.begin(); + continue; + } + auto bb_index_target_succ = GETINDEX(target->pred_list, *bb->itr); + for (auto pred : bb->pred_list) { + // find the pos of bb in pred's succlist, and replace with target, in-place + auto itr = STD_FIND(pred->succ_list, *bb->itr); + *itr = target; + // remember to rewrite pred's jump/branch + auto op_itr = STD_FIND(pred->inst_list.back()->operand_list, *bb->itr); + *op_itr = target; + // for target, there's no need to keep this in-place, we delete it and push new + target->pred_list.push_back(pred); + } + // remove bb from target + target->pred_list.remove(*bb->itr); + // clear bb + bb->pred_list.clear(); + bb->succ_list.clear(); + bb->inst_list.back()->u_remove_from_usees(); + assert(bb->inst_list.back()->use_list.empty()); + bb->inst_list.erase(bb->inst_list.back()->inst_itr); + func->bb_list.erase(bb->itr); + } + else if (target->pred_list.size() == 1) { + /* bb_i -> bb_j, where i is the only pred of j + for simplicity, merge j into i + */ + assert(!shared_cast(target->inst_list.front())); + bb->succ_list.clear(); + for (auto succ : target->succ_list) { + bb->succ_list.push_back(succ); + auto itr = STD_FIND(succ->pred_list, target); + *itr = *bb->itr; + } + // remove bb_i's last jump + bb->inst_list.back()->u_remove_from_usees(); + bb->inst_list.pop_back(); + // clear target and transfer insts + target->succ_list.clear(); + target->pred_list.clear(); + while (!target->inst_list.empty()) { + bb->inst_list.push_back(target->inst_list.front()); + target->inst_list.pop_front(); + } + func->bb_list.erase(target->itr); + } + } + } return false; } +/* +clear some strange pattern +*/ +static void other_clear(Function* func) { + for (auto bb : func->bb_list) { + for (auto itr = bb->inst_list.begin(); itr != bb->inst_list.end(); ) { + auto inst = *itr ++; + if (shared_cast(inst)) { + if (inst->operand_list.size() == 1) { + LOG(DEBUG) << "remove trivial phi"; + inst->u_replace_users(inst->operand_list[0]); + inst->u_remove_from_usees(); + bb->inst_list.erase(inst->inst_itr); + } + } + } + } +} + void PassDCE::run(const Module &module) { LOG(INFO) << "Run pass " << pass_name; for (auto func : module.function_list) { - if (func->is_libfunc()) continue; - auto again = true; - while (again) { - again = false; - again = Elim_Dead_Code(func.get()) || again; - again = Elim_Unreach_Code(func.get()) && again; - again = Fuse_Block(func.get()) || again; - } + if (func->is_libfunc()) continue; + Elim_Dead_Code(func.get()); + Elim_Unreach_Code(func.get()); + other_clear(func.get()); + // Fuse_Block(func.get()); } } diff --git a/src/visitor_llir_gen.cpp b/src/visitor_llir_gen.cpp index c158dd6..0fd44f9 100644 --- a/src/visitor_llir_gen.cpp +++ b/src/visitor_llir_gen.cpp @@ -168,11 +168,6 @@ static void _gen_blocks(std::ostream &ostr, const std::list &bl assert(bb_dest->ir_seqno >= 0); ostr << "label %" << bb_dest->ir_seqno; } - else if (auto const_cond = shared_cast(inst->operand_list[0])) { - //! temporary patch after sccp, but it is often incorrect for phi - auto target_bb = const_cond->value ? inst->operand_list[1] : inst->operand_list[2]; - ostr << "label %" << shared_cast(target_bb)->ir_seqno; - } else { assert(shared_cast(inst->operand_list[0])); assert(Type::isType(inst->operand_list[0]->type));