From 1213df0456fc24391498797027547c7df25426f1 Mon Sep 17 00:00:00 2001 From: Nick Gasson Date: Fri, 4 Nov 2022 22:03:06 +0000 Subject: [PATCH] Add a simple local value numbering pass --- src/jit/jit-core.c | 25 +++-- src/jit/jit-dump.c | 2 +- src/jit/jit-irgen.c | 8 +- src/jit/jit-optim.c | 238 +++++++++++++++++++++++++++++++++++++++++--- src/jit/jit-priv.h | 1 + src/mask.c | 3 + test/test_jit.c | 85 ++++++++++++++++ test/test_misc.c | 12 +++ 8 files changed, 348 insertions(+), 26 deletions(-) diff --git a/src/jit/jit-core.c b/src/jit/jit-core.c index 780ac590..b0d7af10 100644 --- a/src/jit/jit-core.c +++ b/src/jit/jit-core.c @@ -964,15 +964,16 @@ jit_handle_t jit_assemble(jit_t *j, ident_t name, const char *text) int nresult; int nargs; } optab[] = { - { "MOV", J_MOV, 1, 1 }, - { "ADD", J_ADD, 1, 2 }, - { "SUB", J_SUB, 1, 2 }, - { "MUL", J_MUL, 1, 2 }, - { "RECV", J_RECV, 1, 1 }, - { "SEND", J_SEND, 0, 2 }, - { "RET", J_RET, 0, 0 }, - { "CMP", J_CMP, 0, 2 }, - { "JUMP", J_JUMP, 0, 1 }, + { "MOV", J_MOV, 1, 1 }, + { "ADD", J_ADD, 1, 2 }, + { "SUB", J_SUB, 1, 2 }, + { "MUL", J_MUL, 1, 2 }, + { "RECV", J_RECV, 1, 1 }, + { "SEND", J_SEND, 0, 2 }, + { "RET", J_RET, 0, 0 }, + { "CMP", J_CMP, 0, 2 }, + { "JUMP", J_JUMP, 0, 1 }, + { "$COPY", MACRO_COPY, 1, 2 }, }; static const struct { @@ -1106,6 +1107,12 @@ jit_handle_t jit_assemble(jit_t *j, ident_t name, const char *text) arg.kind = JIT_VALUE_LABEL; arg.label = atoi(tok + 1); } + else if (tok[0] == '[' && tok[1] == 'R') { + arg.kind = JIT_ADDR_REG; + arg.reg = atoi(tok + 2); + arg.disp = 0; + f->nregs = MAX(arg.reg + 1, f->nregs); + } else if (tok[0] == '\n') fatal_trace("got newline, expecting argument"); else diff --git a/src/jit/jit-dump.c b/src/jit/jit-dump.c index 3532ac7f..6432af46 100644 --- a/src/jit/jit-dump.c +++ b/src/jit/jit-dump.c @@ -51,7 +51,7 @@ const char *jit_op_name(jit_op_t op) "SEND", "RECV", "ADD", "RET", "TRAP", "ULOAD", "STORE", "JUMP", "CMP", "CSET", "SUB", "MOV", "FADD", "MUL", "FMUL", "CALL", "NEG", "LOAD", "CSEL", "LEA", "NOT", "DIV", "FDIV", "SCVTF", "FNEG", "FCVTNS", - "FCMP", "AND", "OR", "XOR", "FSUB", "REM", "DEBUG" + "FCMP", "AND", "OR", "XOR", "FSUB", "REM", "DEBUG", "NOP" }; assert(op < ARRAY_LEN(names)); return names[op]; diff --git a/src/jit/jit-irgen.c b/src/jit/jit-irgen.c index 12c6b1db..59a9e85d 100644 --- a/src/jit/jit-irgen.c +++ b/src/jit/jit-irgen.c @@ -2644,7 +2644,6 @@ static void irgen_op_report(jit_irgen_t *g, int op) static void irgen_op_assert(jit_irgen_t *g, int op) { - jit_value_t value = irgen_get_arg(g, op, 0); jit_value_t severity = irgen_get_arg(g, op, 1); jit_value_t locus = irgen_get_arg(g, op, 4); @@ -2669,7 +2668,12 @@ static void irgen_op_assert(jit_irgen_t *g, int op) irgen_label_t *l_pass = irgen_alloc_label(g); - j_cmp(g, JIT_CC_NE, value, jit_value_from_int64(0)); + vcode_reg_t arg0 = vcode_get_arg(op, 0); + if (arg0 != g->flags) { + jit_value_t test = irgen_get_arg(g, op, 0); + j_cmp(g, JIT_CC_NE, test, jit_value_from_int64(0)); + } + j_jump(g, JIT_CC_T, l_pass); j_send(g, 0, msg); diff --git a/src/jit/jit-optim.c b/src/jit/jit-optim.c index 5e2b28fb..5549cc23 100644 --- a/src/jit/jit-optim.c +++ b/src/jit/jit-optim.c @@ -16,11 +16,13 @@ // #include "util.h" +#include "array.h" #include "jit/jit-priv.h" #include #include #include +#include //////////////////////////////////////////////////////////////////////////////// // Control flow graph construction @@ -256,6 +258,28 @@ int jit_get_edge(jit_edge_list_t *list, int nth) } \ } while (0) +typedef unsigned valnum_t; + +#define VN_INVALID UINT_MAX +#define SMALL_CONST 100 + +typedef struct _lvn_tab lvn_tab_t; + +typedef struct { + jit_func_t *func; + valnum_t *regvn; + valnum_t nextvn; + lvn_tab_t *hashtab; + size_t tabsz; +} lvn_state_t; + +typedef struct _lvn_tab { + jit_ir_t *ir; + valnum_t vn; +} lvn_tab_t; + +static void jit_lvn_mov(jit_ir_t *ir, lvn_state_t *state); + static inline bool lvn_can_fold(jit_ir_t *ir) { return ir->arg1.kind == JIT_VALUE_INT64 @@ -263,16 +287,134 @@ static inline bool lvn_can_fold(jit_ir_t *ir) || ir->arg2.kind == JIT_VALUE_INT64); } -static void lvn_convert_mov(jit_ir_t *ir, int64_t cval) +static void lvn_mov_const(jit_ir_t *ir, lvn_state_t *state, int64_t cval) { ir->op = J_MOV; ir->size = JIT_SZ_UNSPEC; + ir->cc = JIT_CC_NONE; ir->arg1.kind = JIT_VALUE_INT64; ir->arg1.int64 = cval; ir->arg2.kind = JIT_VALUE_INVALID; + + jit_lvn_mov(ir, state); +} + +static void lvn_mov_reg(jit_ir_t *ir, lvn_state_t *state, jit_reg_t reg) +{ + ir->op = J_MOV; + ir->size = JIT_SZ_UNSPEC; + ir->cc = JIT_CC_NONE; + ir->arg1.kind = JIT_VALUE_REG; + ir->arg1.reg = reg; + ir->arg2.kind = JIT_VALUE_INVALID; + + jit_lvn_mov(ir, state); +} + +static valnum_t lvn_value_num(jit_value_t value, lvn_state_t *state) +{ + switch (value.kind) { + case JIT_VALUE_REG: + if (state->regvn[value.reg] != VN_INVALID) + return state->regvn[value.reg]; + else + return (state->regvn[value.reg] = state->nextvn++); + + case JIT_VALUE_INT64: + if (value.int64 >= 0 && value.int64 < SMALL_CONST) + return value.int64; + else + return state->nextvn++; // TODO: add constant table + + case JIT_VALUE_INVALID: + return VN_INVALID; + + case JIT_VALUE_HANDLE: + return state->nextvn++; + + default: + fatal_trace("cannot handle value kind %d in lvn_value_num", value.kind); + } +} + +static inline bool lvn_is_commutative(jit_op_t op) +{ + return op == J_ADD || op == J_MUL; +} + +static void lvn_commute_const(jit_ir_t *ir) +{ + assert(lvn_is_commutative(ir->op)); + + if (ir->arg1.kind == JIT_VALUE_INT64) { + jit_value_t tmp = ir->arg1; + ir->arg1 = ir->arg2; + ir->arg2 = tmp; + } +} + +static void lvn_get_tuple(jit_ir_t *ir, lvn_state_t *state, int tuple[3]) +{ + tuple[0] = ir->op | ir->size << 8 | ir->cc << 11; + + const valnum_t vn1 = lvn_value_num(ir->arg1, state); + const valnum_t vn2 = lvn_value_num(ir->arg2, state); + + if (lvn_is_commutative(ir->op) && vn1 > vn2) { + tuple[1] = vn2; + tuple[2] = vn1; + } + else { + tuple[1] = vn1; + tuple[2] = vn2; + } +} + +static void jit_lvn_generic(jit_ir_t *ir, lvn_state_t *state) +{ + assert(ir->result != JIT_REG_INVALID); + + int tuple[3]; + lvn_get_tuple(ir, state, tuple); + + const unsigned hash = tuple[0]*29 + tuple[1]*1093 + tuple[2]*6037; + + for (int idx = hash & (state->tabsz - 1), limit = 0; limit < 10; + idx = (idx + 1) & (state->tabsz - 1), limit++) { + + lvn_tab_t *tab = &(state->hashtab[idx]); + if (tab->ir == NULL) { + tab->ir = ir; + tab->vn = state->regvn[ir->result] = state->nextvn++; + break; + } + else if (tab->vn != state->regvn[tab->ir->result]) { + // Stale entry + tab->ir = ir; + tab->vn = state->regvn[ir->result] = state->nextvn++; + break; + } + else { + int cmp[3]; + lvn_get_tuple(tab->ir, state, cmp); + + if (cmp[0] == tuple[0] && cmp[1] == tuple[1] && cmp[2] == tuple[2]) { + assert(tab->ir->result != JIT_REG_INVALID); + + ir->op = J_MOV; + ir->size = JIT_SZ_UNSPEC; + ir->arg1.kind = JIT_VALUE_REG; + ir->arg1.reg = tab->ir->result; + ir->arg2.kind = JIT_VALUE_INVALID; + + state->regvn[ir->result] = tab->vn; + break; + } + } + } } -static void jit_lvn_mul(jit_ir_t *ir) +static void jit_lvn_mul(jit_ir_t *ir, lvn_state_t *state) { if (lvn_can_fold(ir)) { const int64_t lhs = ir->arg1.int64; @@ -281,7 +423,7 @@ static void jit_lvn_mul(jit_ir_t *ir) #define FOLD_MUL(type) do { \ type result; \ if (!__builtin_mul_overflow(lhs, rhs, &result)) { \ - lvn_convert_mov(ir, result); \ + lvn_mov_const(ir, state, result); \ return; \ } \ } while (0) @@ -289,9 +431,24 @@ static void jit_lvn_mul(jit_ir_t *ir) FOR_ALL_SIZES(ir->size, FOLD_MUL); #undef FOLD_MUL } + + lvn_commute_const(ir); + + if (ir->arg2.kind == JIT_VALUE_INT64) { + if (ir->arg2.int64 == 0) { + lvn_mov_const(ir, state, 0); + return; + } + else if (ir->arg2.int64 == 1) { + lvn_mov_reg(ir, state, ir->arg1.reg); + return; + } + } + + jit_lvn_generic(ir, state); } -static void jit_lvn_add(jit_ir_t *ir) +static void jit_lvn_add(jit_ir_t *ir, lvn_state_t *state) { if (lvn_can_fold(ir)) { const int64_t lhs = ir->arg1.int64; @@ -300,7 +457,7 @@ static void jit_lvn_add(jit_ir_t *ir) #define FOLD_ADD(type) do { \ type result; \ if (!__builtin_add_overflow(lhs, rhs, &result)) { \ - lvn_convert_mov(ir, result); \ + lvn_mov_const(ir, state, result); \ return; \ } \ } while (0) @@ -308,9 +465,18 @@ static void jit_lvn_add(jit_ir_t *ir) FOR_ALL_SIZES(ir->size, FOLD_ADD); #undef FOLD_ADD } + + lvn_commute_const(ir); + + if (ir->arg2.kind == JIT_VALUE_INT64 && ir->arg2.int64 == 0) { + lvn_mov_reg(ir, state, ir->arg1.reg); + return; + } + + jit_lvn_generic(ir, state); } -static void jit_lvn_sub(jit_ir_t *ir) +static void jit_lvn_sub(jit_ir_t *ir, lvn_state_t *state) { if (lvn_can_fold(ir)) { const int64_t lhs = ir->arg1.int64; @@ -319,7 +485,7 @@ static void jit_lvn_sub(jit_ir_t *ir) #define FOLD_SUB(type) do { \ type result; \ if (!__builtin_sub_overflow(lhs, rhs, &result)) { \ - lvn_convert_mov(ir, result); \ + lvn_mov_const(ir, state, result); \ return; \ } \ } while (0) @@ -327,17 +493,61 @@ static void jit_lvn_sub(jit_ir_t *ir) FOR_ALL_SIZES(ir->size, FOLD_SUB); #undef FOLD_SUB } + + jit_lvn_generic(ir, state); +} + +static void jit_lvn_mov(jit_ir_t *ir, lvn_state_t *state) +{ + if (ir->arg1.kind == JIT_VALUE_REG && ir->arg1.reg == ir->result) { + ir->op = J_NOP; + ir->result = JIT_REG_INVALID; + ir->arg1.kind = JIT_VALUE_INVALID; + ir->arg2.kind = JIT_VALUE_INVALID; + return; + } + + jit_lvn_generic(ir, state); +} + +static void jit_lvn_copy(jit_ir_t *ir, lvn_state_t *state) +{ + // Clobbers the count register + state->regvn[ir->result] = state->nextvn++; } void jit_do_lvn(jit_func_t *f) { - for (int i = 0; i < f->nirs; i++) { - jit_ir_t *ir = &(f->irbuf[i]); - switch (ir->op) { - case J_MUL: jit_lvn_mul(ir); break; - case J_ADD: jit_lvn_add(ir); break; - case J_SUB: jit_lvn_sub(ir); break; - default: break; + lvn_state_t state = { + .tabsz = next_power_of_2(f->nirs), + .func = f, + .nextvn = SMALL_CONST + }; + + state.regvn = xmalloc_array(f->nregs, sizeof(valnum_t)); + state.hashtab = xcalloc_array(state.tabsz, sizeof(lvn_tab_t)); + + jit_cfg_t *cfg = jit_get_cfg(f); + + for (int i = 0; i < cfg->nblocks; i++) { + jit_block_t *bb = &(cfg->blocks[i]); + + for (int j = 0; j < f->nregs; j++) + state.regvn[j] = VN_INVALID; + + for (int j = bb->first; j <= bb->last; j++) { + jit_ir_t *ir = &(f->irbuf[j]); + switch (ir->op) { + case J_MUL: jit_lvn_mul(ir, &state); break; + case J_ADD: jit_lvn_add(ir, &state); break; + case J_SUB: jit_lvn_sub(ir, &state); break; + case J_MOV: jit_lvn_mov(ir, &state); break; + case MACRO_COPY: jit_lvn_copy(ir, &state); break; + default: break; + } } } + + free(state.regvn); + free(state.hashtab); } diff --git a/src/jit/jit-priv.h b/src/jit/jit-priv.h index 51dc47a7..750f71b1 100644 --- a/src/jit/jit-priv.h +++ b/src/jit/jit-priv.h @@ -61,6 +61,7 @@ typedef enum { J_FSUB, J_REM, J_DEBUG, + J_NOP, __MACRO_BASE = 0x80, MACRO_COPY = __MACRO_BASE, diff --git a/src/mask.c b/src/mask.c index b9a8160e..630bfce9 100644 --- a/src/mask.c +++ b/src/mask.c @@ -42,6 +42,9 @@ void mask_free(bit_mask_t *m) static inline uint64_t mask_for_range(int low, int high) { + if (high < low) + return 0; + assert(low >= 0 && low < 64); assert(high >= 0 && high < 64); assert(low <= high); diff --git a/test/test_jit.c b/test/test_jit.c index e0c31289..132a549a 100644 --- a/test/test_jit.c +++ b/test/test_jit.c @@ -1289,6 +1289,15 @@ START_TEST(test_lvn1) " MUL.8 R0, #100, #5 \n" " ADD R0, #5, #-7 \n" " SUB R0, #12, #30 \n" + " ADD R1, R2, R3 \n" + " ADD R4, R2, R3 \n" + " MUL R1, R2, #3 \n" + " MUL R4, #3, R2 \n" + " MOV R1, #666 \n" + " MUL R5, #3, R2 \n" + " MUL R1, #1, R2 \n" + " MUL R1, #0, R2 \n" + " ADD R1, R1, #0 \n" " RET \n"; jit_handle_t h1 = jit_assemble(j, ident_new("myfunc"), text1); @@ -1308,6 +1317,80 @@ START_TEST(test_lvn1) ck_assert_int_eq(f->irbuf[3].op, J_MOV); ck_assert_int_eq(f->irbuf[3].arg1.int64, -18); + ck_assert_int_eq(f->irbuf[5].op, J_MOV); + ck_assert_int_eq(f->irbuf[5].arg1.reg, f->irbuf[4].result); + + ck_assert_int_eq(f->irbuf[7].op, J_MOV); + ck_assert_int_eq(f->irbuf[7].arg1.reg, f->irbuf[6].result); + + ck_assert_int_eq(f->irbuf[9].op, J_MUL); + ck_assert_int_eq(f->irbuf[9].arg1.reg, 2); + ck_assert_int_eq(f->irbuf[9].arg2.int64, 3); + + ck_assert_int_eq(f->irbuf[10].op, J_MOV); + ck_assert_int_eq(f->irbuf[10].arg1.reg, 2); + + ck_assert_int_eq(f->irbuf[11].op, J_MOV); + ck_assert_int_eq(f->irbuf[11].arg1.int64, 0); + + ck_assert_int_eq(f->irbuf[12].op, J_NOP); + + jit_free(j); +} +END_TEST + +START_TEST(test_lvn2) +{ + jit_t *j = jit_new(); + + const char *text1 = + " MOV R0, #2 \n" + " ADD R1, R0, #1 \n" + "L1: ADD R1, R0, #1 \n" + " ADD R0, R0, #1 \n" + " JUMP L1 \n"; + + jit_handle_t h1 = jit_assemble(j, ident_new("myfunc"), text1); + + jit_func_t *f = jit_get_func(j, h1); + jit_do_lvn(f); + + ck_assert_int_eq(f->irbuf[2].op, J_ADD); + ck_assert_int_eq(f->irbuf[2].arg1.reg, 0); + ck_assert_int_eq(f->irbuf[2].arg2.int64, 1); + + ck_assert_int_eq(f->irbuf[3].op, J_MOV); + ck_assert_int_eq(f->irbuf[3].arg1.reg, 1); + + jit_free(j); +} +END_TEST + +START_TEST(test_lvn3) +{ + jit_t *j = jit_new(); + + const char *text1 = + " MOV R0, #2 \n" + " $COPY R0, [R1], [R2] \n" + " MOV R0, #2 \n" + " $COPY R0, [R3], [R4] \n" + " SUB R46, #1, R41 \n" + " ADD R58, R41, #1 \n"; + + jit_handle_t h1 = jit_assemble(j, ident_new("myfunc"), text1); + + jit_func_t *f = jit_get_func(j, h1); + jit_do_lvn(f); + + ck_assert_int_eq(f->irbuf[2].op, J_MOV); + ck_assert_int_eq(f->irbuf[2].arg1.kind, JIT_VALUE_INT64); + ck_assert_int_eq(f->irbuf[2].arg1.int64, 2); + + ck_assert_int_eq(f->irbuf[5].op, J_ADD); + ck_assert_int_eq(f->irbuf[5].arg2.kind, JIT_VALUE_INT64); + ck_assert_int_eq(f->irbuf[5].arg2.int64, 1); + jit_free(j); } END_TEST @@ -1349,6 +1432,8 @@ Suite *get_jit_tests(void) tcase_add_test(tc, test_assemble2); tcase_add_test(tc, test_cfg1); tcase_add_test(tc, test_lvn1); + tcase_add_test(tc, test_lvn2); + tcase_add_test(tc, test_lvn3); suite_add_tcase(s, tc); return s; diff --git a/test/test_misc.c b/test/test_misc.c index d8533838..d8858831 100644 --- a/test/test_misc.c +++ b/test/test_misc.c @@ -441,6 +441,17 @@ START_TEST(test_subtract) } END_TEST +START_TEST(test_empty_mask) +{ + bit_mask_t m; + mask_init(&m, 0); + + mask_clearall(&m); + + mask_free(&m); +} +END_TEST + static volatile int counter = 0; static nvc_lock_t lock = 0; @@ -504,6 +515,7 @@ Suite *get_misc_tests(void) tcase_add_loop_test(tc_mask, test_count_clear, 0, ARRAY_LEN(mask_size)); tcase_add_loop_test(tc_mask, test_scan_backwards, 0, ARRAY_LEN(mask_size)); tcase_add_loop_test(tc_mask, test_subtract, 0, ARRAY_LEN(mask_size)); + tcase_add_test(tc_mask, test_empty_mask); suite_add_tcase(s, tc_mask); TCase *tc_thread = tcase_create("thread"); -- 2.39.2