From fe256be9e930b5b850b3a90bd299a607120dade0 Mon Sep 17 00:00:00 2001 From: Nick Gasson Date: Tue, 1 Nov 2022 17:55:37 +0000 Subject: [PATCH] Add some simple peephole optimisations for JIT IR --- src/jit/jit-core.c | 26 ++++++++---- src/jit/jit-irgen.c | 3 ++ src/jit/jit-optim.c | 99 +++++++++++++++++++++++++++++++++++++++++++++ src/jit/jit-priv.h | 2 + test/test_jit.c | 33 +++++++++++++++ 5 files changed, 156 insertions(+), 7 deletions(-) diff --git a/src/jit/jit-core.c b/src/jit/jit-core.c index 900845f5..780ac590 100644 --- a/src/jit/jit-core.c +++ b/src/jit/jit-core.c @@ -1053,15 +1053,27 @@ jit_handle_t jit_assemble(jit_t *j, ident_t name, const char *text) case CCSIZE: { - int cpos = 0; - for (; cpos < ARRAY_LEN(optab) - && strcmp(cctab[cpos].name, tok); cpos++) - ; + if (isdigit((int)tok[0])) { + switch (atoi(tok)) { + case 8: ir->size = JIT_SZ_8; break; + case 16: ir->size = JIT_SZ_16; break; + case 32: ir->size = JIT_SZ_32; break; + case 64: ir->size = JIT_SZ_64; break; + default: + fatal_trace("illegal operation size %s", tok); + } + } + else { + int cpos = 0; + for (; cpos < ARRAY_LEN(cctab) + && strcmp(cctab[cpos].name, tok); cpos++) + ; - if (cpos == ARRAY_LEN(cctab)) - fatal_trace("illegal condition code %s", tok); + if (cpos == ARRAY_LEN(cctab)) + fatal_trace("illegal condition code %s", tok); - ir->cc = cctab[cpos].cc; + ir->cc = cctab[cpos].cc; + } state = nresult > 0 ? RESULT : (nargs > 0 ? ARG1 : NEWLINE); } diff --git a/src/jit/jit-irgen.c b/src/jit/jit-irgen.c index eb43cae7..12c6b1db 100644 --- a/src/jit/jit-irgen.c +++ b/src/jit/jit-irgen.c @@ -3716,6 +3716,9 @@ void jit_irgen(jit_func_t *f) } g->labels = NULL; + jit_do_lvn(f); + jit_free_cfg(f); + if (opt_get_verbose(OPT_JIT_VERBOSE, istr(f->name))) { #ifdef DEBUG jit_dump_interleaved(f); diff --git a/src/jit/jit-optim.c b/src/jit/jit-optim.c index 413e80ee..5e2b28fb 100644 --- a/src/jit/jit-optim.c +++ b/src/jit/jit-optim.c @@ -242,3 +242,102 @@ int jit_get_edge(jit_edge_list_t *list, int nth) else return list->u.external[nth]; } + +//////////////////////////////////////////////////////////////////////////////// +// Local value numbering and simple peepholes + +#define FOR_ALL_SIZES(size, macro) do { \ + switch (size) { \ + case JIT_SZ_8: macro(int8_t); break; \ + case JIT_SZ_16: macro(int16_t); break; \ + case JIT_SZ_32: macro(int32_t); break; \ + case JIT_SZ_UNSPEC: \ + case JIT_SZ_64: macro(int64_t); break; \ + } \ + } while (0) + +static inline bool lvn_can_fold(jit_ir_t *ir) +{ + return ir->arg1.kind == JIT_VALUE_INT64 + && (ir->arg2.kind == JIT_VALUE_INVALID + || ir->arg2.kind == JIT_VALUE_INT64); +} + +static void lvn_convert_mov(jit_ir_t *ir, int64_t cval) +{ + ir->op = J_MOV; + ir->size = JIT_SZ_UNSPEC; + ir->arg1.kind = JIT_VALUE_INT64; + ir->arg1.int64 = cval; + ir->arg2.kind = JIT_VALUE_INVALID; +} + +static void jit_lvn_mul(jit_ir_t *ir) +{ + if (lvn_can_fold(ir)) { + const int64_t lhs = ir->arg1.int64; + const int64_t rhs = ir->arg2.int64; + +#define FOLD_MUL(type) do { \ + type result; \ + if (!__builtin_mul_overflow(lhs, rhs, &result)) { \ + lvn_convert_mov(ir, result); \ + return; \ + } \ + } while (0) + + FOR_ALL_SIZES(ir->size, FOLD_MUL); +#undef FOLD_MUL + } +} + +static void jit_lvn_add(jit_ir_t *ir) +{ + if (lvn_can_fold(ir)) { + const int64_t lhs = ir->arg1.int64; + const int64_t rhs = ir->arg2.int64; + +#define FOLD_ADD(type) do { \ + type result; \ + if (!__builtin_add_overflow(lhs, rhs, &result)) { \ + lvn_convert_mov(ir, result); \ + return; \ + } \ + } while (0) + + FOR_ALL_SIZES(ir->size, FOLD_ADD); +#undef FOLD_ADD + } +} + +static void jit_lvn_sub(jit_ir_t *ir) +{ + if (lvn_can_fold(ir)) { + const int64_t lhs = ir->arg1.int64; + const int64_t rhs = ir->arg2.int64; + +#define FOLD_SUB(type) do { \ + type result; \ + if (!__builtin_sub_overflow(lhs, rhs, &result)) { \ + lvn_convert_mov(ir, result); \ + return; \ + } \ + } while (0) + + FOR_ALL_SIZES(ir->size, FOLD_SUB); +#undef FOLD_SUB + } +} + +void jit_do_lvn(jit_func_t *f) +{ + for (int i = 0; i < f->nirs; i++) { + jit_ir_t *ir = &(f->irbuf[i]); + switch (ir->op) { + case J_MUL: jit_lvn_mul(ir); break; + case J_ADD: jit_lvn_add(ir); break; + case J_SUB: jit_lvn_sub(ir); break; + default: break; + } + } +} diff --git a/src/jit/jit-priv.h b/src/jit/jit-priv.h index abc75b2e..51dc47a7 100644 --- a/src/jit/jit-priv.h +++ b/src/jit/jit-priv.h @@ -297,6 +297,8 @@ void jit_free_cfg(jit_func_t *f); jit_block_t *jit_block_for(jit_cfg_t *cfg, int pos); int jit_get_edge(jit_edge_list_t *list, int nth); +void jit_do_lvn(jit_func_t *f); + void __nvc_do_exit(jit_exit_t which, jit_anchor_t *anchor, jit_scalar_t *args); void __nvc_do_fficall(jit_foreign_t *ff, jit_anchor_t *anchor, jit_scalar_t *args); diff --git a/test/test_jit.c b/test/test_jit.c index 8f78c45e..e0c31289 100644 --- a/test/test_jit.c +++ b/test/test_jit.c @@ -1280,6 +1280,38 @@ START_TEST(test_cfg1) } END_TEST +START_TEST(test_lvn1) +{ + jit_t *j = jit_new(); + + const char *text1 = + " MUL R0, #2, #3 \n" + " MUL.8 R0, #100, #5 \n" + " ADD R0, #5, #-7 \n" + " SUB R0, #12, #30 \n" + " RET \n"; + + jit_handle_t h1 = jit_assemble(j, ident_new("myfunc"), text1); + + jit_func_t *f = jit_get_func(j, h1); + jit_do_lvn(f); + + ck_assert_int_eq(f->irbuf[0].op, J_MOV); + ck_assert_int_eq(f->irbuf[0].arg1.int64, 6); + + ck_assert_int_eq(f->irbuf[1].op, J_MUL); + ck_assert_int_eq(f->irbuf[1].size, JIT_SZ_8); + + ck_assert_int_eq(f->irbuf[2].op, J_MOV); + ck_assert_int_eq(f->irbuf[2].arg1.int64, -2); + + ck_assert_int_eq(f->irbuf[3].op, J_MOV); + ck_assert_int_eq(f->irbuf[3].arg1.int64, -18); + + jit_free(j); +} +END_TEST + Suite *get_jit_tests(void) { Suite *s = suite_create("jit"); @@ -1316,6 +1348,7 @@ Suite *get_jit_tests(void) tcase_add_test(tc, test_assemble1); tcase_add_test(tc, test_assemble2); tcase_add_test(tc, test_cfg1); + tcase_add_test(tc, test_lvn1); suite_add_tcase(s, tc); return s; -- 2.39.2