From a1b4cae30a9951709c82edc4a972c9ecb9dbb63b Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Fri, 5 Apr 2019 13:20:05 -0700 Subject: [PATCH] Post commit cleanups to the Linalg dialect -- PiperOrigin-RevId: 242181687 --- mlir/examples/Linalg/Linalg1/Conversion.cpp | 2 + mlir/examples/Linalg/Linalg1/Example.cpp | 49 ++++++++++++---------- .../Linalg/Linalg1/include/linalg1/Common.h | 27 ++++++------ .../Linalg/Linalg1/include/linalg1/Utils.h | 5 +++ .../Linalg/Linalg1/include/linalg1/ViewOp.h | 2 +- mlir/examples/Linalg/Linalg1/lib/Dialect.cpp | 9 ++-- ...ectRegistration.cpp => DialectConstruction.cpp} | 12 ++---- mlir/examples/Linalg/Linalg1/lib/SliceOp.cpp | 7 ++-- mlir/examples/Linalg/Linalg1/lib/Utils.cpp | 17 ++++++++ mlir/examples/Linalg/Linalg1/lib/ViewOp.cpp | 2 +- mlir/examples/Linalg/Linalg2/Example.cpp | 22 +++++----- .../Linalg/Linalg2/include/linalg2/TensorOps-inl.h | 28 +++++++++---- .../Linalg/Linalg2/include/linalg2/TensorOps.h | 1 + ...ectRegistration.cpp => DialectConstruction.cpp} | 12 ++---- mlir/examples/Linalg/Linalg3/Conversion.cpp | 2 + mlir/examples/Linalg/Linalg3/Example.cpp | 46 ++++++++++---------- .../Linalg/Linalg3/include/linalg3/Transforms.h | 4 +- ...ectRegistration.cpp => DialectConstruction.cpp} | 12 ++---- mlir/examples/Linalg/Linalg4/Example.cpp | 24 ++++++----- mlir/examples/Linalg/Linalg4/lib/Transforms.cpp | 12 +----- 20 files changed, 158 insertions(+), 137 deletions(-) rename mlir/examples/Linalg/Linalg1/lib/{DialectRegistration.cpp => DialectConstruction.cpp} (70%) rename mlir/examples/Linalg/Linalg2/lib/{DialectRegistration.cpp => DialectConstruction.cpp} (70%) rename mlir/examples/Linalg/Linalg3/lib/{DialectRegistration.cpp => DialectConstruction.cpp} (71%) diff --git a/mlir/examples/Linalg/Linalg1/Conversion.cpp b/mlir/examples/Linalg/Linalg1/Conversion.cpp index 343bf0d..21a9645 100644 --- a/mlir/examples/Linalg/Linalg1/Conversion.cpp +++ b/mlir/examples/Linalg/Linalg1/Conversion.cpp @@ -21,6 +21,7 @@ #include "linalg1/Common.h" #include "linalg1/ConvertToLLVMDialect.h" +#include "linalg1/Dialect.h" #include "linalg1/Intrinsics.h" #include "linalg1/Ops.h" #include "linalg1/Types.h" @@ -302,6 +303,7 @@ TEST_FUNC(sliceNonRangeConversion) { } int main() { + mlir::registerDialect(); RUN_TESTS(); return 0; } diff --git a/mlir/examples/Linalg/Linalg1/Example.cpp b/mlir/examples/Linalg/Linalg1/Example.cpp index 5bbd8e4..c5fdda8 100644 --- a/mlir/examples/Linalg/Linalg1/Example.cpp +++ b/mlir/examples/Linalg/Linalg1/Example.cpp @@ -20,9 +20,11 @@ #include "TestHarness.h" #include "linalg1/Common.h" +#include "linalg1/Dialect.h" #include "linalg1/Intrinsics.h" #include "linalg1/Ops.h" #include "linalg1/Types.h" +#include "linalg1/Utils.h" #include "mlir/IR/Function.h" using namespace linalg; @@ -47,19 +49,19 @@ TEST_FUNC(view_op) { // clang-format off ValueHandle M(f->getArgument(0)), N(f->getArgument(1)), A0 = alloc(floatMemRefType<0>(&context)), - A1 = alloc(floatMemRefType<1>(&context), ArrayRef{M}), - A2 = alloc(floatMemRefType<2>(&context), ArrayRef{M, N}), + A1 = alloc(floatMemRefType<1>(&context), {M}), + A2 = alloc(floatMemRefType<2>(&context), {M, N}), r0 = range(constant_index(3), constant_index(17), constant_index(1)), - v0 = view(A0), - v1 = view(A1, ArrayRef{r0}), - v2 = view(A2, ArrayRef{r0, r0}); - some_consumer(ArrayRef{v0, v1, v2}); + v0 = view(A0, {}), + v1 = view(A1, {r0}), + v2 = view(A2, {r0, r0}); + some_consumer({v0, v1, v2}); ret(); // CHECK-LABEL: func @view_op // CHECK: %[[R:.*]] = linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg<"range"> - // CHECK-NEXT: {{.*}} = linalg.view {{.*}}[] : !linalg<"view<0xf32>"> - // CHECK-NEXT: {{.*}} = linalg.view {{.*}}[%[[R]]] : !linalg<"view"> - // CHECK-NEXT: {{.*}} = linalg.view {{.*}}[%[[R]], %[[R]]] : !linalg<"view"> + // CHECK-NEXT: {{.*}} = linalg.view {{.*}}[] : !linalg<"view"> + // CHECK-NEXT: {{.*}} = linalg.view {{.*}}[%[[R]]] : !linalg<"view"> + // CHECK-NEXT: {{.*}} = linalg.view {{.*}}[%[[R]], %[[R]]] : !linalg<"view"> // clang-format on cleanupAndPrintFunction(f); @@ -79,10 +81,8 @@ TEST_FUNC(slice_op) { // clang-format off ValueHandle M(f->getArgument(0)), N(f->getArgument(1)), - A = alloc(floatMemRefType<2>(&context), {M, N}), - r1 = range(constant_index(3), constant_index(17), constant_index(1)), - r2 = range(constant_index(0), N, constant_index(1)); - ViewOp vA = view(A, {r1, r2}).getValue()->getDefiningOp()->cast(); + A = alloc(floatMemRefType<2>(&context), {M, N}); + ViewOp vA = emitAndReturnViewOpFromMemRef(A); IndexHandle i, j; LoopNestRangeBuilder({&i, &j}, vA.getRanges())({ some_consumer(slice(vA, i, 1)), @@ -91,22 +91,25 @@ TEST_FUNC(slice_op) { ret(); // CHECK-LABEL: func @slice_op(%arg0: index, %arg1: index, %arg2: index) { // CHECK: %[[ALLOC:.*]] = alloc(%arg0, %arg1) : memref - // CHECK-NEXT: %[[R1:.*]] = linalg.range {{.*}}:{{.*}}:{{.*}} : !linalg<"range"> - // CHECK-NEXT: %[[R2:.*]] = linalg.range {{.*}}:%arg1:{{.*}} : !linalg<"range"> - // CHECK-NEXT: %[[V:.*]] = linalg.view %0[%[[R1]], %[[R2]]] : !linalg<"view"> - // CHECK-NEXT: for %i0 = 3 to 17 { - // CHECK-NEXT: for %i1 = 0 to (d0) -> (d0)(%arg1) { - // CHECK-NEXT: %[[S1:.*]] = linalg.slice %[[V]][*, %i0] { dim : 1 } : !linalg<"view"> - // CHECK-NEXT: "some_consumer"(%[[S1]]) : (!linalg<"view">) -> () - // CHECK-NEXT: %[[S2:.*]] = linalg.slice %[[V]][%i1, *] { dim : 0 } : !linalg<"view"> - // CHECK-NEXT: %[[S3:.*]] = linalg.slice %[[S2]][%i0] { dim : 0 } : !linalg<"view<0xf32>"> - // CHECK-NEXT: "some_consumer"(%[[S3]]) : (!linalg<"view<0xf32>">) -> () + // CHECK-NEXT: %[[M:.*]] = dim %0, 0 : memref + // CHECK-NEXT: %[[N:.*]] = dim %0, 1 : memref + // CHECK-NEXT: %[[R1:.*]] = linalg.range {{.*}}:%[[M]]:{{.*}} : !linalg<"range"> + // CHECK-NEXT: %[[R2:.*]] = linalg.range {{.*}}:%[[N]]:{{.*}} : !linalg<"range"> + // CHECK-NEXT: %[[V:.*]] = linalg.view %0[%[[R1]], %[[R2]]] : !linalg<"view"> + // CHECK-NEXT: for %i0 = 0 to (d0) -> (d0)(%[[M]]) { + // CHECK-NEXT: for %i1 = 0 to (d0) -> (d0)(%[[N]]) { + // CHECK-NEXT: %[[S1:.*]] = linalg.slice %[[V]][*, %i0] : !linalg<"view"> + // CHECK-NEXT: "some_consumer"(%[[S1]]) : (!linalg<"view">) -> () + // CHECK-NEXT: %[[S2:.*]] = linalg.slice %[[V]][%i1, *] : !linalg<"view"> + // CHECK-NEXT: %[[S3:.*]] = linalg.slice %[[S2]][%i0] : !linalg<"view"> + // CHECK-NEXT: "some_consumer"(%[[S3]]) : (!linalg<"view">) -> () // clang-format on cleanupAndPrintFunction(f); } int main() { + mlir::registerDialect(); RUN_TESTS(); return 0; } diff --git a/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h b/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h index f4c50f2..6573c72 100644 --- a/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h +++ b/mlir/examples/Linalg/Linalg1/include/linalg1/Common.h @@ -49,8 +49,8 @@ namespace common { /// A 2-D abstraction over a flat contiguous memory region of f32 with symbolic /// sizes. template -inline mlir::MemRefType floatMemRefType( - mlir::MLIRContext *context, unsigned memorySpace = 0) { +inline mlir::MemRefType floatMemRefType(mlir::MLIRContext *context, + unsigned memorySpace = 0) { llvm::SmallVector shape(N, -1); auto f32 = mlir::FloatType::getF32(context); return mlir::MemRefType::get(shape, f32, {}, memorySpace); @@ -70,16 +70,12 @@ inline mlir::Function *makeFunction(mlir::Module &module, llvm::StringRef name, } /// A basic pass manager pre-populated with cleanup passes. -inline mlir::PassManager &cleanupPassManager() { - static bool inited = false; - static mlir::PassManager pm; - if (!inited) { - pm.addPass(mlir::createCanonicalizerPass()); - pm.addPass(mlir::createSimplifyAffineStructuresPass()); - pm.addPass(mlir::createCSEPass()); - pm.addPass(mlir::createCanonicalizerPass()); - inited = true; - } +inline std::unique_ptr cleanupPassManager() { + std::unique_ptr pm(new mlir::PassManager()); + pm->addPass(mlir::createCanonicalizerPass()); + pm->addPass(mlir::createSimplifyAffineStructuresPass()); + pm->addPass(mlir::createCSEPass()); + pm->addPass(mlir::createCanonicalizerPass()); return pm; } @@ -91,13 +87,14 @@ inline void cleanupAndPrintFunction(mlir::Function *f) { bool printToOuts = true; auto check = [f, &printToOuts](mlir::LogicalResult result) { if (failed(result)) { - f->dump(); - llvm::errs() << "Failure!\n"; + f->getContext()->emitError(f->getLoc(), + "Verification and cleanup passes failed"); printToOuts = false; } }; + auto pm = cleanupPassManager(); check(f->getModule()->verify()); - check(cleanupPassManager().run(f->getModule())); + check(pm->run(f->getModule())); if (printToOuts) f->print(llvm::outs()); } diff --git a/mlir/examples/Linalg/Linalg1/include/linalg1/Utils.h b/mlir/examples/Linalg/Linalg1/include/linalg1/Utils.h index cb6b285..3f7bb76 100644 --- a/mlir/examples/Linalg/Linalg1/include/linalg1/Utils.h +++ b/mlir/examples/Linalg/Linalg1/include/linalg1/Utils.h @@ -23,10 +23,15 @@ class Value; } // namespace mlir namespace linalg { +class ViewOp; /// Asserts `view` is of ViewType and returns its rank. unsigned getViewRank(mlir::Value *view); +/// Helper function to emit and return a new ViewOp from `memRef` that is +/// assumed to be of MemRefType. This needs to be called under a ScopedContext. +ViewOp emitAndReturnViewOpFromMemRef(mlir::Value *memRef); + } // namespace linalg #endif // LINALG1_UTILS_H_ diff --git a/mlir/examples/Linalg/Linalg1/include/linalg1/ViewOp.h b/mlir/examples/Linalg/Linalg1/include/linalg1/ViewOp.h index fec3396..fcda553 100644 --- a/mlir/examples/Linalg/Linalg1/include/linalg1/ViewOp.h +++ b/mlir/examples/Linalg/Linalg1/include/linalg1/ViewOp.h @@ -40,7 +40,7 @@ public: static llvm::StringRef getOperationName() { return "linalg.view"; } static void build(mlir::Builder *b, mlir::OperationState *result, mlir::Value *memRef, - llvm::ArrayRef indexings = {}); + llvm::ArrayRef indexings); mlir::LogicalResult verify(); static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); void print(mlir::OpAsmPrinter *p); diff --git a/mlir/examples/Linalg/Linalg1/lib/Dialect.cpp b/mlir/examples/Linalg/Linalg1/lib/Dialect.cpp index b6dfeb3..7b1dca4 100644 --- a/mlir/examples/Linalg/Linalg1/lib/Dialect.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/Dialect.cpp @@ -44,13 +44,13 @@ static void print(RangeType rt, raw_ostream &os) { os << "range"; } /// ViewType prints as: /// /// ```{.mlir} -/// view +/// view /// ``` /// /// or /// /// ```{.mlir} -/// view<0xf32> +/// view /// ``` /// /// for 0-D views (a.k.a pointer to a scalar value). @@ -58,11 +58,10 @@ static void print(linalg::ViewType rt, raw_ostream &os) { os << "view<"; if (rt.getRank() > 0) { for (unsigned i = 0, e = rt.getRank(); i < e; ++i) { - os << rt.getElementType() << ((i == e - 1) ? "" : "x"); + os << "?x"; } - } else { - os << "0x" << rt.getElementType(); } + os << rt.getElementType(); os << ">"; } diff --git a/mlir/examples/Linalg/Linalg1/lib/DialectRegistration.cpp b/mlir/examples/Linalg/Linalg1/lib/DialectConstruction.cpp similarity index 70% rename from mlir/examples/Linalg/Linalg1/lib/DialectRegistration.cpp rename to mlir/examples/Linalg/Linalg1/lib/DialectConstruction.cpp index 3c05827..0eaab6a 100644 --- a/mlir/examples/Linalg/Linalg1/lib/DialectRegistration.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/DialectConstruction.cpp @@ -1,4 +1,4 @@ -//===- DialectRegistration.cpp - Registration of the Linalg dialect -------===// +//===- DialectConstruction.cpp - Construction of the Linalg dialect -------===// // // Copyright 2019 The MLIR Authors. // @@ -15,9 +15,9 @@ // limitations under the License. // ============================================================================= // -// This file registers the Linalg dialect and should live in a standalone -// library. Linking with this library will create a static global object that -// performs dialect registration. +// This file implements the constructor for the Linalg Dialect. This is +// explicitly separated from the core library to allow incremental buildup of +// the codebase for the tutorial. // //===----------------------------------------------------------------------===// @@ -33,7 +33,3 @@ LinalgDialect::LinalgDialect(MLIRContext *context) addTypes(); addOperations(); } - -// Dialect registration triggers the creation of a `LinalgDialect` object which -// adds the proper types and operations to the dialect. -static mlir::DialectRegistration LinalgOps; diff --git a/mlir/examples/Linalg/Linalg1/lib/SliceOp.cpp b/mlir/examples/Linalg/Linalg1/lib/SliceOp.cpp index 0b68f20..4f30fb1 100644 --- a/mlir/examples/Linalg/Linalg1/lib/SliceOp.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/SliceOp.cpp @@ -80,10 +80,10 @@ bool linalg::SliceOp::parse(OpAsmParser *parser, OperationState *result) { // A SliceOp prints as: // // ```{.mlir} -// linalg.slice %0[*, %i0] { dim : 1 } : !linalg<"view"> +// linalg.slice %0[*, %i0] : !linalg<"view"> // ``` // -// Where %0 is an ssa-value holding a `view`, %i0 is an ssa-value +// Where %0 is an ssa-value holding a `view`, %i0 is an ssa-value // holding an index. void linalg::SliceOp::print(OpAsmPrinter *p) { unsigned dim = getSlicingDim(); @@ -101,8 +101,7 @@ void linalg::SliceOp::print(OpAsmPrinter *p) { } *p << ((idx == rank - 1) ? "" : ", "); } - *p << "] { " << getSlicingDimAttrName() << " : " << dim << " }" - << " : " << getViewType(); + *p << "] : " << getViewType(); } ViewType linalg::SliceOp::getViewType() { return getType().cast(); } diff --git a/mlir/examples/Linalg/Linalg1/lib/Utils.cpp b/mlir/examples/Linalg/Linalg1/lib/Utils.cpp index f81930a..372c08f 100644 --- a/mlir/examples/Linalg/Linalg1/lib/Utils.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/Utils.cpp @@ -20,11 +20,16 @@ //===----------------------------------------------------------------------===// #include "linalg1/Utils.h" +#include "linalg1/Intrinsics.h" #include "linalg1/Ops.h" +#include "mlir/EDSC/Helpers.h" #include "mlir/IR/StandardTypes.h" using namespace mlir; +using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; using namespace linalg; +using namespace linalg::intrinsics; unsigned linalg::getViewRank(Value *view) { assert(view->getType().isa() && "expected a ViewType"); @@ -32,3 +37,15 @@ unsigned linalg::getViewRank(Value *view) { return viewOp.getRank(); return view->getDefiningOp()->cast().getRank(); } + +ViewOp linalg::emitAndReturnViewOpFromMemRef(Value *memRef) { + // Syntactic sugar helper to extract and emit view-like information from an + // mlir::MemRef without boilerplate. + mlir::edsc::MemRefView v(memRef); + SmallVector indices(v.rank()); + for (unsigned i = 0; i < v.rank(); ++i) { + indices[i] = range(v.lb(i), v.ub(i), constant_index(v.step(i))); + } + return ScopedContext::getBuilder()->create( + ScopedContext::getLocation(), memRef, indices); +} diff --git a/mlir/examples/Linalg/Linalg1/lib/ViewOp.cpp b/mlir/examples/Linalg/Linalg1/lib/ViewOp.cpp index 97779f4..6564391 100644 --- a/mlir/examples/Linalg/Linalg1/lib/ViewOp.cpp +++ b/mlir/examples/Linalg/Linalg1/lib/ViewOp.cpp @@ -97,7 +97,7 @@ bool linalg::ViewOp::parse(OpAsmParser *parser, OperationState *result) { // A ViewOp prints as: // // ```{.mlir} -// linalg.view %0[%1, %2] : !linalg<"view"> +// linalg.view %0[%1, %2] : !linalg<"view"> // ``` // // Where %0 is an ssa-value holding a MemRef, %1 and %2 are ssa-value each diff --git a/mlir/examples/Linalg/Linalg2/Example.cpp b/mlir/examples/Linalg/Linalg2/Example.cpp index 717de06..c239cf3 100644 --- a/mlir/examples/Linalg/Linalg2/Example.cpp +++ b/mlir/examples/Linalg/Linalg2/Example.cpp @@ -19,6 +19,7 @@ #include "TestHarness.h" #include "linalg1/Common.h" +#include "linalg1/Dialect.h" #include "linalg2/Intrinsics.h" #include "linalg2/Ops.h" #include "linalg2/Transforms.h" @@ -57,13 +58,13 @@ TEST_FUNC(linalg_ops) { dot(sA, sB, ssC); ret(); // CHECK-LABEL: func @linalg_ops(%arg0: index, %arg1: index, %arg2: index) { - // CHECK: {{.*}} = linalg.slice {{.*}}[*, {{.*}}] { dim : 1 } : !linalg<"view"> - // CHECK-NEXT: {{.*}} = linalg.slice {{.*}}[*, {{.*}}] { dim : 1 } : !linalg<"view"> - // CHECK-NEXT: {{.*}} = linalg.slice {{.*}}[{{.*}}, *] { dim : 0 } : !linalg<"view"> - // CHECK-NEXT: {{.*}} = linalg.slice {{.*}}[{{.*}}] { dim : 0 } : !linalg<"view<0xf32>"> - // CHECK-NEXT: linalg.matmul {{{.*}}, {{.*}}} -> {{{.*}}} - // CHECK-NEXT: linalg.matvec {{{.*}}, {{.*}}} -> {{{.*}}} - // CHECK-NEXT: linalg.dot {{{.*}}, {{.*}}} -> {{{.*}}} + // CHECK: {{.*}} = linalg.slice {{.*}}[*, {{.*}}] : !linalg<"view"> + // CHECK-NEXT: {{.*}} = linalg.slice {{.*}}[*, {{.*}}] : !linalg<"view"> + // CHECK-NEXT: {{.*}} = linalg.slice {{.*}}[{{.*}}, *] : !linalg<"view"> + // CHECK-NEXT: {{.*}} = linalg.slice {{.*}}[{{.*}}] : !linalg<"view"> + // CHECK: linalg.matmul({{.*}}, {{.*}}, {{.*}}) : !linalg<"view"> + // CHECK-NEXT: linalg.matvec({{.*}}, {{.*}}, {{.*}}) : !linalg<"view"> + // CHECK-NEXT: linalg.dot({{.*}}, {{.*}}, {{.*}}) : !linalg<"view"> // clang-format on cleanupAndPrintFunction(f); @@ -96,9 +97,9 @@ TEST_FUNC(linalg_ops_folded_slices) { ret(); // CHECK-LABEL: func @linalg_ops_folded_slices(%arg0: index, %arg1: index, %arg2: index) { // CHECK-NOT: linalg.slice - // CHECK: linalg.matmul {{{.*}}, {{.*}}} -> {{{.*}}} - // CHECK-NEXT: linalg.matvec {{{.*}}, {{.*}}} -> {{{.*}}} - // CHECK-NEXT: linalg.dot {{{.*}}, {{.*}}} -> {{{.*}}} + // CHECK: linalg.matmul({{.*}}, {{.*}}, {{.*}}) : !linalg<"view"> + // CHECK-NEXT: linalg.matvec({{.*}}, {{.*}}, {{.*}}) : !linalg<"view"> + // CHECK-NEXT: linalg.dot({{.*}}, {{.*}}, {{.*}}) : !linalg<"view"> // clang-format on f->walk([](SliceOp slice) { @@ -111,6 +112,7 @@ TEST_FUNC(linalg_ops_folded_slices) { cleanupAndPrintFunction(f); } int main() { + mlir::registerDialect(); RUN_TESTS(); return 0; } diff --git a/mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps-inl.h b/mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps-inl.h index 1a08e9e..940f8d7 100644 --- a/mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps-inl.h +++ b/mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps-inl.h @@ -45,6 +45,12 @@ linalg::TensorContractionBase::getOutputs() { } template +mlir::Operation::operand_range +linalg::TensorContractionBase::getInputsAndOutputs() { + return {getInputs().begin(), getOutputs().end()}; +} + +template mlir::LogicalResult linalg::TensorContractionBase::verify() { auto *concreteOp = static_cast(this)->getOperation(); if (getNumInputs() <= 0) @@ -85,21 +91,27 @@ bool linalg::TensorContractionBase::parse( // A TensorContraction prints as: // // ```{.mlir} -// concrete_op_name {%0, %1} -> {%2} +// concrete_op_name (ssa-inputs, ssa-outputs) : output-view-types +// ``` +// +// for example: +// +// ``` +// linalg.matmul(%0, %1, %2) : view // ``` // -// Where %0, %1 is an ssa-value holding a View, %2 is an ssa-value holding a -// view. +// Where %0, %1 and %2 are ssa-values of type ViewType. template void linalg::TensorContractionBase::print(mlir::OpAsmPrinter *p) { - *p << static_cast(this)->getOperationName() << " {"; - auto *lastInput = *std::prev(getInputs().end()); - for (auto *i : getInputs()) { - *p << *i << ((i == lastInput) ? "} -> {" : ", "); + *p << static_cast(this)->getOperationName() << "("; + auto *last = *std::prev(getInputsAndOutputs().end()); + for (auto *i : getInputsAndOutputs()) { + *p << *i << ((i == last) ? "" : ", "); } + *p << ") : "; auto *lastOutput = *std::prev(getOutputs().end()); for (auto *o : getOutputs()) { - *p << *o << ((o == lastOutput) ? "}" : ","); + *p << o->getType() << ((o == lastOutput) ? "" : ","); } } diff --git a/mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps.h b/mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps.h index b8d9f8f..39e51f0 100644 --- a/mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps.h +++ b/mlir/examples/Linalg/Linalg2/include/linalg2/TensorOps.h @@ -48,6 +48,7 @@ public: TensorContractionBase() = default; mlir::Operation::operand_range getInputs(); mlir::Operation::operand_range getOutputs(); + mlir::Operation::operand_range getInputsAndOutputs(); /// These are better as methods calling into the ConcreteOp instead of /// template parameters because methods allow more generic behavior and avoid diff --git a/mlir/examples/Linalg/Linalg2/lib/DialectRegistration.cpp b/mlir/examples/Linalg/Linalg2/lib/DialectConstruction.cpp similarity index 70% rename from mlir/examples/Linalg/Linalg2/lib/DialectRegistration.cpp rename to mlir/examples/Linalg/Linalg2/lib/DialectConstruction.cpp index 64cba7c..0b9c9bf 100644 --- a/mlir/examples/Linalg/Linalg2/lib/DialectRegistration.cpp +++ b/mlir/examples/Linalg/Linalg2/lib/DialectConstruction.cpp @@ -1,4 +1,4 @@ -//===- DialectRegistration.cpp - Registration of the Linalg dialect -------===// +//===- DialectConstruction.cpp - Construction of the Linalg dialect -------===// // // Copyright 2019 The MLIR Authors. // @@ -15,9 +15,9 @@ // limitations under the License. // ============================================================================= // -// This file registers the Linalg dialect and should live in a standalone -// library. Linking with this library will create a static global object that -// performs dialect registration. +// This file implements the constructor for the Linalg Dialect. This is +// explicitly separated from the core library to allow incremental buildup of +// the codebase for the tutorial. // //===----------------------------------------------------------------------===// @@ -31,7 +31,3 @@ LinalgDialect::LinalgDialect(mlir::MLIRContext *context) addTypes(); addOperations(); } - -// Dialect registration triggers the creation of a `LinalgDialect` object which -// adds the proper types and operations to the dialect. -static mlir::DialectRegistration LinalgOps; diff --git a/mlir/examples/Linalg/Linalg3/Conversion.cpp b/mlir/examples/Linalg/Linalg3/Conversion.cpp index bdb9d69..1f0f001 100644 --- a/mlir/examples/Linalg/Linalg3/Conversion.cpp +++ b/mlir/examples/Linalg/Linalg3/Conversion.cpp @@ -22,6 +22,7 @@ #include "linalg3/ConvertToLLVMDialect.h" #include "linalg1/Common.h" +#include "linalg1/Dialect.h" #include "linalg2/Intrinsics.h" #include "linalg3/Ops.h" #include "linalg3/Transforms.h" @@ -106,6 +107,7 @@ TEST_FUNC(foo) { } int main() { + mlir::registerDialect(); RUN_TESTS(); return 0; } diff --git a/mlir/examples/Linalg/Linalg3/Example.cpp b/mlir/examples/Linalg/Linalg3/Example.cpp index 5f5760e..e92199c 100644 --- a/mlir/examples/Linalg/Linalg3/Example.cpp +++ b/mlir/examples/Linalg/Linalg3/Example.cpp @@ -19,6 +19,7 @@ #include "TestHarness.h" #include "linalg1/Common.h" +#include "linalg1/Dialect.h" #include "linalg2/Intrinsics.h" #include "linalg3/Ops.h" #include "linalg3/Transforms.h" @@ -68,11 +69,11 @@ TEST_FUNC(matmul_as_matvec) { // clang-format off // CHECK-LABEL: func @matmul_as_matvec(%arg0: memref, %arg1: memref, %arg2: memref) { // CHECK: %[[N:.*]] = dim %arg2, 1 : memref - // CHECK: %[[vA:.*]] = linalg.view %arg0[%{{.*}}, %{{.*}}] : !linalg<"view"> + // CHECK: %[[vA:.*]] = linalg.view %arg0[%{{.*}}, %{{.*}}] : !linalg<"view"> // CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) { - // CHECK: %[[vB:.*]] = linalg.view %arg1[%{{.*}}, %{{.*}}] : !linalg<"view"> - // CHECK: %[[vC:.*]] = linalg.view %arg2[%{{.*}}, %{{.*}}] : !linalg<"view"> - // CHECK: linalg.matvec {%[[vA]], %[[vB]]} -> {%[[vC]]} + // CHECK: %[[vB:.*]] = linalg.view %arg1[%{{.*}}, %{{.*}}] : !linalg<"view"> + // CHECK: %[[vC:.*]] = linalg.view %arg2[%{{.*}}, %{{.*}}] : !linalg<"view"> + // CHECK: linalg.matvec(%[[vA]], %[[vB]], %[[vC]]) : !linalg<"view"> // clang-format on cleanupAndPrintFunction(f); } @@ -89,11 +90,11 @@ TEST_FUNC(matmul_as_dot) { // CHECK: %[[M:.*]] = dim %arg0, 0 : memref // CHECK: %[[N:.*]] = dim %arg2, 1 : memref // CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) { - // CHECK: %[[vB:.*]] = linalg.view %arg1[%{{.*}}, %{{.*}}] : !linalg<"view"> + // CHECK: %[[vB:.*]] = linalg.view %arg1[%{{.*}}, %{{.*}}] : !linalg<"view"> // CHECK-NEXT: affine.for %i1 = 0 to (d0) -> (d0)(%[[M]]) { - // CHECK: %[[vA:.*]] = linalg.view %arg0[%{{.*}}, %{{.*}}] : !linalg<"view"> - // CHECK-NEXT: %[[vC:.*]] = linalg.view %arg2[%{{.*}}, %{{.*}}] : !linalg<"view<0xf32>"> - // CHECK-NEXT: linalg.dot {%[[vA]], %[[vB]]} -> {%[[vC]]} + // CHECK: %[[vA:.*]] = linalg.view %arg0[%{{.*}}, %{{.*}}] : !linalg<"view"> + // CHECK-NEXT: %[[vC:.*]] = linalg.view %arg2[%{{.*}}, %{{.*}}] : !linalg<"view"> + // CHECK-NEXT: linalg.dot(%[[vA]], %[[vB]], %[[vC]]) : !linalg<"view"> // clang-format on cleanupAndPrintFunction(f); } @@ -112,20 +113,20 @@ TEST_FUNC(matmul_as_loops) { // CHECK: %[[rM:.*]] = linalg.range %c0:%[[M]]:%c1 : !linalg<"range"> // CHECK: %[[rN:.*]] = linalg.range %c0:%[[N]]:%c1 : !linalg<"range"> // CHECK: %[[rK:.*]] = linalg.range %c0:%[[K]]:%c1 : !linalg<"range"> - // CHECK: %[[vA:.*]] = linalg.view %arg0[%[[rM]], %[[rK]]] : !linalg<"view"> - // CHECK: %[[vB:.*]] = linalg.view %arg1[%[[rK]], %[[rN]]] : !linalg<"view"> - // CHECK: %[[vC:.*]] = linalg.view %arg2[%[[rM]], %[[rN]]] : !linalg<"view"> + // CHECK: %[[vA:.*]] = linalg.view %arg0[%[[rM]], %[[rK]]] : !linalg<"view"> + // CHECK: %[[vB:.*]] = linalg.view %arg1[%[[rK]], %[[rN]]] : !linalg<"view"> + // CHECK: %[[vC:.*]] = linalg.view %arg2[%[[rM]], %[[rN]]] : !linalg<"view"> // CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[M]]) { // CHECK: affine.for %i1 = 0 to (d0) -> (d0)(%[[N]]) { // CHECK: affine.for %i2 = 0 to (d0) -> (d0)(%[[K]]) { // CHECK: %{{.*}} = cmpi "eq", %{{.*}} : index - // CHECK: %{{.*}} = linalg.load %[[vC]][%i0, %i1] : !linalg<"view"> + // CHECK: %{{.*}} = linalg.load %[[vC]][%i0, %i1] : !linalg<"view"> // CHECK: %{{.*}} = select {{.*}} : f32 - // CHECK: %{{.*}} = linalg.load %[[vB]][%i2, %i1] : !linalg<"view"> - // CHECK: %{{.*}} = linalg.load %[[vA]][%i0, %i2] : !linalg<"view"> + // CHECK: %{{.*}} = linalg.load %[[vB]][%i2, %i1] : !linalg<"view"> + // CHECK: %{{.*}} = linalg.load %[[vA]][%i0, %i2] : !linalg<"view"> // CHECK: %{{.*}} = mulf {{.*}} : f32 // CHECK: %{{.*}} = addf {{.*}} : f32 - // CHECK: linalg.store {{.*}}[%i0, %i1] : !linalg<"view"> + // CHECK: linalg.store {{.*}}[%i0, %i1] : !linalg<"view"> // clang-format on cleanupAndPrintFunction(f); } @@ -143,20 +144,20 @@ TEST_FUNC(matmul_as_matvec_as_loops) { // CHECK: %[[M:.*]] = dim %arg0, 0 : memref // CHECK: %[[N:.*]] = dim %arg2, 1 : memref // CHECK: %[[K:.*]] = dim %arg0, 1 : memref - // CHECK: %[[vA:.*]] = linalg.view %arg0[{{.*}}, {{.*}}] : !linalg<"view"> + // CHECK: %[[vA:.*]] = linalg.view %arg0[{{.*}}, {{.*}}] : !linalg<"view"> // CHECK: affine.for %i0 = 0 to (d0) -> (d0)(%[[N]]) { - // CHECK: %[[vB:.*]] = linalg.view %arg1[{{.*}}, {{.*}}] : !linalg<"view"> - // CHECK: %[[vC:.*]] = linalg.view %arg2[{{.*}}, {{.*}}] : !linalg<"view"> + // CHECK: %[[vB:.*]] = linalg.view %arg1[{{.*}}, {{.*}}] : !linalg<"view"> + // CHECK: %[[vC:.*]] = linalg.view %arg2[{{.*}}, {{.*}}] : !linalg<"view"> // CHECK: affine.for %i1 = 0 to (d0) -> (d0)(%[[M]]) { // CHECK: affine.for %i2 = 0 to (d0) -> (d0)(%[[K]]) { // CHECK: %{{.*}} = cmpi "eq", %i2, %{{.*}} : index - // CHECK: %[[C:.*]] = linalg.load %[[vC]][%i1] : !linalg<"view"> + // CHECK: %[[C:.*]] = linalg.load %[[vC]][%i1] : !linalg<"view"> // CHECK: %[[C2:.*]] = select %{{.*}}, %{{.*}}, %[[C]] : f32 - // CHECK: %[[B:.*]] = linalg.load %[[vB]][%i2] : !linalg<"view"> - // CHECK: %[[A:.*]] = linalg.load %[[vA]][%i1, %i2] : !linalg<"view"> + // CHECK: %[[B:.*]] = linalg.load %[[vB]][%i2] : !linalg<"view"> + // CHECK: %[[A:.*]] = linalg.load %[[vA]][%i1, %i2] : !linalg<"view"> // CHECK: %{{.*}} = mulf %[[A]], %[[B]] : f32 // CHECK: %{{.*}} = addf %[[C2]], %{{.*}} : f32 - // CHECK: linalg.store %{{.*}}, %{{.*}}[%i1] : !linalg<"view"> + // CHECK: linalg.store %{{.*}}, %{{.*}}[%i1] : !linalg<"view"> // clang-format on cleanupAndPrintFunction(f); } @@ -197,6 +198,7 @@ TEST_FUNC(matmul_as_matvec_as_affine) { } int main() { + mlir::registerDialect(); RUN_TESTS(); return 0; } diff --git a/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h b/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h index f54c76b..9af528e 100644 --- a/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h +++ b/mlir/examples/Linalg/Linalg3/include/linalg3/Transforms.h @@ -57,8 +57,8 @@ makeGenericLoopRanges(mlir::AffineMap operandRangesToLoopMaps, /// to only use linalg.view operations. void composeSliceOps(mlir::Function *f); -/// Traverses `f` and rewrites linalg.matmul (resp. linalg.matvec) -/// as linalg.matvec (resp. linalg.dot). +/// Traverses `f` and rewrites linalg.matmul(resp. linalg.matvec) +/// as linalg.matvec(resp. linalg.dot). void lowerToFinerGrainedTensorContraction(mlir::Function *f); /// Operation-wise writing of linalg operations to loop form. diff --git a/mlir/examples/Linalg/Linalg3/lib/DialectRegistration.cpp b/mlir/examples/Linalg/Linalg3/lib/DialectConstruction.cpp similarity index 71% rename from mlir/examples/Linalg/Linalg3/lib/DialectRegistration.cpp rename to mlir/examples/Linalg/Linalg3/lib/DialectConstruction.cpp index 1ab2751..ded48f5 100644 --- a/mlir/examples/Linalg/Linalg3/lib/DialectRegistration.cpp +++ b/mlir/examples/Linalg/Linalg3/lib/DialectConstruction.cpp @@ -1,4 +1,4 @@ -//===- DialectRegistration.cpp - Registration of the Linalg dialect -------===// +//===- DialectConstruction.cpp - Construction of the Linalg dialect -------===// // // Copyright 2019 The MLIR Authors. // @@ -15,9 +15,9 @@ // limitations under the License. // ============================================================================= // -// This file registers the Linalg dialect and should live in a standalone -// library. Linking with this library will create a static global object that -// performs dialect registration. +// This file implements the constructor for the Linalg Dialect. This is +// explicitly separated from the core library to allow incremental buildup of +// the codebase for the tutorial. // //===----------------------------------------------------------------------===// @@ -33,7 +33,3 @@ LinalgDialect::LinalgDialect(mlir::MLIRContext *context) addOperations(); } - -// Dialect registration triggers the creation of a `LinalgDialect` object which -// adds the proper types and operations to the dialect. -static mlir::DialectRegistration LinalgOps; diff --git a/mlir/examples/Linalg/Linalg4/Example.cpp b/mlir/examples/Linalg/Linalg4/Example.cpp index 91cc0f9..f7f1607 100644 --- a/mlir/examples/Linalg/Linalg4/Example.cpp +++ b/mlir/examples/Linalg/Linalg4/Example.cpp @@ -19,6 +19,7 @@ #include "TestHarness.h" #include "linalg1/Common.h" +#include "linalg1/Dialect.h" #include "linalg2/Intrinsics.h" #include "linalg3/Ops.h" #include "linalg4/Transforms.h" @@ -113,13 +114,13 @@ TEST_FUNC(matmul_tiled_views) { // CHECK-NEXT: %[[i0max:.*]] = affine.apply (d0) -> (d0 + 8)(%i0) // CHECK-NEXT: %[[ri0:.*]] = linalg.range %[[i0min]]:%[[i0max]]:{{.*}} : !linalg<"range"> // CHECK: %[[rK:.*]] = linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg<"range"> - // CHECK: %[[vA:.*]] = linalg.view %arg0[%[[ri0]], %[[rK]]] : !linalg<"view"> + // CHECK: %[[vA:.*]] = linalg.view %arg0[%[[ri0]], %[[rK]]] : !linalg<"view"> // CHECK: %[[i1min:.*]] = affine.apply (d0) -> (d0)(%i1) // CHECK-NEXT: %[[i1max:.*]] = affine.apply (d0) -> (d0 + 9)(%i1) // CHECK-NEXT: %[[ri1:.*]] = linalg.range %[[i1min]]:%[[i1max]]:%{{.*}} : !linalg<"range"> - // CHECK-NEXT: %[[vB:.*]] = linalg.view %arg1[%10, %13] : !linalg<"view"> - // CHECK-NEXT: %[[vC:.*]] = linalg.view %arg2[%5, %13] : !linalg<"view"> - // CHECK-NEXT: linalg.matmul {%[[vA]], %[[vB]]} -> {%[[vC]]} + // CHECK-NEXT: %[[vB:.*]] = linalg.view %arg1[%10, %13] : !linalg<"view"> + // CHECK-NEXT: %[[vC:.*]] = linalg.view %arg2[%5, %13] : !linalg<"view"> + // CHECK-NEXT: linalg.matmul(%[[vA]], %[[vB]], %[[vC]]) : !linalg<"view"> // clang-format on cleanupAndPrintFunction(f); } @@ -149,28 +150,29 @@ TEST_FUNC(matmul_tiled_views_as_loops) { // CHECK-NEXT: %[[i0max:.*]] = affine.apply (d0) -> (d0 + 8)(%i0) // CHECK-NEXT: %[[ri0:.*]] = linalg.range %[[i0min]]:%[[i0max]]:{{.*}} : !linalg<"range"> // CHECK: %[[rK:.*]] = linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg<"range"> - // CHECK: %[[vA:.*]] = linalg.view %arg0[%[[ri0]], %[[rK]]] : !linalg<"view"> + // CHECK: %[[vA:.*]] = linalg.view %arg0[%[[ri0]], %[[rK]]] : !linalg<"view"> // CHECK: %[[i1min:.*]] = affine.apply (d0) -> (d0)(%i1) // CHECK-NEXT: %[[i1max:.*]] = affine.apply (d0) -> (d0 + 9)(%i1) // CHECK-NEXT: %[[ri1:.*]] = linalg.range %[[i1min]]:%[[i1max]]:%{{.*}} : !linalg<"range"> - // CHECK-NEXT: %[[vB:.*]] = linalg.view %arg1[%10, %13] : !linalg<"view"> - // CHECK-NEXT: %[[vC:.*]] = linalg.view %arg2[%5, %13] : !linalg<"view"> + // CHECK-NEXT: %[[vB:.*]] = linalg.view %arg1[%10, %13] : !linalg<"view"> + // CHECK-NEXT: %[[vC:.*]] = linalg.view %arg2[%5, %13] : !linalg<"view"> // CHECK-NEXT: affine.for %i2 = (d0) -> (d0)(%i0) to (d0) -> (d0)(%[[i0max]]) { // CHECK-NEXT: affine.for %i3 = (d0) -> (d0)(%i1) to (d0) -> (d0)(%[[i1max]]) { // CHECK-NEXT: affine.for %i4 = 0 to (d0) -> (d0)(%[[K]]) { // CHECK-NEXT: %{{.*}} = cmpi "eq", %i4, %c0 : index - // CHECK-NEXT: %{{.*}} = linalg.load %[[vC]][%i2, %i3] : !linalg<"view"> + // CHECK-NEXT: %{{.*}} = linalg.load %[[vC]][%i2, %i3] : !linalg<"view"> // CHECK-NEXT: %{{.*}} = select %{{.*}}, %cst, %{{.*}} : f32 - // CHECK-NEXT: %{{.*}} = linalg.load %[[vB]][%i4, %i3] : !linalg<"view"> - // CHECK-NEXT: %{{.*}} = linalg.load %[[vA]][%i2, %i4] : !linalg<"view"> + // CHECK-NEXT: %{{.*}} = linalg.load %[[vB]][%i4, %i3] : !linalg<"view"> + // CHECK-NEXT: %{{.*}} = linalg.load %[[vA]][%i2, %i4] : !linalg<"view"> // CHECK-NEXT: %{{.*}} = mulf %{{.*}}, %{{.*}} : f32 // CHECK-NEXT: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 - // CHECK-NEXT: linalg.store %{{.*}}, %[[vC]][%i2, %i3] : !linalg<"view"> + // CHECK-NEXT: linalg.store %{{.*}}, %[[vC]][%i2, %i3] : !linalg<"view"> // clang-format on cleanupAndPrintFunction(f); } int main() { + mlir::registerDialect(); RUN_TESTS(); return 0; } diff --git a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp index ece8598..05865e9 100644 --- a/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp +++ b/mlir/examples/Linalg/Linalg4/lib/Transforms.cpp @@ -51,16 +51,6 @@ void linalg::lowerToTiledLoops(mlir::Function *f, }); } -template -static Operation::operand_range -getInputsAndOutputs(TensorContractionBase &contraction) { - auto *inst = static_cast(&contraction)->getOperation(); - auto begin = inst->operand_begin(); - auto end = inst->operand_begin() + contraction.getNumInputs() + - contraction.getNumOutputs(); - return {begin, end}; -} - static bool isZeroIndex(Value *v) { return v->getDefiningOp() && v->getDefiningOp()->isa() && v->getDefiningOp()->dyn_cast().getValue() == 0; @@ -138,7 +128,7 @@ makeTiledViews(linalg::TensorContractionBase &contraction, makeTiledRanges(contraction, getRanges(contraction), ivs, tileSizes); SmallVector res; unsigned currentRange = 0; - for (auto *in : getInputsAndOutputs(contraction)) { + for (auto *in : contraction.getInputsAndOutputs()) { unsigned runningSliceDim = 0; auto *runningSlice = in; assert(runningSlice->getType().template isa()); -- 2.7.4