optimize bind functions

This commit is contained in:
Ivan 2018-06-03 02:41:28 +03:00
parent eb211c899a
commit 0c8c8e8067
4 changed files with 141 additions and 14 deletions

View file

@ -71,25 +71,29 @@ llvm::Function *createBindFunc(llvm::Module &module,
llvm::Value *allocParam(
llvm::IRBuilder<> &builder, llvm::Type &srcType, const
llvm::DataLayout &layout, const ParamSlice& param,
llvm::function_ref<void(const std::string &)> errHandler) {
llvm::function_ref<void(const std::string &)> errHandler,
const BindOverride &override) {
if (param.type == ParamType::Aggregate && srcType.isPointerTy()) {
auto elemType = llvm::cast<llvm::PointerType>(&srcType)->getElementType();
auto stackArg = builder.CreateAlloca(elemType);
stackArg->setAlignment(layout.getABITypeAlignment(elemType));
auto init = parseInitializer(layout, *elemType, param.data, errHandler);
auto init = parseInitializer(layout, *elemType, param.data, errHandler,
override);
builder.CreateStore(init, stackArg);
return stackArg;
}
auto stackArg = builder.CreateAlloca(&srcType);
stackArg->setAlignment(layout.getABITypeAlignment(&srcType));
auto init = parseInitializer(layout, srcType, param.data, errHandler);
auto init = parseInitializer(layout, srcType, param.data, errHandler,
override);
builder.CreateStore(init, stackArg);
return builder.CreateLoad(stackArg);
}
void doBind(llvm::Module &module, llvm::Function &dstFunc,
llvm::Function &srcFunc, const llvm::ArrayRef<ParamSlice> &params,
llvm::function_ref<void(const std::string &)> errHandler) {
llvm::function_ref<void(const std::string &)> errHandler,
const BindOverride &override) {
auto& context = dstFunc.getContext();
auto bb = llvm::BasicBlock::Create(context, "", &dstFunc);
@ -107,7 +111,7 @@ void doBind(llvm::Module &module, llvm::Function &dstFunc,
++currentArg;
} else {
auto type = funcType->getParamType(static_cast<unsigned>(i));
arg = allocParam(builder, *type, layout, param, errHandler);
arg = allocParam(builder, *type, layout, param, errHandler, override);
}
assert(arg != nullptr);
args.push_back(arg);
@ -115,6 +119,10 @@ void doBind(llvm::Module &module, llvm::Function &dstFunc,
assert(currentArg == dstFunc.arg_end());
auto ret = builder.CreateCall(&srcFunc, args);
if (!srcFunc.isDeclaration()) {
ret->addAttribute(llvm::AttributeList::FunctionIndex,
llvm::Attribute::AlwaysInline);
}
ret->setCallingConv(srcFunc.getCallingConv());
ret->setAttributes(srcFunc.getAttributes());
if (dstFunc.getReturnType()->isVoidTy()) {
@ -125,14 +133,14 @@ void doBind(llvm::Module &module, llvm::Function &dstFunc,
}
}
llvm::Function *bindParamsToFunc(
llvm::Module &module, llvm::Function &srcFunc, llvm::Function &exampleFunc,
llvm::Function *bindParamsToFunc(llvm::Module &module, llvm::Function &srcFunc, llvm::Function &exampleFunc,
const llvm::ArrayRef<ParamSlice> &params,
llvm::function_ref<void(const std::string &)> errHandler) {
llvm::function_ref<void(const std::string &)> errHandler,
const BindOverride &override) {
auto srcType = srcFunc.getFunctionType();
auto dstType = getDstFuncType(*srcType, params);
auto newFunc = createBindFunc(module, srcFunc, exampleFunc, *dstType, params);
doBind(module, *newFunc, srcFunc, params, errHandler);
doBind(module, *newFunc, srcFunc, params, errHandler, override);
return newFunc;
}

View file

@ -18,16 +18,25 @@
#include "param_slice.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
namespace llvm {
class Constant;
class Type;
class Module;
class Function;
}
using BindOverride =
llvm::Optional<llvm::function_ref<llvm::Constant*(
llvm::Type &, const void *, size_t)>>;
llvm::Function *bindParamsToFunc(
llvm::Module &module, llvm::Function &srcFunc,llvm::Function &exampleFunc,
const llvm::ArrayRef<ParamSlice> &params,
llvm::function_ref<void(const std::string &)> errHandler);
llvm::function_ref<void(const std::string &)> errHandler,
const BindOverride &override = BindOverride{});
#endif // BIND_H

View file

@ -29,6 +29,7 @@
#include "llvm/Bitcode/BitcodeReader.h"
#include "llvm/ExecutionEngine/JITSymbol.h"
#include "llvm/IR/Constants.h"
#include "llvm/IR/Mangler.h"
#include "llvm/Linker/Linker.h"
#include "llvm/Support/raw_ostream.h"
@ -189,11 +190,33 @@ void generateBind(const Context &context, JITContext &jitContext,
if (funcToInline != nullptr) {
auto exampleFunc = getIrFunc(bindDesc.exampleFunc);
assert(exampleFunc != nullptr);
auto func = bindParamsToFunc(module, *funcToInline, *exampleFunc,
bindDesc.params,
[&](const std::string &str) {
auto errhandler = [&](const std::string &str) {
fatal(context, str);
});
};
auto overrideHandler =
[&](llvm::Type &type, const void* data, size_t size)->
llvm::Constant *{
if (type.isPointerTy()) {
auto ptype = llvm::cast<llvm::PointerType>(&type);
auto elemType = ptype->getElementType();
if (elemType->isFunctionTy()) {
(void)size;
assert(size == sizeof(void*));
auto val = *reinterpret_cast<void * const *>(data);
if (val != nullptr) {
auto ret = getIrFunc(val);
if (ret != nullptr && ret->getType() != &type) {
return llvm::ConstantExpr::getBitCast(ret, &type);
}
return ret;
}
}
}
return nullptr;
};
auto func = bindParamsToFunc(module, *funcToInline, *exampleFunc,
bindDesc.params, errhandler,
BindOverride(overrideHandler));
moduleInfo.addBindHandle(func->getName(), bindPtr);
} else {
// TODO: ignore for now, user must explicitly check BindPtr
@ -362,6 +385,9 @@ void rtCompileProcessImplSoInternal(const RtCompileModuleList *modlist_head,
JitFinaliser jitFinalizer(myJit);
interruptPoint(context, "Resolve functions");
for (auto &&fun : moduleInfo.functions()) {
if (fun.thunkVar == nullptr) {
continue;
}
auto decorated = decorate(fun.name, layout);
auto symbol = myJit.findSymbol(decorated);
auto addr = resolveSymbol(symbol);

View file

@ -0,0 +1,84 @@
// RUN: %ldc -enable-dynamic-compile -run %s
import std.array;
import std.stdio;
import std.string;
import ldc.attributes;
import ldc.dynamic_compile;
@dynamicCompile
{
int foo(int function() a, int function() b, int function() c)
{
return a() + b() + c();
}
int bar(int delegate() a, int delegate() b, int delegate() c)
{
return a() + b() + c();
}
int get1001()
{
return 1001;
}
int get1002()
{
return 1002;
}
int get1003()
{
return 1003;
}
}
void main(string[] args)
{
auto dump = appender!string();
CompilerSettings settings;
settings.optLevel = 3;
settings.dumpHandler = (DumpStage stage, in char[] str)
{
if (DumpStage.FinalAsm == stage)
{
write(str);
dump.put(str);
}
};
writeln("===========================================");
compileDynamicCode(settings);
writeln();
writeln("===========================================");
stdout.flush();
@dynamicCompile
int get1001d()
{
return 1001;
}
@dynamicCompile
int get1002d()
{
return 1002;
}
@dynamicCompile
int get1004d()
{
return 1004;
}
auto f = ldc.dynamic_compile.bind(&foo, &get1001, &get1002, &get1003);
auto b = ldc.dynamic_compile.bind(&bar, &get1001d, &get1002d, &get1004d);
compileDynamicCode(settings);
assert(3006 == f());
assert(3007 == b());
assert(indexOf(dump.data, "3006") != -1);
assert(indexOf(dump.data, "3007") != -1);
}