From da9c7e793b11ced4808be954f213b16e61889468 Mon Sep 17 00:00:00 2001 From: Nick Gasson Date: Mon, 19 Feb 2024 22:29:32 +0000 Subject: [PATCH] Add memory-to-register promotion pass for JIT IR --- src/jit/jit-irgen.c | 1 + src/jit/jit-optim.c | 123 +++++++++++++++++++++++++++++++++++++++++++- src/jit/jit-priv.h | 1 + test/test_jit.c | 70 +++++++++++++++++++++++++ 4 files changed, 194 insertions(+), 1 deletion(-) diff --git a/src/jit/jit-irgen.c b/src/jit/jit-irgen.c index 2b0ea787..e65c85d0 100644 --- a/src/jit/jit-irgen.c +++ b/src/jit/jit-irgen.c @@ -4227,6 +4227,7 @@ void jit_irgen(jit_func_t *f) g->labels = NULL; if (kind != VCODE_UNIT_THUNK) { + jit_do_mem2reg(f); jit_do_lvn(f); jit_do_cprop(f); jit_do_dce(f); diff --git a/src/jit/jit-optim.c b/src/jit/jit-optim.c index 92884fea..cdf6f801 100644 --- a/src/jit/jit-optim.c +++ b/src/jit/jit-optim.c @@ -1022,7 +1022,7 @@ void jit_do_cprop(jit_func_t *f) ir->arg2 = copy; } - if (ir->op == J_MOV) + if (ir->op == J_MOV && ir->arg1.kind == JIT_VALUE_INT64) map[ir->result] = ir->arg1; else if (ir->result != JIT_REG_INVALID) map[ir->result].kind = JIT_VALUE_INVALID; @@ -1135,3 +1135,124 @@ void jit_delete_nops(jit_func_t *f) f->nirs = wptr; } + +//////////////////////////////////////////////////////////////////////////////// +// Memory to register promotion + +static inline jit_reg_t get_value_reg(jit_value_t value) +{ + switch (value.kind) { + case JIT_ADDR_REG: + case JIT_VALUE_REG: + return value.reg; + default: + return JIT_REG_INVALID; + } +} + +static bool try_promote_allocation(jit_func_t *f, jit_ir_t *alloc) +{ + assert(alloc->arg2.kind == JIT_VALUE_INT64); + + if (alloc->arg2.int64 > 8 || !is_power_of_2(alloc->arg2.int64)) + return false; + + jit_size_t sz = JIT_SZ_UNSPEC; + switch (alloc->arg2.int64) { + case 1: sz = JIT_SZ_8; break; + case 2: sz = JIT_SZ_16; break; + case 4: sz = JIT_SZ_32; break; + case 8: sz = JIT_SZ_64; break; + } + + jit_reg_t reg = alloc->result; + assert(reg != JIT_REG_INVALID); + + int first = -1, last = -1; + for (int i = 0; i < f->nirs; i++) { + jit_ir_t *ir = &(f->irbuf[i]); + if (ir->result == reg && ir != alloc) + return false; + else if (ir->op == J_LOAD || ir->op == J_STORE) { + jit_value_t addr = ir->op == J_LOAD ? ir->arg1 : ir->arg2; + if (get_value_reg(addr) != reg) + continue; + else if (addr.disp != 0 || ir->size != sz) + return false; + else if (ir->op == J_STORE && get_value_reg(ir->arg1) == reg) + return false; + else if (first == -1) + first = last = i; + else + last = i; + } + else if (get_value_reg(ir->arg1) == reg) + return false; + else if (get_value_reg(ir->arg2) == reg) + return false; + } + + if (first != -1) { + for (jit_ir_t *ir = f->irbuf + first; ir <= f->irbuf + last; ir++) { + if (ir->op == J_STORE && get_value_reg(ir->arg2) == reg) { + ir->op = J_MOV; + ir->size = JIT_SZ_UNSPEC; + ir->cc = JIT_CC_NONE; + ir->result = reg; + ir->arg2.kind = JIT_VALUE_INVALID; + } + else if (ir->op == J_LOAD && get_value_reg(ir->arg1) == reg) { + ir->op = J_MOV; + ir->size = JIT_SZ_UNSPEC; + ir->cc = JIT_CC_NONE; + ir->arg1.kind = JIT_VALUE_REG; + ir->arg1.reg = reg; + } + } + } + + alloc->op = J_NOP; + alloc->size = JIT_SZ_UNSPEC; + alloc->cc = JIT_CC_NONE; + alloc->result = JIT_REG_INVALID; + alloc->arg1.kind = JIT_VALUE_INVALID; + alloc->arg2.kind = JIT_VALUE_INVALID; + + return true; +} + +void jit_do_mem2reg(jit_func_t *f) +{ + if (f->framesz == 0) + return; + + int replaced = 0; + for (int i = 0; i < f->nirs; i++) { + jit_ir_t *ir = &(f->irbuf[i]); + if (ir->op == MACRO_SALLOC && try_promote_allocation(f, ir)) + replaced++; + else if (cfg_is_terminator(f, ir)) + break; + } + + if (replaced == 0) + return; + + int newsize = 0; + for (int i = 0; i < f->nirs; i++) { + jit_ir_t *ir = &(f->irbuf[i]); + if (ir->op == MACRO_SALLOC) { + assert(ir->arg1.kind == JIT_VALUE_INT64); + assert(ir->arg1.int64 >= newsize); + ir->arg1.int64 = newsize; + + assert(ir->arg2.kind == JIT_VALUE_INT64); + newsize += ALIGN_UP(ir->arg2.int64, 8); + } + else if (cfg_is_terminator(f, ir)) + break; + } + + assert(newsize < f->framesz); + f->framesz = newsize; +} diff --git a/src/jit/jit-priv.h b/src/jit/jit-priv.h index 9e4ea652..c604f764 100644 --- a/src/jit/jit-priv.h +++ b/src/jit/jit-priv.h @@ -392,6 +392,7 @@ void jit_do_lvn(jit_func_t *f); void jit_do_cprop(jit_func_t *f); void jit_do_dce(jit_func_t *f); void jit_delete_nops(jit_func_t *f); +void jit_do_mem2reg(jit_func_t *f); code_cache_t *code_cache_new(void); void code_cache_free(code_cache_t *code); diff --git a/test/test_jit.c b/test/test_jit.c index d6eafe67..c6340291 100644 --- a/test/test_jit.c +++ b/test/test_jit.c @@ -37,6 +37,8 @@ #define REG(r) ((jit_value_t){ .kind = JIT_VALUE_REG, .reg = (r) }) #define CONST(i) ((jit_value_t){ .kind = JIT_VALUE_INT64, .int64 = (i) }) #define LABEL(l) ((jit_value_t){ .kind = JIT_VALUE_LABEL, .label = (l) }) +#define ADDR(r, d) \ + ((jit_value_t){ .kind = JIT_ADDR_REG, .reg = (r), .disp = (d) }) static jit_handle_t compile_for_test(jit_t *j, const char *name) { @@ -53,6 +55,19 @@ static inline int64_t extend_value(jit_value_t value) } } +static void check_nullary(jit_func_t *f, int nth, jit_op_t expect) +{ + jit_ir_t *ir = &(f->irbuf[nth]); + if (ir->op == expect) + return; + + jit_dump_with_mark(f, nth, false); + + if (ir->op != expect) + ck_abort_msg("expected op %s but have %s", jit_op_name(expect), + jit_op_name(ir->op)); +} + static void check_unary(jit_func_t *f, int nth, jit_op_t expect, jit_value_t arg1) { @@ -1983,6 +1998,59 @@ START_TEST(test_lvn10) } END_TEST +START_TEST(test_cprop2) +{ + jit_t *j = jit_new(NULL); + + const char *text1 = + " MOV R1, R0 \n" + " ADD R0, R0, #1 \n" + " ADD R1, R1, #1 \n"; + + jit_handle_t h1 = jit_assemble(j, ident_new("myfunc1"), text1); + + jit_func_t *f = jit_get_func(j, h1); + jit_do_cprop(f); + + check_unary(f, 0, J_MOV, REG(0)); + check_binary(f, 1, J_ADD, REG(0), CONST(1)); + check_binary(f, 2, J_ADD, REG(1), CONST(1)); + + jit_free(j); +} +END_TEST + +START_TEST(test_mem2reg1) +{ + jit_t *j = jit_new(NULL); + + const char *text1 = + " $SALLOC R0, #0, #4 \n" + " $SALLOC R1, #8, #7 \n" + " $SALLOC R2, #16, #8 \n" + " LOAD.32 R3, [R0] \n" + " LOAD.32 R4, [R2] \n" + " STORE.32 #5, [R0] \n" + ; + + jit_handle_t h1 = jit_assemble(j, ident_new("myfunc1"), text1); + + jit_func_t *f = jit_get_func(j, h1); + jit_do_mem2reg(f); + + check_nullary(f, 0, J_NOP); + check_binary(f, 1, MACRO_SALLOC, CONST(0), CONST(7)); + check_binary(f, 2, MACRO_SALLOC, CONST(8), CONST(8)); + check_unary(f, 3, J_MOV, REG(0)); + check_unary(f, 4, J_LOAD, ADDR(2, 0)); + check_unary(f, 5, J_MOV, CONST(5)); + + ck_assert_int_eq(f->framesz, 16); + + jit_free(j); +} +END_TEST + Suite *get_jit_tests(void) { Suite *s = suite_create("jit"); @@ -2038,6 +2106,8 @@ Suite *get_jit_tests(void) tcase_add_test(tc, test_dce2); tcase_add_test(tc, test_code1); tcase_add_test(tc, test_lvn10); + tcase_add_test(tc, test_cprop2); + tcase_add_test(tc, test_mem2reg1); suite_add_tcase(s, tc); return s; -- 2.39.2