From 9ec1942741d3d3d9e25235febdba5cc74ed57f0b Mon Sep 17 00:00:00 2001 From: Nick Gasson Date: Thu, 27 Oct 2022 12:05:40 +0100 Subject: [PATCH] Improve type coercion in LLVM JIT --- src/jit/jit-llvm.c | 258 ++++++++++++++++++++++++++++----------------- 1 file changed, 162 insertions(+), 96 deletions(-) diff --git a/src/jit/jit-llvm.c b/src/jit/jit-llvm.c index 76613b01..ea031fce 100644 --- a/src/jit/jit-llvm.c +++ b/src/jit/jit-llvm.c @@ -46,6 +46,7 @@ typedef enum { LLVM_INT32, LLVM_INT64, LLVM_INTPTR, + LLVM_DOUBLE, LLVM_PAIR_I8_I1, LLVM_PAIR_I16_I1, @@ -89,6 +90,8 @@ typedef enum { LLVM_MUL_OVERFLOW_U32, LLVM_MUL_OVERFLOW_U64, + LLVM_POW_F64, + LLVM_DO_EXIT, LLVM_GETPRIV, LLVM_PUTPRIV, @@ -147,6 +150,11 @@ static LLVMValueRef llvm_int1(cgen_req_t *req, bool b) return LLVMConstInt(req->types[LLVM_INT1], b, false); } +static LLVMValueRef llvm_int8(cgen_req_t *req, int8_t i) +{ + return LLVMConstInt(req->types[LLVM_INT1], i, false); +} + static LLVMValueRef llvm_int32(cgen_req_t *req, int32_t i) { return LLVMConstInt(req->types[LLVM_INT32], i, false); @@ -336,6 +344,19 @@ static LLVMValueRef cgen_get_fn(cgen_req_t *req, llvm_fn_t which) } break; + case LLVM_POW_F64: + { + LLVMTypeRef args[] = { + req->types[LLVM_DOUBLE], + req->types[LLVM_DOUBLE] + }; + req->fntypes[which] = LLVMFunctionType(req->types[LLVM_DOUBLE], + args, ARRAY_LEN(args), false); + + fn = LLVMAddFunction(req->module, "llvm.pow.f64", req->fntypes[which]); + } + break; + case LLVM_DO_EXIT: { LLVMTypeRef args[] = { @@ -526,35 +547,38 @@ static LLVMValueRef cgen_coerce_value(cgen_req_t *req, cgen_block_t *cgb, else return raw; + case LLVM_INTPTR: + case LLVM_INT64: + case LLVM_INT32: + case LLVM_INT16: + case LLVM_INT8: + if (LLVMGetTypeKind(lltype) == LLVMPointerTypeKind) + return LLVMBuildIntToPtr(req->builder, raw, req->types[LLVM_PTR], ""); + else { + const int bits1 = LLVMGetIntTypeWidth(lltype); + const int bits2 = LLVMGetIntTypeWidth(req->types[type]); + + if (bits1 < bits2) + return LLVMBuildSExt(req->builder, raw, req->types[type], ""); + else if (bits1 == bits2) + return raw; + else + return LLVMBuildTrunc(req->builder, raw, req->types[type], ""); + } + default: return raw; } } -static void cgen_coerece(cgen_req_t *req, jit_size_t size, LLVMValueRef *arg1, - LLVMValueRef *arg2) +static llvm_type_t cgen_int_type_for(jit_ir_t *ir) { - if (size != JIT_SZ_UNSPEC) { - LLVMTypeRef int_type = req->types[LLVM_INT8 + size]; - *arg1 = LLVMBuildTrunc(req->builder, *arg1, int_type, ""); - *arg2 = LLVMBuildTrunc(req->builder, *arg2, int_type, ""); - } - else { - LLVMTypeRef type1 = LLVMTypeOf(*arg1); - LLVMTypeRef type2 = LLVMTypeOf(*arg2); - - if (LLVMGetTypeKind(type1) == LLVMPointerTypeKind) - *arg2 = LLVMBuildIntToPtr(req->builder, *arg2, - req->types[LLVM_PTR], ""); - else { - const int bits1 = LLVMGetIntTypeWidth(type1); - const int bits2 = LLVMGetIntTypeWidth(type2); - - if (bits1 < bits2) - *arg1 = LLVMBuildSExt(req->builder, *arg1, type2, ""); - else - *arg2 = LLVMBuildSExt(req->builder, *arg2, type1, ""); - } + switch (ir->size) { + case JIT_SZ_8: return LLVM_INT8; + case JIT_SZ_16: return LLVM_INT16; + case JIT_SZ_32: return LLVM_INT32; + case JIT_SZ_64: return LLVM_INT64; + default: return LLVM_INTPTR; } } @@ -615,26 +639,11 @@ static void cgen_op_send(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) static void cgen_op_store(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) { - LLVMValueRef value = cgen_get_value(req, cgb, ir->arg1); + llvm_type_t type = cgen_int_type_for(ir); + LLVMValueRef value = cgen_coerce_value(req, cgb, ir->arg1, type); LLVMValueRef ptr = cgen_coerce_value(req, cgb, ir->arg2, LLVM_PTR); - LLVMValueRef trunc = value; - switch (ir->size) { - case JIT_SZ_8: - trunc = LLVMBuildTrunc(req->builder, value, req->types[LLVM_INT8], ""); - break; - case JIT_SZ_16: - trunc = LLVMBuildTrunc(req->builder, value, req->types[LLVM_INT16], ""); - break; - case JIT_SZ_32: - trunc = LLVMBuildTrunc(req->builder, value, req->types[LLVM_INT32], ""); - break; - case JIT_SZ_64: - case JIT_SZ_UNSPEC: - break; - } - - LLVMBuildStore(req->builder, trunc, ptr); + LLVMBuildStore(req->builder, value, ptr); } static void cgen_op_load(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) @@ -663,10 +672,12 @@ static void cgen_op_load(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) static void cgen_op_add(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) { - LLVMValueRef arg1 = cgen_get_value(req, cgb, ir->arg1); - LLVMValueRef arg2 = cgen_get_value(req, cgb, ir->arg2); + llvm_type_t type = req->regtypes[ir->result]; + + if (type == LLVM_PTR) { + LLVMValueRef arg1 = cgen_coerce_value(req, cgb, ir->arg1, LLVM_PTR); + LLVMValueRef arg2 = cgen_coerce_value(req, cgb, ir->arg2, LLVM_INTPTR); - if (LLVMGetTypeKind(LLVMTypeOf(arg1)) == LLVMPointerTypeKind) { LLVMValueRef indexes[] = { cgen_zext_to_intptr(req, arg2) }; LLVMValueRef result = LLVMBuildGEP2(req->builder, req->types[LLVM_INT8], arg1, indexes, ARRAY_LEN(indexes), @@ -674,7 +685,8 @@ static void cgen_op_add(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) cgb->outregs[ir->result] = result; } else { - cgen_coerece(req, ir->size, &arg1, &arg2); + LLVMValueRef arg1 = cgen_coerce_value(req, cgb, ir->arg1, type); + LLVMValueRef arg2 = cgen_coerce_value(req, cgb, ir->arg2, type); llvm_fn_t fn = LLVM_LAST_FN; if (ir->cc == JIT_CC_O) @@ -699,10 +711,9 @@ static void cgen_op_add(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) static void cgen_op_sub(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) { - LLVMValueRef arg1 = cgen_get_value(req, cgb, ir->arg1); - LLVMValueRef arg2 = cgen_get_value(req, cgb, ir->arg2); - - cgen_coerece(req, ir->size, &arg1, &arg2); + llvm_type_t type = cgen_int_type_for(ir); + LLVMValueRef arg1 = cgen_coerce_value(req, cgb, ir->arg1, type); + LLVMValueRef arg2 = cgen_coerce_value(req, cgb, ir->arg2, type); llvm_fn_t fn = LLVM_LAST_FN; if (ir->cc == JIT_CC_O) @@ -726,16 +737,9 @@ static void cgen_op_sub(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) static void cgen_op_mul(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) { - LLVMValueRef arg1 = cgen_get_value(req, cgb, ir->arg1); - LLVMValueRef arg2 = cgen_get_value(req, cgb, ir->arg2); - - cgen_coerece(req, ir->size, &arg1, &arg2); - - if (ir->size != JIT_SZ_UNSPEC) { - LLVMTypeRef int_type = req->types[LLVM_INT8 + ir->size]; - arg1 = LLVMBuildTrunc(req->builder, arg1, int_type, ""); - arg2 = LLVMBuildTrunc(req->builder, arg2, int_type, ""); - } + llvm_type_t type = cgen_int_type_for(ir); + LLVMValueRef arg1 = cgen_coerce_value(req, cgb, ir->arg1, type); + LLVMValueRef arg2 = cgen_coerce_value(req, cgb, ir->arg2, type); llvm_fn_t fn = LLVM_LAST_FN; if (ir->cc == JIT_CC_O) @@ -757,6 +761,26 @@ static void cgen_op_mul(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) cgb->outregs[ir->result] = result; } +static void cgen_op_div(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) +{ + llvm_type_t type = cgen_int_type_for(ir); + LLVMValueRef arg1 = cgen_coerce_value(req, cgb, ir->arg1, type); + LLVMValueRef arg2 = cgen_coerce_value(req, cgb, ir->arg2, type); + + cgb->outregs[ir->result] = LLVMBuildSDiv(req->builder, arg1, arg2, + cgen_reg_name(ir->result)); +} + +static void cgen_op_rem(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) +{ + llvm_type_t type = cgen_int_type_for(ir); + LLVMValueRef arg1 = cgen_coerce_value(req, cgb, ir->arg1, type); + LLVMValueRef arg2 = cgen_coerce_value(req, cgb, ir->arg2, type); + + cgb->outregs[ir->result] = LLVMBuildSRem(req->builder, arg1, arg2, + cgen_reg_name(ir->result)); +} + static void cgen_op_not(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) { LLVMValueRef arg1 = cgen_get_value(req, cgb, ir->arg1); @@ -798,7 +822,20 @@ static void cgen_op_cmp(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) LLVMValueRef arg1 = cgen_get_value(req, cgb, ir->arg1); LLVMValueRef arg2 = cgen_get_value(req, cgb, ir->arg2); - cgen_coerece(req, JIT_SZ_UNSPEC, &arg1, &arg2); + LLVMTypeRef type1 = LLVMTypeOf(arg1); + LLVMTypeRef type2 = LLVMTypeOf(arg2); + + if (LLVMGetTypeKind(type1) == LLVMPointerTypeKind) + arg2 = LLVMBuildIntToPtr(req->builder, arg2, req->types[LLVM_PTR], ""); + else { + const int bits1 = LLVMGetIntTypeWidth(type1); + const int bits2 = LLVMGetIntTypeWidth(type2); + + if (bits1 < bits2) + arg1 = LLVMBuildSExt(req->builder, arg1, type2, ""); + else if (bits2 > bits1) + arg2 = LLVMBuildSExt(req->builder, arg2, type1, ""); + } switch (ir->cc) { case JIT_CC_EQ: @@ -816,6 +853,9 @@ static void cgen_op_cmp(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) case JIT_CC_LE: cgb->outflags = LLVMBuildICmp(req->builder, LLVMIntSLE, arg1, arg2, ""); break; + case JIT_CC_GE: + cgb->outflags = LLVMBuildICmp(req->builder, LLVMIntSGE, arg1, arg2, ""); + break; default: jit_dump_with_mark(req->func, ir - req->func->irbuf, false); fatal_trace("unhandled cmp condition code"); @@ -884,6 +924,24 @@ static void cgen_op_neg(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) cgb->outregs[ir->result] = neg; } +static void cgen_macro_exp(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) +{ + LLVMValueRef arg1 = cgen_get_value(req, cgb, ir->arg1); + LLVMValueRef arg2 = cgen_get_value(req, cgb, ir->arg1); + + // TODO: implement this without the cast + LLVMValueRef cast[] = { + LLVMBuildUIToFP(req->builder, arg1, req->types[LLVM_DOUBLE], ""), + LLVMBuildUIToFP(req->builder, arg2, req->types[LLVM_DOUBLE], "") + }; + cgb->outregs[ir->result] = LLVMBuildFPToUI( + req->builder, + LLVMBuildCall2(req->builder, req->fntypes[LLVM_POW_F64], + cgen_get_fn(req, LLVM_POW_F64), cast, 2, ""), + req->types[req->regtypes[ir->result]], + cgen_reg_name(ir->result)); +} + static void cgen_macro_copy(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) { LLVMValueRef count = cgb->outregs[ir->result]; @@ -893,6 +951,14 @@ static void cgen_macro_copy(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) LLVMBuildMemMove(req->builder, dest, 0, src, 0, count); } +static void cgen_macro_bzero(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) +{ + LLVMValueRef count = cgb->outregs[ir->result]; + LLVMValueRef dest = cgen_coerce_value(req, cgb, ir->arg1, LLVM_PTR); + + LLVMBuildMemSet(req->builder, dest, llvm_int8(req, 0), count, 0); +} + static void cgen_macro_exit(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) { const unsigned irpos = ir - req->func->irbuf; @@ -1002,6 +1068,12 @@ static void cgen_ir(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) case J_MUL: cgen_op_mul(req, cgb, ir); break; + case J_DIV: + cgen_op_div(req, cgb, ir); + break; + case J_REM: + cgen_op_rem(req, cgb, ir); + break; case J_NOT: cgen_op_not(req, cgb, ir); break; @@ -1034,9 +1106,15 @@ static void cgen_ir(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) case J_NEG: cgen_op_neg(req, cgb, ir); break; + case MACRO_EXP: + cgen_macro_exp(req, cgb, ir); + break; case MACRO_COPY: cgen_macro_copy(req, cgb, ir); break; + case MACRO_BZERO: + cgen_macro_bzero(req, cgb, ir); + break; case MACRO_EXIT: cgen_macro_exit(req, cgb, ir); break; @@ -1105,34 +1183,6 @@ static void cgen_set_reg_type(cgen_req_t *req, jit_reg_t reg, } } -static void cgen_hint_value_type(cgen_req_t *req, jit_value_t value, - llvm_type_t type) -{ - switch (value.kind) { - case JIT_VALUE_REG: - case JIT_ADDR_REG: - cgen_set_reg_type(req, value.reg, type); - break; - - default: - break; - } -} - -static void cgen_hint_reg_size(cgen_req_t *req, jit_reg_t reg, jit_size_t sz) -{ - llvm_type_t type = LLVM_INTPTR; - switch (sz) { - case JIT_SZ_8: type = LLVM_INT8; break; - case JIT_SZ_16: type = LLVM_INT16; break; - case JIT_SZ_32: type = LLVM_INT32; break; - case JIT_SZ_64: type = LLVM_INT64; break; - case JIT_SZ_UNSPEC: break; - } - - cgen_set_reg_type(req, reg, type); -} - static void cgen_hint_reg_type(cgen_req_t *req, jit_reg_t reg, jit_value_t value) { @@ -1146,10 +1196,15 @@ static void cgen_hint_reg_type(cgen_req_t *req, jit_reg_t reg, break; default: - fatal_trace("cannot handle kind %d in cgen_constrain_type", value.kind); + fatal_trace("cannot handle kind %d in cgen_hint_reg_type", value.kind); } } +static inline bool cgen_is_ptr(cgen_req_t *req, jit_value_t value) +{ + return value.kind == JIT_VALUE_REG && req->regtypes[value.reg] == LLVM_PTR; +} + static void cgen_reg_types(cgen_req_t *req) { req->regtypes = xcalloc_array(req->func->nregs, sizeof(llvm_type_t)); @@ -1163,23 +1218,27 @@ static void cgen_reg_types(cgen_req_t *req) case J_LOAD: case J_ULOAD: - cgen_hint_reg_size(req, ir->result, ir->size); + cgen_set_reg_type(req, ir->result, cgen_int_type_for(ir)); break; - case J_MUL: case J_ADD: + if (cgen_is_ptr(req, ir->arg1)) + cgen_set_reg_type(req, ir->result, LLVM_PTR); + else + cgen_set_reg_type(req, ir->result, cgen_int_type_for(ir)); + break; + + case J_MUL: case J_SUB: - cgen_hint_reg_size(req, ir->result, ir->size); + case J_DIV: + case J_REM: + cgen_set_reg_type(req, ir->result, cgen_int_type_for(ir)); break; case J_CSET: cgen_set_reg_type(req, ir->result, LLVM_INT1); break; - case J_STORE: - cgen_hint_value_type(req, ir->arg2, LLVM_PTR); - break; - case J_LEA: case MACRO_GETPRIV: case MACRO_GALLOC: @@ -1203,12 +1262,18 @@ static void cgen_reg_types(cgen_req_t *req) case J_DEBUG: case J_RET: case J_CALL: + case J_STORE: + break; + + case MACRO_EXP: + cgen_hint_reg_type(req, ir->result, ir->arg1); break; case MACRO_EXIT: case MACRO_COPY: case MACRO_PUTPRIV: case MACRO_FFICALL: + case MACRO_BZERO: break; default: @@ -1284,6 +1349,7 @@ static void cgen_module(cgen_req_t *req) req->types[LLVM_INT32] = LLVMInt32TypeInContext(req->context); req->types[LLVM_INT64] = LLVMInt64TypeInContext(req->context); req->types[LLVM_INTPTR] = LLVMIntPtrTypeInContext(req->context, data_ref); + req->types[LLVM_DOUBLE] = LLVMDoubleTypeInContext(req->context); LLVMTypeRef atypes[] = { req->types[LLVM_PTR], // Function -- 2.39.2