nested binds

This commit is contained in:
Ivan 2018-06-03 02:41:28 +03:00
parent 5c9cd62b30
commit 0184e07380
5 changed files with 146 additions and 14 deletions

View file

@ -31,6 +31,7 @@
#include "llvm/ExecutionEngine/JITSymbol.h" #include "llvm/ExecutionEngine/JITSymbol.h"
#include "llvm/IR/Constants.h" #include "llvm/IR/Constants.h"
#include "llvm/IR/Mangler.h" #include "llvm/IR/Mangler.h"
#include "llvm/IR/GlobalVariable.h"
#include "llvm/Linker/Linker.h" #include "llvm/Linker/Linker.h"
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
#include "llvm/Transforms/Utils/Cloning.h" #include "llvm/Transforms/Utils/Cloning.h"
@ -181,22 +182,31 @@ void generateBind(const Context &context, JITContext &jitContext,
} }
return module.getFunction(funcDesc->name); return module.getFunction(funcDesc->name);
}; };
for (auto &&bind : jitContext.getBindInstances()) {
auto bindPtr = bind.first;
auto &bindDesc = bind.second;
assert(bindDesc.originalFunc != nullptr);
auto funcToInline = getIrFunc(bindDesc.originalFunc); std::unordered_map<const void *, llvm::Function *> bindFuncs;
bindFuncs.reserve(jitContext.getBindInstances().size() * 2);
auto genBind = [&](void *bindPtr, void *originalFunc, void *exampleFunc,
const llvm::ArrayRef<ParamSlice> &params) {
assert(bindPtr != nullptr);
assert(bindFuncs.end() == bindFuncs.find(bindPtr));
auto funcToInline = getIrFunc(originalFunc);
if (funcToInline != nullptr) { if (funcToInline != nullptr) {
auto exampleFunc = getIrFunc(bindDesc.exampleFunc); auto exampleIrFunc = getIrFunc(exampleFunc);
assert(exampleFunc != nullptr); assert(exampleIrFunc != nullptr);
auto errhandler = [&](const std::string &str) { auto errhandler = [&](const std::string &str) {
fatal(context, str); fatal(context, str);
}; };
auto overrideHandler = auto overrideHandler =
[&](llvm::Type &type, const void* data, size_t size)-> [&](llvm::Type &type, const void *data, size_t size)->
llvm::Constant *{ llvm::Constant *{
if (type.isPointerTy()) { if (type.isPointerTy()) {
auto getBindFunc = [&]() {
auto handle = *static_cast<void * const *>(data);
return handle != nullptr && jitContext.hasBindFunction(handle) ?
handle : nullptr;
};
auto ptype = llvm::cast<llvm::PointerType>(&type); auto ptype = llvm::cast<llvm::PointerType>(&type);
auto elemType = ptype->getElementType(); auto elemType = ptype->getElementType();
if (elemType->isFunctionTy()) { if (elemType->isFunctionTy()) {
@ -210,17 +220,35 @@ void generateBind(const Context &context, JITContext &jitContext,
} }
return ret; return ret;
} }
} else if (auto handle = getBindFunc()) {
auto it = bindFuncs.find(handle);
assert(bindFuncs.end() != it);
auto bindIrFunc = it->second;
auto funcPtrType = bindIrFunc->getType();
auto globalVar1 = new llvm::GlobalVariable(
module, funcPtrType, true,
llvm::GlobalValue::PrivateLinkage,
bindIrFunc, ".jit_bind_handle");
return llvm::ConstantExpr::getBitCast(globalVar1, &type);
} }
} }
return nullptr; return nullptr;
}; };
auto func = bindParamsToFunc(module, *funcToInline, *exampleFunc, auto func = bindParamsToFunc(module, *funcToInline, *exampleIrFunc,
bindDesc.params, errhandler, params, errhandler,
BindOverride(overrideHandler)); BindOverride(overrideHandler));
moduleInfo.addBindHandle(func->getName(), bindPtr); moduleInfo.addBindHandle(func->getName(), bindPtr);
bindFuncs.insert({bindPtr, func});
} else { } else {
// TODO: ignore for now, user must explicitly check BindPtr // TODO: ignore for now, user must explicitly check BindPtr
} }
};
for (auto &&bind : jitContext.getBindInstances()) {
auto bindPtr = bind.first;
auto &bindDesc = bind.second;
assert(bindDesc.originalFunc != nullptr);
genBind(bindPtr, bindDesc.originalFunc, bindDesc.exampleFunc,
bindDesc.params);
} }
} }

View file

@ -167,6 +167,12 @@ void JITContext::unregisterBind(void *handle) {
bindInstances.erase(handle); bindInstances.erase(handle);
} }
bool JITContext::hasBindFunction(const void *handle) const {
assert(handle != nullptr);
auto it = bindInstances.find(const_cast<void*>(handle));
return it != bindInstances.end();
}
void JITContext::removeModule(const ModuleHandleT &handle) { void JITContext::removeModule(const ModuleHandleT &handle) {
cantFail(compileLayer.removeModule(handle)); cantFail(compileLayer.removeModule(handle));
#if LDC_LLVM_VER >= 700 #if LDC_LLVM_VER >= 700

View file

@ -128,6 +128,8 @@ public:
void unregisterBind(void *handle); void unregisterBind(void *handle);
bool hasBindFunction(const void *handle) const;
const llvm::MapVector<void*, BindDesc> &getBindInstances() const { const llvm::MapVector<void*, BindDesc> &getBindInstances() const {
return bindInstances; return bindInstances;
} }

View file

@ -165,9 +165,28 @@ template UnbindTypes(int[] Index, Args...)
struct BindPayloadBase(F) struct BindPayloadBase(F)
{ {
void function(ref BindPayloadBase!F) dtor; alias FuncParams = Parameters!(F);
alias Ret = ReturnType!F;
F func = null; F func = null;
static assert(func.offsetof == 0, "func must be fist");
void function(ref BindPayloadBase!F) dtor;
int counter = 1; int counter = 1;
auto isCallable() const
{
return func !is null;
}
auto opCall(FuncParams args)
{
assert(isCallable());
return func(args);
}
auto toDelegate() @nogc
{
return &opCall;
}
} }
struct BindPayload(OF, F, int[] Index, Args...) struct BindPayload(OF, F, int[] Index, Args...)
@ -246,6 +265,8 @@ struct BindPayload(OF, F, int[] Index, Args...)
registerBindPayload(&base.func, cast(void*)originalFunc, cast(void*)&exampleFunc, desc.ptr, desc.length); registerBindPayload(&base.func, cast(void*)originalFunc, cast(void*)&exampleFunc, desc.ptr, desc.length);
registered = true; registered = true;
} }
alias toDelegate = base.toDelegate;
} }
struct BindPtr(F) struct BindPtr(F)
@ -308,6 +329,12 @@ package:
} }
} }
import ldc.attributes;
@dynamicCompile auto dummyDel()
{
return toDelegate();
}
public: public:
this(this) this(this)
{ {
@ -331,16 +358,16 @@ public:
return _payload !is null && _payload.func !is null; return _payload !is null && _payload.func !is null;
} }
Ret opCall(FuncParams args) auto opCall(FuncParams args)
{ {
assert(_payload !is null);
assert(isCallable()); assert(isCallable());
return _payload.func(args); return _payload.func(args);
} }
auto toDelegate() @nogc auto toDelegate() @nogc
{ {
return &opCall; assert(_payload !is null);
return _payload.toDelegate();
} }
} }

View file

@ -0,0 +1,69 @@
// RUN: %ldc -enable-dynamic-compile -run %s
import std.array;
import std.stdio;
import std.string;
import ldc.attributes;
import ldc.dynamic_compile;
@dynamicCompile
{
int foo(int delegate() a, int delegate() b)
{
return a() + b();
}
int bar(int delegate() a, int delegate() b)
{
return a() + b();
}
int getVal(int val)
{
return val;
}
}
void main(string[] args)
{
auto dump = appender!string();
CompilerSettings settings;
settings.optLevel = 3;
settings.dumpHandler = (DumpStage stage, in char[] str)
{
if (DumpStage.FinalAsm == stage ||
DumpStage.MergedModule == stage ||
DumpStage.OptimizedModule == stage)
{
write(str);
dump.put(str);
}
};
writeln("===========================================");
compileDynamicCode(settings);
writeln();
writeln("===========================================");
stdout.flush();
auto v1 = ldc.dynamic_compile.bind(&getVal, 1001);
auto v2 = ldc.dynamic_compile.bind(&getVal, 1002);
auto v3 = ldc.dynamic_compile.bind(&getVal, 1003);
auto d1 = v1.toDelegate();
auto d2 = v2.toDelegate();
auto d3 = v3.toDelegate();
auto f1 = ldc.dynamic_compile.bind(&foo, d1, d2);
auto d4 = f1.toDelegate();
auto f2 = ldc.dynamic_compile.bind(&bar, d3, d4);
compileDynamicCode(settings);
assert(2003 == f1());
assert(3006 == f2());
assert(indexOf(dump.data, "2003") != -1);
assert(indexOf(dump.data, "3006") != -1);
}