From 4b285a20c081ecbb002f3bdad1972dc79e767486 Mon Sep 17 00:00:00 2001 From: Nick Gasson Date: Wed, 9 Nov 2022 19:54:27 +0000 Subject: [PATCH] Rematerialise JIT functions and FFI handles in AOT mode --- src/jit/jit-exits.c | 20 +++++++ src/jit/jit-ffi.c | 5 ++ src/jit/jit-ffi.h | 1 + src/jit/jit-llvm.c | 139 +++++++++++++++++++++++++++++++++----------- src/symbols.txt | 2 + 5 files changed, 133 insertions(+), 34 deletions(-) diff --git a/src/jit/jit-exits.c b/src/jit/jit-exits.c index 92e90cf8..ede575a1 100644 --- a/src/jit/jit-exits.c +++ b/src/jit/jit-exits.c @@ -1488,3 +1488,23 @@ void __nvc_trampoline(jit_func_t *f, jit_anchor_t *caller, jit_scalar_t *args) printf("trampoline! %s\n", istr(f->name)); (*f->entry)(f, caller, args); } + +DLLEXPORT +jit_func_t *__nvc_get_func(const char *name) +{ + printf("get_func %s\n", name); + + jit_t *j = jit_thread_local()->jit; + jit_handle_t handle = jit_lazy_compile(j, ident_new(name)); + printf("...handle = %d\n", handle); + + return jit_get_func(j, handle); +} + +DLLEXPORT +jit_foreign_t *__nvc_get_foreign(const char *name, ffi_spec_t spec) +{ + printf("get_foreign %s %lx\n", name, spec); + ident_t id = ident_new(name); + return jit_ffi_get(id) ?: jit_ffi_bind(id, spec, NULL); +} diff --git a/src/jit/jit-ffi.c b/src/jit/jit-ffi.c index 06e14bb1..5bd7a26b 100644 --- a/src/jit/jit-ffi.c +++ b/src/jit/jit-ffi.c @@ -324,3 +324,8 @@ ident_t ffi_get_sym(jit_foreign_t *ff) { return ff->sym; } + +ffi_spec_t ffi_get_spec(jit_foreign_t *ff) +{ + return ff->spec; +} diff --git a/src/jit/jit-ffi.h b/src/jit/jit-ffi.h index 2a6a33c1..9d28df84 100644 --- a/src/jit/jit-ffi.h +++ b/src/jit/jit-ffi.h @@ -62,6 +62,7 @@ jit_foreign_t *jit_ffi_get(ident_t sym); void jit_ffi_call(jit_foreign_t *ff, jit_scalar_t *args); ident_t ffi_get_sym(jit_foreign_t *ff); +ffi_spec_t ffi_get_spec(jit_foreign_t *ff); int ffi_count_args(ffi_spec_t spec); ffi_uarray_t ffi_wrap_str(char *buf, size_t len); diff --git a/src/jit/jit-llvm.c b/src/jit/jit-llvm.c index b4b07849..22d3ff5d 100644 --- a/src/jit/jit-llvm.c +++ b/src/jit/jit-llvm.c @@ -107,6 +107,8 @@ typedef enum { LLVM_DO_FFICALL, LLVM_TRAMPOLINE, LLVM_REGISTER, + LLVM_GET_FUNC, + LLVM_GET_FOREIGN, LLVM_LAST_FN, } llvm_fn_t; @@ -124,7 +126,6 @@ typedef struct _llvm_obj { LLVMValueRef fns[LLVM_LAST_FN]; LLVMTypeRef fntypes[LLVM_LAST_FN]; LLVMValueRef ctor; - text_buf_t *textbuf; shash_t *string_pool; } llvm_obj_t; @@ -609,6 +610,27 @@ static LLVMValueRef llvm_get_fn(llvm_obj_t *obj, llvm_fn_t which) } break; + case LLVM_GET_FUNC: + { + LLVMTypeRef args[] = { obj->types[LLVM_PTR] }; + obj->fntypes[which] = LLVMFunctionType(obj->types[LLVM_PTR], args, + ARRAY_LEN(args), false); + fn = llvm_add_fn(obj, "__nvc_get_func", obj->fntypes[which]); + } + break; + + case LLVM_GET_FOREIGN: + { + LLVMTypeRef args[] = { + obj->types[LLVM_PTR], + obj->types[LLVM_INT64] + }; + obj->fntypes[which] = LLVMFunctionType(obj->types[LLVM_PTR], args, + ARRAY_LEN(args), false); + fn = llvm_add_fn(obj, "__nvc_get_foreign", obj->fntypes[which]); + } + break; + default: fatal_trace("cannot generate prototype for function %d", which); } @@ -683,13 +705,6 @@ static const char *cgen_arg_name(int nth) #endif } -static const char *cgen_istr(llvm_obj_t *obj, ident_t id) -{ - tb_rewind(obj->textbuf); - tb_istr(obj->textbuf, id); - return tb_get(obj->textbuf); -} - __attribute__((noreturn)) static void cgen_abort(cgen_block_t *cgb, jit_ir_t *ir, const char *fmt, ...) { @@ -865,6 +880,14 @@ static void cgen_sync_irpos(llvm_obj_t *obj, cgen_block_t *cgb, jit_ir_t *ir) LLVMBuildStore(obj->builder, llvm_int32(obj, irpos), irpos_ptr); } +static LLVMBasicBlockRef cgen_add_ctor(llvm_obj_t *obj) +{ + assert(obj->ctor != NULL); + LLVMBasicBlockRef old_bb = LLVMGetInsertBlock(obj->builder); + LLVMPositionBuilderAtEnd(obj->builder, LLVMGetLastBasicBlock(obj->ctor)); + return old_bb; +} + static void cgen_op_recv(llvm_obj_t *obj, cgen_block_t *cgb, jit_ir_t *ir) { assert(ir->arg1.kind == JIT_VALUE_INT64); @@ -1264,36 +1287,53 @@ static void cgen_op_call(llvm_obj_t *obj, cgen_block_t *cgb, jit_ir_t *ir) jit_func_t *callee = jit_get_func(cgb->func->source->jit, ir->arg1.handle); - LLVMValueRef fnptr = NULL; - if (obj->ctor != NULL) - fnptr = llvm_get_fn(obj, LLVM_TRAMPOLINE); - else { - const char *name = cgen_istr(obj, callee->name); - LLVMValueRef global = LLVMGetNamedGlobal(obj->module, name); + LLVMValueRef entry = NULL, fptr = NULL; + if (obj->ctor != NULL) { + entry = llvm_get_fn(obj, LLVM_TRAMPOLINE); + + LOCAL_TEXT_BUF tb = tb_new(); + tb_istr(tb, callee->name); + tb_cat(tb, ".func"); + + LLVMValueRef global = LLVMGetNamedGlobal(obj->module, tb_get(tb)); if (global == NULL) { - global = LLVMAddGlobal(obj->module, obj->types[LLVM_PTR], name); - LLVMSetGlobalConstant(global, true); - LLVMSetLinkage(global, LLVMPrivateLinkage); + global = LLVMAddGlobal(obj->module, obj->types[LLVM_PTR], tb_get(tb)); LLVMSetUnnamedAddr(global, true); - LLVMSetInitializer(global, llvm_ptr(obj, callee->entry)); + LLVMSetLinkage(global, LLVMPrivateLinkage); + LLVMSetInitializer(global, llvm_ptr(obj, NULL)); + + LLVMBasicBlockRef old_bb = cgen_add_ctor(obj); + + tb_trim(tb, ident_len(callee->name)); // Strip .func + + LLVMValueRef args[] = { + llvm_const_string(obj, tb_get(tb)), + }; + LLVMValueRef init = + llvm_call_fn(obj, LLVM_GET_FUNC, args, ARRAY_LEN(args)); + LLVMBuildStore(obj->builder, init, global); + + LLVMPositionBuilderAtEnd(obj->builder, old_bb); } -#ifdef LLVM_HAS_OPAQUE_POINTERS - fnptr = LLVMBuildLoad2(obj->builder, obj->types[LLVM_PTR], global, name); -#else + fptr = LLVMBuildLoad2(obj->builder, obj->types[LLVM_PTR], global, ""); + } + else { + entry = llvm_ptr(obj, callee->entry); + fptr = llvm_ptr(obj, callee); + +#ifndef LLVM_HAS_OPAQUE_POINTERS LLVMTypeRef ptr_type = LLVMPointerType(obj->types[LLVM_ENTRY_FN], 0); - LLVMValueRef cast = - LLVMBuildPointerCast(obj->builder, global, ptr_type, ""); - fnptr = LLVMBuildLoad2(obj->builder, ptr_type, cast, name); + entry = LLVMBuildPointerCast(obj->builder, entry, ptr_type, ""); #endif } LLVMValueRef args[] = { - llvm_ptr(obj, callee), + fptr, PTR(cgb->func->anchor), cgb->func->args }; - LLVMBuildCall2(obj->builder, obj->types[LLVM_ENTRY_FN], fnptr, + LLVMBuildCall2(obj->builder, obj->types[LLVM_ENTRY_FN], entry, args, ARRAY_LEN(args), ""); } @@ -1386,10 +1426,44 @@ static void cgen_macro_fficall(llvm_obj_t *obj, cgen_block_t *cgb, jit_ir_t *ir) { cgen_sync_irpos(obj, cgb, ir); - LLVMValueRef ff = cgen_get_value(obj, cgb, ir->arg1); + LLVMValueRef ffptr; + if (obj->ctor != NULL) { + assert(ir->arg1.kind == JIT_VALUE_FOREIGN); + ident_t sym = ffi_get_sym(ir->arg1.foreign); + + LOCAL_TEXT_BUF tb = tb_new(); + tb_istr(tb, sym); + tb_cat(tb, ".ffi"); + + LLVMValueRef global = LLVMGetNamedGlobal(obj->module, tb_get(tb)); + if (global == NULL) { + global = LLVMAddGlobal(obj->module, obj->types[LLVM_PTR], tb_get(tb)); + LLVMSetUnnamedAddr(global, true); + LLVMSetLinkage(global, LLVMPrivateLinkage); + LLVMSetInitializer(global, llvm_ptr(obj, NULL)); + + LLVMBasicBlockRef old_bb = cgen_add_ctor(obj); + + tb_trim(tb, ident_len(sym)); // Strip .ffi + + LLVMValueRef args[] = { + llvm_const_string(obj, tb_get(tb)), + llvm_int64(obj, ffi_get_spec(ir->arg1.foreign)), + }; + LLVMValueRef init = + llvm_call_fn(obj, LLVM_GET_FOREIGN, args, ARRAY_LEN(args)); + LLVMBuildStore(obj->builder, init, global); + + LLVMPositionBuilderAtEnd(obj->builder, old_bb); + } + + ffptr = LLVMBuildLoad2(obj->builder, obj->types[LLVM_PTR], global, ""); + } + else + ffptr = cgen_get_value(obj, cgb, ir->arg1); LLVMValueRef args[] = { - ff, + ffptr, PTR(cgb->func->anchor), cgb->func->args }; @@ -1618,16 +1692,17 @@ static void cgen_function(llvm_obj_t *obj, cgen_func_t *func) { func->llvmfn = LLVMAddFunction(obj->module, func->name, obj->types[LLVM_ENTRY_FN]); - LLVMSetLinkage(func->llvmfn, LLVMPrivateLinkage); if (obj->ctor != NULL) { - LLVMPositionBuilderAtEnd(obj->builder, LLVMGetLastBasicBlock(obj->ctor)); + cgen_add_ctor(obj); LLVMValueRef args[] = { llvm_const_string(obj, func->name), PTR(func->llvmfn) }; llvm_call_fn(obj, LLVM_REGISTER, args, ARRAY_LEN(args)); + + LLVMSetLinkage(func->llvmfn, LLVMPrivateLinkage); } LLVMBasicBlockRef entry_bb = llvm_append_block(obj, func->llvmfn, "entry"); @@ -1798,7 +1873,6 @@ static void jit_llvm_cgen(jit_t *j, jit_handle_t handle, void *context) llvm_obj_t obj = { .context = LLVMOrcThreadSafeContextGetContext(state->context), .target = state->target, - .textbuf = tb_new(), }; LOCAL_TEXT_BUF tb = tb_new(); @@ -1836,7 +1910,6 @@ static void jit_llvm_cgen(jit_t *j, jit_handle_t handle, void *context) LLVMDisposeTargetData(obj.data_ref); LLVMDisposeBuilder(obj.builder); - tb_free(obj.textbuf); free(func.name); } @@ -1879,7 +1952,6 @@ llvm_obj_t *llvm_obj_new(const char *name) obj->module = LLVMModuleCreateWithNameInContext(name, obj->context); obj->builder = LLVMCreateBuilderInContext(obj->context); obj->target = llvm_target_machine(LLVMRelocPIC, LLVMCodeModelDefault); - obj->textbuf = tb_new(); obj->data_ref = LLVMCreateTargetDataLayout(obj->target); char *triple = LLVMGetTargetMachineTriple(obj->target); @@ -1967,7 +2039,6 @@ void llvm_obj_emit(llvm_obj_t *obj, const char *path) LLVMDisposeModule(obj->module); LLVMContextDispose(obj->context); - tb_free(obj->textbuf); shash_free(obj->string_pool); free(obj); diff --git a/src/symbols.txt b/src/symbols.txt index 5e9e9771..b02b56c8 100644 --- a/src/symbols.txt +++ b/src/symbols.txt @@ -42,6 +42,8 @@ __nvc_exponent_fail; __nvc_flush; __nvc_force; + __nvc_get_func; + __nvc_get_foreign; __nvc_get_handle; __nvc_getpriv; __nvc_index_fail; -- 2.39.2