From 40ac8b9d65957a7a00b3d363506fc06e1d96b5fb Mon Sep 17 00:00:00 2001 From: ridethepig Date: Thu, 15 Jun 2023 22:14:00 +0800 Subject: [PATCH] SCCP --- .vscode/launch.json | 2 +- include/common.h | 1 + include/llir_instruction.h | 52 +++--- include/llir_value.h | 32 ++-- include/pass.h | 54 ++++++- include/visitor.h | 2 +- scripts/mytester.py | 20 +-- src/main.cpp | 5 +- src/pass_const_prop.cpp | 313 +++++++++++++++++++++++++++++++++++++ src/pass_mem2reg.cpp | 78 ++++----- src/visitor_llir_gen.cpp | 9 +- 11 files changed, 464 insertions(+), 104 deletions(-) create mode 100644 src/pass_const_prop.cpp diff --git a/.vscode/launch.json b/.vscode/launch.json index b67a59a..cd1dcff 100644 --- a/.vscode/launch.json +++ b/.vscode/launch.json @@ -10,7 +10,7 @@ "name": "Debug", "program": "${workspaceFolder}/build/sysy", // "args": ["../sysytests/functional_2021/069_greatest_common_divisor.sy", "-S", "-o", "build/my.s", "-O1", "-emit-llvm"], - "args" : ["-S", "../sysytests/functional_2022/88_many_params2.sy", "-o", "build/my.s", "-emit-llvm", "-O1"], + "args" : ["-S", "../sysytests/functional_2022/21_if_test2.sy", "-o", "build/my.s", "-emit-llvm", "-O1",], "cwd": "${workspaceFolder}" }, ] diff --git a/include/common.h b/include/common.h index 6b79f9e..1d0217a 100644 --- a/include/common.h +++ b/include/common.h @@ -34,6 +34,7 @@ inline sptr(DST) strict_shared_cast(SRC src) { } #define STD_FIND(container, val) std::find(container.begin(), container.end(), val) +#define GETINDEX(container, val) std::distance(container.begin(), STD_FIND(container, val)) #define STD_FOUND(container, val) (STD_FIND(container, val) != container.end()) #define INSET(cont, val) (cont.find(val) != cont.end()) #define BEGINEND(cont) cont.begin(), cont.end() diff --git a/include/llir_instruction.h b/include/llir_instruction.h index 6cc283c..3a2519d 100644 --- a/include/llir_instruction.h +++ b/include/llir_instruction.h @@ -41,10 +41,10 @@ enum class InstTag { GEP, Zext, Phi, - MemPhi, - LoadDep, - InsertEle, - ExtractEle + // MemPhi, + // LoadDep, + // InsertEle, + // ExtractEle }; class Instruction : public User { @@ -78,11 +78,12 @@ public: case InstTag::Store: return "store"; case InstTag::GEP: return "GEP"; case InstTag::Zext: return "zext"; - case InstTag::Phi: return "Phi"; - case InstTag::MemPhi: return "MemPhi"; - case InstTag::LoadDep: return "LoadDep"; - case InstTag::InsertEle: return "InsertEle"; - case InstTag::ExtractEle: return "ExtractEle"; + case InstTag::Phi: + return "Phi"; + // case InstTag::MemPhi: return "MemPhi"; + // case InstTag::LoadDep: return "LoadDep"; + // case InstTag::InsertEle: return "InsertEle"; + // case InstTag::ExtractEle: return "ExtractEle"; } } virtual std::string to_string() override { @@ -101,7 +102,7 @@ public: static sptr(InstAlloca) New(const std::string &name, TypePtr_t type, sptr(BasicBlock) parent_bb) { auto inst = std::make_shared(type, parent_bb); inst->name = name; - auto func_head_bb = parent_bb->parent->bb_list.front(); + auto func_head_bb = parent_bb->parent_func->bb_list.front(); // put all alloca in the head of each basic block inst->inst_itr = func_head_bb->inst_list.insert(func_head_bb->inst_list.begin(), inst); return inst; @@ -113,8 +114,8 @@ public: InstStore(std::shared_ptr value, std::shared_ptr pointer, std::shared_ptr parent_bb) : Instruction(InstTag::Store, TypeHelper::TYPE_VOID, parent_bb) { assert(value); - Add_Operand(value); - Add_Operand(pointer); + add_operand(value); + add_operand(pointer); } static sptr(InstStore) New(sptr(Value) value, sptr(Value) pointer, sptr(BasicBlock) parent_bb) { auto inst = std::make_shared(value, pointer, parent_bb); @@ -127,7 +128,7 @@ class InstLoad : public Instruction { public: InstLoad(std::shared_ptr value, TypePtr_t type, std::shared_ptr parent_bb) : Instruction(InstTag::Load, type, parent_bb) { - Add_Operand(value); + add_operand(value); } virtual std::string to_IR_string() override { std::string str = type->to_IR_string() + " %" + std::to_string(ir_seqno); @@ -150,8 +151,8 @@ public: std::shared_ptr parent_bb ) : Instruction(inst_tag, val_type, parent_bb) { - Add_Operand(op1); - Add_Operand(op2); + add_operand(op1); + add_operand(op2); } virtual std::string to_IR_string() override { std::string str = type->to_IR_string() + " %" + std::to_string(ir_seqno); @@ -193,7 +194,7 @@ class InstZext : public Instruction { public: InstZext(std::shared_ptr op, std::shared_ptr parent_bb) : Instruction(InstTag::Zext, TypeHelper::TYPE_I32, parent_bb) { - Add_Operand(op); + add_operand(op); } virtual std::string to_IR_string() override { std::string str = type->to_IR_string() + " %" + std::to_string(ir_seqno); @@ -206,14 +207,15 @@ public: } }; +// Operands: 0:cond, 1:true, 2:false | 0:target class InstBranch : public Instruction { public: // conditional branch InstBranch(ValuePtr_t cond, BasicBlockPtr_t true_block, BasicBlockPtr_t false_block, BasicBlockPtr_t parent_bb) : Instruction(InstTag::Br, TypeHelper::TYPE_VOID, parent_bb) { - this->Add_Operand(cond); - this->Add_Operand(true_block); - this->Add_Operand(false_block); + this->add_operand(cond); + this->add_operand(true_block); + this->add_operand(false_block); } static sptr(InstBranch) New(ValuePtr_t cond, BasicBlockPtr_t true_block, BasicBlockPtr_t false_block, BasicBlockPtr_t parent_bb) { @@ -224,7 +226,7 @@ public: // unconditional branch InstBranch(BasicBlockPtr_t target_block, BasicBlockPtr_t parent_bb) : Instruction(InstTag::Br, TypeHelper::TYPE_VOID, parent_bb) { - this->Add_Operand(target_block); + this->add_operand(target_block); } static sptr(InstBranch) New(BasicBlockPtr_t target_block, BasicBlockPtr_t parent_bb) { auto inst = std::make_shared(target_block, parent_bb); @@ -237,7 +239,7 @@ class InstReturn : public Instruction { public: InstReturn(ValuePtr_t ret_val, BasicBlockPtr_t parent_bb) : Instruction(InstTag::Ret, TypeHelper::TYPE_VOID, parent_bb) { - this->Add_Operand(ret_val); + this->add_operand(ret_val); } InstReturn(BasicBlockPtr_t parent_bb) : Instruction(InstTag::Ret, TypeHelper::TYPE_VOID, parent_bb) {} static sptr(InstReturn) New(ValuePtr_t ret_val, BasicBlockPtr_t parent_bb) { @@ -257,9 +259,9 @@ class InstCall : public Instruction { public: InstCall(FunctionPtr_t func, const std::vector &args, BasicBlockPtr_t parent_bb) : Instruction(InstTag::Call, func->get_type()->return_type, parent_bb) { - Add_Operand(func); + add_operand(func); for (auto arg : args) { - Add_Operand(arg); + add_operand(arg); } } virtual std::string to_IR_string() override { @@ -305,9 +307,9 @@ public: } assert(indices.size() <= 2); element_type = extract_type(pointer, indices); - Add_Operand(pointer); + add_operand(pointer); for (auto index : indices) { - Add_Operand(index); + add_operand(index); } } diff --git a/include/llir_value.h b/include/llir_value.h index c85af14..c8808d2 100644 --- a/include/llir_value.h +++ b/include/llir_value.h @@ -71,28 +71,28 @@ public: std::vector operand_list; User(const std::string &name, TypePtr_t type) : Value(name, type) {} - void Add_Operand(ValuePtr_t op) { + /* + 把一个Value加入到operands里面,同时维护use信息 + */ + void add_operand(ValuePtr_t op) { // value, use, op_index op->use_list.push_back({/*op.get(),*/ this, (unsigned)operand_list.size()}); operand_list.push_back(op); } - // make anything that use this value use the new value + /* + 把所有this的use给换掉,然后this就没人用了 + this拥有一个use_list,里面记录了所有使用this的指令 + */ void u_replace_users(ValuePtr_t value) { - if (value == nullptr) { - assert(!use_list.size() && "No one should use this"); - return; - } for (auto use : use_list) { - auto user = use.user; - auto index = use.op_index; - user->operand_list[index] = value; assert(value); - value->use_list.push_back({/*value.get(),*/ user, index}); + use.user->operand_list[use.op_index] = value; + value->use_list.push_back({/*value.get(),*/ use.user, use.op_index}); } // all original uses are gone use_list.clear(); } - // remove this user from its operands + // 更新所有被use到的指令,因为this不再使用它,也就是从每个operands的use_list里面把自己删掉 void u_remove_from_usees() { for (auto op : operand_list) { assert(op); @@ -140,7 +140,7 @@ class BasicBlock : public Value { public: int ir_seqno = -1; std::list inst_list; - std::shared_ptr parent; + std::shared_ptr parent_func; BasicBlockListNode_t itr; std::list succ_list; std::list pred_list; @@ -155,7 +155,7 @@ public: int dom_dfs_out; BasicBlock(const std::string &name, std::shared_ptr parent) : Value(name, TypeHelper::TYPE_LABEL) { - this->parent = parent; + this->parent_func = parent; } static sptr(BasicBlock) @@ -175,9 +175,9 @@ public: class ConstantInt : public Constant { public: int value; - ConstantInt(const std::string &name, int value) : Constant(name, TypeHelper::TYPE_I32), value(value) {} - static std::shared_ptr New(int value) { - return std::make_shared("", value); + ConstantInt(const std::string &name, int value, TypePtr_t type) : Constant(name, type), value(value) {} + static std::shared_ptr New(int value, TypePtr_t type = TypeHelper::TYPE_I32) { + return std::make_shared("", value, type); } virtual std::string to_string() override { std::string str = type->to_string() + " " + std::to_string(value); diff --git a/include/pass.h b/include/pass.h index ea2184e..e598310 100644 --- a/include/pass.h +++ b/include/pass.h @@ -11,20 +11,68 @@ class Pass { public: std::string pass_name; Pass(const std::string &name) : pass_name(name) {} - + virtual ~Pass() = default; virtual void run(const Module &module) = 0; }; +class PassSCCP final : public Pass { +public: + PassSCCP() : Pass("const fold") {} + void run(const Module &module) override; + enum class ConstLatTag { Top = 0, Const, Bottom }; + struct ConstLat { + ConstLatTag tag = ConstLatTag::Top; + int value = ~0; + bool is_top() const { + return tag == ConstLatTag::Top; + } + bool is_bot() const { + return tag == ConstLatTag::Bottom; + } + bool is_const() const { + return tag == ConstLatTag::Const; + } + bool operator==(const ConstLat &op2) const { + return tag == op2.tag && value == op2.value; + } + bool operator!=(const ConstLat &op2) const { + return !(tag == op2.tag && value == op2.value); + } + static ConstLat get_bot() { + return {ConstLatTag::Bottom, ~0}; + } + }; + +private: + typedef std::pair edge_type; + std::set FLowWL; + std::set SSAWL; + std::map ExecFlag; + std::map LatCell; + + std::unordered_map inst_id; + std::unordered_map id_inst; + std::set> edge_set; + std::unordered_map> edge_list; + void build_single_inst_block(Function *func); + void Initialize(Function *); + void SCCP(Function *); + void Visit_Phi(InstPhi *); + void Visit_Inst(Instruction *); + ConstLat Lat_Eval(Instruction *inst); + void post_sccp(); +}; + class PassMem2Reg final : public Pass { public: PassMem2Reg() : Pass("mem2reg") {} - virtual void run(const Module &module) override; + void run(const Module &module) override; }; class PassBuildCFG final : public Pass { public: PassBuildCFG() : Pass("build_cfg") {} - virtual void run(const Module &module) override; + void run(const Module &module) override; }; class MCPass { diff --git a/include/visitor.h b/include/visitor.h index f45f57b..41a01cd 100644 --- a/include/visitor.h +++ b/include/visitor.h @@ -43,7 +43,7 @@ private: ScopeTable> _scope_tab; ScopeTable> _func_tab; // var can have same name as func VisitorState _state = {}; - inline static std::shared_ptr CONST0 = std::make_shared("CONST0", 0); + inline static std::shared_ptr CONST0 = ConstantInt::New(0); public: Module module = {}; diff --git a/scripts/mytester.py b/scripts/mytester.py index 7e85d4d..5581259 100644 --- a/scripts/mytester.py +++ b/scripts/mytester.py @@ -52,11 +52,11 @@ class Compiler: log = self.compile_log.format(testcase=testcase) log_file = open(log, "a+") - Print_C.print_procedure(f"Generating {ir} from {sy}") + Print_C.print_procedure(f"Gen {ir} from {sy}") completed = subprocess.run(frontend_instr.format(sy=sy, ir=ir).split(), stdout=log_file, stderr=log_file, bufsize=1) log_file.close() if completed.returncode != 0: - Print_C.print_error(f"Generate {ir} failed! See {log}") + Print_C.print_error(f"Gen {ir} failed! See {log}") self.count_error += 1 return False return True @@ -68,11 +68,11 @@ class Compiler: log = self.compile_log.format(testcase=testcase) log_file = open(log, "a+") - Print_C.print_procedure(f"Generating {asm} from {ir}") + Print_C.print_procedure(f"Gen {asm} from {ir}") completed = subprocess.run(ir_asm_instr.format(ir=ir, asm=asm).split(), stdout=log_file, stderr=log_file, bufsize=1) log_file.close() if completed.returncode != 0: - Print_C.print_error(f"Generate {asm} failed! See {log}") + Print_C.print_error(f"Gen {asm} failed! See {log}") self.count_error += 1 return False return True @@ -84,11 +84,11 @@ class Compiler: log = self.compile_log.format(testcase=testcase) log_file = open(log, "a+") - Print_C.print_procedure(f"Generating {obj} from {asm}") + Print_C.print_procedure(f"Gen {obj} from {asm}") completed = subprocess.run(asm_obj_instr.format(asm=asm,obj=obj).split(), stdout=log_file, stderr=log_file, bufsize=1) log_file.close() if completed.returncode != 0: - Print_C.print_error(f"Generate {obj} failed! See {log}") + Print_C.print_error(f"Gen {obj} failed! See {log}") self.count_error += 1 return False return True @@ -100,11 +100,11 @@ class Compiler: log = self.compile_log.format(testcase=testcase) log_file = open(log, "a+") - Print_C.print_procedure(f"Generating {bin}") + Print_C.print_procedure(f"Gen {bin}") completed = subprocess.run(obj_bin_instr.format(obj=obj,bin=bin).split(), stdout=log_file, stderr=log_file, bufsize=1) log_file.close() if completed.returncode != 0: - Print_C.print_error(f"Generate {bin} failed! See {log}") + Print_C.print_error(f"Gen {bin} failed! See {log}") self.count_error += 1 return False return True @@ -116,12 +116,12 @@ class Compiler: log = self.compile_log.format(testcase=testcase) log_file = open(log, "a+") - Print_C.print_procedure(f"Generating {asm}") + Print_C.print_procedure(f"Gen {asm} from {sy}") completed = subprocess.run(sy_asm_instr.format(asm=asm, sy=sy).split(), stdout=log_file, stderr=log_file, bufsize=1) # print(sy_asm_instr.format(asm=asm, sy=sy)) log_file.close() if completed.returncode != 0: - Print_C.print_error(f"Generate {bin} failed! See {log}") + Print_C.print_error(f"Gen {bin} failed! See {log}") self.count_error += 1 return False return True diff --git a/src/main.cpp b/src/main.cpp index d1030a4..6ef73a2 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -113,7 +113,10 @@ int main(int argc, const char **argv) { } } - std::vector optpasses = {}; + std::vector optpasses = { + std::make_shared(), + }; + if (flg_O1) { for (auto pass : optpasses) { pass->run(visitor.module); diff --git a/src/pass_const_prop.cpp b/src/pass_const_prop.cpp new file mode 100644 index 0000000..98daee2 --- /dev/null +++ b/src/pass_const_prop.cpp @@ -0,0 +1,313 @@ +#include "pass.h" + +namespace CompSysY { + +using ConstLat = PassSCCP::ConstLat; +using ConstLatTag = PassSCCP::ConstLatTag; + +static InstBranch *is_cond_br(Instruction *inst) { + auto inst_br = dynamic_cast(inst); + if (inst_br && inst_br->operand_list.size() == 3) { + return inst_br; + } + return nullptr; +} +static bool Has_Left(Instruction *inst) { + return inst->type->type_tag != Type::TypeTag::VoidType || is_cond_br(inst); +} + +void PassSCCP::build_single_inst_block(Function *func) { + int id_cnt = 1; + inst_id.clear(); + id_inst.clear(); + edge_list.clear(); + edge_set.clear(); + // dummy entry node + edge_set.insert({0, 1}); + edge_list[0].insert(1); + std::vector stk = {func->bb_list.front().get()}; + std::unordered_map visit; + while (stk.size()) { + auto bb = stk.back(); + stk.pop_back(); + if (visit[bb]) continue; + visit[bb] = true; + for (auto inst : bb->inst_list) { + inst_id[inst.get()] = id_cnt; + id_inst[id_cnt] = inst.get(); + ++id_cnt; + } + for (auto succ : bb->succ_list) { + stk.push_back(succ.get()); + } + } + stk.push_back(func->bb_list.front().get()); + visit.clear(); + while (stk.size()) { + auto bb = stk.back(); + stk.pop_back(); + if (visit[bb]) continue; + visit[bb] = true; + int begin_id = inst_id[bb->inst_list.front().get()]; + int end_id = inst_id[bb->inst_list.back().get()]; + for (int i = begin_id; i < end_id; ++i) { + edge_set.insert({i, i + 1}); + edge_list[i].insert(i + 1); + } + for (auto succ : bb->succ_list) { + stk.push_back(succ.get()); + edge_set.insert({end_id, inst_id.at(succ->inst_list.front().get())}); + edge_list[end_id].insert(inst_id.at(succ->inst_list.front().get())); + } + } +} + +void PassSCCP::Initialize(Function *func) { + build_single_inst_block(func); + SSAWL.clear(); + FLowWL.clear(); + ExecFlag.clear(); + for (auto e : edge_set) { + ExecFlag[e] = false; + } + LatCell.clear(); + for (auto [a, inst] : id_inst) { + if (Has_Left(inst)) { + LatCell[inst] = ConstLat(); + } + } + // for ExecFlag, LatCell, [] operator gives default value on first use + auto entry_bb = func->bb_list.front(); + auto entry_inst = inst_id[entry_bb->inst_list.front().get()]; + // dummy entry node, anyways, it won't get used + FLowWL.insert({0, entry_inst}); +} + +/* +According to Wegman,Zadeck's paper, the rules for \union: + 1. any \union Top = any + 2. any \union Bot = Bot + 3. v1 \union v2 = v1, if v1==v2 + 4. v1 \union v2 = Bot, if v1!=v2 +*/ +static ConstLat lattice_meet(const ConstLat &op1, const ConstLat &op2) { + // 1. + if (op1.is_top()) return op2; + if (op2.is_top()) return op1; + // 2. + if (op1.is_bot() || op2.is_bot()) return ConstLat::get_bot(); + // 3. + if (op1 == op2) return op1; + // 4. + return ConstLat::get_bot(); +} + +/* +expression ebaluation rules: the value of the operands of an expression corresponds to +the value of the variables at the entrance to the node, +and the result corresponds to the value of variables that +change during the execution of the node. + +Usually, +1. if the node is an assignment and any of the variables used in its expression portion has a value of \bot, the value +exiting the assignment statement for that variable is \bot. +2. If all values used in its expression portion are constant,the value of the assigned variable is the value of the +expression when evaluated with those constant values. +3. Otherwise, the value assigned is \bot +*/ +ConstLat PassSCCP::Lat_Eval(Instruction *inst) { + bool evaluatable = false; + if (InstTag::Add <= inst->tag && inst->tag <= InstTag::Ne || inst->tag == InstTag::Zext) evaluatable = true; + if (is_cond_br(inst)) { + // directly use its cond's value as its result + return LatCell[inst->operand_list[0].get()]; + } + if (!evaluatable) { + return ConstLat::get_bot(); + } + for (auto op : inst->operand_list) { + if (dynamic_cast(op.get()) || LatCell[op.get()].is_const()) continue; + return ConstLat::get_bot(); + } + auto val = [&](int op) { + if (auto const_op = dynamic_cast(inst->operand_list[op].get())) { + return const_op->value; + } + else { + return LatCell[inst->operand_list[op].get()].value; + } + }; + int result = 0; + switch (inst->tag) { + case InstTag::Add: result = val(0) + val(1); break; + case InstTag::Sub: result = val(0) - val(1); break; + case InstTag::Mod: result = val(0) % val(1); break; + case InstTag::Mul: result = val(0) * val(1); break; + case InstTag::Div: result = val(0) / val(1); break; + case InstTag::Lt: result = val(0) < val(1); break; + case InstTag::Le: result = val(0) <= val(1); break; + case InstTag::Ge: result = val(0) >= val(1); break; + case InstTag::Gt: result = val(0) > val(1); break; + case InstTag::Eq: result = val(0) == val(1); break; + case InstTag::Ne: result = val(0) != val(1); break; + case InstTag::Zext: result = val(0); break; + default: break; + } + return {ConstLatTag::Const, result}; +} + +void PassSCCP::SCCP(Function *func) { + Initialize(func); + auto Edge_Count = [&](int b) { + int cnt = 0; + for (int i = 0; i < inst_id.size(); ++i) { + auto e = std::make_pair(i, b); + if (INSET(edge_set, e) && ExecFlag[e]) { + cnt++; + } + } + return cnt; + }; + while (!FLowWL.empty() || !SSAWL.empty()) { + // propagate along CFG edges + if (!FLowWL.empty()) { + int a = FLowWL.begin()->first; + int b = FLowWL.begin()->second; + FLowWL.erase(FLowWL.begin()); + if (!ExecFlag[{a, b}]) { + ExecFlag[{a, b}] = true; + auto inst = id_inst[b]; + if (auto inst_phi = dynamic_cast(inst)) { + Visit_Phi(inst_phi); + } + else if (Edge_Count(b) == 1) { + Visit_Inst(inst); + } + } + } + // propagate along SSA edges(def-use) + if (!SSAWL.empty()) { + int a = SSAWL.begin()->first; + int b = SSAWL.begin()->second; + SSAWL.erase(SSAWL.begin()); + auto inst = id_inst[b]; + if (auto inst_phi = dynamic_cast(inst)) { + Visit_Phi(inst_phi); + } + else if (Edge_Count(b) >= 1) { + Visit_Inst(inst); + } + } + } +} + +void PassSCCP::Visit_Phi(InstPhi *inst_phi) { + for (int i = 0; i < inst_phi->operand_list.size(); ++i) { + auto inst_vars = inst_phi->operand_list[i].get(); + LatCell[inst_phi] = lattice_meet(LatCell[inst_phi], LatCell[inst_vars]); + } +} + +void PassSCCP::Visit_Inst(Instruction *inst) { + auto EL = [&](int v) -> bool { + auto inst_br = dynamic_cast(inst); + assert(inst_br && inst_br->operand_list.size() == 3); + auto true_bb = dynamic_cast(inst_br->operand_list[1].get()); + if (id_inst[v] == true_bb->inst_list.front().get()) { + return true; + } + return false; + }; + auto val = Lat_Eval(inst); + if (Has_Left(inst)) { + if (val != LatCell[inst]) { + LatCell[inst] = lattice_meet(LatCell[inst], val); + // SSAWL UNION SSASucc(inst) + for (auto &use : inst->use_list) { + auto user = dynamic_cast(use.user); + SSAWL.insert({inst_id[inst], inst_id[user]}); + } + } + } + if (InstTag::Add <= inst->tag && inst->tag <= InstTag::Ne || inst->tag == InstTag::Zext || is_cond_br(inst)) { + int k = inst_id[inst]; + if (val.is_top()) { + for (auto i : edge_list[k]) { + FLowWL.insert({k, i}); + } + } + else if (!val.is_bot()) { + // imply: val is const + assert(edge_list[k].size() <= 2); + if (edge_list[k].size() == 2) { + for (auto i : edge_list[k]) { + // val & EL(k,i)=Y OR !val & EL(k,i)=N + if ((val.value && EL(i)) || (!val.value && !EL(i))) { + FLowWL.insert({k, i}); + } + } + } + else if (edge_list[k].size() == 1) { + // FlowWL UNION k->Succ(k) + FLowWL.insert({k, *edge_list[k].begin()}); + } + } + } +} + +/* +pred is no more bb's pred, so maintain cfg and rewrite bb's phi nodes +*/ +static void rewrite_bb(sptr(BasicBlock) bb, sptr(BasicBlock) pred) { + pred->succ_list.remove(bb); + auto pred_index = GETINDEX(bb->pred_list, pred); + bb->pred_list.remove(pred); + for (auto itr = bb->inst_list.begin(); itr != bb->inst_list.end();) { + auto inst = *itr++; + if (!shared_cast(inst)) break; + auto inst_phi = shared_cast(inst); + assert(inst_phi->operand_list.size() > pred_index); + inst_phi->operand_list.erase(inst_phi->operand_list.begin() + pred_index); + } +} + +void PassSCCP::post_sccp() { + for (auto lat_pair : LatCell) { + if (lat_pair.second.is_const()) { + auto inst = dynamic_cast(lat_pair.first); + inst->u_replace_users(ConstantInt::New(lat_pair.second.value, inst->type)); + inst->u_remove_from_usees(); + VLOG(6) << fmt::format("[SCCP] Replace vreg ${} with const {}", inst_id[inst], lat_pair.second.value); + // Though it should be removed here... leave this work to dce part + } + } + // rewrite branches + for (const auto &_p : inst_id) { + auto inst = _p.first; + if (auto inst_br = is_cond_br(inst)) { + if (auto const_cond = dynamic_cast(inst_br->operand_list[0].get())) { + auto cond = const_cond->value; + VLOG(6) << fmt::format("[SCCP] Replace branch cond ${} with const {}", inst_id[inst], const_cond->value); + inst_br->u_remove_from_usees(); + auto target = shared_cast(cond ? inst->operand_list[1] : inst->operand_list[2]); + auto dead_target = shared_cast(!cond ? inst->operand_list[1] : inst->operand_list[2]); + assert(target && dead_target); + inst_br->operand_list.clear(); + inst_br->add_operand(target); + // maintain cfg + rewrite_bb(dead_target, inst_br->parent_bb); + } + } + } +} + +void PassSCCP::run(const Module &module) { + LOG(INFO) << "Run pass " << pass_name; + for (auto func : module.function_list) { + if (func->is_libfunc()) continue; + SCCP(func.get()); + post_sccp(); + } +} + +} // namespace CompSysY \ No newline at end of file diff --git a/src/pass_mem2reg.cpp b/src/pass_mem2reg.cpp index 4ff0f53..3cf45bc 100644 --- a/src/pass_mem2reg.cpp +++ b/src/pass_mem2reg.cpp @@ -234,70 +234,58 @@ static void _mem_2_reg(FunctionPtr_t func) { if (alloca_list.empty()) return; VLOG(6) << "[mem2reg] Variable renaming"; + /* + From LLVM, 编译器设计书上写的算法比较抽象, 所以抄了个别的(乐) + 本质上是一个先序dfs,从entry开始沿着succ往后走,在遍历中更新每个alloca的值,然后替换 + */ std::vector _init_values; for (int i = 0; i < alloca_list.size(); ++i) _init_values.push_back(ConstantInt::New(0)); - std::vector rename_list = {{func->bb_list.front(), nullptr, _init_values}}; + std::vector rename_list; + rename_list.push_back({func->bb_list.front(), nullptr, _init_values}); std::vector visited(bb_to_id.size(), 0); while (!rename_list.empty()) { auto rename_info = rename_list.back(); rename_list.pop_back(); - // replace block with more specific alloca + // 将phi指令中对应pred的值更新,因为它在处理pred的时候已经被重新定值 for (auto inst : rename_info.bb->inst_list) { // phi only appear at block head if (!shared_cast(inst)) break; - auto phi = shared_cast(inst); - auto alloca_index = phi_to_allocaid.at(phi); - int pred_index = -1; - for (auto pred : rename_info.bb->pred_list) { - pred_index++; - if (pred == rename_info.pred) break; - } - phi->set_incoming_val(pred_index, rename_info.value_list[alloca_index]); + auto phi = shared_cast(inst); + int pred_index = GETINDEX(rename_info.bb->pred_list, rename_info.pred); + phi->set_incoming_val(pred_index, rename_info.value_list[phi_to_allocaid.at(phi)]); } // already processed, skip if (visited[bb_to_id.at(rename_info.bb)]) continue; visited[bb_to_id.at(rename_info.bb)] = true; // process instruction + // 这里其实十分的清楚,就是把所有load alloca的load指令删掉,把这条load的use换成alloca的值(而不是alloca地址) + // 然后把store alloca的store也删掉,同时更新当前的alloca的值 + // 遇到phi指令要更新alloca的值为phi + // 如果是alloca,记得直接删掉(当然只删掉前面处理过的int类型的) for (auto itr = rename_info.bb->inst_list.begin(); itr != rename_info.bb->inst_list.end();) { - auto inst = *itr++; + auto inst = *itr++; // increase itr first, it will get invalidated by remove // we skip non-integer alloca, they are not in our alloca_list - if (shared_cast(inst) && alloca_to_id.count(shared_cast(inst))) { - rename_info.bb->inst_list.remove(inst); + if (shared_cast(inst)) { + if (alloca_to_id.count(shared_cast(inst))) rename_info.bb->inst_list.remove(inst); } - else if (shared_cast(inst)) { - auto li = shared_cast(inst); - if (!(shared_cast(li->operand_list[0]))) { - continue; - } - auto ai = shared_cast(li->operand_list[0]); - if (!Type::isType(get_pointed_type(ai->type))) { - continue; - } - int alloca_index = alloca_to_id.at(ai); - rename_info.bb->inst_list.remove(inst); - li->u_replace_users(rename_info.value_list[alloca_index]); - inst->u_remove_from_usees(); + else if (auto inst_ld = shared_cast(inst)) { + if (!(shared_cast(inst_ld->operand_list[0]))) continue; + auto ai = shared_cast(inst_ld->operand_list[0]); + if (!Type::isType(get_pointed_type(ai->type))) continue; + rename_info.bb->inst_list.remove(inst_ld); + inst_ld->u_replace_users(rename_info.value_list[alloca_to_id.at(ai)]); + inst_ld->u_remove_from_usees(); } - else if (shared_cast(inst)) { - auto si = shared_cast(inst); - if (!(shared_cast(si->operand_list[1]))) { - continue; - } - auto ai = shared_cast(si->operand_list[1]); - if (!Type::isType(get_pointed_type(ai->type))) { - continue; - } - int alloca_index = alloca_to_id.at(ai); - rename_info.value_list[alloca_index] = si->operand_list[0]; - inst->u_remove_from_usees(); - // I dont think anyone will use a store? - si->u_replace_users(nullptr); - rename_info.bb->inst_list.remove(inst); + else if (auto inst_st = shared_cast(inst)) { + if (!(shared_cast(inst_st->operand_list[1]))) continue; + auto ai = shared_cast(inst_st->operand_list[1]); + if (!Type::isType(get_pointed_type(ai->type))) continue; + rename_info.value_list[alloca_to_id.at(ai)] = inst_st->operand_list[0]; + inst_st->u_remove_from_usees(); + rename_info.bb->inst_list.remove(inst_st); } - else if (shared_cast(inst)) { - auto phi = shared_cast(inst); - int alloca_index = phi_to_allocaid.at(phi); - rename_info.value_list[alloca_index] = phi; + else if (auto phi = shared_cast(inst)) { + rename_info.value_list[phi_to_allocaid.at(phi)] = phi; } } for (auto succ : rename_info.bb->succ_list) { diff --git a/src/visitor_llir_gen.cpp b/src/visitor_llir_gen.cpp index a5c80db..c158dd6 100644 --- a/src/visitor_llir_gen.cpp +++ b/src/visitor_llir_gen.cpp @@ -168,6 +168,11 @@ static void _gen_blocks(std::ostream &ostr, const std::list &bl assert(bb_dest->ir_seqno >= 0); ostr << "label %" << bb_dest->ir_seqno; } + else if (auto const_cond = shared_cast(inst->operand_list[0])) { + //! temporary patch after sccp, but it is often incorrect for phi + auto target_bb = const_cond->value ? inst->operand_list[1] : inst->operand_list[2]; + ostr << "label %" << shared_cast(target_bb)->ir_seqno; + } else { assert(shared_cast(inst->operand_list[0])); assert(Type::isType(inst->operand_list[0]->type)); @@ -416,11 +421,11 @@ static void _gen_blocks(std::ostream &ostr, const std::list &bl } else if (shared_cast(inst->operand_list[0])) { auto op0 = shared_cast(inst->operand_list[0]); - ostr << op0->to_IR_string() << "to "; + ostr << op0->to_IR_string() << " to "; } else if (shared_cast(inst->operand_list[0])) { auto op0 = shared_cast(inst->operand_list[0]); - ostr << op0->to_IR_string() << "to "; + ostr << op0->to_IR_string() << " to "; } else { LOG(ERROR) << "Unexpected type of op0: " << inst->operand_list[0]->to_string();