Post commit cleanups to the Linalg dialect
authorNicolas Vasilache <ntv@google.com>
Fri, 5 Apr 2019 20:20:05 +0000 (13:20 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 8 Apr 2019 01:20:19 +0000 (18:20 -0700)
--

PiperOrigin-RevId: 242181687

20 files changed:
mlir/examples/Linalg/Linalg1/Conversion.cpp
mlir/examples/Linalg/Linalg1/Example.cpp
mlir/examples/Linalg/Linalg1/include/linalg1/Common.h
mlir/examples/Linalg/Linalg1/include/linalg1/Utils.h
mlir/examples/Linalg/Linalg1/include/linalg1/ViewOp.h
mlir/examples/Linalg/Linalg1/lib/Dialect.cpp
mlir/examples/Linalg/Linalg1/lib/DialectConstruction.cpp [moved from mlir/examples/Linalg/Linalg1/lib/DialectRegistration.cpp with 70% similarity]
mlir/examples/Linalg/Linalg1/lib/SliceOp.cpp
mlir/examples/Linalg/Linalg1/lib/Utils.cpp
mlir/examples/Linalg/Linalg1/lib/ViewOp.cpp
mlir/examples/Linalg/Linalg2/Example.cpp
mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps-inl.h
mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps.h
mlir/examples/Linalg/Linalg2/lib/DialectConstruction.cpp [moved from mlir/examples/Linalg/Linalg2/lib/DialectRegistration.cpp with 70% similarity]
mlir/examples/Linalg/Linalg3/Conversion.cpp
mlir/examples/Linalg/Linalg3/Example.cpp
mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h
mlir/examples/Linalg/Linalg3/lib/DialectConstruction.cpp [moved from mlir/examples/Linalg/Linalg3/lib/DialectRegistration.cpp with 71% similarity]
mlir/examples/Linalg/Linalg4/Example.cpp
mlir/examples/Linalg/Linalg4/lib/Transforms.cpp

index 343bf0d..21a9645 100644 (file)
@@ -21,6 +21,7 @@
 
 #include "linalg1/Common.h"
 #include "linalg1/ConvertToLLVMDialect.h"
+#include "linalg1/Dialect.h"
 #include "linalg1/Intrinsics.h"
 #include "linalg1/Ops.h"
 #include "linalg1/Types.h"
@@ -302,6 +303,7 @@ TEST_FUNC(sliceNonRangeConversion) {
 }
 
 int main() {
+  mlir::registerDialect<linalg::LinalgDialect>();
   RUN_TESTS();
   return 0;
 }
index 5bbd8e4..c5fdda8 100644 (file)
 #include "TestHarness.h"
 
 #include "linalg1/Common.h"
+#include "linalg1/Dialect.h"
 #include "linalg1/Intrinsics.h"
 #include "linalg1/Ops.h"
 #include "linalg1/Types.h"
+#include "linalg1/Utils.h"
 #include "mlir/IR/Function.h"
 
 using namespace linalg;
@@ -47,19 +49,19 @@ TEST_FUNC(view_op) {
   // clang-format off
   ValueHandle M(f->getArgument(0)), N(f->getArgument(1)),
     A0 = alloc(floatMemRefType<0>(&context)),
-    A1 = alloc(floatMemRefType<1>(&context), ArrayRef<ValueHandle>{M}),
-    A2 = alloc(floatMemRefType<2>(&context), ArrayRef<ValueHandle>{M, N}),
+    A1 = alloc(floatMemRefType<1>(&context), {M}),
+    A2 = alloc(floatMemRefType<2>(&context), {M, N}),
     r0 = range(constant_index(3), constant_index(17), constant_index(1)),
-    v0 = view(A0),
-    v1 = view(A1, ArrayRef<ValueHandle>{r0}),
-    v2 = view(A2, ArrayRef<ValueHandle>{r0, r0});
-  some_consumer(ArrayRef<ValueHandle>{v0, v1, v2});
+    v0 = view(A0, {}),
+    v1 = view(A1, {r0}),
+    v2 = view(A2, {r0, r0});
+  some_consumer({v0, v1, v2});
   ret();
   // CHECK-LABEL: func @view_op
   //       CHECK:   %[[R:.*]] = linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg<"range">
-  //  CHECK-NEXT:  {{.*}} = linalg.view {{.*}}[] : !linalg<"view<0xf32>">
-  //  CHECK-NEXT:  {{.*}} = linalg.view {{.*}}[%[[R]]] : !linalg<"view<f32>">
-  //  CHECK-NEXT:  {{.*}} = linalg.view {{.*}}[%[[R]], %[[R]]] : !linalg<"view<f32xf32>">
+  //  CHECK-NEXT:  {{.*}} = linalg.view {{.*}}[] : !linalg<"view<f32>">
+  //  CHECK-NEXT:  {{.*}} = linalg.view {{.*}}[%[[R]]] : !linalg<"view<?xf32>">
+  //  CHECK-NEXT:  {{.*}} = linalg.view {{.*}}[%[[R]], %[[R]]] : !linalg<"view<?x?xf32>">
   // clang-format on
 
   cleanupAndPrintFunction(f);
@@ -79,10 +81,8 @@ TEST_FUNC(slice_op) {
 
   // clang-format off
   ValueHandle M(f->getArgument(0)), N(f->getArgument(1)),
-      A = alloc(floatMemRefType<2>(&context), {M, N}),
-      r1 = range(constant_index(3), constant_index(17), constant_index(1)),
-      r2 = range(constant_index(0), N, constant_index(1));
-  ViewOp vA = view(A, {r1, r2}).getValue()->getDefiningOp()->cast<ViewOp>();
+      A = alloc(floatMemRefType<2>(&context), {M, N});
+  ViewOp vA = emitAndReturnViewOpFromMemRef(A);
   IndexHandle i, j;
   LoopNestRangeBuilder({&i, &j}, vA.getRanges())({
     some_consumer(slice(vA, i, 1)),
@@ -91,22 +91,25 @@ TEST_FUNC(slice_op) {
   ret();
   // CHECK-LABEL: func @slice_op(%arg0: index, %arg1: index, %arg2: index) {
   //       CHECK: %[[ALLOC:.*]] = alloc(%arg0, %arg1) : memref<?x?xf32>
-  //  CHECK-NEXT: %[[R1:.*]] = linalg.range {{.*}}:{{.*}}:{{.*}} : !linalg<"range">
-  //  CHECK-NEXT: %[[R2:.*]] = linalg.range {{.*}}:%arg1:{{.*}} : !linalg<"range">
-  //  CHECK-NEXT: %[[V:.*]] = linalg.view %0[%[[R1]], %[[R2]]] : !linalg<"view<f32xf32>">
-  //  CHECK-NEXT: for %i0 = 3 to 17 {
-  //  CHECK-NEXT:   for %i1 = 0 to (d0) -> (d0)(%arg1) {
-  //  CHECK-NEXT:     %[[S1:.*]] = linalg.slice %[[V]][*, %i0] { dim : 1 } : !linalg<"view<f32>">
-  //  CHECK-NEXT:     "some_consumer"(%[[S1]]) : (!linalg<"view<f32>">) -> ()
-  //  CHECK-NEXT:     %[[S2:.*]] = linalg.slice %[[V]][%i1, *] { dim : 0 } : !linalg<"view<f32>">
-  //  CHECK-NEXT:     %[[S3:.*]] = linalg.slice %[[S2]][%i0] { dim : 0 } : !linalg<"view<0xf32>">
-  //  CHECK-NEXT:     "some_consumer"(%[[S3]]) : (!linalg<"view<0xf32>">) -> ()
+  //  CHECK-NEXT: %[[M:.*]] = dim %0, 0 : memref<?x?xf32>
+  //  CHECK-NEXT: %[[N:.*]] = dim %0, 1 : memref<?x?xf32>
+  //  CHECK-NEXT: %[[R1:.*]] = linalg.range {{.*}}:%[[M]]:{{.*}} : !linalg<"range">
+  //  CHECK-NEXT: %[[R2:.*]] = linalg.range {{.*}}:%[[N]]:{{.*}} : !linalg<"range">
+  //  CHECK-NEXT: %[[V:.*]] = linalg.view %0[%[[R1]], %[[R2]]] : !linalg<"view<?x?xf32>">
+  //  CHECK-NEXT: for %i0 = 0 to (d0) -> (d0)(%[[M]]) {
+  //  CHECK-NEXT:   for %i1 = 0 to (d0) -> (d0)(%[[N]]) {
+  //  CHECK-NEXT:     %[[S1:.*]] = linalg.slice %[[V]][*, %i0]  : !linalg<"view<?xf32>">
+  //  CHECK-NEXT:     "some_consumer"(%[[S1]]) : (!linalg<"view<?xf32>">) -> ()
+  //  CHECK-NEXT:     %[[S2:.*]] = linalg.slice %[[V]][%i1, *]  : !linalg<"view<?xf32>">
+  //  CHECK-NEXT:     %[[S3:.*]] = linalg.slice %[[S2]][%i0]  : !linalg<"view<f32>">
+  //  CHECK-NEXT:     "some_consumer"(%[[S3]]) : (!linalg<"view<f32>">) -> ()
   // clang-format on
 
   cleanupAndPrintFunction(f);
 }
 
 int main() {
+  mlir::registerDialect<linalg::LinalgDialect>();
   RUN_TESTS();
   return 0;
 }
index f4c50f2..6573c72 100644 (file)
@@ -49,8 +49,8 @@ namespace common {
 /// A 2-D abstraction over a flat contiguous memory region of f32 with symbolic
 /// sizes.
 template <int N>
-inline mlir::MemRefType floatMemRefType(
-    mlir::MLIRContext *context, unsigned memorySpace = 0) {
+inline mlir::MemRefType floatMemRefType(mlir::MLIRContext *context,
+                                        unsigned memorySpace = 0) {
   llvm::SmallVector<int64_t, 4> shape(N, -1);
   auto f32 = mlir::FloatType::getF32(context);
   return mlir::MemRefType::get(shape, f32, {}, memorySpace);
@@ -70,16 +70,12 @@ inline mlir::Function *makeFunction(mlir::Module &module, llvm::StringRef name,
 }
 
 /// A basic pass manager pre-populated with cleanup passes.
-inline mlir::PassManager &cleanupPassManager() {
-  static bool inited = false;
-  static mlir::PassManager pm;
-  if (!inited) {
-    pm.addPass(mlir::createCanonicalizerPass());
-    pm.addPass(mlir::createSimplifyAffineStructuresPass());
-    pm.addPass(mlir::createCSEPass());
-    pm.addPass(mlir::createCanonicalizerPass());
-    inited = true;
-  }
+inline std::unique_ptr<mlir::PassManager> cleanupPassManager() {
+  std::unique_ptr<mlir::PassManager> pm(new mlir::PassManager());
+  pm->addPass(mlir::createCanonicalizerPass());
+  pm->addPass(mlir::createSimplifyAffineStructuresPass());
+  pm->addPass(mlir::createCSEPass());
+  pm->addPass(mlir::createCanonicalizerPass());
   return pm;
 }
 
@@ -91,13 +87,14 @@ inline void cleanupAndPrintFunction(mlir::Function *f) {
   bool printToOuts = true;
   auto check = [f, &printToOuts](mlir::LogicalResult result) {
     if (failed(result)) {
-      f->dump();
-      llvm::errs() << "Failure!\n";
+      f->getContext()->emitError(f->getLoc(),
+                                 "Verification and cleanup passes failed");
       printToOuts = false;
     }
   };
+  auto pm = cleanupPassManager();
   check(f->getModule()->verify());
-  check(cleanupPassManager().run(f->getModule()));
+  check(pm->run(f->getModule()));
   if (printToOuts)
     f->print(llvm::outs());
 }
index cb6b285..3f7bb76 100644 (file)
@@ -23,10 +23,15 @@ class Value;
 } // namespace mlir
 
 namespace linalg {
+class ViewOp;
 
 /// Asserts `view` is of ViewType and returns its rank.
 unsigned getViewRank(mlir::Value *view);
 
+/// Helper function to emit and return a new ViewOp from `memRef` that is
+/// assumed to be of MemRefType. This needs to be called under a ScopedContext.
+ViewOp emitAndReturnViewOpFromMemRef(mlir::Value *memRef);
+
 } // namespace linalg
 
 #endif // LINALG1_UTILS_H_
index fec3396..fcda553 100644 (file)
@@ -40,7 +40,7 @@ public:
   static llvm::StringRef getOperationName() { return "linalg.view"; }
   static void build(mlir::Builder *b, mlir::OperationState *result,
                     mlir::Value *memRef,
-                    llvm::ArrayRef<mlir::Value *> indexings = {});
+                    llvm::ArrayRef<mlir::Value *> indexings);
   mlir::LogicalResult verify();
   static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result);
   void print(mlir::OpAsmPrinter *p);
index b6dfeb3..7b1dca4 100644 (file)
@@ -44,13 +44,13 @@ static void print(RangeType rt, raw_ostream &os) { os << "range"; }
 /// ViewType prints as:
 ///
 /// ```{.mlir}
-///   view<i8xf32xi1>
+///   view<?x?xf32>
 /// ```
 ///
 /// or
 ///
 /// ```{.mlir}
-///   view<0xf32>
+///   view<?xf32>
 /// ```
 ///
 /// for 0-D views (a.k.a pointer to a scalar value).
@@ -58,11 +58,10 @@ static void print(linalg::ViewType rt, raw_ostream &os) {
   os << "view<";
   if (rt.getRank() > 0) {
     for (unsigned i = 0, e = rt.getRank(); i < e; ++i) {
-      os << rt.getElementType() << ((i == e - 1) ? "" : "x");
+      os << "?x";
     }
-  } else {
-    os << "0x" << rt.getElementType();
   }
+  os << rt.getElementType();
   os << ">";
 }
 
@@ -1,4 +1,4 @@
-//===- DialectRegistration.cpp - Registration of the Linalg dialect -------===//
+//===- DialectConstruction.cpp - Construction of the Linalg dialect -------===//
 //
 // Copyright 2019 The MLIR Authors.
 //
@@ -15,9 +15,9 @@
 // limitations under the License.
 // =============================================================================
 //
-// This file registers the Linalg dialect and should live in a standalone
-// library. Linking with this library will create a static global object that
-// performs dialect registration.
+// This file implements the constructor for the Linalg Dialect. This is
+// explicitly separated from the core library to allow incremental buildup of
+// the codebase for the tutorial.
 //
 //===----------------------------------------------------------------------===//
 
@@ -33,7 +33,3 @@ LinalgDialect::LinalgDialect(MLIRContext *context)
   addTypes<RangeType, ViewType>();
   addOperations<RangeOp, SliceOp, ViewOp>();
 }
-
-// Dialect registration triggers the creation of a `LinalgDialect` object which
-// adds the proper types and operations to the dialect.
-static mlir::DialectRegistration<LinalgDialect> LinalgOps;
index 0b68f20..4f30fb1 100644 (file)
@@ -80,10 +80,10 @@ bool linalg::SliceOp::parse(OpAsmParser *parser, OperationState *result) {
 // A SliceOp prints as:
 //
 // ```{.mlir}
-//   linalg.slice %0[*, %i0] { dim : 1 } : !linalg<"view<f32>">
+//   linalg.slice %0[*, %i0]  : !linalg<"view<?xf32>">
 // ```
 //
-// Where %0 is an ssa-value holding a `view<f32xf32>`, %i0 is an ssa-value
+// Where %0 is an ssa-value holding a `view<?x?xf32>`, %i0 is an ssa-value
 // holding an index.
 void linalg::SliceOp::print(OpAsmPrinter *p) {
   unsigned dim = getSlicingDim();
@@ -101,8 +101,7 @@ void linalg::SliceOp::print(OpAsmPrinter *p) {
     }
     *p << ((idx == rank - 1) ? "" : ", ");
   }
-  *p << "] { " << getSlicingDimAttrName() << " : " << dim << " }"
-     << " : " << getViewType();
+  *p << "] : " << getViewType();
 }
 
 ViewType linalg::SliceOp::getViewType() { return getType().cast<ViewType>(); }
index f81930a..372c08f 100644 (file)
 //===----------------------------------------------------------------------===//
 
 #include "linalg1/Utils.h"
+#include "linalg1/Intrinsics.h"
 #include "linalg1/Ops.h"
+#include "mlir/EDSC/Helpers.h"
 #include "mlir/IR/StandardTypes.h"
 
 using namespace mlir;
+using namespace mlir::edsc;
+using namespace mlir::edsc::intrinsics;
 using namespace linalg;
+using namespace linalg::intrinsics;
 
 unsigned linalg::getViewRank(Value *view) {
   assert(view->getType().isa<ViewType>() && "expected a ViewType");
@@ -32,3 +37,15 @@ unsigned linalg::getViewRank(Value *view) {
     return viewOp.getRank();
   return view->getDefiningOp()->cast<SliceOp>().getRank();
 }
+
+ViewOp linalg::emitAndReturnViewOpFromMemRef(Value *memRef) {
+  // Syntactic sugar helper to extract and emit view-like information from an
+  // mlir::MemRef without boilerplate.
+  mlir::edsc::MemRefView v(memRef);
+  SmallVector<Value *, 8> indices(v.rank());
+  for (unsigned i = 0; i < v.rank(); ++i) {
+    indices[i] = range(v.lb(i), v.ub(i), constant_index(v.step(i)));
+  }
+  return ScopedContext::getBuilder()->create<ViewOp>(
+      ScopedContext::getLocation(), memRef, indices);
+}
index 97779f4..6564391 100644 (file)
@@ -97,7 +97,7 @@ bool linalg::ViewOp::parse(OpAsmParser *parser, OperationState *result) {
 // A ViewOp prints as:
 //
 // ```{.mlir}
-//   linalg.view %0[%1, %2] : !linalg<"view<f32xf32>">
+//   linalg.view %0[%1, %2] : !linalg<"view<?x?xf32>">
 // ```
 //
 // Where %0 is an ssa-value holding a MemRef, %1 and %2 are ssa-value each
index 717de06..c239cf3 100644 (file)
@@ -19,6 +19,7 @@
 
 #include "TestHarness.h"
 #include "linalg1/Common.h"
+#include "linalg1/Dialect.h"
 #include "linalg2/Intrinsics.h"
 #include "linalg2/Ops.h"
 #include "linalg2/Transforms.h"
@@ -57,13 +58,13 @@ TEST_FUNC(linalg_ops) {
   dot(sA, sB, ssC);
   ret();
   // CHECK-LABEL: func @linalg_ops(%arg0: index, %arg1: index, %arg2: index) {
-  //       CHECK: {{.*}} = linalg.slice {{.*}}[*, {{.*}}] { dim : 1 } : !linalg<"view<f32>">
-  //  CHECK-NEXT: {{.*}} = linalg.slice {{.*}}[*, {{.*}}] { dim : 1 } : !linalg<"view<f32>">
-  //  CHECK-NEXT: {{.*}} = linalg.slice {{.*}}[{{.*}}, *] { dim : 0 } : !linalg<"view<f32>">
-  //  CHECK-NEXT: {{.*}} = linalg.slice {{.*}}[{{.*}}] { dim : 0 } : !linalg<"view<0xf32>">
-  //  CHECK-NEXT: linalg.matmul {{{.*}}, {{.*}}} -> {{{.*}}}
-  //  CHECK-NEXT: linalg.matvec {{{.*}}, {{.*}}} -> {{{.*}}}
-  //  CHECK-NEXT: linalg.dot {{{.*}}, {{.*}}} -> {{{.*}}}
+  //       CHECK: {{.*}} = linalg.slice {{.*}}[*, {{.*}}] : !linalg<"view<?xf32>">
+  //  CHECK-NEXT: {{.*}} = linalg.slice {{.*}}[*, {{.*}}] : !linalg<"view<?xf32>">
+  //  CHECK-NEXT: {{.*}} = linalg.slice {{.*}}[{{.*}}, *] : !linalg<"view<?xf32>">
+  //  CHECK-NEXT: {{.*}} = linalg.slice {{.*}}[{{.*}}]  : !linalg<"view<f32>">
+  //       CHECK: linalg.matmul({{.*}}, {{.*}}, {{.*}}) : !linalg<"view<?x?xf32>">
+  //  CHECK-NEXT: linalg.matvec({{.*}}, {{.*}}, {{.*}}) : !linalg<"view<?xf32>">
+  //  CHECK-NEXT: linalg.dot({{.*}}, {{.*}}, {{.*}}) : !linalg<"view<f32>">
   // clang-format on
 
   cleanupAndPrintFunction(f);
@@ -96,9 +97,9 @@ TEST_FUNC(linalg_ops_folded_slices) {
   ret();
   // CHECK-LABEL: func @linalg_ops_folded_slices(%arg0: index, %arg1: index, %arg2: index) {
   //   CHECK-NOT: linalg.slice
-  //       CHECK: linalg.matmul {{{.*}}, {{.*}}} -> {{{.*}}}
-  //  CHECK-NEXT: linalg.matvec {{{.*}}, {{.*}}} -> {{{.*}}}
-  //  CHECK-NEXT: linalg.dot {{{.*}}, {{.*}}} -> {{{.*}}}
+  //       CHECK: linalg.matmul({{.*}}, {{.*}}, {{.*}}) : !linalg<"view<?x?xf32>">
+  //  CHECK-NEXT: linalg.matvec({{.*}}, {{.*}}, {{.*}}) : !linalg<"view<?xf32>">
+  //  CHECK-NEXT: linalg.dot({{.*}}, {{.*}}, {{.*}}) : !linalg<"view<f32>">
   // clang-format on
 
   f->walk<SliceOp>([](SliceOp slice) {
@@ -111,6 +112,7 @@ TEST_FUNC(linalg_ops_folded_slices) {
   cleanupAndPrintFunction(f);
 }
 int main() {
+  mlir::registerDialect<linalg::LinalgDialect>();
   RUN_TESTS();
   return 0;
 }
index 1a08e9e..940f8d7 100644 (file)
@@ -45,6 +45,12 @@ linalg::TensorContractionBase<ConcreteOp>::getOutputs() {
 }
 
 template <class ConcreteOp>
+mlir::Operation::operand_range
+linalg::TensorContractionBase<ConcreteOp>::getInputsAndOutputs() {
+  return {getInputs().begin(), getOutputs().end()};
+}
+
+template <class ConcreteOp>
 mlir::LogicalResult linalg::TensorContractionBase<ConcreteOp>::verify() {
   auto *concreteOp = static_cast<ConcreteOp *>(this)->getOperation();
   if (getNumInputs() <= 0)
@@ -85,21 +91,27 @@ bool linalg::TensorContractionBase<ConcreteOp>::parse(
 // A TensorContraction prints as:
 //
 // ```{.mlir}
-//   concrete_op_name {%0, %1} -> {%2}
+//   concrete_op_name (ssa-inputs, ssa-outputs) : output-view-types
+// ```
+//
+// for example:
+//
+// ```
+//   linalg.matmul(%0, %1, %2) : view<?x?xf32>
 // ```
 //
-// Where %0, %1 is an ssa-value holding a View, %2 is an ssa-value holding a
-// view.
+// Where %0, %1 and %2 are ssa-values of type ViewType.
 template <class ConcreteOp>
 void linalg::TensorContractionBase<ConcreteOp>::print(mlir::OpAsmPrinter *p) {
-  *p << static_cast<ConcreteOp *>(this)->getOperationName() << " {";
-  auto *lastInput = *std::prev(getInputs().end());
-  for (auto *i : getInputs()) {
-    *p << *i << ((i == lastInput) ? "} -> {" : ", ");
+  *p << static_cast<ConcreteOp *>(this)->getOperationName() << "(";
+  auto *last = *std::prev(getInputsAndOutputs().end());
+  for (auto *i : getInputsAndOutputs()) {
+    *p << *i << ((i == last) ? "" : ", ");
   }
+  *p << ") : ";
   auto *lastOutput = *std::prev(getOutputs().end());
   for (auto *o : getOutputs()) {
-    *p << *o << ((o == lastOutput) ? "}" : ",");
+    *p << o->getType() << ((o == lastOutput) ? "" : ",");
   }
 }
 
index b8d9f8f..39e51f0 100644 (file)
@@ -48,6 +48,7 @@ public:
   TensorContractionBase() = default;
   mlir::Operation::operand_range getInputs();
   mlir::Operation::operand_range getOutputs();
+  mlir::Operation::operand_range getInputsAndOutputs();
 
   /// These are better as methods calling into the ConcreteOp instead of
   /// template parameters because methods allow more generic behavior and avoid
@@ -1,4 +1,4 @@
-//===- DialectRegistration.cpp - Registration of the Linalg dialect -------===//
+//===- DialectConstruction.cpp - Construction of the Linalg dialect -------===//
 //
 // Copyright 2019 The MLIR Authors.
 //
@@ -15,9 +15,9 @@
 // limitations under the License.
 // =============================================================================
 //
-// This file registers the Linalg dialect and should live in a standalone
-// library. Linking with this library will create a static global object that
-// performs dialect registration.
+// This file implements the constructor for the Linalg Dialect. This is
+// explicitly separated from the core library to allow incremental buildup of
+// the codebase for the tutorial.
 //
 //===----------------------------------------------------------------------===//
 
@@ -31,7 +31,3 @@ LinalgDialect::LinalgDialect(mlir::MLIRContext *context)
   addTypes<RangeType, ViewType>();
   addOperations<DotOp, MatvecOp, MatmulOp, RangeOp, SliceOp, ViewOp>();
 }
-
-// Dialect registration triggers the creation of a `LinalgDialect` object which
-// adds the proper types and operations to the dialect.
-static mlir::DialectRegistration<LinalgDialect> LinalgOps;
index bdb9d69..1f0f001 100644 (file)
@@ -22,6 +22,7 @@
 #include "linalg3/ConvertToLLVMDialect.h"
 
 #include "linalg1/Common.h"
+#include "linalg1/Dialect.h"
 #include "linalg2/Intrinsics.h"
 #include "linalg3/Ops.h"
 #include "linalg3/Transforms.h"
@@ -106,6 +107,7 @@ TEST_FUNC(foo) {
 }
 
 int main() {
+  mlir::registerDialect<linalg::LinalgDialect>();
   RUN_TESTS();
   return 0;
 }
index 5f5760e..e92199c 100644 (file)
@@ -19,6 +19,7 @@
 
 #include "TestHarness.h"
 #include "linalg1/Common.h"
+#include "linalg1/Dialect.h"
 #include "linalg2/Intrinsics.h"
 #include "linalg3/Ops.h"
 #include "linalg3/Transforms.h"
@@ -68,11 +69,11 @@ TEST_FUNC(matmul_as_matvec) {
   // clang-format off
   // CHECK-LABEL: func @matmul_as_matvec(%arg0: memref<?x?xf32>, %arg1: memref<?x?xf32>, %arg2: memref<?x?xf32>) {
   //       CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
-  //       CHECK: %[[vA:.*]] = linalg.view %arg0[%{{.*}}, %{{.*}}] : !linalg<"view<f32xf32>">
+  //       CHECK: %[[vA:.*]] = linalg.view %arg0[%{{.*}}, %{{.*}}] : !linalg<"view<?x?xf32>">
   //       CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) {
-  //       CHECK:   %[[vB:.*]] = linalg.view %arg1[%{{.*}}, %{{.*}}] : !linalg<"view<f32>">
-  //       CHECK:   %[[vC:.*]] = linalg.view %arg2[%{{.*}}, %{{.*}}] : !linalg<"view<f32>">
-  //       CHECK:   linalg.matvec {%[[vA]], %[[vB]]} -> {%[[vC]]}
+  //       CHECK:   %[[vB:.*]] = linalg.view %arg1[%{{.*}}, %{{.*}}] : !linalg<"view<?xf32>">
+  //       CHECK:   %[[vC:.*]] = linalg.view %arg2[%{{.*}}, %{{.*}}] : !linalg<"view<?xf32>">
+  //       CHECK:   linalg.matvec(%[[vA]], %[[vB]], %[[vC]]) : !linalg<"view<?xf32>">
   // clang-format on
   cleanupAndPrintFunction(f);
 }
@@ -89,11 +90,11 @@ TEST_FUNC(matmul_as_dot) {
   //       CHECK: %[[M:.*]] = dim %arg0, 0 : memref<?x?xf32>
   //       CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
   //       CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) {
-  //       CHECK:   %[[vB:.*]] = linalg.view %arg1[%{{.*}}, %{{.*}}] : !linalg<"view<f32>">
+  //       CHECK:   %[[vB:.*]] = linalg.view %arg1[%{{.*}}, %{{.*}}] : !linalg<"view<?xf32>">
   //  CHECK-NEXT:   affine.for %i1 = 0 to (d0) -> (d0)(%[[M]]) {
-  //       CHECK:     %[[vA:.*]] = linalg.view %arg0[%{{.*}}, %{{.*}}] : !linalg<"view<f32>">
-  //  CHECK-NEXT:     %[[vC:.*]] = linalg.view %arg2[%{{.*}}, %{{.*}}] : !linalg<"view<0xf32>">
-  //  CHECK-NEXT:     linalg.dot {%[[vA]], %[[vB]]} -> {%[[vC]]}
+  //       CHECK:     %[[vA:.*]] = linalg.view %arg0[%{{.*}}, %{{.*}}] : !linalg<"view<?xf32>">
+  //  CHECK-NEXT:     %[[vC:.*]] = linalg.view %arg2[%{{.*}}, %{{.*}}] : !linalg<"view<f32>">
+  //  CHECK-NEXT:     linalg.dot(%[[vA]], %[[vB]], %[[vC]]) : !linalg<"view<f32>">
   // clang-format on
   cleanupAndPrintFunction(f);
 }
@@ -112,20 +113,20 @@ TEST_FUNC(matmul_as_loops) {
   //       CHECK: %[[rM:.*]] = linalg.range %c0:%[[M]]:%c1 : !linalg<"range">
   //       CHECK: %[[rN:.*]] = linalg.range %c0:%[[N]]:%c1 : !linalg<"range">
   //       CHECK: %[[rK:.*]] = linalg.range %c0:%[[K]]:%c1 : !linalg<"range">
-  //       CHECK: %[[vA:.*]] = linalg.view %arg0[%[[rM]], %[[rK]]] : !linalg<"view<f32xf32>">
-  //       CHECK: %[[vB:.*]] = linalg.view %arg1[%[[rK]], %[[rN]]] : !linalg<"view<f32xf32>">
-  //       CHECK: %[[vC:.*]] = linalg.view %arg2[%[[rM]], %[[rN]]] : !linalg<"view<f32xf32>">
+  //       CHECK: %[[vA:.*]] = linalg.view %arg0[%[[rM]], %[[rK]]] : !linalg<"view<?x?xf32>">
+  //       CHECK: %[[vB:.*]] = linalg.view %arg1[%[[rK]], %[[rN]]] : !linalg<"view<?x?xf32>">
+  //       CHECK: %[[vC:.*]] = linalg.view %arg2[%[[rM]], %[[rN]]] : !linalg<"view<?x?xf32>">
   //       CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[M]]) {
   //       CHECK:   affine.for %i1 = 0 to (d0) -> (d0)(%[[N]]) {
   //       CHECK:     affine.for %i2 = 0 to (d0) -> (d0)(%[[K]]) {
   //       CHECK:       %{{.*}} = cmpi "eq", %{{.*}} : index
-  //       CHECK:       %{{.*}} = linalg.load %[[vC]][%i0, %i1] : !linalg<"view<f32xf32>">
+  //       CHECK:       %{{.*}} = linalg.load %[[vC]][%i0, %i1] : !linalg<"view<?x?xf32>">
   //       CHECK:       %{{.*}} = select {{.*}} : f32
-  //       CHECK:       %{{.*}} = linalg.load %[[vB]][%i2, %i1] : !linalg<"view<f32xf32>">
-  //       CHECK:       %{{.*}} = linalg.load %[[vA]][%i0, %i2] : !linalg<"view<f32xf32>">
+  //       CHECK:       %{{.*}} = linalg.load %[[vB]][%i2, %i1] : !linalg<"view<?x?xf32>">
+  //       CHECK:       %{{.*}} = linalg.load %[[vA]][%i0, %i2] : !linalg<"view<?x?xf32>">
   //       CHECK:       %{{.*}} = mulf {{.*}} : f32
   //       CHECK:       %{{.*}} = addf {{.*}} : f32
-  //       CHECK:       linalg.store {{.*}}[%i0, %i1] : !linalg<"view<f32xf32>">
+  //       CHECK:       linalg.store {{.*}}[%i0, %i1] : !linalg<"view<?x?xf32>">
   // clang-format on
   cleanupAndPrintFunction(f);
 }
@@ -143,20 +144,20 @@ TEST_FUNC(matmul_as_matvec_as_loops) {
   //       CHECK: %[[M:.*]] = dim %arg0, 0 : memref<?x?xf32>
   //       CHECK: %[[N:.*]] = dim %arg2, 1 : memref<?x?xf32>
   //       CHECK: %[[K:.*]] = dim %arg0, 1 : memref<?x?xf32>
-  //       CHECK: %[[vA:.*]] = linalg.view %arg0[{{.*}}, {{.*}}] : !linalg<"view<f32xf32>">
+  //       CHECK: %[[vA:.*]] = linalg.view %arg0[{{.*}}, {{.*}}] : !linalg<"view<?x?xf32>">
   //       CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) {
-  //       CHECK:   %[[vB:.*]] = linalg.view %arg1[{{.*}}, {{.*}}] : !linalg<"view<f32>">
-  //       CHECK:   %[[vC:.*]] = linalg.view %arg2[{{.*}}, {{.*}}] : !linalg<"view<f32>">
+  //       CHECK:   %[[vB:.*]] = linalg.view %arg1[{{.*}}, {{.*}}] : !linalg<"view<?xf32>">
+  //       CHECK:   %[[vC:.*]] = linalg.view %arg2[{{.*}}, {{.*}}] : !linalg<"view<?xf32>">
   //       CHECK:   affine.for %i1 = 0 to (d0) -> (d0)(%[[M]]) {
   //       CHECK:     affine.for %i2 = 0 to (d0) -> (d0)(%[[K]]) {
   //       CHECK:        %{{.*}} = cmpi "eq", %i2, %{{.*}} : index
-  //       CHECK:        %[[C:.*]] = linalg.load %[[vC]][%i1] : !linalg<"view<f32>">
+  //       CHECK:        %[[C:.*]] = linalg.load %[[vC]][%i1] : !linalg<"view<?xf32>">
   //       CHECK:        %[[C2:.*]] = select %{{.*}}, %{{.*}}, %[[C]] : f32
-  //       CHECK:        %[[B:.*]] = linalg.load %[[vB]][%i2] : !linalg<"view<f32>">
-  //       CHECK:        %[[A:.*]] = linalg.load %[[vA]][%i1, %i2] : !linalg<"view<f32xf32>">
+  //       CHECK:        %[[B:.*]] = linalg.load %[[vB]][%i2] : !linalg<"view<?xf32>">
+  //       CHECK:        %[[A:.*]] = linalg.load %[[vA]][%i1, %i2] : !linalg<"view<?x?xf32>">
   //       CHECK:        %{{.*}} = mulf %[[A]], %[[B]] : f32
   //       CHECK:        %{{.*}} = addf %[[C2]], %{{.*}} : f32
-  //       CHECK:        linalg.store %{{.*}}, %{{.*}}[%i1] : !linalg<"view<f32>">
+  //       CHECK:        linalg.store %{{.*}}, %{{.*}}[%i1] : !linalg<"view<?xf32>">
   // clang-format on
   cleanupAndPrintFunction(f);
 }
@@ -197,6 +198,7 @@ TEST_FUNC(matmul_as_matvec_as_affine) {
 }
 
 int main() {
+  mlir::registerDialect<linalg::LinalgDialect>();
   RUN_TESTS();
   return 0;
 }
index f54c76b..9af528e 100644 (file)
@@ -57,8 +57,8 @@ makeGenericLoopRanges(mlir::AffineMap operandRangesToLoopMaps,
 /// to only use linalg.view operations.
 void composeSliceOps(mlir::Function *f);
 
-/// Traverses `f` and rewrites linalg.matmul (resp. linalg.matvec)
-/// as linalg.matvec (resp. linalg.dot).
+/// Traverses `f` and rewrites linalg.matmul(resp. linalg.matvec)
+/// as linalg.matvec(resp. linalg.dot).
 void lowerToFinerGrainedTensorContraction(mlir::Function *f);
 
 /// Operation-wise writing of linalg operations to loop form.
@@ -1,4 +1,4 @@
-//===- DialectRegistration.cpp - Registration of the Linalg dialect -------===//
+//===- DialectConstruction.cpp - Construction of the Linalg dialect -------===//
 //
 // Copyright 2019 The MLIR Authors.
 //
@@ -15,9 +15,9 @@
 // limitations under the License.
 // =============================================================================
 //
-// This file registers the Linalg dialect and should live in a standalone
-// library. Linking with this library will create a static global object that
-// performs dialect registration.
+// This file implements the constructor for the Linalg Dialect. This is
+// explicitly separated from the core library to allow incremental buildup of
+// the codebase for the tutorial.
 //
 //===----------------------------------------------------------------------===//
 
@@ -33,7 +33,3 @@ LinalgDialect::LinalgDialect(mlir::MLIRContext *context)
   addOperations<DotOp, LoadOp, MatvecOp, MatmulOp, RangeOp, SliceOp, StoreOp,
                 ViewOp>();
 }
-
-// Dialect registration triggers the creation of a `LinalgDialect` object which
-// adds the proper types and operations to the dialect.
-static mlir::DialectRegistration<LinalgDialect> LinalgOps;
index 91cc0f9..f7f1607 100644 (file)
@@ -19,6 +19,7 @@
 
 #include "TestHarness.h"
 #include "linalg1/Common.h"
+#include "linalg1/Dialect.h"
 #include "linalg2/Intrinsics.h"
 #include "linalg3/Ops.h"
 #include "linalg4/Transforms.h"
@@ -113,13 +114,13 @@ TEST_FUNC(matmul_tiled_views) {
   //  CHECK-NEXT:     %[[i0max:.*]] = affine.apply (d0) -> (d0 + 8)(%i0)
   //  CHECK-NEXT:     %[[ri0:.*]] = linalg.range %[[i0min]]:%[[i0max]]:{{.*}} : !linalg<"range">
   //       CHECK:     %[[rK:.*]] = linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg<"range">
-  //       CHECK:     %[[vA:.*]] = linalg.view %arg0[%[[ri0]], %[[rK]]] : !linalg<"view<f32xf32>">
+  //       CHECK:     %[[vA:.*]] = linalg.view %arg0[%[[ri0]], %[[rK]]] : !linalg<"view<?x?xf32>">
   //       CHECK:     %[[i1min:.*]] = affine.apply (d0) -> (d0)(%i1)
   //  CHECK-NEXT:     %[[i1max:.*]] = affine.apply (d0) -> (d0 + 9)(%i1)
   //  CHECK-NEXT:     %[[ri1:.*]] = linalg.range %[[i1min]]:%[[i1max]]:%{{.*}} : !linalg<"range">
-  //  CHECK-NEXT:     %[[vB:.*]]  = linalg.view %arg1[%10, %13] : !linalg<"view<f32xf32>">
-  //  CHECK-NEXT:     %[[vC:.*]]  = linalg.view %arg2[%5, %13] : !linalg<"view<f32xf32>">
-  //  CHECK-NEXT:     linalg.matmul {%[[vA]], %[[vB]]} -> {%[[vC]]}
+  //  CHECK-NEXT:     %[[vB:.*]]  = linalg.view %arg1[%10, %13] : !linalg<"view<?x?xf32>">
+  //  CHECK-NEXT:     %[[vC:.*]]  = linalg.view %arg2[%5, %13] : !linalg<"view<?x?xf32>">
+  //  CHECK-NEXT:     linalg.matmul(%[[vA]], %[[vB]], %[[vC]]) : !linalg<"view<?x?xf32>">
   // clang-format on
   cleanupAndPrintFunction(f);
 }
@@ -149,28 +150,29 @@ TEST_FUNC(matmul_tiled_views_as_loops) {
   //  CHECK-NEXT:     %[[i0max:.*]] = affine.apply (d0) -> (d0 + 8)(%i0)
   //  CHECK-NEXT:     %[[ri0:.*]] = linalg.range %[[i0min]]:%[[i0max]]:{{.*}} : !linalg<"range">
   //       CHECK:     %[[rK:.*]] = linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg<"range">
-  //       CHECK:     %[[vA:.*]] = linalg.view %arg0[%[[ri0]], %[[rK]]] : !linalg<"view<f32xf32>">
+  //       CHECK:     %[[vA:.*]] = linalg.view %arg0[%[[ri0]], %[[rK]]] : !linalg<"view<?x?xf32>">
   //       CHECK:     %[[i1min:.*]] = affine.apply (d0) -> (d0)(%i1)
   //  CHECK-NEXT:     %[[i1max:.*]] = affine.apply (d0) -> (d0 + 9)(%i1)
   //  CHECK-NEXT:     %[[ri1:.*]] = linalg.range %[[i1min]]:%[[i1max]]:%{{.*}} : !linalg<"range">
-  //  CHECK-NEXT:     %[[vB:.*]]  = linalg.view %arg1[%10, %13] : !linalg<"view<f32xf32>">
-  //  CHECK-NEXT:     %[[vC:.*]]  = linalg.view %arg2[%5, %13] : !linalg<"view<f32xf32>">
+  //  CHECK-NEXT:     %[[vB:.*]]  = linalg.view %arg1[%10, %13] : !linalg<"view<?x?xf32>">
+  //  CHECK-NEXT:     %[[vC:.*]]  = linalg.view %arg2[%5, %13] : !linalg<"view<?x?xf32>">
   //  CHECK-NEXT:     affine.for %i2 = (d0) -> (d0)(%i0) to (d0) -> (d0)(%[[i0max]]) {
   //  CHECK-NEXT:       affine.for %i3 = (d0) -> (d0)(%i1) to (d0) -> (d0)(%[[i1max]]) {
   //  CHECK-NEXT:         affine.for %i4 = 0 to (d0) -> (d0)(%[[K]]) {
   //  CHECK-NEXT:           %{{.*}} = cmpi "eq", %i4, %c0 : index
-  //  CHECK-NEXT:           %{{.*}} = linalg.load %[[vC]][%i2, %i3] : !linalg<"view<f32xf32>">
+  //  CHECK-NEXT:           %{{.*}} = linalg.load %[[vC]][%i2, %i3] : !linalg<"view<?x?xf32>">
   //  CHECK-NEXT:           %{{.*}} = select %{{.*}}, %cst, %{{.*}} : f32
-  //  CHECK-NEXT:           %{{.*}} = linalg.load %[[vB]][%i4, %i3] : !linalg<"view<f32xf32>">
-  //  CHECK-NEXT:           %{{.*}} = linalg.load %[[vA]][%i2, %i4] : !linalg<"view<f32xf32>">
+  //  CHECK-NEXT:           %{{.*}} = linalg.load %[[vB]][%i4, %i3] : !linalg<"view<?x?xf32>">
+  //  CHECK-NEXT:           %{{.*}} = linalg.load %[[vA]][%i2, %i4] : !linalg<"view<?x?xf32>">
   //  CHECK-NEXT:           %{{.*}} = mulf %{{.*}}, %{{.*}} : f32
   //  CHECK-NEXT:           %{{.*}} = addf %{{.*}}, %{{.*}} : f32
-  //  CHECK-NEXT:           linalg.store %{{.*}}, %[[vC]][%i2, %i3] : !linalg<"view<f32xf32>">
+  //  CHECK-NEXT:           linalg.store %{{.*}}, %[[vC]][%i2, %i3] : !linalg<"view<?x?xf32>">
   // clang-format on
   cleanupAndPrintFunction(f);
 }
 
 int main() {
+  mlir::registerDialect<linalg::LinalgDialect>();
   RUN_TESTS();
   return 0;
 }
index ece8598..05865e9 100644 (file)
@@ -51,16 +51,6 @@ void linalg::lowerToTiledLoops(mlir::Function *f,
   });
 }
 
-template <class ConcreteOp>
-static Operation::operand_range
-getInputsAndOutputs(TensorContractionBase<ConcreteOp> &contraction) {
-  auto *inst = static_cast<ConcreteOp *>(&contraction)->getOperation();
-  auto begin = inst->operand_begin();
-  auto end = inst->operand_begin() + contraction.getNumInputs() +
-             contraction.getNumOutputs();
-  return {begin, end};
-}
-
 static bool isZeroIndex(Value *v) {
   return v->getDefiningOp() && v->getDefiningOp()->isa<ConstantIndexOp>() &&
          v->getDefiningOp()->dyn_cast<ConstantIndexOp>().getValue() == 0;
@@ -138,7 +128,7 @@ makeTiledViews(linalg::TensorContractionBase<ConcreteOp> &contraction,
       makeTiledRanges(contraction, getRanges(contraction), ivs, tileSizes);
   SmallVector<Value *, 4> res;
   unsigned currentRange = 0;
-  for (auto *in : getInputsAndOutputs(contraction)) {
+  for (auto *in : contraction.getInputsAndOutputs()) {
     unsigned runningSliceDim = 0;
     auto *runningSlice = in;
     assert(runningSlice->getType().template isa<ViewType>());