From d7296a4ae34b880edc68de329059983a840acb2b Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Mon, 1 Apr 2019 15:15:09 -0700 Subject: [PATCH] Linalg portion of the tutorial - part 1 The first part of the Linalg tutorial introduces: 1. the RangeType and ViewType; 2. operations on those, namely RangeOp, ViewOp and SliceOp; 3. programmatic examples to test MLIR construction involving these types, ops and affine.for loops (with a mock custom op called "some_consumer"). -- PiperOrigin-RevId: 241409949 --- mlir/tutorial/Linalg1/Example.cpp | 110 +++++++++++ mlir/tutorial/Linalg1/TestHarness.h | 67 +++++++ mlir/tutorial/Linalg1/include/linalg/Common.h | 151 +++++++++++++++ mlir/tutorial/Linalg1/include/linalg/Dialect.h | 42 +++++ mlir/tutorial/Linalg1/include/linalg/Ops.h | 59 ++++++ mlir/tutorial/Linalg1/include/linalg/RangeOp.h | 56 ++++++ mlir/tutorial/Linalg1/include/linalg/RangeType.h | 47 +++++ mlir/tutorial/Linalg1/include/linalg/SliceOp.h | 102 ++++++++++ mlir/tutorial/Linalg1/include/linalg/Types.h | 37 ++++ mlir/tutorial/Linalg1/include/linalg/ViewOp.h | 71 +++++++ mlir/tutorial/Linalg1/include/linalg/ViewType.h | 54 ++++++ mlir/tutorial/Linalg1/lib/Common.cpp | 70 +++++++ mlir/tutorial/Linalg1/lib/Dialect.cpp | 83 ++++++++ mlir/tutorial/Linalg1/lib/DialectRegistration.cpp | 42 +++++ mlir/tutorial/Linalg1/lib/RangeOp.cpp | 68 +++++++ mlir/tutorial/Linalg1/lib/SliceOp.cpp | 220 ++++++++++++++++++++++ mlir/tutorial/Linalg1/lib/ViewOp.cpp | 156 +++++++++++++++ mlir/tutorial/Linalg1/lib/ViewType.cpp | 79 ++++++++ 18 files changed, 1514 insertions(+) create mode 100644 mlir/tutorial/Linalg1/Example.cpp create mode 100644 mlir/tutorial/Linalg1/TestHarness.h create mode 100644 mlir/tutorial/Linalg1/include/linalg/Common.h create mode 100644 mlir/tutorial/Linalg1/include/linalg/Dialect.h create mode 100644 mlir/tutorial/Linalg1/include/linalg/Ops.h create mode 100644 mlir/tutorial/Linalg1/include/linalg/RangeOp.h create mode 100644 mlir/tutorial/Linalg1/include/linalg/RangeType.h create mode 100644 mlir/tutorial/Linalg1/include/linalg/SliceOp.h create mode 100644 mlir/tutorial/Linalg1/include/linalg/Types.h create mode 100644 mlir/tutorial/Linalg1/include/linalg/ViewOp.h create mode 100644 mlir/tutorial/Linalg1/include/linalg/ViewType.h create mode 100644 mlir/tutorial/Linalg1/lib/Common.cpp create mode 100644 mlir/tutorial/Linalg1/lib/Dialect.cpp create mode 100644 mlir/tutorial/Linalg1/lib/DialectRegistration.cpp create mode 100644 mlir/tutorial/Linalg1/lib/RangeOp.cpp create mode 100644 mlir/tutorial/Linalg1/lib/SliceOp.cpp create mode 100644 mlir/tutorial/Linalg1/lib/ViewOp.cpp create mode 100644 mlir/tutorial/Linalg1/lib/ViewType.cpp diff --git a/mlir/tutorial/Linalg1/Example.cpp b/mlir/tutorial/Linalg1/Example.cpp new file mode 100644 index 0000000..4457fed --- /dev/null +++ b/mlir/tutorial/Linalg1/Example.cpp @@ -0,0 +1,110 @@ +//===- Example.cpp - Our running example ----------------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +// RUN: %p/test | FileCheck %s + +#include "TestHarness.h" + +#include "linalg/Common.h" +#include "linalg/Ops.h" +#include "linalg/RangeOp.h" +#include "linalg/RangeType.h" +#include "linalg/SliceOp.h" +#include "linalg/Types.h" +#include "linalg/ViewOp.h" +#include "linalg/ViewType.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Intrinsics.h" +#include "mlir/IR/Function.h" + +using namespace linalg; +using namespace linalg::common; +using namespace linalg::intrinsics; +using namespace mlir; +using namespace mlir::edsc; +using namespace mlir::edsc::intrinsics; + +TEST_FUNC(view_op) { + Function *f = makeFunction("view_op", {}); + + ScopedContext scope(f); + + // Let's be lazy and define some custom ops that prevent DCE. + CustomOperation some_consumer("some_consumer"); + + // clang-format off + ValueHandle M(f->getArgument(0)), N(f->getArgument(1)), + A0 = alloc(floatMemRefType<0>()), + A1 = alloc(floatMemRefType<1>(), ArrayRef{M}), + A2 = alloc(floatMemRefType<2>(), ArrayRef{M, N}), + r0 = range(constant_index(3), constant_index(17), constant_index(1)), + v0 = view(A0), + v1 = view(A1, ArrayRef{r0}), + v2 = view(A2, ArrayRef{r0, r0}); + some_consumer(ArrayRef{v0, v1, v2}); + ret(); + // CHECK-LABEL: func @view_op(%arg0: index, %arg1: index, %arg2: index, %arg3: index) { + // CHECK: %[[R:.*]] = linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg<"range"> + // CHECK-NEXT: {{.*}} = linalg.view {{.*}}[] : !linalg<"view<0xf32>"> + // CHECK-NEXT: {{.*}} = linalg.view {{.*}}[%[[R]]] : !linalg<"view"> + // CHECK-NEXT: {{.*}} = linalg.view {{.*}}[%[[R]], %[[R]]] : !linalg<"view"> + // clang-format on + + cleanupAndPrintFunction(f); +} + +TEST_FUNC(slice_op) { + Function *f = makeFunction("slice_op", {}); + + ScopedContext scope(f); + + // Let's be lazy and define some custom op that prevents DCE. + CustomOperation some_consumer("some_consumer"); + + // clang-format off + ValueHandle M(f->getArgument(0)), N(f->getArgument(1)), + A = alloc(floatMemRefType<2>(), {M, N}), + r1 = range(constant_index(3), constant_index(17), constant_index(1)), + r2 = range(constant_index(0), N, constant_index(1)); + ViewOp vA = view(A, {r1, r2}).getValue()->getDefiningOp()->cast(); + IndexHandle i, j; + LoopNestRangeBuilder({&i, &j}, vA.getRanges())({ + some_consumer(slice(vA, i, 1)), + some_consumer(slice(slice(vA, j, 0), i, 0)), + }); + ret(); + // CHECK-LABEL: func @slice_op(%arg0: index, %arg1: index, %arg2: index, %arg3: index) { + // CHECK: %[[ALLOC:.*]] = alloc(%arg0, %arg1) : memref + // CHECK-NEXT: %[[R1:.*]] = linalg.range {{.*}}:{{.*}}:{{.*}} : !linalg<"range"> + // CHECK-NEXT: %[[R2:.*]] = linalg.range {{.*}}:%arg1:{{.*}} : !linalg<"range"> + // CHECK-NEXT: %[[V:.*]] = linalg.view %0[%[[R1]], %[[R2]]] : !linalg<"view"> + // CHECK-NEXT: for %i0 = 3 to 17 { + // CHECK-NEXT: for %i1 = 0 to (d0) -> (d0)(%arg1) { + // CHECK-NEXT: %[[S1:.*]] = linalg.slice %[[V]][*, %i0] { dim : 1 } : !linalg<"view"> + // CHECK-NEXT: "some_consumer"(%[[S1]]) : (!linalg<"view">) -> () + // CHECK-NEXT: %[[S2:.*]] = linalg.slice %[[V]][%i1, *] { dim : 0 } : !linalg<"view"> + // CHECK-NEXT: %[[S3:.*]] = linalg.slice %[[S2]][%i0] { dim : 0 } : !linalg<"view<0xf32>"> + // CHECK-NEXT: "some_consumer"(%[[S3]]) : (!linalg<"view<0xf32>">) -> () + // clang-format on + + cleanupAndPrintFunction(f); +} + +int main() { + RUN_TESTS(); + return 0; +} diff --git a/mlir/tutorial/Linalg1/TestHarness.h b/mlir/tutorial/Linalg1/TestHarness.h new file mode 100644 index 0000000..a2b5765 --- /dev/null +++ b/mlir/tutorial/Linalg1/TestHarness.h @@ -0,0 +1,67 @@ +//===- TestHarness.h - Minimal test harness for exercising the linalg API -===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG_TEST_HARNESS_H +#define LINALG_TEST_HARNESS_H + +#include +#include + +namespace test_detail { +// Returns a mutable list of known test functions. Used internally by test +// macros to add and run tests. This function is static to ensure it creates a +// new list in each test file. +static std::vector> &tests() { + static std::vector> list; + return list; +} + +// Test registration class. Used internally by test macros to register tests +// during static allocation. +struct TestRegistration { + explicit TestRegistration(std::function func) { + test_detail::tests().push_back(func); + } +}; +} // end namespace test_detail + +/// Declares a test function with the given name and adds it to the list of +/// known tests. The body of the function must follow immediately. Example: +/// +/// TEST_FUNC(mytest) { +/// // CHECK: expected-output-here +/// emitSomethingToStdOut(); +/// } +/// +#define TEST_FUNC(name) \ + void name(); \ + static test_detail::TestRegistration name##Registration(name); \ + void name() + +/// Runs all registered tests. Example: +/// +/// int main() { +/// RUN_TESTS(); +/// return 0; +/// } +#define RUN_TESTS \ + []() { \ + for (auto f : test_detail::tests()) \ + f(); \ + } + +#endif // LINALG_TEST_HARNESS_H diff --git a/mlir/tutorial/Linalg1/include/linalg/Common.h b/mlir/tutorial/Linalg1/include/linalg/Common.h new file mode 100644 index 0000000..d77aad0 --- /dev/null +++ b/mlir/tutorial/Linalg1/include/linalg/Common.h @@ -0,0 +1,151 @@ +//===- Common.h - Linalg dialect RangeOp operation -----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG_COMMON_H_ +#define LINALG_COMMON_H_ + +#include "mlir/AffineOps/AffineOps.h" +#include "mlir/Analysis/SliceAnalysis.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Helpers.h" +#include "mlir/EDSC/Intrinsics.h" +#include "mlir/IR/AffineExpr.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/Identifier.h" +#include "mlir/IR/MLIRContext.h" +#include "mlir/IR/Module.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" +#include "mlir/IR/Types.h" +#include "mlir/Pass/Pass.h" +#include "mlir/Pass/PassManager.h" +#include "mlir/StandardOps/Ops.h" +#include "mlir/Support/LogicalResult.h" +#include "mlir/Transforms/LoopUtils.h" +#include "mlir/Transforms/Passes.h" + +namespace linalg { +namespace common { + +//////////////////////////////////////////////////////////////////////////////// +// Define a few boilerplate objects used across all linalg examples. +//////////////////////////////////////////////////////////////////////////////// + +// The unique MLIRContext, similar to an llvm::Context. +inline mlir::MLIRContext &globalContext() { + static mlir::MLIRContext context; + return context; +} + +// The unique Module, similar to an llvm::Module. +inline mlir::Module &globalModule() { + static mlir::Module module(&globalContext()); + return module; +} + +/// Shortcut notation for types that we use globally. +/// The index type is the type that must be used with affine operations: +/// (`affine.apply`, `affine.for`, `affine.load`, `affine.store`). +inline mlir::IndexType indexType() { + return mlir::IndexType::get(&globalContext()); +} + +/// Common f32 type. +inline mlir::FloatType f32Type() { + return mlir::FloatType::getF32(&globalContext()); +} + +/// A 2-D abstraction over a flat contiguous memory region of f32 with symbolic +/// sizes. +template +inline mlir::MemRefType floatMemRefType(unsigned memorySpace = 0) { + llvm::SmallVector shape(N, -1); + return mlir::MemRefType::get(shape, f32Type(), {}, memorySpace); +} + +/// The simple function, taking 4 parameters of type index, that we will use +/// throughout this tutorial: +/// +/// ```mlir +/// func @name(%M: index, %N: index, %K: index, %P: index) +/// ``` +inline mlir::Function *makeFunction(llvm::StringRef name, + llvm::ArrayRef resultTypes) { + auto &ctx = globalContext(); + auto *function = + new mlir::Function(mlir::UnknownLoc::get(&ctx), name, + mlir::FunctionType::get({indexType(), indexType(), + indexType(), indexType()}, + resultTypes, &ctx)); + function->addEntryBlock(); + globalModule().getFunctions().push_back(function); + return function; +} + +/// A basic pass manager pre-populated with cleanup passes. +inline mlir::PassManager &cleanupPassManager() { + static bool inited = false; + static mlir::PassManager pm; + if (!inited) { + pm.addPass(mlir::createCanonicalizerPass()); + pm.addPass(mlir::createSimplifyAffineStructuresPass()); + pm.addPass(mlir::createCSEPass()); + pm.addPass(mlir::createCanonicalizerPass()); + inited = true; + } + return pm; +} + +/// A simple function to verify and cleanup the IR before printing it to +/// llvm::outs() for FileCheck'ing. +/// If an error occurs, dump to llvm::errs() and do not print to llvm::outs() +/// which will make the associated FileCheck test fail. +inline void cleanupAndPrintFunction(mlir::Function *f) { + bool printToOuts = true; + auto check = [f, &printToOuts](mlir::LogicalResult result) { + if (failed(result)) { + f->dump(); + llvm::errs() << "Failure!\n"; + printToOuts = false; + } + }; + check(mlir::failure(f->getModule()->verify())); + check(cleanupPassManager().run(f->getModule())); + if (printToOuts) + f->print(llvm::outs()); +} + +/// Helper class to sugar building loop nests from indexings that appear in +/// ViewOp and SliceOp. +class LoopNestRangeBuilder { +public: + LoopNestRangeBuilder(llvm::ArrayRef ivs, + llvm::ArrayRef indexings); + LoopNestRangeBuilder(llvm::ArrayRef ivs, + llvm::ArrayRef indexings); + mlir::edsc::ValueHandle + operator()(llvm::ArrayRef stmts); + +private: + llvm::SmallVector loops; +}; + +} // namespace common +} // namespace linalg + +#endif // LINALG_COMMON_H_ diff --git a/mlir/tutorial/Linalg1/include/linalg/Dialect.h b/mlir/tutorial/Linalg1/include/linalg/Dialect.h new file mode 100644 index 0000000..6a02bf0 --- /dev/null +++ b/mlir/tutorial/Linalg1/include/linalg/Dialect.h @@ -0,0 +1,42 @@ +//===- Dialect.h - Definition of the Linalg dialect -----------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG_DIALECT_H_ +#define LINALG_DIALECT_H_ + +#include "mlir/IR/Dialect.h" + +namespace linalg { + +/// The Linalg Dialect is not exposed to the outside world. It is registered by +/// linking and accessed via generic MLIR accessors. +class LinalgDialect : public mlir::Dialect { +public: + /// Create a new Dialect that is registered on construction and adds the + /// relevant types and operations. + explicit LinalgDialect(mlir::MLIRContext *context); + + /// Parse a type registered to this dialect. + mlir::Type parseType(llvm::StringRef spec, mlir::Location loc) const override; + + /// Print a type registered to this dialect. + void printType(mlir::Type type, llvm::raw_ostream &os) const override; +}; + +} // namespace linalg + +#endif // LINALG_DIALECT_H_ diff --git a/mlir/tutorial/Linalg1/include/linalg/Ops.h b/mlir/tutorial/Linalg1/include/linalg/Ops.h new file mode 100644 index 0000000..00c97d3 --- /dev/null +++ b/mlir/tutorial/Linalg1/include/linalg/Ops.h @@ -0,0 +1,59 @@ +//===- Ops.h - Linalg Ops forward declarations ------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG_OPS_H_ +#define LINALG_OPS_H_ + +#include "mlir/EDSC/Intrinsics.h" +#include "mlir/IR/OpDefinition.h" + +namespace linalg { + +class MatmulOp; +class RangeOp; +class SliceOp; +class ViewOp; +class ViewType; + +struct ViewOrSliceOp { +public: + ViewOrSliceOp(mlir::Value *v) : v(v) {} + ViewOp view(); + SliceOp slice(); + operator bool(); + unsigned getRank(); + ViewType getViewType(); + /// Get the indexing at `dim` by recursing into the parent. + /// Returns the indexing as well as its actual dimension, which may have + /// shifted from the originally requested `dim`. + std::pair getRootIndexing(unsigned dim); + // Get all the indexings without recursing. + mlir::Operation::operand_range getIndexings(); + mlir::Value *getSupportingMemRef(); + +private: + mlir::Value *v; +}; + +namespace intrinsics { +using range = mlir::edsc::intrinsics::ValueBuilder; +using slice = mlir::edsc::intrinsics::ValueBuilder; +using view = mlir::edsc::intrinsics::ValueBuilder; +} // namespace intrinsics +} // namespace linalg + +#endif // LINALG_OPS_H_ diff --git a/mlir/tutorial/Linalg1/include/linalg/RangeOp.h b/mlir/tutorial/Linalg1/include/linalg/RangeOp.h new file mode 100644 index 0000000..eaa266b --- /dev/null +++ b/mlir/tutorial/Linalg1/include/linalg/RangeOp.h @@ -0,0 +1,56 @@ +//===- RangeOp.h - Linalg dialect RangeOp operation definition ------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG_RANGEOP_H_ +#define LINALG_RANGEOP_H_ + +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LLVM.h" + +namespace linalg { + +/// A RangeOp is used to create a value of RangeType from 3 values of type index +/// that represent the min, max and step values of the range. +/// Note: step must be an mlir::ConstantIndexOp for now due to current +/// `affine.for` limitations. +class RangeOp : public mlir::Op::Impl, + mlir::OpTrait::OneResult, + mlir::OpTrait::HasNoSideEffect> { +public: + using Op::Op; + + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this op. + ////////////////////////////////////////////////////////////////////////////// + static llvm::StringRef getOperationName() { return "linalg.range"; } + static void build(mlir::Builder *b, mlir::OperationState *result, + mlir::Value *min, mlir::Value *max, mlir::Value *step); + bool verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + + ////////////////////////////////////////////////////////////////////////////// + // Op-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + mlir::Value *getMin() { return getOperand(0); } + mlir::Value *getMax() { return getOperand(1); } + mlir::Value *getStep() { return getOperand(2); } +}; + +} // namespace linalg + +#endif // LINALG_RANGEOP_H_ diff --git a/mlir/tutorial/Linalg1/include/linalg/RangeType.h b/mlir/tutorial/Linalg1/include/linalg/RangeType.h new file mode 100644 index 0000000..5a9be66 --- /dev/null +++ b/mlir/tutorial/Linalg1/include/linalg/RangeType.h @@ -0,0 +1,47 @@ +//===- RangeType.h - Linalg RangeType definition --------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG_RANGETYPE_H_ +#define LINALG_RANGETYPE_H_ + +#include "linalg/Types.h" + +namespace mlir { +class MLIRContext; +} + +namespace linalg { +/// A RangeType is the simplest possible form of a type in MLIR. It represents +/// a minimal range abstraction (min, max, step). Since RangeType is constructed +/// without any additional argument, this example illustrates the minimal +/// amount of information required to implement a new custom MLIR type. +class RangeType : public mlir::Type::TypeBase { +public: + // Used to implement llvm-style cast. + using Base::Base; + /// Construction hook. + static RangeType get(mlir::MLIRContext *context) { + /// Custom, uniqu'ed construction in the mlir::MLIRContext. + return Base::get(context, LinalgTypes::Range); + } + /// Used to implement llvm-style cast. + static bool kindof(unsigned kind) { return kind == LinalgTypes::Range; } +}; + +} // namespace linalg + +#endif // LINALG_RANGETYPE_H_ diff --git a/mlir/tutorial/Linalg1/include/linalg/SliceOp.h b/mlir/tutorial/Linalg1/include/linalg/SliceOp.h new file mode 100644 index 0000000..3f73b78 --- /dev/null +++ b/mlir/tutorial/Linalg1/include/linalg/SliceOp.h @@ -0,0 +1,102 @@ +//===- SliceOp.h - Linalg dialect SliceOp operation definition ------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG_SLICEOP_H_ +#define LINALG_SLICEOP_H_ + +#include "linalg/Types.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LLVM.h" + +namespace linalg { + +/// A SliceOp is used to create a "sub-View" from a ViewType. It results in a +/// new ViewType which is contained within its parent ViewType. +class SliceOp : public mlir::Op::Impl, + mlir::OpTrait::OneResult, + mlir::OpTrait::HasNoSideEffect> { +public: + using Op::Op; + + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this op. + ////////////////////////////////////////////////////////////////////////////// + static llvm::StringRef getOperationName() { return "linalg.slice"; } + static void build(mlir::Builder *b, mlir::OperationState *result, + mlir::Value *view, mlir::Value *indexing, unsigned dim); + bool verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + + ////////////////////////////////////////////////////////////////////////////// + // Op-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + enum { FirstIndexingOperand = 1 }; + /// Returns the attribute name that describes which dimension of the input + /// view that this SliceOp slices. + static llvm::StringRef getSlicingDimAttrName() { return "dim"; } + /// Returns the unique result of the parent SliceOp of ViewOp instruction that + /// created the view on which this SliceOp operates. + mlir::Value *getParentView() { return getOperand(0); } + /// Returns the indexing operand of the current SliceOp. + /// This operands may either be: + /// 1. A range, in which case the operand comes from a RangeOp. This SliceOp + /// does not reduce the dimension of the input ViewType. + /// 2. An index, in which case the operand comes from any possible producer + /// of an index. This SliceOp reduces the dimension of the input ViewType + /// by 1. + mlir::Value *getIndexing() { return getOperand(1); } + /// Returns the dim of the parent ViewType that is sliced by this SliceOp. + unsigned getSlicingDim() { + return getAttrOfType(getSlicingDimAttrName()).getInt(); + } + /// Returns the ViewType resulting from this SliceOp. + ViewType getViewType(); + /// Returns the rank of the current ViewType. + unsigned getRank(); + /// Return the element type of the current ViewType. + mlir::Type getElementType(); + + /// Returns the ViewType of `getParentView()`. + ViewType getParentViewType(); + /// Returns the rank of the ViewType of `getParentView()`. + unsigned getParentRank(); + /// Returns the element Type of the ViewType of `getParentView()`. + mlir::Type getParentElementType(); + + /// Walks the SliceOp chain until it encounters the base ViewOp. + /// Returns the single return value of the ViewOp. + mlir::Value *getBaseView(); + + /// Returns the MemRef backing the base ViewOp. + // May be another data type than a MemRef in the future. + mlir::Value *getSupportingMemRef(); + + /// Extracts the indexing from the original ViewOp that this slice restricts + /// along `dim`. Walks back the chain of SliceOp and determines the first + /// slice that constrains `dim`. + /// Returns the indexing as well as its actual dimension which may have + /// shifted from the originally requested `dim`. + std::pair getRootIndexing(unsigned dim); + + // Get all the indexings in this slice. + mlir::Operation::operand_range getIndexings(); +}; + +} // namespace linalg + +#endif // LINALG_SLICEOP_H_ diff --git a/mlir/tutorial/Linalg1/include/linalg/Types.h b/mlir/tutorial/Linalg1/include/linalg/Types.h new file mode 100644 index 0000000..5a04ad2 --- /dev/null +++ b/mlir/tutorial/Linalg1/include/linalg/Types.h @@ -0,0 +1,37 @@ +//===- Types.h - Linalg Types forward declarations ------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG_TYPES_H_ +#define LINALG_TYPES_H_ + +#include "mlir/IR/Types.h" + +namespace linalg { + +class RangeType; +class ViewType; +class ViewTypeStorage; + +enum LinalgTypes { + Range = mlir::Type::FIRST_LINALG_TYPE, + View, + LAST_USED_LINALG_TYPE = View, +}; + +} // namespace linalg + +#endif // LINALG_TYPES_H_ diff --git a/mlir/tutorial/Linalg1/include/linalg/ViewOp.h b/mlir/tutorial/Linalg1/include/linalg/ViewOp.h new file mode 100644 index 0000000..2a96cb5 --- /dev/null +++ b/mlir/tutorial/Linalg1/include/linalg/ViewOp.h @@ -0,0 +1,71 @@ +//===- ViewOp.h - Linalg dialect ViewOp operation definition ------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG_VIEWOP_H_ +#define LINALG_VIEWOP_H_ + +#include "mlir/IR/OpDefinition.h" +#include "mlir/Support/LLVM.h" + +namespace linalg { + +class ViewType; + +/// A `ViewOp` produces a `ViewType` which is a multi-dimensional range +/// abstraction on top of an underlying data type. For now we use the existing +/// mlir::MemRef for the underlying data type. +class ViewOp : public mlir::Op { +public: + using Op::Op; + + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this op. + ////////////////////////////////////////////////////////////////////////////// + static llvm::StringRef getOperationName() { return "linalg.view"; } + static void build(mlir::Builder *b, mlir::OperationState *result, + mlir::Value *memRef, + llvm::ArrayRef indexings = {}); + bool verify(); + static bool parse(mlir::OpAsmParser *parser, mlir::OperationState *result); + void print(mlir::OpAsmPrinter *p); + + ////////////////////////////////////////////////////////////////////////////// + // Op-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + enum { FirstIndexingOperand = 1 }; + unsigned getRank(); + mlir::Type getElementType(); + ViewType getViewType(); + // May be something else than a MemRef in the future. + mlir::Value *getSupportingMemRef(); + // Get the underlying indexing at a given rank. + mlir::Value *getIndexing(unsigned rank); + // A ViewOp is a root, its root indexing is trivial. + std::pair getRootIndexing(unsigned rank) { + return std::make_pair(getIndexing(rank), rank); + } + // Get all the indexings of type RangeOp. + llvm::SmallVector getRanges(); + // Get all the indexings in this view. + mlir::Operation::operand_range getIndexings(); +}; + +} // namespace linalg + +#endif // LINALG_VIEWOP_H_ diff --git a/mlir/tutorial/Linalg1/include/linalg/ViewType.h b/mlir/tutorial/Linalg1/include/linalg/ViewType.h new file mode 100644 index 0000000..8cfed55 --- /dev/null +++ b/mlir/tutorial/Linalg1/include/linalg/ViewType.h @@ -0,0 +1,54 @@ +//===- ViewType.h - Linalg ViewType definition --------------------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= + +#ifndef LINALG_VIEWTYPE_H_ +#define LINALG_VIEWTYPE_H_ + +#include "linalg/Types.h" + +namespace linalg { + +/// A ViewType represents a range abstraction on top of an underlying storage +/// type. It is parameterizable by the underlying element type and the rank of +/// the view. +class ViewType + : public mlir::Type::TypeBase { +public: + ////////////////////////////////////////////////////////////////////////////// + // Hooks to customize the behavior of this type. + ////////////////////////////////////////////////////////////////////////////// + // Used to implement llvm-style cast. + using Base::Base; + // Used to implement llvm-style cast. + static bool kindof(unsigned kind) { return kind == LinalgTypes::View; } + /// Construction hook. + static ViewType get(mlir::MLIRContext *context, mlir::Type elementType, + unsigned rank); + + ////////////////////////////////////////////////////////////////////////////// + // Type-specific functionality. + ////////////////////////////////////////////////////////////////////////////// + /// Return the underlying elemental type. + mlir::Type getElementType(); + /// Return the rank of the view. + /// This is the number of indexings needed to reach an underlying element. + unsigned getRank(); +}; + +} // namespace linalg + +#endif // LINALG_VIEWTYPE_H_ diff --git a/mlir/tutorial/Linalg1/lib/Common.cpp b/mlir/tutorial/Linalg1/lib/Common.cpp new file mode 100644 index 0000000..d619219 --- /dev/null +++ b/mlir/tutorial/Linalg1/lib/Common.cpp @@ -0,0 +1,70 @@ +//===- Common.cpp - Implementation of common supporting functions ---------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a simple IR operation to create a new RangeType in the +// linalg dialect. +// +//===----------------------------------------------------------------------===// + +#include "linalg/Common.h" +#include "linalg/Ops.h" +#include "linalg/RangeOp.h" +#include "linalg/ViewOp.h" +#include "linalg/ViewType.h" +#include "mlir/EDSC/Builders.h" +#include "mlir/EDSC/Intrinsics.h" +#include "mlir/StandardOps/Ops.h" + +using llvm::ArrayRef; +using mlir::ConstantIndexOp; +using mlir::edsc::CapturableHandle; +using mlir::edsc::ValueHandle; +using mlir::edsc::intrinsics::alloc; +using mlir::edsc::intrinsics::ret; + +using namespace linalg; + +linalg::common::LoopNestRangeBuilder::LoopNestRangeBuilder( + llvm::ArrayRef ivs, llvm::ArrayRef indexings) { + assert(ivs.size() == indexings.size()); + for (unsigned i = 0, e = indexings.size(); i < e; ++i) { + auto rangeOp = + indexings[i].getValue()->getDefiningOp()->dyn_cast(); + if (!rangeOp) { + continue; + } + auto lb = rangeOp.getMin(); + auto ub = rangeOp.getMax(); + // This must be a constexpr index until we relax the affine.for constraint + auto step = + rangeOp.getStep()->getDefiningOp()->cast().getValue(); + loops.emplace_back(ivs[i], ValueHandle(lb), ValueHandle(ub), step); + } +} + +linalg::common::LoopNestRangeBuilder::LoopNestRangeBuilder( + llvm::ArrayRef ivs, llvm::ArrayRef indexings) + : LoopNestRangeBuilder(ivs, llvm::SmallVector( + indexings.begin(), indexings.end())) {} + +ValueHandle linalg::common::LoopNestRangeBuilder::operator()( + llvm::ArrayRef stmts) { + for (auto &lit : llvm::reverse(loops)) { + lit({}); + } + return ValueHandle::null(); +} diff --git a/mlir/tutorial/Linalg1/lib/Dialect.cpp b/mlir/tutorial/Linalg1/lib/Dialect.cpp new file mode 100644 index 0000000..84f8b45a --- /dev/null +++ b/mlir/tutorial/Linalg1/lib/Dialect.cpp @@ -0,0 +1,83 @@ +//===- Dialect.cpp - Implementation of the linalg dialect -----------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a simple Linalg dialect to which we gradually add +// complexity. +// +//===----------------------------------------------------------------------===// + +#include "linalg/Dialect.h" +#include "linalg/RangeOp.h" +#include "linalg/RangeType.h" +#include "linalg/SliceOp.h" +#include "linalg/ViewOp.h" +#include "linalg/ViewType.h" +#include "mlir/IR/Dialect.h" +#include "llvm/Support/raw_ostream.h" + +using llvm::raw_ostream; +using llvm::StringRef; +using mlir::Location; +using mlir::Type; + +using namespace linalg; + +Type LinalgDialect::parseType(StringRef spec, Location loc) const { + llvm_unreachable("Unhandled linalg dialect parsing"); + return Type(); +} + +/// RangeType prints as just "range". +static void print(RangeType rt, raw_ostream &os) { os << "range"; } + +/// ViewType prints as: +/// +/// ```{.mlir} +/// view +/// ``` +/// +/// or +/// +/// ```{.mlir} +/// view<0xf32> +/// ``` +/// +/// for 0-D views (a.k.a pointer to a scalar value). +static void print(linalg::ViewType rt, raw_ostream &os) { + os << "view<"; + if (rt.getRank() > 0) { + for (unsigned i = 0, e = rt.getRank(); i < e; ++i) { + os << rt.getElementType() << ((i == e - 1) ? "" : "x"); + } + } else { + os << "0x" << rt.getElementType(); + } + os << ">"; +} + +void LinalgDialect::printType(Type type, raw_ostream &os) const { + switch (type.getKind()) { + default: + llvm_unreachable("Unhandled linalg type"); + case LinalgTypes::Range: + print(type.cast(), os); + break; + case linalg::LinalgTypes::View: + print(type.cast(), os); + break; + } +} diff --git a/mlir/tutorial/Linalg1/lib/DialectRegistration.cpp b/mlir/tutorial/Linalg1/lib/DialectRegistration.cpp new file mode 100644 index 0000000..5b007c0 --- /dev/null +++ b/mlir/tutorial/Linalg1/lib/DialectRegistration.cpp @@ -0,0 +1,42 @@ +//===- DialectRegistration.cpp - Registration of the Linalg dialect -------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file registers the Linalg dialect and should live in a standalone +// library. Linking with this library will create a static global object that +// performs dialect registration. +// +//===----------------------------------------------------------------------===// + +#include "linalg/Dialect.h" +#include "linalg/RangeOp.h" +#include "linalg/RangeType.h" +#include "linalg/SliceOp.h" +#include "linalg/ViewOp.h" +#include "linalg/ViewType.h" + +using namespace mlir; +using namespace linalg; + +LinalgDialect::LinalgDialect(MLIRContext *context) + : Dialect("linalg", context) { + addTypes(); + addOperations(); +} + +// Dialect registration triggers the creation of a `LinalgDialect` object which +// adds the proper types and operations to the dialect. +static mlir::DialectRegistration LinalgOps; diff --git a/mlir/tutorial/Linalg1/lib/RangeOp.cpp b/mlir/tutorial/Linalg1/lib/RangeOp.cpp new file mode 100644 index 0000000..28d2e43af --- /dev/null +++ b/mlir/tutorial/Linalg1/lib/RangeOp.cpp @@ -0,0 +1,68 @@ +//===- RangeOp.cpp - Implementation of the linalg RangeOp operation -------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a simple IR operation to create a new RangeType in the +// linalg dialect. +// +//===----------------------------------------------------------------------===// + +#include "linalg/RangeOp.h" +#include "linalg/RangeType.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/StandardTypes.h" + +using mlir::Builder; +using mlir::IndexType; +using mlir::OpAsmParser; +using mlir::OpAsmPrinter; +using mlir::OperationState; +using mlir::Value; + +// Minimal example for a new RangeOp operating on RangeType. +void linalg::RangeOp::build(Builder *b, OperationState *result, Value *min, + Value *max, Value *step) { + result->addOperands({min, max, step}); + result->addTypes({linalg::RangeType::get(b->getContext())}); +} + +// Verification is simply that a RangeOp takes 3 index ssa-value. +bool linalg::RangeOp::verify() { + if (!getMin() || !getMin()->getType().isa()) + return emitOpError("first operand should be of type index"); + if (!getMax() || !getMax()->getType().isa()) + return emitOpError("second operand should be of type index"); + if (!getStep() || !getStep()->getType().isa()) + return emitOpError("third operand should be of type index"); + return false; +} + +// Parsing of the linalg dialect is not supported in this tutorial. +bool linalg::RangeOp::parse(OpAsmParser *parser, OperationState *result) { + assert(false && "NYI"); + return false; +} + +// A RangeOp prints as: +// +// ```{.mlir} +// linalg.range %arg0:%arg1:%c42 : !linalg<"range"> +// ``` +void linalg::RangeOp::print(OpAsmPrinter *p) { + *p << getOperationName() << " " << *getMin() << ":" << *getMax() << ":" + << *getStep() << " : " << getType(); +} diff --git a/mlir/tutorial/Linalg1/lib/SliceOp.cpp b/mlir/tutorial/Linalg1/lib/SliceOp.cpp new file mode 100644 index 0000000..5264a07 --- /dev/null +++ b/mlir/tutorial/Linalg1/lib/SliceOp.cpp @@ -0,0 +1,220 @@ +//===- SliceOp.cpp - Implementation of the linalg SliceOp operation -------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements an IR operation to extract a "sub-View" from a ViewType +// in the Linalg dialect. +// +//===----------------------------------------------------------------------===// + +#include "linalg/SliceOp.h" +#include "linalg/Ops.h" +#include "linalg/RangeOp.h" +#include "linalg/RangeType.h" +#include "linalg/ViewOp.h" +#include "linalg/ViewType.h" +#include "mlir/IR/Attributes.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/StandardTypes.h" + +using mlir::Builder; +using mlir::IndexType; +using mlir::OpAsmParser; +using mlir::OpAsmPrinter; +using mlir::OperationState; +using mlir::Type; +using mlir::Value; + +using namespace linalg; + +ViewOp linalg::ViewOrSliceOp::view() { + return v->getDefiningOp()->dyn_cast(); +} +SliceOp linalg::ViewOrSliceOp::slice() { + return v->getDefiningOp()->dyn_cast(); +} +linalg::ViewOrSliceOp::operator bool() { + return static_cast(view()) || static_cast(slice()); +} +unsigned linalg::ViewOrSliceOp::getRank() { + assert(*this && "Not a ViewOp or a SliceOp!"); + return view() ? view().getRank() : slice().getRank(); +} +ViewType linalg::ViewOrSliceOp::getViewType() { + assert(*this && "Not a ViewOp or a SliceOp!"); + return view() ? view().getViewType() : slice().getViewType(); +} +std::pair +linalg::ViewOrSliceOp::getRootIndexing(unsigned dim) { + assert(*this && "Not a ViewOp or a SliceOp!"); + return view() ? view().getRootIndexing(dim) : slice().getRootIndexing(dim); +} +llvm::iterator_range +linalg::ViewOrSliceOp::getIndexings() { + assert(*this && "Not a ViewOp or a SliceOp!"); + return view() ? view().getIndexings() : slice().getIndexings(); +} +Value *linalg::ViewOrSliceOp::getSupportingMemRef() { + assert(*this && "Not a ViewOp or a SliceOp!"); + return view() ? view().getSupportingMemRef() : slice().getSupportingMemRef(); +} + +// A view may itself be coming either from a ViewOp or from a SliceOp. +// TODO assert statically or dynamically that indexing is within the bounds of +// view. +void linalg::SliceOp::build(Builder *b, OperationState *result, Value *view, + Value *indexing, unsigned dim) { + // Early sanity checks + extract rank. + ViewOrSliceOp op(view); + unsigned rank = op.getRank(); + ViewType viewType = op.getViewType(); + Type elementType = viewType.getElementType(); + + result->addOperands({view, indexing}); + result->addAttribute(getSlicingDimAttrName(), + b->getIntegerAttr(b->getIndexType(), dim)); + if (indexing->getType().isa()) { + // Taking a range slice does not decrease the rank, the view has the same + // type. + result->addTypes({viewType}); + } else { + assert(indexing->getType().cast()); + result->addTypes( + {linalg::ViewType::get(b->getContext(), elementType, rank - 1)}); + } +} + +bool linalg::SliceOp::verify() { + unsigned dim = getSlicingDim(); + if (dim >= getParentRank()) + return emitOpError("slicing dim must be in the [0 .. parent_rank) range"); + ViewOrSliceOp op(getOperand(0)); + if (!op) + return emitOpError( + "first operand must be of ViewType (i.e. a ViewOp or a SliceOp)"); + auto type = getOperand(1)->getType().dyn_cast(); + auto *inst = getOperand(1)->getDefiningOp(); + auto range = inst ? inst->dyn_cast() : RangeOp(); + if (!range && !type) + return emitOpError( + "second operand must be of RangeType (i.e. a RangeOp) or IndexType"); + return false; +} + +// Parsing of the linalg dialect is not supported in this tutorial. +bool linalg::SliceOp::parse(OpAsmParser *parser, OperationState *result) { + assert(false && "NYI"); + return false; +} + +// A SliceOp prints as: +// +// ```{.mlir} +// linalg.slice %0[*, %i0] { dim : 1 } : !linalg<"view"> +// ``` +// +// Where %0 is an ssa-value holding a `view`, %i0 is an ssa-value +// holding an index. +void linalg::SliceOp::print(OpAsmPrinter *p) { + unsigned dim = getSlicingDim(); + *p << getOperationName() << " " << *getParentView() << "["; + for (unsigned idx = 0, rank = getParentRank(); idx < rank; ++idx) { + if (idx != dim) { + *p << "*"; + } else { + auto *v = getIndexing(); + if (v->getDefiningOp() && v->getDefiningOp()->isa()) { + *p << *v << ".."; + } else { + *p << *v; + } + } + *p << ((idx == rank - 1) ? "" : ", "); + } + *p << "] { " << getSlicingDimAttrName() << " : " << dim << " }" + << " : " << getViewType(); +} + +ViewType linalg::SliceOp::getViewType() { return getType().cast(); } + +unsigned linalg::SliceOp::getRank() { return getViewType().getRank(); } + +mlir::Type linalg::SliceOp::getElementType() { + return getViewType().getElementType(); +} + +ViewType linalg::SliceOp::getParentViewType() { + ViewOrSliceOp op(getParentView()); + return op.getViewType(); +} + +unsigned linalg::SliceOp::getParentRank() { + return getParentViewType().getRank(); +} + +mlir::Type linalg::SliceOp::getParentElementType() { + return getParentViewType().getElementType(); +} + +Value *linalg::SliceOp::getBaseView() { + Value *parent = getParentView(); + while (!parent->getDefiningOp()->isa()) { + parent = parent->getDefiningOp()->cast().getParentView(); + } + assert(parent && "null parent"); + return parent; +} + +// We want to extract the range from the original ViewOp that this slice +// captures along `dim`. To achieve this, we want to walk back the chain of +// SliceOp and determine the first slice that constrains `dim`. +std::pair linalg::SliceOp::getRootIndexing(unsigned dim) { + assert(dim < getRank()); + auto *view = getParentView(); + unsigned sliceDim = getSlicingDim(); + auto *indexing = getIndexing(); + if (indexing->getDefiningOp()) { + if (auto rangeOp = indexing->getDefiningOp()->cast()) { + // If I sliced with a range and I sliced at this dim, then I'm it. + if (dim == sliceDim) { + return make_pair(rangeOp.getResult(), dim); + } + // Otherwise, I did not change the rank, just go look for `dim` into my + // parent. + ViewOrSliceOp op(view); + return op.getRootIndexing(dim); + } + } + assert(indexing->getType().isa()); + // If I get here, I indexed and reduced along the dim `sliceDim` from my + // parent. I need to query my parent for `dim` or `dim + 1` depending on + // whether dim > sliceDim or not. + unsigned parentDim = dim > sliceDim ? dim + 1 : dim; + ViewOrSliceOp op(view); + return op.getRootIndexing(parentDim); +} + +Value *linalg::SliceOp::getSupportingMemRef() { + auto view = getBaseView()->getDefiningOp()->cast(); + return view.getSupportingMemRef(); +} + +mlir::Operation::operand_range linalg::SliceOp::getIndexings() { + return {this->getOperation()->operand_begin() + SliceOp::FirstIndexingOperand, + this->getOperation()->operand_end()}; +} diff --git a/mlir/tutorial/Linalg1/lib/ViewOp.cpp b/mlir/tutorial/Linalg1/lib/ViewOp.cpp new file mode 100644 index 0000000..d80e86a --- /dev/null +++ b/mlir/tutorial/Linalg1/lib/ViewOp.cpp @@ -0,0 +1,156 @@ +//===- ViewOp.cpp - Implementation of the linalg ViewOp operation -------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a simple IR operation to create a new ViewType in the +// linalg dialect. +// +//===----------------------------------------------------------------------===// + +#include "linalg/ViewOp.h" +#include "linalg/Ops.h" +#include "linalg/RangeOp.h" +#include "linalg/RangeType.h" +#include "linalg/ViewType.h" +#include "mlir/EDSC/Helpers.h" +#include "mlir/IR/Builders.h" +#include "mlir/IR/OpImplementation.h" +#include "mlir/IR/StandardTypes.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/Twine.h" + +using llvm::ArrayRef; +using llvm::SmallVector; +using llvm::Twine; +using mlir::Builder; +using mlir::IndexType; +using mlir::MemRefType; +using mlir::OpAsmParser; +using mlir::OpAsmPrinter; +using mlir::OperationState; +using mlir::Type; +using mlir::Value; + +using namespace linalg; + +void linalg::ViewOp::build(Builder *b, OperationState *result, Value *memRef, + ArrayRef indexings) { + MemRefType memRefType = memRef->getType().cast(); + result->addOperands({memRef}); + assert(indexings.size() == memRefType.getRank() && + "unexpected number of indexings (must match the memref rank)"); + + result->addOperands(indexings); + unsigned rank = memRefType.getRank(); + for (auto *v : indexings) { + if (!v->getType().isa()) { + rank--; + } + } + Type elementType = memRefType.getElementType(); + result->addTypes({linalg::ViewType::get(b->getContext(), elementType, rank)}); +} + +bool linalg::ViewOp::verify() { + if (llvm::empty(getOperands())) + return emitOpError( + "requires at least a memref operand followed by 'rank' indices"); + auto memrefType = getOperand(0)->getType().dyn_cast(); + unsigned memrefRank = memrefType.getRank(); + if (!memrefType) + return emitOpError("first operand must be of MemRefType"); + unsigned index = 0; + for (auto indexing : getIndexings()) { + if (!indexing->getType().isa() && + !indexing->getType().isa()) { + return emitOpError(Twine(index) + + "^th index must be of range or index type"); + } + ++index; + } + if (llvm::size(getIndexings()) != memrefRank) { + return emitOpError("requires at least a memref operand followed by " + + Twine(memrefRank) + " indices"); + } + unsigned rank = memrefRank; + for (auto *v : getIndexings()) { + if (!v->getType().isa()) { + rank--; + } + } + if (getRank() != rank) { + return emitOpError("the rank of the view must be the number of its range " + "indices: " + + Twine(rank)); + } + return false; +} + +// Parsing of the linalg dialect is not supported in this tutorial. +bool linalg::ViewOp::parse(OpAsmParser *parser, OperationState *result) { + assert(false && "NYI"); + return false; +} + +// A ViewOp prints as: +// +// ```{.mlir} +// linalg.view %0[%1, %2] : !linalg<"view"> +// ``` +// +// Where %0 is an ssa-value holding a MemRef, %1 and %2 are ssa-value each +// holding a range. +void linalg::ViewOp::print(OpAsmPrinter *p) { + *p << getOperationName() << " " << *getSupportingMemRef() << "["; + unsigned numRanges = llvm::size(getIndexings()); + unsigned index = 0; + for (auto indexing : getIndexings()) { + *p << *indexing << ((index++ == numRanges - 1) ? "" : ", "); + } + *p << "] : " << getType(); +} + +Type linalg::ViewOp::getElementType() { return getViewType().getElementType(); } + +ViewType linalg::ViewOp::getViewType() { return getType().cast(); } + +unsigned linalg::ViewOp::getRank() { return getViewType().getRank(); } + +// May be something else than a MemRef in the future. +Value *linalg::ViewOp::getSupportingMemRef() { + auto *res = getOperand(0); + assert(res->getType().isa()); + return res; +} + +SmallVector linalg::ViewOp::getRanges() { + llvm::SmallVector res; + for (auto *operand : getIndexings()) { + if (!operand->getType().isa()) { + res.push_back(operand); + } + } + return res; +} + +Value *linalg::ViewOp::getIndexing(unsigned rank) { + SmallVector ranges(getIndexings().begin(), getIndexings().end()); + return ranges[rank]; +} + +mlir::Operation::operand_range linalg::ViewOp::getIndexings() { + return {operand_begin() + ViewOp::FirstIndexingOperand, operand_end()}; +} diff --git a/mlir/tutorial/Linalg1/lib/ViewType.cpp b/mlir/tutorial/Linalg1/lib/ViewType.cpp new file mode 100644 index 0000000..1560692 --- /dev/null +++ b/mlir/tutorial/Linalg1/lib/ViewType.cpp @@ -0,0 +1,79 @@ +//===- ViewType.h - Implementation of the ViewType custom type ------------===// +// +// Copyright 2019 The MLIR Authors. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +// ============================================================================= +// +// This file implements a custom ViewType in the linalg dialect. +// +//===----------------------------------------------------------------------===// + +#include "linalg/ViewType.h" + +using mlir::MLIRContext; +using mlir::Type; +using mlir::TypeStorage; +using mlir::TypeStorageAllocator; + +namespace linalg { + +struct ViewTypeStorage : public mlir::TypeStorage { + /// Underlying Key type to transport the payload needed to construct a custom + /// type in a generic way. + struct Key { + Key(Type elementType, unsigned rank) + : elementType(elementType), rank(rank) {} + Type elementType; + unsigned rank; + }; + /// `KeyTy` is a necessary typename hook for MLIR's custom type unique'ing. + using KeyTy = Key; + + /// Construction in the llvm::BumpPtrAllocator given a key. + static ViewTypeStorage *construct(TypeStorageAllocator &allocator, + const Key &key) { + return new (allocator.allocate()) ViewTypeStorage(key); + } + + /// Equality operator for hashing. + bool operator==(const Key &key) const { + return elementType == key.elementType && rank == key.rank; + } + + /// Hashing for unique'ing. + static unsigned hashKey(const Key &key) { + return llvm::hash_combine(key.elementType, key.rank); + } + + unsigned getRank() { return rank; }; + Type getElementType() { return elementType; }; + +private: + ViewTypeStorage(const Key &key) + : elementType(key.elementType), rank(key.rank) {} + + Type elementType; + unsigned rank; +}; + +ViewType linalg::ViewType::get(MLIRContext *context, Type elementType, + unsigned rank) { + return Base::get(context, LinalgTypes::View, elementType, rank); +} + +Type linalg::ViewType::getElementType() { return getImpl()->getElementType(); } + +unsigned linalg::ViewType::getRank() { return getImpl()->getRank(); } + +} // namespace linalg -- 2.7.4