Add -output-mlir and prepare for MLIR emission (#3313)

This commit is contained in:
Roberto Rosmaninho 2020-05-22 07:31:24 -03:00 committed by GitHub
parent 4eaa2fd864
commit 6274217c39
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
11 changed files with 151 additions and 26 deletions

View file

@ -28,7 +28,11 @@ else()
set(MLIR_LIB_DIR ${MLIR_ROOT_DIR}/lib)
# To be done: add the required MLIR libraries. Hopefully we don't have to manually list all MLIR libs.
set(MLIR_LIBRARIES "")
if(EXISTS "${MLIR_LIB_DIR}/MLIRIR.lib")
set(MLIR_LIBRARIES ${MLIR_LIB_DIR}/MLIRIR.lib ${MLIR_LIB_DIR}/MLIRSupport.lib)
elseif(EXISTS "${MLIR_LIB_DIR}/libMLIRIR.a")
set(MLIR_LIBRARIES ${MLIR_LIB_DIR}/libMLIRIR.a ${MLIR_LIB_DIR}/libMLIRSupport.a)
endif()
# XXX: This function is untested and will need adjustment.
function(mlir_tablegen)

View file

@ -550,6 +550,8 @@ version (IN_LLVM)
objExt = global.ll_ext;
else if (global.params.output_s)
objExt = global.s_ext;
else if (global.params.output_mlir)
objExt = global.mlir_ext;
if (objExt)
objfile = setOutfilename(global.params.objname, global.params.objdir, filename, objExt);

View file

@ -306,6 +306,7 @@ version (IN_LLVM)
// LDC stuff
OUTPUTFLAG output_ll;
OUTPUTFLAG output_mlir;
OUTPUTFLAG output_bc;
OUTPUTFLAG output_s;
OUTPUTFLAG output_o;
@ -345,6 +346,7 @@ extern (C++) struct Global
version (IN_LLVM)
{
const(char)[] ll_ext;
const(char)[] mlir_ext;
const(char)[] bc_ext;
const(char)[] s_ext;
const(char)[] ldc_version;
@ -510,6 +512,7 @@ else
vendor = "LDC";
obj_ext = "o";
ll_ext = "ll";
mlir_ext = "mlir";
bc_ext = "bc";
s_ext = "s";

View file

@ -268,6 +268,7 @@ struct Param
// LDC stuff
OUTPUTFLAG output_ll;
OUTPUTFLAG output_mlir;
OUTPUTFLAG output_bc;
OUTPUTFLAG output_s;
OUTPUTFLAG output_o;
@ -304,6 +305,7 @@ struct Global
DString obj_ext;
#if IN_LLVM
DString ll_ext;
DString mlir_ext; //MLIR code
DString bc_ext;
DString s_ext;
DString ldc_version;

View file

@ -204,6 +204,9 @@ cl::opt<bool> output_bc("output-bc", cl::desc("Write LLVM bitcode"),
cl::opt<bool> output_ll("output-ll", cl::desc("Write LLVM IR"), cl::ZeroOrMore);
cl::opt<bool> output_mlir("output-mlir", cl::desc("Write MLIR"),
cl::ZeroOrMore);
cl::opt<bool> output_s("output-s", cl::desc("Write native assembly"),
cl::ZeroOrMore);

View file

@ -52,6 +52,7 @@ extern cl::opt<std::string> objectDir;
extern cl::opt<std::string> soname;
extern cl::opt<bool> output_bc;
extern cl::opt<bool> output_ll;
extern cl::opt<bool> output_mlir;
extern cl::opt<bool> output_s;
extern cl::opt<cl::boolOrDefault> output_o;
extern cl::opt<std::string> ddocDir;

View file

@ -219,8 +219,16 @@ void inlineAsmDiagnosticHandler(const llvm::SMDiagnostic &d, void *context,
} // anonymous namespace
namespace ldc {
CodeGenerator::CodeGenerator(llvm::LLVMContext &context, bool singleObj)
: context_(context), moduleCount_(0), singleObj_(singleObj), ir_(nullptr) {
CodeGenerator::CodeGenerator(llvm::LLVMContext &context,
#if LDC_MLIR_ENABLED
mlir::MLIRContext &mlirContext,
#endif
bool singleObj)
: context_(context),
#if LDC_MLIR_ENABLED
mlirContext_(mlirContext),
#endif
moduleCount_(0), singleObj_(singleObj), ir_(nullptr) {
// Set the context to discard value names when not generating textual IR.
if (!global.params.output_ll) {
context_.setDiscardValueNames(true);
@ -274,7 +282,6 @@ void CodeGenerator::finishLLModule(Module *m) {
if (moduleCount_ == 1) {
insertBitcodeFiles(ir_->module, ir_->context(), global.params.bitcodeFiles);
}
writeAndFreeLLModule(m->objfile.toChars());
}
@ -341,4 +348,58 @@ void CodeGenerator::emit(Module *m) {
Logger::disable();
}
}
#if LDC_MLIR_ENABLED
void CodeGenerator::emitMLIR(Module *m) {
bool const loggerWasEnabled = Logger::enabled();
if (m->llvmForceLogging && !loggerWasEnabled) {
Logger::enable();
}
IF_LOG Logger::println("CodeGenerator::emitMLIR(%s)", m->toPrettyChars());
LOG_SCOPE;
if (global.params.verbose_cg) {
printf("codegen: %s (%s)\n", m->toPrettyChars(), m->srcfile.toChars());
}
if (global.errors) {
Logger::println("Aborting because of errors");
fatal();
}
mlir::OwningModuleRef module;
/*module = mlirGen(mlirContext, m, irs);
if(!module){
IF_LOG Logger::println("Error generating MLIR:'%s'", llpath.c_str());
fatal();
}*/
writeMLIRModule(&module, m->objfile.toChars());
if (m->llvmForceLogging && !loggerWasEnabled) {
Logger::disable();
}
}
void CodeGenerator::writeMLIRModule(mlir::OwningModuleRef *module,
const char *filename) {
// Write MLIR
if (global.params.output_mlir) {
const auto llpath = replaceExtensionWith(global.mlir_ext, filename);
Logger::println("Writting MLIR to %s\n", llpath.c_str());
std::error_code errinfo;
llvm::raw_fd_ostream aos(llpath, errinfo, llvm::sys::fs::F_None);
if (aos.has_error()) {
error(Loc(), "Cannot write MLIR file '%s': %s", llpath.c_str(),
errinfo.message().c_str());
fatal();
}
// module->print(aos);
}
}
#endif
}

View file

@ -20,21 +20,40 @@
#pragma once
#include "gen/irstate.h"
#if LDC_MLIR_ENABLED
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/Module.h"
#endif
namespace ldc {
class CodeGenerator {
public:
CodeGenerator(llvm::LLVMContext &context, bool singleObj);
CodeGenerator(llvm::LLVMContext &context,
#if LDC_MLIR_ENABLED
mlir::MLIRContext &mlirContext,
#endif
bool singleObj);
~CodeGenerator();
void emit(Module *m);
#if LDC_MLIR_ENABLED
void emitMLIR(Module *m);
#endif
private:
void prepareLLModule(Module *m);
void finishLLModule(Module *m);
void writeAndFreeLLModule(const char *filename);
#if LDC_MLIR_ENABLED
void writeMLIRModule(mlir::OwningModuleRef *module, const char *filename);
#endif
llvm::LLVMContext &context_;
#if LDC_MLIR_ENABLED
mlir::MLIRContext &mlirContext_;
#endif
int moduleCount_;
bool const singleObj_;
IRState *ir_;

View file

@ -439,11 +439,13 @@ void parseCommandLine(Strings &sourceFiles) {
global.params.output_o =
(opts::output_o == cl::BOU_UNSET &&
!(opts::output_bc || opts::output_ll || opts::output_s))
!(opts::output_bc || opts::output_ll || opts::output_s ||
opts::output_mlir))
? OUTPUTFLAGdefault
: opts::output_o == cl::BOU_TRUE ? OUTPUTFLAGset : OUTPUTFLAGno;
global.params.output_bc = opts::output_bc ? OUTPUTFLAGset : OUTPUTFLAGno;
global.params.output_ll = opts::output_ll ? OUTPUTFLAGset : OUTPUTFLAGno;
global.params.output_mlir = opts::output_mlir ? OUTPUTFLAGset : OUTPUTFLAGno;
global.params.output_s = opts::output_s ? OUTPUTFLAGset : OUTPUTFLAGno;
global.params.cov = (global.params.covPercent <= 100);
@ -509,9 +511,20 @@ void parseCommandLine(Strings &sourceFiles) {
strcmp(ext, global.s_ext.ptr) == 0) {
global.params.output_s = OUTPUTFLAGset;
global.params.output_o = OUTPUTFLAGno;
} else if (opts::output_mlir.getNumOccurrences() == 0 &&
strcmp(ext, global.mlir_ext.ptr) == 0) {
global.params.output_mlir = OUTPUTFLAGset;
global.params.output_o = OUTPUTFLAGno;
}
}
#ifndef LDC_MLIR_ENABLED
if (global.params.output_mlir == OUTPUTFLAGset) {
error(Loc(), "MLIR output requested but this LDC was built without MLIR support");
fatal();
}
#endif
if (soname.getNumOccurrences() > 0 && !global.params.dll) {
error(Loc(), "-soname can be used only when building a shared library");
}
@ -1094,7 +1107,13 @@ int cppmain() {
void codegenModules(Modules &modules) {
// Generate one or more object/IR/bitcode files/dcompute kernels.
if (global.params.obj && !modules.empty()) {
#if LDC_MLIR_ENABLED
mlir::MLIRContext mlircontext;
ldc::CodeGenerator cg(getGlobalContext(), mlircontext,
global.params.oneobj);
#else
ldc::CodeGenerator cg(getGlobalContext(), global.params.oneobj);
#endif
DComputeCodeGenManager dccg(getGlobalContext());
std::vector<Module *> computeModules;
// When inlining is enabled, we are calling semantic3 on function
@ -1117,6 +1136,11 @@ void codegenModules(Modules &modules) {
const auto atCompute = hasComputeAttr(m);
if (atCompute == DComputeCompileFor::hostOnly ||
atCompute == DComputeCompileFor::hostAndDevice) {
#if LDC_MLIR_ENABLED
if (global.params.output_mlir == OUTPUTFLAGset)
cg.emitMLIR(m);
else
#endif
cg.emit(m);
}
if (atCompute != DComputeCompileFor::hostOnly) {

View file

@ -309,6 +309,23 @@ bool shouldDoLTO(llvm::Module *m) {
}
} // end of anonymous namespace
std::string replaceExtensionWith(const DArray<const char> &ext,
const char *filename) {
const auto outputFlags = {global.params.output_o, global.params.output_bc,
global.params.output_ll, global.params.output_s,
global.params.output_mlir};
const auto numOutputFiles =
std::count_if(outputFlags.begin(), outputFlags.end(),
[](OUTPUTFLAG flag) { return flag != 0; });
if (numOutputFiles == 1)
return filename;
llvm::SmallString<128> buffer(filename);
llvm::sys::path::replace_extension(buffer,
llvm::StringRef(ext.ptr, ext.length));
return {buffer.data(), buffer.size()};
}
void writeModule(llvm::Module *m, const char *filename) {
const bool doLTO = shouldDoLTO(m);
const bool outputObj = shouldOutputObjectFile();
@ -349,29 +366,13 @@ void writeModule(llvm::Module *m, const char *filename) {
}
}
const auto outputFlags = {global.params.output_o, global.params.output_bc,
global.params.output_ll, global.params.output_s};
const auto numOutputFiles =
std::count_if(outputFlags.begin(), outputFlags.end(),
[](OUTPUTFLAG flag) { return flag != 0; });
const auto replaceExtensionWith =
[=](const DArray<const char> &ext) -> std::string {
if (numOutputFiles == 1)
return filename;
llvm::SmallString<128> buffer(filename);
llvm::sys::path::replace_extension(buffer,
llvm::StringRef(ext.ptr, ext.length));
return {buffer.data(), buffer.size()};
};
// write LLVM bitcode
const bool emitBitcodeAsObjectFile =
doLTO && outputObj && !global.params.output_bc;
if (global.params.output_bc || emitBitcodeAsObjectFile) {
std::string bcpath = emitBitcodeAsObjectFile
? filename
: replaceExtensionWith(global.bc_ext);
: replaceExtensionWith(global.bc_ext, filename);
Logger::println("Writing LLVM bitcode to: %s\n", bcpath.c_str());
std::error_code errinfo;
llvm::raw_fd_ostream bos(bcpath.c_str(), errinfo, llvm::sys::fs::F_None);
@ -413,7 +414,7 @@ void writeModule(llvm::Module *m, const char *filename) {
// write LLVM IR
if (global.params.output_ll) {
const auto llpath = replaceExtensionWith(global.ll_ext);
const auto llpath = replaceExtensionWith(global.ll_ext, filename);
Logger::println("Writing LLVM IR to: %s\n", llpath.c_str());
std::error_code errinfo;
llvm::raw_fd_ostream aos(llpath.c_str(), errinfo, llvm::sys::fs::F_None);
@ -435,7 +436,7 @@ void writeModule(llvm::Module *m, const char *filename) {
llvm::sys::fs::createUniqueFile("ldc-%%%%%%%.s", buffer);
spath = {buffer.data(), buffer.size()};
} else {
spath = replaceExtensionWith(global.s_ext);
spath = replaceExtensionWith(global.s_ext, filename);
}
Logger::println("Writing asm to: %s\n", spath.c_str());

View file

@ -12,9 +12,14 @@
//===----------------------------------------------------------------------===//
#pragma once
#include <string>
#include "dmd/root/dcompat.h"
namespace llvm {
class Module;
}
void writeModule(llvm::Module *m, const char *filename);
std::string replaceExtensionWith(const DArray<const char> &ext,
const char *filename);