From 82fb1f05f48d10995d89d730b001b14a2c2fe017 Mon Sep 17 00:00:00 2001 From: Nick Gasson Date: Sat, 25 Mar 2023 16:12:02 +0000 Subject: [PATCH] Add a dedicated JIT memset opcode --- src/jit/jit-core.c | 1 + src/jit/jit-dump.c | 2 +- src/jit/jit-interp.c | 22 +++++++- src/jit/jit-irgen.c | 45 ++++++---------- src/jit/jit-llvm.c | 124 ++++++++++++++++++++++++++++++++++++++++++- src/jit/jit-optim.c | 54 ++++++++++++++++++- src/jit/jit-priv.h | 1 + test/test_jit.c | 65 +++++++++++++++++++++++ 8 files changed, 282 insertions(+), 32 deletions(-) diff --git a/src/jit/jit-core.c b/src/jit/jit-core.c index 2b3cf66e..7985fc9b 100644 --- a/src/jit/jit-core.c +++ b/src/jit/jit-core.c @@ -1172,6 +1172,7 @@ jit_handle_t jit_assemble(jit_t *j, ident_t name, const char *text) { "$SALLOC", MACRO_SALLOC, 1, 2 }, { "$LALLOC", MACRO_LALLOC, 1, 1 }, { "$BZERO", MACRO_BZERO, 1, 1 }, + { "$MEMSET", MACRO_MEMSET, 1, 2 }, { "$EXP", MACRO_EXP, 1, 2 }, { "$FEXP", MACRO_FEXP, 1, 2 }, }; diff --git a/src/jit/jit-dump.c b/src/jit/jit-dump.c index cf144ed0..d4a780ad 100644 --- a/src/jit/jit-dump.c +++ b/src/jit/jit-dump.c @@ -42,7 +42,7 @@ const char *jit_op_name(jit_op_t op) static const char *names[] = { "$COPY", "$GALLOC", "$EXIT", "$FEXP", "$EXP", "$BZERO", "$FFICALL", "$GETPRIV", "$PUTPRIV", "$LALLOC", "$SALLOC", - "$CASE", "$TRIM", "$MOVE", + "$CASE", "$TRIM", "$MOVE", "$MEMSET", }; assert(op - __MACRO_BASE < ARRAY_LEN(names)); return names[op - __MACRO_BASE]; diff --git a/src/jit/jit-interp.c b/src/jit/jit-interp.c index 07218c0a..145d0c91 100644 --- a/src/jit/jit-interp.c +++ b/src/jit/jit-interp.c @@ -52,7 +52,7 @@ typedef struct _jit_interp { #define JIT_ASSERT(expr) do { \ if (unlikely(!(expr))) { \ interp_dump(state); \ - fatal_trace("assertion '" #expr "' failed"); \ + fatal_trace("assertion '%s' failed", #expr); \ } \ } while (0) #define CANNOT_HANDLE(value) do { \ @@ -795,6 +795,23 @@ static void interp_bzero(jit_interp_t *state, jit_ir_t *ir) memset(dest, '\0', count); } +static void interp_memset(jit_interp_t *state, jit_ir_t *ir) +{ + const size_t bytes = state->regs[ir->result].integer; + void *dest = interp_get_pointer(state, ir->arg1); + const uint64_t value = interp_get_int(state, ir->arg2); + +#define MEMSET_LOOP(type) do { \ + JIT_ASSERT((type)value == value); \ + JIT_ASSERT(bytes % sizeof(type) == 0); \ + type *eptr = dest + bytes; \ + for (type *p = dest; p < eptr; p++) \ + *p = value; \ + } while (0) + + FOR_EACH_SIZE(ir->size, MEMSET_LOOP); +} + static void interp_galloc(jit_interp_t *state, jit_ir_t *ir) { jit_thread_local_t *thread = jit_thread_local(); @@ -999,6 +1016,9 @@ static void interp_loop(jit_interp_t *state) case MACRO_BZERO: interp_bzero(state, ir); break; + case MACRO_MEMSET: + interp_memset(state, ir); + break; case MACRO_GALLOC: interp_galloc(state, ir); break; diff --git a/src/jit/jit-irgen.c b/src/jit/jit-irgen.c index e15923fa..ed27433e 100644 --- a/src/jit/jit-irgen.c +++ b/src/jit/jit-irgen.c @@ -599,6 +599,13 @@ static void macro_bzero(jit_irgen_t *g, jit_value_t dest, jit_reg_t count) count, dest); } +static void macro_memset(jit_irgen_t *g, jit_size_t sz, jit_value_t dest, + jit_value_t value, jit_reg_t count) +{ + assert(jit_value_is_addr(dest)); + irgen_emit_binary(g, MACRO_MEMSET, sz, JIT_CC_NONE, count, dest, value); +} + static jit_value_t macro_galloc(jit_irgen_t *g, jit_value_t bytes) { jit_reg_t r = irgen_alloc_reg(g); @@ -1163,39 +1170,21 @@ static void irgen_op_memset(jit_irgen_t *g, int op) jit_value_t length = irgen_get_arg(g, op, 2); vcode_type_t vtype = vcode_reg_type(vcode_get_arg(op, 1)); + + jit_value_t addr = jit_addr_from_value(base, 0); jit_value_t scale = jit_value_from_int64(irgen_size_bytes(vtype)); jit_value_t bytes = j_mul(g, length, scale); - if (value.kind == JIT_VALUE_INT64 && value.int64 == 0) { - jit_value_t addr = jit_addr_from_value(base, 0); - macro_bzero(g, addr, jit_value_as_reg(bytes)); - } - else { - irgen_label_t *l_exit = irgen_alloc_label(g); - irgen_label_t *l_head = irgen_alloc_label(g); - - jit_value_t base_ptr = irgen_lea(g, base); + jit_reg_t bytes_r = jit_value_as_reg(bytes); - j_cmp(g, JIT_CC_LE, length, jit_value_from_int64(0)); - j_jump(g, JIT_CC_T, l_exit); - - jit_value_t end = j_add(g, base_ptr, bytes); - - jit_reg_t ptr_r = irgen_alloc_reg(g); - j_mov(g, ptr_r, base_ptr); - - irgen_bind_label(g, l_head); - - jit_value_t ptr = jit_value_from_reg(ptr_r); - j_store(g, irgen_jit_size(vtype), value, jit_addr_from_value(ptr, 0)); - - jit_value_t next = j_add(g, ptr, scale); - j_mov(g, ptr_r, next); - j_cmp(g, JIT_CC_LT, next, end); - j_jump(g, JIT_CC_T, l_head); - - irgen_bind_label(g, l_exit); + if (value.kind == JIT_VALUE_INT64 && value.int64 == 0) + macro_bzero(g, addr, bytes_r); + else if (value.kind == JIT_VALUE_DOUBLE) { + jit_value_t bits = jit_value_from_int64(value.int64); + macro_memset(g, irgen_jit_size(vtype), addr, bits, bytes_r); } + else + macro_memset(g, irgen_jit_size(vtype), addr, value, bytes_r); } static void irgen_send_args(jit_irgen_t *g, int op, int first) diff --git a/src/jit/jit-llvm.c b/src/jit/jit-llvm.c index b2577d27..f6a23019 100644 --- a/src/jit/jit-llvm.c +++ b/src/jit/jit-llvm.c @@ -112,6 +112,11 @@ typedef enum { LLVM_EXP_OVERFLOW_U32, LLVM_EXP_OVERFLOW_U64, + LLVM_MEMSET_U8, + LLVM_MEMSET_U16, + LLVM_MEMSET_U32, + LLVM_MEMSET_U64, + LLVM_POW_F64, LLVM_COPYSIGN_F64, LLVM_MEMSET_INLINE, @@ -686,7 +691,6 @@ static LLVMValueRef llvm_get_fn(llvm_obj_t *obj, llvm_fn_t which) jit_size_t sz = which - LLVM_EXP_OVERFLOW_U8; LLVMTypeRef int_type = obj->types[LLVM_INT8 + sz]; LLVMTypeRef pair_type = obj->types[LLVM_PAIR_I8_I1 + sz]; - printf("sz=%d pair_type=%p \n", sz, pair_type); LLVMTypeRef args[] = { int_type, int_type }; obj->fntypes[which] = LLVMFunctionType(pair_type, args, ARRAY_LEN(args), false); @@ -701,6 +705,30 @@ static LLVMValueRef llvm_get_fn(llvm_obj_t *obj, llvm_fn_t which) } break; + case LLVM_MEMSET_U8: + case LLVM_MEMSET_U16: + case LLVM_MEMSET_U32: + case LLVM_MEMSET_U64: + { + jit_size_t sz = which - LLVM_MEMSET_U8; + LLVMTypeRef args[] = { + obj->types[LLVM_PTR], + obj->types[LLVM_INT8 + sz], + obj->types[LLVM_INT64] + }; + obj->fntypes[which] = LLVMFunctionType(obj->types[LLVM_VOID], args, + ARRAY_LEN(args), false); + + static const char *names[] = { + NULL, // Use regular memset instead + "nvc.memset.i16", + "nvc.memset.i32", + "nvc.memset.i64" + }; + fn = llvm_add_fn(obj, names[sz], obj->fntypes[which]); + } + break; + case LLVM_POW_F64: { LLVMTypeRef args[] = { @@ -1957,6 +1985,34 @@ static void cgen_macro_bzero(llvm_obj_t *obj, cgen_block_t *cgb, jit_ir_t *ir) LLVMBuildMemSet(obj->builder, PTR(dest), llvm_int8(obj, 0), count, 0); } +static void cgen_macro_memset(llvm_obj_t *obj, cgen_block_t *cgb, jit_ir_t *ir) +{ + LLVMValueRef count = cgb->outregs[ir->result]; + LLVMValueRef dest = cgen_coerce_value(obj, cgb, ir->arg1, LLVM_PTR); + LLVMValueRef value = cgen_coerce_value(obj, cgb, ir->arg2, + LLVM_INT8 + ir->size); + + if (ir->size == JIT_SZ_8) { +#ifdef LLVM_HAS_OPAQUE_POINTERSS + if (LLVMIsConstant(count) && LLVMConstIntGetZExtValue(count) <= 64) { + LLVMValueRef args[] = { + dest, + value, + count, + llvm_int1(obj, 0), + }; + llvm_call_fn(obj, LLVM_MEMSET_INLINE, args, ARRAY_LEN(args)); + } + else +#endif + LLVMBuildMemSet(obj->builder, PTR(dest), value, count, 0); + } + else { + LLVMValueRef args[] = { dest, value, count, }; + llvm_call_fn(obj, LLVM_MEMSET_U8 + ir->size, args, ARRAY_LEN(args)); + } +} + static void cgen_macro_exit(llvm_obj_t *obj, cgen_block_t *cgb, jit_ir_t *ir) { cgen_sync_irpos(obj, cgb, ir); @@ -2283,6 +2339,9 @@ static void cgen_ir(llvm_obj_t *obj, cgen_block_t *cgb, jit_ir_t *ir) case MACRO_BZERO: cgen_macro_bzero(obj, cgb, ir); break; + case MACRO_MEMSET: + cgen_macro_memset(obj, cgb, ir); + break; case MACRO_EXIT: cgen_macro_exit(obj, cgb, ir); break; @@ -3055,6 +3114,66 @@ static void cgen_exp_overflow_body(llvm_obj_t *obj, llvm_fn_t which, LLVMBuildRet(obj->builder, pair); } +static void cgen_memset_body(llvm_obj_t *obj, llvm_fn_t which, + jit_size_t sz, llvm_fn_t mulbase) +{ + LLVMValueRef fn = obj->fns[which]; + LLVMSetLinkage(fn, LLVMPrivateLinkage); + + LLVMBasicBlockRef entry = llvm_append_block(obj, fn, "entry"); + + LLVMPositionBuilderAtEnd(obj->builder, entry); + + LLVMValueRef dest = LLVMGetParam(fn, 0); + LLVMSetValueName(dest, "dest"); + + LLVMValueRef value = LLVMGetParam(fn, 1); + LLVMSetValueName(value, "value"); + + LLVMValueRef bytes = LLVMGetParam(fn, 2); + LLVMSetValueName(bytes, "bytes"); + + LLVMValueRef indexes1[] = { bytes }; + LLVMValueRef eptr = LLVMBuildGEP2(obj->builder, obj->types[LLVM_INT8], + dest, indexes1, 1, "eptr"); + + LLVMBasicBlockRef body = llvm_append_block(obj, fn, "body"); + LLVMBasicBlockRef exit = llvm_append_block(obj, fn, "exit"); + + LLVMTypeRef type = obj->types[LLVM_INT8 + sz]; + + LLVMValueRef zero = LLVMBuildIsNull(obj->builder, bytes, ""); + LLVMBuildCondBr(obj->builder, zero, exit, body); + + LLVMPositionBuilderAtEnd(obj->builder, body); + + LLVMValueRef ptr = LLVMBuildPhi(obj->builder, obj->types[LLVM_PTR], "ptr"); + +#ifndef LLVM_HAS_OPAQUE_POINTERS + LLVMTypeRef ptr_type = LLVMPointerType(type, 0); + LLVMValueRef cast = LLVMBuildPointerCast(obj->builder, ptr, ptr_type, ""); +#else + LLVMValueRef cast = ptr; +#endif + + LLVMBuildStore(obj->builder, value, cast); + + LLVMValueRef indexes2[] = { llvm_intptr(obj, 1) }; + LLVMValueRef next = LLVMBuildGEP2(obj->builder, type, cast, indexes2, 1, ""); + + LLVMValueRef ptr_in[] = { dest, PTR(next) }; + LLVMBasicBlockRef ptr_bb[] = { entry, body }; + LLVMAddIncoming(ptr, ptr_in, ptr_bb, 2); + + LLVMValueRef cont = + LLVMBuildICmp(obj->builder, LLVMIntULT, PTR(next), eptr, ""); + LLVMBuildCondBr(obj->builder, cont, body, exit); + + LLVMPositionBuilderAtEnd(obj->builder, exit); + + LLVMBuildRetVoid(obj->builder); +} + //////////////////////////////////////////////////////////////////////////////// // JIT plugin interface @@ -3262,6 +3381,9 @@ void llvm_obj_finalise(llvm_obj_t *obj, llvm_opt_level_t olevel) if (obj->fns[LLVM_EXP_OVERFLOW_U8 + sz] != NULL) cgen_exp_overflow_body(obj, LLVM_EXP_OVERFLOW_U8 + sz, sz, LLVM_MUL_OVERFLOW_U8); + if (obj->fns[LLVM_MEMSET_U8 + sz] != NULL) + cgen_memset_body(obj, LLVM_MEMSET_U8 + sz, sz, + LLVM_MUL_OVERFLOW_U8); } DWARF_ONLY(LLVMDIBuilderFinalize(obj->debuginfo)); diff --git a/src/jit/jit-optim.c b/src/jit/jit-optim.c index 9932e5d5..d03dc97a 100644 --- a/src/jit/jit-optim.c +++ b/src/jit/jit-optim.c @@ -77,7 +77,7 @@ static jit_reg_t cfg_get_reg(jit_value_t value) static inline bool cfg_reads_result(jit_ir_t *ir) { return ir->op == MACRO_COPY || ir->op == MACRO_CASE || ir->op == MACRO_BZERO - || ir->op == MACRO_MOVE; + || ir->op == MACRO_MOVE || ir->op == MACRO_MEMSET; } static inline bool cfg_writes_result(jit_ir_t *ir) @@ -357,6 +357,7 @@ static void lvn_convert_nop(jit_ir_t *ir) ir->op = J_NOP; ir->size = JIT_SZ_UNSPEC; ir->cc = JIT_CC_NONE; + ir->result = JIT_REG_INVALID; ir->arg1.kind = JIT_VALUE_INVALID; ir->arg2.kind = JIT_VALUE_INVALID; } @@ -785,6 +786,56 @@ static void jit_lvn_bzero(jit_ir_t *ir, lvn_state_t *state) state->regvn[ir->result] = lvn_new_value(state); } +static void jit_lvn_memset(jit_ir_t *ir, lvn_state_t *state) +{ + int64_t value; + if (lvn_is_const(ir->arg2, state, &value)) { + static const uint64_t compress[JIT_SZ_UNSPEC] = { + 0x01, 0x0101, 0x01010101, 0x0101010101010101 + }; + static const uint64_t expand[JIT_SZ_UNSPEC] = { + 0x0101010101010101, 0x0001000100010001, 0x100000001, 1, + }; + static const uint64_t mask[JIT_SZ_UNSPEC] = { + 0xff, 0xffff, 0xffffffff, 0xffffffffffffffff + }; + + uint64_t bits = value; + if (value == 0) { + ir->op = MACRO_BZERO; + ir->size = JIT_SZ_UNSPEC; + ir->arg2.kind = JIT_VALUE_INVALID; + jit_lvn_bzero(ir, state); + return; + } + else if (bits == (bits & 0xff) * compress[ir->size]) { + ir->size = JIT_SZ_8; + ir->arg2 = LVN_CONST((bits &= 0xff)); + } + + int64_t bytes; + if (lvn_get_const(state->regvn[ir->result], state, &bytes)) { + if (bytes == 0) { + lvn_convert_nop(ir); + return; + } + else if (bytes > 0 && bytes <= 8 && is_power_of_2(bytes)) { + // Convert to a store + const jit_size_t newsz = ilog2(bytes); + ir->op = J_STORE; + ir->arg2 = ir->arg1; + ir->arg1 = LVN_CONST((bits * expand[ir->size]) & mask[newsz]); + ir->result = JIT_REG_INVALID; + ir->size = newsz; + return; + } + } + } + + // Clobbers the count register + state->regvn[ir->result] = lvn_new_value(state); +} + static void jit_lvn_exp(jit_ir_t *ir, lvn_state_t *state) { int64_t base, exp; @@ -841,6 +892,7 @@ void jit_do_lvn(jit_func_t *f) case MACRO_MOVE: case MACRO_COPY: jit_lvn_copy(ir, &state); break; case MACRO_BZERO: jit_lvn_bzero(ir, &state); break; + case MACRO_MEMSET: jit_lvn_memset(ir, &state); break; case MACRO_EXP: jit_lvn_exp(ir, &state); break; default: break; } diff --git a/src/jit/jit-priv.h b/src/jit/jit-priv.h index b7ecec26..54039459 100644 --- a/src/jit/jit-priv.h +++ b/src/jit/jit-priv.h @@ -82,6 +82,7 @@ typedef enum { MACRO_CASE, MACRO_TRIM, MACRO_MOVE, + MACRO_MEMSET, } jit_op_t; typedef enum { diff --git a/test/test_jit.c b/test/test_jit.c index 71f7fe17..5e28796e 100644 --- a/test/test_jit.c +++ b/test/test_jit.c @@ -1873,6 +1873,70 @@ START_TEST(test_lvn8) } END_TEST +START_TEST(test_lvn9) +{ + jit_t *j = jit_new(); + + const char *text1 = + " MOV R1, #32 \n" + " $MEMSET.16 R1, [R2], #0xffff \n" + " $MEMSET.32 R1, [R2], #0xabababab \n" + " $MEMSET.64 R1, [R2], #0x0101010101010101 \n" + " $MEMSET.32 R3, [R4], #0 \n" + " MOV R3, #0 \n" + " $MEMSET.8 R3, [R2], #0xf0 \n" + " MOV R4, #4 \n" + " $MEMSET.8 R4, [R0+20], #1 \n" + " $MEMSET.16 R4, [R0+20], #0x2233 \n" + " MOV R4, #8 \n" + " $MEMSET.16 R4, [R0+20], #0x2233 \n" + " $MEMSET.32 R4, [R0+20], #0x11223344 \n" + " $MEMSET.64 R4, [R0+20], #0xbeef \n"; + + jit_handle_t h1 = jit_assemble(j, ident_new("myfunc1"), text1); + + jit_func_t *f = jit_get_func(j, h1); + jit_do_lvn(f); + + ck_assert_int_eq(f->irbuf[1].size, JIT_SZ_8); + ck_assert_int_eq(f->irbuf[1].arg2.int64, 0xff); + + ck_assert_int_eq(f->irbuf[2].size, JIT_SZ_8); + ck_assert_int_eq(f->irbuf[2].arg2.int64, 0xab); + + ck_assert_int_eq(f->irbuf[3].size, JIT_SZ_8); + ck_assert_int_eq(f->irbuf[3].arg2.int64, 1); + + ck_assert_int_eq(f->irbuf[4].op, MACRO_BZERO); + ck_assert_int_eq(f->irbuf[4].size, JIT_SZ_UNSPEC); + ck_assert_int_eq(f->irbuf[4].arg2.kind, JIT_VALUE_INVALID); + + ck_assert_int_eq(f->irbuf[6].op, J_NOP); + + ck_assert_int_eq(f->irbuf[8].op, J_STORE); + ck_assert_int_eq(f->irbuf[8].size, JIT_SZ_32); + ck_assert_int_eq(f->irbuf[8].arg1.int64, 0x1010101); + + ck_assert_int_eq(f->irbuf[9].op, J_STORE); + ck_assert_int_eq(f->irbuf[9].size, JIT_SZ_32); + ck_assert_int_eq(f->irbuf[9].arg1.int64, 0x22332233); + + ck_assert_int_eq(f->irbuf[11].op, J_STORE); + ck_assert_int_eq(f->irbuf[11].size, JIT_SZ_64); + ck_assert_int_eq(f->irbuf[11].arg1.int64, 0x2233223322332233); + + ck_assert_int_eq(f->irbuf[12].op, J_STORE); + ck_assert_int_eq(f->irbuf[12].size, JIT_SZ_64); + ck_assert_int_eq(f->irbuf[12].arg1.int64, 0x1122334411223344); + + ck_assert_int_eq(f->irbuf[13].op, J_STORE); + ck_assert_int_eq(f->irbuf[13].size, JIT_SZ_64); + ck_assert_int_eq(f->irbuf[13].arg1.int64, 0xbeef); + + jit_free(j); +} +END_TEST + Suite *get_jit_tests(void) { Suite *s = suite_create("jit"); @@ -1924,6 +1988,7 @@ Suite *get_jit_tests(void) tcase_add_test(tc, test_tlab1); tcase_add_test(tc, test_lvn7); tcase_add_test(tc, test_lvn8); + tcase_add_test(tc, test_lvn9); suite_add_tcase(s, tc); return s; -- 2.39.2