Tutorial Linalg1: implement conversion to the LLVM Dialect
authorAlex Zinenko <zinenko@google.com>
Wed, 3 Apr 2019 17:38:47 +0000 (10:38 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Thu, 4 Apr 2019 01:36:09 +0000 (18:36 -0700)
    Implement conversion from the Linalg dialect to the LLVM dialect using a simple
    set of DialectOpConverters and by plugging them into the dialect conversion
    infrastructure.  View and Range Linalg types are converted into descriptors
    that store the dynamic values in an LLVM aggregate type, similarly to memrefs.
    Slice operations create new descriptors based on the original descriptors and
    thus remove the constraint on ViewTypes not being acceptable as function
    arguments.

--

PiperOrigin-RevId: 241760189

mlir/tutorial/Linalg1/Conversion.cpp [new file with mode: 0644]
mlir/tutorial/Linalg1/Example.cpp
mlir/tutorial/Linalg1/include/linalg1/ConvertToLLVMDialect.h [new file with mode: 0644]
mlir/tutorial/Linalg1/include/linalg1/SliceOp.h
mlir/tutorial/Linalg1/lib/ConvertToLLVMDialect.cpp [new file with mode: 0644]
mlir/tutorial/Linalg1/lib/SliceOp.cpp

diff --git a/mlir/tutorial/Linalg1/Conversion.cpp b/mlir/tutorial/Linalg1/Conversion.cpp
new file mode 100644 (file)
index 0000000..343bf0d
--- /dev/null
@@ -0,0 +1,307 @@
+//===- Conversion.cpp - Linalg to LLVM conversion driver ------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+// RUN: %p/conversion | FileCheck %s
+
+#include "TestHarness.h"
+
+#include "linalg1/Common.h"
+#include "linalg1/ConvertToLLVMDialect.h"
+#include "linalg1/Intrinsics.h"
+#include "linalg1/Ops.h"
+#include "linalg1/Types.h"
+#include "mlir/EDSC/Builders.h"
+#include "mlir/EDSC/Intrinsics.h"
+#include "mlir/IR/Function.h"
+
+using namespace linalg;
+using namespace linalg::common;
+using namespace linalg::intrinsics;
+using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
+
+TEST_FUNC(rangeConversion) {
+  // Define the MLIR context, create a Module in this context, and a Builder to
+  // facilitate type construction.
+  MLIRContext context;
+  Module module(&context);
+  Builder builder(&module);
+
+  // Declare a function called "rangeConversion" with type:
+  //   (index, index, index) -> ()
+  // define it, and add it to the module.
+  FunctionType funcType = builder.getFunctionType(
+      {builder.getIndexType(), builder.getIndexType(), builder.getIndexType()},
+      {});
+  Function *f =
+      new Function(builder.getUnknownLoc(), "rangeConversion", funcType);
+  f->addEntryBlock();
+  module.getFunctions().push_back(f);
+
+  // Construct a linalg::RangeOp taking function arguments as operands.
+  ScopedContext scope(f);
+  ValueHandle arg0(f->getArgument(0)), arg1(f->getArgument(1)),
+      arg2(f->getArgument(2));
+  {
+    range(arg0, arg1, arg2);
+    ret();
+  }
+
+  // clang-format off
+  // CHECK-LABEL: @rangeConversion
+  // CHECK-NEXT: %0 = llvm.undef : !llvm<"{ i64, i64, i64 }">
+  // CHECK-NEXT: %1 = llvm.insertvalue %arg0, %0[0] : !llvm<"{ i64, i64, i64 }">
+  // CHECK-NEXT: %2 = llvm.insertvalue %arg1, %1[1] : !llvm<"{ i64, i64, i64 }">
+  // CHECK-NEXT: %3 = llvm.insertvalue %arg2, %2[2] : !llvm<"{ i64, i64, i64 }">
+  // clang-format on
+  convertToLLVM(module);
+  module.print(llvm::outs());
+}
+
+TEST_FUNC(viewRangeConversion) {
+  // Define the MLIR context, create a Module in this context, and a Builder to
+  // facilitate type construction.
+  MLIRContext context;
+  Module module(&context);
+  Builder builder(&module);
+
+  // Declare a function called "viewRangeConversion" with type:
+  //   (memref<?x?xf32>, !linalg.range, !linalg.range) -> ()
+  // define it, and add it to the module.
+  FunctionType funcType = builder.getFunctionType(
+      {builder.getMemRefType({-1, -1}, builder.getF32Type(), {}, 0),
+       builder.getType<RangeType>(), builder.getType<RangeType>()},
+      {});
+  Function *f =
+      new Function(builder.getUnknownLoc(), "viewRangeConversion", funcType);
+  f->addEntryBlock();
+  module.getFunctions().push_back(f);
+
+  // Construct a linalg::ViewOp taking function arguments as operands.
+  ScopedContext scope(f);
+  ValueHandle memref(f->getArgument(0)), range1(f->getArgument(1)),
+      range2(f->getArgument(2));
+  {
+    view(memref, {range1, range2});
+    ret();
+  }
+
+  // clang-format off
+  // CHECK-LABEL: @viewRangeConversion
+  // CHECK-NEXT: %0 = llvm.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK-NEXT: %1 = llvm.extractvalue %arg0[0] : !llvm<"{ float*, i64, i64 }">
+  // CHECK-NEXT: %2 = llvm.insertvalue %1, %0[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK-NEXT: %3 = llvm.extractvalue %arg0[2] : !llvm<"{ float*, i64, i64 }">
+  // CHECK-NEXT: %4 = llvm.constant(1 : index) : !llvm<"i64">
+  // CHECK-NEXT: %5 = llvm.mul %4, %3 : !llvm<"i64">
+  // CHECK-NEXT: %6 = llvm.constant(0 : index) : !llvm<"i64">
+  // CHECK-NEXT: %7 = llvm.extractvalue %arg1[0] : !llvm<"{ i64, i64, i64 }">
+  // CHECK-NEXT: %8 = llvm.mul %7, %5 : !llvm<"i64">
+  // CHECK-NEXT: %9 = llvm.add %6, %8 : !llvm<"i64">
+  // CHECK-NEXT: %10 = llvm.extractvalue %arg2[0] : !llvm<"{ i64, i64, i64 }">
+  // CHECK-NEXT: %11 = llvm.mul %10, %4 : !llvm<"i64">
+  // CHECK-NEXT: %12 = llvm.add %9, %11 : !llvm<"i64">
+  // CHECK-NEXT: %13 = llvm.insertvalue %12, %2[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK-NEXT: %14 = llvm.extractvalue %arg1[0] : !llvm<"{ i64, i64, i64 }">
+  // CHECK-NEXT: %15 = llvm.extractvalue %arg1[1] : !llvm<"{ i64, i64, i64 }">
+  // CHECK-NEXT: %16 = llvm.sub %15, %14 : !llvm<"i64">
+  // CHECK-NEXT: %17 = llvm.insertvalue %16, %13[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK-NEXT: %18 = llvm.extractvalue %arg2[0] : !llvm<"{ i64, i64, i64 }">
+  // CHECK-NEXT: %19 = llvm.extractvalue %arg2[1] : !llvm<"{ i64, i64, i64 }">
+  // CHECK-NEXT: %20 = llvm.sub %19, %18 : !llvm<"i64">
+  // CHECK-NEXT: %21 = llvm.insertvalue %20, %17[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK-NEXT: %22 = llvm.extractvalue %arg1[2] : !llvm<"{ i64, i64, i64 }">
+  // CHECK-NEXT: %23 = llvm.mul %5, %22 : !llvm<"i64">
+  // CHECK-NEXT: %24 = llvm.insertvalue %23, %21[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK-NEXT: %25 = llvm.extractvalue %arg2[2] : !llvm<"{ i64, i64, i64 }">
+  // CHECK-NEXT: %26 = llvm.mul %4, %25 : !llvm<"i64">
+  // CHECK-NEXT: %27 = llvm.insertvalue %26, %24[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+  // clang-format on
+  convertToLLVM(module);
+  module.print(llvm::outs());
+}
+
+TEST_FUNC(viewNonRangeConversion) {
+  // Define the MLIR context, create a Module in this context, and a Builder to
+  // facilitate type construction.
+  MLIRContext context;
+  Module module(&context);
+  Builder builder(&module);
+
+  // Declare a function called "viewNonRangeConversion" with type:
+  //   (memref<?x?xf32>, !linalg.range, index) -> ()
+  // define it, and add it to the module.
+  FunctionType funcType = builder.getFunctionType(
+      {builder.getMemRefType({-1, -1}, builder.getF32Type(), {}, 0),
+       builder.getType<RangeType>(), builder.getIndexType()},
+      {});
+  Function *f =
+      new Function(builder.getUnknownLoc(), "viewNonRangeConversion", funcType);
+  f->addEntryBlock();
+  module.getFunctions().push_back(f);
+
+  // Construct a linalg::ViewOp taking function arguments as operands.
+  ScopedContext scope(f);
+  ValueHandle memref(f->getArgument(0)), range(f->getArgument(1)),
+      index(f->getArgument(2));
+  {
+    view(memref, {range, index});
+    ret();
+  }
+
+  // clang-format off
+  // CHECK-LABEL: @viewNonRangeConversion
+  // CHECK-NEXT: %0 = llvm.undef : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+  // CHECK-NEXT: %1 = llvm.extractvalue %arg0[0] : !llvm<"{ float*, i64, i64 }">
+  // CHECK-NEXT: %2 = llvm.insertvalue %1, %0[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+  // CHECK-NEXT: %3 = llvm.extractvalue %arg0[2] : !llvm<"{ float*, i64, i64 }">
+  // CHECK-NEXT: %4 = llvm.constant(1 : index) : !llvm<"i64">
+  // CHECK-NEXT: %5 = llvm.mul %4, %3 : !llvm<"i64">
+  // CHECK-NEXT: %6 = llvm.constant(0 : index) : !llvm<"i64">
+  // CHECK-NEXT: %7 = llvm.extractvalue %arg1[0] : !llvm<"{ i64, i64, i64 }">
+  // CHECK-NEXT: %8 = llvm.mul %7, %5 : !llvm<"i64">
+  // CHECK-NEXT: %9 = llvm.add %6, %8 : !llvm<"i64">
+  // CHECK-NEXT: %10 = llvm.mul %arg2, %4 : !llvm<"i64">
+  // CHECK-NEXT: %11 = llvm.add %9, %10 : !llvm<"i64">
+  // CHECK-NEXT: %12 = llvm.insertvalue %11, %2[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+  // CHECK-NEXT: %13 = llvm.extractvalue %arg1[0] : !llvm<"{ i64, i64, i64 }">
+  // CHECK-NEXT: %14 = llvm.extractvalue %arg1[1] : !llvm<"{ i64, i64, i64 }">
+  // CHECK-NEXT: %15 = llvm.sub %14, %13 : !llvm<"i64">
+  // CHECK-NEXT: %16 = llvm.insertvalue %15, %12[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+  // CHECK-NEXT: %17 = llvm.extractvalue %arg1[2] : !llvm<"{ i64, i64, i64 }">
+  // CHECK-NEXT: %18 = llvm.mul %5, %17 : !llvm<"i64">
+  // CHECK-NEXT: %19 = llvm.insertvalue %18, %16[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+  // clang-format on
+  convertToLLVM(module);
+  module.print(llvm::outs());
+}
+
+TEST_FUNC(sliceRangeConversion) {
+  // Define the MLIR context, create a Module in this context, and a Builder to
+  // facilitate type construction.
+  MLIRContext context;
+  Module module(&context);
+  Builder builder(&module);
+
+  // Declare a function called "sliceRangeConversion" with type:
+  //   (memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.range) -> ()
+  // define it, and add it to the module.
+  FunctionType funcType = builder.getFunctionType(
+      {builder.getMemRefType({-1, -1}, builder.getF32Type(), {}, 0),
+       builder.getType<RangeType>(), builder.getType<RangeType>(),
+       builder.getType<RangeType>()},
+      {});
+  Function *f =
+      new Function(builder.getUnknownLoc(), "sliceRangeConversion", funcType);
+  f->addEntryBlock();
+  module.getFunctions().push_back(f);
+
+  // Construct a linalg::SliceOp based on the result of a linalg::ViewOp.
+  // Note: SliceOp builder does not support ViewOps that are not defined by
+  // a dominating ViewOp.
+  ScopedContext scope(f);
+  ValueHandle memref(f->getArgument(0)), range1(f->getArgument(1)),
+      range2(f->getArgument(2)), range3(f->getArgument(3));
+  {
+    slice(view(memref, {range1, range2}), range3, 0);
+    ret();
+  }
+
+  // clang-format off
+  // CHECK-LABEL: @sliceRangeConversion
+  // CHECK:      %28 = llvm.undef : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK-NEXT: %29 = llvm.extractvalue %27[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK-NEXT: %30 = llvm.insertvalue %29, %28[0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK-NEXT: %31 = llvm.extractvalue %27[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK-NEXT: %32 = llvm.extractvalue %arg3[0] : !llvm<"{ i64, i64, i64 }">
+  // CHECK-NEXT: %33 = llvm.extractvalue %27[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK-NEXT: %34 = llvm.mul %32, %33 : !llvm<"i64">
+  // CHECK-NEXT: %35 = llvm.add %31, %34 : !llvm<"i64">
+  // CHECK-NEXT: %36 = llvm.insertvalue %35, %30[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK-NEXT: %37 = llvm.extractvalue %arg3[1] : !llvm<"{ i64, i64, i64 }">
+  // CHECK-NEXT: %38 = llvm.extractvalue %arg3[0] : !llvm<"{ i64, i64, i64 }">
+  // CHECK-NEXT: %39 = llvm.sub %37, %38 : !llvm<"i64">
+  // CHECK-NEXT: %40 = llvm.extractvalue %27[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK-NEXT: %41 = llvm.extractvalue %arg3[2] : !llvm<"{ i64, i64, i64 }">
+  // CHECK-NEXT: %42 = llvm.mul %40, %41 : !llvm<"i64">
+  // CHECK-NEXT: %43 = llvm.insertvalue %39, %36[2, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK-NEXT: %44 = llvm.insertvalue %42, %43[3, 0] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK-NEXT: %45 = llvm.extractvalue %27[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK-NEXT: %46 = llvm.extractvalue %27[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK-NEXT: %47 = llvm.insertvalue %45, %44[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK-NEXT: %48 = llvm.insertvalue %46, %47[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+  // clang-format on
+  convertToLLVM(module);
+  module.print(llvm::outs());
+}
+
+TEST_FUNC(sliceNonRangeConversion) {
+  // Define the MLIR context, create a Module in this context, and a Builder to
+  // facilitate type construction.
+  MLIRContext context;
+  Module module(&context);
+  Builder builder(&module);
+
+  // Declare a function called "sliceNonRangeConversion" with type:
+  //   (memref<?x?xf32>, !linalg.range, !linalg.range, !linalg.range) -> ()
+  // define it, and add it to the module.
+  FunctionType funcType = builder.getFunctionType(
+      {builder.getMemRefType({-1, -1}, builder.getF32Type(), {}, 0),
+       builder.getType<RangeType>(), builder.getType<RangeType>(),
+       builder.getIndexType()},
+      {});
+  Function *f = new Function(builder.getUnknownLoc(), "sliceNonRangeConversion",
+                             funcType);
+  f->addEntryBlock();
+  module.getFunctions().push_back(f);
+
+  // Construct a linalg::SliceOp based on the result of a linalg::ViewOp.
+  // Note: SliceOp builder does not support ViewOps that are not defined by
+  // a dominating ViewOp.
+  ScopedContext scope(f);
+  ValueHandle memref(f->getArgument(0)), range1(f->getArgument(1)),
+      range2(f->getArgument(2)), index(f->getArgument(3));
+  {
+    slice(view(memref, {range1, range2}), index, 0);
+    ret();
+  }
+
+  // CHECK-LABEL: @sliceNonRangeConversion
+  // CHECK:      %28 = llvm.undef : !llvm<"{ float*, i64, [1 x i64], [1 x i64]
+  // }"> CHECK-NEXT: %29 = llvm.extractvalue %27[0] : !llvm<"{ float*, i64, [2 x
+  // i64], [2 x i64] }"> CHECK-NEXT: %30 = llvm.insertvalue %29, %28[0] :
+  // !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> CHECK-NEXT: %31 =
+  // llvm.extractvalue %27[1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">
+  // CHECK-NEXT: %32 = llvm.extractvalue %27[3, 0] : !llvm<"{ float*, i64, [2 x
+  // i64], [2 x i64] }"> CHECK-NEXT: %33 = llvm.mul %arg3, %32 : !llvm<"i64">
+  // CHECK-NEXT: %34 = llvm.add %31, %33 : !llvm<"i64">
+  // CHECK-NEXT: %35 = llvm.insertvalue %34, %30[1] : !llvm<"{ float*, i64, [1 x
+  // i64], [1 x i64] }"> CHECK-NEXT: %36 = llvm.extractvalue %27[2, 1] :
+  // !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> CHECK-NEXT: %37 =
+  // llvm.extractvalue %27[3, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64]
+  // }"> CHECK-NEXT: %38 = llvm.insertvalue %36, %35[2, 0] : !llvm<"{ float*,
+  // i64, [1 x i64], [1 x i64] }"> CHECK-NEXT: %39 = llvm.insertvalue %37,
+  // %38[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
+  convertToLLVM(module);
+  module.print(llvm::outs());
+}
+
+int main() {
+  RUN_TESTS();
+  return 0;
+}
index 5899f82..5bbd8e4 100644 (file)
@@ -55,7 +55,7 @@ TEST_FUNC(view_op) {
     v2 = view(A2, ArrayRef<ValueHandle>{r0, r0});
   some_consumer(ArrayRef<ValueHandle>{v0, v1, v2});
   ret();
-  // CHECK-LABEL: func @view_op(%arg0: index, %arg1: index, %arg2: index) {
+  // CHECK-LABEL: func @view_op
   //       CHECK:   %[[R:.*]] = linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg<"range">
   //  CHECK-NEXT:  {{.*}} = linalg.view {{.*}}[] : !linalg<"view<0xf32>">
   //  CHECK-NEXT:  {{.*}} = linalg.view {{.*}}[%[[R]]] : !linalg<"view<f32>">
diff --git a/mlir/tutorial/Linalg1/include/linalg1/ConvertToLLVMDialect.h b/mlir/tutorial/Linalg1/include/linalg1/ConvertToLLVMDialect.h
new file mode 100644 (file)
index 0000000..bf3002f
--- /dev/null
@@ -0,0 +1,29 @@
+//===- ConvertToLLVMDialect.h - conversion from Linalg to LLVM --*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#ifndef LINALG_CONVERTTOLLVMDIALECT_H_
+#define LINALG_CONVERTTOLLVMDIALECT_H_
+
+namespace mlir {
+class Module;
+} // end namespace mlir
+
+namespace linalg {
+void convertToLLVM(mlir::Module &module);
+} // end namespace linalg
+
+#endif // LINALG_CONVERTTOLLVMDIALECT_H_
index 8fa88cb..1d79784 100644 (file)
@@ -78,6 +78,10 @@ public:
   /// Returns the element Type of the ViewType of `getParentView()`.
   mlir::Type getParentElementType();
 
+  /// Returns true if the rank of the part view is greater than the rank of
+  /// the child view.
+  bool isRankDecreasing();
+
   // Get all the indexings in this slice.
   mlir::Operation::operand_range getIndexings();
 };
diff --git a/mlir/tutorial/Linalg1/lib/ConvertToLLVMDialect.cpp b/mlir/tutorial/Linalg1/lib/ConvertToLLVMDialect.cpp
new file mode 100644 (file)
index 0000000..d7cb918
--- /dev/null
@@ -0,0 +1,522 @@
+//===- ConvertToLLVMDialect.cpp - conversion from Linalg to LLVM dialect --===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+
+#include "mlir/EDSC/Builders.h"
+#include "mlir/EDSC/Intrinsics.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/MLIRContext.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/Operation.h"
+#include "mlir/IR/PatternMatch.h"
+#include "mlir/IR/StandardTypes.h"
+#include "mlir/IR/Types.h"
+#include "mlir/LLVMIR/LLVMDialect.h"
+#include "mlir/LLVMIR/Transforms.h"
+#include "mlir/Pass/PassManager.h"
+#include "mlir/Support/LogicalResult.h"
+#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/Passes.h"
+
+#include "llvm/IR/DerivedTypes.h"
+#include "llvm/IR/Module.h"
+#include "llvm/IR/Type.h"
+#include "llvm/Support/Allocator.h"
+#include "llvm/Support/ErrorHandling.h"
+
+#include "linalg1/Common.h"
+#include "linalg1/ConvertToLLVMDialect.h"
+#include "linalg1/RangeOp.h"
+#include "linalg1/RangeType.h"
+#include "linalg1/SliceOp.h"
+#include "linalg1/Types.h"
+#include "linalg1/ViewOp.h"
+#include "linalg1/ViewType.h"
+
+using namespace mlir;
+
+// Convert the given type to the LLVM IR Dialect type.  The following
+// conversions are supported:
+//   - an Index type is converted into an LLVM integer type with pointer
+//     bitwidth (analogous to intptr_t in C);
+//   - an Integer type is converted into an LLVM integer type of the same width;
+//   - an F32 type is converted into an LLVM float type
+//   - a Memref, Range, or View is converted into an LLVM structure type
+//     containing the respective dynamic values.
+LLVM::LLVMType convertType(Type t) {
+  auto *context = t.getContext();
+  auto *dialect =
+      static_cast<LLVM::LLVMDialect *>(context->getRegisteredDialect("llvm"));
+
+  // Simple conversions.
+  if (t.isa<IndexType>()) {
+    int width = dialect->getLLVMModule().getDataLayout().getPointerSizeInBits();
+    auto *integerTy = llvm::IntegerType::get(dialect->getLLVMContext(), width);
+    return LLVM::LLVMType::get(context, integerTy);
+  }
+  if (auto intTy = t.dyn_cast<IntegerType>()) {
+    int width = intTy.getWidth();
+    auto *integerTy = llvm::IntegerType::get(dialect->getLLVMContext(), width);
+    return LLVM::LLVMType::get(context, integerTy);
+  }
+  if (t.isF32()) {
+    auto *floatTy = llvm::Type::getFloatTy(dialect->getLLVMContext());
+    return LLVM::LLVMType::get(context, floatTy);
+  }
+
+  // Memref descriptor contains the pointer to the data buffer, followed by
+  // as many 64-bit integers as the memref has dynamic sizes.  These integers
+  // store the actual value of the dynamic size.
+  //
+  // template <typename Elem, size_t NumDynamicRanks>
+  // struct {
+  //   Elem *ptr;
+  //   int64_t dynRank_0, dynRank_1, ... dynRank_#NumDynamicRanks
+  // };
+  if (auto memrefTy = t.dyn_cast<MemRefType>()) {
+    auto *elementTy =
+        convertType(memrefTy.getElementType()).getUnderlyingType();
+    if (memrefTy.hasStaticShape())
+      return LLVM::LLVMType::get(context, elementTy->getPointerTo());
+
+    int width = dialect->getLLVMModule().getDataLayout().getPointerSizeInBits();
+    auto *sizeTy = llvm::IntegerType::get(dialect->getLLVMContext(), width);
+    SmallVector<llvm::Type *, 4> types(1 + memrefTy.getNumDynamicDims(),
+                                       sizeTy);
+    types[0] = elementTy->getPointerTo();
+    return LLVM::LLVMType::get(
+        context, llvm::StructType::get(dialect->getLLVMContext(), types));
+  }
+
+  // Range descriptor contains the range bounds and the step as 64-bit integers.
+  //
+  // struct {
+  //   int64_t min;
+  //   int64_t max;
+  //   int64_t step;
+  // };
+  if (auto rangeTy = t.dyn_cast<linalg::RangeType>()) {
+    auto *int64Ty = llvm::Type::getInt64Ty(dialect->getLLVMContext());
+    auto *structTy = llvm::StructType::get(int64Ty, int64Ty, int64Ty);
+    return LLVM::LLVMType::get(context, structTy);
+  }
+
+  // View descriptor contains the pointer to the data buffer, followed by a
+  // 64-bit integer containing the distance between the beginning of the buffer
+  // and the first element to be accessed through the view, followed by two
+  // arrays, each containing as many 64-bit integers as the rank of the View.
+  // The first array represents the size, in number of original elements, of the
+  // view along the given dimension.  When taking the view, the size is the
+  // difference between the upper and the lower bound of the range.  The second
+  // array represents the "stride" (in tensor abstraction sense), i.e. the
+  // number of consecutive elements of the underlying buffer that separate two
+  // consecutive elements addressable through the view along the given
+  // dimension.  When taking the view, the strides are constructed as products
+  // of the original sizes along the trailing dimensions, multiplied by the view
+  // step.  For example, a view of a MxN memref with ranges {0:M:1}, {0:N:1},
+  // i.e. the view of a complete memref, will have strides N and 1.  A view with
+  // ranges {0:M:2}, {0:N:3} will have strides 2*N and 3.
+  //
+  // template <typename Elem, size_t Rank>
+  // struct {
+  //   Elem *ptr;
+  //   int64_t offset;
+  //   int64_t sizes[Rank];
+  //   int64_t strides[Rank];
+  // };
+  if (auto viewTy = t.dyn_cast<linalg::ViewType>()) {
+    auto *elemTy = convertType(viewTy.getElementType())
+                       .getUnderlyingType()
+                       ->getPointerTo();
+    auto *int64Ty = llvm::Type::getInt64Ty(dialect->getLLVMContext());
+    auto *arrayTy = llvm::ArrayType::get(int64Ty, viewTy.getRank());
+    auto *structTy = llvm::StructType::get(elemTy, int64Ty, arrayTy, arrayTy);
+    return LLVM::LLVMType::get(context, structTy);
+  }
+
+  llvm_unreachable("unsupported type");
+}
+
+// Create an array attribute containing integer attributes with values provided
+// in `position`.
+static ArrayAttr makePositionAttr(FuncBuilder &builder,
+                                  ArrayRef<int> position) {
+  SmallVector<Attribute, 4> attrs;
+  attrs.reserve(position.size());
+  for (auto p : position)
+    attrs.push_back(builder.getIntegerAttr(builder.getIntegerType(64), p));
+  return builder.getArrayAttr(attrs);
+}
+
+// Expose some LLVM IR instructions to declarative builders.
+namespace intrinsics {
+using undef = edsc::intrinsics::ValueBuilder<LLVM::UndefOp>;
+using insertvalue = edsc::intrinsics::ValueBuilder<LLVM::InsertValueOp>;
+using extractvalue = edsc::intrinsics::ValueBuilder<LLVM::ExtractValueOp>;
+using constant = edsc::intrinsics::ValueBuilder<LLVM::ConstantOp>;
+using add = edsc::intrinsics::ValueBuilder<LLVM::AddOp>;
+using sub = edsc::intrinsics::ValueBuilder<LLVM::SubOp>;
+using mul = edsc::intrinsics::ValueBuilder<LLVM::MulOp>;
+} // end namespace intrinsics
+
+// RangeOp creates a new range descriptor.
+class RangeOpConversion : public DialectOpConversion {
+public:
+  explicit RangeOpConversion(MLIRContext *context)
+      : DialectOpConversion(linalg::RangeOp::getOperationName(), 1, context) {}
+
+  PatternMatchResult match(Operation *op) const override {
+    if (op->isa<linalg::RangeOp>())
+      return matchSuccess();
+    return matchFailure();
+  }
+
+  SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
+                                  FuncBuilder &rewriter) const override {
+    auto rangeOp = op->cast<linalg::RangeOp>();
+    auto rangeDescriptorType = convertType(rangeOp.getResult()->getType());
+
+    using namespace intrinsics;
+    auto context = edsc::ScopedContext(rewriter, op->getLoc());
+
+    // Fill in an aggregate value of the descriptor.
+    Value *rangeDescriptor = undef(rangeDescriptorType);
+    rangeDescriptor = insertvalue(rangeDescriptorType, rangeDescriptor,
+                                  operands[0], makePositionAttr(rewriter, 0));
+    rangeDescriptor = insertvalue(rangeDescriptorType, rangeDescriptor,
+                                  operands[1], makePositionAttr(rewriter, 1));
+    rangeDescriptor = insertvalue(rangeDescriptorType, rangeDescriptor,
+                                  operands[2], makePositionAttr(rewriter, 2));
+    return {rangeDescriptor};
+  }
+};
+
+class ViewOpConversion : public DialectOpConversion {
+public:
+  explicit ViewOpConversion(MLIRContext *context)
+      : DialectOpConversion(linalg::ViewOp::getOperationName(), 1, context) {}
+
+  PatternMatchResult match(Operation *op) const override {
+    if (op->isa<linalg::ViewOp>())
+      return matchSuccess();
+    return matchFailure();
+  }
+
+  SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
+                                  FuncBuilder &rewriter) const override {
+    auto viewOp = op->cast<linalg::ViewOp>();
+    auto viewDescriptorType = convertType(viewOp.getViewType());
+    auto memrefType =
+        viewOp.getSupportingMemRef()->getType().cast<MemRefType>();
+    auto int64Ty = convertType(rewriter.getIntegerType(64));
+
+    // Helper function to create an integer array attribute out of a list of
+    // values.
+    auto pos = [&rewriter](ArrayRef<int> values) {
+      return makePositionAttr(rewriter, values);
+    };
+
+    // Helper function to emit an LLVMIR Dialect 64-bit integer constant given
+    // its value.
+    auto i64cst = [&rewriter, int64Ty](int64_t value) {
+      return intrinsics::constant(
+          int64Ty, IntegerAttr::get(rewriter.getIndexType(), value));
+    };
+
+    // Helper function to obtain the size of the given `memref` along the
+    // dimension `dim`.  For static dimensions, emits a constant; for dynamic
+    // dimensions, extracts the size from the memref descriptor.
+    auto memrefSize = [int64Ty, pos, i64cst](MemRefType type, Value *memref,
+                                             int dim) -> Value * {
+      assert(dim < type.getRank());
+      if (type.getShape()[dim] != -1) {
+        return i64cst(type.getShape()[dim]);
+      }
+      int dynamicDimPos = 0;
+      for (int i = 0; i < dim; ++i)
+        if (type.getShape()[i] == -1)
+          ++dynamicDimPos;
+      return intrinsics::extractvalue(int64Ty, memref, pos(1 + dynamicDimPos));
+    };
+
+    // Helper function to obtain the data pointer of the given `memref`.
+    auto memrefPtr = [pos](MemRefType type, Value *memref) -> Value * {
+      if (type.hasStaticShape())
+        return memref;
+
+      auto elementTy = LLVM::LLVMType::get(type.getContext(),
+                                           convertType(type.getElementType())
+                                               .getUnderlyingType()
+                                               ->getPointerTo());
+      return intrinsics::extractvalue(elementTy, memref, pos(0));
+    };
+
+    using namespace intrinsics;
+    auto context = edsc::ScopedContext(rewriter, op->getLoc());
+
+    // Declare the view descriptor.
+    Value *viewDescriptor = undef(viewDescriptorType);
+    // Insert the data pointer.
+    Value *bufferPtr = memrefPtr(memrefType, operands[0]);
+    viewDescriptor =
+        insertvalue(viewDescriptorType, viewDescriptor, bufferPtr, pos(0));
+
+    // Collect all memref sizes but the first, which are needed for further
+    // computation.
+    SmallVector<Value *, 4> trueSizes(memrefType.getRank());
+    for (int i = 1, e = memrefType.getRank(); i < e; ++i) {
+      trueSizes[i] = memrefSize(memrefType, operands[0], i);
+    }
+
+    // Compute all strides of the memref.
+    SmallVector<Value *, 4> trueStrides(memrefType.getRank());
+    if (viewOp.getRank() != 0)
+      trueStrides[memrefType.getRank() - 1] = i64cst(1);
+    for (int i = memrefType.getRank() - 2; i >= 0; --i)
+      trueStrides[i] = mul(trueStrides[i + 1], trueSizes[i + 1]);
+
+    // Compute and insert the base offset.
+    Value *baseOffset = i64cst(0);
+    for (int j = 0, e = memrefType.getRank(); j < e; ++j) {
+      Value *indexing = operands[1 + j];
+      Value *min = viewOp.getIndexing(j)->getType().isa<linalg::RangeType>()
+                       ? (Value *)extractvalue(int64Ty, indexing, pos(0))
+                       : indexing;
+      Value *product = mul(min, trueStrides[j]);
+      baseOffset = add(baseOffset, product);
+    }
+    viewDescriptor =
+        insertvalue(viewDescriptorType, viewDescriptor, baseOffset, pos(1));
+
+    // Compute and insert view sizes (max - min along the range).  Skip the
+    // non-range operands as they will be projected away from the view.
+    int i = 0;
+    for (Value *index : viewOp.getIndexings()) {
+      if (!index->getType().isa<linalg::RangeType>())
+        continue;
+
+      Value *rangeDescriptor = operands[1 + i];
+      Value *min = extractvalue(int64Ty, rangeDescriptor, pos(0));
+      Value *max = extractvalue(int64Ty, rangeDescriptor, pos(1));
+      Value *size = sub(max, min);
+
+      viewDescriptor =
+          insertvalue(viewDescriptorType, viewDescriptor, size, pos({2, i}));
+      ++i;
+    }
+
+    // Compute and insert view strides.  Step over the strides that correspond
+    // to non-range operands as they are projected away from the view.
+    i = 0;
+    for (int j = 0, e = trueStrides.size(); j < e; ++j) {
+      if (!viewOp.getIndexing(j)->getType().isa<linalg::RangeType>())
+        continue;
+      Value *step = extractvalue(int64Ty, operands[1 + j], pos(2));
+      Value *stride = mul(trueStrides[j], step);
+      viewDescriptor =
+          insertvalue(viewDescriptorType, viewDescriptor, stride, pos({3, i}));
+      ++i;
+    }
+
+    return {viewDescriptor};
+  }
+};
+
+class SliceOpConversion : public DialectOpConversion {
+public:
+  explicit SliceOpConversion(MLIRContext *context)
+      : DialectOpConversion(linalg::SliceOp::getOperationName(), 1, context) {}
+
+  PatternMatchResult match(Operation *op) const override {
+    if (op->isa<linalg::SliceOp>())
+      return matchSuccess();
+    return matchFailure();
+  }
+
+  SmallVector<Value *, 4> rewrite(Operation *op, ArrayRef<Value *> operands,
+                                  FuncBuilder &rewriter) const override {
+    auto sliceOp = op->cast<linalg::SliceOp>();
+    auto newViewDescriptorType = convertType(sliceOp.getViewType());
+    auto elementType =
+        rewriter.getType<LLVM::LLVMType>(convertType(sliceOp.getElementType())
+                                             .getUnderlyingType()
+                                             ->getPointerTo());
+    auto int64Ty = convertType(rewriter.getIntegerType(64));
+
+    auto pos = [&rewriter](ArrayRef<int> values) {
+      return makePositionAttr(rewriter, values);
+    };
+
+    // First operand to `slice` is the old view descriptor.
+    Value *oldViewDescriptor = operands[0];
+
+    // Properties of the slice.
+    bool isRankDecreasing = sliceOp.isRankDecreasing();
+    int dim = sliceOp.getSlicingDim();
+    assert(isRankDecreasing ^
+           sliceOp.getIndexing()->getType().isa<linalg::RangeType>());
+
+    // Declare the descriptor of the new view.
+    using namespace intrinsics;
+    auto edscContext = edsc::ScopedContext(rewriter, op->getLoc());
+    Value *newViewDescriptor = undef(newViewDescriptorType);
+
+    // Copy the buffer pointer from the old descriptor to the new one.
+    Value *buffer = extractvalue(elementType, oldViewDescriptor, pos(0));
+    newViewDescriptor =
+        insertvalue(newViewDescriptorType, newViewDescriptor, buffer, pos(0));
+
+    // Update the base offset:
+    //   base_offset' = base_offset + min_d * stride_d
+    // where d is the dimension being sliced, min_d is the minimum value of the
+    // range (in case of a single-value slice, that value), stride_d is the
+    // stride along this dimension.
+    Value *baseOffset = extractvalue(int64Ty, oldViewDescriptor, pos(1));
+    Value *slicingValue = operands[1];
+    // If `slice` is not rank-decreasing, we need to extract the "min" value
+    // from the range descriptor.  Otherwise, we take the value directly.
+    Value *min = !isRankDecreasing
+                     ? (Value *)extractvalue(int64Ty, slicingValue, pos(0))
+                     : slicingValue;
+    Value *stride = extractvalue(int64Ty, oldViewDescriptor, pos({3, dim}));
+    baseOffset = add(baseOffset, mul(min, stride));
+    newViewDescriptor = insertvalue(newViewDescriptorType, newViewDescriptor,
+                                    baseOffset, pos(1));
+
+    // Copy the sizes and strides into the new descriptor, updating or dropping
+    // the affected dimension.  If the `slice` is rank-decreasing, the resulting
+    // view will no longer one of the dimensions, its size and stride become
+    // unnecessary and can be dropped.  Otherwise, the size of the affected
+    // updated to the size of the range and its stride is multiplied with the
+    // step of the range.
+    for (int i = 0, e = sliceOp.getRank(); i < e; ++i) {
+      int originalPos = (isRankDecreasing && i >= dim) ? i + 1 : i;
+      Value *size;
+      Value *stride;
+      if (!isRankDecreasing && i == dim) {
+        Value *upper = extractvalue(int64Ty, slicingValue, pos(1));
+        Value *lower = extractvalue(int64Ty, slicingValue, pos(0));
+        size = sub(upper, lower);
+
+        Value *previousStride =
+            extractvalue(int64Ty, oldViewDescriptor, pos({3, originalPos}));
+        Value *step = extractvalue(int64Ty, slicingValue, pos(2));
+        stride = mul(previousStride, step);
+      } else {
+        size = extractvalue(int64Ty, oldViewDescriptor, pos({2, originalPos}));
+        stride =
+            extractvalue(int64Ty, oldViewDescriptor, pos({3, originalPos}));
+      }
+      newViewDescriptor = insertvalue(newViewDescriptorType, newViewDescriptor,
+                                      size, pos({2, i}));
+      newViewDescriptor = insertvalue(newViewDescriptorType, newViewDescriptor,
+                                      stride, pos({3, i}));
+    }
+
+    return {newViewDescriptor};
+  }
+};
+
+// When converting the "some_consumer" operation, don't emit anything and
+// effectively drop it.
+class DropConsumer : public DialectOpConversion {
+public:
+  explicit DropConsumer(MLIRContext *context)
+      : DialectOpConversion("some_consumer", 1, context) {}
+
+  PatternMatchResult match(Operation *op) const override {
+    if (op->getName().getStringRef() == "some_consumer")
+      return matchSuccess();
+    return matchFailure();
+  }
+
+  SmallVector<Value *, 4> rewrite(Operation *, ArrayRef<Value *>,
+                                  FuncBuilder &) const override {
+    return {};
+  }
+};
+
+// The conversion class from Linalg to LLVMIR.
+class Lowering : public DialectConversion {
+public:
+  Lowering() {}
+
+protected:
+  // Initialize the list of converters.
+  llvm::DenseSet<DialectOpConversion *>
+  initConverters(MLIRContext *context) override {
+    converterSotrage.Reset();
+    return ConversionListBuilder<DropConsumer, RangeOpConversion,
+                                 SliceOpConversion,
+                                 ViewOpConversion>::build(&converterSotrage,
+                                                          context);
+  }
+
+  // This gets called for block and region arguments, and attributes.
+  Type convertType(Type t) override { return ::convertType(t); }
+
+  // This gets called for function signatures.  Convert function arguments and
+  // results to the LLVM types, but keep the outer function type as built-in
+  // MLIR function type.  This does not support multi-result functions because
+  // LLVM does not.
+  FunctionType convertFunctionSignatureType(
+      FunctionType t, ArrayRef<NamedAttributeList> argAttrs,
+      SmallVectorImpl<NamedAttributeList> &convertedArgAttrs) override {
+    convertedArgAttrs.reserve(argAttrs.size());
+    convertedArgAttrs.insert(convertedArgAttrs.end(), argAttrs.begin(),
+                             argAttrs.end());
+
+    SmallVector<Type, 4> argTypes;
+    argTypes.reserve(t.getNumInputs());
+    for (auto ty : t.getInputs())
+      argTypes.push_back(convertType(ty));
+
+    SmallVector<Type, 1> resultTypes;
+    resultTypes.reserve(t.getNumResults());
+    for (auto ty : t.getResults())
+      resultTypes.push_back(convertType(ty));
+    assert(t.getNumResults() <= 1 && "NYI: multi-result functions");
+
+    return FunctionType::get(argTypes, resultTypes, t.getContext());
+  }
+
+private:
+  // Storage for individual converters.
+  llvm::BumpPtrAllocator converterSotrage;
+};
+
+void linalg::convertToLLVM(mlir::Module &module) {
+  // Remove affine constructs if any by using an existing pass.
+  PassManager pm;
+  pm.addPass(createLowerAffinePass());
+  auto rr = pm.run(&module);
+  (void)rr;
+  assert(succeeded(rr) && "affine loop lowering failed");
+
+  // Convert Linalg ops to the LLVM IR dialect using the converter defined
+  // above.
+  auto r = Lowering().convert(&module);
+  (void)r;
+  assert(succeeded(r) && "conversion failed");
+
+  // Convert the remaining standard MLIR operations to the LLVM IR dialect using
+  // the default converter.
+  auto converter = createStdToLLVMConverter();
+  r = converter->convert(&module);
+  (void)r;
+  assert(succeeded(r) && "second conversion failed");
+}
index 4383743..a3bdcee 100644 (file)
@@ -65,8 +65,7 @@ mlir::LogicalResult linalg::SliceOp::verify() {
     return emitOpError(
         "first operand must be of ViewType (i.e. a ViewOp or a SliceOp)");
   auto type = getOperand(1)->getType().dyn_cast<IndexType>();
-  auto *op = getOperand(1)->getDefiningOp();
-  auto range = op ? op->dyn_cast<RangeOp>() : RangeOp();
+  auto range = getOperand(1)->getType().dyn_cast<RangeType>();
   if (!range && !type)
     return emitOpError(
         "second operand must be of RangeType (i.e. a RangeOp) or IndexType");
@@ -126,6 +125,10 @@ mlir::Type linalg::SliceOp::getParentElementType() {
   return getParentViewType().getElementType();
 }
 
+bool linalg::SliceOp::isRankDecreasing() {
+  return getParentRank() != getRank();
+}
+
 mlir::Operation::operand_range linalg::SliceOp::getIndexings() {
   return {this->getOperation()->operand_begin() + SliceOp::FirstIndexingOperand,
           this->getOperation()->operand_end()};