From 55abb7b96e6f12372e1c44be5c522f20212a34b8 Mon Sep 17 00:00:00 2001 From: Nick Gasson Date: Mon, 5 Dec 2022 14:40:10 +0000 Subject: [PATCH] Optimise $GETPRIV implementation in LLVM JIT --- src/jit/jit-core.c | 10 ++-------- src/jit/jit-exits.c | 6 +++--- src/jit/jit-interp.c | 4 ++-- src/jit/jit-llvm.c | 33 ++++++++++++++++++++++++++++----- src/jit/jit-priv.h | 3 +-- 5 files changed, 36 insertions(+), 20 deletions(-) diff --git a/src/jit/jit-core.c b/src/jit/jit-core.c index fa07b82b..e7a8b90b 100644 --- a/src/jit/jit-core.c +++ b/src/jit/jit-core.c @@ -411,18 +411,12 @@ void *jit_link(jit_t *j, jit_handle_t handle) return result.pointer; } -void *jit_get_privdata(jit_t *j, jit_func_t *f) +void **jit_get_privdata_ptr(jit_t *j, jit_func_t *f) { if (f->privdata == MPTR_INVALID) f->privdata = mptr_new(j->mspace, "privdata"); - return *mptr_get(f->privdata); -} - -void jit_put_privdata(jit_t *j, jit_func_t *f, void *ptr) -{ - assert(f->privdata != MPTR_INVALID); - *mptr_get(f->privdata) = ptr; + return mptr_get(f->privdata); } void *jit_get_frame_var(jit_t *j, jit_handle_t handle, uint32_t var) diff --git a/src/jit/jit-exits.c b/src/jit/jit-exits.c index 4407fbb5..3cdce73b 100644 --- a/src/jit/jit-exits.c +++ b/src/jit/jit-exits.c @@ -1473,12 +1473,12 @@ jit_handle_t __nvc_get_handle(const char *func, ffi_spec_t spec) } DLLEXPORT -void *__nvc_getpriv(jit_handle_t handle) +void **__nvc_getpriv(jit_handle_t handle) { jit_t *j = jit_thread_local()->jit; jit_func_t *f = jit_get_func(j, handle); - return jit_get_privdata(j, f); + return jit_get_privdata_ptr(j, f); } DLLEXPORT @@ -1487,7 +1487,7 @@ void __nvc_putpriv(jit_handle_t handle, void *data) jit_t *j = jit_thread_local()->jit; jit_func_t *f = jit_get_func(j, handle); - jit_put_privdata(j, f, data); + *jit_get_privdata_ptr(j, f) = data; } DLLEXPORT diff --git a/src/jit/jit-interp.c b/src/jit/jit-interp.c index 72799b69..d5f14158 100644 --- a/src/jit/jit-interp.c +++ b/src/jit/jit-interp.c @@ -701,7 +701,7 @@ static void interp_getpriv(jit_interp_t *state, jit_ir_t *ir) { JIT_ASSERT(ir->arg1.kind == JIT_VALUE_HANDLE); jit_func_t *f = jit_get_func(state->func->jit, ir->arg1.handle); - void *ptr = jit_get_privdata(state->func->jit, f); + void *ptr = *jit_get_privdata_ptr(state->func->jit, f); state->regs[ir->result].pointer = ptr; } @@ -710,7 +710,7 @@ static void interp_putpriv(jit_interp_t *state, jit_ir_t *ir) JIT_ASSERT(ir->arg1.kind == JIT_VALUE_HANDLE); jit_func_t *f = jit_get_func(state->func->jit, ir->arg1.handle); void *ptr = interp_get_value(state, ir->arg2).pointer; - jit_put_privdata(state->func->jit, f, ptr); + *jit_get_privdata_ptr(state->func->jit, f) = ptr; } static void interp_loop(jit_interp_t *state) diff --git a/src/jit/jit-llvm.c b/src/jit/jit-llvm.c index 958dd105..cdf3566d 100644 --- a/src/jit/jit-llvm.c +++ b/src/jit/jit-llvm.c @@ -1740,12 +1740,35 @@ static void cgen_macro_lalloc(llvm_obj_t *obj, cgen_block_t *cgb, jit_ir_t *ir) static void cgen_macro_getpriv(llvm_obj_t *obj, cgen_block_t *cgb, jit_ir_t *ir) { - // TODO: this needs some kind of fast-path + jit_func_t *f = jit_get_func(cgb->func->source->jit, ir->arg1.handle); - LLVMValueRef args[] = { - cgen_get_value(obj, cgb, ir->arg1) - }; - LLVMValueRef ptr = llvm_call_fn(obj, LLVM_GETPRIV, args, ARRAY_LEN(args)); + LOCAL_TEXT_BUF tb = tb_new(); + tb_istr(tb, f->name); + tb_cat(tb, ".privdata"); + + LLVMValueRef global = LLVMGetNamedGlobal(obj->module, tb_get(tb)); + if (global == NULL) { + global = LLVMAddGlobal(obj->module, obj->types[LLVM_PTR], tb_get(tb)); + LLVMSetUnnamedAddr(global, true); + LLVMSetLinkage(global, LLVMPrivateLinkage); + LLVMSetInitializer(global, llvm_ptr(obj, NULL)); + + LLVMBasicBlockRef old_bb = cgen_add_ctor(obj, 1); + + LLVMValueRef args[] = { + cgen_get_value(obj, cgb, ir->arg1) + }; + LLVMValueRef init = + llvm_call_fn(obj, LLVM_GETPRIV, args, ARRAY_LEN(args)); + LLVMBuildStore(obj->builder, init, global); + + LLVMPositionBuilderAtEnd(obj->builder, old_bb); + } + + LLVMValueRef ptrptr = + LLVMBuildLoad2(obj->builder, obj->types[LLVM_PTR], global, ""); + LLVMValueRef ptr = + LLVMBuildLoad2(obj->builder, obj->types[LLVM_PTR], ptrptr, ""); cgb->outregs[ir->result] = LLVMBuildPtrToInt(obj->builder, ptr, obj->types[LLVM_INT64], diff --git a/src/jit/jit-priv.h b/src/jit/jit-priv.h index d63432d4..cf6262a2 100644 --- a/src/jit/jit-priv.h +++ b/src/jit/jit-priv.h @@ -312,8 +312,7 @@ void jit_interp(jit_func_t *f, jit_anchor_t *caller, jit_scalar_t *args, jit_func_t *jit_get_func(jit_t *j, jit_handle_t handle); void jit_hexdump(const unsigned char *data, size_t sz, int blocksz, const void *highlight, const char *prefix); -void *jit_get_privdata(jit_t *j, jit_func_t *f); -void jit_put_privdata(jit_t *j, jit_func_t *f, void *ptr); +void **jit_get_privdata_ptr(jit_t *j, jit_func_t *f); bool jit_has_runtime(jit_t *j); int jit_backedge_limit(jit_t *j); void jit_tier_up(jit_func_t *f); -- 2.39.2