Add transformation of the NVVM dialect to an LLVM module. Only handles
authorStephan Herhut <herhut@google.com>
Tue, 30 Apr 2019 13:08:21 +0000 (06:08 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 6 May 2019 15:22:14 +0000 (08:22 -0700)
    the generation of intrinsics out of NVVM index ops for now.

--

PiperOrigin-RevId: 245933152

mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h [new file with mode: 0644]
mlir/include/mlir/Target/NVVMIR.h [new file with mode: 0644]
mlir/lib/Target/CMakeLists.txt
mlir/lib/Target/LLVMIR/CMakeLists.txt [deleted file]
mlir/lib/Target/LLVMIR/ConvertToLLVMIR.cpp
mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp [new file with mode: 0644]
mlir/lib/Target/LLVMIR/ModuleTranslation.cpp [new file with mode: 0644]
mlir/test/Target/nvvmir.mlir [new file with mode: 0644]
mlir/tools/mlir-cpu-runner/CMakeLists.txt
mlir/tools/mlir-translate/CMakeLists.txt
mlir/tools/mlir-translate/mlir-translate.cpp

diff --git a/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h b/mlir/include/mlir/Target/LLVMIR/ModuleTranslation.h
new file mode 100644 (file)
index 0000000..5e35aab
--- /dev/null
@@ -0,0 +1,98 @@
+//===- ModuleTranslation.h - MLIR to LLVM conversion ------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the translation between an MLIR LLVM dialect module and
+// the corresponding LLVMIR module. It only handles core LLVM IR operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
+#define MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
+
+#include "mlir/IR/Block.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Value.h"
+
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Function.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/Value.h"
+
+namespace mlir {
+class Attribute;
+class Location;
+class Module;
+class Operation;
+
+namespace LLVM {
+
+// Implementation class for module translation.  Holds a reference to the module
+// being translated, and the mappings between the original and the translated
+// functions, basic blocks and values.  It is practically easier to hold these
+// mappings in one class since the conversion of control flow operations
+// needs to look up block and function mappings.
+class ModuleTranslation {
+public:
+  template <typename T = ModuleTranslation>
+  static std::unique_ptr<llvm::Module> translateModule(Module &m) {
+    auto llvmModule = prepareLLVMModule(m);
+
+    T translator(m);
+    translator.llvmModule = std::move(llvmModule);
+    if (translator.convertFunctions())
+      return nullptr;
+
+    return std::move(translator.llvmModule);
+  }
+
+protected:
+  // Translate the given MLIR module expressed in MLIR LLVM IR dialect into an
+  // LLVM IR module.  The MLIR LLVM IR dialect holds a pointer to an
+  // LLVMContext, the LLVM IR module will be created in that context.
+  explicit ModuleTranslation(Module &module) : mlirModule(module) {}
+  virtual ~ModuleTranslation() {}
+
+  virtual bool convertOperation(Operation &op, llvm::IRBuilder<> &builder);
+  static std::unique_ptr<llvm::Module> prepareLLVMModule(Module &m);
+
+private:
+
+  bool convertFunctions();
+  bool convertOneFunction(Function &func);
+  void connectPHINodes(Function &func);
+  bool convertBlock(Block &bb, bool ignoreArguments);
+
+  template <typename Range>
+  SmallVector<llvm::Value *, 8> lookupValues(Range &&values);
+
+  llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
+                                  Location loc);
+
+  // Original and translated module.
+  Module &mlirModule;
+  std::unique_ptr<llvm::Module> llvmModule;
+
+  // Mappings between original and translated values, used for lookups.
+  llvm::DenseMap<Function *, llvm::Function *> functionMapping;
+  llvm::DenseMap<Value *, llvm::Value *> valueMapping;
+  llvm::DenseMap<Block *, llvm::BasicBlock *> blockMapping;
+};
+
+} // namespace LLVM
+} // namespace mlir
+
+#endif // MLIR_TARGET_LLVMIR_MODULETRANSLATION_H
diff --git a/mlir/include/mlir/Target/NVVMIR.h b/mlir/include/mlir/Target/NVVMIR.h
new file mode 100644 (file)
index 0000000..27f964e
--- /dev/null
@@ -0,0 +1,45 @@
+//===- NVVMIR.h - MLIR to LLVM + NVVM IR conversion -------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file declares the entry point for the MLIR to LLVM + NVVM IR conversion.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TARGET_NVVMIR_H
+#define MLIR_TARGET_NVVMIR_H
+
+#include <memory>
+
+// Forward-declare LLVM classses.
+namespace llvm {
+class Module;
+} // namespace llvm
+
+namespace mlir {
+
+class Module;
+
+/// Convert the given MLIR module into NVVM IR. This conversion requires the
+/// registration of the LLVM IR dialect and will extract the LLVM context
+/// from the registered LLVM IR dialect.  In case of error, report it
+/// to the error handler registered with the MLIR context, if any (obtained from
+/// the MLIR module), and return `nullptr`.
+std::unique_ptr<llvm::Module> translateModuleToNVVMIR(Module &m);
+
+} // namespace mlir
+
+#endif // MLIR_TARGET_NVVMIR_H
index 39d31dc..2652f67 100644 (file)
@@ -1 +1,25 @@
-add_subdirectory(LLVMIR)
+add_llvm_library(MLIRTargetLLVMIRModuleTranslation
+  LLVMIR/ModuleTranslation.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR
+  DEPENDS
+  intrinsics_gen
+  )
+target_link_libraries(MLIRTargetLLVMIRModuleTranslation MLIRLLVMIR LLVMCore LLVMSupport LLVMTransformUtils MLIRTranslation)
+add_llvm_library(MLIRTargetLLVMIR
+  LLVMIR/ConvertToLLVMIR.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR
+  )
+target_link_libraries(MLIRTargetLLVMIR MLIRTargetLLVMIRModuleTranslation)
+add_llvm_library(MLIRTargetNVVMIR
+  LLVMIR/ConvertToNVVMIR.cpp
+
+  ADDITIONAL_HEADER_DIRS
+  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR
+  DEPENDS
+  intrinsics_gen
+  )
+target_link_libraries(MLIRTargetNVVMIR MLIRNVVMIR MLIRTargetLLVMIRModuleTranslation)
diff --git a/mlir/lib/Target/LLVMIR/CMakeLists.txt b/mlir/lib/Target/LLVMIR/CMakeLists.txt
deleted file mode 100644 (file)
index 26534dc..0000000
+++ /dev/null
@@ -1,9 +0,0 @@
-add_llvm_library(MLIRTargetLLVMIR
-  ConvertToLLVMIR.cpp
-
-  ADDITIONAL_HEADER_DIRS
-  ${MLIR_MAIN_INCLUDE_DIR}/mlir/Target/LLVMIR
-  DEPENDS
-  intrinsics_gen
-  )
-target_link_libraries(MLIRTargetLLVMIR MLIRLLVMIR MLIRTranslation LLVMCore LLVMSupport LLVMTransformUtils)
index 6a4ab56..ff9aa99 100644 (file)
 //
 //===----------------------------------------------------------------------===//
 
-#include "mlir/IR/Attributes.h"
-#include "mlir/IR/Module.h"
-#include "mlir/LLVMIR/LLVMDialect.h"
-#include "mlir/StandardOps/Ops.h"
-#include "mlir/Support/FileUtilities.h"
-#include "mlir/Support/LLVM.h"
 #include "mlir/Target/LLVMIR.h"
+
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
 #include "mlir/Translation.h"
 
-#include "llvm/ADT/SetVector.h"
-#include "llvm/IR/BasicBlock.h"
-#include "llvm/IR/Constants.h"
-#include "llvm/IR/DerivedTypes.h"
-#include "llvm/IR/IRBuilder.h"
-#include "llvm/IR/LLVMContext.h"
+#include "llvm/ADT/StringRef.h"
 #include "llvm/IR/Module.h"
 #include "llvm/Support/ToolOutputFile.h"
-#include "llvm/Transforms/Utils/Cloning.h"
 
 using namespace mlir;
 
-namespace {
-// Implementation class for module translation.  Holds a reference to the module
-// being translated, and the mappings between the original and the translated
-// functions, basic blocks and values.  It is practically easier to hold these
-// mappings in one class since the conversion of control flow operations
-// needs to look up block and function mappings.
-class ModuleTranslation {
-public:
-  // Translate the given MLIR module expressed in MLIR LLVM IR dialect into an
-  // LLVM IR module.  The MLIR LLVM IR dialect holds a pointer to an
-  // LLVMContext, the LLVM IR module will be created in that context.
-  static std::unique_ptr<llvm::Module> translateModule(Module &m);
-
-private:
-  explicit ModuleTranslation(Module &module) : mlirModule(module) {}
-
-  bool convertFunctions();
-  bool convertOneFunction(Function &func);
-  void connectPHINodes(Function &func);
-  bool convertBlock(Block &bb, bool ignoreArguments);
-  bool convertOperation(Operation &op, llvm::IRBuilder<> &builder);
-
-  template <typename Range>
-  SmallVector<llvm::Value *, 8> lookupValues(Range &&values);
-
-  llvm::Constant *getLLVMConstant(llvm::Type *llvmType, Attribute attr,
-                                  Location loc);
-
-  // Original and translated module.
-  Module &mlirModule;
-  std::unique_ptr<llvm::Module> llvmModule;
-
-  // Mappings between original and translated values, used for lookups.
-  llvm::DenseMap<Function *, llvm::Function *> functionMapping;
-  llvm::DenseMap<Value *, llvm::Value *> valueMapping;
-  llvm::DenseMap<Block *, llvm::BasicBlock *> blockMapping;
-};
-} // end anonymous namespace
-
-// Convert an MLIR function type to LLVM IR.  Arguments of the function must of
-// MLIR LLVM IR dialect types.  Use `loc` as a location when reporting errors.
-// Return nullptr on errors.
-static llvm::FunctionType *convertFunctionType(llvm::LLVMContext &llvmContext,
-                                               FunctionType type, Location loc,
-                                               bool isVarArgs) {
-  assert(type && "expected non-null type");
-
-  auto context = type.getContext();
-  if (type.getNumResults() > 1)
-    return context->emitError(loc,
-                              "LLVM functions can only have 0 or 1 result"),
-           nullptr;
-
-  SmallVector<llvm::Type *, 8> argTypes;
-  argTypes.reserve(type.getNumInputs());
-  for (auto t : type.getInputs()) {
-    auto wrappedLLVMType = t.dyn_cast<LLVM::LLVMType>();
-    if (!wrappedLLVMType)
-      return context->emitError(loc, "non-LLVM function argument type"),
-             nullptr;
-    argTypes.push_back(wrappedLLVMType.getUnderlyingType());
-  }
-
-  if (type.getNumResults() == 0)
-    return llvm::FunctionType::get(llvm::Type::getVoidTy(llvmContext), argTypes,
-                                   isVarArgs);
-
-  auto wrappedResultType = type.getResult(0).dyn_cast<LLVM::LLVMType>();
-  if (!wrappedResultType)
-    return context->emitError(loc, "non-LLVM function result"), nullptr;
-
-  return llvm::FunctionType::get(wrappedResultType.getUnderlyingType(),
-                                 argTypes, isVarArgs);
-}
-
-// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
-// This currently supports integer, floating point, splat and dense element
-// attributes and combinations thereof.  In case of error, report it to `loc`
-// and return nullptr.
-llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType,
-                                                   Attribute attr,
-                                                   Location loc) {
-  if (auto intAttr = attr.dyn_cast<IntegerAttr>())
-    return llvm::ConstantInt::get(llvmType, intAttr.getValue());
-  if (auto floatAttr = attr.dyn_cast<FloatAttr>())
-    return llvm::ConstantFP::get(llvmType, floatAttr.getValue());
-  if (auto funcAttr = attr.dyn_cast<FunctionAttr>())
-    return functionMapping.lookup(funcAttr.getValue());
-  if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
-    auto *vectorType = cast<llvm::VectorType>(llvmType);
-    auto *child = getLLVMConstant(vectorType->getElementType(),
-                                  splatAttr.getValue(), loc);
-    return llvm::ConstantVector::getSplat(vectorType->getNumElements(), child);
-  }
-  if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>()) {
-    auto *vectorType = cast<llvm::VectorType>(llvmType);
-    SmallVector<llvm::Constant *, 8> constants;
-    uint64_t numElements = vectorType->getNumElements();
-    constants.reserve(numElements);
-    SmallVector<Attribute, 8> nested;
-    denseAttr.getValues(nested);
-    for (auto n : nested) {
-      constants.push_back(
-          getLLVMConstant(vectorType->getElementType(), n, loc));
-      if (!constants.back())
-        return nullptr;
-    }
-    return llvm::ConstantVector::get(constants);
-  }
-  mlirModule.getContext()->emitError(loc, "unsupported constant value");
-  return nullptr;
-}
-
-// Convert MLIR integer comparison predicate to LLVM IR comparison predicate.
-static llvm::CmpInst::Predicate getLLVMCmpPredicate(CmpIPredicate p) {
-  switch (p) {
-  case CmpIPredicate::EQ:
-    return llvm::CmpInst::Predicate::ICMP_EQ;
-  case CmpIPredicate::NE:
-    return llvm::CmpInst::Predicate::ICMP_NE;
-  case CmpIPredicate::SLT:
-    return llvm::CmpInst::Predicate::ICMP_SLT;
-  case CmpIPredicate::SLE:
-    return llvm::CmpInst::Predicate::ICMP_SLE;
-  case CmpIPredicate::SGT:
-    return llvm::CmpInst::Predicate::ICMP_SGT;
-  case CmpIPredicate::SGE:
-    return llvm::CmpInst::Predicate::ICMP_SGE;
-  case CmpIPredicate::ULT:
-    return llvm::CmpInst::Predicate::ICMP_ULT;
-  case CmpIPredicate::ULE:
-    return llvm::CmpInst::Predicate::ICMP_ULE;
-  case CmpIPredicate::UGT:
-    return llvm::CmpInst::Predicate::ICMP_UGT;
-  case CmpIPredicate::UGE:
-    return llvm::CmpInst::Predicate::ICMP_UGE;
-  default:
-    llvm_unreachable("incorrect comparison predicate");
-  }
-}
-
-// A helper to look up remapped operands in the value remapping table.
-template <typename Range>
-SmallVector<llvm::Value *, 8> ModuleTranslation::lookupValues(Range &&values) {
-  SmallVector<llvm::Value *, 8> remapped;
-  remapped.reserve(llvm::size(values));
-  for (Value *v : values) {
-    remapped.push_back(valueMapping.lookup(v));
-  }
-  return remapped;
-}
-
-// Given a single MLIR operation, create the corresponding LLVM IR operation
-// using the `builder`.  LLVM IR Builder does not have a generic interface so
-// this has to be a long chain of `if`s calling different functions with a
-// different number of arguments.
-bool ModuleTranslation::convertOperation(Operation &opInst,
-                                         llvm::IRBuilder<> &builder) {
-  auto extractPosition = [](ArrayAttr attr) {
-    SmallVector<unsigned, 4> position;
-    position.reserve(attr.size());
-    for (Attribute v : attr)
-      position.push_back(v.cast<IntegerAttr>().getValue().getZExtValue());
-    return position;
-  };
-
-#include "mlir/LLVMIR/LLVMConversions.inc"
-
-  // Emit function calls.  If the "callee" attribute is present, this is a
-  // direct function call and we also need to look up the remapped function
-  // itself.  Otherwise, this is an indirect call and the callee is the first
-  // operand, look it up as a normal value.  Return the llvm::Value representing
-  // the function result, which may be of llvm::VoidTy type.
-  auto convertCall = [this, &builder](Operation &op) -> llvm::Value * {
-    auto operands = lookupValues(op.getOperands());
-    ArrayRef<llvm::Value *> operandsRef(operands);
-    if (auto attr = op.getAttrOfType<FunctionAttr>("callee")) {
-      return builder.CreateCall(functionMapping.lookup(attr.getValue()),
-                                operandsRef);
-    } else {
-      return builder.CreateCall(operandsRef.front(), operandsRef.drop_front());
-    }
-  };
-
-  // Emit calls.  If the called function has a result, remap the corresponding
-  // value.  Note that LLVM IR dialect CallOp has either 0 or 1 result.
-  if (opInst.isa<LLVM::CallOp>()) {
-    llvm::Value *result = convertCall(opInst);
-    if (opInst.getNumResults() != 0) {
-      valueMapping[opInst.getResult(0)] = result;
-      return false;
-    }
-    // Check that LLVM call returns void for 0-result functions.
-    return !result->getType()->isVoidTy();
-  }
-
-  // Emit branches.  We need to look up the remapped blocks and ignore the block
-  // arguments that were transformed into PHI nodes.
-  if (auto brOp = opInst.dyn_cast<LLVM::BrOp>()) {
-    builder.CreateBr(blockMapping[brOp.getSuccessor(0)]);
-    return false;
-  }
-  if (auto condbrOp = opInst.dyn_cast<LLVM::CondBrOp>()) {
-    builder.CreateCondBr(valueMapping.lookup(condbrOp.getOperand(0)),
-                         blockMapping[condbrOp.getSuccessor(0)],
-                         blockMapping[condbrOp.getSuccessor(1)]);
-    return false;
-  }
-
-  opInst.emitError("unsupported or non-LLVM operation: " +
-                   opInst.getName().getStringRef());
-  return true;
-}
-
-// Convert block to LLVM IR.  Unless `ignoreArguments` is set, emit PHI nodes
-// to define values corresponding to the MLIR block arguments.  These nodes
-// are not connected to the source basic blocks, which may not exist yet.
-bool ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) {
-  llvm::IRBuilder<> builder(blockMapping[&bb]);
-
-  // Before traversing operations, make block arguments available through
-  // value remapping and PHI nodes, but do not add incoming edges for the PHI
-  // nodes just yet: those values may be defined by this or following blocks.
-  // This step is omitted if "ignoreArguments" is set.  The arguments of the
-  // first block have been already made available through the remapping of
-  // LLVM function arguments.
-  if (!ignoreArguments) {
-    auto predecessors = bb.getPredecessors();
-    unsigned numPredecessors =
-        std::distance(predecessors.begin(), predecessors.end());
-    for (auto *arg : bb.getArguments()) {
-      auto wrappedType = arg->getType().dyn_cast<LLVM::LLVMType>();
-      if (!wrappedType) {
-        arg->getType().getContext()->emitError(
-            bb.front().getLoc(), "block argument does not have an LLVM type");
-        return true;
-      }
-      llvm::Type *type = wrappedType.getUnderlyingType();
-      llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors);
-      valueMapping[arg] = phi;
-    }
-  }
-
-  // Traverse operations.
-  for (auto &op : bb) {
-    if (convertOperation(op, builder))
-      return true;
-  }
-
-  return false;
-}
-
-// Get the SSA value passed to the current block from the terminator operation
-// of its predecessor.
-static Value *getPHISourceValue(Block *current, Block *pred,
-                                unsigned numArguments, unsigned index) {
-  auto &terminator = *pred->getTerminator();
-  if (terminator.isa<LLVM::BrOp>()) {
-    return terminator.getOperand(index);
-  }
-
-  // For conditional branches, we need to check if the current block is reached
-  // through the "true" or the "false" branch and take the relevant operands.
-  auto condBranchOp = terminator.dyn_cast<LLVM::CondBrOp>();
-  assert(condBranchOp &&
-         "only branch operations can be terminators of a block that "
-         "has successors");
-  assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) &&
-         "successors with arguments in LLVM conditional branches must be "
-         "different blocks");
-
-  return condBranchOp.getSuccessor(0) == current
-             ? terminator.getSuccessorOperand(0, index)
-             : terminator.getSuccessorOperand(1, index);
-}
-
-void ModuleTranslation::connectPHINodes(Function &func) {
-  // Skip the first block, it cannot be branched to and its arguments correspond
-  // to the arguments of the LLVM function.
-  for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
-    Block *bb = &*it;
-    llvm::BasicBlock *llvmBB = blockMapping.lookup(bb);
-    auto phis = llvmBB->phis();
-    auto numArguments = bb->getNumArguments();
-    assert(numArguments == std::distance(phis.begin(), phis.end()));
-    for (auto &numberedPhiNode : llvm::enumerate(phis)) {
-      auto &phiNode = numberedPhiNode.value();
-      unsigned index = numberedPhiNode.index();
-      for (auto *pred : bb->getPredecessors()) {
-        phiNode.addIncoming(valueMapping.lookup(getPHISourceValue(
-                                bb, pred, numArguments, index)),
-                            blockMapping.lookup(pred));
-      }
-    }
-  }
-}
-
-// TODO(mlir-team): implement an iterative version
-static void topologicalSortImpl(llvm::SetVector<Block *> &blocks, Block *b) {
-  blocks.insert(b);
-  for (Block *bb : b->getSuccessors()) {
-    if (blocks.count(bb) == 0)
-      topologicalSortImpl(blocks, bb);
-  }
-}
-
-// Sort function blocks topologically.
-static llvm::SetVector<Block *> topologicalSort(Function &f) {
-  // For each blocks that has not been visited yet (i.e. that has no
-  // predecessors), add it to the list and traverse its successors in DFS
-  // preorder.
-  llvm::SetVector<Block *> blocks;
-  for (Block &b : f.getBlocks()) {
-    if (blocks.count(&b) == 0)
-      topologicalSortImpl(blocks, &b);
-  }
-  assert(blocks.size() == f.getBlocks().size() && "some blocks are not sorted");
-
-  return blocks;
-}
-
-bool ModuleTranslation::convertOneFunction(Function &func) {
-  // Clear the block and value mappings, they are only relevant within one
-  // function.
-  blockMapping.clear();
-  valueMapping.clear();
-  llvm::Function *llvmFunc = functionMapping.lookup(&func);
-  // Add function arguments to the value remapping table.
-  // If there was noalias info then we decorate each argument accordingly.
-  unsigned int argIdx = 0;
-  for (const auto &kvp : llvm::zip(func.getArguments(), llvmFunc->args())) {
-    llvm::Argument &llvmArg = std::get<1>(kvp);
-    BlockArgument *mlirArg = std::get<0>(kvp);
-
-    if (auto attr = func.getArgAttrOfType<BoolAttr>(argIdx, "llvm.noalias")) {
-      // NB: Attribute already verified to be boolean, so check if we can indeed
-      // attach the attribute to this argument, based on its type.
-      auto argTy = mlirArg->getType().dyn_cast<LLVM::LLVMType>();
-      if (!argTy.getUnderlyingType()->isPointerTy())
-        return argTy.getContext()->emitError(
-            func.getLoc(),
-            "llvm.noalias attribute attached to LLVM non-pointer argument");
-      if (attr.getValue())
-        llvmArg.addAttr(llvm::Attribute::AttrKind::NoAlias);
-    }
-    valueMapping[mlirArg] = &llvmArg;
-    argIdx++;
-  }
-
-  // First, create all blocks so we can jump to them.
-  llvm::LLVMContext &llvmContext = llvmFunc->getContext();
-  for (auto &bb : func) {
-    auto *llvmBB = llvm::BasicBlock::Create(llvmContext);
-    llvmBB->insertInto(llvmFunc);
-    blockMapping[&bb] = llvmBB;
-  }
-
-  // Then, convert blocks one by one in topological order to ensure defs are
-  // converted before uses.
-  auto blocks = topologicalSort(func);
-  for (auto indexedBB : llvm::enumerate(blocks)) {
-    auto *bb = indexedBB.value();
-    if (convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0))
-      return true;
-  }
-
-  // Finally, after all blocks have been traversed and values mapped, connect
-  // the PHI nodes to the results of preceding blocks.
-  connectPHINodes(func);
-  return false;
-}
-
-bool ModuleTranslation::convertFunctions() {
-  // Declare all functions first because there may be function calls that form a
-  // call graph with cycles.
-  for (Function &function : mlirModule) {
-    Function *functionPtr = &function;
-    mlir::BoolAttr isVarArgsAttr =
-        function.getAttrOfType<BoolAttr>("std.varargs");
-    bool isVarArgs = isVarArgsAttr && isVarArgsAttr.getValue();
-    llvm::FunctionType *functionType =
-        convertFunctionType(llvmModule->getContext(), function.getType(),
-                            function.getLoc(), isVarArgs);
-    if (!functionType)
-      return true;
-    llvm::FunctionCallee llvmFuncCst =
-        llvmModule->getOrInsertFunction(function.getName(), functionType);
-    assert(isa<llvm::Function>(llvmFuncCst.getCallee()));
-    functionMapping[functionPtr] =
-        cast<llvm::Function>(llvmFuncCst.getCallee());
-  }
-
-  // Convert functions.
-  for (Function &function : mlirModule) {
-    // Ignore external functions.
-    if (function.isExternal())
-      continue;
-
-    if (convertOneFunction(function))
-      return true;
-  }
-
-  return false;
-}
-
-std::unique_ptr<llvm::Module> ModuleTranslation::translateModule(Module &m) {
-  Dialect *dialect = m.getContext()->getRegisteredDialect("llvm");
-  assert(dialect && "LLVM dialect must be registered");
-  auto *llvmDialect = static_cast<LLVM::LLVMDialect *>(dialect);
-
-  auto llvmModule = llvm::CloneModule(llvmDialect->getLLVMModule());
-  if (!llvmModule)
-    return nullptr;
-
-  llvm::LLVMContext &llvmContext = llvmModule->getContext();
-  llvm::IRBuilder<> builder(llvmContext);
-
-  // Inject declarations for `malloc` and `free` functions that can be used in
-  // memref allocation/deallocation coming from standard ops lowering.
-  llvmModule->getOrInsertFunction("malloc", builder.getInt8PtrTy(),
-                                  builder.getInt64Ty());
-  llvmModule->getOrInsertFunction("free", builder.getVoidTy(),
-                                  builder.getInt8PtrTy());
-
-  ModuleTranslation translator(m);
-  translator.llvmModule = std::move(llvmModule);
-  if (translator.convertFunctions())
-    return nullptr;
-
-  return std::move(translator.llvmModule);
-}
-
 std::unique_ptr<llvm::Module> mlir::translateModuleToLLVMIR(Module &m) {
-  return ModuleTranslation::translateModule(m);
+  return LLVM::ModuleTranslation::translateModule<>(m);
 }
 
 static TranslateFromMLIRRegistration registration(
@@ -481,7 +40,7 @@ static TranslateFromMLIRRegistration registration(
       if (!module)
         return true;
 
-      auto llvmModule = ModuleTranslation::translateModule(*module);
+      auto llvmModule = LLVM::ModuleTranslation::translateModule<>(*module);
       if (!llvmModule)
         return true;
 
diff --git a/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp b/mlir/lib/Target/LLVMIR/ConvertToNVVMIR.cpp
new file mode 100644 (file)
index 0000000..2f7ce41
--- /dev/null
@@ -0,0 +1,84 @@
+//===- ConvertToNVVMIR.cpp - MLIR to LLVM IR conversion ---------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a translation between the MLIR LLVM + NVVM dialects and
+// LLVM IR with NVVM intrinsics and metadata.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/NVVMIR.h"
+
+#include "mlir/LLVMIR/NVVMDialect.h"
+#include "mlir/Support/FileUtilities.h"
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+#include "mlir/Translation.h"
+
+#include "llvm/ADT/StringRef.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Support/ToolOutputFile.h"
+
+using namespace mlir;
+
+namespace {
+static void createIntrinsicCall(llvm::IRBuilder<> &builder,
+                                llvm::Intrinsic::ID intrinsic) {
+  llvm::Module *module = builder.GetInsertBlock()->getModule();
+  llvm::Function *fn = llvm::Intrinsic::getDeclaration(module, intrinsic, {});
+  builder.CreateCall(fn);
+}
+
+class ModuleTranslation : public LLVM::ModuleTranslation {
+
+public:
+  explicit ModuleTranslation(Module &module)
+      : LLVM::ModuleTranslation(module) {}
+  ~ModuleTranslation() override {}
+
+protected:
+  bool convertOperation(Operation &opInst,
+                        llvm::IRBuilder<> &builder) override {
+
+#include "mlir/LLVMIR/NVVMConversions.inc"
+
+    return LLVM::ModuleTranslation::convertOperation(opInst, builder);
+  }
+};
+} // namespace
+
+std::unique_ptr<llvm::Module> mlir::translateModuleToNVVMIR(Module &m) {
+  ModuleTranslation translation(m);
+  return LLVM::ModuleTranslation::translateModule<ModuleTranslation>(m);
+}
+
+static TranslateFromMLIRRegistration registration(
+    "mlir-to-nvvmir", [](Module *module, llvm::StringRef outputFilename) {
+      if (!module)
+        return true;
+
+      auto llvmModule =
+          LLVM::ModuleTranslation::translateModule<ModuleTranslation>(*module);
+      if (!llvmModule)
+        return true;
+
+      auto file = openOutputFile(outputFilename);
+      if (!file)
+        return true;
+
+      llvmModule->print(file->os(), nullptr);
+      file->keep();
+      return false;
+    });
diff --git a/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp b/mlir/lib/Target/LLVMIR/ModuleTranslation.cpp
new file mode 100644 (file)
index 0000000..58bbc9c
--- /dev/null
@@ -0,0 +1,432 @@
+//===- ModuleTranslation.cpp - MLIR to LLVM conversion ----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements the translation between an MLIR LLVM dialect module and
+// the corresponding LLVMIR module. It only handles core LLVM IR operations.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Target/LLVMIR/ModuleTranslation.h"
+
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Module.h"
+#include "mlir/LLVMIR/LLVMDialect.h"
+#include "mlir/StandardOps/Ops.h"
+#include "mlir/Support/LLVM.h"
+
+#include "llvm/ADT/SetVector.h"
+#include "llvm/IR/BasicBlock.h"
+#include "llvm/IR/Constants.h"
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/LLVMContext.h"
+#include "llvm/IR/Module.h"
+#include "llvm/Transforms/Utils/Cloning.h"
+
+namespace mlir {
+namespace LLVM {
+
+// Convert an MLIR function type to LLVM IR.  Arguments of the function must of
+// MLIR LLVM IR dialect types.  Use `loc` as a location when reporting errors.
+// Return nullptr on errors.
+static llvm::FunctionType *convertFunctionType(llvm::LLVMContext &llvmContext,
+                                               FunctionType type, Location loc,
+                                               bool isVarArgs) {
+  assert(type && "expected non-null type");
+
+  auto context = type.getContext();
+  if (type.getNumResults() > 1)
+    return context->emitError(loc,
+                              "LLVM functions can only have 0 or 1 result"),
+           nullptr;
+
+  SmallVector<llvm::Type *, 8> argTypes;
+  argTypes.reserve(type.getNumInputs());
+  for (auto t : type.getInputs()) {
+    auto wrappedLLVMType = t.dyn_cast<LLVM::LLVMType>();
+    if (!wrappedLLVMType)
+      return context->emitError(loc, "non-LLVM function argument type"),
+             nullptr;
+    argTypes.push_back(wrappedLLVMType.getUnderlyingType());
+  }
+
+  if (type.getNumResults() == 0)
+    return llvm::FunctionType::get(llvm::Type::getVoidTy(llvmContext), argTypes,
+                                   isVarArgs);
+
+  auto wrappedResultType = type.getResult(0).dyn_cast<LLVM::LLVMType>();
+  if (!wrappedResultType)
+    return context->emitError(loc, "non-LLVM function result"), nullptr;
+
+  return llvm::FunctionType::get(wrappedResultType.getUnderlyingType(),
+                                 argTypes, isVarArgs);
+}
+
+// Create an LLVM IR constant of `llvmType` from the MLIR attribute `attr`.
+// This currently supports integer, floating point, splat and dense element
+// attributes and combinations thereof.  In case of error, report it to `loc`
+// and return nullptr.
+llvm::Constant *ModuleTranslation::getLLVMConstant(llvm::Type *llvmType,
+                                                   Attribute attr,
+                                                   Location loc) {
+  if (auto intAttr = attr.dyn_cast<IntegerAttr>())
+    return llvm::ConstantInt::get(llvmType, intAttr.getValue());
+  if (auto floatAttr = attr.dyn_cast<FloatAttr>())
+    return llvm::ConstantFP::get(llvmType, floatAttr.getValue());
+  if (auto funcAttr = attr.dyn_cast<FunctionAttr>())
+    return functionMapping.lookup(funcAttr.getValue());
+  if (auto splatAttr = attr.dyn_cast<SplatElementsAttr>()) {
+    auto *vectorType = cast<llvm::VectorType>(llvmType);
+    auto *child = getLLVMConstant(vectorType->getElementType(),
+                                  splatAttr.getValue(), loc);
+    return llvm::ConstantVector::getSplat(vectorType->getNumElements(), child);
+  }
+  if (auto denseAttr = attr.dyn_cast<DenseElementsAttr>()) {
+    auto *vectorType = cast<llvm::VectorType>(llvmType);
+    SmallVector<llvm::Constant *, 8> constants;
+    uint64_t numElements = vectorType->getNumElements();
+    constants.reserve(numElements);
+    SmallVector<Attribute, 8> nested;
+    denseAttr.getValues(nested);
+    for (auto n : nested) {
+      constants.push_back(
+          getLLVMConstant(vectorType->getElementType(), n, loc));
+      if (!constants.back())
+        return nullptr;
+    }
+    return llvm::ConstantVector::get(constants);
+  }
+  mlirModule.getContext()->emitError(loc, "unsupported constant value");
+  return nullptr;
+}
+
+// Convert MLIR integer comparison predicate to LLVM IR comparison predicate.
+static llvm::CmpInst::Predicate getLLVMCmpPredicate(CmpIPredicate p) {
+  switch (p) {
+  case CmpIPredicate::EQ:
+    return llvm::CmpInst::Predicate::ICMP_EQ;
+  case CmpIPredicate::NE:
+    return llvm::CmpInst::Predicate::ICMP_NE;
+  case CmpIPredicate::SLT:
+    return llvm::CmpInst::Predicate::ICMP_SLT;
+  case CmpIPredicate::SLE:
+    return llvm::CmpInst::Predicate::ICMP_SLE;
+  case CmpIPredicate::SGT:
+    return llvm::CmpInst::Predicate::ICMP_SGT;
+  case CmpIPredicate::SGE:
+    return llvm::CmpInst::Predicate::ICMP_SGE;
+  case CmpIPredicate::ULT:
+    return llvm::CmpInst::Predicate::ICMP_ULT;
+  case CmpIPredicate::ULE:
+    return llvm::CmpInst::Predicate::ICMP_ULE;
+  case CmpIPredicate::UGT:
+    return llvm::CmpInst::Predicate::ICMP_UGT;
+  case CmpIPredicate::UGE:
+    return llvm::CmpInst::Predicate::ICMP_UGE;
+  default:
+    llvm_unreachable("incorrect comparison predicate");
+  }
+}
+
+// A helper to look up remapped operands in the value remapping table.
+template <typename Range>
+SmallVector<llvm::Value *, 8> ModuleTranslation::lookupValues(Range &&values) {
+  SmallVector<llvm::Value *, 8> remapped;
+  remapped.reserve(llvm::size(values));
+  for (Value *v : values) {
+    remapped.push_back(valueMapping.lookup(v));
+  }
+  return remapped;
+}
+
+// Given a single MLIR operation, create the corresponding LLVM IR operation
+// using the `builder`.  LLVM IR Builder does not have a generic interface so
+// this has to be a long chain of `if`s calling different functions with a
+// different number of arguments.
+bool ModuleTranslation::convertOperation(Operation &opInst,
+                                         llvm::IRBuilder<> &builder) {
+  auto extractPosition = [](ArrayAttr attr) {
+    SmallVector<unsigned, 4> position;
+    position.reserve(attr.size());
+    for (Attribute v : attr)
+      position.push_back(v.cast<IntegerAttr>().getValue().getZExtValue());
+    return position;
+  };
+
+#include "mlir/LLVMIR/LLVMConversions.inc"
+
+  // Emit function calls.  If the "callee" attribute is present, this is a
+  // direct function call and we also need to look up the remapped function
+  // itself.  Otherwise, this is an indirect call and the callee is the first
+  // operand, look it up as a normal value.  Return the llvm::Value representing
+  // the function result, which may be of llvm::VoidTy type.
+  auto convertCall = [this, &builder](Operation &op) -> llvm::Value * {
+    auto operands = lookupValues(op.getOperands());
+    ArrayRef<llvm::Value *> operandsRef(operands);
+    if (auto attr = op.getAttrOfType<FunctionAttr>("callee")) {
+      return builder.CreateCall(functionMapping.lookup(attr.getValue()),
+                                operandsRef);
+    } else {
+      return builder.CreateCall(operandsRef.front(), operandsRef.drop_front());
+    }
+  };
+
+  // Emit calls.  If the called function has a result, remap the corresponding
+  // value.  Note that LLVM IR dialect CallOp has either 0 or 1 result.
+  if (opInst.isa<LLVM::CallOp>()) {
+    llvm::Value *result = convertCall(opInst);
+    if (opInst.getNumResults() != 0) {
+      valueMapping[opInst.getResult(0)] = result;
+      return false;
+    }
+    // Check that LLVM call returns void for 0-result functions.
+    return !result->getType()->isVoidTy();
+  }
+
+  // Emit branches.  We need to look up the remapped blocks and ignore the block
+  // arguments that were transformed into PHI nodes.
+  if (auto brOp = opInst.dyn_cast<LLVM::BrOp>()) {
+    builder.CreateBr(blockMapping[brOp.getSuccessor(0)]);
+    return false;
+  }
+  if (auto condbrOp = opInst.dyn_cast<LLVM::CondBrOp>()) {
+    builder.CreateCondBr(valueMapping.lookup(condbrOp.getOperand(0)),
+                         blockMapping[condbrOp.getSuccessor(0)],
+                         blockMapping[condbrOp.getSuccessor(1)]);
+    return false;
+  }
+
+  opInst.emitError("unsupported or non-LLVM operation: " +
+                   opInst.getName().getStringRef());
+  return true;
+}
+
+// Convert block to LLVM IR.  Unless `ignoreArguments` is set, emit PHI nodes
+// to define values corresponding to the MLIR block arguments.  These nodes
+// are not connected to the source basic blocks, which may not exist yet.
+bool ModuleTranslation::convertBlock(Block &bb, bool ignoreArguments) {
+  llvm::IRBuilder<> builder(blockMapping[&bb]);
+
+  // Before traversing operations, make block arguments available through
+  // value remapping and PHI nodes, but do not add incoming edges for the PHI
+  // nodes just yet: those values may be defined by this or following blocks.
+  // This step is omitted if "ignoreArguments" is set.  The arguments of the
+  // first block have been already made available through the remapping of
+  // LLVM function arguments.
+  if (!ignoreArguments) {
+    auto predecessors = bb.getPredecessors();
+    unsigned numPredecessors =
+        std::distance(predecessors.begin(), predecessors.end());
+    for (auto *arg : bb.getArguments()) {
+      auto wrappedType = arg->getType().dyn_cast<LLVM::LLVMType>();
+      if (!wrappedType) {
+        arg->getType().getContext()->emitError(
+            bb.front().getLoc(), "block argument does not have an LLVM type");
+        return true;
+      }
+      llvm::Type *type = wrappedType.getUnderlyingType();
+      llvm::PHINode *phi = builder.CreatePHI(type, numPredecessors);
+      valueMapping[arg] = phi;
+    }
+  }
+
+  // Traverse operations.
+  for (auto &op : bb) {
+    if (convertOperation(op, builder))
+      return true;
+  }
+
+  return false;
+}
+
+// Get the SSA value passed to the current block from the terminator operation
+// of its predecessor.
+static Value *getPHISourceValue(Block *current, Block *pred,
+                                unsigned numArguments, unsigned index) {
+  auto &terminator = *pred->getTerminator();
+  if (terminator.isa<LLVM::BrOp>()) {
+    return terminator.getOperand(index);
+  }
+
+  // For conditional branches, we need to check if the current block is reached
+  // through the "true" or the "false" branch and take the relevant operands.
+  auto condBranchOp = terminator.dyn_cast<LLVM::CondBrOp>();
+  assert(condBranchOp &&
+         "only branch operations can be terminators of a block that "
+         "has successors");
+  assert((condBranchOp.getSuccessor(0) != condBranchOp.getSuccessor(1)) &&
+         "successors with arguments in LLVM conditional branches must be "
+         "different blocks");
+
+  return condBranchOp.getSuccessor(0) == current
+             ? terminator.getSuccessorOperand(0, index)
+             : terminator.getSuccessorOperand(1, index);
+}
+
+void ModuleTranslation::connectPHINodes(Function &func) {
+  // Skip the first block, it cannot be branched to and its arguments correspond
+  // to the arguments of the LLVM function.
+  for (auto it = std::next(func.begin()), eit = func.end(); it != eit; ++it) {
+    Block *bb = &*it;
+    llvm::BasicBlock *llvmBB = blockMapping.lookup(bb);
+    auto phis = llvmBB->phis();
+    auto numArguments = bb->getNumArguments();
+    assert(numArguments == std::distance(phis.begin(), phis.end()));
+    for (auto &numberedPhiNode : llvm::enumerate(phis)) {
+      auto &phiNode = numberedPhiNode.value();
+      unsigned index = numberedPhiNode.index();
+      for (auto *pred : bb->getPredecessors()) {
+        phiNode.addIncoming(valueMapping.lookup(getPHISourceValue(
+                                bb, pred, numArguments, index)),
+                            blockMapping.lookup(pred));
+      }
+    }
+  }
+}
+
+// TODO(mlir-team): implement an iterative version
+static void topologicalSortImpl(llvm::SetVector<Block *> &blocks, Block *b) {
+  blocks.insert(b);
+  for (Block *bb : b->getSuccessors()) {
+    if (blocks.count(bb) == 0)
+      topologicalSortImpl(blocks, bb);
+  }
+}
+
+// Sort function blocks topologically.
+static llvm::SetVector<Block *> topologicalSort(Function &f) {
+  // For each blocks that has not been visited yet (i.e. that has no
+  // predecessors), add it to the list and traverse its successors in DFS
+  // preorder.
+  llvm::SetVector<Block *> blocks;
+  for (Block &b : f.getBlocks()) {
+    if (blocks.count(&b) == 0)
+      topologicalSortImpl(blocks, &b);
+  }
+  assert(blocks.size() == f.getBlocks().size() && "some blocks are not sorted");
+
+  return blocks;
+}
+
+bool ModuleTranslation::convertOneFunction(Function &func) {
+  // Clear the block and value mappings, they are only relevant within one
+  // function.
+  blockMapping.clear();
+  valueMapping.clear();
+  llvm::Function *llvmFunc = functionMapping.lookup(&func);
+  // Add function arguments to the value remapping table.
+  // If there was noalias info then we decorate each argument accordingly.
+  unsigned int argIdx = 0;
+  for (const auto &kvp : llvm::zip(func.getArguments(), llvmFunc->args())) {
+    llvm::Argument &llvmArg = std::get<1>(kvp);
+    BlockArgument *mlirArg = std::get<0>(kvp);
+
+    if (auto attr = func.getArgAttrOfType<BoolAttr>(argIdx, "llvm.noalias")) {
+      // NB: Attribute already verified to be boolean, so check if we can indeed
+      // attach the attribute to this argument, based on its type.
+      auto argTy = mlirArg->getType().dyn_cast<LLVM::LLVMType>();
+      if (!argTy.getUnderlyingType()->isPointerTy())
+        return argTy.getContext()->emitError(
+            func.getLoc(),
+            "llvm.noalias attribute attached to LLVM non-pointer argument");
+      if (attr.getValue())
+        llvmArg.addAttr(llvm::Attribute::AttrKind::NoAlias);
+    }
+    valueMapping[mlirArg] = &llvmArg;
+    argIdx++;
+  }
+
+  // First, create all blocks so we can jump to them.
+  llvm::LLVMContext &llvmContext = llvmFunc->getContext();
+  for (auto &bb : func) {
+    auto *llvmBB = llvm::BasicBlock::Create(llvmContext);
+    llvmBB->insertInto(llvmFunc);
+    blockMapping[&bb] = llvmBB;
+  }
+
+  // Then, convert blocks one by one in topological order to ensure defs are
+  // converted before uses.
+  auto blocks = topologicalSort(func);
+  for (auto indexedBB : llvm::enumerate(blocks)) {
+    auto *bb = indexedBB.value();
+    if (convertBlock(*bb, /*ignoreArguments=*/indexedBB.index() == 0))
+      return true;
+  }
+
+  // Finally, after all blocks have been traversed and values mapped, connect
+  // the PHI nodes to the results of preceding blocks.
+  connectPHINodes(func);
+  return false;
+}
+
+bool ModuleTranslation::convertFunctions() {
+  // Declare all functions first because there may be function calls that form a
+  // call graph with cycles.
+  for (Function &function : mlirModule) {
+    Function *functionPtr = &function;
+    mlir::BoolAttr isVarArgsAttr =
+        function.getAttrOfType<BoolAttr>("std.varargs");
+    bool isVarArgs = isVarArgsAttr && isVarArgsAttr.getValue();
+    llvm::FunctionType *functionType =
+        convertFunctionType(llvmModule->getContext(), function.getType(),
+                            function.getLoc(), isVarArgs);
+    if (!functionType)
+      return true;
+    llvm::FunctionCallee llvmFuncCst =
+        llvmModule->getOrInsertFunction(function.getName(), functionType);
+    assert(isa<llvm::Function>(llvmFuncCst.getCallee()));
+    functionMapping[functionPtr] =
+        cast<llvm::Function>(llvmFuncCst.getCallee());
+  }
+
+  // Convert functions.
+  for (Function &function : mlirModule) {
+    // Ignore external functions.
+    if (function.isExternal())
+      continue;
+
+    if (convertOneFunction(function))
+      return true;
+  }
+
+  return false;
+}
+
+std::unique_ptr<llvm::Module> ModuleTranslation::prepareLLVMModule(Module &m) {
+  Dialect *dialect = m.getContext()->getRegisteredDialect("llvm");
+  assert(dialect && "LLVM dialect must be registered");
+  auto *llvmDialect = static_cast<LLVM::LLVMDialect *>(dialect);
+
+  auto llvmModule = llvm::CloneModule(llvmDialect->getLLVMModule());
+  if (!llvmModule)
+    return nullptr;
+
+  llvm::LLVMContext &llvmContext = llvmModule->getContext();
+  llvm::IRBuilder<> builder(llvmContext);
+
+  // Inject declarations for `malloc` and `free` functions that can be used in
+  // memref allocation/deallocation coming from standard ops lowering.
+  llvmModule->getOrInsertFunction("malloc", builder.getInt8PtrTy(),
+                                  builder.getInt64Ty());
+  llvmModule->getOrInsertFunction("free", builder.getVoidTy(),
+                                  builder.getInt8PtrTy());
+
+  return llvmModule;
+}
+
+} // namespace LLVM
+} // namespace mlir
diff --git a/mlir/test/Target/nvvmir.mlir b/mlir/test/Target/nvvmir.mlir
new file mode 100644 (file)
index 0000000..cbb2f16
--- /dev/null
@@ -0,0 +1,29 @@
+// RUN: mlir-translate -mlir-to-nvvmir %s | FileCheck %s
+
+func @nvvm_special_regs() -> !llvm.i32 {
+  // CHECK: %1 = call i32 @llvm.nvvm.read.ptx.sreg.tid.x()
+  %1 = nvvm.read.ptx.sreg.tid.x : !llvm.i32
+  // CHECK: %2 = call i32 @llvm.nvvm.read.ptx.sreg.tid.y()
+  %2 = nvvm.read.ptx.sreg.tid.y : !llvm.i32
+  // CHECK: %3 = call i32 @llvm.nvvm.read.ptx.sreg.tid.z()
+  %3 = nvvm.read.ptx.sreg.tid.z : !llvm.i32
+  // CHECK: %4 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.x()
+  %4 = nvvm.read.ptx.sreg.ntid.x : !llvm.i32
+  // CHECK: %5 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.y()
+  %5 = nvvm.read.ptx.sreg.ntid.y : !llvm.i32
+  // CHECK: %6 = call i32 @llvm.nvvm.read.ptx.sreg.ntid.z()
+  %6 = nvvm.read.ptx.sreg.ntid.z : !llvm.i32
+  // CHECK: %7 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.x()
+  %7 = nvvm.read.ptx.sreg.ctaid.x : !llvm.i32
+  // CHECK: %8 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.y()
+  %8 = nvvm.read.ptx.sreg.ctaid.y : !llvm.i32
+  // CHECK: %9 = call i32 @llvm.nvvm.read.ptx.sreg.ctaid.z()
+  %9 = nvvm.read.ptx.sreg.ctaid.z : !llvm.i32
+  // CHECK: %10 = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.x()
+  %10 = nvvm.read.ptx.sreg.nctaid.x : !llvm.i32
+  // CHECK: %11 = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.y()
+  %11 = nvvm.read.ptx.sreg.nctaid.y : !llvm.i32
+  // CHECK: %12 = call i32 @llvm.nvvm.read.ptx.sreg.nctaid.z()
+  %12 = nvvm.read.ptx.sreg.nctaid.z : !llvm.i32
+  llvm.return %1 : !llvm.i32
+}
\ No newline at end of file
index 1c59ae5..eff9409 100644 (file)
@@ -5,6 +5,7 @@ set(LIBS
   MLIRExecutionEngine
   MLIRIR
   MLIRParser
+  MLIRTargetLLVMIR
   MLIRTransforms
   MLIRSupport
   LLVMCore
index 330c9ad..3d4e96a 100644 (file)
@@ -6,6 +6,7 @@ set(LIBS
   MLIRPass
   MLIRStandardOps
   MLIRTargetLLVMIR
+  MLIRTargetNVVMIR
   MLIRTransforms
   MLIRTranslation
   MLIRSupport
index 3d67d40..6e27150 100644 (file)
@@ -73,7 +73,7 @@ using TranslateFunction =
     std::function<bool(StringRef, StringRef, MLIRContext *)>;
 
 // Storage for the translation function wrappers that survive the parser.
-static llvm::SmallVector<TranslateFunction, 8> wrapperStorage;
+static llvm::SmallVector<TranslateFunction, 16> wrapperStorage;
 
 // Custom parser for TranslateFunction.
 // Wraps TranslateToMLIRFunctions and TranslateFromMLIRFunctions into