jit: refactor valueparser (#2744)

This commit is contained in:
Ivan Butygin 2018-06-13 21:24:04 +03:00 committed by GitHub
parent 65337ed170
commit d2c55491c4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
3 changed files with 129 additions and 51 deletions

View file

@ -159,7 +159,8 @@ void setRtCompileVar(const Context &context, llvm::Module &module,
if (nullptr != var) { if (nullptr != var) {
auto type = var->getType()->getElementType(); auto type = var->getType()->getElementType();
auto initializer = auto initializer =
parseInitializer(context, module.getDataLayout(), type, init); parseInitializer(module.getDataLayout(), *type, init,
[&](const std::string &str) { fatal(context, str); });
var->setConstant(true); var->setConstant(true);
var->setInitializer(initializer); var->setInitializer(initializer);
var->setLinkage(llvm::GlobalValue::PrivateLinkage); var->setLinkage(llvm::GlobalValue::PrivateLinkage);

View file

@ -21,70 +21,96 @@
#include "llvm/Support/raw_ostream.h" #include "llvm/Support/raw_ostream.h"
namespace { namespace {
void checkOverrideType(
llvm::Type &type, llvm::Constant &val,
const llvm::function_ref<void(const std::string &)> &errHandler) {
auto retType = val.getType();
if (retType != &type) {
std::string str;
llvm::raw_string_ostream ss(str);
ss << "Override type mismatch, expected \"";
type.print(ss);
ss << "\", got \"";
retType->print(ss);
ss << "\"";
ss.flush();
errHandler(str);
}
}
template <typename T> template <typename T>
llvm::ConstantInt *getInt(llvm::LLVMContext &context, const void *data) { llvm::Constant *
callOverride(const ParseInitializerOverride &override, llvm::Type &type,
const T &val,
const llvm::function_ref<void(const std::string &)> &errHandler) {
if (override) {
auto ptr = reinterpret_cast<const char *>(&val);
auto ret = (*override)(type, ptr, sizeof(val));
if (ret != nullptr) {
checkOverrideType(type, *ret, errHandler);
}
return ret;
}
return nullptr;
}
template <typename T>
llvm::Constant *
getInt(llvm::LLVMContext &context, const void *data, llvm::Type &type,
const llvm::function_ref<void(const std::string &)> &errHandler,
const ParseInitializerOverride &override) {
assert(nullptr != data); assert(nullptr != data);
const T val = *static_cast<const T *>(data); const T val = *static_cast<const T *>(data);
if (auto ret = callOverride(override, type, val, errHandler)) {
return ret;
}
return llvm::ConstantInt::get(context, llvm::APInt(sizeof(T) * 8, val, true)); return llvm::ConstantInt::get(context, llvm::APInt(sizeof(T) * 8, val, true));
} }
template <typename T> template <typename T>
llvm::ConstantFP *getFloat(llvm::LLVMContext &context, const void *data) { llvm::Constant *
getFloat(llvm::LLVMContext &context, const void *data, llvm::Type &type,
const llvm::function_ref<void(const std::string &)> &errHandler,
const ParseInitializerOverride &override) {
assert(nullptr != data); assert(nullptr != data);
const T val = *static_cast<const T *>(data); const T val = *static_cast<const T *>(data);
if (auto ret = callOverride(override, type, val, errHandler)) {
return ret;
}
return llvm::ConstantFP::get(context, llvm::APFloat(val)); return llvm::ConstantFP::get(context, llvm::APFloat(val));
} }
llvm::Constant *getPtr(llvm::LLVMContext &context, llvm::Type *targetType, llvm::Constant *
const void *data) { getPtr(llvm::LLVMContext &context, const void *data, llvm::Type &type,
assert(nullptr != targetType); const llvm::function_ref<void(const std::string &)> &errHandler,
const ParseInitializerOverride &override) {
assert(nullptr != data); assert(nullptr != data);
const auto val = *static_cast<const uintptr_t *>(data); const auto val = *static_cast<const uintptr_t *>(data);
if (auto ret = callOverride(override, type, val, errHandler)) {
return ret;
}
return llvm::ConstantExpr::getIntToPtr( return llvm::ConstantExpr::getIntToPtr(
llvm::ConstantInt::get(context, llvm::APInt(sizeof(val) * 8, val)), llvm::ConstantInt::get(context, llvm::APInt(sizeof(val) * 8, val)),
targetType); &type);
}
} }
llvm::Constant *parseInitializer(const Context &context, llvm::Constant *
const llvm::DataLayout &dataLayout, getStruct(const void *data, const llvm::DataLayout &dataLayout,
llvm::Type *type, const void *data) { llvm::Type &type,
assert(nullptr != type); const llvm::function_ref<void(const std::string &)> &errHandler,
const ParseInitializerOverride &override) {
assert(nullptr != data); assert(nullptr != data);
auto &llcontext = type->getContext(); if (override) {
if (type->isIntegerTy()) { auto size = dataLayout.getTypeStoreSize(&type);
const auto width = type->getIntegerBitWidth(); auto ptr = static_cast<const char *>(data);
switch (width) { auto ret = (*override)(type, ptr, size);
case 8: if (ret != nullptr) {
return getInt<uint8_t>(llcontext, data); checkOverrideType(type, *ret, errHandler);
case 16: return ret;
return getInt<uint16_t>(llcontext, data);
case 32:
return getInt<uint32_t>(llcontext, data);
case 64:
return getInt<uint64_t>(llcontext, data);
default:
fatal(context,
std::string("Invalid int bit width: ") + std::to_string(width));
} }
} }
if (type->isFloatingPointTy()) { auto stype = llvm::cast<llvm::StructType>(&type);
const auto width = type->getPrimitiveSizeInBits();
switch (width) {
case 32:
return getFloat<float>(llcontext, data);
case 64:
return getFloat<double>(llcontext, data);
default:
fatal(context,
std::string("Invalid fp bit width: ") + std::to_string(width));
}
}
if (type->isPointerTy()) {
return getPtr(llcontext, type, data);
}
if (type->isStructTy()) {
auto stype = llvm::cast<llvm::StructType>(type);
auto slayout = dataLayout.getStructLayout(stype); auto slayout = dataLayout.getStructLayout(stype);
auto numElements = stype->getNumElements(); auto numElements = stype->getNumElements();
llvm::SmallVector<llvm::Constant *, 16> elements(numElements); llvm::SmallVector<llvm::Constant *, 16> elements(numElements);
@ -92,25 +118,71 @@ llvm::Constant *parseInitializer(const Context &context,
const auto elemType = stype->getElementType(i); const auto elemType = stype->getElementType(i);
const auto elemOffset = slayout->getElementOffset(i); const auto elemOffset = slayout->getElementOffset(i);
const auto elemPtr = static_cast<const char *>(data) + elemOffset; const auto elemPtr = static_cast<const char *>(data) + elemOffset;
elements[i] = parseInitializer(context, dataLayout, elemType, elemPtr); elements[i] =
parseInitializer(dataLayout, *elemType, elemPtr, errHandler, override);
} }
return llvm::ConstantStruct::get(stype, elements); return llvm::ConstantStruct::get(stype, elements);
} }
if (type->isArrayTy()) { }
auto elemType = type->getArrayElementType();
llvm::Constant *
parseInitializer(const llvm::DataLayout &dataLayout, llvm::Type &type,
const void *data,
llvm::function_ref<void(const std::string &)> errHandler,
const ParseInitializerOverride &override) {
assert(nullptr != data);
auto &llcontext = type.getContext();
if (type.isIntegerTy()) {
const auto width = type.getIntegerBitWidth();
switch (width) {
case 8:
return getInt<uint8_t>(llcontext, data, type, errHandler, override);
case 16:
return getInt<uint16_t>(llcontext, data, type, errHandler, override);
case 32:
return getInt<uint32_t>(llcontext, data, type, errHandler, override);
case 64:
return getInt<uint64_t>(llcontext, data, type, errHandler, override);
default:
errHandler(std::string("Invalid int bit width: ") +
std::to_string(width));
return nullptr;
}
}
if (type.isFloatingPointTy()) {
const auto width = type.getPrimitiveSizeInBits();
switch (width) {
case 32:
return getFloat<float>(llcontext, data, type, errHandler, override);
case 64:
return getFloat<double>(llcontext, data, type, errHandler, override);
default:
errHandler(std::string("Invalid fp bit width: ") + std::to_string(width));
return nullptr;
}
}
if (type.isPointerTy()) {
return getPtr(llcontext, data, type, errHandler, override);
}
if (type.isStructTy()) {
return getStruct(data, dataLayout, type, errHandler, override);
}
if (type.isArrayTy()) {
auto elemType = type.getArrayElementType();
const auto step = dataLayout.getTypeAllocSize(elemType); const auto step = dataLayout.getTypeAllocSize(elemType);
const auto numElements = type->getArrayNumElements(); const auto numElements = type.getArrayNumElements();
llvm::SmallVector<llvm::Constant *, 16> elements(numElements); llvm::SmallVector<llvm::Constant *, 16> elements(numElements);
for (uint64_t i = 0; i < numElements; ++i) { for (uint64_t i = 0; i < numElements; ++i) {
const auto elemPtr = static_cast<const char *>(data) + step * i; const auto elemPtr = static_cast<const char *>(data) + step * i;
elements[i] = parseInitializer(context, dataLayout, elemType, elemPtr); elements[i] = parseInitializer(dataLayout, *elemType, elemPtr, errHandler,
override);
} }
return llvm::ConstantArray::get(llvm::cast<llvm::ArrayType>(type), return llvm::ConstantArray::get(llvm::cast<llvm::ArrayType>(&type),
elements); elements);
} }
std::string tname; std::string tname;
llvm::raw_string_ostream os(tname); llvm::raw_string_ostream os(tname);
type->print(os, true); type.print(os, true);
fatal(context, std::string("Unhandled type: ") + os.str()); errHandler(std::string("Unhandled type: ") + os.str());
return nullptr; return nullptr;
} }

View file

@ -16,16 +16,21 @@
#ifndef VALUEPARSER_H #ifndef VALUEPARSER_H
#define VALUEPARSER_H #define VALUEPARSER_H
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/STLExtras.h"
namespace llvm { namespace llvm {
class Constant; class Constant;
class Type; class Type;
class DataLayout; class DataLayout;
} }
struct Context; using ParseInitializerOverride = llvm::Optional<
llvm::function_ref<llvm::Constant *(llvm::Type &, const void *, size_t)>>;
llvm::Constant *parseInitializer(const Context &context, llvm::Constant *parseInitializer(
const llvm::DataLayout &dataLayout, const llvm::DataLayout &dataLayout, llvm::Type &type, const void *data,
llvm::Type *type, const void *data); llvm::function_ref<void(const std::string &)> errHandler,
const ParseInitializerOverride &override = ParseInitializerOverride{});
#endif // VALUEPARSER_H #endif // VALUEPARSER_H