From e36337a998a6be39d65872eab3e3e2291b6518b9 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Tue, 1 Oct 2019 05:22:54 -0700 Subject: [PATCH] Unify Linalg types by using strided memrefs This CL finishes the implementation of the Linalg + Affine type unification of the [strided memref RFC](https://groups.google.com/a/tensorflow.org/forum/#!topic/mlir/MaL8m2nXuio). As a consequence, the !linalg.view type, linalg::DimOp, linalg::LoadOp and linalg::StoreOp can now disappear and Linalg can use standard types everywhere. PiperOrigin-RevId: 272187165 --- mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td | 9 +- .../mlir/Dialect/Linalg/IR/LinalgLibraryOps.td | 106 ++++--- mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h | 9 +- mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td | 164 +++------- mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h | 47 +-- mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h | 32 +- .../include/mlir/Dialect/Linalg/Utils/Intrinsics.h | 6 - mlir/include/mlir/Dialect/Linalg/Utils/Utils.h | 4 +- mlir/include/mlir/IR/OpBase.td | 14 + mlir/include/mlir/IR/StandardTypes.h | 29 +- .../StandardToLLVM/ConvertStandardToLLVM.cpp | 70 ++--- .../Dialect/Linalg/Analysis/DependenceAnalysis.cpp | 2 +- mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp | 305 ++++++------------- mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp | 93 +----- mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp | 8 +- .../Linalg/Transforms/LowerToLLVMDialect.cpp | 336 ++++++--------------- .../lib/Dialect/Linalg/Transforms/LowerToLoops.cpp | 40 +-- mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp | 23 +- mlir/lib/IR/StandardTypes.cpp | 115 +++++-- mlir/test/Dialect/Linalg/canonicalize.mlir | 73 ----- mlir/test/Dialect/Linalg/fusion-2-level.mlir | 37 ++- mlir/test/Dialect/Linalg/fusion.mlir | 235 +++++++------- mlir/test/Dialect/Linalg/invalid.mlir | 112 +++---- mlir/test/Dialect/Linalg/llvm.mlir | 60 ++-- mlir/test/Dialect/Linalg/loops.mlir | 257 ++++++++-------- mlir/test/Dialect/Linalg/promote.mlir | 46 +-- mlir/test/Dialect/Linalg/roundtrip.mlir | 152 +++++----- mlir/test/Dialect/Linalg/tile.mlir | 158 +++++----- mlir/test/Dialect/Linalg/tile_conv.mlir | 38 +-- .../lib/Transforms/TestMemRefStrideCalculation.cpp | 8 +- mlir/test/mlir-cpu-runner/cblas_interface.cpp | 26 +- .../mlir-cpu-runner/linalg_integration_test.mlir | 39 +-- 32 files changed, 1107 insertions(+), 1546 deletions(-) delete mode 100644 mlir/test/Dialect/Linalg/canonicalize.mlir diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td index 5ca798e..1984056 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgBase.td @@ -42,8 +42,9 @@ def Linalg_Dialect : Dialect { more generally at interface points across language boundaries (e.g. C++ / Python). - Generally, Linalg passes non-owning pointers to View data structures to - precompiled library calls linked externally. + Generally, Linalg passes non-owning pointers to strided memref data + structures to precompiled library calls linked externally. The name `view` + is used interchangeably in Linalg to signify strided memref. }]; } @@ -55,8 +56,4 @@ def Buffer : Type; def LinalgIsRangeTypePred : CPred<"$_self.isa()">; def Range : Type; -// Whether a type is a ViewType. -def LinalgIsViewTypePred : CPred<"$_self.isa()">; -def View : Type; - #endif // LINALG_BASE diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td index 67457b7..1e6384c 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgLibraryOps.td @@ -59,20 +59,12 @@ class NLoopTypes : LinalgParametricIntNativeOpTrait<"NLoopTypes", [n_par, n_red, n_win]> {} -// The linalg `ViewRanks` trait the API for ops that are known to have a -// specified list of view ranks. -// See Linalg/LinalgTraits.h for implementation details an usage. -class ViewRanks ranks> : -LinalgParametricIntNativeOpTrait<"ViewRanks", ranks> -{} - def ViewTraits : NativeOpTrait<"linalg::ViewTraits">; // The linalg 'LinalgLibraryInterface' provides access to the 'LinalgOp' // interface. def LinalgLibraryInterface : OpInterface<"LinalgOp"> { let methods = [ - /// Query the number of inputs and outputs from the operation. InterfaceMethod< "Query the number of inputs from the current operation.", "unsigned", "getNumInputs" @@ -97,8 +89,6 @@ def LinalgLibraryInterface : OpInterface<"LinalgOp"> { "Query the input and output operands from the current operation.", "Operation::operand_range", "getInputsAndOutputs" >, - - /// Query the number of each type of loop. InterfaceMethod< "Query the number of parallel loops within the current operation.", "unsigned", "getNumParallelLoops" @@ -117,16 +107,12 @@ def LinalgLibraryInterface : OpInterface<"LinalgOp"> { return op.getNumParallelLoops() + op.getNumReductionLoops() + op.getNumWindowLoops(); }]>, - - /// Get a specific input/output at the given index. - InterfaceMethod<"Query the input for the given index.", + InterfaceMethod<"Query the input view at the given index.", "Value *", "getInput", (ins "unsigned":$i) >, - InterfaceMethod<"Query the output for the given index.", + InterfaceMethod<"Query the output view at the given index.", "Value *", "getOutput", (ins "unsigned":$i) >, - - /// Get the index of the given value, or None if the value is not an input. InterfaceMethod<[{ Query the index of the given input value, or `None` if the value is not an input. @@ -139,16 +125,13 @@ def LinalgLibraryInterface : OpInterface<"LinalgOp"> { }], "llvm::Optional", "getIndexOfOutput", (ins "Value *":$view) >, + InterfaceMethod<[{ + Query the type of the input view at the given index. + }], "MemRefType", "getInputViewType", (ins "unsigned":$i)>, + InterfaceMethod<[{ + Query the type of the output view at the given index. + }], "MemRefType", "getOutputViewType", (ins "unsigned":$i)>, - /// Get the view type of the input/output at the given index. - InterfaceMethod<"Query the view type for the given input.", - "ViewType", "getInputViewType", (ins "unsigned":$i) - >, - InterfaceMethod<"Query the view type for the given output.", - "ViewType", "getOutputViewType", (ins "unsigned":$i) - >, - - /// Create an operation with the given location and operands. StaticInterfaceMethod<[{ Create an operation of the current type with the given location, operands, and attributes. @@ -211,13 +194,14 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputsAndOutputs<1, 1>]> { Copies the data in the input view into the output view. Usage: - linalg.copy(%arg0, %arg1) : !linalg.view, !linalg.view + linalg.copy(%arg0, %arg1) : memref, + memref One possible lowering to loop form is: %0 = linalg.dim %arg0, 0 : index loop.for %i0 = %c0 to %0 step %c1 { - %1 = linalg.load %arg0[%i0] : !linalg.view - linalg.store %1, %arg1[%i0] : !linalg.view + %1 = linalg.load %arg0[%i0] : memref + linalg.store %1, %arg1[%i0] : memref } Optionally, can take `input_permutation` and `output_permutation` attributes @@ -226,7 +210,8 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputsAndOutputs<1, 1>]> { Usage: linalg.copy(%arg0, %arg1) {inputPermutation : (i, j, k) -> (i, k, j), outputPermutation : (i, j, k) -> (k, j, i)} : - !linalg.view, !linalg.view + memref, + memref One possible lowering to loop form is: %0 = linalg.dim %arg0, 0 @@ -235,15 +220,17 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputsAndOutputs<1, 1>]> { loop.for %i0 = %c0 to %{{.*}} step %c1 { loop.for %i1 = %c0 to %{{.*}} step %c1 { loop.for %i2 = %c0 to %{{.*}} step %c1 { - %3 = linalg.load %arg0[%i0, %i2, %i1] : !linalg.view - linalg.store %3, %arg1[%i2, %i1, %i0] : !linalg.view + %3 = linalg.load %arg0[%i0, %i2, %i1] : + memref + linalg.store %3, %arg1[%i2, %i1, %i0] : + memref The views are expected to be compatible for correctness but this is not enforced at the moment. }]; let arguments = (ins - View:$input, - View:$output, + AnyStridedMemRef:$input, + AnyStridedMemRef:$output, OptionalAttr:$inputPermutation, OptionalAttr:$outputPermutation); // TODO(ntv) this should go away once the usage of OptionalAttr triggers @@ -256,7 +243,7 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputsAndOutputs<1, 1>]> { let extraClassDeclaration = libraryCallName # [{ unsigned getNumParallelLoops() { auto *view = *(getOperands().begin()); - return view->getType().cast().getRank(); + return view->getType().cast().getRank(); } unsigned getNumReductionLoops() { return 0; } unsigned getNumWindowLoops() { return 0; } @@ -265,11 +252,12 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputsAndOutputs<1, 1>]> { } def FillOp : LinalgLibrary_Op<"fill", [NInputsAndOutputs<0, 1>]> { - let arguments = (ins View, AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>); + let arguments = (ins AnyStridedMemRef, + AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>); let extraClassDeclaration = libraryCallName # [{ unsigned getNumParallelLoops() { auto *view = *(getOperands().begin()); - return view->getType().cast().getRank(); + return view->getType().cast().getRank(); } unsigned getNumReductionLoops() { return 0; } unsigned getNumWindowLoops() { return 0; } @@ -282,25 +270,28 @@ def FillOp : LinalgLibrary_Op<"fill", [NInputsAndOutputs<0, 1>]> { def DotOp : LinalgLibrary_Op<"dot", [NInputsAndOutputs<2, 1>, - NLoopTypes<0, 1, 0>, - ViewRanks<[1, 1, 0]>]> { - let arguments = (ins View, View, View); + NLoopTypes<0, 1, 0>]> { + let arguments = (ins AnyStridedMemRefOfRank<1>, + AnyStridedMemRefOfRank<1>, + AnyStridedMemRefOfRank<0>); let extraClassDeclaration = libraryCallName; } def MatvecOp : LinalgLibrary_Op<"matvec", [NInputsAndOutputs<2, 1>, - NLoopTypes<1, 1, 0>, - ViewRanks<[2, 1, 1]>]> { - let arguments = (ins View, View, View); + NLoopTypes<1, 1, 0>]> { + let arguments = (ins AnyStridedMemRefOfRank<2>, + AnyStridedMemRefOfRank<1>, + AnyStridedMemRefOfRank<1>); let extraClassDeclaration = libraryCallName; } def MatmulOp : LinalgLibrary_Op<"matmul", [NInputsAndOutputs<2, 1>, - NLoopTypes<2, 1, 0>, - ViewRanks<[2, 2, 2]>]> { - let arguments = (ins View, View, View); + NLoopTypes<2, 1, 0>]> { + let arguments = (ins AnyStridedMemRefOfRank<2>, + AnyStridedMemRefOfRank<2>, + AnyStridedMemRefOfRank<2>); let extraClassDeclaration = libraryCallName; } @@ -323,7 +314,8 @@ def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> { // TODO(ntv) padding. // Following the TF source of truth above, strides and dilations are integer // attributes of the same rank as the number of window dimensions. - let arguments = (ins View:$filter, View:$input, View:$output, + let arguments = (ins AnyStridedMemRef:$filter, AnyStridedMemRef:$input, + AnyStridedMemRef:$output, OptionalAttr:$strides, OptionalAttr:$dilations); let extraClassDeclaration = libraryCallName # [{ @@ -373,7 +365,9 @@ def GenericOp : LinalgLibraryBase_Op<"generic", []> { ``` linalg.generic #trait_attribute %A, %B, %C {other-attributes} : - !linalg.view, !linalg.view, !linalg.view + memref, + memref, + memref ``` Where #trait_attributes is an alias of a dictionary attribute containing: @@ -424,13 +418,17 @@ def GenericOp : LinalgLibraryBase_Op<"generic", []> { And can be reused in multiple places as: ``` linalg.generic #matmul_trait %A, %B, %C [other-attributes] : - !linalg.view, !linalg.view, !linalg.view + memref, + memref, + memref ``` This may lower to either: ``` call @linalg_matmul(%A, %B, %C) : - (!linalg.view, !linalg.view, !linalg.view) + (memref, + memref, + memref) -> () ``` @@ -439,17 +437,17 @@ def GenericOp : LinalgLibraryBase_Op<"generic", []> { loop.for %m = %c0 to %M step %c1 { loop.for %n = %c0 to %N step %c1 { loop.for %k = %c0 to %K step %c1 { - %a = linalg.load %A[%m, %k] : !linalg.view - %b = linalg.load %B[%k, %n] : !linalg.view - %c = linalg.load %C[%m, %n] : !linalg.view + %a = linalg.load %A[%m, %k] : memref + %b = linalg.load %B[%k, %n] : memref + %c = linalg.load %C[%m, %n] : memref %d = call @mac(%a, %b, %c) : (f32, f32, f32) -> (f32) - linalg.store %d, %C[%m, %n] : !linalg.view + linalg.store %d, %C[%m, %n] : memref } } } ``` }]; - let arguments = (ins Variadic:$views, + let arguments = (ins Variadic:$views, AffineMapArrayAttr:$indexing_maps, I64ArrayAttr:$n_loop_types, I64ArrayAttr:$n_views, diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h index f9bcf77..b1dbdf3 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.h @@ -44,15 +44,18 @@ namespace linalg { /// /// Examples: /// -/// 1. linalg.fill(%A, %f) : !linalg.view, f32 +/// 1. linalg.fill(%A, %f) : memref, f32 /// name mangles into `linalg_fill_viewf32_f32_impl` /// /// 2. linalg.dot(%A, %B, %C) : -/// !linalg.view, !linalg.view, !linalg.view +/// memref, +/// memref, memref /// name mangles into `linalg_dot_viewxf32_viewxf32_viewf32_impl` /// /// 3. linalg.matmul(...) : -/// !linalg.view, !linalg.view, !linalg.view +/// memref, +/// memref, +/// memref /// name mangles into `linalg_matmul_viewxxf32_viewxxf32_viewxxf32_impl` std::string generateLibraryCallName(Operation *op); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td index 15efbd6..cf6f36c 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgOps.td @@ -142,76 +142,6 @@ def BufferSizeOp : let verifier = ?; } -def DimOp : Linalg_Op<"dim", [NoSideEffect]>, - Arguments<(ins View:$view, APIntAttr:$index)>, - Results<(outs Index)> { - let summary = "dimension index operation"; - let description = [{ - The "linalg.dim" operation takes a linalg.view and returns an - "index". It requires a single integer attribute named "index". It - returns the size of the specified dimension. - - Example: - - %1 = linalg.dim %0, 2 : view - }]; - - let verifier = [{ - if (getIndex() >= getViewType().getRank()) - return emitOpError("index is out of range"); - return success(); - }]; - - let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *view, unsigned index", - [{ - result.addOperands(view); - result.addAttribute( - "index", builder->getIntegerAttr(builder->getIndexType(), index)); - result.types.push_back(builder->getIndexType()); - }]>]; - - let extraClassDeclaration = [{ - unsigned getIndex() { - return getAttrOfType("index").getValue().getZExtValue(); - } - ViewType getViewType() { return getOperand()->getType().cast(); } - }]; - - let hasCanonicalizer = 1; -} - -def LoadOp : - Linalg_Op<"load" - // TODO(ntv): activate once ViewType can be made a ShapeType (i.e. - // shape type is extensible or standard adopts a reasonable view type). - // , [ PredOpTrait<"operand and result have same element type", - // TCresVTEtIsSameAsOpBase<0, 0>>] - >, - Arguments<(ins View:$view, Variadic:$indices)>, - Results<(outs AnyType:$value)> { - let summary = "Read an elemental value from a view at a certain index"; - let description = [{ - The `linalg.load` op reads an elemental value from a view at a certain - index. This is the counterpart of other load ops but operating on ViewType. - - Example: - - %0 = linalg.load %V[%c0] : !linalg.view - }]; - let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *view, " - "ArrayRef indices", - [{ - auto viewType = view->getType().cast(); - build(builder, result, viewType.getElementType(), view, indices); - }]>]; - let extraClassDeclaration = [{ - unsigned getRank() { return getViewType().getRank(); } - ViewType getViewType() { return view()->getType().cast(); } - }]; -} - def RangeOp : Linalg_Op<"range", [NoSideEffect]>, Arguments<(ins Index:$min, Index:$max, Index:$step)>, @@ -238,8 +168,8 @@ def RangeOp : } def SliceOp : Linalg_Op<"slice", [NoSideEffect]>, - Arguments<(ins View:$view, Variadic>:$indexings)>, - Results<(outs View)> { + Arguments<(ins AnyStridedMemRef:$view, Variadic>:$indexings)>, + Results<(outs AnyStridedMemRef)> { let summary = "Produce a linalg.view which is a subview of a base view."; let description = [{ The "linalg.slice" op produces a linalg.view which is a subview of a given @@ -259,18 +189,18 @@ def SliceOp : Linalg_Op<"slice", [NoSideEffect]>, 1. rank-preserving slice: - %4 = linalg.slice %0[%1, %2] : !linalg.view, !linalg.range, - !linalg.range, !linalg.view + %4 = linalg.slice %0[%1, %2] : memref, + !linalg.range, !linalg.range, memref 2. rank-reducing slice (from 2-D to 1-D): - %4 = linalg.slice %0[%1, %2] : !linalg.view, index, - !linalg.range, !linalg.view + %4 = linalg.slice %0[%1, %2] : memref, + index, !linalg.range, memref 3. rank-reducing slice (from 2-D to 0-D): - %4 = linalg.slice %0[%1, %2] : !linalg.view, index, index, - !linalg.view + %4 = linalg.slice %0[%1, %2] : memref, + index, index, memref }]; let builders = [OpBuilder< @@ -281,9 +211,9 @@ def SliceOp : Linalg_Op<"slice", [NoSideEffect]>, enum { FirstIndexingOperand = 1 }; unsigned getRank() { return getViewType().getRank(); } Type getElementType() { return getViewType().getElementType(); } - ViewType getViewType() { return getType().cast(); } + MemRefType getViewType() { return getType().cast(); } unsigned getBaseViewRank() { return getBaseViewType().getRank(); } - ViewType getBaseViewType() { return view()->getType().cast(); } + MemRefType getBaseViewType() { return view()->getType().cast(); } // Get the underlying indexing at a given rank. Value *indexing(unsigned rank) { return *(indexings().begin() + rank); } @@ -299,33 +229,9 @@ def SliceOp : Linalg_Op<"slice", [NoSideEffect]>, }]; } -def StoreOp : - Linalg_Op<"store" - // TODO(ntv): activate once ViewType can be made a ShapeType (i.e. - // shape type is extensible or standard adopts a reasonable view type). - // , [ PredOpTrait<"value to store and view have the same element type", - // TCopVTEtIsSameAs<0, 1>>] - >, - Arguments<(ins AnyType:$value, View:$view, Variadic:$indices)>, - Results<(outs)> { - let summary = "Write an elemental value in a view at a certain index"; - let description = [{ - The `linalg.store` op writes an elemental value in a view at a certain - index. This is the counterpart of other store ops but operating on ViewType. - - Example: - - linalg.store %f, %V[%c0] : !linalg.view - }]; - let extraClassDeclaration = [{ - unsigned getRank() { return getViewType().getRank(); } - ViewType getViewType() { return view()->getType().cast(); } - }]; -} - def SubViewOp : Linalg_Op<"subview", [NoSideEffect]>, - Arguments<(ins View:$view, Variadic:$ranges)>, - Results<(outs View)> { + Arguments<(ins AnyStridedMemRef:$view, Variadic:$ranges)>, + Results<(outs AnyStridedMemRef)> { let summary = "subview operation"; let description = [{ The "linalg.subview" op produces a linalg.view which is a subview of a given @@ -344,32 +250,30 @@ def SubViewOp : Linalg_Op<"subview", [NoSideEffect]>, Example: - %1 = linalg.subview %0[%1, %2, %3, %4, %5, %6] : view + %1 = linalg.subview %0[%1, %2, %3, %4, %5, %6] : + memref }]; // TODO(ntv) evolve syntax towards: - // linalg.subview %0[%1:%2:%3][%4:%5:%6] : view + // linalg.subview %0[%1:%2:%3][%4:%5:%6] : + // memref let builders = [OpBuilder< - "Builder *builder, OperationState &result, Value *view, " - "ArrayRef ranges", - [{ - result.addOperands(view); - result.addOperands(ranges); - result.types.push_back(view->getType()); - }]>]; + "Builder *b, OperationState &result, Value *buffer, " + "ArrayRef ranges, Type resultType = Type(), " + "ArrayRef attrs = {}">]; let verifier = [{ auto rank = getViewType().getRank(); if (getNumOperands() != 3 * rank + 1) - return emitOpError("expected a view followed by ") << (3 * rank) << - " indices specifying a range for each dimension"; + return emitOpError("expected a strided memref followed by ") << (3 * rank) + << " indices specifying a range for each dimension"; return success(); }]; let extraClassDeclaration = [{ Value *getView() { return getOperand(0); } - ViewType getViewType() { return getView()->getType().cast(); } + MemRefType getViewType() { return getView()->getType().cast(); } struct Range { Value *min; Value *max; Value *step; }; @@ -400,16 +304,17 @@ def SubViewOp : Linalg_Op<"subview", [NoSideEffect]>, } def TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>, - Arguments<(ins View:$view, AffineMapAttr:$permutation)>, - Results<(outs View)> { - let summary = "transpose operation produces a new view (metadata-only)"; + Arguments<(ins AnyStridedMemRef:$view, AffineMapAttr:$permutation)>, + Results<(outs AnyStridedMemRef)> { + let summary = "transpose operation produces a new strided memref (metadata-only)"; let description = [{ - The "linalg.transpose" op produces a linalg.view whose sizes and strides are - a permutation of the original. This is a pure metadata transformation. + The "linalg.transpose" op produces a strided memref whose sizes and strides + are a permutation of the original. This is a pure metadata transformation. Example: - %1 = linalg.transpose %0 (i, j) -> (j, i) : !linalg.view + %1 = linalg.transpose %0 (i, j) -> (j, i) : + memref }]; let builders = [OpBuilder< @@ -426,16 +331,16 @@ def TransposeOp : Linalg_Op<"transpose", [NoSideEffect]>, let extraClassDeclaration = [{ static StringRef getPermutationAttrName() { return "permutation"; } - ViewType getViewType() { return view()->getType().cast(); } + MemRefType getViewType() { return view()->getType().cast(); } }]; } def ViewOp : Linalg_Op<"view", [NoSideEffect]>, Arguments<(ins Buffer:$buffer, Variadic:$ranges)>, - Results<(outs View)> { + Results<(outs AnyStridedMemRef)> { let summary = "view operation"; let description = [{ - The "linalg.view" op produces a linalg.view which is a multi-dimensional + The "linalg.view" op produces a strided memref which is a multi-dimensional range abstraction on top of an underlying linalg.buffer. This gives an indexing structure to an otherwise non-indexable linalg.buffer. @@ -447,7 +352,8 @@ def ViewOp : Linalg_Op<"view", [NoSideEffect]>, %1 = linalg.buffer_alloc %0 : !linalg.buffer %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range - %3 = linalg.view %1[%2, %2] : !linalg.view> + %3 = linalg.view %1[%2, %2] : + memref, stride_specification> }]; let builders = [OpBuilder< @@ -465,7 +371,7 @@ def ViewOp : Linalg_Op<"view", [NoSideEffect]>, enum { FirstIndexingOperand = 1 }; unsigned getRank() { return getViewType().getRank(); } Type getElementType() { return getViewType().getElementType(); } - ViewType getViewType() { return getType().cast(); } + MemRefType getViewType() { return getType().cast(); } /// Get the underlying indexing at a given rank. Value *getRange(unsigned rank) { assert(rank < getRank() && "rank overflow"); diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h index 593021d..5456deb 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTraits.h @@ -18,8 +18,9 @@ #ifndef MLIR_DIALECT_LINALG_LINALGTRAITS_H_ #define MLIR_DIALECT_LINALG_LINALGTRAITS_H_ -#include "mlir/IR/OpDefinition.h" #include "mlir/Dialect/Linalg/IR/LinalgTypes.h" +#include "mlir/IR/OpDefinition.h" +#include "mlir/IR/StandardTypes.h" #include "mlir/Support/LLVM.h" namespace mlir { @@ -81,8 +82,8 @@ public: return llvm::None; } /// Return the `i`-th input view type. - mlir::linalg::ViewType getInputViewType(unsigned i) { - return getInput(i)->getType().template cast(); + MemRefType getInputViewType(unsigned i) { + return getInput(i)->getType().template cast(); } /// Return the range over input views. Operation::operand_range getInputs() { @@ -102,8 +103,8 @@ public: return llvm::None; } /// Return the `i`-th output view type. - mlir::linalg::ViewType getOutputViewType(unsigned i) { - return getOutput(i)->getType().template cast(); + MemRefType getOutputViewType(unsigned i) { + return getOutput(i)->getType().template cast(); } /// Return the range over output views. Operation::operand_range getOutputs() { @@ -114,7 +115,7 @@ public: /// Return the number of input and output views. unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); } /// Return the `i`-th view type. - mlir::linalg::ViewType getViewType(unsigned i) { + MemRefType getViewType(unsigned i) { return (i < nInputs()) ? getInputViewType(i) : getOutputViewType(i - nInputs()); } @@ -127,10 +128,6 @@ public: auto nViews = cast(op).getNumInputsAndOutputs(); if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nViews))) return failure(); - for (unsigned i = 0, e = nViews; i < e; ++i) { - if (!op->getOperand(i)->getType().dyn_cast()) - return op->emitOpError("operand ") << i << " must have view type "; - } return success(); } }; @@ -156,36 +153,6 @@ public: }; }; -/// This class provides the API for ops that are known to have a specified -/// list of view ranks. This is used as a trait like this: -/// -/// class MatvecOp : public Op::Impl> { -/// -template class ViewRanks { -public: - template - class Impl - : public OpTrait::TraitBase::Impl> { - public: - static LogicalResult verifyTrait(Operation *op) { - if (op->getNumOperands() != sizeof...(Ranks)) - return op->emitError("expected ") << sizeof...(Ranks) << " operands"; - - unsigned ranks[]{Ranks...}; - for (unsigned i = 0, e = op->getNumOperands(); i < e; ++i) { - auto viewType = - op->getOperand(i)->getType().dyn_cast(); - if (!viewType) - return op->emitOpError("operand ") << i << " must have view type "; - if (ranks[i] != viewType.getRank()) - return op->emitOpError("operand ") - << i << " must have rank " << ranks[i]; - } - return success(); - } - }; -}; - } // namespace linalg } // namespace OpTrait } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h index 86b77f1..1835073 100644 --- a/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h +++ b/mlir/include/mlir/Dialect/Linalg/IR/LinalgTypes.h @@ -28,8 +28,7 @@ namespace linalg { enum LinalgTypes { Buffer = Type::FIRST_LINALG_TYPE, Range, - View, - LAST_USED_LINALG_TYPE = View, + LAST_USED_LINALG_TYPE = Range, }; class LinalgDialect : public Dialect { @@ -86,35 +85,6 @@ public: static bool kindof(unsigned kind) { return kind == LinalgTypes::Range; } }; -/// A ViewType represents a multi-dimensional range abstraction on top of an -/// underlying storage type. It is parameterizable by the underlying element -/// type and the rank of the view. -/// A new value of ViewType is constructed from a buffer with a view op and -/// passing it ranges: -/// -/// ```{.mlir} -/// %1 = linalg.buffer_alloc %0 : !linalg.buffer -/// %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range -/// %3 = linalg.view %1[%2, %2] : !linalg.view -/// ``` -struct ViewTypeStorage; -class ViewType : public Type::TypeBase { -public: - // Used for generic hooks in TypeBase. - using Base::Base; - /// Construction hook. - static ViewType get(MLIRContext *context, Type elementType, unsigned rank); - // Used to implement llvm-style cast. - static bool kindof(unsigned kind) { return kind == LinalgTypes::View; } - - // Type-specific functionality. - /// Return the underlying elemental type. - Type getElementType(); - /// Return the rank of the view. - /// This is the number of indexings needed to reach an underlying element. - unsigned getRank(); -}; - } // namespace linalg } // namespace mlir diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Intrinsics.h b/mlir/include/mlir/Dialect/Linalg/Utils/Intrinsics.h index 014fa72..1c6bb68 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Intrinsics.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Intrinsics.h @@ -25,22 +25,16 @@ namespace linalg { class BufferAllocOp; class BufferDeallocOp; class CopyOp; -class DimOp; class FillOp; -class LoadOp; class RangeOp; class SliceOp; -class StoreOp; class ViewOp; namespace intrinsics { using buffer_alloc = mlir::edsc::intrinsics::ValueBuilder; using buffer_dealloc = mlir::edsc::intrinsics::OperationBuilder; using copy = mlir::edsc::intrinsics::OperationBuilder; -using dim = mlir::edsc::intrinsics::ValueBuilder; using fill = mlir::edsc::intrinsics::OperationBuilder; -using linalg_load = mlir::edsc::intrinsics::ValueBuilder; -using linalg_store = mlir::edsc::intrinsics::OperationBuilder; using range = mlir::edsc::intrinsics::ValueBuilder; using slice = mlir::edsc::intrinsics::ValueBuilder; using view = mlir::edsc::intrinsics::ValueBuilder; diff --git a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h index ff46f6a..fce2f73 100644 --- a/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h +++ b/mlir/include/mlir/Dialect/Linalg/Utils/Utils.h @@ -84,9 +84,9 @@ template SmallVector getViewSizes(ConcreteOp linalgOp) { SmallVector res; for (auto v : linalgOp.getInputsAndOutputs()) { - ViewType t = v->getType().template cast(); + MemRefType t = v->getType().template cast(); for (unsigned i = 0; i < t.getRank(); ++i) - res.push_back(intrinsics::dim(v, i)); + res.push_back(edsc::intrinsics::dim(v, i)); } return res; } diff --git a/mlir/include/mlir/IR/OpBase.td b/mlir/include/mlir/IR/OpBase.td index c662576..4bb88a2 100644 --- a/mlir/include/mlir/IR/OpBase.td +++ b/mlir/include/mlir/IR/OpBase.td @@ -495,6 +495,20 @@ class StaticShapeMemRefOf allowedTypes> def AnyStaticShapeMemRef : StaticShapeMemRefOf<[AnyType]>; +// For a MemRefType, verify that it has strides. +def HasStridesPred : CPred<[{ isStrided($_self.cast()) }]>; + +class StridedMemRefOf allowedTypes> + : Type.predicate, HasStridesPred]>, + "strided " # MemRefOf.description>; + +def AnyStridedMemRef : StridedMemRefOf<[AnyType]>; + +class AnyStridedMemRefOfRank : + Type.predicate]>, + AnyStridedMemRef.description # " of rank " # rank>; + // This represents a generic tuple without any constraints on element type. def AnyTuple : Type; diff --git a/mlir/include/mlir/IR/StandardTypes.h b/mlir/include/mlir/IR/StandardTypes.h index 225b717..55aa440 100644 --- a/mlir/include/mlir/IR/StandardTypes.h +++ b/mlir/include/mlir/IR/StandardTypes.h @@ -387,11 +387,11 @@ public: /// layout map, returns llvm::None. /// /// The convention is that the strides for dimensions d0, .. dn appear in - /// order followed by the constant offset, to make indexing intuitive into the - /// result. + /// order to make indexing intuitive into the result. static constexpr int64_t kDynamicStrideOrOffset = std::numeric_limits::min(); - LogicalResult getStridesAndOffset(SmallVectorImpl &strides) const; + LogicalResult getStridesAndOffset(SmallVectorImpl &strides, + int64_t &offset) const; static bool kindof(unsigned kind) { return kind == StandardTypes::MemRef; } @@ -491,6 +491,29 @@ public: static bool kindof(unsigned kind) { return kind == StandardTypes::None; } }; + +/// Given a list of strides (in which MemRefType::kDynamicStrideOrOffset +/// represents a dynamic value), return the single result AffineMap which +/// represents the linearized strided layout map. Dimensions correspond to the +/// offset followed by the strides in order. Symbols are inserted for each +/// dynamic dimension in order. A stride cannot take value `0`. +/// +/// Examples: +/// ========= +/// +/// 1. For offset: 0 strides: ?, ?, 1 return +/// (i, j, k)[M, N]->(M * i + N * j + k) +/// +/// 2. For offset: 3 strides: 32, ?, 16 return +/// (i, j, k)[M]->(3 + 32 * i + M * j + 16 * k) +/// +/// 3. For offset: ? strides: ?, ?, ? return +/// (i, j, k)[off, M, N, P]->(off + M * i + N * j + P * k) +AffineMap makeStridedLinearLayoutMap(ArrayRef strides, int64_t offset, + MLIRContext *context); + +bool isStrided(MemRefType t); + } // end namespace mlir #endif // MLIR_IR_STANDARDTYPES_H diff --git a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp index 663c302..1539f02 100644 --- a/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp +++ b/mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp @@ -150,8 +150,9 @@ static unsigned kOffsetPosInMemRefDescriptor = 1; static unsigned kSizePosInMemRefDescriptor = 2; static unsigned kStridePosInMemRefDescriptor = 3; Type LLVMTypeConverter::convertMemRefType(MemRefType type) { + int64_t offset; SmallVector strides; - bool strideSuccess = succeeded(type.getStridesAndOffset(strides)); + bool strideSuccess = succeeded(type.getStridesAndOffset(strides, offset)); assert(strideSuccess && "Non-strided layout maps must have been normalized away"); (void)strideSuccess; @@ -568,15 +569,16 @@ struct AllocOpLowering : public LLVMLegalizationPattern { if (isSupportedMemRefType(type)) return matchSuccess(); - SmallVector stridesAndOffset; - auto successStrides = type.getStridesAndOffset(stridesAndOffset); + int64_t offset; + SmallVector strides; + auto successStrides = type.getStridesAndOffset(strides, offset); if (failed(successStrides)) return matchFailure(); // Dynamic strides are ok if they can be deduced from dynamic sizes (which // is guaranteed when succeeded(successStrides)). // Dynamic offset however can never be alloc'ed. - if (stridesAndOffset.back() != MemRefType::kDynamicStrideOrOffset) + if (offset != MemRefType::kDynamicStrideOrOffset) return matchFailure(); return matchSuccess(); @@ -648,15 +650,17 @@ struct AllocOpLowering : public LLVMLegalizationPattern { allocated = rewriter.create(op->getLoc(), elementPtrType, ArrayRef(allocated)); - SmallVector stridesAndOffset; - auto successStrides = type.getStridesAndOffset(stridesAndOffset); + int64_t offset; + SmallVector strides; + auto successStrides = type.getStridesAndOffset(strides, offset); assert(succeeded(successStrides) && "unexpected non-strided memref"); (void)successStrides; + assert(offset != MemRefType::kDynamicStrideOrOffset && + "unexpected dynamic offset"); - ArrayRef strides = ArrayRef(stridesAndOffset).drop_back(); // 0-D memref corner case: they have size 1 ... assert((type.getRank() == 0 && strides.empty() && sizes.size() == 1) || - (strides.size() == sizes.size()) && "unexpected number of stride"); + (strides.size() == sizes.size()) && "unexpected number of strides"); // Create the MemRef descriptor. auto structType = lowering.convertType(type); @@ -668,14 +672,15 @@ struct AllocOpLowering : public LLVMLegalizationPattern { rewriter.getIndexArrayAttr(kPtrPosInMemRefDescriptor)); memRefDescriptor = rewriter.create( op->getLoc(), structType, memRefDescriptor, - createIndexConstant(rewriter, op->getLoc(), stridesAndOffset.back()), + createIndexConstant(rewriter, op->getLoc(), offset), rewriter.getIndexArrayAttr(kOffsetPosInMemRefDescriptor)); if (type.getRank() == 0) - // No size/stride arrays in 0-D memref, use the descriptor value. + // No size/stride descriptor in memref, return the descriptor value. return rewriter.replaceOp(op, memRefDescriptor); - // Store all sizes in the descriptor. + // Store all sizes in the descriptor. Only dynamic sizes are passed in as + // operands to AllocOp. Value *runningStride = nullptr; // Iterate strides in reverse order, compute runningStride and strideValues. auto nStrides = strides.size(); @@ -874,29 +879,25 @@ struct MemRefCastOpLowering : public LLVMLegalizationPattern { struct DimOpLowering : public LLVMLegalizationPattern { using LLVMLegalizationPattern::LLVMLegalizationPattern; - PatternMatchResult match(Operation *op) const override { - auto dimOp = cast(op); - MemRefType type = dimOp.getOperand()->getType().cast(); - return isSupportedMemRefType(type) ? matchSuccess() : matchFailure(); - } - - void rewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { + PatternMatchResult + matchAndRewrite(Operation *op, ArrayRef operands, + ConversionPatternRewriter &rewriter) const override { auto dimOp = cast(op); OperandAdaptor transformed(operands); MemRefType type = dimOp.getOperand()->getType().cast(); auto shape = type.getShape(); int64_t index = dimOp.getIndex(); - // Extract dynamic size from the memref descriptor and define static size - // as a constant. + // Extract dynamic size from the memref descriptor. if (ShapedType::isDynamic(shape[index])) rewriter.replaceOpWithNewOp( op, getIndexType(), transformed.memrefOrTensor(), rewriter.getI64ArrayAttr({kSizePosInMemRefDescriptor, index})); else + // Use constant for static size. rewriter.replaceOp( op, createIndexConstant(rewriter, op->getLoc(), shape[index])); + return matchSuccess(); } }; @@ -945,19 +946,18 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern { Value *getStridedElementPtr(Location loc, Type elementTypePtr, Value *memRefDescriptor, ArrayRef indices, - ArrayRef stridesAndOffset, + ArrayRef strides, int64_t offset, ConversionPatternRewriter &rewriter) const { auto indexTy = this->getIndexType(); Value *base = rewriter.create( loc, elementTypePtr, memRefDescriptor, rewriter.getIndexArrayAttr(kPtrPosInMemRefDescriptor)); - Value *offset = - stridesAndOffset.back() == MemRefType::kDynamicStrideOrOffset + Value *offsetValue = + offset == MemRefType::kDynamicStrideOrOffset ? rewriter.create( loc, indexTy, memRefDescriptor, rewriter.getIndexArrayAttr(kOffsetPosInMemRefDescriptor)) - : this->createIndexConstant(rewriter, loc, stridesAndOffset.back()); - auto strides = stridesAndOffset.drop_back(); + : this->createIndexConstant(rewriter, loc, offset); for (int i = 0, e = indices.size(); i < e; ++i) { Value *stride; if (strides[i] != MemRefType::kDynamicStrideOrOffset) { @@ -973,9 +973,10 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern { } Value *additionalOffset = rewriter.create(loc, indices[i], stride); - offset = rewriter.create(loc, offset, additionalOffset); + offsetValue = + rewriter.create(loc, offsetValue, additionalOffset); } - return rewriter.create(loc, elementTypePtr, base, offset); + return rewriter.create(loc, elementTypePtr, base, offsetValue); } Value *getDataPtr(Location loc, MemRefType type, Value *memRefDesc, @@ -983,12 +984,13 @@ struct LoadStoreOpLowering : public LLVMLegalizationPattern { ConversionPatternRewriter &rewriter, llvm::Module &module) const { auto ptrType = getMemRefElementPtrType(type, this->lowering); - SmallVector stridesAndOffset; - auto res = type.getStridesAndOffset(stridesAndOffset); - assert(succeeded(res) && "expected strided MemRef"); - (void)res; - return getStridedElementPtr(loc, ptrType, memRefDesc, indices, - stridesAndOffset, rewriter); + int64_t offset; + SmallVector strides; + auto successStrides = type.getStridesAndOffset(strides, offset); + assert(succeeded(successStrides) && "unexpected non-strided memref"); + (void)successStrides; + return getStridedElementPtr(loc, ptrType, memRefDesc, indices, strides, + offset, rewriter); } }; diff --git a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp index ed904e0..8e304be 100644 --- a/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp +++ b/mlir/lib/Dialect/Linalg/Analysis/DependenceAnalysis.cpp @@ -55,7 +55,7 @@ Value *Aliases::find(Value *v) { auto it = aliases.find(v); if (it != aliases.end()) { assert(((isa(it->getSecond()) && - it->getSecond()->getType().isa()) || + it->getSecond()->getType().isa()) || it->getSecond()->getType().isa()) && "Buffer or block argument expected"); return it->getSecond(); diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp index f3251fc..64a8013 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp @@ -45,80 +45,6 @@ using namespace mlir::edsc; using namespace mlir::edsc::intrinsics; using namespace mlir::linalg; -namespace { -/// Fold constant dimensions into an alloc operation. -struct SimplifyDimOp : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - PatternMatchResult matchAndRewrite(linalg::DimOp dimOp, - PatternRewriter &rewriter) const override; -}; -} // end namespace - -PatternMatchResult -SimplifyDimOp::matchAndRewrite(linalg::DimOp dimOp, - PatternRewriter &rewriter) const { - auto *viewProducingOp = dimOp.view()->getDefiningOp(); - auto subView = dyn_cast_or_null(viewProducingOp); - auto slice = dyn_cast_or_null(viewProducingOp); - auto view = dyn_cast_or_null(viewProducingOp); - assert(subView || slice || view); - - unsigned dim = dimOp.getIndex(); - Value *min, *max, *step; - if (view) { - // Cannot traverse block arguments, fail. - if (isa(view.getRange(dim))) - return matchFailure(); - // Record min, max, step for further processing. - auto range = cast(view.getRange(dim)->getDefiningOp()); - std::tie(min, max, step) = - std::make_tuple(range.min(), range.max(), range.step()); - } else if (subView) { - // Record min, max, step for further processing. - auto range = subView.getRange(dim); - std::tie(min, max, step) = - std::make_tuple(range.min, range.max, range.step); - } else { - // Taking the dim of a slice must take a range (since other dims have been - // rank-reduced). - auto *rangeValue = slice.getRanges()[dim]; - // Cannot traverse block arguments, fail. - if (isa(rangeValue)) - return matchFailure(); - auto range = cast(rangeValue->getDefiningOp()); - // Record min, max, step for further processing. - std::tie(min, max, step) = - std::make_tuple(range.min(), range.max(), range.step()); - } - - // Only support constant steps of 1 atm. - auto constant = dyn_cast_or_null(step->getDefiningOp()); - if (!constant || constant.getValue() != 1) - return matchFailure(); - - // Circumvent affine constraints: - // emit an affine_apply when possible, otherwise emit a `subi`. - bool validAffineMin = isValidDim(min) || isValidSymbol(min) || - isa_and_nonnull(min->getDefiningOp()); - bool validAffineMax = isValidDim(max) || isValidSymbol(max) || - isa_and_nonnull(max->getDefiningOp()); - - OpBuilder b(dimOp); - ScopedContext scope(b, dimOp.getLoc()); - // Emit `subi`. - if (!validAffineMin || !validAffineMax) { - rewriter.replaceOp(dimOp, {subi(max, min)}, {dimOp.view()}); - return matchSuccess(); - } - - // Emit affine_apply. - using edsc::op::operator-; - rewriter.replaceOp(dimOp, {ValueHandle(max) - ValueHandle(min)}, - {dimOp.view()}); - return matchSuccess(); -} - ///////////////////// Operations defined with Tablegen ///////////////////////// // For such operations that do not correspond to library calls (i.e. defined in // LinalgOps.td), we define an overloaded `print` function and a @@ -221,35 +147,6 @@ static ParseResult parseBufferSizeOp(OpAsmParser &parser, } //===----------------------------------------------------------------------===// -// DimOp -//===----------------------------------------------------------------------===// -void mlir::linalg::DimOp::getCanonicalizationPatterns( - OwningRewritePatternList &results, MLIRContext *context) { - results.insert(context); -} - -static void print(OpAsmPrinter &p, linalg::DimOp op) { - p << op.getOperationName() << " " << *op.getOperand() << ", " - << op.getIndex(); - p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"index"}); - p << " : " << op.getOperand()->getType(); -} - -static ParseResult parseDimOp(OpAsmParser &parser, OperationState &result) { - OpAsmParser::OperandType operandInfo; - IntegerAttr indexAttr; - Type type; - Type indexType = parser.getBuilder().getIndexType(); - return failure( - parser.parseOperand(operandInfo) || parser.parseComma() || - parser.parseAttribute(indexAttr, indexType, "index", result.attributes) || - parser.parseOptionalAttributeDict(result.attributes) || - parser.parseColonType(type) || - parser.resolveOperand(operandInfo, type, result.operands) || - parser.addTypeToList(indexType, result.types)); -} - -//===----------------------------------------------------------------------===// // GenericOp //===----------------------------------------------------------------------===// @@ -390,41 +287,6 @@ static LogicalResult verify(GenericOp op) { } //===----------------------------------------------------------------------===// -// LoadOp -//===----------------------------------------------------------------------===// - -static void print(OpAsmPrinter &p, linalg::LoadOp op) { - p << op.getOperationName() << " " << *op.view() << '['; - p.printOperands(op.indices()); - p << ']'; - p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.getViewType(); -} - -static ParseResult parseLoadOp(OpAsmParser &parser, OperationState &result) { - OpAsmParser::OperandType viewInfo; - SmallVector indexInfo; - ViewType type; - - auto indexTy = parser.getBuilder().getIndexType(); - return failure( - parser.parseOperand(viewInfo) || - parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || - parser.parseOptionalAttributeDict(result.attributes) || - parser.parseColonType(type) || - parser.resolveOperand(viewInfo, type, result.operands) || - parser.resolveOperands(indexInfo, indexTy, result.operands) || - parser.addTypeToList(type.getElementType(), result.types)); -} - -static LogicalResult verify(linalg::LoadOp op) { - if (op.getRank() != llvm::size(op.indices())) - return op.emitOpError("expected ") - << op.getRank() << " indices, got " << llvm::size(op.indices()); - return success(); -} - -//===----------------------------------------------------------------------===// // RangeOp //===----------------------------------------------------------------------===// @@ -457,13 +319,21 @@ void mlir::linalg::SliceOp::build(Builder *b, OperationState &result, result.addOperands(base); result.addOperands(indexings); - ViewType viewType = base->getType().cast(); - unsigned rank = viewType.getRank(); - for (auto *i : indexings) - if (!i->getType().isa()) - rank--; - Type elementType = viewType.getElementType(); - result.addTypes({ViewType::get(b->getContext(), elementType, rank)}); + auto memRefType = base->getType().cast(); + int64_t offset; + SmallVector strides; + auto res = memRefType.getStridesAndOffset(strides, offset); + assert(succeeded(res) && strides.size() == indexings.size()); + (void)res; + + unsigned rank = memRefType.getRank(); + // TODO(ntv): propagate static size and stride information when available. + SmallVector sizes(rank, -1); // -1 encodes dynamic size. + Type elementType = memRefType.getElementType(); + result.addTypes({MemRefType::get( + sizes, elementType, + {makeStridedLinearLayoutMap(strides, offset, b->getContext())}, + memRefType.getMemorySpace())}); } static void print(OpAsmPrinter &p, SliceOp op) { @@ -519,49 +389,27 @@ static LogicalResult verify(SliceOp op) { } //===----------------------------------------------------------------------===// -// StoreOp -//===----------------------------------------------------------------------===// - -static void print(OpAsmPrinter &p, linalg::StoreOp op) { - p << op.getOperationName() << " " << *op.value(); - p << ", " << *op.view() << '['; - p.printOperands(op.indices()); - p << ']'; - p.printOptionalAttrDict(op.getAttrs()); - p << " : " << op.getViewType(); -} - -static ParseResult parseStoreOp(OpAsmParser &parser, OperationState &result) { - OpAsmParser::OperandType storeValueInfo; - OpAsmParser::OperandType viewInfo; - SmallVector indexInfo; - ViewType viewType; - - auto indexTy = parser.getBuilder().getIndexType(); - return failure( - parser.parseOperand(storeValueInfo) || parser.parseComma() || - parser.parseOperand(viewInfo) || - parser.parseOperandList(indexInfo, OpAsmParser::Delimiter::Square) || - parser.parseOptionalAttributeDict(result.attributes) || - parser.parseColonType(viewType) || - parser.resolveOperand(storeValueInfo, viewType.getElementType(), - result.operands) || - parser.resolveOperand(viewInfo, viewType, result.operands) || - parser.resolveOperands(indexInfo, indexTy, result.operands)); -} - -static LogicalResult verify(linalg::StoreOp op) { - if (op.value()->getType() != op.getViewType().getElementType()) - return op.emitOpError("expected value type to match view element type"); - if (op.getRank() != llvm::size(op.indices())) - return op.emitOpError("expected ") - << op.getRank() << " indices, got " << llvm::size(op.indices()); - return success(); -} - -//===----------------------------------------------------------------------===// // SubViewOp //===----------------------------------------------------------------------===// +void mlir::linalg::SubViewOp::build(Builder *b, OperationState &result, + Value *view, ArrayRef ranges, + Type resultType, + ArrayRef attrs) { + // If the result type is not specified, assume sizes are fully dynamic. + // Strides don't change though. + // TODO(ntv) for canonicalization it may be better to use a (min, size, step) + // instead of a (min, max, step) abstraction. + if (!resultType) { + auto rank = ranges.size(); + SmallVector sizes(rank, -1); + auto memRefType = view->getType().cast(); + Type elementType = memRefType.getElementType(); + resultType = MemRefType::get(sizes, elementType, memRefType.getAffineMaps(), + memRefType.getMemorySpace()); + } + build(b, result, resultType, view, ranges); + result.addAttributes(attrs); +} static void print(OpAsmPrinter &p, SubViewOp op) { p << op.getOperationName() << " " << *op.getOperand(0) << "["; @@ -576,7 +424,7 @@ static void print(OpAsmPrinter &p, SubViewOp op) { static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType inputView, resultView; - Type viewType; + MemRefType memRefType; if (parser.parseOperand(inputView)) return failure(); @@ -587,13 +435,14 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { // linalg.subview %0[%1:%2:%3][%4:%5:%6] if (parser.parseOperandList(ops, OpAsmParser::Delimiter::Square) || parser.parseOptionalAttributeDict(result.attributes) || - parser.parseColonType(viewType)) + parser.parseColonType(memRefType)) return failure(); auto indexTy = parser.getBuilder().getIndexType(); - return failure(parser.resolveOperand(inputView, viewType, result.operands) || - parser.resolveOperands(ops, indexTy, result.operands) || - parser.addTypeToList(viewType, result.types)); + return failure( + parser.resolveOperand(inputView, memRefType, result.operands) || + parser.resolveOperands(ops, indexTy, result.operands) || + parser.addTypeToList(memRefType, result.types)); } //===----------------------------------------------------------------------===// @@ -602,8 +451,31 @@ static ParseResult parseSubViewOp(OpAsmParser &parser, OperationState &result) { void mlir::linalg::TransposeOp::build(Builder *b, OperationState &result, Value *view, AffineMapAttr permutation, ArrayRef attrs) { - // TODO(ntv): once views have static dimensions, compute the permuted type. - build(b, result, view->getType(), view, attrs); + auto permutationMap = permutation.getValue(); + assert(permutationMap); + + auto memRefType = view->getType().cast(); + auto rank = memRefType.getRank(); + auto originalSizes = memRefType.getShape(); + // Compute permuted sizes. + SmallVector sizes(rank, 0); + for (auto en : llvm::enumerate(permutationMap.getResults())) + sizes[en.index()] = + originalSizes[en.value().cast().getPosition()]; + + // Compute permuted strides. + int64_t offset; + SmallVector strides; + auto res = memRefType.getStridesAndOffset(strides, offset); + assert(succeeded(res) && strides.size() == static_cast(rank)); + (void)res; + auto map = makeStridedLinearLayoutMap(strides, offset, b->getContext()); + map = permutationMap ? map.compose(permutationMap) : map; + // Compute result type. + auto resultType = MemRefType::get(sizes, memRefType.getElementType(), map, + memRefType.getMemorySpace()); + + build(b, result, resultType, view, attrs); result.addAttribute(TransposeOp::getPermutationAttrName(), permutation); } @@ -618,7 +490,7 @@ static ParseResult parseTransposeOp(OpAsmParser &parser, OperationState &result) { OpAsmParser::OperandType view; AffineMapAttr permutation; - Type type; + MemRefType type; return failure(parser.parseOperand(view) || parser.parseAttribute(permutation, TransposeOp::getPermutationAttrName(), @@ -636,9 +508,13 @@ void mlir::linalg::ViewOp::build(Builder *b, OperationState &result, Value *buffer, ArrayRef ranges, Type resultType, ArrayRef attrs) { + // If the result type is not specified, assume sizes are fully dynamic. + // Strides are set to match an empty layout map which means "contiguous view". if (!resultType) { + auto rank = ranges.size(); + SmallVector sizes(rank, -1); Type elementType = buffer->getType().cast().getElementType(); - resultType = ViewType::get(b->getContext(), elementType, ranges.size()); + resultType = MemRefType::get(sizes, elementType, {}, 0); } build(b, result, resultType, buffer, ranges); result.addAttributes(attrs); @@ -664,18 +540,18 @@ static ParseResult parseViewOp(OpAsmParser &parser, OperationState &result) { return failure(); } - ViewType viewType = vType.dyn_cast(); - if (!viewType) - return parser.emitError(parser.getNameLoc(), "expected view type"); - if (viewType.getRank() != rangesInfo.size()) + MemRefType memRefType = vType.dyn_cast(); + if (!memRefType) + return parser.emitError(parser.getNameLoc(), "expected memref type"); + if (static_cast(memRefType.getRank()) != rangesInfo.size()) return parser.emitError(parser.getNameLoc(), "expected ") - << viewType.getRank() << " ranges"; + << memRefType.getRank() << " ranges"; return failure( parser.resolveOperand(bufferInfo, bType, result.operands) || (!rangesInfo.empty() && parser.resolveOperands(rangesInfo, RangeType::get(vType.getContext()), result.operands)) || - parser.addTypeToList(viewType, result.types)); + parser.addTypeToList(memRefType, result.types)); } //===----------------------------------------------------------------------===// @@ -747,10 +623,12 @@ static LogicalResult verify(YieldOp op) { // // ``` // linalg.matmul(%0, %1, %2) : -// !linalg.view, !linalg.view, !linalg.view +// memref, +// memref, +// memref // ``` // -// Where %0, %1 and %2 are ssa-values of type ViewType. +// Where %0, %1 and %2 are ssa-values of type MemRefType with strides. static void printLinalgLibraryOp(OpAsmPrinter &p, Operation *op) { assert(op->getAbstractOperation() && "unregistered operation"); p << op->getName().getStringRef() << "("; @@ -829,14 +707,14 @@ verifyStrideOrDilation(ConvOp op, ArrayRef attrs, bool isStride) { } static LogicalResult verify(ConvOp op) { - auto oType = op.output()->getType().cast(); - auto fType = op.filter()->getType().cast(); - auto iType = op.input()->getType().cast(); + auto oType = op.output()->getType().cast(); + auto fType = op.filter()->getType().cast(); + auto iType = op.input()->getType().cast(); if (oType.getElementType() != iType.getElementType() || oType.getElementType() != fType.getElementType()) - return op.emitOpError("expects view elemental types to match"); + return op.emitOpError("expects memref elemental types to match"); if (oType.getRank() != iType.getRank() || oType.getRank() != fType.getRank()) - return op.emitOpError("expects view ranks to match"); + return op.emitOpError("expects memref ranks to match"); if (auto strides = op.strides()) { if (failed( verifyStrideOrDilation(op, strides->getValue(), /*isStride=*/true))) @@ -1000,11 +878,14 @@ SmallVector mlir::linalg::loopToOperandRangesMaps(Operation *op) { } static void appendMangledType(llvm::raw_string_ostream &ss, Type t) { - if (auto view = t.dyn_cast()) { + if (auto memref = t.dyn_cast()) { ss << "view"; - for (unsigned i = 0, e = view.getRank(); i < e; ++i) - ss << "x"; - appendMangledType(ss, view.getElementType()); + for (auto size : memref.getShape()) + if (size < 0) + ss << "sx"; + else + ss << size << "x"; + appendMangledType(ss, memref.getElementType()); } else if (auto vec = t.dyn_cast()) { ss << "vector"; interleave( diff --git a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp index 6fdd9ad..c09b75e 100644 --- a/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp +++ b/mlir/lib/Dialect/Linalg/IR/LinalgTypes.cpp @@ -34,7 +34,7 @@ using namespace mlir::linalg; mlir::linalg::LinalgDialect::LinalgDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { - addTypes(); + addTypes(); addOperations< #define GET_OP_LIST #include "mlir/Dialect/Linalg/IR/LinalgOps.cpp.inc" @@ -142,76 +142,10 @@ Type mlir::linalg::LinalgDialect::parseType(StringRef spec, return (bufferSize == -1 ? BufferType::get(getContext(), t) : BufferType::get(getContext(), t, bufferSize)); } - } else if (spec.consume_front("view")) { - if (spec.consume_front("<") && spec.consume_back(">")) { - // Just count the number of ? to get the rank. - unsigned rank = 0; - for (unsigned i = 0, e = spec.size(); i < e; ++i) { - if (spec.consume_front("?")) { - ++rank; - if (!spec.consume_front("x")) { - emitError(loc, "expected a list of '?x' dimension specifiers: ") - << spec; - return Type(); - } - } - } - if (auto t = mlir::parseType(spec, context)) - return ViewType::get(context, t, rank); - } } return (emitError(loc, "unknown Linalg type: " + origSpec), Type()); } -struct mlir::linalg::ViewTypeStorage : public TypeStorage { - /// Underlying Key type to transport the payload needed to construct a custom - /// type in a generic way. - struct Key { - Key(Type elementType, unsigned rank) - : elementType(elementType), rank(rank) {} - Type elementType; - unsigned rank; - }; - /// `KeyTy` is a necessary typename hook for MLIR's custom type unique'ing. - using KeyTy = Key; - - /// Construction in the llvm::BumpPtrAllocator given a key. - static ViewTypeStorage *construct(TypeStorageAllocator &allocator, - const Key &key) { - return new (allocator.allocate()) ViewTypeStorage(key); - } - - /// Equality operator for hashing. - bool operator==(const Key &key) const { - return elementType == key.elementType && rank == key.rank; - } - - /// Hashing for unique'ing. - static unsigned hashKey(const Key &key) { - return llvm::hash_combine(key.elementType, key.rank); - } - - unsigned getRank() { return rank; }; - Type getElementType() { return elementType; }; - -private: - ViewTypeStorage(const Key &key) - : elementType(key.elementType), rank(key.rank) {} - - Type elementType; - unsigned rank; -}; - -ViewType mlir::linalg::ViewType::get(MLIRContext *context, Type elementType, - unsigned rank) { - return Base::get(context, LinalgTypes::View, elementType, rank); -} - -Type mlir::linalg::ViewType::getElementType() { - return getImpl()->getElementType(); -} - -unsigned mlir::linalg::ViewType::getRank() { return getImpl()->getRank(); } /// BufferType prints as "buffer". static void print(BufferType bt, raw_ostream &os) { @@ -228,28 +162,6 @@ static void print(BufferType bt, raw_ostream &os) { /// RangeType prints as just "range". static void print(RangeType rt, raw_ostream &os) { os << "range"; } -/// ViewType prints as: -/// -/// ```{.mlir} -/// view -/// ``` -/// -/// or -/// -/// ```{.mlir} -/// view -/// ``` -/// -/// for 0-D views (a.k.a pointer to a scalar value). -static void print(mlir::linalg::ViewType rt, raw_ostream &os) { - os << "view<"; - for (unsigned i = 0, e = rt.getRank(); i < e; ++i) { - os << "?x"; - } - os << rt.getElementType(); - os << ">"; -} - void mlir::linalg::LinalgDialect::printType(Type type, raw_ostream &os) const { switch (type.getKind()) { default: @@ -260,8 +172,5 @@ void mlir::linalg::LinalgDialect::printType(Type type, raw_ostream &os) const { case LinalgTypes::Range: print(type.cast(), os); break; - case LinalgTypes::View: - print(type.cast(), os); - break; } } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp index 42ff5ce..1f1f8ac 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Fusion.cpp @@ -186,10 +186,10 @@ static LinalgOp fuse(Value *producedView, LinalgOp producer, LinalgOp consumer, << "existing LoopRange: " << loopRanges[i] << "\n"); else { auto viewDim = getViewDefiningLoopRange(producer, i); - loopRanges[i] = SubViewOp::Range{ - state.create(b, loc, 0), - linalg::intrinsics::dim(viewDim.view, viewDim.dimension), - state.create(b, loc, 1)}; + loopRanges[i] = + SubViewOp::Range{state.create(b, loc, 0), + dim(viewDim.view, viewDim.dimension), + state.create(b, loc, 1)}; LLVM_DEBUG(llvm::dbgs() << "new LoopRange: " << loopRanges[i] << "\n"); } } diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp index 8a8e747..9e0a3eb 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp @@ -123,36 +123,6 @@ static Type convertLinalgType(Type t, LLVMTypeConverter &lowering) { if (t.isa()) return LLVMType::getStructTy(int64Ty, int64Ty, int64Ty); - // A linalg.view type converts to a view descriptor. The view descriptor - // contains the pointer to the data buffer, followed by a 64-bit integer - // containing the distance between the beginning of the buffer and the first - // element to be accessed through the view, followed by two arrays, each - // containing as many 64-bit integers as the rank of the View. The first array - // represents the size, in number of original elements, of the view along the - // given dimension. When taking the view, the size is the difference between - // the upper and the lower bound of the range. The second array represents the - // "stride" (in tensor abstraction sense), i.e. the number of consecutive - // elements of the underlying buffer that separate two consecutive elements - // addressable through the view along the given dimension. When taking the - // view, the strides are constructed as products of the original sizes along - // the trailing dimensions, multiplied by the view step. For example, a view - // of a MxN memref with ranges {0:M:1}, {0:N:1}, i.e. the view of a complete - // memref, will have strides N and 1. A view with ranges {0:M:2}, {0:N:3} - // will have strides 2*N and 3. - // - // template - // struct { - // Elem *ptr; - // int64_t offset; - // int64_t sizes[Rank]; - // int64_t strides[Rank]; - // }; - if (auto viewType = t.dyn_cast()) { - auto ptrTy = getPtrToElementType(viewType, lowering); - auto arrayTy = LLVMType::getArrayTy(int64Ty, viewType.getRank()); - return LLVMType::getStructTy(ptrTy, int64Ty, arrayTy, arrayTy); - } - return Type(); } @@ -171,24 +141,27 @@ namespace { /// 3. view descriptor construction `desc`. class BaseViewConversionHelper { public: - BaseViewConversionHelper(Operation *op, ViewType viewType, + BaseViewConversionHelper(Location loc, MemRefType memRefType, ConversionPatternRewriter &rewriter, LLVMTypeConverter &lowering) - : elementTy(getPtrToElementType(viewType, lowering)), + : zeroDMemRef(memRefType.getRank() == 0), + elementTy(getPtrToElementType(memRefType, lowering)), int64Ty( lowering.convertType(rewriter.getIntegerType(64)).cast()), - rewriter(rewriter) { - viewDescriptorTy = convertLinalgType(viewType, lowering).cast(); - desc = rewriter.create(op->getLoc(), viewDescriptorTy); + desc(nullptr), rewriter(rewriter) { + assert(isStrided(memRefType) && "expected strided memref type"); + viewDescriptorTy = lowering.convertType(memRefType).cast(); + desc = rewriter.create(loc, viewDescriptorTy); } ArrayAttr pos(ArrayRef values) const { return rewriter.getI64ArrayAttr(values); }; + bool zeroDMemRef; LLVMType elementTy, int64Ty, viewDescriptorTy; - ConversionPatternRewriter &rewriter; Value *desc; + ConversionPatternRewriter &rewriter; }; } // namespace @@ -325,83 +298,6 @@ public: } }; -// DimOp creates a new `index` value. -class DimOpConversion : public LLVMOpLowering { -public: - explicit DimOpConversion(MLIRContext *context, LLVMTypeConverter &lowering_) - : LLVMOpLowering(linalg::DimOp::getOperationName(), context, lowering_) {} - - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto dimOp = cast(op); - auto indexTy = lowering.convertType(rewriter.getIndexType()); - edsc::ScopedContext context(rewriter, op->getLoc()); - auto pos = rewriter.getI64ArrayAttr( - {kSizePosInView, static_cast(dimOp.getIndex())}); - linalg::DimOpOperandAdaptor adaptor(operands); - Value *viewDescriptor = adaptor.view(); - rewriter.replaceOp(op, {extractvalue(indexTy, viewDescriptor, pos)}); - return matchSuccess(); - } -}; - -namespace { -// Common functionality for Linalg LoadOp and StoreOp conversion to the -// LLVM IR Dialect. -template class LoadStoreOpConversion : public LLVMOpLowering { -public: - explicit LoadStoreOpConversion(MLIRContext *context, - LLVMTypeConverter &lowering_) - : LLVMOpLowering(Op::getOperationName(), context, lowering_) {} - using Base = LoadStoreOpConversion; - - // Compute the pointer to an element of the buffer underlying the view given - // current view indices. Use the base offset and strides stored in the view - // descriptor to emit IR iteratively computing the actual offset, followed by - // a getelementptr. This must be called under an edsc::ScopedContext. - Value *obtainDataPtr(Operation *op, Value *viewDescriptor, - ArrayRef indices, - ConversionPatternRewriter &rewriter) const { - auto loadOp = cast(op); - auto elementTy = getPtrToElementType(loadOp.getViewType(), lowering); - auto int64Ty = lowering.convertType(rewriter.getIntegerType(64)); - auto pos = [&rewriter](ArrayRef values) { - return rewriter.getI64ArrayAttr(values); - }; - - // Linearize subscripts as: - // base_offset + SUM_i index_i * stride_i. - Value *base = extractvalue(elementTy, viewDescriptor, pos(kPtrPosInView)); - Value *offset = - extractvalue(int64Ty, viewDescriptor, pos(kOffsetPosInView)); - for (int i = 0, e = loadOp.getRank(); i < e; ++i) { - Value *stride = - extractvalue(int64Ty, viewDescriptor, pos({kStridePosInView, i})); - Value *additionalOffset = mul(indices[i], stride); - offset = add(offset, additionalOffset); - } - return gep(elementTy, base, offset); - } -}; -} // namespace - -// A load is converted into the actual address computation, getelementptr and -// an LLVM IR load. -class LoadOpConversion : public LoadStoreOpConversion { - using Base::Base; - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - edsc::ScopedContext edscContext(rewriter, op->getLoc()); - auto elementTy = lowering.convertType(*op->result_type_begin()); - linalg::LoadOpOperandAdaptor adaptor(operands); - auto ptr = obtainDataPtr(op, adaptor.view(), adaptor.indices(), rewriter); - rewriter.replaceOp(op, {llvm_load(elementTy, ptr)}); - return matchSuccess(); - } -}; - // RangeOp creates a new range descriptor. class RangeOpConversion : public LLVMOpLowering { public: @@ -448,39 +344,45 @@ public: Value *baseDesc = adaptor.view(); auto sliceOp = cast(op); - BaseViewConversionHelper helper(op, sliceOp.getViewType(), rewriter, - lowering); + auto memRefType = sliceOp.getBaseViewType(); + + BaseViewConversionHelper helper(op->getLoc(), sliceOp.getViewType(), + rewriter, lowering); LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty; Value *desc = helper.desc; - auto viewType = sliceOp.getBaseViewType(); - edsc::ScopedContext context(rewriter, op->getLoc()); - Value *zero = - constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); - - auto ptrPos = helper.pos(kPtrPosInView); - desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos); // TODO(ntv): extract sizes and emit asserts. - SmallVector strides(viewType.getRank()); - for (int i = 0, e = viewType.getRank(); i < e; ++i) { + SmallVector strides(memRefType.getRank()); + for (int i = 0, e = memRefType.getRank(); i < e; ++i) strides[i] = extractvalue(int64Ty, baseDesc, helper.pos({kStridePosInView, i})); - } - // Compute and insert base offset. + // Compute base offset. Value *baseOffset = extractvalue(int64Ty, baseDesc, helper.pos(kOffsetPosInView)); - for (int i = 0, e = viewType.getRank(); i < e; ++i) { + for (int i = 0, e = memRefType.getRank(); i < e; ++i) { Value *indexing = adaptor.indexings()[i]; Value *min = indexing; if (sliceOp.indexing(i)->getType().isa()) min = extractvalue(int64Ty, indexing, helper.pos(0)); baseOffset = add(baseOffset, mul(min, strides[i])); } + + // Insert base pointer. + auto ptrPos = helper.pos(kPtrPosInView); + desc = insertvalue(desc, extractvalue(elementTy, baseDesc, ptrPos), ptrPos); + + // Insert base offset. desc = insertvalue(desc, baseOffset, helper.pos(kOffsetPosInView)); + // Corner case, no sizes or strides: early return the descriptor. + if (helper.zeroDMemRef) + return rewriter.replaceOp(op, desc), matchSuccess(); + + Value *zero = + constant(int64Ty, rewriter.getIntegerAttr(rewriter.getIndexType(), 0)); // Compute and insert view sizes (max - min along the range) and strides. // Skip the non-range operands as they will be projected away from the view. int numNewDims = 0; @@ -515,22 +417,6 @@ public: } }; -// A store is converted into the actual address computation, getelementptr and -// an LLVM IR store. -class StoreOpConversion : public LoadStoreOpConversion { - using Base::Base; - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - edsc::ScopedContext edscContext(rewriter, op->getLoc()); - linalg::StoreOpOperandAdaptor adaptor(operands); - Value *ptr = obtainDataPtr(op, adaptor.view(), adaptor.indices(), rewriter); - llvm_store(adaptor.value(), ptr); - rewriter.replaceOp(op, llvm::None); - return matchSuccess(); - } -}; - /// Conversion pattern that transforms a linalg.transpose op into: /// 1. A function entry `alloca` operation to allocate a ViewDescriptor. /// 2. A load of the ViewDescriptor from the pointer allocated in 1. @@ -556,8 +442,8 @@ public: if (tranposeOp.permutation().isIdentity()) return rewriter.replaceOp(op, baseDesc), matchSuccess(); - BaseViewConversionHelper helper(op, tranposeOp.getViewType(), rewriter, - lowering); + BaseViewConversionHelper helper(op->getLoc(), tranposeOp.getViewType(), + rewriter, lowering); LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty; Value *desc = helper.desc; @@ -606,8 +492,8 @@ public: ViewOpOperandAdaptor adaptor(operands); auto viewOp = cast(op); - BaseViewConversionHelper helper(op, viewOp.getViewType(), rewriter, - lowering); + BaseViewConversionHelper helper(op->getLoc(), viewOp.getViewType(), + rewriter, lowering); LLVMType elementTy = helper.elementTy, int64Ty = helper.int64Ty; Value *desc = helper.desc; @@ -629,6 +515,12 @@ public: Value *baseOffset = constant(int64Ty, IntegerAttr::get(indexTy, 0)); desc = insertvalue(desc, baseOffset, helper.pos(kOffsetPosInView)); + // Corner case, no sizes or stride: early return the descriptor. + if (helper.zeroDMemRef) { + rewriter.replaceOp(op, desc); + return matchSuccess(); + } + // Compute and insert view sizes (max - min along the range). int numRanges = llvm::size(viewOp.ranges()); Value *runningStride = constant(int64Ty, IntegerAttr::get(indexTy, 1)); @@ -653,60 +545,11 @@ public: } }; -// Promote LLVM struct types to pointer to struct types to avoid ABI issues -// related to C struct packing. -static SmallVector -promoteStructTypes(Operation::operand_range operands, - LLVMTypeConverter &lowering) { - SmallVector res; - for (auto operand : operands) { - auto type = lowering.convertType(operand->getType()).cast(); - if (type.isStructTy()) - res.push_back(type.getPointerTo()); - else - res.push_back(type); - } - return res; -} - -// Promote LLVM struct to pointer to struct to avoid ABI issues related to -// C struct packing. -static SmallVector -promoteStructs(Location loc, ArrayRef operands, - ConversionPatternRewriter &rewriter, - LLVMTypeConverter &lowering) { - auto *context = rewriter.getContext(); - auto int64Ty = LLVM::LLVMType::getInt64Ty(lowering.getDialect()); - auto indexType = IndexType::get(context); - edsc::ScopedContext scope(rewriter, loc); - SmallVector promotedOperands; - promotedOperands.reserve(operands.size()); - for (auto *operand : operands) { - auto type = operand->getType().cast(); - if (!type.isStructTy()) { - promotedOperands.push_back(operand); - continue; - } - // Alloca with proper alignment. This is purely for solving ABI issues - // related to C struct packing across external library call boundaries. We - // do not expect optimizations of this alloca op and so we omit - // allocating at the entry block. - auto ptrType = type.cast().getPointerTo(); - Value *one = constant(int64Ty, IntegerAttr::get(indexType, 1)); - Value *allocated = llvm_alloca(ptrType, one, /*alignment=*/8); - // Store into the alloca'ed descriptor. - llvm_store(operand, allocated); - promotedOperands.push_back(allocated); - } - return promotedOperands; -} - // Get function definition for the LinalgOp. If it doesn't exist, insert a // definition. template -static FuncOp -getLLVMLibraryCallDeclaration(Operation *op, LLVMTypeConverter &lowering, - ConversionPatternRewriter &rewriter) { +static FuncOp getLLVMLibraryCallDeclaration(Operation *op, + PatternRewriter &rewriter) { auto linalgOp = cast(op); auto fnName = linalgOp.getLibraryCallName(); if (fnName.empty()) { @@ -718,19 +561,19 @@ getLLVMLibraryCallDeclaration(Operation *op, LLVMTypeConverter &lowering, return f; } - // Get the Function type consistent with LLVM Lowering. - // Structs are automatically promoted to pointer to struct in order to avoid - // ABI issues related to C struct packing that we don't want to handle here. - auto inputTypes = promoteStructTypes(op->getOperands(), lowering); + SmallVector inputTypes(op->getOperandTypes()); assert(op->getNumResults() == 0 && "Library call for linalg operation can be generated only for ops that " "have void return types"); - auto libFnType = FunctionType::get(inputTypes, {}, op->getContext()); - auto libFn = FuncOp::create(op->getLoc(), fnName, libFnType); - module.push_back(libFn); - // Return after creating the function definition. The body will be created - // later. - return libFn; + auto libFnType = FunctionType::get(inputTypes, {}, rewriter.getContext()); + // fnName is a dynamic std::String, unique it via a SymbolRefAttr. + SymbolRefAttr fnNameAttr = rewriter.getSymbolRefAttr(fnName); + OpBuilder::InsertionGuard guard(rewriter); + // Insert before module terminator. + rewriter.setInsertionPoint(module.getBody(), + std::prev(module.getBody()->end())); + return rewriter.create(op->getLoc(), fnNameAttr.getValue(), libFnType, + ArrayRef{}); } namespace { @@ -751,56 +594,50 @@ public: // `LinalgOp::getLibraryCallName()` function. // The implementation of the function can be either in the same module or in an // externally linked library. -template class LinalgOpConversion : public LLVMOpLowering { +template +class LinalgOpConversion : public OpRewritePattern { public: - explicit LinalgOpConversion(MLIRContext *context, - LinalgTypeConverter &lowering_) - : LLVMOpLowering(LinalgOp::getOperationName(), context, lowering_) {} + using OpRewritePattern::OpRewritePattern; - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto f = getLLVMLibraryCallDeclaration(op, lowering, rewriter); + PatternMatchResult matchAndRewrite(LinalgOp op, + PatternRewriter &rewriter) const override { + auto f = getLLVMLibraryCallDeclaration(op, rewriter); if (!f) - return matchFailure(); + return this->matchFailure(); auto fAttr = rewriter.getSymbolRefAttr(f); - auto named = rewriter.getNamedAttr("callee", fAttr); - rewriter.replaceOpWithNewOp( - op, promoteStructs(op->getLoc(), operands, rewriter, lowering), - ArrayRef{named}); - return matchSuccess(); + SmallVector operands(op.getOperands().begin(), + op.getOperands().end()); + rewriter.replaceOpWithNewOp(op, fAttr.getValue(), + ArrayRef{}, operands); + return this->matchSuccess(); } }; /// Conversion pattern specialization for CopyOp. This kicks in when both input /// and output permutations are left unspecified or are the identity. -template <> class LinalgOpConversion : public LLVMOpLowering { +template <> class LinalgOpConversion : public OpRewritePattern { public: - explicit LinalgOpConversion(MLIRContext *context, - LinalgTypeConverter &lowering_) - : LLVMOpLowering(CopyOp::getOperationName(), context, lowering_) {} + using OpRewritePattern::OpRewritePattern; - PatternMatchResult - matchAndRewrite(Operation *op, ArrayRef operands, - ConversionPatternRewriter &rewriter) const override { - auto copyOp = cast(op); - auto inputPerm = copyOp.inputPermutation(); + PatternMatchResult matchAndRewrite(CopyOp op, + PatternRewriter &rewriter) const override { + auto inputPerm = op.inputPermutation(); if (inputPerm.hasValue() && !inputPerm->isIdentity()) return matchFailure(); - auto outputPerm = copyOp.outputPermutation(); + auto outputPerm = op.outputPermutation(); if (outputPerm.hasValue() && !outputPerm->isIdentity()) return matchFailure(); - auto f = getLLVMLibraryCallDeclaration(op, lowering, rewriter); + auto f = getLLVMLibraryCallDeclaration(op, rewriter); if (!f) return matchFailure(); auto fAttr = rewriter.getSymbolRefAttr(f); - auto named = rewriter.getNamedAttr("callee", fAttr); - rewriter.replaceOpWithNewOp( - op, promoteStructs(op->getLoc(), operands, rewriter, lowering), - ArrayRef{named}); + SmallVector operands(op.getOperands().begin(), + op.getOperands().end()); + rewriter.replaceOpWithNewOp(op, fAttr.getValue(), + ArrayRef{}, operands); return matchSuccess(); } }; @@ -836,19 +673,23 @@ public: } }; +/// Populate the given list with patterns that convert from Linalg to Standard. +static void +populateLinalgToStandardConversionPatterns(OwningRewritePatternList &patterns, + MLIRContext *ctx) { + patterns.insert, + LinalgOpConversion, LinalgOpConversion, + LinalgOpConversion>(ctx); +} + /// Populate the given list with patterns that convert from Linalg to LLVM. static void populateLinalgToLLVMConversionPatterns(LinalgTypeConverter &converter, OwningRewritePatternList &patterns, MLIRContext *ctx) { - patterns.insert(ctx); patterns.insert, LinalgOpConversion, - LinalgOpConversion, LinalgOpConversion, - LoadOpConversion, RangeOpConversion, SliceOpConversion, - StoreOpConversion, TransposeOpConversion, ViewOpConversion>( - ctx, converter); + BufferSizeOpConversion, RangeOpConversion, SliceOpConversion, + TransposeOpConversion, ViewOpConversion>(ctx, converter); } namespace { @@ -885,15 +726,16 @@ void LowerLinalgToLLVMPass::runOnModule() { populateAffineToStdConversionPatterns(patterns, &getContext()); populateLoopToStdConversionPatterns(patterns, &getContext()); populateStdToLLVMConversionPatterns(converter, patterns); + populateLinalgToStandardConversionPatterns(patterns, &getContext()); populateLinalgToLLVMConversionPatterns(converter, patterns, &getContext()); ConversionTarget target(getContext()); target.addLegalDialect(); target.addDynamicallyLegalOp( [&](FuncOp op) { return converter.isSignatureLegal(op.getType()); }); - if (failed(applyPartialConversion(module, target, patterns, &converter))) { + target.addLegalOp(); + if (failed(applyFullConversion(module, target, patterns, &converter))) signalPassFailure(); - } } std::unique_ptr> @@ -902,5 +744,5 @@ mlir::linalg::createLowerLinalgToLLVMPass() { } static PassRegistration - pass("linalg-lower-to-llvm-dialect", + pass("linalg-convert-to-llvm", "Lower the operations from the linalg dialect into the LLVM dialect"); diff --git a/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp b/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp index 6c5777c..7854df8 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp @@ -40,7 +40,7 @@ using namespace mlir::edsc::intrinsics; using namespace mlir::linalg; using namespace mlir::linalg::intrinsics; -using IndexedLinalgValue = TemplatedIndexedValue; +using IndexedLinalgValue = TemplatedIndexedValue; using edsc::op::operator+; using edsc::op::operator==; @@ -191,12 +191,12 @@ public: }; // Emits the MLIR for the scalar part of the generic op by: -// 1. Emitting linalg_load and linalg_store ops for each input and output +// 1. Emitting std_load and std_store ops for each input and output // view in order. This is achieved by applying the appropriate input or // output map to the enclosing induction variables. // 2. Emitting a call to `op.fun()` that takes as arguments the scalars // from point 1. above. -// 3. Emitting linalg_store to store the results of 2. to the output +// 3. Emitting std_store to store the results of 2. to the output // views. // // An example output may resemble: @@ -205,12 +205,17 @@ public: // loop.for %i = %c0 to %0 step %c1 { // loop.for %j = %c0 to %1 step %c1 { // loop.for %k = %c0 to %4 step %c1 { -// %11 = linalg.load %arg0[%i, %j] : !linalg.view -// %12 = linalg.load %arg1[%i, %j, %k] : !linalg.view -// %13 = linalg.load %arg2[%i, %k, %j] : !linalg.view +// %11 = linalg.load %arg0[%i, %j] : +// memref +// %12 = linalg.load %arg1[%i, %j, %k] : +// memref +// %13 = linalg.load %arg2[%i, %k, %j] : +// memref // %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32) -// linalg.store %14#0, %arg1[%i, %j, %k] : !linalg.view -// linalg.store %14#1, %arg2[%i, %k, %j] : !linalg.view +// linalg.store %14#0, %arg1[%i, %j, %k] : +// memref +// linalg.store %14#1, %arg2[%i, %k, %j] : +// memref // } // } // } @@ -227,19 +232,18 @@ public: unsigned nOutputs = genericOp.getNumOutputs(); SmallVector indexedValues(nInputs + nOutputs); - // 1.a. Emit linalg_load from input views. + // 1.a. Emit std_load from input views. for (unsigned i = 0, e = nInputs; i < e; ++i) { ValueHandleArray indexing(foldedAffineApplies( b, loc, genericOp.getInputIndexingMap(i), allIvs, folder)); - indexedValues[i] = linalg_load(genericOp.getInput(i), indexing); + indexedValues[i] = std_load(genericOp.getInput(i), indexing); } - // 1.b. Emit linalg_load from output views. + // 1.b. Emit std_load from output views. for (unsigned i = 0, e = nOutputs; i < e; ++i) { ValueHandleArray indexing(foldedAffineApplies( b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder)); - indexedValues[nInputs + i] = - linalg_load(genericOp.getOutput(i), indexing); + indexedValues[nInputs + i] = std_load(genericOp.getOutput(i), indexing); } auto funcOp = genericOp.getFunction(); @@ -248,11 +252,11 @@ public: Operation *callOp = call(funcOp, indexedValues); assert(callOp->getNumResults() == genericOp.getNumOutputs()); - // 3. Emit linalg_store. + // 3. Emit std_store. for (unsigned i = 0, e = nOutputs; i < e; ++i) { ValueHandleArray indexing(foldedAffineApplies( b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder)); - linalg_store(callOp->getResult(i), genericOp.getOutput(i), indexing); + std_store(callOp->getResult(i), genericOp.getOutput(i), indexing); } } else { // TODO(ntv): When a region inliner exists, use it. @@ -271,14 +275,14 @@ public: map.map(std::get<0>(it), std::get<1>(it)); } - // 3. Emit linalg_store. + // 3. Emit std_store. auto *yieldOp = cast(block.back()).getOperation(); assert(yieldOp->getNumOperands() == nOutputs); for (unsigned i = 0, e = nOutputs; i < e; ++i) { ValueHandleArray indexing(foldedAffineApplies( b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder)); - linalg_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i), - indexing); + std_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i), + indexing); } } } diff --git a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp index 929e26b..aca9477 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Tiling.cpp @@ -177,7 +177,7 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, for (unsigned viewIndex = 0; viewIndex < linalgOp.getNumInputsAndOutputs(); ++viewIndex) { Value *view = *(viewIteratorBegin + viewIndex); - unsigned viewRank = view->getType().cast().getRank(); + unsigned rank = view->getType().cast().getRank(); auto map = loopToOperandRangesMaps(linalgOp)[viewIndex]; // If the view is not tiled, we can use it as is. if (!isTiled(map, tileSizes)) { @@ -187,12 +187,12 @@ makeTiledViews(OpBuilder &b, Location loc, LinalgOp linalgOp, // Construct a new subview for the tile. SmallVector subViewOperands; - subViewOperands.reserve(viewRank * 3); - for (unsigned r = 0; r < viewRank; ++r) { + subViewOperands.reserve(rank * 3); + for (unsigned r = 0; r < rank; ++r) { if (!isTiled(map.getSubMap({r}), tileSizes)) { - subViewOperands.push_back(SubViewOp::Range{ - constant_index(folder, 0), linalg::intrinsics::dim(view, r), - constant_index(folder, 1)}); + subViewOperands.push_back(SubViewOp::Range{constant_index(folder, 0), + dim(view, r), + constant_index(folder, 1)}); continue; } @@ -261,15 +261,14 @@ static PromotionInfo promoteFullTileBuffer(OpBuilder &b, Location loc, auto rank = en.index(); auto rangeValue = en.value(); Value *d = - isa(rangeValue.max->getDefiningOp()) + isa(rangeValue.max->getDefiningOp()) ? rangeValue.max : applyMapToValues(b, loc, getAffineDifferenceMap(b.getContext()), {rangeValue.max, rangeValue.min}, folder) .front(); allocSize = muli(folder, allocSize, d).getValue(); fullRanges.push_back(range(folder, zero, d, one)); - partialRanges.push_back( - range(folder, zero, linalg::intrinsics::dim(subView, rank), one)); + partialRanges.push_back(range(folder, zero, dim(subView, rank), one)); } auto *buffer = allocBuffer(viewType.getElementType(), allocSize); auto fullLocalView = view(buffer, fullRanges); @@ -293,13 +292,13 @@ static PromotionInfo promotePartialTileBuffer(OpBuilder &b, Location loc, auto zero = constant_index(folder, 0); auto one = constant_index(folder, 1); - auto viewType = v->getType().cast(); + auto viewType = v->getType().cast(); auto rank = viewType.getRank(); Value *allocSize = one; SmallVector partialRanges; partialRanges.reserve(rank); for (unsigned r = 0; r < rank; ++r) { - Value *d = linalg::intrinsics::dim(v, r); + Value *d = dim(v, r); allocSize = muli(folder, allocSize, d).getValue(); partialRanges.push_back(range(folder, zero, d, one)); } @@ -333,7 +332,7 @@ mlir::linalg::promoteLinalgViews(OpBuilder &b, Location loc, auto info = promotionInfo.find(v); if (info == promotionInfo.end()) continue; - auto viewType = v->getType().cast(); + auto viewType = v->getType().cast(); // TODO(ntv): value to fill with should be related to the operation. // For now, just use APFloat(0.0f). auto t = viewType.getElementType().cast(); diff --git a/mlir/lib/IR/StandardTypes.cpp b/mlir/lib/IR/StandardTypes.cpp index 6d445cc..6ef18ce 100644 --- a/mlir/lib/IR/StandardTypes.cpp +++ b/mlir/lib/IR/StandardTypes.cpp @@ -452,14 +452,29 @@ static void accumulateStrides(MutableArrayRef strides, strides[pos] += val; } +// This sums multiple offsets as they are seen. In the particular case of +// accumulating a dynamic offset with either a static of dynamic one, this +// saturates to MemRefType::kDynamicStrideOrOffset. +static void accumulateOffset(int64_t &offset, bool &seenOffset, int64_t val) { + if (!seenOffset) { + // Newly seen case, sets value + offset = val; + seenOffset = true; + return; + } + if (offset != MemRefType::kDynamicStrideOrOffset) + // Already seen case accumulates unless they are already saturated. + offset += val; +} + /// Takes a single AffineExpr `e` and populates the `strides` and `seen` arrays /// with the strides values for each dim position and whether a value exists at /// that position, respectively. /// The convention is that the strides for dimensions d0, .. dn appear in -/// order followed by the constant offset, to make indexing intuitive into the -/// result. +/// order to make indexing intuitive into the result. static void extractStrides(AffineExpr e, MutableArrayRef strides, - MutableArrayRef seen, bool &failed) { + int64_t &offset, MutableArrayRef seen, + bool &seenOffset, bool &failed) { auto bin = e.dyn_cast(); if (!bin) return; @@ -485,11 +500,11 @@ static void extractStrides(AffineExpr e, MutableArrayRef strides, for (auto e : {bin.getLHS(), bin.getRHS()}) { if (auto cst = e.dyn_cast()) { // Independent constants cumulate. - accumulateStrides(strides, seen, seen.size() - 1, cst.getValue()); + accumulateOffset(offset, seenOffset, cst.getValue()); } else if (auto sym = e.dyn_cast()) { // Independent symbols saturate. - strides.back() = MemRefType::kDynamicStrideOrOffset; - seen.back() = true; + offset = MemRefType::kDynamicStrideOrOffset; + seenOffset = true; } else if (auto dim = e.dyn_cast()) { // Independent symbols cumulate 1. accumulateStrides(strides, seen, dim.getPosition(), 1); @@ -501,25 +516,27 @@ static void extractStrides(AffineExpr e, MutableArrayRef strides, llvm_unreachable("unexpected binary operation"); } -// Fallback cases for terminal dim/sym/cst that are not part of a binary op -// (i.e. single term). +// Fallback cases for terminal dim/sym/cst that are not part of a binary op ( +// i.e. single term). static void extractStridesFromTerm(AffineExpr e, MutableArrayRef strides, - MutableArrayRef seen) { + int64_t &offset, MutableArrayRef seen, + bool &seenOffset) { if (auto cst = e.dyn_cast()) { - assert(!seen.back() && "unexpected `seen` bit with single term"); - strides.back() = cst.getValue(); - seen.back() = true; + assert(!seenOffset && "unexpected `seen` bit with single term"); + offset = cst.getValue(); + seenOffset = true; return; } if (auto sym = e.dyn_cast()) { - assert(!seen.back() && "unexpected `seen` bit with single term"); - strides.back() = MemRefType::kDynamicStrideOrOffset; - seen.back() = true; + assert(!seenOffset && "unexpected `seen` bit with single term"); + offset = MemRefType::kDynamicStrideOrOffset; + seenOffset = true; return; } if (auto dim = e.dyn_cast()) { - assert(!seen.back() && "unexpected `seen` bit with single term"); + assert(!seen[dim.getPosition()] && + "unexpected `seen` bit with single term"); strides[dim.getPosition()] = 1; seen[dim.getPosition()] = true; return; @@ -527,8 +544,8 @@ static void extractStridesFromTerm(AffineExpr e, llvm_unreachable("unexpected binary operation"); } -LogicalResult -MemRefType::getStridesAndOffset(SmallVectorImpl &strides) const { +LogicalResult MemRefType::getStridesAndOffset(SmallVectorImpl &strides, + int64_t &offset) const { auto affineMaps = getAffineMaps(); // For now strides are only computed on a single affine map with a single // result (i.e. the closed subset of linearization maps that are compatible @@ -540,7 +557,7 @@ MemRefType::getStridesAndOffset(SmallVectorImpl &strides) const { if (affineMaps.empty() || affineMaps[0].isIdentity()) { if (getRank() == 0) { // Handle 0-D corner case. - strides.push_back(0); + offset = 0; return success(); } stridedExpr = makeCanonicalStridedLayoutExpr(getShape(), getContext()); @@ -551,27 +568,28 @@ MemRefType::getStridesAndOffset(SmallVectorImpl &strides) const { return failure(); bool failed = false; - strides = SmallVector(getRank() + 1, 0); - SmallVector seen(getRank() + 1, false); + strides = SmallVector(getRank(), 0); + bool seenOffset = false; + SmallVector seen(getRank(), false); if (stridedExpr.isa()) { stridedExpr.walk([&](AffineExpr e) { if (!failed) - extractStrides(e, strides, seen, failed); + extractStrides(e, strides, offset, seen, seenOffset, failed); }); } else { - extractStridesFromTerm(stridedExpr, strides, seen); + extractStridesFromTerm(stridedExpr, strides, offset, seen, seenOffset); } // Constant offset may not be present in `stridedExpr` which means it is // implicitly 0. - if (!seen.back()) { - seen.back() = true; - strides.back() = 0; - } + if (!seenOffset) + offset = 0; + if (failed || !llvm::all_of(seen, [](bool b) { return b; })) { strides.clear(); return failure(); } + return success(); } @@ -630,3 +648,46 @@ void TupleType::getFlattenedTypes(SmallVectorImpl &types) { /// Return the number of element types. size_t TupleType::size() const { return getImpl()->size(); } + +AffineMap mlir::makeStridedLinearLayoutMap(ArrayRef strides, + int64_t offset, + MLIRContext *context) { + AffineExpr expr; + unsigned nSymbols = 0; + + // AffineExpr for offset. + // Static case. + if (offset != MemRefType::kDynamicStrideOrOffset) { + auto cst = getAffineConstantExpr(offset, context); + expr = cst; + } else { + // Dynamic case, new symbol for the offset. + auto sym = getAffineSymbolExpr(nSymbols++, context); + expr = sym; + } + + // AffineExpr for strides. + for (auto en : llvm::enumerate(strides)) { + auto dim = en.index(); + auto stride = en.value(); + assert(stride != 0 && "Invalid stride specification"); + auto d = getAffineDimExpr(dim, context); + AffineExpr mult; + // Static case. + if (stride != MemRefType::kDynamicStrideOrOffset) + mult = getAffineConstantExpr(stride, context); + else + // Dynamic case, new symbol for each new stride. + mult = getAffineSymbolExpr(nSymbols++, context); + expr = expr + d * mult; + } + + return AffineMap::get(strides.size(), nSymbols, expr); +} + +bool mlir::isStrided(MemRefType t) { + int64_t offset; + SmallVector stridesAndOffset; + auto res = t.getStridesAndOffset(stridesAndOffset, offset); + return succeeded(res); +} diff --git a/mlir/test/Dialect/Linalg/canonicalize.mlir b/mlir/test/Dialect/Linalg/canonicalize.mlir deleted file mode 100644 index 65e1d54..0000000 --- a/mlir/test/Dialect/Linalg/canonicalize.mlir +++ /dev/null @@ -1,73 +0,0 @@ -// RUN: mlir-opt %s -canonicalize | FileCheck %s - -// CHECK-DAG: #[[SUB:.*]] = ()[s0, s1] -> (s0 - s1) - -func @fold_constants(%arg0: !linalg.buffer) -> (index, index, index, index, index) { - %c0 = constant 0 : index - %c1 = constant 1 : index - %c2 = constant 2 : index - %c3 = constant 3 : index - %c4 = constant 4 : index - %c5 = constant 5 : index - %R02 = linalg.range %c0:%c2:%c1 : !linalg.range - %R03 = linalg.range %c0:%c3:%c1 : !linalg.range - %R04 = linalg.range %c0:%c4:%c1 : !linalg.range - %R12 = linalg.range %c1:%c2:%c1 : !linalg.range - %R13 = linalg.range %c1:%c3:%c1 : !linalg.range - %R14 = linalg.range %c1:%c4:%c1 : !linalg.range - - %v = linalg.view %arg0[%R02, %R14] : !linalg.buffer -> !linalg.view - // Expected 2. - %v0 = linalg.dim %v, 0 : !linalg.view - // Expected 3. - %v1 = linalg.dim %v, 1 : !linalg.view - - %s = linalg.slice %v[%c1, %R12] : !linalg.view, index, !linalg.range, !linalg.view - // Expected 1. - %s0 = linalg.dim %s, 0 : !linalg.view - - %sv = linalg.subview %v[%v0, %v1, %c1, %c2, %c4, %c1] : !linalg.view - // Expected 1. - %sv0 = linalg.dim %sv, 0 : !linalg.view - // Expected 2. - %sv1 = linalg.dim %sv, 1 : !linalg.view - - return %v0, %v1, %s0, %sv0, %sv1 : index, index, index, index, index -} - -// CHECK-LABEL: fold_constants -// CHECK-DAG: %[[c1:.*]] = constant 1 : index -// CHECK-DAG: %[[c2:.*]] = constant 2 : index -// CHECK-DAG: %[[c3:.*]] = constant 3 : index -// CHECK: return %[[c2]], %[[c3]], %[[c1]], %[[c1]], %[[c2]] - - -func @fold_indices(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index) -> (index, index, index, index) { - %c0 = constant 0 : index - %c1 = constant 1 : index - %R = linalg.range %arg1:%arg3:%c1 : !linalg.range - - %v = linalg.view %arg0[%R, %R] : !linalg.buffer -> !linalg.view - // Expected %arg3 - %arg1. - %v0 = linalg.dim %v, 0 : !linalg.view - // Expected %arg3 - %arg1. - %v1 = linalg.dim %v, 1 : !linalg.view - - %arg1_p_arg2 = addi %arg1, %arg2: index - %arg1_p_arg2_affine = affine.apply (i, j) -> (i + j) (%arg1, %arg2) - %sv = linalg.subview %v[%arg1, %arg1_p_arg2, %c1, %arg1, %arg1_p_arg2_affine, %c1] : !linalg.view - // Expected %arg2 but can't fold affine.apply with addi. - %sv0 = linalg.dim %sv, 0 : !linalg.view - // Expected %arg2. - %sv1 = linalg.dim %sv, 1 : !linalg.view - - return %v0, %v1, %sv0, %sv1 : index, index, index, index -} - -// CHECK-LABEL: fold_indices -// CHECK: (%[[arg0:.*]]: !linalg.buffer, %[[arg1:.*]]: index, %[[arg2:.*]]: index, %[[arg3:.*]]: index -// CHECK: %[[r0:.*]] = affine.apply #[[SUB]]()[%[[arg3]], %[[arg1]]] -// CHECK: %[[r1:.*]] = affine.apply #[[SUB]]()[%[[arg3]], %[[arg1]]] -// CHECK: %[[add:.*]] = addi %[[arg1]], %[[arg2]] : index -// CHECK: %[[aff:.*]] = affine.apply #[[SUB]]()[%[[add]], %[[arg1]]] -// CHECK: return %[[r0]], %[[r1]], %[[aff]], %[[arg2]] \ No newline at end of file diff --git a/mlir/test/Dialect/Linalg/fusion-2-level.mlir b/mlir/test/Dialect/Linalg/fusion-2-level.mlir index 1ef4a0b..632ba34 100644 --- a/mlir/test/Dialect/Linalg/fusion-2-level.mlir +++ b/mlir/test/Dialect/Linalg/fusion-2-level.mlir @@ -1,11 +1,16 @@ // RUN: mlir-opt %s -linalg-fusion | FileCheck %s + #map0 = (d0) -> (d0 + 20) #map1 = (d0) -> (d0 + 40) #map2 = (d0) -> (d0 + 30) #map3 = (d0) -> (d0 + 2) #map4 = (d0) -> (d0 + 4) #map5 = (d0) -> (d0 + 3) -func @f1(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view, %D: !linalg.view, %E: !linalg.view) -> !linalg.view { + +// CHECK-DAG: #[[strided2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) +#strided2D = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) + +func @f1(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { %c1 = constant 1 : index %c0 = constant 0 : index %c4 = constant 4 : index @@ -14,39 +19,39 @@ func @f1(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view< %c40 = constant 40 : index %c30 = constant 30 : index %c20 = constant 20 : index - %0 = linalg.dim %C, 0 : !linalg.view - %1 = linalg.dim %C, 1 : !linalg.view - %2 = linalg.dim %D, 1 : !linalg.view - linalg.matmul(%A, %B, %C) : !linalg.view, !linalg.view, !linalg.view + %0 = dim %C, 0 : memref + %1 = dim %C, 1 : memref + %2 = dim %D, 1 : memref + linalg.matmul(%A, %B, %C) : memref, memref, memref loop.for %arg5 = %c0 to %0 step %c20 { loop.for %arg6 = %c0 to %2 step %c30 { loop.for %arg7 = %c0 to %1 step %c40 { %3 = affine.apply #map0(%arg5) %4 = affine.apply #map1(%arg7) - %5 = linalg.subview %C[%arg5, %3, %c1, %arg7, %4, %c1] : !linalg.view + %5 = linalg.subview %C[%arg5, %3, %c1, %arg7, %4, %c1] : memref %6 = affine.apply #map2(%arg6) - %7 = linalg.subview %D[%arg7, %4, %c1, %arg6, %6, %c1] : !linalg.view - %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : !linalg.view - %9 = linalg.dim %5, 0 : !linalg.view - %10 = linalg.dim %5, 1 : !linalg.view - %11 = linalg.dim %7, 1 : !linalg.view + %7 = linalg.subview %D[%arg7, %4, %c1, %arg6, %6, %c1] : memref + %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : memref + %9 = dim %5, 0 : memref + %10 = dim %5, 1 : memref + %11 = dim %7, 1 : memref loop.for %arg8 = %c0 to %9 step %c2 { loop.for %arg9 = %c0 to %11 step %c3 { loop.for %B0 = %c0 to %10 step %c4 { %12 = affine.apply #map3(%arg8) %13 = affine.apply #map4(%B0) - %14 = linalg.subview %5[%arg8, %12, %c1, %B0, %13, %c1] : !linalg.view + %14 = linalg.subview %5[%arg8, %12, %c1, %B0, %13, %c1] : memref %15 = affine.apply #map5(%arg9) - %16 = linalg.subview %7[%B0, %13, %c1, %arg9, %15, %c1] : !linalg.view - %17 = linalg.subview %8[%arg8, %12, %c1, %arg9, %15, %c1] : !linalg.view - linalg.matmul(%14, %16, %17) : !linalg.view, !linalg.view, !linalg.view + %16 = linalg.subview %7[%B0, %13, %c1, %arg9, %15, %c1] : memref + %17 = linalg.subview %8[%arg8, %12, %c1, %arg9, %15, %c1] : memref + linalg.matmul(%14, %16, %17) : memref, memref, memref } } } } } } - return %E : !linalg.view + return %E : memref } // CHECK-LABEL: func @f1 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) diff --git a/mlir/test/Dialect/Linalg/fusion.mlir b/mlir/test/Dialect/Linalg/fusion.mlir index 65e8c72..d75ae64 100644 --- a/mlir/test/Dialect/Linalg/fusion.mlir +++ b/mlir/test/Dialect/Linalg/fusion.mlir @@ -4,30 +4,33 @@ #map1 = (d0) -> (d0 + 4) #map2 = (d0) -> (d0 + 3) -func @f1(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view, %D: !linalg.view, %E: !linalg.view) -> !linalg.view { +// CHECK-DAG: #[[strided2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) +#strided2D = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) + +func @f1(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { %c0 = constant 0 : index %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index - %0 = linalg.dim %A, 0 : !linalg.view - %1 = linalg.dim %A, 1 : !linalg.view - %2 = linalg.dim %B, 1 : !linalg.view - linalg.matmul(%A, %B, %C) : !linalg.view, !linalg.view, !linalg.view + %0 = dim %A, 0 : memref + %1 = dim %A, 1 : memref + %2 = dim %B, 1 : memref + linalg.matmul(%A, %B, %C) : memref, memref, memref %c1 = constant 1 : index loop.for %arg5 = %c0 to %0 step %c2 { loop.for %arg6 = %c0 to %2 step %c3 { loop.for %arg7 = %c0 to %1 step %c4 { %3 = affine.apply #map0(%arg5) %4 = affine.apply #map1(%arg7) - %5 = linalg.subview %A[%arg5, %3, %c1, %arg7, %4, %c1] : !linalg.view + %5 = linalg.subview %A[%arg5, %3, %c1, %arg7, %4, %c1] : memref %6 = affine.apply #map2(%arg6) - %7 = linalg.subview %B[%arg7, %4, %c1, %arg6, %6, %c1] : !linalg.view - %8 = linalg.subview %C[%arg5, %3, %c1, %arg6, %6, %c1] : !linalg.view - linalg.matmul(%5, %7, %8) : !linalg.view, !linalg.view, !linalg.view + %7 = linalg.subview %B[%arg7, %4, %c1, %arg6, %6, %c1] : memref + %8 = linalg.subview %C[%arg5, %3, %c1, %arg6, %6, %c1] : memref + linalg.matmul(%5, %7, %8) : memref, memref, memref } } } - return %E : !linalg.view + return %E : memref } // CHECK-LABEL: func @f1 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) @@ -38,109 +41,109 @@ func @f1(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view< // CHECK: loop.for // CHECK: linalg.matmul -func @f2(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view, %D: !linalg.view, %E: !linalg.view) -> !linalg.view { +func @f2(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { %c1 = constant 1 : index %c0 = constant 0 : index %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index - linalg.matmul(%A, %B, %C) : !linalg.view, !linalg.view, !linalg.view - %0 = linalg.dim %C, 0 : !linalg.view - %1 = linalg.dim %C, 1 : !linalg.view - %2 = linalg.dim %D, 1 : !linalg.view + linalg.matmul(%A, %B, %C) : memref, memref, memref + %0 = dim %C, 0 : memref + %1 = dim %C, 1 : memref + %2 = dim %D, 1 : memref loop.for %arg5 = %c0 to %0 step %c2 { loop.for %arg6 = %c0 to %2 step %c3 { loop.for %arg7 = %c0 to %1 step %c4 { %3 = affine.apply #map0(%arg5) %4 = affine.apply #map1(%arg7) - %5 = linalg.subview %C[%arg5, %3, %c1, %arg7, %4, %c1] : !linalg.view + %5 = linalg.subview %C[%arg5, %3, %c1, %arg7, %4, %c1] : memref %6 = affine.apply #map2(%arg6) - %7 = linalg.subview %D[%arg7, %4, %c1, %arg6, %6, %c1] : !linalg.view - %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : !linalg.view - linalg.matmul(%5, %7, %8) : !linalg.view, !linalg.view, !linalg.view + %7 = linalg.subview %D[%arg7, %4, %c1, %arg6, %6, %c1] : memref + %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : memref + linalg.matmul(%5, %7, %8) : memref, memref, memref } } } - return %E : !linalg.view + return %E : memref } // CHECK-LABEL: func @f2 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) -// CHECK-DAG: %[[C_0:.*]] = linalg.dim %[[C]], 0 : !linalg.view -// CHECK-DAG: %[[C_1:.*]] = linalg.dim %[[C]], 1 : !linalg.view -// CHECK-DAG: %[[D_1:.*]] = linalg.dim %[[D]], 1 : !linalg.view +// CHECK-DAG: %[[C_0:.*]] = dim %[[C]], 0 : memref +// CHECK-DAG: %[[C_1:.*]] = dim %[[C]], 1 : memref +// CHECK-DAG: %[[D_1:.*]] = dim %[[D]], 1 : memref // CHECK: loop.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { // CHECK: linalg.matmul // CHECK: linalg.matmul -func @f3(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view, %D: !linalg.view, %E: !linalg.view) -> !linalg.view { +func @f3(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { %c1 = constant 1 : index %c0 = constant 0 : index %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index - linalg.matmul(%A, %B, %C) : !linalg.view, !linalg.view, !linalg.view - %0 = linalg.dim %D, 0 : !linalg.view - %1 = linalg.dim %D, 1 : !linalg.view - %2 = linalg.dim %C, 1 : !linalg.view + linalg.matmul(%A, %B, %C) : memref, memref, memref + %0 = dim %D, 0 : memref + %1 = dim %D, 1 : memref + %2 = dim %C, 1 : memref loop.for %arg5 = %c0 to %0 step %c2 { loop.for %arg6 = %c0 to %2 step %c3 { loop.for %arg7 = %c0 to %1 step %c4 { %3 = affine.apply #map0(%arg5) %4 = affine.apply #map1(%arg7) - %5 = linalg.subview %D[%arg5, %3, %c1, %arg7, %4, %c1] : !linalg.view + %5 = linalg.subview %D[%arg5, %3, %c1, %arg7, %4, %c1] : memref %6 = affine.apply #map2(%arg6) - %7 = linalg.subview %C[%arg7, %4, %c1, %arg6, %6, %c1] : !linalg.view - %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : !linalg.view - linalg.matmul(%5, %7, %8) : !linalg.view, !linalg.view, !linalg.view + %7 = linalg.subview %C[%arg7, %4, %c1, %arg6, %6, %c1] : memref + %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : memref + linalg.matmul(%5, %7, %8) : memref, memref, memref } } } - return %E : !linalg.view + return %E : memref } // CHECK-LABEL: func @f3 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) -// CHECK: %[[D_0:.*]] = linalg.dim %[[D]], 0 : !linalg.view -// CHECK: %[[D_1:.*]] = linalg.dim %[[D]], 1 : !linalg.view -// CHECK: %[[C_1:.*]] = linalg.dim %[[C]], 1 : !linalg.view +// CHECK: %[[D_0:.*]] = dim %[[D]], 0 : memref +// CHECK: %[[D_1:.*]] = dim %[[D]], 1 : memref +// CHECK: %[[C_1:.*]] = dim %[[C]], 1 : memref // CHECK: loop.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { // CHECK: linalg.matmul // CHECK: linalg.matmul -func @f4(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view, %D: !linalg.view, %E: !linalg.view) -> !linalg.view { +func @f4(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { %c1 = constant 1 : index %c0 = constant 0 : index %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index - linalg.matmul(%A, %B, %C) : !linalg.view, !linalg.view, !linalg.view - linalg.matmul(%A, %B, %D) : !linalg.view, !linalg.view, !linalg.view - %0 = linalg.dim %C, 0 : !linalg.view - %1 = linalg.dim %C, 1 : !linalg.view - %2 = linalg.dim %D, 1 : !linalg.view + linalg.matmul(%A, %B, %C) : memref, memref, memref + linalg.matmul(%A, %B, %D) : memref, memref, memref + %0 = dim %C, 0 : memref + %1 = dim %C, 1 : memref + %2 = dim %D, 1 : memref loop.for %arg5 = %c0 to %0 step %c2 { loop.for %arg6 = %c0 to %2 step %c3 { loop.for %arg7 = %c0 to %1 step %c4 { %3 = affine.apply #map0(%arg5) %4 = affine.apply #map1(%arg7) - %5 = linalg.subview %C[%arg5, %3, %c1, %arg7, %4, %c1] : !linalg.view + %5 = linalg.subview %C[%arg5, %3, %c1, %arg7, %4, %c1] : memref %6 = affine.apply #map2(%arg6) - %7 = linalg.subview %D[%arg7, %4, %c1, %arg6, %6, %c1] : !linalg.view - %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : !linalg.view - linalg.matmul(%5, %7, %8) : !linalg.view, !linalg.view, !linalg.view + %7 = linalg.subview %D[%arg7, %4, %c1, %arg6, %6, %c1] : memref + %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : memref + linalg.matmul(%5, %7, %8) : memref, memref, memref } } } - return %E : !linalg.view + return %E : memref } // CHECK-LABEL: func @f4 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) -// CHECK: %[[C_0:.*]] = linalg.dim %[[C]], 0 : !linalg.view -// CHECK: %[[C_1:.*]] = linalg.dim %[[C]], 1 : !linalg.view -// CHECK: %[[D_1:.*]] = linalg.dim %[[D]], 1 : !linalg.view +// CHECK: %[[C_0:.*]] = dim %[[C]], 0 : memref +// CHECK: %[[C_1:.*]] = dim %[[C]], 1 : memref +// CHECK: %[[D_1:.*]] = dim %[[D]], 1 : memref // CHECK: loop.for %{{.*}} = %{{.*}} to %[[C_0]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[D_1]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { @@ -149,37 +152,37 @@ func @f4(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view< // CHECK: linalg.matmul // CHECK: linalg.matmul -func @f5(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view, %D: !linalg.view, %E: !linalg.view) -> !linalg.view { +func @f5(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { %c1 = constant 1 : index %c0 = constant 0 : index %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index - %0 = linalg.dim %B, 1 : !linalg.view - %1 = linalg.dim %D, 0 : !linalg.view - %2 = linalg.dim %D, 1 : !linalg.view - linalg.matmul(%A, %B, %C) : !linalg.view, !linalg.view, !linalg.view - linalg.matmul(%C, %B, %D) : !linalg.view, !linalg.view, !linalg.view + %0 = dim %B, 1 : memref + %1 = dim %D, 0 : memref + %2 = dim %D, 1 : memref + linalg.matmul(%A, %B, %C) : memref, memref, memref + linalg.matmul(%C, %B, %D) : memref, memref, memref loop.for %arg5 = %c0 to %1 step %c2 { loop.for %arg6 = %c0 to %0 step %c3 { loop.for %arg7 = %c0 to %2 step %c4 { %3 = affine.apply #map0(%arg5) %4 = affine.apply #map1(%arg7) - %5 = linalg.subview %D[%arg5, %3, %c1, %arg7, %4, %c1] : !linalg.view + %5 = linalg.subview %D[%arg5, %3, %c1, %arg7, %4, %c1] : memref %6 = affine.apply #map2(%arg6) - %7 = linalg.subview %B[%arg7, %4, %c1, %arg6, %6, %c1] : !linalg.view - %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : !linalg.view - linalg.matmul(%5, %7, %8) : !linalg.view, !linalg.view, !linalg.view + %7 = linalg.subview %B[%arg7, %4, %c1, %arg6, %6, %c1] : memref + %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : memref + linalg.matmul(%5, %7, %8) : memref, memref, memref } } } - return %E : !linalg.view + return %E : memref } // CHECK-LABEL: func @f5 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) -// CHECK-DAG: %[[B_1:.*]] = linalg.dim %[[B]], 1 : !linalg.view -// CHECK-DAG: %[[D_0:.*]] = linalg.dim %[[D]], 0 : !linalg.view -// CHECK-DAG: %[[D_1:.*]] = linalg.dim %[[D]], 1 : !linalg.view +// CHECK-DAG: %[[B_1:.*]] = dim %[[B]], 1 : memref +// CHECK-DAG: %[[D_0:.*]] = dim %[[D]], 0 : memref +// CHECK-DAG: %[[D_1:.*]] = dim %[[D]], 1 : memref // Don't fuse C due to false dependence, note that this is too conservative though. // CHECK: linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) // CHECK: loop.for %{{.*}} = %{{.*}} to %[[D_0]] step %{{.*}} { @@ -188,31 +191,31 @@ func @f5(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view< // CHECK: linalg.matmul // CHECK: linalg.matmul -func @f6(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view, %D: !linalg.view, %E: !linalg.view) -> !linalg.view { +func @f6(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { %c1 = constant 1 : index %c0 = constant 0 : index %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index - %0 = linalg.dim %C, 1 : !linalg.view - linalg.matmul(%A, %B, %C) : !linalg.view, !linalg.view, !linalg.view - linalg.matmul(%A, %C, %E) : !linalg.view, !linalg.view, !linalg.view - %1 = linalg.dim %C, 0 : !linalg.view - %2 = linalg.dim %D, 1 : !linalg.view + %0 = dim %C, 1 : memref + linalg.matmul(%A, %B, %C) : memref, memref, memref + linalg.matmul(%A, %C, %E) : memref, memref, memref + %1 = dim %C, 0 : memref + %2 = dim %D, 1 : memref loop.for %arg5 = %c0 to %1 step %c2 { loop.for %arg6 = %c0 to %2 step %c3 { loop.for %arg7 = %c0 to %0 step %c4 { %3 = affine.apply #map0(%arg5) %4 = affine.apply #map1(%arg7) - %5 = linalg.subview %C[%arg5, %3, %c1, %arg7, %4, %c1] : !linalg.view + %5 = linalg.subview %C[%arg5, %3, %c1, %arg7, %4, %c1] : memref %6 = affine.apply #map2(%arg6) - %7 = linalg.subview %D[%arg7, %4, %c1, %arg6, %6, %c1] : !linalg.view - %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : !linalg.view - linalg.matmul(%5, %7, %8) : !linalg.view, !linalg.view, !linalg.view + %7 = linalg.subview %D[%arg7, %4, %c1, %arg6, %6, %c1] : memref + %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : memref + linalg.matmul(%5, %7, %8) : memref, memref, memref } } } - return %E : !linalg.view + return %E : memref } // CHECK-LABEL: func @f6 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) @@ -226,29 +229,29 @@ func @f6(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view< // CHECK: linalg.matmul // CHECK-NOT: linalg.matmul -func @f7(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view, %D: !linalg.view, %E: !linalg.view) -> !linalg.view { +func @f7(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { %c1 = constant 1 : index %c0 = constant 0 : index %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index - %0 = linalg.dim %A, 0 : !linalg.view - %1 = linalg.dim %A, 1 : !linalg.view - %2 = linalg.dim %C, 1 : !linalg.view - %3 = linalg.dim %C, 0 : !linalg.view - %4 = linalg.dim %D, 1 : !linalg.view - linalg.matmul(%A, %C, %E) : !linalg.view, !linalg.view, !linalg.view - linalg.matmul(%A, %B, %C) : !linalg.view, !linalg.view, !linalg.view + %0 = dim %A, 0 : memref + %1 = dim %A, 1 : memref + %2 = dim %C, 1 : memref + %3 = dim %C, 0 : memref + %4 = dim %D, 1 : memref + linalg.matmul(%A, %C, %E) : memref, memref, memref + linalg.matmul(%A, %B, %C) : memref, memref, memref loop.for %arg5 = %c0 to %0 step %c2 { loop.for %arg6 = %c0 to %2 step %c3 { loop.for %arg7 = %c0 to %1 step %c4 { %5 = affine.apply #map0(%arg5) %6 = affine.apply #map1(%arg7) - %7 = linalg.subview %A[%arg5, %5, %c1, %arg7, %6, %c1] : !linalg.view + %7 = linalg.subview %A[%arg5, %5, %c1, %arg7, %6, %c1] : memref %8 = affine.apply #map2(%arg6) - %9 = linalg.subview %C[%arg7, %6, %c1, %arg6, %8, %c1] : !linalg.view - %10 = linalg.subview %E[%arg5, %5, %c1, %arg6, %8, %c1] : !linalg.view - linalg.matmul(%7, %9, %10) : !linalg.view, !linalg.view, !linalg.view + %9 = linalg.subview %C[%arg7, %6, %c1, %arg6, %8, %c1] : memref + %10 = linalg.subview %E[%arg5, %5, %c1, %arg6, %8, %c1] : memref + linalg.matmul(%7, %9, %10) : memref, memref, memref } } } @@ -257,23 +260,23 @@ func @f7(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view< loop.for %arg7 = %c0 to %2 step %c4 { %5 = affine.apply #map0(%arg5) %6 = affine.apply #map1(%arg7) - %7 = linalg.subview %C[%arg5, %5, %c1, %arg7, %6, %c1] : !linalg.view + %7 = linalg.subview %C[%arg5, %5, %c1, %arg7, %6, %c1] : memref %8 = affine.apply #map2(%arg6) - %9 = linalg.subview %D[%arg7, %6, %c1, %arg6, %8, %c1] : !linalg.view - %10 = linalg.subview %E[%arg5, %5, %c1, %arg6, %8, %c1] : !linalg.view - linalg.matmul(%7, %9, %10) : !linalg.view, !linalg.view, !linalg.view + %9 = linalg.subview %D[%arg7, %6, %c1, %arg6, %8, %c1] : memref + %10 = linalg.subview %E[%arg5, %5, %c1, %arg6, %8, %c1] : memref + linalg.matmul(%7, %9, %10) : memref, memref, memref } } } - return %E : !linalg.view + return %E : memref } // CHECK-LABEL: func @f7 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) -// CHECK: %[[A_0:.*]] = linalg.dim %[[A]], 0 : !linalg.view -// CHECK: %[[A_1:.*]] = linalg.dim %[[A]], 1 : !linalg.view -// CHECK: %[[C_1:.*]] = linalg.dim %[[C]], 1 : !linalg.view -// CHECK: %[[C_0:.*]] = linalg.dim %[[C]], 0 : !linalg.view -// CHECK: %[[D_1:.*]] = linalg.dim %[[D]], 1 : !linalg.view +// CHECK: %[[A_0:.*]] = dim %[[A]], 0 : memref +// CHECK: %[[A_1:.*]] = dim %[[A]], 1 : memref +// CHECK: %[[C_1:.*]] = dim %[[C]], 1 : memref +// CHECK: %[[C_0:.*]] = dim %[[C]], 0 : memref +// CHECK: %[[D_1:.*]] = dim %[[D]], 1 : memref // CHECK: linalg.matmul(%[[A]], %[[C]], %[[E]]) // CHECK: loop.for %{{.*}} = %{{.*}} to %[[A_0]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[C_1]] step %{{.*}} { @@ -286,31 +289,31 @@ func @f7(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view< // CHECK: linalg.matmul // CHECK-NOT: linalg.matmul -func @f8(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view, %D: !linalg.view, %E: !linalg.view) -> !linalg.view { +func @f8(%A: memref, %B: memref, %C: memref, %D: memref, %E: memref) -> memref { %c1 = constant 1 : index %c0 = constant 0 : index %c4 = constant 4 : index %c3 = constant 3 : index %c2 = constant 2 : index - %0 = linalg.dim %A, 0 : !linalg.view - %1 = linalg.dim %A, 1 : !linalg.view - linalg.matmul(%A, %C, %D) : !linalg.view, !linalg.view, !linalg.view - linalg.matmul(%A, %B, %C) : !linalg.view, !linalg.view, !linalg.view - %2 = linalg.dim %D, 1 : !linalg.view + %0 = dim %A, 0 : memref + %1 = dim %A, 1 : memref + linalg.matmul(%A, %C, %D) : memref, memref, memref + linalg.matmul(%A, %B, %C) : memref, memref, memref + %2 = dim %D, 1 : memref loop.for %arg5 = %c0 to %0 step %c2 { loop.for %arg6 = %c0 to %2 step %c3 { loop.for %arg7 = %c0 to %1 step %c4 { %3 = affine.apply #map0(%arg5) %4 = affine.apply #map1(%arg7) - %5 = linalg.subview %A[%arg5, %3, %c1, %arg7, %4, %c1] : !linalg.view + %5 = linalg.subview %A[%arg5, %3, %c1, %arg7, %4, %c1] : memref %6 = affine.apply #map2(%arg6) - %7 = linalg.subview %D[%arg7, %4, %c1, %arg6, %6, %c1] : !linalg.view - %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : !linalg.view - linalg.matmul(%5, %7, %8) : !linalg.view, !linalg.view, !linalg.view + %7 = linalg.subview %D[%arg7, %4, %c1, %arg6, %6, %c1] : memref + %8 = linalg.subview %E[%arg5, %3, %c1, %arg6, %6, %c1] : memref + linalg.matmul(%5, %7, %8) : memref, memref, memref } } } - return %E : !linalg.view + return %E : memref } // CHECK-LABEL: func @f8 // CHECK: (%[[A:.*]]:{{.*}}, %[[B:.*]]:{{.*}}, %[[C:.*]]:{{.*}}, %[[D:.*]]:{{.*}}, %[[E:.*]]:{{.*}}) @@ -328,7 +331,7 @@ func @f8(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view< n_loop_types = [2, 0, 0], n_views = [2, 1] } -func @pointwise(%A: !linalg.view, %B: !linalg.view, %C: !linalg.view, %D: !linalg.view) { +func @pointwise(%A: memref, %B: memref, %C: memref, %D: memref) { %c1 = constant 1 : index %c0 = constant 0 : index %c3 = constant 3 : index @@ -337,21 +340,21 @@ func @pointwise(%A: !linalg.view, %B: !linalg.view, %C: !linal ^bb0(%E: f32, %arg5: f32, %arg6: f32): // no predecessors %2 = addf %E, %arg5 : f32 linalg.yield %2 : f32 - }: !linalg.view, !linalg.view, !linalg.view - %0 = linalg.dim %B, 0 : !linalg.view - %1 = linalg.dim %B, 1 : !linalg.view + }: memref, memref, memref + %0 = dim %B, 0 : memref + %1 = dim %B, 1 : memref loop.for %E = %c0 to %0 step %c2 { loop.for %arg5 = %c0 to %1 step %c3 { %2 = affine.apply #map0(%E) %3 = affine.apply #map1(%arg5) - %4 = linalg.subview %B[%E, %2, %c1, %arg5, %3, %c1] : !linalg.view - %5 = linalg.subview %C[%E, %2, %c1, %arg5, %3, %c1] : !linalg.view - %6 = linalg.subview %D[%E, %2, %c1, %arg5, %3, %c1] : !linalg.view + %4 = linalg.subview %B[%E, %2, %c1, %arg5, %3, %c1] : memref + %5 = linalg.subview %C[%E, %2, %c1, %arg5, %3, %c1] : memref + %6 = linalg.subview %D[%E, %2, %c1, %arg5, %3, %c1] : memref linalg.generic #pointwise_2d_trait %4, %5, %6 { ^bb0(%arg6: f32, %arg7: f32, %arg8: f32): // no predecessors %7 = mulf %arg6, %arg7 : f32 linalg.yield %7 : f32 - }: !linalg.view, !linalg.view, !linalg.view + }: memref, memref, memref } } return diff --git a/mlir/test/Dialect/Linalg/invalid.mlir b/mlir/test/Dialect/Linalg/invalid.mlir index 9ea9e51..c8a3cb4 100644 --- a/mlir/test/Dialect/Linalg/invalid.mlir +++ b/mlir/test/Dialect/Linalg/invalid.mlir @@ -44,63 +44,63 @@ func @buffer_valid_element_type() { // ----- -func @load_number_of_indices(%v : !linalg.view) { - // expected-error @+2 {{expected 0 indices, got 1}} +func @load_number_of_indices(%v : memref) { + // expected-error @+2 {{incorrect number of indices for load}} %c0 = constant 0 : index - linalg.load %v[%c0] : !linalg.view + load %v[%c0] : memref } // ----- -func @slice_number_of_indexings(%arg0: !linalg.view) { +func @slice_number_of_indexings(%arg0: memref(off + M * i + j)>) { // expected-error @+2 {{expected 2 indexings, got 1}} %c0 = constant 0: index - %0 = linalg.slice %arg0[%c0] : !linalg.view, index, !linalg.view + %0 = linalg.slice %arg0[%c0] : memref(off + M * i + j)>, index, memref(off + M * i + j)> } // ----- -func @slice_rank_vs_range_indices(%arg0: !linalg.view) { +func @slice_rank_vs_range_indices(%arg0: memref(off + M * i + j)>) { // expected-error @+2 {{op expected rank of the view(1) to be the number of ranges(0)}} %c0 = constant 0: index - %0 = linalg.slice %arg0[%c0, %c0] : !linalg.view, index, index, !linalg.view + %0 = linalg.slice %arg0[%c0, %c0] : memref(off + M * i + j)>, index, index, memref(off + i)> } // ----- -func @store_number_of_indices(%v : !linalg.view) { - // expected-error @+3 {{expected 0 indices, got 1}} +func @store_number_of_indices(%v : memref) { + // expected-error @+3 {{store index operand count not equal to memref rank}} %c0 = constant 0 : index %f0 = constant 0.0 : f32 - linalg.store %f0, %v[%c0] : !linalg.view + store %f0, %v[%c0] : memref } // ----- -func @subview_number_of_indices(%v : !linalg.view) { - // expected-error @+2 {{expected a view followed by 6 indices specifying a range for each dimension}} +func @subview_number_of_indices(%v : memref(off + M * i + j)>) { + // expected-error @+2 {{expected a strided memref followed by 6 indices specifying a range for each dimension}} %c0 = constant 0 : index - linalg.subview %v[%c0, %c0] : !linalg.view + linalg.subview %v[%c0, %c0] : memref(off + M * i + j)> } // ----- -func @transpose_not_permutation(%v : !linalg.view) { +func @transpose_not_permutation(%v : memref(off + M * i + j)>) { // expected-error @+1 {{expected a permutation map}} - linalg.transpose %v (i, j) -> (i, i) : !linalg.view + linalg.transpose %v (i, j) -> (i, i) : memref(off + M * i + j)> } // ----- -func @transpose_bad_rank(%v : !linalg.view) { +func @transpose_bad_rank(%v : memref(off + M * i + j)>) { // expected-error @+1 {{expected a permutation map of same rank as the view}} - linalg.transpose %v (i) -> (i) : !linalg.view + linalg.transpose %v (i) -> (i) : memref(off + M * i + j)> } // ----- func @view_type(%buf: !linalg.buffer, %min: index, %max: index, %step: index) { - // expected-error @+2 {{expected view type}} + // expected-error @+2 {{expected memref type}} %r = linalg.range %min:%max:%step : !linalg.range %0 = linalg.view %buf[%r]: !linalg.buffer -> index } @@ -110,134 +110,134 @@ func @view_type(%buf: !linalg.buffer, %min: index, %max: index, %step: in func @view_num_ranges(%buf: !linalg.buffer, %min: index, %max: index, %step: index) { // expected-error @+2 {{expected 2 ranges}} %r = linalg.range %min:%max:%step : !linalg.range - %0 = linalg.view %buf[%r]: !linalg.buffer -> !linalg.view + %0 = linalg.view %buf[%r]: !linalg.buffer -> memref(off + M * i + j)> } // ----- -func @yield_parent(%arg0: !linalg.view) { +func @yield_parent(%arg0: memref(off + i)>) { // expected-error @+1 {{op expected 'linalg.generic' parent op}} - linalg.yield %arg0: !linalg.view + linalg.yield %arg0: memref(off + i)> } // ----- -func @generic_at_least_2_operands(%arg0: !linalg.view) { +func @generic_at_least_2_operands(%arg0: memref) { // expected-error @+1 {{op expected 2 or more operands}} linalg.generic { fun = @foo, indexing_maps = [ () -> (0) ], n_views = [1, 1], n_loop_types = [0, 0, 0] - } %arg0: !linalg.view + } %arg0: memref } // ----- -func @generic_exactly_2_views(%arg0: !linalg.view) { +func @generic_exactly_2_views(%arg0: memref) { // expected-error @+1 {{op expected exactly 2 view operands}} linalg.generic { fun = @foo, indexing_maps = [ () -> (0) ], n_views = [1, 1], n_loop_types = [0, 0, 0] - } %arg0, %arg0, %arg0: !linalg.view, !linalg.view, !linalg.view + } %arg0, %arg0, %arg0: memref, memref, memref } // ----- -func @generic_undefined_fun(%arg0: !linalg.view) { +func @generic_undefined_fun(%arg0: memref) { // expected-error @+1 {{op expected fun attribute to refer to a defined symbol}} linalg.generic { fun = @foo, indexing_maps = [ () -> (0) ], n_views = [1, 1], n_loop_types = [0, 0, 0] - } %arg0, %arg0: !linalg.view, !linalg.view + } %arg0, %arg0: memref, memref } // ----- func @foo() { return } -func @generic_mismatched_num_arguments(%arg0: !linalg.view) { +func @generic_mismatched_num_arguments(%arg0: memref) { // expected-error @+1 {{op expected fun arguments to match number of views}} linalg.generic { fun = @foo, indexing_maps = [ () -> (0) ], n_views = [0, 1], n_loop_types = [0, 0, 0] - } %arg0: !linalg.view + } %arg0: memref } // ----- func @foo(%0: i32) { return } -func @generic_mismatched_num_returns(%arg0: !linalg.view) { +func @generic_mismatched_num_returns(%arg0: memref) { // expected-error @+1 {{op expected fun results to match number of output views}} linalg.generic { fun = @foo, indexing_maps = [ () -> (0) ], n_views = [0, 1], n_loop_types = [0, 0, 0] - } %arg0: !linalg.view + } %arg0: memref } // ----- func @foo(%0: i32) -> i32 { return %0: i32 } -func @generic_symbol_in_map(%arg0: !linalg.view) { +func @generic_symbol_in_map(%arg0: memref) { // expected-error @+1 {{op expected indexing_map #0 to have no symbols}} linalg.generic { fun = @foo, indexing_maps = [ ()[N] -> (0) ], n_views = [0, 1], n_loop_types = [1, 0, 0] - } %arg0: !linalg.view + } %arg0: memref } // ----- func @foo(%0: i32) -> i32 { return %0: i32 } -func @generic_wrong_dim_in_map(%arg0: !linalg.view) { +func @generic_wrong_dim_in_map(%arg0: memref) { // expected-error @+1 {{op expected indexing_map #0 to have 1 dim(s) to match the number of loops}} linalg.generic { fun = @foo, indexing_maps = [ () -> (0) ], n_views = [0, 1], n_loop_types = [1, 0, 0] - } %arg0: !linalg.view + } %arg0: memref } // ----- func @foo(%0: i32) -> i32 { return %0: i32 } -func @generic_zero_d_view(%arg0: !linalg.view) { - // expected-error @+1 {{op expected indexing_map #0 to be 0 to match 0-D view: '!linalg.view'}} +func @generic_zero_d_view(%arg0: memref) { + // expected-error @+1 {{op expected indexing_map #0 to be 0 to match 0-D view: 'memref'}} linalg.generic { fun = @foo, indexing_maps = [ () -> (1) ], n_views = [0, 1], n_loop_types = [0, 0, 0] - } %arg0: !linalg.view + } %arg0: memref } // ----- func @foo(%0: f32) -> f32 { return %0: f32 } -func @generic_one_d_view(%arg0: !linalg.view) { - // expected-error @+1 {{op expected indexing_map #0 results to match view rank: '!linalg.view'}} +func @generic_one_d_view(%arg0: memref(off + i)>) { + // expected-error @+1 {{op expected indexing_map #0 results to match view rank: 'memref (d0 + s0)>'}} linalg.generic { fun = @foo, indexing_maps = [ () -> (0, 0) ], n_views = [0, 1], n_loop_types = [0, 0, 0] - } %arg0: !linalg.view + } %arg0: memref(off + i)> } // ----- @@ -247,14 +247,14 @@ func @foo(%0: i32) -> f32 { return %1: f32 } -func @generic_fun_arg_0_element_type(%arg0: !linalg.view) { +func @generic_fun_arg_0_element_type(%arg0: memref(off + i)>) { // expected-error @+1 {{op expected fun argument 0 to match view element type: 'f32'}} linalg.generic { fun = @foo, indexing_maps = [ () -> (0) ], n_views = [0, 1], n_loop_types = [0, 0, 0] - } %arg0: !linalg.view + } %arg0: memref(off + i)> } // ----- @@ -264,21 +264,21 @@ func @foo(%0: f32) -> i4 { return %1: i4 } -func @generic_fun_result_0_element_type(%arg0: !linalg.view) { +func @generic_fun_result_0_element_type(%arg0: memref(off + i)>) { // expected-error @+1 {{op expected fun result 0 to match output view element type: 'f32'}} linalg.generic { fun = @foo, indexing_maps = [ () -> (0) ], n_views = [0, 1], n_loop_types = [0, 0, 0] - } %arg0: !linalg.view + } %arg0: memref(off + i)> } // ----- func @foo(%0: f32, %1: f32) -> f32 { return %1: f32 } -func @generic_singular_maps(%arg0: !linalg.view, %arg1: !linalg.view) { +func @generic_singular_maps(%arg0: memref(off + i)>, %arg1: memref(off + i)>) { // expected-error @+1 {{op expected the concatenation of maps in indexing_map to be invertible}} linalg.generic { fun = @foo, @@ -288,7 +288,7 @@ func @generic_singular_maps(%arg0: !linalg.view, %arg1: !linalg.view, !linalg.view + } %arg0, %arg1: memref(off + i)>, memref(off + i)> } //////////////////////////////////////////////////////////////////////////////// @@ -297,7 +297,7 @@ func @generic_singular_maps(%arg0: !linalg.view, %arg1: !linalg.view) { +func @generic_empty_region(%arg0: memref) { // expected-error @+1 {{op expected region with 1 block}} linalg.generic { indexing_maps = [ () -> (0) ], @@ -306,12 +306,12 @@ func @generic_empty_region(%arg0: !linalg.view) { } %arg0, %arg0 { ^bb1: ^bb2: - }: !linalg.view, !linalg.view + }: memref, memref } // ----- -func @generic_mismatched_num_arguments(%arg0: !linalg.view) { +func @generic_mismatched_num_arguments(%arg0: memref) { // expected-error @+1 {{op expected number of block arguments to match number of views}} linalg.generic { indexing_maps = [ () -> (0) ], @@ -319,25 +319,25 @@ func @generic_mismatched_num_arguments(%arg0: !linalg.view) { n_loop_types = [0, 0, 0] } %arg0 { ^bb: - }: !linalg.view + }: memref } // ----- -func @generic_block_arg_type(%arg0: !linalg.view) { - // expected-error @+1 {{op expected block argument 0 of the same type as elemental type of output view: '!linalg.view'}} +func @generic_block_arg_type(%arg0: memref) { + // expected-error @+1 {{op expected block argument 0 of the same type as elemental type of output view: 'memref'}} linalg.generic { indexing_maps = [ () -> (0) ], n_views = [0, 1], n_loop_types = [0, 0, 0] } %arg0 { ^bb(%i: i1): - }: !linalg.view + }: memref } // ----- -func @generic_fun_result_0_element_type(%arg0: !linalg.view) { +func @generic_fun_result_0_element_type(%arg0: memref(off + i)>) { // expected-error @+8 {{type of return operand 0 ('i1') doesn't match view element type ('f32')}} linalg.generic { indexing_maps = [ (i) -> (i) ], @@ -347,5 +347,5 @@ func @generic_fun_result_0_element_type(%arg0: !linalg.view) { ^bb(%i: f32): %0 = constant 0: i1 linalg.yield %0: i1 - }: !linalg.view + }: memref(off + i)> } diff --git a/mlir/test/Dialect/Linalg/llvm.mlir b/mlir/test/Dialect/Linalg/llvm.mlir index 3d21051..3cedbe0 100644 --- a/mlir/test/Dialect/Linalg/llvm.mlir +++ b/mlir/test/Dialect/Linalg/llvm.mlir @@ -1,4 +1,9 @@ -// RUN: mlir-opt %s -linalg-lower-to-llvm-dialect | FileCheck %s +// RUN: mlir-opt %s -linalg-convert-to-llvm +// RUN: mlir-opt %s -linalg-convert-to-llvm | FileCheck %s + +#strided1D = (d0)[s0] -> (d0 + s0) +#strided2D = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) +#strided3D = (d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2) func @buffer_size(%arg0: !linalg.buffer) { %c1 = constant 1 : index @@ -44,7 +49,7 @@ func @range(%arg0: index) { // CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2] : !llvm<"{ i64, i64, i64 }"> func @view(%arg0: !linalg.buffer, %arg1: !linalg.range) { - %0 = linalg.view %arg0[%arg1] : !linalg.buffer -> !linalg.view + %0 = linalg.view %arg0[%arg1] : !linalg.buffer -> memref return } // CHECK-LABEL: func @view @@ -64,7 +69,7 @@ func @view(%arg0: !linalg.buffer, %arg1: !linalg.range) { // CHECK-NEXT: llvm.return func @view3d(%arg0: !linalg.buffer, %arg1: !linalg.range, %arg2: !linalg.range, %arg3: !linalg.range) { - %0 = linalg.view %arg0[%arg1, %arg2, %arg3] : !linalg.buffer -> !linalg.view + %0 = linalg.view %arg0[%arg1, %arg2, %arg3] : !linalg.buffer -> memref return } // CHECK-LABEL: func @view3d @@ -78,22 +83,23 @@ func @view3d(%arg0: !linalg.buffer, %arg1: !linalg.range, %arg2: !linalg. // CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> func @slice(%arg0: !linalg.buffer, %arg1: !linalg.range) { - %0 = linalg.view %arg0[%arg1] : !linalg.buffer -> !linalg.view - %1 = linalg.slice %0[%arg1] : !linalg.view, !linalg.range, !linalg.view + %0 = linalg.view %arg0[%arg1] : !linalg.buffer -> memref + %1 = linalg.slice %0[%arg1] : memref, !linalg.range, memref return } // CHECK-LABEL: func @slice // insert ptr for view op // CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> // insert data ptr for slice op -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> -// CHECK-NEXT: llvm.extractvalue %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> +// CHECK: llvm.extractvalue %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> // CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> // CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }"> // CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64 // CHECK-NEXT: llvm.add %{{.*}}, %{{.*}} : !llvm.i64 // insert offset +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> // CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[1] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> +// CHECK-NEXT: llvm.mlir.constant(0 : index) // CHECK-NEXT: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }"> // CHECK-NEXT: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }"> // CHECK-NEXT: llvm.extractvalue %{{.*}}[2] : !llvm<"{ i64, i64, i64 }"> @@ -112,24 +118,24 @@ func @slice(%arg0: !linalg.buffer, %arg1: !linalg.range) { // CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> // CHECK-NEXT: llvm.insertvalue %{{.*}}, %{{.*}}[3, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> -func @dot(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !linalg.view) { - linalg.dot(%arg0, %arg1, %arg2) : !linalg.view, !linalg.view, !linalg.view +func @dot(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.dot(%arg0, %arg1, %arg2) : memref, memref, memref return } -// CHECK-LABEL: func @dot(%{{.*}}: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, %{{.*}}: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">, %{{.*}}: !llvm<"{ float*, i64, [0 x i64], [0 x i64] }">) { +// CHECK-LABEL: func @dot(%{{.*}}: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, %{{.*}}: !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, %{{.*}}: !llvm<"{ float*, i64 }*">) { // CHECK-COUNT-3: llvm.mlir.constant(1 : index){{.*[[:space:]].*}}llvm.alloca{{.*[[:space:]].*}}llvm.store -// CHECK-NEXT: llvm.call @linalg_dot_viewxf32_viewxf32_viewf32(%{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, i64, [0 x i64], [0 x i64] }*">) -> () +// CHECK-NEXT: llvm.call @linalg_dot_viewsxf32_viewsxf32_viewf32(%{{.*}}, %{{.*}}, %{{.*}}) : (!llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, i64, [1 x i64], [1 x i64] }*">, !llvm<"{ float*, i64 }*">) -> () -func @dim(%arg0: !linalg.view) { - %0 = linalg.dim %arg0, 1 : !linalg.view +func @dim(%arg0: memref) { + %0 = dim %arg0, 1 : memref return } -// CHECK-LABEL: func @dim(%{{.*}}: !llvm<"{ float*, i64, [2 x i64], [2 x i64] }">) { +// CHECK-LABEL: func @dim(%{{.*}}: !llvm<"{ float*, i64, [2 x i64], [2 x i64] }*">) { // CHECK: llvm.extractvalue %{{.*}}[2, 1] : !llvm<"{ float*, i64, [2 x i64], [2 x i64] }"> -func @subview(%arg0: !linalg.view) { +func @subview(%arg0: memref) { %c0 = constant 0 : index - %0 = linalg.subview %arg0[%c0, %c0, %c0, %c0, %c0, %c0] : !linalg.view + %0 = linalg.subview %arg0[%c0, %c0, %c0, %c0, %c0, %c0] : memref return } // CHECK-LABEL: func @subview @@ -156,37 +162,37 @@ func @subview(%arg0: !linalg.view) { // CHECK-NEXT: llvm.select %{{.*}}, %{{.*}}, %{{.*}} : !llvm.i1, !llvm.i64 // CHECK-NEXT: llvm.mul %{{.*}}, %{{.*}} : !llvm.i64 -func @view_with_range_and_index(%arg0: !linalg.view) { +func @view_with_range_and_index(%arg0: memref) { %c0 = constant 0 : index %c1 = constant 1 : index %R = linalg.range %c0:%c1:%c1 : !linalg.range loop.for %i0 = %c0 to %c1 step %c1 { - %1 = linalg.slice %arg0[%i0, %R] : !linalg.view, index, !linalg.range, !linalg.view + %1 = linalg.slice %arg0[%i0, %R] : memref, index, !linalg.range, memref } return } // CHECK-LABEL: func @view_with_range_and_index // loop-body. // CHECK: llvm.mlir.undef : !llvm<"{ double*, i64, [1 x i64], [1 x i64] }"> -// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ double*, i64, [1 x i64], [1 x i64] }"> // CHECK: llvm.extractvalue %{{.*}}[3, 0] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }"> // CHECK: llvm.extractvalue %{{.*}}[3, 1] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }"> // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm<"{ double*, i64, [2 x i64], [2 x i64] }"> +// CHECK: llvm.insertvalue %{{.*}}, %{{.*}}[0] : !llvm<"{ double*, i64, [1 x i64], [1 x i64] }"> // CHECK: llvm.insertvalue %{{.*}}[1] : !llvm<"{ double*, i64, [1 x i64], [1 x i64] }"> // CHECK: llvm.extractvalue %{{.*}}[0] : !llvm<"{ i64, i64, i64 }"> // CHECK: llvm.extractvalue %{{.*}}[1] : !llvm<"{ i64, i64, i64 }"> // CHECK: llvm.insertvalue %{{.*}}[2, 0] : !llvm<"{ double*, i64, [1 x i64], [1 x i64] }"> // CHECK: llvm.insertvalue %{{.*}}[3, 0] : !llvm<"{ double*, i64, [1 x i64], [1 x i64] }"> -func @copy(%arg0: !linalg.view, %arg1: !linalg.view) { - linalg.copy(%arg0, %arg1) : !linalg.view, !linalg.view +func @copy(%arg0: memref, %arg1: memref) { + linalg.copy(%arg0, %arg1) : memref, memref return } // CHECK-LABEL: func @copy -// CHECK: llvm.call @linalg_copy_viewxxxf32_viewxxxf32(%{{.*}}, %{{.*}}) : (!llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">, !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">) -> () +// CHECK: llvm.call @linalg_copy_viewsxsxsxf32_viewsxsxsxf32(%{{.*}}, %{{.*}}) : (!llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">, !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">) -> () -func @transpose(%arg0: !linalg.view) { - %0 = linalg.transpose %arg0 (i, j, k) -> (k, i, j) : !linalg.view +func @transpose(%arg0: memref) { + %0 = linalg.transpose %arg0 (i, j, k) -> (k, i, j) : memref return } // CHECK-LABEL: func @transpose @@ -200,10 +206,10 @@ func @transpose(%arg0: !linalg.view) { // CHECK: llvm.extractvalue {{.*}}[2, 2] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> // CHECK: llvm.insertvalue {{.*}}[2, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> -func @copy_transpose(%arg0: !linalg.view, %arg1: !linalg.view) { +func @copy_transpose(%arg0: memref, %arg1: memref) { linalg.copy(%arg0, %arg1) {inputPermutation = (i, j, k) -> (i, k, j), outputPermutation = (i, j, k) -> (k, j, i)} - : !linalg.view, !linalg.view + : memref, memref return } // CHECK-LABEL: func @copy @@ -227,4 +233,4 @@ func @copy_transpose(%arg0: !linalg.view, %arg1: !linalg.view // Call external copy after promoting input and output structs to pointers // CHECK-COUNT-2: llvm.mlir.constant(1 : index){{.*[[:space:]].*}}llvm.alloca{{.*[[:space:]].*}}llvm.store -// CHECK: llvm.call @linalg_copy_viewxxxf32_viewxxxf32(%{{.*}}, %{{.*}}) : (!llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">, !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">) -> () +// CHECK: llvm.call @linalg_copy_viewsxsxsxf32_viewsxsxsxf32(%{{.*}}, %{{.*}}) : (!llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">, !llvm<"{ float*, i64, [3 x i64], [3 x i64] }*">) -> () diff --git a/mlir/test/Dialect/Linalg/loops.mlir b/mlir/test/Dialect/Linalg/loops.mlir index 83311ee..518189b 100644 --- a/mlir/test/Dialect/Linalg/loops.mlir +++ b/mlir/test/Dialect/Linalg/loops.mlir @@ -1,8 +1,18 @@ // RUN: mlir-opt %s -linalg-lower-to-loops | FileCheck %s -// CHECK-DAG: #[[S2D1:.*]] = (d0, d1) -> (d0 * 2 + d1) -// CHECK-DAG: #[[S2D3:.*]] = (d0, d1) -> (d0 * 2 + d1 * 4) -// CHECK-DAG: #[[S3D2:.*]] = (d0, d1) -> (d0 * 3 + d1 * 5) +// CHECK-DAG: #[[strided1D:.*]] = (d0)[s0] -> (d0 + s0) +#strided1D = (d0)[s0] -> (d0 + s0) +// CHECK-DAG: #[[strided2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) +#strided2D = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) +// CHECK-DAG: #[[strided3D:.*]] = (d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2) +#strided3D = (d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2) +// CHECK-DAG: #[[strided4D:.*]] = (d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3) +#strided4D = (d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3) + +// CHECK-DAG: #[[Stride2Dilation1:.*]] = (d0, d1) -> (d0 * 2 + d1) +// CHECK-DAG: #[[Stride2Dilation4:.*]] = (d0, d1) -> (d0 * 2 + d1 * 4) +// CHECK-DAG: #[[Stride3Dilation5:.*]] = (d0, d1) -> (d0 * 3 + d1 * 5) + func @matmul(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index) { %c0 = constant 0 : index @@ -10,182 +20,189 @@ func @matmul(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: in %I = linalg.range %c0:%arg1:%c1 : !linalg.range %J = linalg.range %c0:%arg2:%c1 : !linalg.range %K = linalg.range %c0:%arg3:%c1 : !linalg.range - %A = linalg.view %arg0[%I, %K] : !linalg.buffer -> !linalg.view - %B = linalg.view %arg0[%K, %J] : !linalg.buffer -> !linalg.view - %C = linalg.view %arg0[%I, %J] : !linalg.buffer -> !linalg.view - linalg.matmul(%A, %B, %C) : !linalg.view, !linalg.view, !linalg.view + %A = linalg.view %arg0[%I, %K] : !linalg.buffer -> memref + %B = linalg.view %arg0[%K, %J] : !linalg.buffer -> memref + %C = linalg.view %arg0[%I, %J] : !linalg.buffer -> memref + linalg.matmul(%A, %B, %C) : memref, memref, memref return } // CHECK-LABEL: func @matmul(%{{.*}}: !linalg.buffer, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { -// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer -> !linalg.view -// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer -> !linalg.view -// CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer -> !linalg.view -// CHECK: %[[M:.*]] = linalg.dim %[[A]], 0 : !linalg.view -// CHECK: %[[K:.*]] = linalg.dim %[[A]], 1 : !linalg.view -// CHECK: %[[N:.*]] = linalg.dim %[[B]], 1 : !linalg.view +// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer -> memref +// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer -> memref +// CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer -> memref +// CHECK: %[[M:.*]] = dim %[[A]], 0 : memref +// CHECK: %[[K:.*]] = dim %[[A]], 1 : memref +// CHECK: %[[N:.*]] = dim %[[B]], 1 : memref // CHECK: loop.for %{{.*}} = %{{.*}} to %[[M]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} { -// CHECK-DAG: %[[a:.*]] = linalg.load %[[A]][%{{.*}}, %{{.*}}] : !linalg.view -// CHECK-DAG: %[[b:.*]] = linalg.load %[[B]][%{{.*}}, %{{.*}}] : !linalg.view +// CHECK-DAG: %[[a:.*]] = load %[[A]][%{{.*}}, %{{.*}}] : memref +// CHECK-DAG: %[[b:.*]] = load %[[B]][%{{.*}}, %{{.*}}] : memref // CHECK-DAG: %[[inc:.*]] = mulf %[[a]], %[[b]] : f32 -// CHECK-DAG: %[[c:.*]] = linalg.load %[[C]][%{{.*}}, %{{.*}}] : !linalg.view +// CHECK-DAG: %[[c:.*]] = load %[[C]][%{{.*}}, %{{.*}}] : memref // CHECK-DAG: %[[res:.*]] = addf %[[c]], %[[inc]] : f32 -// CHECK: linalg.store %[[res]], %[[C]][%{{.*}}, %{{.*}}] : !linalg.view +// CHECK: store %[[res]], %[[C]][%{{.*}}, %{{.*}}] : memref func @matvec(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index) { %c0 = constant 0 : index %c1 = constant 1 : index %I = linalg.range %c0:%arg1:%c1 : !linalg.range %J = linalg.range %c0:%arg2:%c1 : !linalg.range - %2 = linalg.view %arg0[%I, %J] : !linalg.buffer -> !linalg.view - %3 = linalg.view %arg0[%J] : !linalg.buffer -> !linalg.view - %4 = linalg.view %arg0[%I] : !linalg.buffer -> !linalg.view - linalg.matvec(%2, %3, %4) : !linalg.view, !linalg.view, !linalg.view + %2 = linalg.view %arg0[%I, %J] : !linalg.buffer -> memref + %3 = linalg.view %arg0[%J] : !linalg.buffer -> memref + %4 = linalg.view %arg0[%I] : !linalg.buffer -> memref + linalg.matvec(%2, %3, %4) : memref, memref, memref return } // CHECK-LABEL: func @matvec(%{{.*}}: !linalg.buffer, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { -// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer -> !linalg.view -// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer -> !linalg.view -// CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer -> !linalg.view -// CHECK: %[[M:.*]] = linalg.dim %[[A]], 0 : !linalg.view -// CHECK: %[[K:.*]] = linalg.dim %[[A]], 1 : !linalg.view +// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer -> memref +// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer -> memref +// CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer -> memref +// CHECK: %[[M:.*]] = dim %[[A]], 0 : memref +// CHECK: %[[K:.*]] = dim %[[A]], 1 : memref // CHECK: loop.for %{{.*}} = %{{.*}} to %[[M]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} { -// CHECK-DAG: %[[a:.*]] = linalg.load %[[A]][%{{.*}}, %{{.*}}] : !linalg.view -// CHECK-DAG: %[[b:.*]] = linalg.load %[[B]][%{{.*}}] : !linalg.view +// CHECK-DAG: %[[a:.*]] = load %[[A]][%{{.*}}, %{{.*}}] : memref +// CHECK-DAG: %[[b:.*]] = load %[[B]][%{{.*}}] : memref // CHECK-DAG: %[[inc:.*]] = mulf %[[a]], %[[b]] : f32 -// CHECK-DAG: %[[c:.*]] = linalg.load %[[C]][%{{.*}}] : !linalg.view +// CHECK-DAG: %[[c:.*]] = load %[[C]][%{{.*}}] : memref // CHECK-DAG: %[[res:.*]] = addf %[[c]], %[[inc]] : f32 -// CHECK: linalg.store %[[res]], %[[C]][%{{.*}}] : !linalg.view +// CHECK: store %[[res]], %[[C]][%{{.*}}] : memref func @dot(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index) { %c0 = constant 0 : index %c1 = constant 1 : index %I = linalg.range %c0:%arg1:%c1 : !linalg.range - %1 = linalg.view %arg0[%I] : !linalg.buffer -> !linalg.view - %2 = linalg.view %arg0[%I] : !linalg.buffer -> !linalg.view - %3 = linalg.view %arg0[] : !linalg.buffer -> !linalg.view - linalg.dot(%1, %2, %3) : !linalg.view, !linalg.view, !linalg.view + %1 = linalg.view %arg0[%I] : !linalg.buffer -> memref + %2 = linalg.view %arg0[%I] : !linalg.buffer -> memref + %3 = linalg.view %arg0[] : !linalg.buffer -> memref + linalg.dot(%1, %2, %3) : memref, memref, memref return } // CHECK-LABEL: func @dot(%{{.*}}: !linalg.buffer, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { -// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer -> !linalg.view -// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer -> !linalg.view -// CHECK: %[[C:.*]] = linalg.view %arg0[] : !linalg.buffer -> !linalg.view -// CHECK: %[[K:.*]] = linalg.dim %[[A]], 0 : !linalg.view +// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer -> memref +// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.buffer -> memref +// CHECK: %[[C:.*]] = linalg.view %arg0[] : !linalg.buffer -> memref +// CHECK: %[[K:.*]] = dim %[[A]], 0 : memref // CHECK: loop.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} { -// CHECK-DAG: %[[a:.*]] = linalg.load %[[A]][%{{.*}}] : !linalg.view -// CHECK-DAG: %[[b:.*]] = linalg.load %[[B]][%{{.*}}] : !linalg.view +// CHECK-DAG: %[[a:.*]] = load %[[A]][%{{.*}}] : memref +// CHECK-DAG: %[[b:.*]] = load %[[B]][%{{.*}}] : memref // CHECK-DAG: %[[inc:.*]] = mulf %[[a]], %[[b]] : f32 -// CHECK-DAG: %[[c:.*]] = linalg.load %[[C]][] : !linalg.view +// CHECK-DAG: %[[c:.*]] = load %[[C]][] : memref // CHECK-DAG: %[[res:.*]] = addf %[[c]], %[[inc]] : f32 -// CHECK: linalg.store %[[res]], %[[C]][] : !linalg.view +// CHECK: store %[[res]], %[[C]][] : memref -func @dot_view(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !linalg.view) { - linalg.dot(%arg0, %arg1, %arg2) : !linalg.view, !linalg.view, !linalg.view +func @dot_view(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.dot(%arg0, %arg1, %arg2) : memref, memref, memref return } -// CHECK-LABEL: func @dot_view(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { -// CHECK: %[[K:.*]] = linalg.dim %arg0, 0 : !linalg.view +// CHECK-LABEL: func @dot_view( +// CHECK: %{{.*}}: memref, %{{.*}}: memref, %{{.*}}: memref) { +// CHECK: %[[K:.*]] = dim %arg0, 0 : memref // CHECK: loop.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} { -// CHECK-DAG: %[[a:.*]] = linalg.load %arg0[%{{.*}}] : !linalg.view -// CHECK-DAG: %[[b:.*]] = linalg.load %{{.*}}[%{{.*}}] : !linalg.view +// CHECK-DAG: %[[a:.*]] = load %arg0[%{{.*}}] : memref +// CHECK-DAG: %[[b:.*]] = load %{{.*}}[%{{.*}}] : memref // CHECK-DAG: %[[inc:.*]] = mulf %[[a]], %[[b]] : f32 -// CHECK-DAG: %[[c:.*]] = linalg.load %{{.*}}[] : !linalg.view +// CHECK-DAG: %[[c:.*]] = load %{{.*}}[] : memref // CHECK-DAG: %[[res:.*]] = addf %[[c]], %[[inc]] : f32 -// CHECK: linalg.store %[[res]], %{{.*}}[] : !linalg.view +// CHECK: store %[[res]], %{{.*}}[] : memref -func @fill_view(%arg0: !linalg.view, %arg1: f32) { - linalg.fill(%arg0, %arg1) : !linalg.view, f32 +func @fill_view(%arg0: memref, %arg1: f32) { + linalg.fill(%arg0, %arg1) : memref, f32 return } -// CHECK-LABEL: func @fill_view(%{{.*}}: !linalg.view, %{{.*}}: f32) { +// CHECK-LABEL: func @fill_view( +// CHECK: %{{.*}}: memref, %{{.*}}: f32) { // CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { -// CHECK: linalg.store %{{.*}}, %{{.*}}[%{{.*}}] : !linalg.view +// CHECK: store %{{.*}}, %{{.*}}[%{{.*}}] : memref -func @fill_view0(%arg0: !linalg.view, %arg1: f32) { - linalg.fill(%arg0, %arg1) : !linalg.view, f32 +func @fill_view0(%arg0: memref, %arg1: f32) { + linalg.fill(%arg0, %arg1) : memref, f32 return } -// CHECK-LABEL: func @fill_view0(%{{.*}}: !linalg.view, %{{.*}}: f32) { -// CHECK: linalg.store %{{.*}}, %{{.*}}[] : !linalg.view +// CHECK-LABEL: func @fill_view0(%{{.*}}: memref, %{{.*}}: f32) { +// CHECK: store %{{.*}}, %{{.*}}[] : memref -func @fill_view3(%arg0: !linalg.view, %arg1: f32) { - linalg.fill(%arg0, %arg1) : !linalg.view, f32 +func @fill_view3(%arg0: memref, %arg1: f32) { + linalg.fill(%arg0, %arg1) : memref, f32 return } -// CHECK-LABEL: func @fill_view3(%{{.*}}: !linalg.view, %{{.*}}: f32) { +// CHECK-LABEL: func @fill_view3( +// CHECK: %{{.*}}: memref, %{{.*}}: f32) { // CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { -// CHECK: linalg.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] : !linalg.view +// CHECK: store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] : memref -func @copy_view(%arg0: !linalg.view, %arg1: !linalg.view) { - linalg.copy(%arg0, %arg1) : !linalg.view, !linalg.view +func @copy_view(%arg0: memref, %arg1: memref) { + linalg.copy(%arg0, %arg1) : memref, memref return } -// CHECK-LABEL: func @copy_view(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { +// CHECK-LABEL: func @copy_view( +// CHECK: %{{.*}}: memref, %{{.*}}: memref) { // CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { -// CHECK: %[[L:.*]] = linalg.load %{{.*}}[%{{.*}}] : !linalg.view -// CHECK: linalg.store %[[L]], %{{.*}}[%{{.*}}] : !linalg.view +// CHECK: %[[L:.*]] = load %{{.*}}[%{{.*}}] : memref +// CHECK: store %[[L]], %{{.*}}[%{{.*}}] : memref -func @copy_view0(%arg0: !linalg.view, %arg1: !linalg.view) { - linalg.copy(%arg0, %arg1) : !linalg.view, !linalg.view +func @copy_view0(%arg0: memref, %arg1: memref) { + linalg.copy(%arg0, %arg1) : memref, memref return } -// CHECK-LABEL: func @copy_view0(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { -// CHECK: %{{.*}} = linalg.load %{{.*}}[] : !linalg.view -// CHECK: linalg.store %{{.*}}, %{{.*}}[] : !linalg.view +// CHECK-LABEL: func @copy_view0(%{{.*}}: memref, %{{.*}}: memref) { +// CHECK: %{{.*}} = load %{{.*}}[] : memref +// CHECK: store %{{.*}}, %{{.*}}[] : memref -func @copy_view3(%arg0: !linalg.view, %arg1: !linalg.view) { +func @copy_view3(%arg0: memref, %arg1: memref) { linalg.copy(%arg0, %arg1) {inputPermutation = (i, j, k) -> (i, k, j), outputPermutation = (i, j, k) -> (k, j, i)} : - !linalg.view, !linalg.view + memref, memref return } -// CHECK-LABEL: func @copy_view3(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { +// CHECK-LABEL: func @copy_view3 +// CHECK: (%{{.*}}: memref, %{{.*}}: memref) { // CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { -// CHECK: %[[L:.*]] = linalg.load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] : !linalg.view -// CHECK: linalg.store %[[L]], %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] : !linalg.view +// CHECK: %[[L:.*]] = load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] : memref +// CHECK: store %[[L]], %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] : memref -func @conv_view3(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !linalg.view) { - linalg.conv(%arg0, %arg1, %arg2) {strides = [2]}: !linalg.view, !linalg.view, !linalg.view +func @conv_view3(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.conv(%arg0, %arg1, %arg2) {strides = [2]}: memref, memref, memref return } -// CHECK-LABEL: func @conv_view3(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { -// CHECK: %[[Z0:.*]] = linalg.dim %arg0, 0 : !linalg.view -// CHECK: %[[Q:.*]] = linalg.dim %arg0, 1 : !linalg.view -// CHECK: %[[K:.*]] = linalg.dim %arg0, 2 : !linalg.view -// CHECK: %[[B:.*]] = linalg.dim %arg1, 0 : !linalg.view -// CHECK: %[[X0:.*]] = linalg.dim %arg2, 1 : !linalg.view +// CHECK-LABEL: func @conv_view3( +// CHECK: %{{.*}}: memref, %{{.*}}: memref, %{{.*}}: memref) { +// CHECK: %[[Z0:.*]] = dim %arg0, 0 : memref +// CHECK: %[[Q:.*]] = dim %arg0, 1 : memref +// CHECK: %[[K:.*]] = dim %arg0, 2 : memref +// CHECK: %[[B:.*]] = dim %arg1, 0 : memref +// CHECK: %[[X0:.*]] = dim %arg2, 1 : memref // CHECK: loop.for %{{.*}} = %{{.*}} to %[[B]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[X0]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[K]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[Q]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[Z0]] step %{{.*}} { -// CHECK: %[[SUM:.*]] = affine.apply #[[S2D1]](%{{.*}}, %{{.*}}) -// CHECK: %{{.*}} = linalg.load %{{.*}}[%{{.*}}, %[[SUM]], %{{.*}}] : !linalg.view -// CHECK: %{{.*}} = linalg.load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] : !linalg.view +// CHECK: %[[SUM:.*]] = affine.apply #[[Stride2Dilation1]](%{{.*}}, %{{.*}}) +// CHECK: %{{.*}} = load %{{.*}}[%{{.*}}, %[[SUM]], %{{.*}}] : memref +// CHECK: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] : memref // CHECK: %{{.*}} = mulf %{{.*}}, %{{.*}} : f32 -// CHECK: %{{.*}} = linalg.load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] : !linalg.view +// CHECK: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] : memref // CHECK: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 -// CHECK: linalg.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] : !linalg.view +// CHECK: store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}] : memref -func @conv_view4(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !linalg.view) { - linalg.conv(%arg0, %arg1, %arg2) {dilations = [4, 5], strides = [2, 3]} : !linalg.view, !linalg.view, !linalg.view +func @conv_view4(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.conv(%arg0, %arg1, %arg2) {dilations = [4, 5], strides = [2, 3]} : memref, memref, memref return } -// CHECK-LABEL: func @conv_view4(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { -// CHECK: %[[Z0:.*]] = linalg.dim %arg0, 0 : !linalg.view -// CHECK: %[[Z1:.*]] = linalg.dim %arg0, 1 : !linalg.view -// CHECK: %[[Q:.*]] = linalg.dim %arg0, 2 : !linalg.view -// CHECK: %[[K:.*]] = linalg.dim %arg0, 3 : !linalg.view -// CHECK: %[[B:.*]] = linalg.dim %arg1, 0 : !linalg.view -// CHECK: %[[X0:.*]] = linalg.dim %arg2, 1 : !linalg.view -// CHECK: %[[X1:.*]] = linalg.dim %arg2, 2 : !linalg.view +// CHECK-LABEL: func @conv_view4( +// CHECK: %{{.*}}: memref, %{{.*}}: memref, %{{.*}}: memref) { +// CHECK: %[[Z0:.*]] = dim %arg0, 0 : memref +// CHECK: %[[Z1:.*]] = dim %arg0, 1 : memref +// CHECK: %[[Q:.*]] = dim %arg0, 2 : memref +// CHECK: %[[K:.*]] = dim %arg0, 3 : memref +// CHECK: %[[B:.*]] = dim %arg1, 0 : memref +// CHECK: %[[X0:.*]] = dim %arg2, 1 : memref +// CHECK: %[[X1:.*]] = dim %arg2, 2 : memref // CHECK: loop.for %{{.*}} = %{{.*}} to %[[B]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[X0]] step %{{.*}} { // CHECK: loop.for %{{.*}} = %{{.*}} to %[[X1]] step %{{.*}} { @@ -193,14 +210,14 @@ func @conv_view4(%arg0: !linalg.view, %arg1: !linalg.view -// CHECK: %{{.*}} = linalg.load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : !linalg.view +// CHECK: %[[SUM0:.*]] = affine.apply #[[Stride2Dilation4]](%{{.*}}, %{{.*}}) +// CHECK: %[[SUM1:.*]] = affine.apply #[[Stride3Dilation5]](%{{.*}}, %{{.*}}) +// CHECK: %{{.*}} = load %{{.*}}[%{{.*}}, %[[SUM0]], %[[SUM1]], %{{.*}}] : memref +// CHECK: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref // CHECK: %{{.*}} = mulf %{{.*}}, %{{.*}} : f32 -// CHECK: %{{.*}} = linalg.load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : !linalg.view +// CHECK: %{{.*}} = load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref // CHECK: %{{.*}} = addf %{{.*}}, %{{.*}} : f32 -// CHECK: linalg.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : !linalg.view +// CHECK: store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref func @foo(%0: f32, %1: f32, %2: f32) -> (f32, f32) { %f0 = constant 0.0 : f32 @@ -219,9 +236,9 @@ func @foo(%0: f32, %1: f32, %2: f32) -> (f32, f32) { library_call = "external_function_name", doc = "B(i,j,k), C(i,k,j) = foo(A(i, j), B(i,j,k), C(i,k,j))" } -func @generic_function(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !linalg.view) { +func @generic_function(%arg0: memref, %arg1: memref, %arg2: memref) { linalg.generic #trait %arg0, %arg1, %arg2: - !linalg.view, !linalg.view, !linalg.view + memref, memref, memref return } // CHECK-LABEL: @foo @@ -229,12 +246,12 @@ func @generic_function(%arg0: !linalg.view, %arg1: !linalg.view -// CHECK: %[[b:.*]] = linalg.load %{{.*}}[%[[i]], %[[j]], %[[k]]] : !linalg.view -// CHECK: %[[c:.*]] = linalg.load %{{.*}}[%[[i]], %[[k]], %[[j]]] : !linalg.view +// CHECK: %[[a:.*]] = load %{{.*}}[%[[i]], %[[j]]] : memref +// CHECK: %[[b:.*]] = load %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref +// CHECK: %[[c:.*]] = load %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref // CHECK: %[[res:.*]]:2 = call @foo(%[[a]], %[[b]], %[[c]]) : (f32, f32, f32) -> (f32, f32) -// CHECK: linalg.store %[[res]]#0, %{{.*}}[%[[i]], %[[j]], %[[k]]] : !linalg.view -// CHECK: linalg.store %[[res]]#1, %{{.*}}[%[[i]], %[[k]], %[[j]]] : !linalg.view +// CHECK: store %[[res]]#0, %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref +// CHECK: store %[[res]]#1, %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref #trait2 = { n_views = [1, 2], @@ -243,23 +260,23 @@ func @generic_function(%arg0: !linalg.view, %arg1: !linalg.view, %arg1: !linalg.view, %arg2: !linalg.view) { +func @generic_region(%arg0: memref, %arg1: memref, %arg2: memref) { linalg.generic #trait2 %arg0, %arg1, %arg2 { ^bb0(%a: f32, %b: f32, %c: f32): %d = mulf %a, %b : f32 %e = addf %c, %d : f32 linalg.yield %d, %e : f32, f32 - }: !linalg.view, !linalg.view, !linalg.view + }: memref, memref, memref return } // CHECK-LABEL: @generic_region // CHECK: loop.for %[[i:.*]] = {{.*}} // CHECK: loop.for %[[j:.*]] = {{.*}} // CHECK: loop.for %[[k:.*]] = {{.*}} -// CHECK: %[[a:.*]] = linalg.load %{{.*}}[%[[i]], %[[j]]] : !linalg.view -// CHECK: %[[b:.*]] = linalg.load %{{.*}}[%[[i]], %[[j]], %[[k]]] : !linalg.view -// CHECK: %[[c:.*]] = linalg.load %{{.*}}[%[[i]], %[[k]], %[[j]]] : !linalg.view +// CHECK: %[[a:.*]] = load %{{.*}}[%[[i]], %[[j]]] : memref +// CHECK: %[[b:.*]] = load %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref +// CHECK: %[[c:.*]] = load %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref // CHECK: %[[d:.*]] = mulf %[[a]], %[[b]] : f32 // CHECK: %[[e:.*]] = addf %[[c]], %[[d]] : f32 -// CHECK: linalg.store %[[d]], %{{.*}}[%[[i]], %[[j]], %[[k]]] : !linalg.view -// CHECK: linalg.store %[[e]], %{{.*}}[%[[i]], %[[k]], %[[j]]] : !linalg.view +// CHECK: store %[[d]], %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref +// CHECK: store %[[e]], %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref diff --git a/mlir/test/Dialect/Linalg/promote.mlir b/mlir/test/Dialect/Linalg/promote.mlir index 611f1aa..becd95f 100644 --- a/mlir/test/Dialect/Linalg/promote.mlir +++ b/mlir/test/Dialect/Linalg/promote.mlir @@ -1,47 +1,51 @@ // RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=2,3,4 -linalg-tile-promote-full-tile-views=true | FileCheck %s -check-prefix=TILE-1D +// TILE-1D-DAG: #[[strided2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) +// TILE-1D-DAG: #[[strided2DnoOffset:.*]] = (d0, d1)[s0] -> (d0 * s0 + d1) +#strided2D = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) + func @matmul(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index) { %c0 = constant 0 : index %c1 = constant 1 : index %I = linalg.range %c0:%arg1:%c1 : !linalg.range %J = linalg.range %c0:%arg2:%c1 : !linalg.range %K = linalg.range %c0:%arg3:%c1 : !linalg.range - %A = linalg.view %arg0[%I, %K] : !linalg.buffer -> !linalg.view - %B = linalg.view %arg0[%K, %J] : !linalg.buffer -> !linalg.view - %C = linalg.view %arg0[%I, %J] : !linalg.buffer -> !linalg.view - linalg.matmul(%A, %B, %C) : !linalg.view, !linalg.view, !linalg.view + %A = linalg.view %arg0[%I, %K] : !linalg.buffer -> memref + %B = linalg.view %arg0[%K, %J] : !linalg.buffer -> memref + %C = linalg.view %arg0[%I, %J] : !linalg.buffer -> memref + linalg.matmul(%A, %B, %C) : memref, memref, memref return } // TILE-1D-LABEL: func @matmul(%{{.*}}: !linalg.buffer, %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { // TILE-1D: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // TILE-1D: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { // TILE-1D: loop.for %{{.*}} = %{{.*}} to %{{.*}} step %{{.*}} { -// TILE-1D: %[[vA:.*]] = linalg.subview {{.*}} : !linalg.view -// TILE-1D: %[[vB:.*]] = linalg.subview {{.*}} : !linalg.view -// TILE-1D: %[[vC:.*]] = linalg.subview {{.*}} : !linalg.view +// TILE-1D: %[[vA:.*]] = linalg.subview {{.*}} : memref +// TILE-1D: %[[vB:.*]] = linalg.subview {{.*}} : memref +// TILE-1D: %[[vC:.*]] = linalg.subview {{.*}} : memref /// // TILE-1D: %[[tmpA:.*]] = linalg.buffer_alloc : !linalg.buffer<8xf32> -// TILE-1D: %[[fullA:.*]] = linalg.view %[[tmpA]][{{.*}}] : !linalg.buffer<8xf32> -> !linalg.view -// TILE-1D: %[[partialA:.*]] = linalg.slice %[[fullA]][%{{.*}}, %{{.*}}] : !linalg.view, !linalg.range, !linalg.range, !linalg.view +// TILE-1D: %[[fullA:.*]] = linalg.view %[[tmpA]][{{.*}}] : !linalg.buffer<8xf32> -> memref +// TILE-1D: %[[partialA:.*]] = linalg.slice %[[fullA]][%{{.*}}, %{{.*}}] : memref, !linalg.range, !linalg.range, memref /// // TILE-1D: %[[tmpB:.*]] = linalg.buffer_alloc : !linalg.buffer<12xf32> -// TILE-1D: %[[fullB:.*]] = linalg.view %[[tmpB]][{{.*}}] : !linalg.buffer<12xf32> -> !linalg.view -// TILE-1D: %[[partialB:.*]] = linalg.slice %[[fullB]][%{{.*}}, %{{.*}}] : !linalg.view, !linalg.range, !linalg.range, !linalg.view +// TILE-1D: %[[fullB:.*]] = linalg.view %[[tmpB]][{{.*}}] : !linalg.buffer<12xf32> -> memref +// TILE-1D: %[[partialB:.*]] = linalg.slice %[[fullB]][%{{.*}}, %{{.*}}] : memref, !linalg.range, !linalg.range, memref /// // TILE-1D: %[[tmpC:.*]] = linalg.buffer_alloc : !linalg.buffer<6xf32> -// TILE-1D: %[[fullC:.*]] = linalg.view %[[tmpC]][{{.*}}] : !linalg.buffer<6xf32> -> !linalg.view -// TILE-1D: %[[partialC:.*]] = linalg.slice %[[fullC]][%{{.*}}, %{{.*}}] : !linalg.view, !linalg.range, !linalg.range, !linalg.view +// TILE-1D: %[[fullC:.*]] = linalg.view %[[tmpC]][{{.*}}] : !linalg.buffer<6xf32> -> memref +// TILE-1D: %[[partialC:.*]] = linalg.slice %[[fullC]][%{{.*}}, %{{.*}}] : memref, !linalg.range, !linalg.range, memref -// TILE-1D: linalg.fill(%[[fullA]], {{.*}}) : !linalg.view, f32 -// TILE-1D: linalg.fill(%[[fullB]], {{.*}}) : !linalg.view, f32 -// TILE-1D: linalg.fill(%[[fullC]], {{.*}}) : !linalg.view, f32 -// TILE-1D: linalg.copy(%[[vA]], %[[partialA]]) : !linalg.view, !linalg.view -// TILE-1D: linalg.copy(%[[vB]], %[[partialB]]) : !linalg.view, !linalg.view -// TILE-1D: linalg.copy(%[[vC]], %[[partialC]]) : !linalg.view, !linalg.view +// TILE-1D: linalg.fill(%[[fullA]], {{.*}}) : memref, f32 +// TILE-1D: linalg.fill(%[[fullB]], {{.*}}) : memref, f32 +// TILE-1D: linalg.fill(%[[fullC]], {{.*}}) : memref, f32 +// TILE-1D: linalg.copy(%[[vA]], %[[partialA]]) : memref, memref +// TILE-1D: linalg.copy(%[[vB]], %[[partialB]]) : memref, memref +// TILE-1D: linalg.copy(%[[vC]], %[[partialC]]) : memref, memref // -// TILE-1D: linalg.matmul(%[[fullA]], %[[fullB]], %[[fullC]]) : !linalg.view, !linalg.view, !linalg.view +// TILE-1D: linalg.matmul(%[[fullA]], %[[fullB]], %[[fullC]]) : memref, memref, memref // -// TILE-1D: linalg.copy(%[[partialC]], %[[vC]]) : !linalg.view, !linalg.view +// TILE-1D: linalg.copy(%[[partialC]], %[[vC]]) : memref, memref // // TILE-1D: linalg.buffer_dealloc %[[tmpA]] : !linalg.buffer<8xf32> // TILE-1D: linalg.buffer_dealloc %[[tmpB]] : !linalg.buffer<12xf32> diff --git a/mlir/test/Dialect/Linalg/roundtrip.mlir b/mlir/test/Dialect/Linalg/roundtrip.mlir index eefa409..58aa805 100644 --- a/mlir/test/Dialect/Linalg/roundtrip.mlir +++ b/mlir/test/Dialect/Linalg/roundtrip.mlir @@ -1,5 +1,14 @@ // RUN: mlir-opt %s | mlir-opt | FileCheck %s +// CHECK-DAG: #[[strided1D:.*]] = (d0)[s0] -> (d0 + s0) +#strided1D = (d0)[s0] -> (d0 + s0) +// CHECK-DAG: #[[strided2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) +#strided2D = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) +// CHECK-DAG: #[[strided3D:.*]] = (d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2) +#strided3D = (d0, d1, d2)[s0, s1, s2] -> (d0 * s1 + s0 + d1 * s2 + d2) +// CHECK-DAG: #[[strided6D:.*]] = (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5 + d5 * s6) +#strided6D = (d0, d1, d2, d3, d4, d5)[s0, s1, s2, s3, s4, s5, s6] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3 * s4 + d4 * s5 + d5 * s6) + // CHECK-DAG: #[[map0:.*]] = (d0, d1, d2) -> (d0, d2, d1) // CHECK-DAG: #[[map1:.*]] = (d0, d1, d2) -> (d2, d1, d0) @@ -40,21 +49,16 @@ func @buffer(%arg0: index, %arg1: index) { // CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer> // CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer> -func @view_fun(%arg0: !linalg.view>) { - return -} -// CHECK-LABEL: func @view_fun(%{{.*}}: !linalg.view>) { - func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) { %0 = muli %arg0, %arg0 : index %1 = linalg.buffer_alloc %0 : !linalg.buffer %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range - %3 = linalg.view %1[%2, %2] : !linalg.buffer -> !linalg.view - %4 = linalg.slice %3[%2, %2] : !linalg.view, !linalg.range, !linalg.range, !linalg.view - %5 = linalg.slice %3[%2, %arg2] : !linalg.view, !linalg.range, index, !linalg.view - %6 = linalg.slice %3[%arg2, %2] : !linalg.view, index, !linalg.range, !linalg.view - %7 = linalg.slice %3[%arg2, %arg3] : !linalg.view, index, index, !linalg.view - %8 = linalg.view %1[%2, %2] : !linalg.buffer -> !linalg.view> + %3 = linalg.view %1[%2, %2] : !linalg.buffer -> memref + %4 = linalg.slice %3[%2, %2] : memref, !linalg.range, !linalg.range, memref + %5 = linalg.slice %3[%2, %arg2] : memref, !linalg.range, index, memref + %6 = linalg.slice %3[%arg2, %2] : memref, index, !linalg.range, memref + %7 = linalg.slice %3[%arg2, %arg3] : memref, index, index, memref + %8 = linalg.view %1[%2, %2] : !linalg.buffer -> memref, #strided2D> linalg.buffer_dealloc %1 : !linalg.buffer return } @@ -62,33 +66,35 @@ func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index // CHECK-NEXT: muli %{{.*}}, %{{.*}} : index // CHECK-NEXT: linalg.buffer_alloc %{{.*}} : !linalg.buffer // CHECK-NEXT: linalg.range %{{.*}}:%{{.*}}:%{{.*}} : !linalg.range -// CHECK-NEXT: linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.buffer -> !linalg.view -// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.view, !linalg.range, !linalg.range, !linalg.view -// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.view, !linalg.range, index, !linalg.view -// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.view, index, !linalg.range, !linalg.view -// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.view, index, index, !linalg.view -// CHECK-NEXT: linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.buffer -> !linalg.view> +// CHECK-NEXT: linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.buffer -> memref +// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : memref, !linalg.range, !linalg.range, memref +// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : memref, !linalg.range, index, memref +// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : memref, index, !linalg.range, memref +// CHECK-NEXT: linalg.slice %{{.*}}[%{{.*}}, %{{.*}}] : memref, index, index, memref +// CHECK-NEXT: linalg.view %{{.*}}[%{{.*}}, %{{.*}}] : !linalg.buffer -> memref, #[[strided2D]]> // CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer -func @ops(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !linalg.view, %arg3: !linalg.view) { - linalg.matmul(%arg0, %arg0, %arg0) : !linalg.view, !linalg.view, !linalg.view - linalg.matvec(%arg0, %arg1, %arg2) : !linalg.view, !linalg.view, !linalg.view - linalg.dot(%arg1, %arg2, %arg3) : !linalg.view, !linalg.view, !linalg.view +func @ops(%arg0: memref, %arg1: memref, %arg2: memref, %arg3: memref) { + linalg.matmul(%arg0, %arg0, %arg0) : memref, memref, memref + linalg.matvec(%arg0, %arg1, %arg2) : memref, memref, memref + linalg.dot(%arg1, %arg2, %arg3) : memref, memref, memref return } -// CHECK-LABEL: func @ops(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { -// CHECK-NEXT: linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) : !linalg.view, !linalg.view, !linalg.view -// CHECK-NEXT: linalg.matvec(%{{.*}}, %{{.*}}, %{{.*}}) : !linalg.view, !linalg.view, !linalg.view -// CHECK-NEXT: linalg.dot(%{{.*}}, %{{.*}}, %{{.*}}) : !linalg.view, !linalg.view, !linalg.view +// CHECK-LABEL: func @ops(% +// CHECK: {{.*}}: memref, %{{.*}}: memref, %{{.*}}: memref, %{{.*}}: memref) { +// CHECK-NEXT: linalg.matmul(%{{.*}}, %{{.*}}, %{{.*}}) : memref, memref, memref +// CHECK-NEXT: linalg.matvec(%{{.*}}, %{{.*}}, %{{.*}}) : memref, memref, memref +// CHECK-NEXT: linalg.dot(%{{.*}}, %{{.*}}, %{{.*}}) : memref, memref, memref -func @dim(%arg0: !linalg.view) { - %0 = linalg.dim %arg0, 1 : !linalg.view +func @dim(%arg0: memref) { + %0 = dim %arg0, 1 : memref %1 = linalg.buffer_alloc %0 : !linalg.buffer linalg.buffer_dealloc %1 : !linalg.buffer return } -// CHECK-LABEL: func @dim(%{{.*}}: !linalg.view) { -// CHECK-NEXT: linalg.dim %{{.*}}, 1 : !linalg.view +// CHECK-LABEL: func @dim( +// CHECK: %{{.*}}: memref) { +// CHECK-NEXT: dim %{{.*}}, 1 : memref // CHECK-NEXT: linalg.buffer_alloc %{{.*}} : !linalg.buffer // CHECK-NEXT: linalg.buffer_dealloc %{{.*}} : !linalg.buffer @@ -105,7 +111,8 @@ func @linalg_for(%arg0 : index, %arg1 : index, %arg2 : index) { } return } -// CHECK-LABEL: func @linalg_for(%{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { +// CHECK-LABEL: func @linalg_for( +// CHECK: %{{.*}}: index, %{{.*}}: index, %{{.*}}: index) { // CHECK-NEXT: loop.for %{{.*}} to %{{.*}} step %{{.*}} { // CHECK-NEXT: loop.for %{{.*}} to %{{.*}} step %{{.*}} { // CHECK-NEXT: cmpi "slt", %{{.*}}, %{{.*}} : index @@ -114,70 +121,77 @@ func @linalg_for(%arg0 : index, %arg1 : index, %arg2 : index) { // CHECK-NEXT: select %{{.*}}, %{{.*}}, %{{.*}} : index // CHECK-NEXT: loop.for %{{.*}} to %{{.*}} step %{{.*}} { -func @fill_view(%arg0: !linalg.view, %arg1: f32) { - linalg.fill(%arg0, %arg1) : !linalg.view, f32 +func @fill_view(%arg0: memref, %arg1: f32) { + linalg.fill(%arg0, %arg1) : memref, f32 return } -// CHECK-LABEL: func @fill_view(%{{.*}}: !linalg.view, %{{.*}}: f32) { -// CHECK: linalg.fill(%{{.*}}, %{{.*}}) : !linalg.view, f32 +// CHECK-LABEL: func @fill_view( +// CHECK: %{{.*}}: memref, %{{.*}}: f32) { +// CHECK: linalg.fill(%{{.*}}, %{{.*}}) : memref, f32 -func @transpose(%arg0: !linalg.view) { - %0 = linalg.transpose %arg0 (i, j, k) -> (k, j, i) : !linalg.view +func @transpose(%arg0: memref) { + %0 = linalg.transpose %arg0 (i, j, k) -> (k, j, i) : memref return } // CHECK-LABEL: func @transpose -// CHECK: linalg.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) : !linalg.view +// CHECK: linalg.transpose %{{.*}} ([[i:.*]], [[j:.*]], [[k:.*]]) -> ([[k]], [[j]], [[i]]) : memref -func @fill_view3(%arg0: !linalg.view, %arg1: f32) { - linalg.fill(%arg0, %arg1) : !linalg.view, f32 +func @fill_view3(%arg0: memref, %arg1: f32) { + linalg.fill(%arg0, %arg1) : memref, f32 return } -// CHECK-LABEL: func @fill_view3(%{{.*}}: !linalg.view, %{{.*}}: f32) { -// CHECK: linalg.fill(%{{.*}}, %{{.*}}) : !linalg.view, f32 +// CHECK-LABEL: func @fill_view3( +// CHECK: %{{.*}}: memref, %{{.*}}: f32) { +// CHECK: linalg.fill(%{{.*}}, %{{.*}}) : memref, f32 -func @copy_view(%arg0: !linalg.view, %arg1: !linalg.view) { - linalg.copy(%arg0, %arg1) : !linalg.view, !linalg.view +func @copy_view(%arg0: memref, %arg1: memref) { + linalg.copy(%arg0, %arg1) : memref, memref return } -// CHECK-LABEL: func @copy_view(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { -// CHECK: linalg.copy(%{{.*}}, %{{.*}}) : !linalg.view, !linalg.view +// CHECK-LABEL: func @copy_view( +// CHECK: %{{.*}}: memref, %{{.*}}: memref) { +// CHECK: linalg.copy(%{{.*}}, %{{.*}}) : memref, memref -func @copy_view3(%arg0: !linalg.view, %arg1: !linalg.view) { +func @copy_view3(%arg0: memref, %arg1: memref) { linalg.copy(%arg0, %arg1) {inputPermutation = (i, j, k) -> (i, k, j), outputPermutation = (i, j, k) -> (k, j, i)} : - !linalg.view, !linalg.view + memref, memref return } -// CHECK-LABEL: func @copy_view3(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { -// CHECK: linalg.copy(%{{.*}}, %{{.*}}) {inputPermutation = #[[map0]], outputPermutation = #[[map1]]} : !linalg.view, !linalg.view +// CHECK-LABEL: func @copy_view3( +// CHECK: %{{.*}}: memref, %{{.*}}: memref) { +// CHECK: linalg.copy(%{{.*}}, %{{.*}}) {inputPermutation = #[[map0]], outputPermutation = #[[map1]]} : memref, memref -func @conv_view3(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !linalg.view) { - linalg.conv(%arg0, %arg1, %arg2) : !linalg.view, !linalg.view, !linalg.view +func @conv_view3(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.conv(%arg0, %arg1, %arg2) : memref, memref, memref return } -// CHECK-LABEL: func @conv_view3(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { -// CHECK: linalg.conv(%{{.*}}, %{{.*}}, %{{.*}}) : !linalg.view, !linalg.view, !linalg.view +// CHECK-LABEL: func @conv_view3( +// CHECK: %{{.*}}: memref, %{{.*}}: memref, %{{.*}}: memref) { +// CHECK: linalg.conv(%{{.*}}, %{{.*}}, %{{.*}}) : memref, memref, memref -func @conv_view6(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !linalg.view) { - linalg.conv(%arg0, %arg1, %arg2) {dilations = [4, 4, 5, 5], strides = [2, 2, 3, 3]} : !linalg.view, !linalg.view, !linalg.view +func @conv_view6(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.conv(%arg0, %arg1, %arg2) {dilations = [4, 4, 5, 5], strides = [2, 2, 3, 3]} : memref, memref, memref return } -// CHECK-LABEL: func @conv_view6(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { -// CHECK: linalg.conv(%{{.*}}, %{{.*}}, %{{.*}}) {dilations = [4, 4, 5, 5], strides = [2, 2, 3, 3]} : !linalg.view, !linalg.view, !linalg.view +// CHECK-LABEL: func @conv_view6( +// CHECK: %{{.*}}: memref, %{{.*}}: memref, %{{.*}}: memref) { +// CHECK: linalg.conv(%{{.*}}, %{{.*}}, %{{.*}}) {dilations = [4, 4, 5, 5], strides = [2, 2, 3, 3]} : memref, memref, memref -func @subview(%arg0: !linalg.view>) { +func @subview(%arg0: memref, #strided2D>) { %c0 = constant 0 : index - %0 = linalg.subview %arg0[%c0, %c0, %c0, %c0, %c0, %c0] : !linalg.view> + %0 = linalg.subview %arg0[%c0, %c0, %c0, %c0, %c0, %c0] : memref, #strided2D> return } -// CHECK-LABEL: func @subview(%{{.*}}: !linalg.view>) { +// CHECK-LABEL: func @subview( +// CHECK: %{{.*}}: memref, #[[strided2D]]>) { // CHECK: constant 0 : index -// CHECK: linalg.subview %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : !linalg.view> +// CHECK: linalg.subview %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : memref, #[[strided2D]]> func @const_buffer_view(%arg0: index, %arg1: index, %arg2: index) { %c0 = linalg.buffer_alloc : !linalg.buffer<17xf32> %c1 = linalg.range %arg0:%arg1:%arg2 : !linalg.range - %c2 = linalg.view %c0[%c1] : !linalg.buffer<17xf32> -> !linalg.view + %c2 = linalg.view %c0[%c1] : !linalg.buffer<17xf32> -> memref return } @@ -196,13 +210,13 @@ func @foo(%0: vector<3x4xi4>, %1: f32) -> f32 { %f0 = constant 0.0 : f32 return %f0 : f32 } -func @generic(%arg0: !linalg.view>, %arg1: !linalg.view) { - linalg.generic #trait %arg0, %arg1 {foo = 1} : !linalg.view>, !linalg.view +func @generic(%arg0: memref, #strided2D>, %arg1: memref) { + linalg.generic #trait %arg0, %arg1 {foo = 1} : memref, #strided2D>, memref return } // CHECK-LABEL: func @foo // CHECK-LABEL: func @generic -// CHECK: linalg.generic {fun = @foo, indexing_maps = [#map2, #map3], library_call = "external_function_name", n_loop_types = [3, 0, 0], n_views = [1, 1]} %{{.*}}, %{{.*}} {foo = 1 : i64}: !linalg.view>, !linalg.view +// CHECK: linalg.generic {fun = @foo, indexing_maps = [#{{.*}}, #{{.*}}], library_call = "external_function_name", n_loop_types = [3, 0, 0], n_views = [1, 1]} %{{.*}}, %{{.*}} {foo = 1 : i64}: memref, #[[strided2D]]>, memref #trait2 = { indexing_maps = #accesses, @@ -210,15 +224,15 @@ func @generic(%arg0: !linalg.view>, %arg1: !linalg.view>, %arg1: !linalg.view) { +func @generic_region(%arg0: memref, #strided2D>, %arg1: memref) { linalg.generic #trait2 %arg0, %arg1 { ^bb(%a: vector<3x4xi4>, %b: f32) : linalg.yield %b : f32 - } {foo = 1}: !linalg.view>, !linalg.view + } {foo = 1}: memref, #strided2D>, memref return } // CHECK-LABEL: func @generic_region -// CHECK: linalg.generic {indexing_maps = [#map2, #map3], library_call = "external_function_name", n_loop_types = [3, 0, 0], n_views = [1, 1]} %{{.*}}, %{{.*}} { +// CHECK: linalg.generic {indexing_maps = [#{{.*}}, #{{.*}}], library_call = "external_function_name", n_loop_types = [3, 0, 0], n_views = [1, 1]} %{{.*}}, %{{.*}} { // CHECK: ^{{.*}}(%{{.*}}: vector<3x4xi4>, %{{.*}}: f32): // no predecessors // CHECK: linalg.yield %{{.*}} : f32 -// CHECK: } {foo = 1 : i64}: !linalg.view>, !linalg.view \ No newline at end of file +// CHECK: } {foo = 1 : i64}: memref, #[[strided2D]]>, memref diff --git a/mlir/test/Dialect/Linalg/tile.mlir b/mlir/test/Dialect/Linalg/tile.mlir index c3d6826..6ace583 100644 --- a/mlir/test/Dialect/Linalg/tile.mlir +++ b/mlir/test/Dialect/Linalg/tile.mlir @@ -10,134 +10,146 @@ // TILE-234-DAG: #[[UB1:.*]] = (d0) -> (d0 + 3) // TILE-234-DAG: #[[UB2:.*]] = (d0) -> (d0 + 4) -func @matmul(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !linalg.view) { - linalg.matmul(%arg0, %arg1, %arg2) : !linalg.view, !linalg.view, !linalg.view +// TILE-2-DAG: #[[strided1D:.*]] = (d0)[s0] -> (d0 + s0) +// TILE-02-DAG: #[[strided1D:.*]] = (d0)[s0] -> (d0 + s0) +// TILE-002-DAG: #[[strided1D:.*]] = (d0)[s0] -> (d0 + s0) +// TILE-234-DAG: #[[strided1D:.*]] = (d0)[s0] -> (d0 + s0) +#strided1D = (d0)[s0] -> (d0 + s0) +// CHECK-DAG: #[[strided2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) +// TILE-2-DAG: #[[strided2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) +// TILE-02-DAG: #[[strided2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) +// TILE-002-DAG: #[[strided2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) +// TILE-234-DAG: #[[strided2D:.*]] = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) +#strided2D = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) + +func @matmul(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.matmul(%arg0, %arg1, %arg2) : memref, memref, memref return } -// TILE-2-LABEL: func @matmul(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { -// TILE-2: %[[M:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view +// TILE-2-LABEL: func @matmul( +// TILE-2: %[[M:.*]] = dim %{{.*}}, 0 : memref // TILE-2: loop.for %{{.*}} = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { // TILE-2: %[[a:.*]] = affine.apply #[[UB0]](%{{.*}}) -// TILE-2: %[[K:.*]] = linalg.dim %{{.*}}, 1 : !linalg.view -// TILE-2: %[[sAi:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[a]], %{{.*}}, %{{.*}}, %[[K]], %{{.*}}] : !linalg.view +// TILE-2: %[[K:.*]] = dim %{{.*}}, 1 : memref +// TILE-2: %[[sAi:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[a]], %{{.*}}, %{{.*}}, %[[K]], %{{.*}}] : memref // TILE-2: %[[c:.*]] = affine.apply #[[UB0]](%{{.*}}) -// TILE-2: %[[N:.*]] = linalg.dim %{{.*}}, 1 : !linalg.view -// TILE-2: %[[sCi:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[c]], %{{.*}}, %{{.*}}, %[[N]], %{{.*}}] : !linalg.view -// TILE-2: linalg.matmul(%[[sAi]], %{{.*}}, %[[sCi]]) : !linalg.view, !linalg.view, !linalg.view +// TILE-2: %[[N:.*]] = dim %{{.*}}, 1 : memref +// TILE-2: %[[sCi:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[c]], %{{.*}}, %{{.*}}, %[[N]], %{{.*}}] : memref +// TILE-2: linalg.matmul(%[[sAi]], %{{.*}}, %[[sCi]]) : memref, memref, memref -// TILE-02-LABEL: func @matmul(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { -// TILE-02: %[[N:.*]] = linalg.dim %arg1, 1 : !linalg.view +// TILE-02-LABEL: func @matmul( +// TILE-02: %[[N:.*]] = dim %arg1, 1 : memref // TILE-02: loop.for %{{.*}} = %{{.*}} to %[[N]] step %{{.*}} { -// TILE-02: %[[K:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view +// TILE-02: %[[K:.*]] = dim %{{.*}}, 0 : memref // TILE-02: %[[b:.*]] = affine.apply #[[UB0]](%{{.*}}) -// TILE-02: %[[sBj:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[K]], %{{.*}}, %{{.*}}, %[[b]], %{{.*}}] : !linalg.view -// TILE-02: %[[M:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view +// TILE-02: %[[sBj:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[K]], %{{.*}}, %{{.*}}, %[[b]], %{{.*}}] : memref +// TILE-02: %[[M:.*]] = dim %{{.*}}, 0 : memref // TILE-02: %[[c:.*]] = affine.apply #[[UB0]](%{{.*}}) -// TILE-02: %[[sCj:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[M]], %{{.*}}, %{{.*}}, %[[c]], %{{.*}}] : !linalg.view -// TILE-02: linalg.matmul(%{{.*}}, %[[sBj]], %[[sCj]]) : !linalg.view, !linalg.view, !linalg.view +// TILE-02: %[[sCj:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[M]], %{{.*}}, %{{.*}}, %[[c]], %{{.*}}] : memref +// TILE-02: linalg.matmul(%{{.*}}, %[[sBj]], %[[sCj]]) : memref, memref, memref -// TILE-002-LABEL: func @matmul(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { -// TILE-002: %[[K:.*]] = linalg.dim %{{.*}}, 1 : !linalg.view +// TILE-002-LABEL: func @matmul( +// TILE-002: %[[K:.*]] = dim %{{.*}}, 1 : memref // TILE-002: loop.for %{{.*}} = %{{.*}}{{.*}} to %[[K]] step %{{.*}} { -// TILE-002: %[[M:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view +// TILE-002: %[[M:.*]] = dim %{{.*}}, 0 : memref // TILE-002: %[[a:.*]] = affine.apply #[[UB0]](%{{.*}}) -// TILE-002: %[[sAj:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[M]], %{{.*}}, %{{.*}}, %[[a]], %{{.*}}] : !linalg.view +// TILE-002: %[[sAj:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[M]], %{{.*}}, %{{.*}}, %[[a]], %{{.*}}] : memref // TILE-002: %[[b:.*]] = affine.apply #[[UB0]](%{{.*}}) -// TILE-002: %[[N:.*]] = linalg.dim %{{.*}}, 1 : !linalg.view -// TILE-002: %[[sBj:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[b]], %{{.*}}, %{{.*}}, %[[N]], %{{.*}}] : !linalg.view -// TILE-002: linalg.matmul(%[[sAj]], %[[sBj]], %{{.*}}) : !linalg.view, !linalg.view, !linalg.view - -// TILE-234-LABEL: func @matmul(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { -// TILE-234: %[[M:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view -// TILE-234: %[[K:.*]] = linalg.dim %{{.*}}, 1 : !linalg.view -// TILE-234: %[[N:.*]] = linalg.dim %{{.*}}, 1 : !linalg.view +// TILE-002: %[[N:.*]] = dim %{{.*}}, 1 : memref +// TILE-002: %[[sBj:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[b]], %{{.*}}, %{{.*}}, %[[N]], %{{.*}}] : memref +// TILE-002: linalg.matmul(%[[sAj]], %[[sBj]], %{{.*}}) : memref, memref, memref + +// TILE-234-LABEL: func @matmul( +// TILE-234: %[[M:.*]] = dim %{{.*}}, 0 : memref +// TILE-234: %[[K:.*]] = dim %{{.*}}, 1 : memref +// TILE-234: %[[N:.*]] = dim %{{.*}}, 1 : memref // TILE-234: loop.for %{{.*}} = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { // TILE-234: loop.for %{{.*}} = %{{.*}}{{.*}} to %[[N]] step %{{.*}} { // TILE-234: loop.for %{{.*}} = %{{.*}}{{.*}} to %[[K]] step %{{.*}} { // TILE-234: %[[ai:.*]] = affine.apply #[[UB0]](%{{.*}}) // TILE-234: %[[ak:.*]] = affine.apply #[[UB2]](%{{.*}}) -// TILE-234: %[[sAik:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[ai]], %{{.*}}, %{{.*}}, %[[ak]], %{{.*}}] : !linalg.view +// TILE-234: %[[sAik:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[ai]], %{{.*}}, %{{.*}}, %[[ak]], %{{.*}}] : memref // TILE-234: %[[bk:.*]] = affine.apply #[[UB2]](%{{.*}}) // TILE-234: %[[bj:.*]] = affine.apply #[[UB1]](%{{.*}}) -// TILE-234: %[[sBkj:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[bk]], %{{.*}}, %{{.*}}, %[[bj]], %{{.*}}] : !linalg.view +// TILE-234: %[[sBkj:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[bk]], %{{.*}}, %{{.*}}, %[[bj]], %{{.*}}] : memref // TILE-234: %[[ci:.*]] = affine.apply #[[UB0]](%{{.*}}) // TILE-234: %[[cj:.*]] = affine.apply #[[UB1]](%{{.*}}) -// TILE-234: %[[sCij:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[ci]], %{{.*}}, %{{.*}}, %[[cj]], %{{.*}}] : !linalg.view +// TILE-234: %[[sCij:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[ci]], %{{.*}}, %{{.*}}, %[[cj]], %{{.*}}] : memref // -// TILE-234: linalg.matmul(%[[sAik]], %[[sBkj]], %[[sCij]]) : !linalg.view, !linalg.view, !linalg.view +// TILE-234: linalg.matmul(%[[sAik]], %[[sBkj]], %[[sCij]]) : memref, memref, memref -func @matvec(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !linalg.view) { - linalg.matvec(%arg0, %arg1, %arg2) : !linalg.view, !linalg.view, !linalg.view +func @matvec(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.matvec(%arg0, %arg1, %arg2) : memref, memref, memref return } -// TILE-2-LABEL: func @matvec(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { -// TILE-2: %[[M:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view +// TILE-2-LABEL: func @matvec( +// TILE-2: %[[M:.*]] = dim %{{.*}}, 0 : memref // TILE-2: loop.for %{{.*}} = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { // TILE-2: %[[a:.*]] = affine.apply #[[UB0]](%{{.*}}) -// TILE-2: %[[N:.*]] = linalg.dim %{{.*}}, 1 : !linalg.view -// TILE-2: %[[sAi:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[a]], %{{.*}}, %{{.*}}, %[[N]], %{{.*}}] : !linalg.view +// TILE-2: %[[N:.*]] = dim %{{.*}}, 1 : memref +// TILE-2: %[[sAi:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[a]], %{{.*}}, %{{.*}}, %[[N]], %{{.*}}] : memref // TILE-2: %[[c:.*]] = affine.apply #[[UB0]](%{{.*}}) -// TILE-2: %[[sCi:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[c]], %{{.*}}] : !linalg.view -// TILE-2: linalg.matvec(%[[sAi]], %{{.*}}, %[[sCi]]) : !linalg.view, !linalg.view, !linalg.view +// TILE-2: %[[sCi:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[c]], %{{.*}}] : memref +// TILE-2: linalg.matvec(%[[sAi]], %{{.*}}, %[[sCi]]) : memref, memref, memref -// TILE-02-LABEL: func @matvec(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { -// TILE-02: %[[K:.*]] = linalg.dim %{{.*}}, 1 : !linalg.view +// TILE-02-LABEL: func @matvec( +// TILE-02: %[[K:.*]] = dim %{{.*}}, 1 : memref // TILE-02: loop.for %{{.*}} = %{{.*}}{{.*}} to %[[K]] step %{{.*}} { -// TILE-02: %[[M:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view +// TILE-02: %[[M:.*]] = dim %{{.*}}, 0 : memref // TILE-02: %[[a:.*]] = affine.apply #[[UB0]](%{{.*}}) -// TILE-02: %[[sAj:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[M]], %{{.*}}, %{{.*}}, %[[a]], %{{.*}}] : !linalg.view +// TILE-02: %[[sAj:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[M]], %{{.*}}, %{{.*}}, %[[a]], %{{.*}}] : memref // TILE-02: %[[b:.*]] = affine.apply #[[UB0]](%{{.*}}) -// TILE-02: %[[sBj:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[b]], %{{.*}}] : !linalg.view -// TILE-02: linalg.matvec(%[[sAj]], %[[sBj]], %{{.*}}) : !linalg.view, !linalg.view, !linalg.view +// TILE-02: %[[sBj:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[b]], %{{.*}}] : memref +// TILE-02: linalg.matvec(%[[sAj]], %[[sBj]], %{{.*}}) : memref, memref, memref -// TILE-002-LABEL: func @matvec(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { +// TILE-002-LABEL: func @matvec( // TILE-002-NOT: loop.for -// TILE-234-LABEL: func @matvec(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { -// TILE-234: %[[M:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view -// TILE-234: %[[K:.*]] = linalg.dim %{{.*}}, 1 : !linalg.view +// TILE-234-LABEL: func @matvec( +// TILE-234: %[[M:.*]] = dim %{{.*}}, 0 : memref +// TILE-234: %[[K:.*]] = dim %{{.*}}, 1 : memref // TILE-234: loop.for %{{.*}} = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { // TILE-234: loop.for %{{.*}} = %{{.*}}{{.*}} to %[[K]] step %{{.*}} { // TILE-234: %[[ai:.*]] = affine.apply #[[UB0]](%{{.*}}) // TILE-234: %[[aj:.*]] = affine.apply #[[UB1]](%{{.*}}) -// TILE-234: %[[sAij:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[ai]], %{{.*}}, %{{.*}}, %[[aj]], %{{.*}}] : !linalg.view +// TILE-234: %[[sAij:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[ai]], %{{.*}}, %{{.*}}, %[[aj]], %{{.*}}] : memref // TILE-234: %[[bj:.*]] = affine.apply #[[UB1]](%{{.*}}) -// TILE-234: %[[sBj:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[bj]], %{{.*}}] : !linalg.view +// TILE-234: %[[sBj:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[bj]], %{{.*}}] : memref // TILE-234: %[[ci:.*]] = affine.apply #[[UB0]](%{{.*}}) -// TILE-234: %[[sCi:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[ci]], %{{.*}}] : !linalg.view +// TILE-234: %[[sCi:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[ci]], %{{.*}}] : memref // -// TILE-234: linalg.matvec(%[[sAij]], %[[sBj]], %[[sCi]]) : !linalg.view, !linalg.view, !linalg.view +// TILE-234: linalg.matvec(%[[sAij]], %[[sBj]], %[[sCi]]) : memref, memref, memref -func @dot(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !linalg.view) { - linalg.dot(%arg0, %arg1, %arg2) : !linalg.view, !linalg.view, !linalg.view +func @dot(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.dot(%arg0, %arg1, %arg2) : memref, memref, memref return } -// TILE-2-LABEL: func @dot(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { -// TILE-2: %[[M:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view +// TILE-2-LABEL: func @dot( +// TILE-2: %[[M:.*]] = dim %{{.*}}, 0 : memref // TILE-2: loop.for %{{.*}} = %{{.*}}{{.*}} to %[[M]] step %{{.*}} { // TILE-2: %[[a:.*]] = affine.apply #[[UB0]](%{{.*}}) -// TILE-2: %[[sAi:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[a]], %{{.*}}] : !linalg.view +// TILE-2: %[[sAi:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[a]], %{{.*}}] : memref // TILE-2: %[[b:.*]] = affine.apply #[[UB0]](%{{.*}}) -// TILE-2: %[[sBi:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[b]], %{{.*}}] : !linalg.view -// TILE-2: linalg.dot(%[[sAi]], %[[sBi]], {{.*}}) : !linalg.view, !linalg.view, !linalg.view +// TILE-2: %[[sBi:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[b]], %{{.*}}] : memref +// TILE-2: linalg.dot(%[[sAi]], %[[sBi]], {{.*}}) : memref, memref, memref -// TILE-02-LABEL: func @dot(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { +// TILE-02-LABEL: func @dot( // TILE-02-NOT: loop.for -// TILE-002-LABEL: func @dot(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { +// TILE-002-LABEL: func @dot( // TILE-002-NOT: loop.for -// TILE-234-LABEL: func @dot(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { -// TILE-234: %[[K:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view +// TILE-234-LABEL: func @dot( +// TILE-234: %[[K:.*]] = dim %{{.*}}, 0 : memref // TILE-234: loop.for %{{.*}} = %{{.*}}{{.*}} to %[[K]] step %{{.*}} { // TILE-234: %[[a:.*]] = affine.apply #[[UB0]](%{{.*}}) -// TILE-234: %[[sAi:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[a]], %{{.*}}] : !linalg.view +// TILE-234: %[[sAi:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[a]], %{{.*}}] : memref // TILE-234: %[[b:.*]] = affine.apply #[[UB0]](%{{.*}}) -// TILE-234: %[[sBi:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[b]], %{{.*}}] : !linalg.view -// TILE-234: linalg.dot(%[[sAi]], %[[sBi]], %{{.*}}) : !linalg.view, !linalg.view, !linalg.view +// TILE-234: %[[sBi:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[b]], %{{.*}}] : memref +// TILE-234: linalg.dot(%[[sAi]], %[[sBi]], %{{.*}}) : memref, memref, memref -func @fill(%arg0: !linalg.view, %arg1: f32) { - linalg.fill(%arg0, %arg1) : !linalg.view, f32 +func @fill(%arg0: memref, %arg1: f32) { + linalg.fill(%arg0, %arg1) : memref, f32 return } // TILE-2-LABEL: func @fill @@ -167,13 +179,13 @@ func @fill(%arg0: !linalg.view, %arg1: f32) { n_views = [2, 1] } -func @pointwise(%arg0: !linalg.view, %arg1: !linalg.view, - %arg2: !linalg.view) { +func @pointwise(%arg0: memref, %arg1: memref, + %arg2: memref) { linalg.generic #pointwise_2d_trait %arg0, %arg1, %arg2 { ^bb0(%arg4: f32, %arg5: f32, %arg6: f32): // no predecessors %4 = addf %arg4, %arg5 : f32 linalg.yield %4 : f32 - }: !linalg.view, !linalg.view, !linalg.view + }: memref, memref, memref return } // TILE-2-LABEL: func @pointwise diff --git a/mlir/test/Dialect/Linalg/tile_conv.mlir b/mlir/test/Dialect/Linalg/tile_conv.mlir index 128161e..a06b9f1 100644 --- a/mlir/test/Dialect/Linalg/tile_conv.mlir +++ b/mlir/test/Dialect/Linalg/tile_conv.mlir @@ -5,35 +5,39 @@ // TILE-23004-DAG: #[[UB2:.*]] = (d0) -> (d0 + 4) // TILE-23004-DAG: #[[D0x30pS0x10:.*]] = (d0) -> (d0 * 30) // TILE-23004-DAG: #[[D0x30pS0x10p90:.*]] = (d0)[s0] -> (d0 * 30 + s0 * 10 + 90) -func @conv(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !linalg.view) { - linalg.conv(%arg0, %arg1, %arg2) {dilations = [10, 20], strides = [30, 40]} : !linalg.view, !linalg.view, !linalg.view +// TILE-23004-DAG: #[[strided4D:.*]] = (d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3) +#strided4D = (d0, d1, d2, d3)[s0, s1, s2, s3] -> (d0 * s1 + s0 + d1 * s2 + d2 * s3 + d3) + +func @conv(%arg0: memref, %arg1: memref, %arg2: memref) { + linalg.conv(%arg0, %arg1, %arg2) {dilations = [10, 20], strides = [30, 40]} : memref, memref, memref return } -// TILE-23004-LABEL: func @conv(%{{.*}}: !linalg.view, %{{.*}}: !linalg.view, %{{.*}}: !linalg.view) { -// TILE-23004: %[[Q:.*]] = linalg.dim %{{.*}}, 2 : !linalg.view -// TILE-23004: %[[B:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view -// TILE-23004: %[[PaddedInput0:.*]] = linalg.dim %{{.*}}, 1 : !linalg.view -// TILE-23004: %[[X0:.*]] = linalg.dim %{{.*}}, 1 : !linalg.view +// TILE-23004-LABEL: func @conv( +// TILE-23004: %{{.*}}: memref, %{{.*}}: memref, %{{.*}}: memref) { +// TILE-23004: %[[Q:.*]] = dim %{{.*}}, 2 : memref +// TILE-23004: %[[B:.*]] = dim %{{.*}}, 0 : memref +// TILE-23004: %[[PaddedInput0:.*]] = dim %{{.*}}, 1 : memref +// TILE-23004: %[[X0:.*]] = dim %{{.*}}, 1 : memref // TILE-23004: loop.for %{{.*}} = %{{.*}} to %[[B]] step %{{.*}} { // TILE-23004: loop.for %{{.*}} = %{{.*}} to %[[X0]] step %{{.*}} { // TILE-23004: loop.for %{{.*}} = %{{.*}} to %[[Q]] step %{{.*}} { -// TILE-23004: %[[Z0:.*]] = linalg.dim %{{.*}}, 0 : !linalg.view -// TILE-23004: %[[Z1:.*]] = linalg.dim %{{.*}}, 1 : !linalg.view +// TILE-23004: %[[Z0:.*]] = dim %{{.*}}, 0 : memref +// TILE-23004: %[[Z1:.*]] = dim %{{.*}}, 1 : memref // TILE-23004: %[[I2p4:.*]] = affine.apply #[[UB2]](%{{.*}}) -// TILE-23004: %[[K:.*]] = linalg.dim %{{.*}}, 3 : !linalg.view -// TILE-23004: %[[FilterView:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[Z0]], %{{.*}}, %{{.*}}, %[[Z1]], %{{.*}}, %{{.*}}, %[[I2p4]], %{{.*}}, %{{.*}}, %[[K]], %{{.*}}] : !linalg.view +// TILE-23004: %[[K:.*]] = dim %{{.*}}, 3 : memref +// TILE-23004: %[[FilterView:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[Z0]], %{{.*}}, %{{.*}}, %[[Z1]], %{{.*}}, %{{.*}}, %[[I2p4]], %{{.*}}, %{{.*}}, %[[K]], %{{.*}}] : memref // // TILE-23004: %[[I0p3:.*]] = affine.apply #[[UB0]](%{{.*}}) // TILE-23004: %[[I1:.*]] = affine.apply #[[D0x30pS0x10]](%{{.*}}) // TILE-23004: %[[I1pStep:.*]] = affine.apply #[[D0x30pS0x10p90]](%{{.*}})[%[[PaddedInput0]]] -// TILE-23004: %[[SZ2:.*]] = linalg.dim %{{.*}}, 2 : !linalg.view +// TILE-23004: %[[SZ2:.*]] = dim %{{.*}}, 2 : memref // TILE-23004: %[[I2p2:.*]] = affine.apply #[[UB2]](%{{.*}}) -// TILE-23004: %[[InputView:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[I0p3]], %{{.*}}, %[[I1]], %[[I1pStep]], %{{.*}}, %{{.*}}, %[[SZ2]], %{{.*}}, %{{.*}}, %[[I2p2]], %{{.*}}] : !linalg.view +// TILE-23004: %[[InputView:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[I0p3]], %{{.*}}, %[[I1]], %[[I1pStep]], %{{.*}}, %{{.*}}, %[[SZ2]], %{{.*}}, %{{.*}}, %[[I2p2]], %{{.*}}] : memref // // TILE-23004: %[[B:.*]] = affine.apply #[[UB0]](%{{.*}}) // TILE-23004: %[[I1p3:.*]] = affine.apply #[[UB1]](%{{.*}}) -// TILE-23004: %[[X0:.*]] = linalg.dim %{{.*}}, 2 : !linalg.view -// TILE-23004: %[[X1:.*]] = linalg.dim %{{.*}}, 3 : !linalg.view -// TILE-23004: %[[OutputView:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[B]], %{{.*}}, %{{.*}}, %[[I1p3]], %{{.*}}, %{{.*}}, %[[X0]], %{{.*}}, %{{.*}}, %[[X1]], %{{.*}}] : !linalg.view +// TILE-23004: %[[X0:.*]] = dim %{{.*}}, 2 : memref +// TILE-23004: %[[X1:.*]] = dim %{{.*}}, 3 : memref +// TILE-23004: %[[OutputView:.*]] = linalg.subview %{{.*}}[%{{.*}}, %[[B]], %{{.*}}, %{{.*}}, %[[I1p3]], %{{.*}}, %{{.*}}, %[[X0]], %{{.*}}, %{{.*}}, %[[X1]], %{{.*}}] : memref // -// TILE-23004: linalg.conv(%[[FilterView]], %[[InputView]], %[[OutputView]]) {dilations = [10, 20], strides = [30, 40]} : !linalg.view, !linalg.view, !linalg.view +// TILE-23004: linalg.conv(%[[FilterView]], %[[InputView]], %[[OutputView]]) {dilations = [10, 20], strides = [30, 40]} : memref, memref, memref diff --git a/mlir/test/lib/Transforms/TestMemRefStrideCalculation.cpp b/mlir/test/lib/Transforms/TestMemRefStrideCalculation.cpp index 6efaec3..01dfdee 100644 --- a/mlir/test/lib/Transforms/TestMemRefStrideCalculation.cpp +++ b/mlir/test/lib/Transforms/TestMemRefStrideCalculation.cpp @@ -35,15 +35,13 @@ void TestMemRefStrideCalculation::runOnFunction() { llvm::outs() << "Testing: " << getFunction().getName() << "\n"; getFunction().walk([&](AllocOp allocOp) { auto memrefType = allocOp.getResult()->getType().cast(); - SmallVector strideVector; - if (failed(memrefType.getStridesAndOffset(strideVector))) { + int64_t offset; + SmallVector strides; + if (failed(memrefType.getStridesAndOffset(strides, offset))) { llvm::outs() << "MemRefType " << memrefType << " cannot be converted to " << "strided form\n"; return; } - ArrayRef strides(strideVector); - auto offset = strides.back(); - strides = strides.drop_back(); llvm::outs() << "MemRefType offset: "; if (offset == MemRefType::kDynamicStrideOrOffset) llvm::outs() << "?"; diff --git a/mlir/test/mlir-cpu-runner/cblas_interface.cpp b/mlir/test/mlir-cpu-runner/cblas_interface.cpp index 514d522..aae3112 100644 --- a/mlir/test/mlir-cpu-runner/cblas_interface.cpp +++ b/mlir/test/mlir-cpu-runner/cblas_interface.cpp @@ -29,23 +29,21 @@ template struct ViewType { unsigned long strides[N]; }; -// This is separated out to avoid `unsigned long sizes[0]` which triggers: -// warning: ISO C++ forbids zero-size array [-Wpedantic] template struct ViewType { T *data; unsigned long offset; }; extern "C" void linalg_fill_viewf32_f32(ViewType *X, float f) { - *(X->data + X->offset) = f; + X->data[X->offset] = f; } -extern "C" void linalg_fill_viewxf32_f32(ViewType *X, float f) { +extern "C" void linalg_fill_viewsxf32_f32(ViewType *X, float f) { for (unsigned i = 0; i < X->sizes[0]; ++i) *(X->data + X->offset + i * X->strides[0]) = f; } -extern "C" void linalg_fill_viewxxf32_f32(ViewType *X, float f) { +extern "C" void linalg_fill_viewsxsxf32_f32(ViewType *X, float f) { for (unsigned i = 0; i < X->sizes[0]; ++i) for (unsigned j = 0; j < X->sizes[1]; ++j) *(X->data + X->offset + i * X->strides[0] + j * X->strides[1]) = f; @@ -56,16 +54,16 @@ extern "C" void linalg_copy_viewf32_viewf32(ViewType *I, O->data[O->offset] = I->data[I->offset]; } -extern "C" void linalg_copy_viewxf32_viewxf32(ViewType *I, - ViewType *O) { +extern "C" void linalg_copy_viewsxf32_viewsxf32(ViewType *I, + ViewType *O) { assert(I->sizes[0] == O->sizes[0]); for (unsigned i = 0; i < I->sizes[0]; ++i) O->data[O->offset + i * O->strides[0]] = I->data[I->offset + i * I->strides[0]]; } -extern "C" void linalg_copy_viewxxf32_viewxxf32(ViewType *I, - ViewType *O) { +extern "C" void linalg_copy_viewsxsxf32_viewsxsxf32(ViewType *I, + ViewType *O) { assert(I->sizes[0] == O->sizes[0]); assert(I->sizes[1] == O->sizes[1]); auto so0 = O->strides[0], so1 = O->strides[1]; @@ -76,18 +74,18 @@ extern "C" void linalg_copy_viewxxf32_viewxxf32(ViewType *I, I->data[I->offset + i * si0 + j * si1]; } -extern "C" void linalg_dot_viewxf32_viewxf32_viewf32(ViewType *X, - ViewType *Y, - ViewType *Z) { +extern "C" void linalg_dot_viewsxf32_viewsxf32_viewf32(ViewType *X, + ViewType *Y, + ViewType *Z) { assert(X->strides[0] == 1); assert(Y->strides[0] == 1); assert(X->sizes[0] == Y->sizes[0] && "Expected X and Y of same size"); - *(Z->data + Z->offset) += + Z->data[Z->offset] += cblas_sdot(X->sizes[0], X->data + X->offset, X->strides[0], Y->data + Y->offset, Y->strides[0]); } -extern "C" void linalg_matmul_viewxxf32_viewxxf32_viewxxf32( +extern "C" void linalg_matmul_viewsxsxf32_viewsxsxf32_viewsxsxf32( ViewType *A, ViewType *B, ViewType *C) { assert(A->strides[1] == B->strides[1]); assert(A->strides[1] == C->strides[1]); diff --git a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir index 9c5d9aa..9fdd6b9 100644 --- a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir +++ b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir @@ -1,9 +1,12 @@ -// RUN: mlir-opt %s -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e dot -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s -// RUN: mlir-opt %s -linalg-lower-to-loops -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e dot -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s -// RUN: mlir-opt %s -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s -// RUN: mlir-opt %s -linalg-lower-to-loops -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s -// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=2,3,4 -linalg-tile-promote-full-tile-views=true -linalg-lower-to-loops -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s -// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=2,3,4 -linalg-tile-promote-full-tile-views=true -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s +// RUN: mlir-opt %s -linalg-convert-to-llvm | mlir-cpu-runner -e dot -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s +// RUN: mlir-opt %s -linalg-lower-to-loops -linalg-convert-to-llvm | mlir-cpu-runner -e dot -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s +// RUN: mlir-opt %s -linalg-convert-to-llvm | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s +// RUN: mlir-opt %s -linalg-lower-to-loops -linalg-convert-to-llvm | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s +// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=2,3,4 -linalg-tile-promote-full-tile-views=true -linalg-lower-to-loops -linalg-convert-to-llvm | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s +// RUN: mlir-opt %s -linalg-tile -linalg-tile-sizes=2,3,4 -linalg-tile-promote-full-tile-views=true -linalg-convert-to-llvm | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s + +#strided1D = (d0)[s0] -> (d0 + s0) +#strided2D = (d0, d1)[s0, s1] -> (d0 * s1 + s0 + d1) // Creates and returns a 1-D buffer of size %s filled with the value %f func @alloc_filled_f32(%s : index, %f : f32) -> !linalg.buffer { @@ -11,8 +14,8 @@ func @alloc_filled_f32(%s : index, %f : f32) -> !linalg.buffer { %c1 = constant 1 : index %buf = linalg.buffer_alloc %s {alignment = 256} : !linalg.buffer %R = linalg.range %c0:%s:%c1 : !linalg.range - %V = linalg.view %buf[%R] : !linalg.buffer -> !linalg.view - linalg.fill(%V, %f) : !linalg.view, f32 + %V = linalg.view %buf[%R] : !linalg.buffer -> memref + linalg.fill(%V, %f) : memref, f32 return %buf : !linalg.buffer } @@ -30,12 +33,12 @@ func @dot() -> f32 { %bC = call @alloc_filled_f32(%c1, %f10) : (index, f32) -> (!linalg.buffer) %R = linalg.range %c0:%c16:%c1 : !linalg.range - %A = linalg.view %bA[%R] : !linalg.buffer -> !linalg.view - %B = linalg.view %bB[%R] : !linalg.buffer -> !linalg.view - %C = linalg.view %bC[] : !linalg.buffer -> !linalg.view + %A = linalg.view %bA[%R] : !linalg.buffer -> memref + %B = linalg.view %bB[%R] : !linalg.buffer -> memref + %C = linalg.view %bC[] : !linalg.buffer -> memref - linalg.dot(%A, %B, %C) : !linalg.view, !linalg.view, !linalg.view - %res = linalg.load %C[] : !linalg.view + linalg.dot(%A, %B, %C) : memref, memref, memref + %res = load %C[] : memref linalg.buffer_dealloc %bC : !linalg.buffer linalg.buffer_dealloc %bB : !linalg.buffer @@ -65,12 +68,12 @@ func @matmul() -> f32 { %M = linalg.range %c0:%c10:%c1 : !linalg.range %N = linalg.range %c0:%c10:%c1 : !linalg.range %K = linalg.range %c0:%c16:%c1 : !linalg.range - %A = linalg.view %bA[%M, %K] : !linalg.buffer -> !linalg.view - %B = linalg.view %bB[%K, %N] : !linalg.buffer -> !linalg.view - %C = linalg.view %bC[%M, %N] : !linalg.buffer -> !linalg.view + %A = linalg.view %bA[%M, %K] : !linalg.buffer -> memref + %B = linalg.view %bB[%K, %N] : !linalg.buffer -> memref + %C = linalg.view %bC[%M, %N] : !linalg.buffer -> memref - linalg.matmul(%A, %B, %C) : !linalg.view, !linalg.view, !linalg.view - %res = linalg.load %C[%c6, %c7] : !linalg.view + linalg.matmul(%A, %B, %C) : memref, memref, memref + %res = load %C[%c6, %c7] : memref linalg.buffer_dealloc %bC : !linalg.buffer linalg.buffer_dealloc %bB : !linalg.buffer -- 2.7.4