dead code elimination

This commit is contained in:
ridethepig 2023-06-18 00:54:52 +08:00
parent 07a078b80e
commit 247d92b0ab
9 changed files with 488 additions and 27 deletions

2
.vscode/launch.json vendored
View File

@ -10,7 +10,7 @@
"name": "Debug", "name": "Debug",
"program": "${workspaceFolder}/build/sysy", "program": "${workspaceFolder}/build/sysy",
// "args": ["../sysytests/functional_2021/069_greatest_common_divisor.sy", "-S", "-o", "build/my.s", "-O1", "-emit-llvm"], // "args": ["../sysytests/functional_2021/069_greatest_common_divisor.sy", "-S", "-o", "build/my.s", "-O1", "-emit-llvm"],
"args" : ["-S", "../sysytests/functional_2022/21_if_test2.sy", "-o", "build/my.s", "-emit-llvm", "-O1",], "args" : ["-S", "../sysytests/functional_2022/38_op_priority4.sy", "-o", "build/dbg.s", "-emit-llvm", "-O1",],
"cwd": "${workspaceFolder}" "cwd": "${workspaceFolder}"
}, },
] ]

View File

@ -8,4 +8,88 @@ void gen_dominance(FunctionPtr_t func);
void gen_dominance_frontier(FunctionPtr_t func); void gen_dominance_frontier(FunctionPtr_t func);
void update_dfs_numbers(BasicBlockPtr_t bb, bool rst); void update_dfs_numbers(BasicBlockPtr_t bb, bool rst);
struct CFGNode {
BasicBlock *bb = nullptr;
int special_node = 0;
std::vector<CFGNode *> pred_list = {};
std::vector<CFGNode *> succ_list = {};
std::unordered_set<CFGNode *> DF = {};
CFGNode *idom; // immediate dominator of this node
std::unordered_set<CFGNode *> dom; // dominators of this node
std::unordered_set<CFGNode *> dominated; // nodes immediately dominated by this node
static sptr(CFGNode) New(BasicBlock *bb = nullptr) {
auto newnode = std::make_shared<CFGNode>();
newnode->bb = bb;
return newnode;
}
static sptr(CFGNode) New(int special_node) {
auto newnode = std::make_shared<CFGNode>();
newnode->bb = nullptr;
newnode->special_node = special_node;
return newnode;
}
std::string to_string() const {
if (is_entry()) return "ENTRY";
if (is_exit()) return "EXIT";
return bb->name;
}
bool is_entry() const {
return special_node == 1;
}
bool is_exit() const {
return special_node == 2;
}
void print_cfg(std::ostream &ostr) const {
ostr << "---" << to_string() << "---\n";
ostr << " pred: [";
for (auto pred : pred_list) {
ostr << pred->to_string() << ", ";
}
ostr << "]\n";
ostr << " succ: [";
for (auto succ : succ_list) {
ostr << succ->to_string() << ", ";
}
ostr << "]\n";
ostr.flush();
}
void print_dom(std::ostream &ostr) const {
print_cfg(ostr);
ostr << " idom: " << (!idom ? "NULL" : idom->to_string()) << "\n";
ostr << " domer: [";
for (auto d : dom) {
ostr << d->to_string() << ", ";
}
ostr << "]\n";
ostr << " domee: [";
for (auto domee : dominated) {
ostr << domee->to_string() << ", ";
}
ostr << "]\n";
ostr << " DF: [";
for (auto df : DF) {
ostr << df->to_string() << ", ";
}
ostr << "]\n";
ostr.flush();
}
};
class CFG {
public:
std::list<sptr(CFGNode)> node_pool;
CFGNode *entry_node;
CFGNode *exit_node;
std::unordered_map<BasicBlock *, CFGNode *> bb_node;
void build_from_func(Function *func, bool reversed);
void print_cfg();
void print_dom();
void calculate();
private:
std::unordered_set<CFGNode *> visit;
void _print_cfg(CFGNode *);
void _print_dom(CFGNode *);
};
} // namespace CompSysY } // namespace CompSysY

View File

@ -42,10 +42,7 @@ inline sptr(DST) strict_shared_cast(SRC src) {
#define INF (0x3f3f3f3f) #define INF (0x3f3f3f3f)
#define BTWN(v, l, r) (l <= v && v <= r) #define BTWN(v, l, r) (l <= v && v <= r)
#define panic(message) \ #define panic(message) __assert_fail(message, __FILE__, __LINE__, __ASSERT_FUNCTION)
do { \
throw GrammarException(__FILE__, __LINE__, (message)); \
} while (0)
#define DEF_PTR_T(type) \ #define DEF_PTR_T(type) \
class type; \ class type; \

View File

@ -76,4 +76,4 @@ private:
std::set<MInstMove *, MvCmp> active_moves; std::set<MInstMove *, MvCmp> active_moves;
}; };
} } // namespace CompSysY

View File

@ -4,7 +4,7 @@ namespace CompSysY {
class PassSCCP final : public Pass { class PassSCCP final : public Pass {
public: public:
PassSCCP() : Pass("Constant Propagation") {} PassSCCP() : Pass("constant propagation") {}
void run(const Module &module) override; void run(const Module &module) override;
enum class ConstLatTag { Top = 0, Const, Bottom }; enum class ConstLatTag { Top = 0, Const, Bottom };
struct ConstLat { struct ConstLat {
@ -58,13 +58,19 @@ public:
class PassBuildCFG final : public Pass { class PassBuildCFG final : public Pass {
public: public:
PassBuildCFG() : Pass("build_cfg") {} PassBuildCFG() : Pass("build control flow graph") {}
void run(const Module &module) override; void run(const Module &module) override;
}; };
class PassConstFold final : public Pass { class PassConstFold final : public Pass {
public: public:
PassConstFold() : Pass("const fold") {} PassConstFold() : Pass("constant fold") {}
void run(const Module &module) override;
};
class PassDCE final : public Pass {
public:
PassDCE() : Pass("dead code elimination") {}
void run(const Module &module) override; void run(const Module &module) override;
}; };

View File

@ -1,3 +1,4 @@
#include "algos.h"
#include "llir.h" #include "llir.h"
#include "visitor.h" #include "visitor.h"
@ -132,4 +133,164 @@ void update_dfs_numbers(BasicBlockPtr_t bb, bool rst) {
bb->dom_dfs_out = dfs_num++; bb->dom_dfs_out = dfs_num++;
} }
void CFG::build_from_func(Function *func, bool reversed) {
// special node entry&exit
auto _entry_node = CFGNode::New(1);
auto _exit_node = CFGNode::New(2);
node_pool.push_back(_entry_node);
node_pool.push_back(_exit_node);
entry_node = _entry_node.get();
exit_node = _exit_node.get();
// create a cfgnode for each bb
for (auto bb : func->bb_list) {
auto _new_node = CFGNode::New(bb.get());
node_pool.push_back(_new_node);
bb_node.insert({bb.get(), _new_node.get()});
}
auto as_bb = [](sptr(Value) value) { return dynamic_cast<BasicBlock *>(value.get()); };
auto connect = [&](CFGNode *pred, CFGNode *succ) {
if (reversed) std::swap(pred, succ);
pred->succ_list.push_back(succ);
succ->pred_list.push_back(pred);
};
connect(entry_node, bb_node.at(func->bb_list.front().get()));
for (auto bb : func->bb_list) {
auto _inst_br = bb->inst_list.back();
if (auto inst_br = shared_cast<InstBranch>(_inst_br)) {
if (inst_br->operand_list.size() == 1) {
connect(bb_node.at(bb.get()), bb_node.at(as_bb(inst_br->operand_list[0])));
}
else if (inst_br->operand_list.size() == 3) {
connect(bb_node.at(bb.get()), bb_node.at(as_bb(inst_br->operand_list[1])));
connect(bb_node.at(bb.get()), bb_node.at(as_bb(inst_br->operand_list[2])));
}
else {
panic("br should have either 1 or 3 operands");
}
}
else if (auto inst_ret = shared_cast<InstReturn>(_inst_br)) {
connect(bb_node.at(bb.get()), exit_node);
}
else {
panic("The last instruction of a basic block should be control transfer");
}
}
if (reversed) {
exit_node->special_node = 1;
entry_node->special_node = 2;
std::swap(exit_node, entry_node);
}
}
void CFG::print_cfg() {
visit.clear();
_print_cfg(entry_node);
}
void CFG::_print_cfg(CFGNode *rt) {
rt->print_cfg(std::cout);
for (auto succ : rt->succ_list) {
if (!visit.count(succ)) {
visit.insert(succ);
_print_cfg(succ);
}
}
}
void CFG::print_dom() {
visit.clear();
_print_dom(entry_node);
}
void CFG::_print_dom(CFGNode *rt) {
rt->print_dom(std::cout);
for (auto succ : rt->succ_list) {
if (!visit.count(succ)) {
visit.insert(succ);
_print_dom(succ);
}
}
}
/*
DOM(entry) = entry
DOM(everything else) = all nodes
change = true
while change, do
change = false
for each BB (except the entry BB)
TMP(BB) = BB + {intersect of DOM of all predecessor BBs}
if (TMP(BB) != DOM(BB))
DOM(BB) = TMP(BB)
change = true
*/
void CFG::calculate() {
entry_node->dom.clear();
entry_node->dom.insert(entry_node);
std::unordered_set<CFGNode *> all_node_set;
for (auto cfgnode : node_pool) {
all_node_set.insert(cfgnode.get());
}
for (auto cfgnode : node_pool) {
if (cfgnode->is_entry()) continue;
cfgnode->dom = all_node_set;
}
all_node_set.erase(entry_node); // now, it is every node but entry
bool changed = true;
while (changed) {
changed = false;
for (auto node : all_node_set) {
// std::unordered_set<CFGNode *> tmp = {node};
// for (auto pred : node->pred_list) {
// tmp.insert(BEGINEND(pred->dom));
// }
// if (tmp != node->dom) {
// changed = true;
// node->dom = tmp;
// }
for (auto it = node->dom.begin(); it != node->dom.end();) {
auto *x = *it;
auto check = [x](CFGNode *p) { return !INSET(p->dom, x); };
if (x != node && std::any_of(BEGINEND(node->pred_list), check)) {
changed = true;
it = node->dom.erase(it);
}
else ++it;
}
}
}
entry_node->idom = nullptr;
for (auto bb : all_node_set) {
for (auto domer : bb->dom) {
// given domer dom bb, apparently, idom(bb) != bb, i.e. domer sdom bb
if (domer == bb) continue;
// if domer not sdom x where x sdom bb
auto check = [domer, bb](CFGNode *x) { return x == bb || x == domer || !INSET(x->dom, domer); };
if (std::all_of(BEGINEND(bb->dom), check)) {
bb->idom = domer;
domer->dominated.insert(bb);
break;
}
}
}
// calculate DF
for (auto &bb : bb_node) {
bb.second->DF.clear();
}
for (auto &bb : bb_node) {
auto n = bb.second;
if (n->pred_list.size() >= 2) {
for (auto pred : n->pred_list) {
auto runner = pred;
while (runner && runner != n->idom) {
runner->DF.insert(n);
runner = runner->idom;
}
}
}
}
}
} // namespace CompSysY } // namespace CompSysY

View File

@ -116,6 +116,7 @@ int main(int argc, const char **argv) {
std::vector<sptr(Pass)> optpasses = { std::vector<sptr(Pass)> optpasses = {
std::make_shared<PassConstFold>(), std::make_shared<PassConstFold>(),
std::make_shared<PassSCCP>(), std::make_shared<PassSCCP>(),
std::make_shared<PassDCE>(),
}; };
if (flg_O1) { if (flg_O1) {

View File

@ -2,15 +2,12 @@
namespace CompSysY { namespace CompSysY {
static bool foldable(Instruction *inst) {
static bool foldable(Instruction* inst) {
if (InstTag::Add <= inst->tag && inst->tag <= InstTag::Ne) { if (InstTag::Add <= inst->tag && inst->tag <= InstTag::Ne) {
if (shared_cast<ConstantInt>(inst->operand_list[0]) && shared_cast<ConstantInt>(inst->operand_list[1])) if (shared_cast<ConstantInt>(inst->operand_list[0]) && shared_cast<ConstantInt>(inst->operand_list[1])) return true;
return true;
} }
else if (inst->tag == InstTag::Zext) { else if (inst->tag == InstTag::Zext) {
if (shared_cast<ConstantInt>(inst->operand_list[0])) if (shared_cast<ConstantInt>(inst->operand_list[0])) return true;
return true;
} }
return false; return false;
} }
@ -23,6 +20,7 @@ static InstBranch *is_cond_br(Instruction *inst) {
return nullptr; return nullptr;
} }
// not removing the block, this only means that pred no longer branch to bb
static void rewrite_bb(sptr(BasicBlock) bb, sptr(BasicBlock) pred) { static void rewrite_bb(sptr(BasicBlock) bb, sptr(BasicBlock) pred) {
pred->succ_list.remove(bb); pred->succ_list.remove(bb);
auto pred_index = GETINDEX(bb->pred_list, pred); auto pred_index = GETINDEX(bb->pred_list, pred);
@ -36,7 +34,7 @@ static void rewrite_bb(sptr(BasicBlock) bb, sptr(BasicBlock) pred) {
} }
} }
static void compute(Function* func) { static void compute(Function *func) {
bool changed = true; bool changed = true;
while (changed) { while (changed) {
changed = false; changed = false;
@ -50,8 +48,8 @@ static void compute(Function* func) {
} }
assert(0); assert(0);
}; };
changed = true; changed = true;
int result = 0; int result = 0;
TypePtr_t type = TypeHelper::TYPE_I32; TypePtr_t type = TypeHelper::TYPE_I32;
switch (inst->tag) { switch (inst->tag) {
case InstTag::Add: result = val(0) + val(1); break; case InstTag::Add: result = val(0) + val(1); break;
@ -59,12 +57,30 @@ static void compute(Function* func) {
case InstTag::Mod: result = val(0) % val(1); break; case InstTag::Mod: result = val(0) % val(1); break;
case InstTag::Mul: result = val(0) * val(1); break; case InstTag::Mul: result = val(0) * val(1); break;
case InstTag::Div: result = val(0) / val(1); break; case InstTag::Div: result = val(0) / val(1); break;
case InstTag::Lt: result = val(0) < val(1); type = TypeHelper::TYPE_I1; break; case InstTag::Lt:
case InstTag::Le: result = val(0) <= val(1); type = TypeHelper::TYPE_I1; break; result = val(0) < val(1);
case InstTag::Ge: result = val(0) >= val(1); type = TypeHelper::TYPE_I1; break; type = TypeHelper::TYPE_I1;
case InstTag::Gt: result = val(0) > val(1); type = TypeHelper::TYPE_I1; break; break;
case InstTag::Eq: result = val(0) == val(1); type = TypeHelper::TYPE_I1; break; case InstTag::Le:
case InstTag::Ne: result = val(0) != val(1); type = TypeHelper::TYPE_I1; break; result = val(0) <= val(1);
type = TypeHelper::TYPE_I1;
break;
case InstTag::Ge:
result = val(0) >= val(1);
type = TypeHelper::TYPE_I1;
break;
case InstTag::Gt:
result = val(0) > val(1);
type = TypeHelper::TYPE_I1;
break;
case InstTag::Eq:
result = val(0) == val(1);
type = TypeHelper::TYPE_I1;
break;
case InstTag::Ne:
result = val(0) != val(1);
type = TypeHelper::TYPE_I1;
break;
case InstTag::Zext: result = val(0); break; case InstTag::Zext: result = val(0); break;
default: break; default: break;
} }
@ -74,10 +90,10 @@ static void compute(Function* func) {
} }
} }
for (auto bb : func->bb_list) { for (auto bb : func->bb_list) {
for(auto inst : bb->inst_list) { for (auto inst : bb->inst_list) {
if (!is_cond_br(inst.get())) continue; if (!is_cond_br(inst.get())) continue;
if (!shared_cast<ConstantInt>(inst->operand_list[0])) continue; if (!shared_cast<ConstantInt>(inst->operand_list[0])) continue;
auto cond = shared_cast<ConstantInt>(inst->operand_list[0])->value; auto cond = shared_cast<ConstantInt>(inst->operand_list[0])->value;
auto target = shared_cast<BasicBlock>(cond ? inst->operand_list[1] : inst->operand_list[2]); auto target = shared_cast<BasicBlock>(cond ? inst->operand_list[1] : inst->operand_list[2]);
auto dead_target = shared_cast<BasicBlock>(!cond ? inst->operand_list[1] : inst->operand_list[2]); auto dead_target = shared_cast<BasicBlock>(!cond ? inst->operand_list[1] : inst->operand_list[2]);
assert(target && dead_target); assert(target && dead_target);
@ -96,4 +112,4 @@ void PassConstFold::run(const Module &module) {
} }
} }
} } // namespace CompSysY

196
src/pass_dce.cpp Normal file
View File

@ -0,0 +1,196 @@
#include "algos.h"
#include "pass.h"
namespace CompSysY {
void get_reachable(BasicBlock *bb, std::unordered_set<BasicBlock *> &reachable) {
std::unordered_set<BasicBlock *> visited;
std::vector<BasicBlock *> stk = {bb};
while (!stk.empty()) {
auto u = stk.back();
stk.pop_back();
if (INSET(visited, u)) continue;
reachable.insert(u);
visited.insert(u);
for (auto succ : u->succ_list) {
stk.push_back(succ.get());
}
}
}
/*
Remove block from the function
1. remove from pred's succ
2. remove from succ's pred
2.1 remove from succ's phi-nodes
3. remove from function
*/
void remove_block(sptr(BasicBlock) victim) {
for (auto pred : victim->pred_list) {
pred->succ_list.remove(victim);
}
for (auto succ : victim->succ_list) {
auto pred_index = GETINDEX(succ->pred_list, victim);
succ->pred_list.remove(victim);
for (auto itr = succ->inst_list.begin(); itr != succ->inst_list.end();) {
auto inst = *itr++;
if (!shared_cast<InstPhi>(inst)) break;
auto inst_phi = shared_cast<InstPhi>(inst);
assert(inst_phi->operand_list.size() > pred_index);
inst_phi->operand_list.erase(inst_phi->operand_list.begin() + pred_index);
}
}
victim->parent_func->bb_list.erase(victim->itr);
}
// 高级编译器18.1
bool Elim_Unreach_Code(Function *func) {
bool again = true;
bool changed = false;
std::unordered_set<BasicBlock *> reachable;
do {
again = false;
reachable.clear();
get_reachable(func->bb_list.front().get(), reachable);
for (auto itr = func->bb_list.begin(); itr != func->bb_list.end();) {
auto bb = *itr++;
if (!INSET(reachable, bb.get())) {
again = true;
changed = true;
VLOG(6) << "[DCE] remove unreachable block " << bb->to_string();
remove_block(bb);
}
}
} while (again);
return changed;
}
// Operations defined as critical: I/O statements, linkage code (entry & exit blocks), return values, calls to other
// procedures From Efficiently Computing SSA Form and the Control Dependence Graph, Zadeck 1991
bool Elim_Dead_Code(Function *func) {
CFG RCFG;
RCFG.build_from_func(func, true);
RCFG.calculate();
// RCFG.print_dom();
// Mark
std::unordered_map<Instruction *, bool> mark;
std::unordered_set<Instruction *> worklist;
for (auto bb : func->bb_list) {
for (auto inst : bb->inst_list) {
switch (inst->tag) {
case InstTag::Call: // though not all functions are critical, we dont impl anything to get that info
case InstTag::Ret: // useful for its caller, though could be useless, but we can't either know anything about
// this
case InstTag::Store: // though not every store is critical, mark the operands conservatively
// case InstTag::Br: // branch, well, controls the flow
break;
default: continue;
}
mark[inst.get()] = true;
worklist.insert(inst.get());
}
}
while (!worklist.empty()) {
auto inst = *worklist.begin();
worklist.erase(worklist.begin());
if (auto inst_phi = dynamic_cast<InstPhi*>(inst)) {
for (auto pred : inst->parent_bb->pred_list) {
if (!mark[pred->inst_list.back().get()]) {
mark[pred->inst_list.back().get()]= true;
worklist.insert(pred->inst_list.back().get());
}
}
}
for (auto op : inst->operand_list) {
// we ignore globals and constants
if (auto inst_op = shared_cast<Instruction>(op)) {
if (!mark[inst_op.get()]) {
mark[inst_op.get()] = true;
worklist.insert(inst_op.get());
}
}
}
auto cfgnode = RCFG.bb_node.at(inst->parent_bb.get());
for (auto rdf : cfgnode->DF) {
assert(rdf->bb);
auto block_end_branch = rdf->bb->inst_list.back().get();
assert(dynamic_cast<InstBranch *>(block_end_branch));
if (!mark[block_end_branch]) {
mark[block_end_branch] = true;
worklist.insert(block_end_branch);
}
}
}
auto bb_useful = [&](BasicBlock *bb) {
for (auto inst : bb->inst_list) {
if (mark[inst.get()]) return true;
}
return false;
};
// sweep
bool changed = false;
for (auto bb : func->bb_list) {
for (auto itr = bb->inst_list.begin(); itr != bb->inst_list.end();) {
auto inst = *itr++;
if (mark[inst.get()]) continue;
// deal with unmakred
if (auto inst_br = shared_cast<InstBranch>(inst)) {
// in this algo, jump and branch are different
if (inst_br->operand_list.size() == 1) continue;
// Still unknown what this would be like, the only case by now is
// true_branch = false_branch
changed = true;
LOG(WARNING) << "rewrite dead branch";
auto cfgnode = RCFG.bb_node.at(bb.get());
// walk up through the dom-tree(idom), and check if there is any useful insts in it
auto useful_pdom = cfgnode->idom;
while (!useful_pdom->is_entry()) {
if (bb_useful(useful_pdom->bb)) {
break;
}
useful_pdom = useful_pdom->idom;
}
// rewrite branch to jmp
inst_br->u_remove_from_usees();
inst_br->operand_list.clear();
sptr(BasicBlock) tmp = nullptr;
for (auto _bb : func->bb_list) {
if (_bb.get() == useful_pdom->bb) {
tmp = _bb;
break;
}
}
inst_br->add_operand(tmp);
}
else {
changed = true;
// inst->u_replace_users(nullptr);
inst->u_remove_from_usees();
bb->inst_list.erase(inst->inst_itr);
VLOG(6) << "remove dead instruction";
}
}
}
return changed;
}
// 高级编译器18.2
bool Fuse_Block(Function *func) {
return false;
}
void PassDCE::run(const Module &module) {
LOG(INFO) << "Run pass " << pass_name;
for (auto func : module.function_list) {
if (func->is_libfunc()) continue;
auto again = true;
while (again) {
again = false;
again = Elim_Dead_Code(func.get()) || again;
again = Elim_Unreach_Code(func.get()) && again;
again = Fuse_Block(func.get()) || again;
}
}
}
} // namespace CompSysY