From f934cfb39c7526f609e6548e342d3330a0974057 Mon Sep 17 00:00:00 2001 From: Nick Gasson Date: Sun, 30 Oct 2022 17:23:38 +0000 Subject: [PATCH] Add LLVM JIT support for floating point types --- src/jit/jit-llvm.c | 208 ++++++++++++++++++++++++++++++++++++++----- test/perf/simple.vhd | 25 ++++++ 2 files changed, 213 insertions(+), 20 deletions(-) diff --git a/src/jit/jit-llvm.c b/src/jit/jit-llvm.c index fdad32d2..2203c0ad 100644 --- a/src/jit/jit-llvm.c +++ b/src/jit/jit-llvm.c @@ -91,6 +91,7 @@ typedef enum { LLVM_MUL_OVERFLOW_U64, LLVM_POW_F64, + LLVM_ROUND_F64, LLVM_DO_EXIT, LLVM_GETPRIV, @@ -182,6 +183,11 @@ static LLVMValueRef llvm_ptr(cgen_req_t *req, void *ptr) req->types[LLVM_PTR]); } +static LLVMValueRef llvm_real(cgen_req_t *req, double r) +{ + return LLVMConstReal(req->types[LLVM_DOUBLE], r); +} + static LLVMBasicBlockRef cgen_append_block(cgen_req_t *req, const char *name) { return LLVMAppendBasicBlockInContext(req->context, req->llvmfn, name); @@ -363,6 +369,17 @@ static LLVMValueRef cgen_get_fn(cgen_req_t *req, llvm_fn_t which) } break; + case LLVM_ROUND_F64: + { + LLVMTypeRef args[] = { req->types[LLVM_DOUBLE] }; + req->fntypes[which] = LLVMFunctionType(req->types[LLVM_DOUBLE], + args, ARRAY_LEN(args), false); + + fn = LLVMAddFunction(req->module, "llvm.round.f64", + req->fntypes[which]); + } + break; + case LLVM_DO_EXIT: { LLVMTypeRef args[] = { @@ -488,9 +505,8 @@ static LLVMValueRef cgen_get_value(cgen_req_t *req, cgen_block_t *cgb, return cgb->outregs[value.reg]; case JIT_VALUE_INT64: return llvm_int64(req, value.int64); - /* case JIT_VALUE_DOUBLE: - return (jit_scalar_t){ .real = value.dval };*/ + return llvm_real(req, value.dval); case JIT_ADDR_FRAME: { assert(value.int64 >= 0 && value.int64 < req->func->framesz); @@ -563,6 +579,9 @@ static LLVMValueRef cgen_coerce_value(cgen_req_t *req, cgen_block_t *cgb, return LLVMBuildTrunc(req->builder, raw, req->types[type], ""); } + case LLVM_DOUBLE: + return LLVMBuildBitCast(req->builder, raw, req->types[type], ""); + default: return raw; } @@ -572,30 +591,50 @@ static void cgen_sext_result(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir, LLVMValueRef value) { LLVMTypeRef type = LLVMTypeOf(value); - if (LLVMGetTypeKind(type) == LLVMIntegerTypeKind - && LLVMGetIntTypeWidth(type) == 64) { - DEBUG_ONLY(LLVMSetValueName(value, cgen_reg_name(ir->result))); - cgb->outregs[ir->result] = value; + switch (LLVMGetTypeKind(type)) { + case LLVMIntegerTypeKind: + if (LLVMGetIntTypeWidth(type) == 64) { + DEBUG_ONLY(LLVMSetValueName(value, cgen_reg_name(ir->result))); + cgb->outregs[ir->result] = value; + } + else + cgb->outregs[ir->result] = LLVMBuildSExt(req->builder, value, + req->types[LLVM_INT64], + cgen_reg_name(ir->result)); + break; + + case LLVMDoubleTypeKind: + cgb->outregs[ir->result] = LLVMBuildBitCast(req->builder, value, + req->types[LLVM_INT64], + cgen_reg_name(ir->result)); + break; + + default: + LLVMDumpType(type); + fatal_trace("unhandled LLVM type kind in cgen_sext_result"); } - else - cgb->outregs[ir->result] = LLVMBuildSExt(req->builder, value, - req->types[LLVM_INT64], - cgen_reg_name(ir->result)); } static void cgen_zext_result(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir, LLVMValueRef value) { LLVMTypeRef type = LLVMTypeOf(value); - if (LLVMGetTypeKind(type) == LLVMIntegerTypeKind - && LLVMGetIntTypeWidth(type) == 64) { - DEBUG_ONLY(LLVMSetValueName(value, cgen_reg_name(ir->result))); - cgb->outregs[ir->result] = value; + switch (LLVMGetTypeKind(type)) { + case LLVMIntegerTypeKind: + if (LLVMGetIntTypeWidth(type) == 64) { + DEBUG_ONLY(LLVMSetValueName(value, cgen_reg_name(ir->result))); + cgb->outregs[ir->result] = value; + } + else + cgb->outregs[ir->result] = LLVMBuildZExt(req->builder, value, + req->types[LLVM_INT64], + cgen_reg_name(ir->result)); + break; + + default: + LLVMDumpType(type); + fatal_trace("unhandled LLVM type kind in cgen_sext_result"); } - else - cgb->outregs[ir->result] = LLVMBuildZExt(req->builder, value, - req->types[LLVM_INT64], - cgen_reg_name(ir->result)); } static void cgen_op_recv(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) @@ -786,6 +825,74 @@ static void cgen_op_rem(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) cgen_reg_name(ir->result)); } +static void cgen_op_fadd(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) +{ + LLVMValueRef arg1 = cgen_coerce_value(req, cgb, ir->arg1, LLVM_DOUBLE); + LLVMValueRef arg2 = cgen_coerce_value(req, cgb, ir->arg2, LLVM_DOUBLE); + + LLVMValueRef real = LLVMBuildFAdd(req->builder, arg1, arg2, ""); + cgen_sext_result(req, cgb, ir, real); +} + +static void cgen_op_fsub(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) +{ + LLVMValueRef arg1 = cgen_coerce_value(req, cgb, ir->arg1, LLVM_DOUBLE); + LLVMValueRef arg2 = cgen_coerce_value(req, cgb, ir->arg2, LLVM_DOUBLE); + + LLVMValueRef real = LLVMBuildFSub(req->builder, arg1, arg2, ""); + cgen_sext_result(req, cgb, ir, real); +} + +static void cgen_op_fmul(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) +{ + LLVMValueRef arg1 = cgen_coerce_value(req, cgb, ir->arg1, LLVM_DOUBLE); + LLVMValueRef arg2 = cgen_coerce_value(req, cgb, ir->arg2, LLVM_DOUBLE); + + LLVMValueRef real = LLVMBuildFMul(req->builder, arg1, arg2, ""); + cgen_sext_result(req, cgb, ir, real); +} + +static void cgen_op_fdiv(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) +{ + LLVMValueRef arg1 = cgen_coerce_value(req, cgb, ir->arg1, LLVM_DOUBLE); + LLVMValueRef arg2 = cgen_coerce_value(req, cgb, ir->arg2, LLVM_DOUBLE); + + LLVMValueRef real = LLVMBuildFDiv(req->builder, arg1, arg2, ""); + cgen_sext_result(req, cgb, ir, real); +} + +static void cgen_op_fneg(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) +{ + LLVMValueRef arg1 = cgen_coerce_value(req, cgb, ir->arg1, LLVM_DOUBLE); + + LLVMValueRef real = LLVMBuildFNeg(req->builder, arg1, ""); + cgen_sext_result(req, cgb, ir, real); +} + +static void cgen_op_fcvtns(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) +{ + LLVMValueRef arg1 = cgen_coerce_value(req, cgb, ir->arg1, LLVM_DOUBLE); + + LLVMValueRef fn = cgen_get_fn(req, LLVM_ROUND_F64); + LLVMValueRef args[] = { arg1 }; + LLVMValueRef rounded = LLVMBuildCall2(req->builder, + req->fntypes[LLVM_ROUND_F64], fn, + args, ARRAY_LEN(args), ""); + + cgb->outregs[ir->result] = LLVMBuildFPToSI(req->builder, rounded, + req->types[LLVM_INT64], + cgen_reg_name(ir->result)); +} + +static void cgen_op_scvtf(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) +{ + LLVMValueRef arg1 = cgen_get_value(req, cgb, ir->arg1); + + LLVMValueRef real = LLVMBuildSIToFP(req->builder, arg1, + req->types[LLVM_DOUBLE], ""); + cgen_sext_result(req, cgb, ir, real); +} + static void cgen_op_not(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) { LLVMValueRef arg1 = cgen_get_value(req, cgb, ir->arg1); @@ -871,6 +978,27 @@ static void cgen_op_cmp(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) cgb->outflags = LLVMBuildICmp(req->builder, pred, arg1, arg2, "FLAGS"); } +static void cgen_op_fcmp(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) +{ + LLVMValueRef arg1 = cgen_coerce_value(req, cgb, ir->arg1, LLVM_DOUBLE); + LLVMValueRef arg2 = cgen_coerce_value(req, cgb, ir->arg2, LLVM_DOUBLE); + + LLVMRealPredicate pred; + switch (ir->cc) { + case JIT_CC_EQ: pred = LLVMRealUEQ; break; + case JIT_CC_NE: pred = LLVMRealUNE; break; + case JIT_CC_GT: pred = LLVMRealUGT; break; + case JIT_CC_LT: pred = LLVMRealULT; break; + case JIT_CC_LE: pred = LLVMRealULE; break; + case JIT_CC_GE: pred = LLVMRealUGE; break; + default: + jit_dump_with_mark(req->func, ir - req->func->irbuf, false); + fatal_trace("unhandled fcmp condition code"); + } + + cgb->outflags = LLVMBuildFCmp(req->builder, pred, arg1, arg2, "FLAGS"); +} + static void cgen_op_cset(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) { cgen_zext_result(req, cgb, ir, cgb->outflags); @@ -952,17 +1080,30 @@ static void cgen_macro_exp(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) LLVMValueRef arg2 = cgen_get_value(req, cgb, ir->arg1); // TODO: implement this without the cast + LLVMValueRef fn = cgen_get_fn(req, LLVM_POW_F64); LLVMValueRef cast[] = { LLVMBuildUIToFP(req->builder, arg1, req->types[LLVM_DOUBLE], ""), LLVMBuildUIToFP(req->builder, arg2, req->types[LLVM_DOUBLE], "") }; cgb->outregs[ir->result] = LLVMBuildFPToUI( req->builder, - LLVMBuildCall2(req->builder, req->fntypes[LLVM_POW_F64], - cgen_get_fn(req, LLVM_POW_F64), cast, 2, ""), + LLVMBuildCall2(req->builder, req->fntypes[LLVM_POW_F64], fn, cast, 2, ""), req->types[LLVM_INT64], cgen_reg_name(ir->result)); } +static void cgen_macro_fexp(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) +{ + LLVMValueRef arg1 = cgen_get_value(req, cgb, ir->arg1); + LLVMValueRef arg2 = cgen_get_value(req, cgb, ir->arg1); + + LLVMValueRef fn = cgen_get_fn(req, LLVM_POW_F64); + LLVMValueRef args[] = { arg1, arg2 }; + LLVMValueRef real = LLVMBuildCall2(req->builder, req->fntypes[LLVM_POW_F64], + fn, args, ARRAY_LEN(args), ""); + + cgen_sext_result(req, cgb, ir, real); +} + static void cgen_macro_copy(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) { LLVMValueRef count = cgb->outregs[ir->result]; @@ -1100,6 +1241,27 @@ static void cgen_ir(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) case J_REM: cgen_op_rem(req, cgb, ir); break; + case J_FADD: + cgen_op_fadd(req, cgb, ir); + break; + case J_FSUB: + cgen_op_fsub(req, cgb, ir); + break; + case J_FMUL: + cgen_op_fmul(req, cgb, ir); + break; + case J_FDIV: + cgen_op_fdiv(req, cgb, ir); + break; + case J_FNEG: + cgen_op_fneg(req, cgb, ir); + break; + case J_FCVTNS: + cgen_op_fcvtns(req, cgb, ir); + break; + case J_SCVTF: + cgen_op_scvtf(req, cgb, ir); + break; case J_NOT: cgen_op_not(req, cgb, ir); break; @@ -1121,6 +1283,9 @@ static void cgen_ir(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) case J_CMP: cgen_op_cmp(req, cgb, ir); break; + case J_FCMP: + cgen_op_fcmp(req, cgb, ir); + break; case J_CSET: cgen_op_cset(req, cgb, ir); break; @@ -1144,6 +1309,9 @@ static void cgen_ir(cgen_req_t *req, cgen_block_t *cgb, jit_ir_t *ir) case MACRO_EXP: cgen_macro_exp(req, cgb, ir); break; + case MACRO_FEXP: + cgen_macro_fexp(req, cgb, ir); + break; case MACRO_COPY: cgen_macro_copy(req, cgb, ir); break; diff --git a/test/perf/simple.vhd b/test/perf/simple.vhd index 977b2387..538de2f1 100644 --- a/test/perf/simple.vhd +++ b/test/perf/simple.vhd @@ -3,6 +3,7 @@ package simple is procedure test_fact; procedure test_sum; procedure test_int_image; + procedure test_sqrt; end package; package body simple is @@ -77,4 +78,28 @@ package body simple is end loop; end procedure; + --------------------------------------------------------------------------- + + function sqrt (n, limit : real) return real is + variable x : real := n; + variable root : real; + begin + loop + root := 0.5 * (x + (n / x)); + exit when abs(root - x) < limit; + x := root; + end loop; + return root; + end function; + + procedure test_sqrt is + variable sum : real := 0.0; + begin + assert abs(sqrt(4.0, 0.0001) - 2.0) < 0.0001; + for i in 1 to 100 loop + sum := sum + sqrt(real(i), 0.0001); + end loop; + assert integer(sum) = 671; + end procedure; + end package body; -- 2.39.2