CompilerSysY/src/visitor.cpp
2023-05-07 21:59:04 +08:00

906 lines
33 KiB
C++

#include "visitor.h"
#include "3rdparty/easylogging++.h"
#include "common.h"
#include "llir_type.h"
#include "llir_value.h"
#include <any>
#include <memory>
#include <string>
#include <vector>
// Some virtual methods are comment out in the class definition, since they do
// nothing
// beyond base class implement
namespace antlrSysY {
#define ANY2VALUE(_my_type_) \
if (fuck_any.type() == typeid(std::shared_ptr<_my_type_>)) { \
auto value = std::any_cast<std::shared_ptr<_my_type_>>(fuck_any); \
return value; \
}
static ValuePtr_t any_to_Value(const std::any &fuck_any) {
ANY2VALUE(Value)
ANY2VALUE(User)
ANY2VALUE(FParam)
ANY2VALUE(Function)
ANY2VALUE(BasicBlock)
ANY2VALUE(ConstantInt)
ANY2VALUE(ConstantArr)
ANY2VALUE(Constant)
ANY2VALUE(GlobalVar)
ANY2VALUE(Instruction)
ANY2VALUE(InstAlloca)
ANY2VALUE(InstStore)
ANY2VALUE(InstLoad)
ANY2VALUE(InstBinary)
ANY2VALUE(InstZext)
ANY2VALUE(InstBranch)
ANY2VALUE(InstReturn)
ANY2VALUE(InstCall)
LOG(ERROR) << fuck_any.type().name();
panic("Unreachable");
}
std::any Visitor::visitConstDecl(SysyParser::ConstDeclContext *ctx) {
for (auto constDef : ctx->constDef()) {
visitConstDef(constDef);
}
return {};
}
// constInitVal is compulsory
// constDef : IDENT ('[' constExp ']')* '=' constInitVal ';'
std::any Visitor::visitConstDef(SysyParser::ConstDefContext *ctx) {
const auto const_name = ctx->IDENT()->getText();
LOG(DEBUG) << "Visiting ConstDef " << const_name;
if (_scope_tab.get_name(const_name, _scope_tab.get_level()).has_value())
throw GrammarException("Duplicate const def");
// not array
if (ctx->constExp().empty()) {
auto result = std::any_cast<std::shared_ptr<ConstantInt>>(visitConstInitVal(ctx->constInitVal()));
_scope_tab.push_name(const_name, result);
}
// const array
else {
std::vector<int> dim_list;
auto const_exp_list = ctx->constExp();
// collect size of each dimension
for (auto const_exp : const_exp_list) {
auto n_elem = std::any_cast<std::shared_ptr<ConstantInt>>(visitConstExp(const_exp));
dim_list.push_back(n_elem->value);
}
auto array_type = ArrayType::build_from_list(dim_list);
// anyways, first collect init value
_state.arr_dim_index = 0;
_state.arr_dim_list = &dim_list;
auto array_value = std::any_cast<std::vector<ValuePtr_t>>(visitConstInitVal(ctx->constInitVal()));
_state.arr_dim_list = nullptr;
sysy_assert(_state.arr_dim_index == 0);
if (_scope_tab.get_level()) {
// local const array
// Keep a pointer to the base address, alloca_ = &array, *alloca_ = &array[0]
auto alloca_ = build_InstAlloca(array_type, _state.current_bb);
_scope_tab.push_name(const_name, alloca_);
// first dim base, ptr = *alloca_ = &array[0]
auto base_ptr = build_InstGEP(alloca_, {CONST0, CONST0}, _state.current_bb);
// get the first element's base, final ptr = &a[0][0]...[0]
for (int i = 1; i < dim_list.size(); ++i) {
base_ptr = build_InstGEP(base_ptr, {CONST0, CONST0}, _state.current_bb);
}
// we store the elements in 1-dim shape
build_InstStore(array_value[0], base_ptr, _state.current_bb);
for (int i = 1; i < array_value.size(); ++i) {
auto ptr = build_InstGEP(base_ptr, {ConstantInt::make_shared(i)}, _state.current_bb);
build_InstStore(array_value[i], ptr, _state.current_bb);
}
}
else {
// global const array
auto const_arr = ConstantArr::make_shared("const_arr", array_value, array_type);
auto global_var = GlobalVar::make_shared(const_name, const_arr, true);
_scope_tab.push_name(const_name, global_var);
}
}
return {};
}
// initVal is optional
// varDef: IDENT ('[' constExp ']')* ('=' initVal)?;
std::any Visitor::visitVarDef(SysyParser::VarDefContext *ctx) {
auto var_name = ctx->IDENT()->getText();
LOG(DEBUG) << "Visiting VarDef " << var_name;
if (_scope_tab.get_name(var_name, _scope_tab.get_level()).has_value())
panic("Duplicate const def");
// Not Array
if (ctx->constExp().empty()) {
// global variable
if (_scope_tab.get_level() == 0) {
// global variable must be initialized, either specified or zeroed
if (ctx->initVal()) {
_state.isGlobalIint = true;
auto result = std::any_cast<std::shared_ptr<ConstantInt>>(visitInitVal(ctx->initVal()));
_state.isGlobalIint = false;
auto global_var_int = std::make_shared<GlobalVar>(var_name, result, false);
_scope_tab.push_name(var_name, global_var_int);
}
else {
auto global_var_int = std::make_shared<GlobalVar>(var_name, CONST0, false);
_scope_tab.push_name(var_name, global_var_int);
}
}
// local variable
else {
auto alloca_ = build_InstAlloca(TypeHelper::TYPE_I32, _state.current_bb);
_scope_tab.push_name(var_name, alloca_);
if (ctx->initVal()) {
auto init_val = any_to_Value(visitInitVal(ctx->initVal()));
build_InstStore(init_val, alloca_, _state.current_bb);
}
}
}
else {
std::vector<int> dim_list;
auto const_exp_list = ctx->constExp();
// collect size of each dimension
for (auto const_exp : const_exp_list) {
auto n_elem = std::any_cast<std::shared_ptr<ConstantInt>>(visitConstExp(const_exp));
dim_list.push_back(n_elem->value);
}
auto array_type = ArrayType::build_from_list(dim_list);
// local array
if (_scope_tab.get_level()) {
auto alloca_ = build_InstAlloca(array_type, _state.current_bb);
_scope_tab.push_name(var_name, alloca_);
if (ctx->initVal()) {
_state.arr_dim_index = 0;
_state.arr_dim_list = &dim_list;
auto array_value = std::any_cast<std::vector<ValuePtr_t>>(visitInitVal(ctx->initVal()));
_state.arr_dim_list = nullptr;
sysy_assert(_state.arr_dim_index == 0);
// Build GEP
auto base_ptr = build_InstGEP(alloca_, {CONST0, CONST0}, _state.current_bb);
for (int i = 1; i < dim_list.size(); ++i) {
base_ptr = build_InstGEP(base_ptr, {CONST0, CONST0}, _state.current_bb);
}
// TODO: BAAU-2021 calls memset in `libc` (not sysylib) to optimize this process
build_InstStore(array_value[0], base_ptr, _state.current_bb);
for (int i = 1; i < array_value.size(); ++i) {
auto ptr = build_InstGEP(base_ptr, {ConstantInt::make_shared(i)}, _state.current_bb);
build_InstStore(array_value[i], ptr, _state.current_bb);
}
}
// If there is no init expr, let it be
}
// global array, generally the same as constant global array
else {
if (ctx->initVal()) {
// first collect init value
_state.isGlobalIint = true;
_state.arr_dim_index = 0;
_state.arr_dim_list = &dim_list;
auto array_value = std::any_cast<std::vector<ValuePtr_t>>(visitInitVal(ctx->initVal()));
_state.arr_dim_list = nullptr;
sysy_assert(_state.arr_dim_index == 0);
_state.isGlobalIint = false;
auto const_arr = ConstantArr::make_shared("var_arr", array_value, array_type);
auto global_var = GlobalVar::make_shared(var_name, const_arr, false);
_scope_tab.push_name(var_name, global_var);
}
else {
auto global_var = GlobalVar::make_shared(var_name, CONST0, false);
_scope_tab.push_name(var_name, global_var);
}
}
}
return {};
}
std::any Visitor::visitInitVal(SysyParser::InitValContext *ctx) {
if (ctx->exp()) {
if (_state.isGlobalIint)
_state.isConstInt = true;
auto retval = visitExp(ctx->exp());
if (_state.isGlobalIint)
_state.isConstInt = false;
return retval;
}
// Array
else {
sysy_assert(_state.arr_dim_list);
int cur_dim = _state.arr_dim_list->at(_state.arr_dim_index);
int elem_size = 1;
std::vector<ValuePtr_t> cur_arr;
for (int i = _state.arr_dim_index + 1; i < _state.arr_dim_list->size(); ++i) {
elem_size *= _state.arr_dim_list->at(i);
}
for (auto init_val : ctx->initVal()) {
if (init_val->exp()) {
if (_state.isGlobalIint) {
auto const_value = any_to_Value(visitInitVal(init_val));
// should evaluate to const int
sysy_assert(typeid(const_value) == typeid(std::shared_ptr<ConstantInt>));
cur_arr.push_back(const_value);
}
else {
// God knows what it evaluates to
auto exp_value = any_to_Value(visitInitVal(init_val));
cur_arr.push_back(exp_value);
}
}
else {
// evaluate sub-array
// before evaluate the new sub-array, first fill up the last one
const int pos = cur_arr.size();
// the additional `%elem_size`: what if the last dim is {}?
for (int i = 0; i < (elem_size - (pos % elem_size)) % elem_size; ++i) {
cur_arr.push_back(CONST0);
}
_state.arr_dim_index++;
auto sub_array = std::any_cast<std::vector<ValuePtr_t>>(visitInitVal(init_val));
_state.arr_dim_index--;
cur_arr.insert(cur_arr.end(), sub_array.begin(), sub_array.end());
}
}
// fill up the rest
for (int i = cur_arr.size(); i < cur_dim * elem_size; ++i) {
cur_arr.push_back(CONST0);
}
return cur_arr;
}
}
// TODO: Replace all any<int> to ConstantInt
// constInitVal: constExp | ('{' (constInitVal (',' constInitVal)*)? '}');
std::any Visitor::visitConstInitVal(SysyParser::ConstInitValContext *ctx) {
if (ctx->constExp() != nullptr) {
return visitConstExp(ctx->constExp());
}
else {
sysy_assert(_state.arr_dim_list);
int cur_dim = _state.arr_dim_list->at(_state.arr_dim_index);
int elem_size = 1;
std::vector<ValuePtr_t> cur_arr;
for (int i = _state.arr_dim_index + 1; i < _state.arr_dim_list->size(); ++i) {
elem_size *= _state.arr_dim_list->at(i);
}
for (auto init_val : ctx->constInitVal()) {
if (init_val->constExp()) {
// evaluate to const int
auto const_value = any_to_Value(visitConstInitVal(init_val));
sysy_assert(TypeHelper::isIntegerType(const_value->type));
cur_arr.push_back(const_value);
}
else {
// evaluate to sub-array
// before evaluate the new sub-array, first fill up the last one
int pos = cur_arr.size();
// the additional `%elem_size`: what if the last dim is {}?
for (int i = 0; i < (elem_size - (pos % elem_size)) % elem_size; ++i) {
cur_arr.push_back(CONST0);
}
_state.arr_dim_index++;
auto sub_array = std::any_cast<std::vector<ValuePtr_t>>(visitConstInitVal(init_val));
_state.arr_dim_index--;
cur_arr.insert(cur_arr.end(), sub_array.begin(), sub_array.end());
}
}
// fill up the rest
for (int i = cur_arr.size(); i < cur_dim * elem_size; ++i) {
cur_arr.push_back(CONST0);
}
return cur_arr;
}
}
// @retval: int
std::any Visitor::visitConstExp(SysyParser::ConstExpContext *ctx) {
if (ctx->addExp() == nullptr)
panic("Unreachable");
_state.isConstInt = true;
auto result = std::any_cast<std::shared_ptr<ConstantInt>>(visitAddExp(ctx->addExp()));
_state.isConstInt = false;
LOG(DEBUG) << "ConstExp Eval to " << result->value;
return {result};
}
// addExp: mulExp | addExp ('+' | '-') mulExp;
std::any Visitor::visitAddExp(SysyParser::AddExpContext *ctx) {
if (_state.isConstInt) {
auto result = std::any_cast<std::shared_ptr<ConstantInt>>(visitMulExp(ctx->mulExp()));
if (ctx->addExp()) {
auto add_result = std::any_cast<std::shared_ptr<ConstantInt>>(visitAddExp(ctx->addExp()));
if (ctx->ADD())
result->value = add_result->value + result->value;
else if (ctx->SUB())
result->value = add_result->value - result->value;
else
panic("missing operator");
}
return {result};
}
else {
auto mul_exp = any_to_Value(visitMulExp(ctx->mulExp()));
if (ctx->addExp()) {
auto add_exp = any_to_Value(visitAddExp(ctx->addExp()));
if (std::dynamic_pointer_cast<IntegerType>(add_exp->type)->isI1()) {
add_exp = build_InstZext(add_exp, _state.current_bb);
}
if (std::dynamic_pointer_cast<IntegerType>(mul_exp->type)->isI1()) {
mul_exp = build_InstZext(mul_exp, _state.current_bb);
}
if (ctx->ADD()) {
mul_exp = build_InstBinary(InstTag::Add, add_exp, mul_exp, _state.current_bb);
}
else if (ctx->SUB()) {
mul_exp = build_InstBinary(InstTag::Sub, add_exp, mul_exp, _state.current_bb);
}
else
panic("Unreachable");
}
return mul_exp;
}
}
std::any Visitor::visitMulExp(SysyParser::MulExpContext *ctx) {
if (_state.isConstInt) {
auto result = std::any_cast<std::shared_ptr<ConstantInt>>(visitUnaryExp(ctx->unaryExp()));
if (ctx->mulExp()) {
auto mul_result = std::any_cast<std::shared_ptr<ConstantInt>>(visitMulExp(ctx->mulExp()));
if (ctx->MUL()) {
result->value = mul_result->value * result->value;
}
else if (ctx->DIV()) {
result->value = mul_result->value / result->value;
}
else if (ctx->MOD()) {
result->value = mul_result->value % result->value;
}
else
panic("Unreachable");
}
return {result};
}
else {
auto unary_exp = any_to_Value(visitUnaryExp(ctx->unaryExp()));
if (ctx->mulExp()) {
auto mul_exp = any_to_Value(visitMulExp(ctx->mulExp()));
if (std::dynamic_pointer_cast<IntegerType>(unary_exp->type)->isI1()) {
unary_exp = build_InstZext(unary_exp, _state.current_bb);
}
if (std::dynamic_pointer_cast<IntegerType>(mul_exp->type)->isI1()) {
mul_exp = build_InstZext(mul_exp, _state.current_bb);
}
if (ctx->MUL()) {
unary_exp = build_InstBinary(InstTag::Mul, mul_exp, unary_exp, _state.current_bb);
}
else if (ctx->DIV()) {
unary_exp = build_InstBinary(InstTag::Div, mul_exp, unary_exp, _state.current_bb);
}
else if (ctx->MOD()) {
unary_exp = build_InstBinary(InstTag::Mod, mul_exp, unary_exp, _state.current_bb);
}
else
panic("Unreachable");
}
return {unary_exp};
}
}
// relExp: addExp | relExp ('<' | '>' | '<=' | '>=') addExp;
std::any Visitor::visitRelExp(SysyParser::RelExpContext *ctx) {
auto add_exp = any_to_Value(visitAddExp(ctx->addExp()));
// should always has type I32
sysy_assert(TypeHelper::isIntegerTypeI32(add_exp->type));
if (ctx->relExp()) {
auto rel_exp = any_to_Value(visitRelExp(ctx->relExp()));
sysy_assert(TypeHelper::isIntegerTypeI32(rel_exp->type));
if (ctx->relExp()->LE()) {
add_exp = build_InstBinary(InstTag::Le, rel_exp, add_exp, _state.current_bb);
}
else if (ctx->relExp()->LT()) {
add_exp = build_InstBinary(InstTag::Lt, rel_exp, add_exp, _state.current_bb);
}
else if (ctx->relExp()->GE()) {
add_exp = build_InstBinary(InstTag::Ge, rel_exp, add_exp, _state.current_bb);
}
else if (ctx->relExp()->GT()) {
add_exp = build_InstBinary(InstTag::Gt, rel_exp, add_exp, _state.current_bb);
}
else
panic("Unreachable");
}
return add_exp;
}
// eqExp: relExp | eqExp ('==' | '!=') relExp;
std::any Visitor::visitEqExp(SysyParser::EqExpContext *ctx) {
auto rel_exp = any_to_Value(visitRelExp(ctx->relExp()));
sysy_assert(TypeHelper::isIntegerTypeI1(rel_exp->type));
if (ctx->eqExp()) {
auto eq_exp = any_to_Value(visitEqExp(ctx->eqExp()));
sysy_assert(TypeHelper::isIntegerTypeI1(eq_exp->type));
if (ctx->EQ()) {
rel_exp = build_InstBinary(InstTag::Eq, eq_exp, rel_exp, _state.current_bb);
}
else if (ctx->NE()) {
rel_exp = build_InstBinary(InstTag::Ne, eq_exp, rel_exp, _state.current_bb);
}
else
panic("Unreachable");
}
return rel_exp;
}
// Notes about SideEffect: except for || and &&, other sub-expression evaluations are unsequenced
// as long as they are calculated before the operator
// lAndExp: eqExp | lAndExp '&&' eqExp;
// Luckily, there is only one path to lOrExp, which is `stmt` -> `cond` -> `lOrExp`
// Thus, there is no need to care about what if it appears in a arithmetic expression
std::any Visitor::visitLAndExp(SysyParser::LAndExpContext *ctx) {
auto eq_exp_list = ctx->eqExp();
for (int i = 0; i < eq_exp_list.size(); ++i) {
auto next_block = build_BasicBlock("", _state.current_func);
auto eq_exp = any_to_Value(visitEqExp(eq_exp_list[i]));
sysy_assert(TypeHelper::isIntegerTypeI1(eq_exp->type)); // expect a boolean
auto condition = build_InstBinary(InstTag::Ne, eq_exp, CONST0, _state.current_bb);
build_InstBranch(condition, next_block, ctx->false_block, _state.current_bb);
_state.current_bb = next_block;
}
build_InstBranch(ctx->true_block, _state.current_bb);
return {};
}
// @retval: Bool
// lOrExp: lAndExp ('||' lAndExp)*;
std::any Visitor::visitLOrExp(SysyParser::LOrExpContext *ctx) {
auto n_and_exp = ctx->lAndExp().size();
for (int i = 0; i < n_and_exp - 1; ++i) {
auto next_block = build_BasicBlock("", _state.current_func);
ctx->lAndExp(i)->true_block = ctx->true_block;
ctx->lAndExp(i)->false_block = next_block;
visitLAndExp(ctx->lAndExp(i));
_state.current_bb = next_block;
}
ctx->lAndExp(n_and_exp - 1)->true_block = ctx->true_block;
ctx->lAndExp(n_and_exp - 1)->false_block = ctx->false_block;
return {};
}
// unaryExp: primaryExp | IDENT '(' (funcRParams)? ')' | unaryOp unaryExp;
std::any Visitor::visitUnaryExp(SysyParser::UnaryExpContext *ctx) {
if (_state.isConstInt) {
if (ctx->unaryExp()) {
auto result = std::any_cast<std::shared_ptr<ConstantInt>>(visitUnaryExp(ctx->unaryExp()));
if (ctx->unaryOp()->ADD())
result->value = +result->value;
else if (ctx->unaryOp()->SUB())
result->value = -result->value;
else if (ctx->unaryOp()->NOT())
result->value = !result->value;
else
panic("Unreachable");
return {result};
}
else if (ctx->primaryExp()) {
return visitPrimaryExp(ctx->primaryExp());
}
else if (ctx->IDENT()) {
panic("Unexpected func call in const expr");
}
panic("Unreachable");
}
else {
if (ctx->unaryExp()) {
auto _result = visitUnaryExp(ctx->unaryExp());
auto unary_exp = any_to_Value(_result);
sysy_assert(unary_exp->type->type_tag == Type::TypeTag::IntegerType);
if (std::dynamic_pointer_cast<IntegerType>(unary_exp->type)->isI1()) {
unary_exp = build_InstZext(unary_exp, _state.current_bb);
}
if (ctx->unaryOp()->NOT()) {
// should eval to i1
sysy_assert(_state.isCondExp);
return build_InstBinary(InstTag::Eq, unary_exp, CONST0, _state.current_bb);
}
else if (ctx->unaryOp()->ADD()) {
return unary_exp;
}
else if (ctx->unaryOp()->SUB()) {
return build_InstBinary(InstTag::Sub, CONST0, unary_exp, _state.current_bb);
}
}
else if (ctx->IDENT()) {
// Fucntion call
// TODO: buildCall/isRealParam
// TODO: Handle string & putf()
auto func_name = ctx->IDENT()->getText();
LOG(DEBUG) << "Calling Func: " << func_name;
auto _result = _func_tab.get_name(func_name);
sysy_assert(_result.has_value());
auto func = _result.value();
std::vector<ValuePtr_t> args;
// Directly parse RParams
if (ctx->funcRParams()) {
auto rparams = ctx->funcRParams()->funcRParam();
const auto &fparams = func->fparam_list;
for (int i = 0; i < rparams.size(); ++i) {
auto rparam = rparams[i];
auto fparam = fparams[i];
auto exp = any_to_Value(visitExp(rparam->exp()));
args.push_back(exp);
}
}
return build_InstCall(func, args, _state.current_bb);
}
else if (ctx->primaryExp()) {
return visitPrimaryExp(ctx->primaryExp());
}
panic("Unreachable");
}
}
// primaryExp: ('(' exp ')') | lVal | number;
std::any Visitor::visitPrimaryExp(SysyParser::PrimaryExpContext *ctx) {
// @retval: int
if (_state.isConstInt) {
if (ctx->exp()) {
return visitExp(ctx->exp());
}
else if (ctx->lVal()) {
auto value = any_to_Value(visitLVal(ctx->lVal()));
auto constint = std::dynamic_pointer_cast<ConstantInt>(value);
// actually, it is only a type assertion
return constint->value;
}
else if (ctx->number()) {
return visitNumber(ctx->number());
}
panic("Unreachable");
}
else {
if (ctx->exp()) {
return visitExp(ctx->exp());
}
else if (ctx->lVal()) {
if (_state.isRealParam) {
_state.isRealParam = false;
LOG(WARNING) << "isRealParam";
return visitLVal(ctx->lVal());
}
else {
auto child_ret = visitLVal(ctx->lVal());
auto lval = any_to_Value(child_ret);
// @retval: ConstantInt
if (lval->type->type_tag == Type::TypeTag::IntegerType) {
return lval;
}
// @retval: InstLoad
else {
LOG(WARNING) << "lval type is pointer";
// should be InstAlloca
auto ptr_type = std::dynamic_pointer_cast<PointerType>(lval->type);
return build_InstLoad(lval, ptr_type->pointed_type, _state.current_bb);
}
}
}
// @retval: int
else if (ctx->number()) {
return visitNumber(ctx->number());
}
panic("Unreachable");
}
}
// @retval: ConstantInt
std::any Visitor::visitNumber(SysyParser::NumberContext *ctx) {
return visitIntConst(ctx->intConst());
}
// @retval: ConstantInt
std::any Visitor::visitIntConst(SysyParser::IntConstContext *ctx) {
int const_int = 0;
if (ctx->DECIMAL_CONST()) {
const_int = std::stoi(ctx->DECIMAL_CONST()->getText(), nullptr, 10);
}
else if (ctx->HEXADECIMAL_CONST()) {
const_int = std::stoi(ctx->HEXADECIMAL_CONST()->getText(), nullptr, 16);
}
else if (ctx->OCTAL_CONST()) {
const_int = std::stoi(ctx->OCTAL_CONST()->getText(), nullptr, 8);
}
return build_ConstantInt("", const_int);
}
// lVal: IDENT ('[' exp ']')*;
std::any Visitor::visitLVal(SysyParser::LValContext *ctx) {
auto name = ctx->IDENT()->getText();
LOG(DEBUG) << "Eval to lVal " << name;
auto _lval = _scope_tab.get_name(name);
sysy_assert(_lval.has_value());
auto lval = _lval.value();
// @retval: ConstantInt
if (lval->type->type_tag == Type::TypeTag::IntegerType) {
return {lval};
}
if (lval->type->type_tag == Type::TypeTag::PointerType) {
auto ptr_type = std::dynamic_pointer_cast<PointerType>(lval->type);
switch (ptr_type->pointed_type->type_tag) {
case Type::TypeTag::IntegerType: {
// Int
if (ctx->exp().empty()) {
// int ref
// @retval: InstAlloca
return lval;
}
else {
LOG(WARNING) << "Unexpected array referece";
// array index, perhaps
// @retval: InstGEP
auto gep = lval;
for (auto exp_ctx : ctx->exp()) {
auto exp = any_to_Value(visitExp(exp_ctx));
gep = build_InstGEP(gep, {exp}, _state.current_bb);
}
return gep;
}
};
case Type::TypeTag::PointerType: {
if (ctx->exp().empty()) {
// ??? Pointer
// @retval: InstLoad
LOG(WARNING) << "Unexpected pointer";
auto pointed_type = std::dynamic_pointer_cast<PointerType>(lval)->pointed_type;
auto inst_load = build_InstLoad(lval, pointed_type, _state.current_bb);
return inst_load;
}
else {
// fparam array, whose first dim is represented by a pointer
auto pointed_type = std::dynamic_pointer_cast<PointerType>(lval)->pointed_type;
sysy_assert(pointed_type->type_tag == Type::TypeTag::ArrayType);
auto inst_load = build_InstLoad(lval, pointed_type, _state.current_bb);
auto ptr = build_InstGEP(inst_load, {CONST0}, _state.current_bb);
pointed_type = std::dynamic_pointer_cast<PointerType>(ptr->type)->pointed_type;
ValuePtr_t offset = ConstantInt::make_shared(0);
auto exp_list = ctx->exp();
// calculate offset by hand: offset = (offset + exp[i]) * dim_size[i]
for (int i = 0; i < exp_list.size() - 1; ++i) {
sysy_assert(typeid(pointed_type) == typeid(std::shared_ptr<ArrayType>));
auto array_type = std::dynamic_pointer_cast<ArrayType>(pointed_type);
auto exp_val = any_to_Value(visitExp(exp_list[i]));
auto dim_size = ConstantInt::make_shared(array_type->element_count);
auto inst_add = build_InstBinary(InstTag::Add, offset, exp_val, _state.current_bb);
auto inst_mul = build_InstBinary(InstTag::Mul, inst_add, dim_size, _state.current_bb);
offset = inst_mul;
pointed_type = array_type->element_type;
ptr = build_InstGEP(ptr, {CONST0, CONST0}, _state.current_bb);
}
// visit the last dimension, mul is not needed
auto exp_val = any_to_Value(visitExp(exp_list.back()));
auto inst_add = build_InstBinary(InstTag::Add, offset, exp_val, _state.current_bb);
offset = inst_add; // finally, we get the offset
pointed_type = std::dynamic_pointer_cast<PointerType>(ptr->type)->pointed_type;
if (TypeHelper::isIntegerType(pointed_type)) {
// return the address of the array element
auto arr_elem_ptr = build_InstGEP(ptr, {offset}, _state.current_bb);
return arr_elem_ptr;
}
else {
panic("Should be int");
}
}
}
case Type::TypeTag::ArrayType: {
sysy_assert(!ctx->exp().empty());
// get &array[0]
auto ptr = build_InstGEP(lval, {CONST0, CONST0}, _state.current_bb);
auto pointed_type = std::dynamic_pointer_cast<PointerType>(ptr->type)->pointed_type;
ValuePtr_t offset = ConstantInt::make_shared(0);
auto exp_list = ctx->exp();
// calculate offset by hand: offset = (offset + exp[i]) * dim_size[i]
for (int i = 0; i < exp_list.size() - 1; ++i) {
sysy_assert(typeid(pointed_type) == typeid(std::shared_ptr<ArrayType>));
auto array_type = std::dynamic_pointer_cast<ArrayType>(pointed_type);
auto exp_val = any_to_Value(visitExp(exp_list[i]));
auto dim_size = ConstantInt::make_shared(array_type->element_count);
auto inst_add = build_InstBinary(InstTag::Add, offset, exp_val, _state.current_bb);
auto inst_mul = build_InstBinary(InstTag::Mul, inst_add, dim_size, _state.current_bb);
offset = inst_mul;
pointed_type = array_type->element_type;
ptr = build_InstGEP(ptr, {CONST0, CONST0}, _state.current_bb);
}
auto exp_val = any_to_Value(visitExp(exp_list.back()));
auto inst_add = build_InstBinary(InstTag::Add, offset, exp_val, _state.current_bb);
offset = inst_add; // finally, we get the offset
pointed_type = std::dynamic_pointer_cast<PointerType>(ptr->type)->pointed_type;
if (TypeHelper::isIntegerType(pointed_type)) {
// return the address of the array element
auto arr_elem_ptr = build_InstGEP(ptr, {offset}, _state.current_bb);
return arr_elem_ptr;
}
else {
panic("Should be int");
}
}
default:
panic("Unreachable");
}
}
panic("Unreachable");
}
std::any Visitor::visitFuncDef(SysyParser::FuncDefContext *ctx) {
auto func_name = ctx->IDENT()->getText();
LOG(DEBUG) << "Visit FuncDef " << func_name;
auto func_ret_type = TypeHelper::TYPE_VOID;
if (ctx->funcType()->INT()) {
func_ret_type = TypeHelper::TYPE_I32;
}
// param list will get collected as well as locally allocated in FuncFParam
auto func_obj = std::make_shared<Function>(func_name, func_ret_type);
_module.functions.push_back(func_obj);
_func_tab.push_name(func_name, func_obj);
auto basic_block = build_BasicBlock(func_name + "_ENTRY", func_obj);
_scope_tab.enter_scope(true);
_state.current_func = func_obj;
_state.current_bb = basic_block;
if (ctx->funcFParams()) {
visitFuncFParams(ctx->funcFParams());
}
visitBlock(ctx->block());
// add return
// _scope_tab.leave_scope();
// TODO: avoid duplicate ret
if (func_ret_type->type_tag == Type::TypeTag::VoidType) {
build_InstReturn(_state.current_bb);
}
else {
build_InstReturn(CONST0, _state.current_bb);
}
return {};
}
// @retval: any
// Directly add to function, rather than return something...
std::any Visitor::visitFuncFParams(SysyParser::FuncFParamsContext *ctx) {
for (auto fparam_ctx : ctx->funcFParam()) {
auto fparam_type = std::any_cast<TypePtr_t>(visitFuncFParam(fparam_ctx));
auto fparam_name = fparam_ctx->getText();
auto fparam = std::make_shared<FParam>(fparam_name, fparam_type);
_state.current_func->fparam_list.push_back(fparam);
auto alloca_ = build_InstAlloca(fparam_type, _state.current_bb);
build_InstStore(fparam, alloca_, _state.current_bb);
_scope_tab.push_name(fparam_name, alloca_);
_state.current_func->fparam_list.push_back(fparam);
}
return {};
}
// funcFParam: bType IDENT ('[' ']' ('[' exp ']')*)?;
std::any Visitor::visitFuncFParam(SysyParser::FuncFParamContext *ctx) {
if (ctx->LBRACKET().empty()) {
// int type
return {TypeHelper::TYPE_I32};
}
else {
// array type
std::vector<int> dim_list; // the first dim must be empty, though
for (auto exp_ctx : ctx->exp()) {
_state.isConstInt = true;
auto exp_val = std::dynamic_pointer_cast<ConstantInt>(any_to_Value(visitExp(exp_ctx)));
_state.isConstInt = false;
dim_list.push_back(exp_val->value);
}
auto array_type = ArrayType::build_from_list(dim_list);
auto true_array_type = std::make_shared<PointerType>(array_type);
return {true_array_type};
}
}
std::any Visitor::visitBlock(SysyParser::BlockContext *ctx) {
_scope_tab.enter_scope();
for (auto block_item : ctx->blockItem()) {
visitBlockItem(block_item);
}
_scope_tab.leave_scope();
return {};
}
/*
* nobody needs the value of stmt, so return nothing
stmt: lVal '=' exp ';' # assignStmt
| (exp)? ';' # expStmt // do nothing
| block # blockStmt // do nothing
| 'if' '(' cond ')' stmt ('else' stmt)? # ifStmt
| 'while' '(' cond ')' stmt # whileStmt
| 'break' ';' # breakStmt
| 'continue' ';' # continueStmt
| 'return' (exp)? ';' # returnStmt;
*/
std::any Visitor::visitAssignStmt(SysyParser::AssignStmtContext *ctx) {
auto lval = any_to_Value(visitLVal(ctx->lVal()));
auto rhs = any_to_Value(visitExp(ctx->exp()));
auto store = build_InstStore(rhs, lval, _state.current_bb);
return {};
}
// TODO: Remove RETURN in else stmt
std::any Visitor::visitIfStmt(SysyParser::IfStmtContext *ctx) {
auto true_block = build_BasicBlock("_then", _state.current_func);
auto next_block = build_BasicBlock("_next", _state.current_func);
auto false_block = next_block;
if (ctx->ELSE()) {
false_block = build_BasicBlock("_else", _state.current_func);
}
ctx->cond()->lOrExp()->true_block = true_block;
ctx->cond()->lOrExp()->false_block = false_block;
visitCond(ctx->cond());
_state.current_bb = true_block;
visit(ctx->stmt(0));
build_InstBranch(next_block, _state.current_bb); // use current_bb, god knows what happened
if (ctx->ELSE()) {
_state.current_bb = false_block;
visit(ctx->stmt(1));
build_InstBranch(next_block, _state.current_bb);
}
_state.current_bb = next_block;
return {};
}
// TODO: backpatching? I am not sure whether it is necessary
std::any Visitor::visitWhileStmt(SysyParser::WhileStmtContext *ctx) {
auto while_id = std::to_string(_state.loop_stmt_count);
auto cond_block = build_BasicBlock("_loop_cond_" + while_id, _state.current_func);
auto body_block = build_BasicBlock("_loop_body_" + while_id, _state.current_func);
auto next_block = build_BasicBlock("_loop_exit_" + while_id, _state.current_func);
build_InstBranch(cond_block, _state.current_bb);
_state.loop_stack.push_back({cond_block, body_block, next_block, _state.loop_stmt_count++});
// condition
ctx->cond()->lOrExp()->true_block = body_block;
ctx->cond()->lOrExp()->false_block = next_block;
_state.current_bb = cond_block;
visitCond(ctx->cond());
// body
_state.current_bb = body_block;
visit(ctx->stmt());
build_InstBranch(cond_block, _state.current_bb);
// exit
_state.loop_stack.pop_back();
_state.current_bb = next_block;
return {};
}
std::any Visitor::visitBreakStmt(SysyParser::BreakStmtContext *ctx) {
sysy_assert(!_state.loop_stack.empty());
build_InstBranch(_state.loop_stack.back().next, _state.current_bb);
_state.current_bb = build_BasicBlock("_after_break", _state.current_func);
return {};
}
std::any Visitor::visitContinueStmt(SysyParser::ContinueStmtContext *ctx) {
sysy_assert(!_state.loop_stack.empty());
build_InstBranch(_state.loop_stack.back().cond, _state.current_bb);
_state.current_bb = build_BasicBlock("_after_continue", _state.current_func);
return {};
}
std::any Visitor::visitReturnStmt(SysyParser::ReturnStmtContext *ctx) {
if (ctx->exp()) {
auto exp = any_to_Value(visitExp(ctx->exp()));
build_InstReturn(exp, _state.current_bb);
}
else {
build_InstReturn(_state.current_bb);
}
return {};
}
} // namespace antlrSysY