basic bind works!

This commit is contained in:
Ivan 2018-04-10 00:48:35 +03:00
parent 6e52f7fde4
commit 28abe2c975
5 changed files with 312 additions and 44 deletions

View file

@ -0,0 +1,118 @@
//===-- bind.cpp ----------------------------------------------------------===//
//
// LDC the LLVM D compiler
//
// This file is distributed under the Boost Software License. See the LICENSE
// file for details.
//
//===----------------------------------------------------------------------===//
#include "bind.h"
#include "llvm/IR/BasicBlock.h"
#include "llvm/IR/Function.h"
#include "llvm/IR/IRBuilder.h"
#include "llvm/IR/Module.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/Transforms/Utils/Cloning.h"
#include "valueparser.h"
namespace {
enum {
SmallParamsCount = 5
};
llvm::FunctionType *getDstFuncType(llvm::FunctionType &srcType,
const llvm::ArrayRef<Slice> &params) {
assert(!srcType.isVarArg());
llvm::SmallVector<llvm::Type*, SmallParamsCount> newParams;
const auto srcParamsCount = srcType.params().size();
assert(params.size() == srcParamsCount);
for (size_t i = 0; i < srcParamsCount; ++i) {
if (params[i].data == nullptr) {
newParams.push_back(srcType.getParamType(static_cast<unsigned>(i)));
}
}
auto retType = srcType.getReturnType();
return llvm::FunctionType::get(retType, newParams, /*isVarArg*/false);
}
llvm::Function *createBindFunc(llvm::Module &module,
llvm::Function &srcFunc,
llvm::FunctionType &funcType,
const llvm::ArrayRef<Slice> &params) {
auto newFunc = llvm::Function::Create(
&funcType, llvm::GlobalValue::ExternalLinkage, "\1.jit_bind",
&module);
newFunc->setCallingConv(srcFunc.getCallingConv());
auto srcAttributes = srcFunc.getAttributes();
newFunc->addAttributes(llvm::AttributeList::ReturnIndex,
srcAttributes.getRetAttributes());
newFunc->addAttributes(llvm::AttributeList::FunctionIndex,
srcAttributes.getFnAttributes());
unsigned dstInd = 0;
for (size_t i = 0; i < params.size(); ++i) {
if (params[i].data == nullptr) {
newFunc->addAttributes(llvm::AttributeList::FirstArgIndex + dstInd,
srcAttributes.getParamAttributes(
static_cast<unsigned>(i)));
++dstInd;
}
}
assert(dstInd == funcType.getNumParams());
return newFunc;
}
void doBind(llvm::Module &module, llvm::Function &dstFunc,
llvm::Function &srcFunc, const llvm::ArrayRef<Slice> &params,
llvm::function_ref<void(const std::string &)> errHandler) {
auto& context = dstFunc.getContext();
auto bb = llvm::BasicBlock::Create(context, "", &dstFunc);
llvm::IRBuilder<> builder(context);
builder.SetInsertPoint(bb);
llvm::SmallVector<llvm::Value*, SmallParamsCount> args;
auto currentArg = dstFunc.arg_begin();
for (size_t i = 0; i < params.size(); ++i) {
auto type = srcFunc.getFunctionType()->getParamType(static_cast<unsigned>(i));
llvm::Value* arg = nullptr;
if (params[i].data == nullptr) {
arg = currentArg;
++currentArg;
} else {
auto &layout = module.getDataLayout();
auto stackArg = builder.CreateAlloca(type);
stackArg->setAlignment(layout.getABITypeAlignment(type));
const auto& param = params[i];
auto init = parseInitializer(layout, *type, param.data, errHandler);
builder.CreateStore(init, stackArg);
arg = builder.CreateLoad(stackArg);
}
assert(arg != nullptr);
args.push_back(arg);
}
assert(currentArg == dstFunc.arg_end());
auto ret = builder.CreateCall(&srcFunc, args);
ret->setCallingConv(srcFunc.getCallingConv());
if (dstFunc.getReturnType()->isVoidTy()) {
builder.CreateRetVoid();
} else {
builder.CreateRet(ret);
}
}
}
llvm::Function *bindParamsToFunc(
llvm::Module &module, llvm::Function &srcFunc,
const llvm::ArrayRef<Slice> &params,
llvm::function_ref<void(const std::string &)> errHandler) {
auto srcType = srcFunc.getFunctionType();
auto dstType = getDstFuncType(*srcType, params);
auto newFunc = createBindFunc(module, srcFunc, *dstType, params);
doBind(module, *newFunc, srcFunc, params, errHandler);
return newFunc;
}

View file

@ -0,0 +1,33 @@
//===-- bind.h - jit support ------------------------------------*- C++ -*-===//
//
// LDC the LLVM D compiler
//
// This file is distributed under the Boost Software License. See the LICENSE
// file for details.
//
//===----------------------------------------------------------------------===//
//
// Jit runtime - support routines for bind, allow to dynamically create
// specialized functions for each bind instance.
//
//===----------------------------------------------------------------------===//
#ifndef BIND_H
#define BIND_H
#include "context.h" // Slice
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/STLExtras.h"
namespace llvm {
class Module;
class Function;
}
llvm::Function *bindParamsToFunc(
llvm::Module &module, llvm::Function &srcFunc,
const llvm::ArrayRef<Slice> &params,
llvm::function_ref<void(const std::string &)> errHandler);
#endif // BIND_H

View file

@ -20,6 +20,7 @@
#include <type_traits> #include <type_traits>
#include <unordered_map> #include <unordered_map>
#include "bind.h"
#include "callback_ostream.h" #include "callback_ostream.h"
#include "context.h" #include "context.h"
#include "jit_context.h" #include "jit_context.h"
@ -88,9 +89,18 @@ void enumModules(const RtCompileModuleList *modlist_head,
} }
} }
std::string decorate(const std::string &name,
const llvm::DataLayout &datalayout) {
assert(!name.empty());
llvm::SmallVector<char, 64> ret;
llvm::Mangler::getNameWithPrefix(ret, name, datalayout);
assert(!ret.empty());
return std::string(ret.data(), ret.size());
}
struct JitModuleInfo final { struct JitModuleInfo final {
private: private:
struct Func { struct Func final {
llvm::StringRef name; llvm::StringRef name;
void **thunkVar; void **thunkVar;
void *originalFunc; void *originalFunc;
@ -98,6 +108,12 @@ private:
std::vector<Func> funcs; std::vector<Func> funcs;
mutable std::unordered_map<const void *, const Func *> funcsMap; mutable std::unordered_map<const void *, const Func *> funcsMap;
struct BindHandle final {
std::string name;
void* handle = nullptr;
};
std::vector<BindHandle> bindHandles;
public: public:
JitModuleInfo(const Context &context, JitModuleInfo(const Context &context,
const RtCompileModuleList *modlist_head) { const RtCompileModuleList *modlist_head) {
@ -129,14 +145,75 @@ public:
} }
return nullptr; return nullptr;
} }
const std::vector<BindHandle> &getBindHandles() const {
return bindHandles;
}
void addBindHandle(llvm::StringRef name, void *handle) {
assert(!name.empty());
assert(handle != nullptr);
BindHandle h;
h.name = name.str();
h.handle = handle;
bindHandles.emplace_back(std::move(h));
}
}; };
std::string decorate(const std::string &name, void *resolveSymbol(llvm::JITSymbol &symbol) {
const llvm::DataLayout &datalayout) { auto addr = symbol.getAddress();
llvm::SmallVector<char, 64> ret; if (!addr) {
llvm::Mangler::getNameWithPrefix(ret, name, datalayout); consumeError(addr.takeError());
assert(!ret.empty()); return nullptr;
return std::string(ret.data(), ret.size()); } else {
return reinterpret_cast<void *>(addr.get());
}
}
void generateBind(const Context &context, JITContext &jitContext,
JitModuleInfo &moduleInfo, llvm::Module &module) {
auto getIrFunc = [&](const void *ptr)->llvm::Function * {
assert(ptr != nullptr);
auto funcDesc = moduleInfo.getFunc(ptr);
if (funcDesc == nullptr) {
return nullptr;
}
return module.getFunction(funcDesc->name);
};
for (auto &&bind : jitContext.getBindInstances()) {
auto bindPtr = bind.first;
auto &bindDesc = bind.second;
assert(bindDesc.originalFunc != nullptr);
auto funcToInline = getIrFunc(bindDesc.originalFunc);
if (funcToInline != nullptr) {
auto func = bindParamsToFunc(module, *funcToInline, bindDesc.params,
[&](const std::string &str) {
fatal(context, str);
});
moduleInfo.addBindHandle(func->getName(), bindPtr);
} else {
// TODO: ignore for now, user must explicitly check BindPtr
}
}
}
void applyBind(const Context &context, JITContext &jitContext,
const JitModuleInfo &moduleInfo) {
auto &layout = jitContext.getDataLayout();
for (auto& elem : moduleInfo.getBindHandles()) {
auto decorated = decorate(elem.name, layout);
auto symbol = jitContext.findSymbol(decorated);
auto addr = resolveSymbol(symbol);
if (nullptr == addr) {
std::string desc = std::string("Symbol not found in jitted code: \"") +
elem.name + "\" (\"" + decorated + "\")";
fatal(context, desc);
} else {
auto handle = static_cast<void**>(elem.handle);
*handle = addr;
}
}
} }
JITContext &getJit() { JITContext &getJit() {
@ -151,16 +228,6 @@ void setRtCompileVars(const Context &context, llvm::Module &module,
} }
} }
void *resolveSymbol(llvm::JITSymbol &symbol) {
auto addr = symbol.getAddress();
if (!addr) {
consumeError(addr.takeError());
return nullptr;
} else {
return reinterpret_cast<void *>(addr.get());
}
}
void dumpModule(const Context &context, const llvm::Module &module, void dumpModule(const Context &context, const llvm::Module &module,
DumpStage stage) { DumpStage stage) {
if (nullptr != context.dumpHandler) { if (nullptr != context.dumpHandler) {
@ -260,6 +327,9 @@ void rtCompileProcessImplSoInternal(const RtCompileModuleList *modlist_head,
}); });
assert(nullptr != finalModule); assert(nullptr != finalModule);
interruptPoint(context, "Generate bind functions");
generateBind(context, myJit, moduleInfo, *finalModule);
dumpModule(context, *finalModule, DumpStage::MergedModule); dumpModule(context, *finalModule, DumpStage::MergedModule);
interruptPoint(context, "Optimize final module"); interruptPoint(context, "Optimize final module");
optimizeModule(context, myJit.getTargetMachine(), settings, *finalModule); optimizeModule(context, myJit.getTargetMachine(), settings, *finalModule);
@ -307,6 +377,8 @@ void rtCompileProcessImplSoInternal(const RtCompileModuleList *modlist_head,
interruptPoint(context, "Resolved", str.c_str()); interruptPoint(context, "Resolved", str.c_str());
} }
} }
interruptPoint(context, "Update bind handles");
applyBind(context, myJit, moduleInfo);
jitFinalizer.finalze(); jitFinalizer.finalze();
} }

View file

@ -271,15 +271,7 @@ package:
return ret; return ret;
} }
public: void decPayload()
this(this)
{
if (_payload !is null)
{
++_payload.counter;
}
}
~this()
{ {
if (_payload !is null) if (_payload !is null)
{ {
@ -289,24 +281,53 @@ public:
{ {
_payload.dtor(*_payload); _payload.dtor(*_payload);
pureFree(_payload); pureFree(_payload);
}
_payload = null; _payload = null;
} }
} }
void incPayload()
{
if (_payload !is null)
{
++_payload.counter;
}
}
public:
this(this)
{
incPayload();
}
~this()
{
decPayload();
} }
void opAssign(typeof(this) rhs) void opAssign(typeof(this) rhs)
{ {
import std.algorithm.mutation : swap; import std.algorithm.mutation : swap;
decPayload();
_payload = rhs._payload;
incPayload();
}
swap(_payload, rhs._payload); bool isCallable() const pure nothrow @safe @nogc
{
return _payload.func !is null;
} }
Ret opCall(FuncParams args) Ret opCall(FuncParams args)
{ {
assert(_payload !is null); assert(_payload !is null);
assert(_payload.func !is null); assert(isCallable());
return _payload.func(args); return _payload.func(args);
} }
auto toDelegate()
{
return &opCall;
}
} }
extern(C) extern(C)

View file

@ -9,23 +9,47 @@ import ldc.dynamic_compile;
return a + b + c; return a + b + c;
} }
int bar(int a, int b, int c)
{
return a + b + c;
}
void main(string[] args) void main(string[] args)
{
foreach (i; 0..4)
{ {
CompilerSettings settings; CompilerSettings settings;
settings.dumpHandler = (DumpStage stage, in char[] str) settings.optLevel = i;
{
if (DumpStage.OriginalModule == stage)
{
import std.stdio;
//write(str);
//stdout.flush();
}
};
auto f1 = ldc.dynamic_compile.bind(&foo, placeholder, placeholder, placeholder); auto f1 = ldc.dynamic_compile.bind(&foo, placeholder, placeholder, placeholder);
auto f2 = ldc.dynamic_compile.bind(&foo, 1, placeholder, 3); auto f2 = ldc.dynamic_compile.bind(&foo, 1, placeholder, 3);
auto f3 = ldc.dynamic_compile.bind(&foo, 1, 2, 3); auto f3 = ldc.dynamic_compile.bind(&foo, 1, 2, 3);
auto f4 = f3;
int delegate(int,int,int) fd1 = f1.toDelegate();
int delegate(int) fd2 = f2.toDelegate();
int delegate() fd3 = f3.toDelegate();
int delegate() fd4 = f4.toDelegate();
auto b1 = ldc.dynamic_compile.bind(&bar, placeholder, placeholder, placeholder);
auto b2 = ldc.dynamic_compile.bind(&bar, 1, placeholder, 3);
auto b3 = ldc.dynamic_compile.bind(&bar, 1, 2, 3);
auto b4 = b3;
compileDynamicCode(settings); compileDynamicCode(settings);
//h(2); assert(6 == f1(1,2,3));
assert(false); assert(6 == f2(2));
assert(6 == f3());
assert(6 == f4());
assert(6 == fd1(1,2,3));
assert(6 == fd2(2));
assert(6 == fd3());
assert(6 == fd4());
assert(!b1.isCallable());
assert(!b2.isCallable());
assert(!b3.isCallable());
assert(!b4.isCallable());
}
} }