[mlir][Linalg] Add MapOp to Linalg structured ops.
authorAdrian Kuegel <akuegel@google.com>
Wed, 12 Oct 2022 08:18:38 +0000 (10:18 +0200)
committerAdrian Kuegel <akuegel@google.com>
Wed, 12 Oct 2022 11:56:21 +0000 (13:56 +0200)
This will allow to model elementwise ops with this special op instead of using
GenericOp.
Also allow MapOp and ReduceOp to have no result if the output type is not a tensor.
This is needed for buffer semantics.

Differential Revision: https://reviews.llvm.org/D135754

mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/roundtrip.mlir

index c37d978..6bcf509 100644 (file)
@@ -225,12 +225,77 @@ def GenericOp : LinalgStructuredBase_Op<"generic", [
 
 
 //===----------------------------------------------------------------------===//
-// Reduce op.
+// Map op.
 //===----------------------------------------------------------------------===//
 
 def TensorOrMemref :
   AnyTypeOf<[AnyMemRef, AnyRankedTensor], "", "::mlir::ShapedType">;
 
+def MapOp : LinalgStructuredBase_Op<"map", [
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
+    DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>,
+    SingleBlockImplicitTerminator<"YieldOp">]> {
+  let summary = "Elementwise operations";
+  let description = [{
+    Models elementwise operations on tensors in terms of arithmetic operations
+    on the corresponding elements.
+
+    Example:
+    ```
+      %add = linalg.map
+          ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
+          outs(%init: tensor<64xf32>)
+          (%lhs_elem: f32, %rhs_elem: f32) {
+            %0 = arith.addf %lhs_elem, %rhs_elem: f32
+            linalg.yield %0: f32
+          }
+    ```
+  }];
+
+  let arguments = (ins
+    // Input args
+    Variadic<TensorOrMemref>:$inputs,
+
+    // Output arg
+    TensorOrMemref:$init
+  );
+  let results = (outs Variadic<AnyTensor>:$result);
+  let regions = (region SizedRegion<1>:$mapper);
+
+  let extraClassDeclaration = structuredOpsBaseDecls # [{
+    // Implement functions necessary for LinalgStructuredInterface.
+    ArrayAttr getIteratorTypes();
+    ArrayAttr getIndexingMaps();
+    std::string getLibraryCallName() {
+      return "op_has_no_registered_library_name";
+    }
+
+    // Implement functions necessary for DestinationStyleOpInterface.
+    unsigned getNumInputs() {
+      return this->getOperation()->getNumOperands() - getNumOutputs();
+    };
+    unsigned getNumOutputs() { return 1; };
+    mlir::ValueRange getOutputs() { return getOperands().take_back(1); }
+    linalg::OpOperandVector getOpOperandsMatchingBBargs() {
+      return getInputOperands();
+    }
+
+    static std::function<void(mlir::ImplicitLocOpBuilder &, mlir::Block &,
+                              mlir::ArrayRef<mlir::NamedAttribute>)>
+    getRegionBuilder() {
+      return nullptr;
+    }
+  }];
+
+  let hasCustomAssemblyFormat = 1;
+  let hasVerifier = 1;
+}
+
+
+//===----------------------------------------------------------------------===//
+// Reduce op.
+//===----------------------------------------------------------------------===//
+
 def ReduceOp : LinalgStructuredBase_Op<"reduce", [
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmResultNames"]>,
     DeclareOpInterfaceMethods<OpAsmOpInterface, ["getAsmBlockArgumentNames"]>,
@@ -264,7 +329,7 @@ def ReduceOp : LinalgStructuredBase_Op<"reduce", [
     ConfinedAttr<DenseI64ArrayAttr,
                  [DenseArrayStrictlySorted<DenseI64ArrayAttr>]>:$dimensions
   );
-  let results = (outs Variadic<TensorOrMemref>);
+  let results = (outs Variadic<AnyTensor>);
   let regions = (region SizedRegion<1>:$combiner);
 
   let extraClassDeclaration = structuredOpsBaseDecls # [{
index 5dbec1e..61b9386 100644 (file)
@@ -1289,6 +1289,135 @@ LogicalResult GenericOp::fold(ArrayRef<Attribute>,
 }
 
 //===----------------------------------------------------------------------===//
+// MapOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseDstStyleOp(
+    OpAsmParser &parser, OperationState &result,
+    function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
+        nullptr) {
+  // Parse `ins` and `outs`.
+  SmallVector<Type, 4> inputTypes, outputTypes;
+  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
+                                   /*addOperandSegmentSizes=*/false))
+    return failure();
+
+  // Add result types.
+  for (Type outputType : outputTypes) {
+    if (outputType.isa<RankedTensorType>())
+      result.addTypes(outputType);
+  }
+
+  // Parse required attributes.
+  if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
+    return failure();
+
+  // Parse optional attributes.
+  if (parser.parseOptionalAttrDict(result.attributes))
+    return failure();
+  return success();
+}
+
+void MapOp::getAsmBlockArgumentNames(Region &region,
+                                     OpAsmSetValueNameFn setNameFn) {
+  for (Value v : getRegionInputArgs())
+    setNameFn(v, "in");
+}
+
+void MapOp::getAsmResultNames(function_ref<void(Value, StringRef)> setNameFn) {
+  if (!getResults().empty())
+    setNameFn(getResults().front(), "mapped");
+}
+
+ParseResult MapOp::parse(OpAsmParser &parser, OperationState &result) {
+  if (parseDstStyleOp(parser, result))
+    return failure();
+
+  SmallVector<OpAsmParser::Argument> regionArgs;
+  if (parser.parseArgumentList(regionArgs, OpAsmParser::Delimiter::Paren,
+                               /*allowType=*/true, /*allowAttrs=*/true)) {
+    return failure();
+  }
+
+  Region *body = result.addRegion();
+  if (parser.parseRegion(*body, regionArgs))
+    return failure();
+
+  return success();
+}
+
+void MapOp::print(OpAsmPrinter &p) {
+  printCommonStructuredOpParts(p, getInputs(), getOutputs());
+  p.printOptionalAttrDict((*this)->getAttrs());
+
+  p << "(";
+  llvm::interleaveComma(getMapper().getArguments(), p,
+                        [&](auto arg) { p.printRegionArgument(arg); });
+  p << ") ";
+
+  p.printRegion(getMapper(), /*printEntryBlockArgs=*/false);
+}
+
+LogicalResult MapOp::verify() {
+  auto *bodyBlock = getBody();
+  auto blockArgs = bodyBlock->getArguments();
+
+  // Checks if the number of `inputs` match the arity of the `mapper` region.
+  if (getInputs().size() != blockArgs.size())
+    return emitOpError() << "expects number of operands to match the arity of "
+                            "mapper, but got: "
+                         << getInputs().size() << " and " << blockArgs.size();
+
+  // The parameters of mapper should all match the element type // of inputs.
+  for (const auto &[bbArgType, inputArg] :
+       llvm::zip(bodyBlock->getArgumentTypes(), getInputs())) {
+    auto inputElemType = inputArg.getType().cast<ShapedType>().getElementType();
+    if (bbArgType != inputElemType) {
+      return emitOpError() << "expected element type of input " << inputElemType
+                           << " to match bbArg type " << bbArgType;
+    }
+  }
+
+  // The shape of each input must match the shape of the output.
+  auto outputShape =
+      getOutputs().front().getType().cast<ShapedType>().getShape();
+  for (Type inputArgType : TypeRange{getInputs()}) {
+    auto inputElemShape = inputArgType.cast<ShapedType>().getShape();
+    if (inputElemShape != outputShape) {
+      return emitOpError() << "expected shape of input (" << inputElemShape
+                           << ") to match shape of output (" << outputShape
+                           << ")";
+    }
+  }
+
+  return success();
+}
+
+ArrayAttr MapOp::getIteratorTypes() {
+  int64_t rank = getInit().getType().getRank();
+  return Builder(getContext())
+      .getStrArrayAttr(
+          SmallVector<StringRef>(rank, getParallelIteratorTypeName()));
+}
+
+ArrayAttr MapOp::getIndexingMaps() {
+  Builder builder(getContext());
+  int64_t rank = getInit().getType().getRank();
+  int64_t numIndexingMaps = getOperands().size();
+  return builder.getAffineMapArrayAttr(SmallVector<AffineMap>(
+      numIndexingMaps, builder.getMultiDimIdentityMap(rank)));
+}
+
+void MapOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  SmallVector<Value> inputBuffers = getInputBufferOperands();
+  SmallVector<Value> outputBuffers = getOutputBufferOperands();
+  getGenericEffectsImpl(effects, getOperation()->getResults(), inputBuffers,
+                        outputBuffers);
+}
+
+//===----------------------------------------------------------------------===//
 // ReduceOp
 //===----------------------------------------------------------------------===//
 
@@ -1302,7 +1431,8 @@ void ReduceOp::getAsmBlockArgumentNames(Region &region,
 
 void ReduceOp::getAsmResultNames(
     function_ref<void(Value, StringRef)> setNameFn) {
-  setNameFn(getResults().front(), "reduced");
+  if (!getResults().empty())
+    setNameFn(getResults().front(), "reduced");
 }
 
 ArrayAttr ReduceOp::getIteratorTypes() {
@@ -1336,33 +1466,6 @@ void ReduceOp::getEffects(
                         outputBuffers);
 }
 
-static ParseResult parseDstStyleOp(
-    OpAsmParser &parser, OperationState &result,
-    function_ref<ParseResult(OpAsmParser &, NamedAttrList &)> parseAttrsFn =
-        nullptr) {
-  // Parse `ins` and `outs`.
-  SmallVector<Type, 4> inputTypes, outputTypes;
-  if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes,
-                                   /*addOperandSegmentSizes=*/false))
-    return failure();
-
-  // Add result types.
-  for (Type outputType : outputTypes) {
-    if (!outputType.isa<RankedTensorType>())
-      return failure();
-    result.addTypes(outputType);
-  }
-
-  // Parse required attributes.
-  if (parseAttrsFn && failed(parseAttrsFn(parser, result.attributes)))
-    return failure();
-
-  // Parse optional attributes.
-  if (parser.parseOptionalAttrDict(result.attributes))
-    return failure();
-  return success();
-}
-
 static ParseResult parseDenseI64ArrayAttr(OpAsmParser &parser,
                                           NamedAttrList &attributes,
                                           StringRef attributeName) {
index 9ae761e..00352c4 100644 (file)
@@ -391,6 +391,70 @@ func.func @invalid_reverse(%A: memref<5xf32>, %B: memref<5xf32>) {
 
 // -----
 
+func.func @map_binary_wrong_yield_operands(
+    %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
+    -> tensor<64xf32> {
+   %add = linalg.map
+          ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
+          outs(%init:tensor<64xf32>)
+          (%lhs_elem: f32, %rhs_elem: f32) {
+            %0 = arith.addf %lhs_elem, %rhs_elem: f32
+            // expected-error @+1{{'linalg.yield' op expected number of yield values (1) to match the number of operands of the enclosing LinalgOp (2)}}
+            linalg.yield %0, %0: f32, f32
+          }
+  func.return %add : tensor<64xf32>
+}
+
+// -----
+
+func.func @map_input_mapper_arity_mismatch(
+    %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
+    -> tensor<64xf32> {
+  // expected-error@+1{{'linalg.map' op expects number of operands to match the arity of mapper, but got: 2 and 3}}
+  %add = linalg.map
+      ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
+      outs(%init:tensor<64xf32>)
+      (%lhs_elem: f32, %rhs_elem: f32, %extra_elem: f32) {
+        %0 = arith.addf %lhs_elem, %rhs_elem: f32
+        linalg.yield %0: f32
+      }
+  func.return %add : tensor<64xf32>
+}
+
+// -----
+
+func.func @map_input_mapper_type_mismatch(
+    %lhs: tensor<64xf32>, %rhs: tensor<64xf32>, %init: tensor<64xf32>)
+    -> tensor<64xf32> {
+    // expected-error@+1{{'linalg.map' op expected element type of input 'f32' to match bbArg type 'f64'}}
+  %add = linalg.map
+      ins(%lhs, %rhs : tensor<64xf32>, tensor<64xf32>)
+      outs(%init:tensor<64xf32>)
+      (%lhs_elem: f64, %rhs_elem: f64) {
+        %0 = arith.addf %lhs_elem, %rhs_elem: f64
+        linalg.yield %0: f64
+      }
+  func.return %add : tensor<64xf32>
+}
+
+// -----
+
+func.func @map_input_output_shape_mismatch(
+    %lhs: tensor<64x64xf32>, %rhs: tensor<64x64xf32>, %init: tensor<32xf32>)
+    -> tensor<32xf32> {
+    // expected-error@+1{{'linalg.map' op expected shape of input (64, 64) to match shape of output (32)}}
+  %add = linalg.map
+      ins(%lhs, %rhs : tensor<64x64xf32>, tensor<64x64xf32>)
+      outs(%init:tensor<32xf32>)
+      (%lhs_elem: f32, %rhs_elem: f32) {
+        %0 = arith.addf %lhs_elem, %rhs_elem: f32
+        linalg.yield %0: f32
+      }
+  func.return %add : tensor<32xf32>
+}
+
+// -----
+
 func.func @reduce_input_vs_init_dimension_mismatch(
     %input: tensor<16x32x64xf32>,
     %init: tensor<16x64xf32>)  -> tensor<16x64xf32> {
index 3fb6c3d..02471b1 100644 (file)
@@ -354,8 +354,70 @@ func.func @mixed_parallel_reduced_results(%arg0 : tensor<?x?x?xf32>,
 
 // -----
 
+func.func @map_binary(%lhs: tensor<64xf32>, %rhs: tensor<64xf32>,
+                      %init: tensor<64xf32>) -> tensor<64xf32> {
+   %add = linalg.map
+          ins(%lhs, %rhs: tensor<64xf32>, tensor<64xf32>)
+          outs(%init:tensor<64xf32>)
+          (%lhs_elem: f32, %rhs_elem: f32) {
+            %0 = arith.addf %lhs_elem, %rhs_elem: f32
+            linalg.yield %0: f32
+          }
+  func.return %add : tensor<64xf32>
+}
+// CHECK-LABEL: func @map_binary
+//       CHECK:     linalg.map
+
+// -----
+
+func.func @map_binary_memref(%lhs: memref<64xf32>, %rhs: memref<64xf32>,
+                      %init: memref<64xf32>) {
+   linalg.map
+      ins(%lhs, %rhs: memref<64xf32>, memref<64xf32>)
+      outs(%init:memref<64xf32>)
+      (%lhs_elem: f32, %rhs_elem: f32) {
+        %0 = arith.addf %lhs_elem, %rhs_elem: f32
+        linalg.yield %0: f32
+      }
+  func.return
+}
+// CHECK-LABEL: func @map_binary_memref
+//       CHECK:     linalg.map
+
+// -----
+
+func.func @map_unary(%input: tensor<64xf32>, %init: tensor<64xf32>) -> tensor<64xf32> {
+   %abs = linalg.map
+          ins(%input:tensor<64xf32>)
+          outs(%init:tensor<64xf32>)
+          (%input_elem: f32) {
+            %0 = math.absf %input_elem: f32
+            linalg.yield %0: f32
+          }
+  func.return %abs : tensor<64xf32>
+}
+// CHECK-LABEL: func @map_unary
+//       CHECK:     linalg.map
+
+// -----
+
+func.func @map_unary_memref(%input: memref<64xf32>, %init: memref<64xf32>) {
+   linalg.map
+      ins(%input:memref<64xf32>)
+      outs(%init:memref<64xf32>)
+      (%input_elem: f32) {
+        %0 = math.absf %input_elem: f32
+        linalg.yield %0: f32
+      }
+  func.return
+}
+// CHECK-LABEL: func @map_unary_memref
+//       CHECK:     linalg.map
+
+// -----
+
 func.func @reduce(%input: tensor<16x32x64xf32>,
-                     %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
+                  %init: tensor<16x64xf32>) -> tensor<16x64xf32> {
   %reduce = linalg.reduce
       ins(%input:tensor<16x32x64xf32>)
       outs(%init:tensor<16x64xf32>)
@@ -371,6 +433,23 @@ func.func @reduce(%input: tensor<16x32x64xf32>,
 
 // -----
 
+func.func @reduce_memref(%input: memref<16x32x64xf32>,
+                         %init: memref<16x64xf32>) {
+  linalg.reduce
+      ins(%input:memref<16x32x64xf32>)
+      outs(%init:memref<16x64xf32>)
+      dimensions = [1]
+      (%in: f32, %out: f32) {
+        %0 = arith.addf %in, %out: f32
+        linalg.yield %0: f32
+      }
+  func.return
+}
+// CHECK-LABEL: func @reduce_memref
+//       CHECK:     linalg.reduce
+
+// -----
+
 func.func @variadic_reduce(%input1: tensor<16x32x64xf32>,
     %init1: tensor<16x64xf32>, %input2: tensor<16x32x64xi64>,
     %init2: tensor<16x64xi64>)  -> (tensor<16x64xf32>, tensor<16x64xi64>) {
@@ -387,3 +466,22 @@ func.func @variadic_reduce(%input1: tensor<16x32x64xf32>,
 }
 // CHECK-LABEL: func @variadic_reduce
 //       CHECK:     linalg.reduce
+
+// -----
+
+func.func @variadic_reduce_memref(%input1: memref<16x32x64xf32>,
+    %init1: memref<16x64xf32>, %input2: memref<16x32x64xi64>,
+    %init2: memref<16x64xi64>) {
+  linalg.reduce
+      ins(%input1, %input2 : memref<16x32x64xf32>, memref<16x32x64xi64>)
+      outs(%init1, %init2 : memref<16x64xf32>, memref<16x64xi64>)
+      dimensions = [1]
+      (%in1: f32, %in2: i64, %out1: f32, %out2: i64) {
+        %0 = arith.addf %in1, %out1: f32
+        %1 = arith.addi %in2, %out2: i64
+        linalg.yield %0, %1: f32, i64
+      }
+  func.return
+}
+// CHECK-LABEL: func @variadic_reduce_memref
+//       CHECK:     linalg.reduce