Jit TLS workaround

This commit is contained in:
Ivan 2017-10-16 14:54:42 +03:00
parent 5bbfb93ec2
commit 3a34845395
5 changed files with 117 additions and 17 deletions

View file

@ -461,6 +461,12 @@ cl::opt<bool> enableRuntimeCompile(
"enable-runtime-compile", "enable-runtime-compile",
cl::desc("Enable runtime compilation"), cl::desc("Enable runtime compilation"),
cl::init(false)); cl::init(false));
cl::opt<bool> runtimeCompileTlsWorkaround(
"runtime-compile-tls-workaround",
cl::desc("Enable runtime compilation TLS workaround"),
cl::init(true),
cl::Hidden);
#endif #endif
static cl::extrahelp footer( static cl::extrahelp footer(

View file

@ -124,6 +124,7 @@ extern cl::opt<std::string> dcomputeFilePrefix;
#if defined(LDC_RUNTIME_COMPILE) #if defined(LDC_RUNTIME_COMPILE)
extern cl::opt<bool> enableRuntimeCompile; extern cl::opt<bool> enableRuntimeCompile;
extern cl::opt<bool> runtimeCompileTlsWorkaround;
#else #else
constexpr bool enableRuntimeCompile = false; constexpr bool enableRuntimeCompile = false;
#endif #endif

View file

@ -81,7 +81,9 @@ using GlobalValsMap =
void getPredefinedSymbols(IRState *irs, GlobalValsMap &symList) { void getPredefinedSymbols(IRState *irs, GlobalValsMap &symList) {
assert(nullptr != irs); assert(nullptr != irs);
const llvm::Triple *triple = global.params.targetTriple; const llvm::Triple *triple = global.params.targetTriple;
if (triple->isWindowsMSVCEnvironment() || triple->isWindowsGNUEnvironment()) { if (!opts::runtimeCompileTlsWorkaround) {
if (triple->isWindowsMSVCEnvironment() ||
triple->isWindowsGNUEnvironment()) {
symList.insert(std::make_pair( symList.insert(std::make_pair(
getPredefinedSymbol(irs->module, "_tls_index", getPredefinedSymbol(irs->module, "_tls_index",
llvm::Type::getInt32Ty(irs->context())), llvm::Type::getInt32Ty(irs->context())),
@ -93,6 +95,7 @@ void getPredefinedSymbols(IRState *irs, GlobalValsMap &symList) {
GlobalValVisibility::Declaration)); GlobalValVisibility::Declaration));
} }
} }
}
} }
GlobalValsMap createGlobalValsFilter(IRState *irs) { GlobalValsMap createGlobalValsFilter(IRState *irs) {
@ -174,6 +177,85 @@ void fixRtMudule(llvm::Module &newModule,
assert((thunk2func.size() + externalFuncs.size()) == objectsFixed); assert((thunk2func.size() + externalFuncs.size()) == objectsFixed);
} }
llvm::Function *createGlobalVarLoadFun(llvm::Module &module,
llvm::GlobalVariable *var,
const llvm::Twine &funcName) {
assert(nullptr != var);
auto &context = module.getContext();
auto varType = var->getType();
auto funcType = llvm::FunctionType::get(varType, false);
auto func = llvm::Function::Create(
funcType, llvm::GlobalValue::WeakODRLinkage, funcName, &module);
auto bb = llvm::BasicBlock::Create(context, "", func);
llvm::IRBuilder<> builder(context);
builder.SetInsertPoint(bb);
builder.CreateRet(var);
return func;
}
void replaceDynamicThreadLocals(llvm::Module &oldModule,
llvm::Module &newModule,
GlobalValsMap &valsMap) {
// Wrap all thread locals access in dynamic code by function calls
// to 'normal' code
std::unordered_map<llvm::GlobalVariable *, llvm::Function *>
threadLocalAccessors;
auto getAccessor = [&](llvm::GlobalVariable *var) {
assert(nullptr != var);
auto it = threadLocalAccessors.find(var);
if (threadLocalAccessors.end() != it) {
return it->second;
}
auto srcVar = oldModule.getGlobalVariable(var->getName());
assert(nullptr != srcVar);
auto srcFunc = createGlobalVarLoadFun(oldModule, srcVar,
"." + var->getName() + "_accessor");
srcFunc->addFnAttr(llvm::Attribute::NoInline);
auto dstFunc = llvm::Function::Create(srcFunc->getFunctionType(),
llvm::GlobalValue::ExternalLinkage,
srcFunc->getName(), &newModule);
threadLocalAccessors.insert({var, dstFunc});
valsMap.insert({srcFunc, GlobalValVisibility::Declaration});
return dstFunc;
};
for (auto &&fun : newModule.functions()) {
for (auto &&bb : fun) {
// We can change bb contents in this loop
// so we reiterate it from start after each change
bool bbChanged = true;
while (bbChanged) {
bbChanged = false;
for (auto &&instr : bb) {
for (unsigned int i = 0; i < instr.getNumOperands(); ++i) {
auto op = instr.getOperand(i);
if (auto globalVar = llvm::dyn_cast<llvm::GlobalVariable>(op)) {
if (globalVar->isThreadLocal()) {
auto accessor = getAccessor(globalVar);
assert(nullptr != accessor);
auto callResult = llvm::CallInst::Create(accessor);
callResult->insertBefore(&instr);
instr.setOperand(i, callResult);
bbChanged = true;
}
}
}
if (bbChanged)
break;
} // for (auto &&instr : bb)
} // while (bbChanged)
} // for (auto &&bb : fun)
} // for (auto &&fun : newModule.functions())
for (auto &&it : threadLocalAccessors) {
it.first->eraseFromParent();
}
}
// void hideExternalSymbols(llvm::Module &newModule, const GlobalValsMap // void hideExternalSymbols(llvm::Module &newModule, const GlobalValsMap
// &filter) { // &filter) {
// std::set<std::string> externalSymbols; // std::set<std::string> externalSymbols;
@ -595,6 +677,9 @@ void generateBitcodeForRuntimeCompile(IRState *irs) {
return filter.end() != it && return filter.end() != it &&
it->second != GlobalValVisibility::Declaration; it->second != GlobalValVisibility::Declaration;
}); });
if (opts::runtimeCompileTlsWorkaround) {
replaceDynamicThreadLocals(irs->module, *newModule, filter);
}
fixRtMudule(*newModule, irs->runtimeCompiledFunctions); fixRtMudule(*newModule, irs->runtimeCompiledFunctions);
// hideExternalSymbols(*newModule, filter); // hideExternalSymbols(*newModule, filter);

View file

@ -4,12 +4,5 @@ import platform
config.available_features.add('JitExceptions') config.available_features.add('JitExceptions')
config.available_features.add('JitThreadLocal') config.available_features.add('JitThreadLocal')
if (platform.system() == 'Linux'):
config.available_features.remove('JitThreadLocal')
if (platform.system() == 'Windows'): if (platform.system() == 'Windows'):
config.available_features.remove('JitExceptions') config.available_features.remove('JitExceptions')
config.available_features.remove('JitThreadLocal')
if (platform.system() == 'Darwin'):
config.available_features.remove('JitThreadLocal')

View file

@ -8,15 +8,30 @@ import ldc.runtimecompile;
ThreadID threadId; //thread local ThreadID threadId; //thread local
@runtimeCompile void foo() @runtimeCompile void set_val()
{ {
threadId = Thread.getThis().id(); threadId = Thread.getThis().id();
} }
@runtimeCompile ThreadID get_val()
{
return threadId;
}
@runtimeCompile ThreadID* get_ptr()
{
auto ptr = &threadId;
return ptr;
}
void bar() void bar()
{ {
foo(); set_val();
assert(threadId == Thread.getThis().id()); auto id = Thread.getThis().id();
assert(id == threadId);
assert(id == get_val());
assert(&threadId is get_ptr());
assert(id == *get_ptr());
} }
void main(string[] args) void main(string[] args)