From 247d92b0ab1748b4f2f9327b0a7f807402b115ef Mon Sep 17 00:00:00 2001 From: ridethepig Date: Sun, 18 Jun 2023 00:54:52 +0800 Subject: [PATCH] dead code elimination --- .vscode/launch.json | 2 +- include/algos.h | 84 +++++++++++++++++ include/common.h | 5 +- include/mc_pass.h | 2 +- include/pass.h | 12 ++- src/algo_dominance.cpp | 161 +++++++++++++++++++++++++++++++++ src/main.cpp | 1 + src/pass_const_fold.cpp | 52 +++++++---- src/pass_dce.cpp | 196 ++++++++++++++++++++++++++++++++++++++++ 9 files changed, 488 insertions(+), 27 deletions(-) create mode 100644 src/pass_dce.cpp diff --git a/.vscode/launch.json b/.vscode/launch.json index cd1dcff..b72ef30 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -10,7 +10,7 @@ "name": "Debug", "program": "${workspaceFolder}/build/sysy", // "args": ["../sysytests/functional_2021/069_greatest_common_divisor.sy", "-S", "-o", "build/my.s", "-O1", "-emit-llvm"], - "args" : ["-S", "../sysytests/functional_2022/21_if_test2.sy", "-o", "build/my.s", "-emit-llvm", "-O1",], + "args" : ["-S", "../sysytests/functional_2022/38_op_priority4.sy", "-o", "build/dbg.s", "-emit-llvm", "-O1",], "cwd": "${workspaceFolder}" }, ] diff --git a/include/algos.h b/include/algos.h index 7f31acf..ef26b29 100644 --- a/include/algos.h +++ b/include/algos.h @@ -8,4 +8,88 @@ void gen_dominance(FunctionPtr_t func); void gen_dominance_frontier(FunctionPtr_t func); void update_dfs_numbers(BasicBlockPtr_t bb, bool rst); +struct CFGNode { + BasicBlock *bb = nullptr; + int special_node = 0; + std::vector pred_list = {}; + std::vector succ_list = {}; + std::unordered_set DF = {}; + CFGNode *idom; // immediate dominator of this node + std::unordered_set dom; // dominators of this node + std::unordered_set dominated; // nodes immediately dominated by this node + static sptr(CFGNode) New(BasicBlock *bb = nullptr) { + auto newnode = std::make_shared(); + newnode->bb = bb; + return newnode; + } + static sptr(CFGNode) New(int special_node) { + auto newnode = std::make_shared(); + newnode->bb = nullptr; + newnode->special_node = special_node; + return newnode; + } + std::string to_string() const { + if (is_entry()) return "ENTRY"; + if (is_exit()) return "EXIT"; + return bb->name; + } + bool is_entry() const { + return special_node == 1; + } + bool is_exit() const { + return special_node == 2; + } + void print_cfg(std::ostream &ostr) const { + ostr << "---" << to_string() << "---\n"; + ostr << " pred: ["; + for (auto pred : pred_list) { + ostr << pred->to_string() << ", "; + } + ostr << "]\n"; + ostr << " succ: ["; + for (auto succ : succ_list) { + ostr << succ->to_string() << ", "; + } + ostr << "]\n"; + ostr.flush(); + } + void print_dom(std::ostream &ostr) const { + print_cfg(ostr); + ostr << " idom: " << (!idom ? "NULL" : idom->to_string()) << "\n"; + ostr << " domer: ["; + for (auto d : dom) { + ostr << d->to_string() << ", "; + } + ostr << "]\n"; + ostr << " domee: ["; + for (auto domee : dominated) { + ostr << domee->to_string() << ", "; + } + ostr << "]\n"; + ostr << " DF: ["; + for (auto df : DF) { + ostr << df->to_string() << ", "; + } + ostr << "]\n"; + ostr.flush(); + } +}; + +class CFG { +public: + std::list node_pool; + CFGNode *entry_node; + CFGNode *exit_node; + std::unordered_map bb_node; + void build_from_func(Function *func, bool reversed); + void print_cfg(); + void print_dom(); + void calculate(); + +private: + std::unordered_set visit; + void _print_cfg(CFGNode *); + void _print_dom(CFGNode *); +}; + } // namespace CompSysY diff --git a/include/common.h b/include/common.h index 1d0217a..b970913 100644 --- a/include/common.h +++ b/include/common.h @@ -42,10 +42,7 @@ inline sptr(DST) strict_shared_cast(SRC src) { #define INF (0x3f3f3f3f) #define BTWN(v, l, r) (l <= v && v <= r) -#define panic(message) \ - do { \ - throw GrammarException(__FILE__, __LINE__, (message)); \ - } while (0) +#define panic(message) __assert_fail(message, __FILE__, __LINE__, __ASSERT_FUNCTION) #define DEF_PTR_T(type) \ class type; \ diff --git a/include/mc_pass.h b/include/mc_pass.h index 7c255be..bf0237d 100644 --- a/include/mc_pass.h +++ b/include/mc_pass.h @@ -76,4 +76,4 @@ private: std::set active_moves; }; -} \ No newline at end of file +} // namespace CompSysY \ No newline at end of file diff --git a/include/pass.h b/include/pass.h index 019c143..c8f00df 100644 --- a/include/pass.h +++ b/include/pass.h @@ -4,7 +4,7 @@ namespace CompSysY { class PassSCCP final : public Pass { public: - PassSCCP() : Pass("Constant Propagation") {} + PassSCCP() : Pass("constant propagation") {} void run(const Module &module) override; enum class ConstLatTag { Top = 0, Const, Bottom }; struct ConstLat { @@ -58,13 +58,19 @@ public: class PassBuildCFG final : public Pass { public: - PassBuildCFG() : Pass("build_cfg") {} + PassBuildCFG() : Pass("build control flow graph") {} void run(const Module &module) override; }; class PassConstFold final : public Pass { public: - PassConstFold() : Pass("const fold") {} + PassConstFold() : Pass("constant fold") {} + void run(const Module &module) override; +}; + +class PassDCE final : public Pass { +public: + PassDCE() : Pass("dead code elimination") {} void run(const Module &module) override; }; diff --git a/src/algo_dominance.cpp b/src/algo_dominance.cpp index 340837c..dd5d137 100644 --- a/src/algo_dominance.cpp +++ b/src/algo_dominance.cpp @@ -1,3 +1,4 @@ +#include "algos.h" #include "llir.h" #include "visitor.h" @@ -132,4 +133,164 @@ void update_dfs_numbers(BasicBlockPtr_t bb, bool rst) { bb->dom_dfs_out = dfs_num++; } +void CFG::build_from_func(Function *func, bool reversed) { + // special node entry&exit + auto _entry_node = CFGNode::New(1); + auto _exit_node = CFGNode::New(2); + node_pool.push_back(_entry_node); + node_pool.push_back(_exit_node); + entry_node = _entry_node.get(); + exit_node = _exit_node.get(); + // create a cfgnode for each bb + for (auto bb : func->bb_list) { + auto _new_node = CFGNode::New(bb.get()); + node_pool.push_back(_new_node); + bb_node.insert({bb.get(), _new_node.get()}); + } + auto as_bb = [](sptr(Value) value) { return dynamic_cast(value.get()); }; + auto connect = [&](CFGNode *pred, CFGNode *succ) { + if (reversed) std::swap(pred, succ); + pred->succ_list.push_back(succ); + succ->pred_list.push_back(pred); + }; + + connect(entry_node, bb_node.at(func->bb_list.front().get())); + for (auto bb : func->bb_list) { + auto _inst_br = bb->inst_list.back(); + if (auto inst_br = shared_cast(_inst_br)) { + if (inst_br->operand_list.size() == 1) { + connect(bb_node.at(bb.get()), bb_node.at(as_bb(inst_br->operand_list[0]))); + } + else if (inst_br->operand_list.size() == 3) { + connect(bb_node.at(bb.get()), bb_node.at(as_bb(inst_br->operand_list[1]))); + connect(bb_node.at(bb.get()), bb_node.at(as_bb(inst_br->operand_list[2]))); + } + else { + panic("br should have either 1 or 3 operands"); + } + } + else if (auto inst_ret = shared_cast(_inst_br)) { + connect(bb_node.at(bb.get()), exit_node); + } + else { + panic("The last instruction of a basic block should be control transfer"); + } + } + if (reversed) { + exit_node->special_node = 1; + entry_node->special_node = 2; + std::swap(exit_node, entry_node); + } +} + +void CFG::print_cfg() { + visit.clear(); + _print_cfg(entry_node); +} + +void CFG::_print_cfg(CFGNode *rt) { + rt->print_cfg(std::cout); + for (auto succ : rt->succ_list) { + if (!visit.count(succ)) { + visit.insert(succ); + _print_cfg(succ); + } + } +} + +void CFG::print_dom() { + visit.clear(); + _print_dom(entry_node); +} + +void CFG::_print_dom(CFGNode *rt) { + rt->print_dom(std::cout); + for (auto succ : rt->succ_list) { + if (!visit.count(succ)) { + visit.insert(succ); + _print_dom(succ); + } + } +} +/* +DOM(entry) = entry +DOM(everything else) = all nodes +change = true +while change, do + change = false + for each BB (except the entry BB) + TMP(BB) = BB + {intersect of DOM of all predecessor BB’s} + if (TMP(BB) != DOM(BB)) + DOM(BB) = TMP(BB) + change = true +*/ + +void CFG::calculate() { + entry_node->dom.clear(); + entry_node->dom.insert(entry_node); + std::unordered_set all_node_set; + for (auto cfgnode : node_pool) { + all_node_set.insert(cfgnode.get()); + } + for (auto cfgnode : node_pool) { + if (cfgnode->is_entry()) continue; + cfgnode->dom = all_node_set; + } + all_node_set.erase(entry_node); // now, it is every node but entry + bool changed = true; + while (changed) { + changed = false; + for (auto node : all_node_set) { + // std::unordered_set tmp = {node}; + // for (auto pred : node->pred_list) { + // tmp.insert(BEGINEND(pred->dom)); + // } + // if (tmp != node->dom) { + // changed = true; + // node->dom = tmp; + // } + for (auto it = node->dom.begin(); it != node->dom.end();) { + auto *x = *it; + auto check = [x](CFGNode *p) { return !INSET(p->dom, x); }; + if (x != node && std::any_of(BEGINEND(node->pred_list), check)) { + changed = true; + it = node->dom.erase(it); + } + else ++it; + } + } + } + + entry_node->idom = nullptr; + for (auto bb : all_node_set) { + for (auto domer : bb->dom) { + // given domer dom bb, apparently, idom(bb) != bb, i.e. domer sdom bb + if (domer == bb) continue; + // if domer not sdom x where x sdom bb + auto check = [domer, bb](CFGNode *x) { return x == bb || x == domer || !INSET(x->dom, domer); }; + if (std::all_of(BEGINEND(bb->dom), check)) { + bb->idom = domer; + domer->dominated.insert(bb); + break; + } + } + } + // calculate DF + for (auto &bb : bb_node) { + bb.second->DF.clear(); + } + for (auto &bb : bb_node) { + auto n = bb.second; + if (n->pred_list.size() >= 2) { + for (auto pred : n->pred_list) { + auto runner = pred; + while (runner && runner != n->idom) { + runner->DF.insert(n); + runner = runner->idom; + } + } + } + } +} + } // namespace CompSysY \ No newline at end of file diff --git a/src/main.cpp b/src/main.cpp index 1ce783d..bfd49d9 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -116,6 +116,7 @@ int main(int argc, const char **argv) { std::vector optpasses = { std::make_shared(), std::make_shared(), + std::make_shared(), }; if (flg_O1) { diff --git a/src/pass_const_fold.cpp b/src/pass_const_fold.cpp index ea8e8d7..b21df97 100644 --- a/src/pass_const_fold.cpp +++ b/src/pass_const_fold.cpp @@ -2,15 +2,12 @@ namespace CompSysY { - -static bool foldable(Instruction* inst) { +static bool foldable(Instruction *inst) { if (InstTag::Add <= inst->tag && inst->tag <= InstTag::Ne) { - if (shared_cast(inst->operand_list[0]) && shared_cast(inst->operand_list[1])) - return true; + if (shared_cast(inst->operand_list[0]) && shared_cast(inst->operand_list[1])) return true; } else if (inst->tag == InstTag::Zext) { - if (shared_cast(inst->operand_list[0])) - return true; + if (shared_cast(inst->operand_list[0])) return true; } return false; } @@ -23,6 +20,7 @@ static InstBranch *is_cond_br(Instruction *inst) { return nullptr; } +// not removing the block, this only means that pred no longer branch to bb static void rewrite_bb(sptr(BasicBlock) bb, sptr(BasicBlock) pred) { pred->succ_list.remove(bb); auto pred_index = GETINDEX(bb->pred_list, pred); @@ -36,7 +34,7 @@ static void rewrite_bb(sptr(BasicBlock) bb, sptr(BasicBlock) pred) { } } -static void compute(Function* func) { +static void compute(Function *func) { bool changed = true; while (changed) { changed = false; @@ -50,8 +48,8 @@ static void compute(Function* func) { } assert(0); }; - changed = true; - int result = 0; + changed = true; + int result = 0; TypePtr_t type = TypeHelper::TYPE_I32; switch (inst->tag) { case InstTag::Add: result = val(0) + val(1); break; @@ -59,12 +57,30 @@ static void compute(Function* func) { case InstTag::Mod: result = val(0) % val(1); break; case InstTag::Mul: result = val(0) * val(1); break; case InstTag::Div: result = val(0) / val(1); break; - case InstTag::Lt: result = val(0) < val(1); type = TypeHelper::TYPE_I1; break; - case InstTag::Le: result = val(0) <= val(1); type = TypeHelper::TYPE_I1; break; - case InstTag::Ge: result = val(0) >= val(1); type = TypeHelper::TYPE_I1; break; - case InstTag::Gt: result = val(0) > val(1); type = TypeHelper::TYPE_I1; break; - case InstTag::Eq: result = val(0) == val(1); type = TypeHelper::TYPE_I1; break; - case InstTag::Ne: result = val(0) != val(1); type = TypeHelper::TYPE_I1; break; + case InstTag::Lt: + result = val(0) < val(1); + type = TypeHelper::TYPE_I1; + break; + case InstTag::Le: + result = val(0) <= val(1); + type = TypeHelper::TYPE_I1; + break; + case InstTag::Ge: + result = val(0) >= val(1); + type = TypeHelper::TYPE_I1; + break; + case InstTag::Gt: + result = val(0) > val(1); + type = TypeHelper::TYPE_I1; + break; + case InstTag::Eq: + result = val(0) == val(1); + type = TypeHelper::TYPE_I1; + break; + case InstTag::Ne: + result = val(0) != val(1); + type = TypeHelper::TYPE_I1; + break; case InstTag::Zext: result = val(0); break; default: break; } @@ -74,10 +90,10 @@ static void compute(Function* func) { } } for (auto bb : func->bb_list) { - for(auto inst : bb->inst_list) { + for (auto inst : bb->inst_list) { if (!is_cond_br(inst.get())) continue; if (!shared_cast(inst->operand_list[0])) continue; - auto cond = shared_cast(inst->operand_list[0])->value; + auto cond = shared_cast(inst->operand_list[0])->value; auto target = shared_cast(cond ? inst->operand_list[1] : inst->operand_list[2]); auto dead_target = shared_cast(!cond ? inst->operand_list[1] : inst->operand_list[2]); assert(target && dead_target); @@ -96,4 +112,4 @@ void PassConstFold::run(const Module &module) { } } -} \ No newline at end of file +} // namespace CompSysY \ No newline at end of file diff --git a/src/pass_dce.cpp b/src/pass_dce.cpp new file mode 100644 index 0000000..0f3ec56 --- /dev/null +++ b/src/pass_dce.cpp @@ -0,0 +1,196 @@ +#include "algos.h" +#include "pass.h" + +namespace CompSysY { + +void get_reachable(BasicBlock *bb, std::unordered_set &reachable) { + std::unordered_set visited; + std::vector stk = {bb}; + while (!stk.empty()) { + auto u = stk.back(); + stk.pop_back(); + if (INSET(visited, u)) continue; + reachable.insert(u); + visited.insert(u); + for (auto succ : u->succ_list) { + stk.push_back(succ.get()); + } + } +} + +/* +Remove block from the function +1. remove from pred's succ +2. remove from succ's pred + 2.1 remove from succ's phi-nodes +3. remove from function +*/ +void remove_block(sptr(BasicBlock) victim) { + for (auto pred : victim->pred_list) { + pred->succ_list.remove(victim); + } + for (auto succ : victim->succ_list) { + auto pred_index = GETINDEX(succ->pred_list, victim); + succ->pred_list.remove(victim); + for (auto itr = succ->inst_list.begin(); itr != succ->inst_list.end();) { + auto inst = *itr++; + if (!shared_cast(inst)) break; + auto inst_phi = shared_cast(inst); + assert(inst_phi->operand_list.size() > pred_index); + inst_phi->operand_list.erase(inst_phi->operand_list.begin() + pred_index); + } + } + victim->parent_func->bb_list.erase(victim->itr); +} + +// 高级编译器18.1 +bool Elim_Unreach_Code(Function *func) { + bool again = true; + bool changed = false; + std::unordered_set reachable; + do { + again = false; + reachable.clear(); + get_reachable(func->bb_list.front().get(), reachable); + for (auto itr = func->bb_list.begin(); itr != func->bb_list.end();) { + auto bb = *itr++; + if (!INSET(reachable, bb.get())) { + again = true; + changed = true; + VLOG(6) << "[DCE] remove unreachable block " << bb->to_string(); + remove_block(bb); + } + } + } while (again); + return changed; +} + +// Operations defined as critical: I/O statements, linkage code (entry & exit blocks), return values, calls to other +// procedures From Efficiently Computing SSA Form and the Control Dependence Graph, Zadeck 1991 +bool Elim_Dead_Code(Function *func) { + CFG RCFG; + RCFG.build_from_func(func, true); + RCFG.calculate(); + // RCFG.print_dom(); + // Mark + std::unordered_map mark; + std::unordered_set worklist; + for (auto bb : func->bb_list) { + for (auto inst : bb->inst_list) { + switch (inst->tag) { + case InstTag::Call: // though not all functions are critical, we dont impl anything to get that info + case InstTag::Ret: // useful for its caller, though could be useless, but we can't either know anything about + // this + case InstTag::Store: // though not every store is critical, mark the operands conservatively + // case InstTag::Br: // branch, well, controls the flow + break; + default: continue; + } + mark[inst.get()] = true; + worklist.insert(inst.get()); + } + } + while (!worklist.empty()) { + auto inst = *worklist.begin(); + worklist.erase(worklist.begin()); + if (auto inst_phi = dynamic_cast(inst)) { + for (auto pred : inst->parent_bb->pred_list) { + if (!mark[pred->inst_list.back().get()]) { + mark[pred->inst_list.back().get()]= true; + worklist.insert(pred->inst_list.back().get()); + } + } + } + for (auto op : inst->operand_list) { + // we ignore globals and constants + if (auto inst_op = shared_cast(op)) { + if (!mark[inst_op.get()]) { + mark[inst_op.get()] = true; + worklist.insert(inst_op.get()); + } + } + } + auto cfgnode = RCFG.bb_node.at(inst->parent_bb.get()); + for (auto rdf : cfgnode->DF) { + assert(rdf->bb); + auto block_end_branch = rdf->bb->inst_list.back().get(); + assert(dynamic_cast(block_end_branch)); + if (!mark[block_end_branch]) { + mark[block_end_branch] = true; + worklist.insert(block_end_branch); + } + } + } + auto bb_useful = [&](BasicBlock *bb) { + for (auto inst : bb->inst_list) { + if (mark[inst.get()]) return true; + } + return false; + }; + // sweep + bool changed = false; + for (auto bb : func->bb_list) { + for (auto itr = bb->inst_list.begin(); itr != bb->inst_list.end();) { + auto inst = *itr++; + if (mark[inst.get()]) continue; + // deal with unmakred + if (auto inst_br = shared_cast(inst)) { + // in this algo, jump and branch are different + if (inst_br->operand_list.size() == 1) continue; + // Still unknown what this would be like, the only case by now is + // true_branch = false_branch + changed = true; + LOG(WARNING) << "rewrite dead branch"; + auto cfgnode = RCFG.bb_node.at(bb.get()); + // walk up through the dom-tree(idom), and check if there is any useful insts in it + auto useful_pdom = cfgnode->idom; + while (!useful_pdom->is_entry()) { + if (bb_useful(useful_pdom->bb)) { + break; + } + useful_pdom = useful_pdom->idom; + } + // rewrite branch to jmp + inst_br->u_remove_from_usees(); + inst_br->operand_list.clear(); + sptr(BasicBlock) tmp = nullptr; + for (auto _bb : func->bb_list) { + if (_bb.get() == useful_pdom->bb) { + tmp = _bb; + break; + } + } + inst_br->add_operand(tmp); + } + else { + changed = true; + // inst->u_replace_users(nullptr); + inst->u_remove_from_usees(); + bb->inst_list.erase(inst->inst_itr); + VLOG(6) << "remove dead instruction"; + } + } + } + return changed; +} + +// 高级编译器18.2 +bool Fuse_Block(Function *func) { + return false; +} + +void PassDCE::run(const Module &module) { + LOG(INFO) << "Run pass " << pass_name; + for (auto func : module.function_list) { + if (func->is_libfunc()) continue; + auto again = true; + while (again) { + again = false; + again = Elim_Dead_Code(func.get()) || again; + again = Elim_Unreach_Code(func.get()) && again; + again = Fuse_Block(func.get()) || again; + } + } +} + +} // namespace CompSysY