Revert "[mlir][Linalg] Improve region support in Linalg ops."
authorMehdi Amini <joker.eph@gmail.com>
Fri, 12 Feb 2021 18:15:15 +0000 (18:15 +0000)
committerMehdi Amini <joker.eph@gmail.com>
Fri, 12 Feb 2021 18:15:51 +0000 (18:15 +0000)
This reverts commit 973e133b769773c89ce4b8bbfd6c77612d2ff9d4.

It triggers an issue in gcc5 that require investigation, the build is
broken with:

/tmp/ccdpj3B9.s: Assembler messages:
/tmp/ccdpj3B9.s:5821: Error: symbol `_ZNSt17_Function_handlerIFvjjEUljjE2_E9_M_invokeERKSt9_Any_dataOjS6_' is already defined
/tmp/ccdpj3B9.s:5860: Error: symbol `_ZNSt14_Function_base13_Base_managerIUljjE2_E10_M_managerERSt9_Any_dataRKS3_St18_Manager_operation' is already defined

mlir/include/mlir/Dialect/Linalg/IR/LinalgInterfaces.td
mlir/include/mlir/Dialect/Linalg/IR/LinalgStructuredOps.td
mlir/lib/Conversion/LinalgToStandard/LinalgToStandard.cpp
mlir/lib/Dialect/Linalg/EDSC/Builders.cpp
mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/Generalization.cpp
mlir/lib/Dialect/Linalg/Transforms/Loops.cpp
mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp
mlir/test/Transforms/copy-removal.mlir
mlir/test/mlir-linalg-ods-gen/test-linalg-ods-gen.tc
mlir/tools/mlir-linalg-ods-gen/mlir-linalg-ods-gen.cpp

index 95656eb..c26f022 100644 (file)
@@ -1056,6 +1056,20 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
     //===------------------------------------------------------------------===//
     // Other static interface methods.
     //===------------------------------------------------------------------===//
+    StaticInterfaceMethod<
+      /*desc=*/[{
+        Create an operation of the current type with the given location,
+        operands, and attributes.
+      }],
+      /*retTy=*/"Operation *",
+      /*methodName=*/"create",
+      (ins "OpBuilder &":$builder, "Location":$loc, "TypeRange":$resultTypes,
+           "ValueRange":$operands,
+           "ArrayRef<NamedAttribute>":$attributes), [{
+        return builder.create<ConcreteOp>(
+          loc, resultTypes, operands, attributes);
+      }]
+    >,
     InterfaceMethod<
       /*desc=*/[{
         Clone the current operation with the given location and operands. This
@@ -1068,13 +1082,14 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
       (ins "OpBuilder &":$b, "Location":$loc, "TypeRange":$resultTypes,
            "ValueRange":$operands),
       [{
-        BlockAndValueMapping bvm;
-        OperationState state(
-          loc, ConcreteOp::getOperationName(), operands, resultTypes,
-          $_op->getAttrs());
-        for (Region &r : $_op->getRegions())
-          r.cloneInto(state.addRegion(), bvm);
-        return b.createOperation(state);
+        BlockAndValueMapping map;
+        unsigned numRegions = $_op->getNumRegions();
+        Operation *res = create(b, loc, resultTypes, operands, $_op->getAttrs());
+        assert(res->getNumRegions() == numRegions && "inconsistent # regions");
+        for (unsigned ridx = 0; ridx < numRegions; ++ridx)
+          $_op->getRegion(ridx).cloneInto(
+            &res->getRegion(ridx), map);
+        return res;
       }]
     >,
     StaticInterfaceMethod<
@@ -1083,7 +1098,7 @@ def LinalgStructuredInterface : OpInterface<"LinalgOp"> {
         Returns a null function if this named op does not define a region
         builder.
       }],
-      /*retTy=*/"std::function<void(Block &, ValueRange)>",
+      /*retTy=*/"std::function<void(Block &)>",
       /*methodName=*/"getRegionBuilder",
       (ins),
       [{ return ConcreteOp::getRegionBuilder(); }]
index 05a6bb7..8988a3a 100644 (file)
@@ -110,13 +110,14 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
     AnyStridedMemRef:$output,
     OptionalAttr<AffineMapAttr>:$inputPermutation,
     OptionalAttr<AffineMapAttr>:$outputPermutation);
-  let regions = (region AnyRegion:$region);
 
-  let builders = [
-    OpBuilderDAG<(ins "Value":$input, "Value":$output,
-      CArg<"AffineMap", "AffineMap()">:$inputPermutation,
-      CArg<"AffineMap", "AffineMap()">:$outputPermutation,
-      CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs)>];
+  // TODO: this should go away once the usage of OptionalAttr triggers emission
+  // of builders with default arguments left unspecified.
+  let builders = [OpBuilderDAG<(ins "Value":$input, "Value":$output),
+    [{
+      return build(
+        $_builder, $_state, input, output, AffineMapAttr(), AffineMapAttr());
+    }]>];
 
   let extraClassDeclaration = structuredOpsDecls # [{
     ValueRange inputs() { return getOperands().take_front(); }
@@ -145,31 +146,24 @@ def CopyOp : LinalgStructured_Op<"copy", [CopyOpInterface]> {
     Value getSource() { return input();}
     Value getTarget() { return output(); }
 
-    static void regionBuilder(Block &block, ValueRange captures);
-    static std::function<void(Block &block, ValueRange captures)>
-    getRegionBuilder() {
-      return &regionBuilder;
+    static std::function<void(Block &)> getRegionBuilder() {
+      return nullptr;
     }
-    static unsigned getNumRegionArgs() { return 2; }
   }];
   let verifier = [{ return ::verify(*this); }];
 
   let assemblyFormat = [{
-    `(` $input `,` $output `)` attr-dict `:`
-        type($input) `,` type($output)
-      custom<CopyOpRegion>($region, ref(type($input)), ref(type($input)))
+    `(` operands `)` attr-dict `:` type(operands)
   }];
 
   let hasFolder = 1;
   let hasCanonicalizer = 1;
-  let skipDefaultBuilders = 1;
 }
 
 def FillOp : LinalgStructured_Op<"fill", []> {
   let arguments = (ins AnyShaped:$output,
                    AnyTypeOf<[AnyFloat, AnySignlessInteger, AnyVector]>:$value);
   let results = (outs Optional<AnyRankedTensor>:$result);
-  let regions = (region AnyRegion:$region);
   let extraClassDeclaration = structuredOpsDecls # [{
     ValueRange inputs() { return {}; }
     ValueRange outputs() { return getOperands().take_front(); }
@@ -189,18 +183,13 @@ def FillOp : LinalgStructured_Op<"fill", []> {
           extractOrIdentityMap(llvm::None, getNumParallelLoops(), context)});
     }
 
-    static void regionBuilder(Block &block, ValueRange captures);
-    static std::function<void(Block &block, ValueRange captures)>
-    getRegionBuilder() {
-      return &regionBuilder;
+    static std::function<void(Block &)> getRegionBuilder() {
+      return nullptr;
     }
-    static unsigned getNumRegionArgs() { return 1; }
   }];
 
   let assemblyFormat = [{
-    `(` $output `,` $value `)` attr-dict `:`
-        type($output) `,` type($value) (`->` type($result)^)?
-      custom<FillOpRegion>($region, ref(type($output)), ref($value))
+    `(` operands `)` attr-dict `:` type(operands) (`->` type($result)^)?
   }];
 
   let builders = [
@@ -279,8 +268,7 @@ class PoolingBase_Op<string mnemonic, list<OpTrait> props>
       return padding().getValue().getValue<int64_t>({i, 1});
     }
 
-    static std::function<void(Block &, ValueRange captures)> getRegionBuilder()
-    {
+    static std::function<void(Block &)> getRegionBuilder() {
       return nullptr;
     }
   }];
@@ -531,7 +519,7 @@ class GenericOpBase<string mnemonic> : LinalgStructuredBase_Op<mnemonic, [
         library_call()->str() : "op_has_no_registered_library_name";
     }
 
-    static std::function<void(Block &, ValueRange)> getRegionBuilder() {
+    static std::function<void(Block &)> getRegionBuilder() {
       return nullptr;
     }
   }];
index 276b124..8b53ecb 100644 (file)
@@ -154,13 +154,7 @@ LogicalResult mlir::linalg::CopyTransposeRewrite::matchAndRewrite(
   if (in == op.input() && out == op.output())
     return failure();
 
-  auto libraryCallName = getLibraryCallSymbolRef(op, rewriter);
-  if (!libraryCallName)
-    return failure();
-
-  rewriter.replaceOpWithNewOp<mlir::CallOp>(
-      op, libraryCallName.getValue(), TypeRange(),
-      createTypeCanonicalizedMemRefOperands(rewriter, op.getLoc(), {in, out}));
+  rewriter.replaceOpWithNewOp<CopyOp>(op, in, out);
   return success();
 }
 
index 46e42e2..8bb104d 100644 (file)
@@ -27,6 +27,8 @@ Operation *mlir::edsc::makeGenericLinalgOp(
     ArrayRef<StructuredIndexed> outputs, TypeRange resultTensorTypes,
     function_ref<void(ValueRange)> regionBuilder, ArrayRef<Value> otherValues,
     ArrayRef<Attribute> otherAttributes) {
+  OpBuilder &builder = edsc::ScopedContext::getBuilderRef();
+
   // Build maps
   SmallVector<SmallVector<AffineExpr, 4>, 4> exprsList;
   exprsList.reserve(inputs.size() + outputs.size());
@@ -52,10 +54,13 @@ Operation *mlir::edsc::makeGenericLinalgOp(
               resultTensorTypes,
               inputValues,
               outputValues,
-              maps,
-              iteratorStrTypes,
-              ""/*doc*/,
-              ""/*library_call*/)
+              builder.getAffineMapArrayAttr(maps),
+              builder.getStrArrayAttr(iteratorStrTypes),
+              StringAttr() /*doc*/,
+              StringAttr() /*library_call*/,
+              ArrayAttr() /*sparse*/
+              /* TODO: other attributes in op */
+              )
           .getOperation();
   // clang-format on
 
index 42a7900..3cc4a78 100644 (file)
@@ -33,53 +33,32 @@ using namespace mlir;
 using namespace mlir::linalg;
 
 /// Forward declarations.
-
-/// Generic entry point to create the block for the region of a LinalgOp.
-/// This is used by both named structured ops created by ods-gen and by manually
-/// defined C++ ops.
-/// This is used by both builders and parsers.
-/// This function creates the block in the region with arguments corresponding
-/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted
-/// to be ShapedType.
-template <typename NamedStructuredOpType>
-static void fillStructuredOpRegion(
-    OpBuilder &opBuilder, Region &region, TypeRange inputTypes,
-    TypeRange outputTypes, ValueRange captures = {},
-    std::function<void(unsigned, unsigned)> errorHandler = [](unsigned,
-                                                              unsigned) {});
-
-/// Generic entry point to create both the region and the block of a LinalgOp.
 template <typename NamedStructuredOpType>
-static void
-createAndFillStructuredOpRegion(OpBuilder &opBuilder, OperationState &result,
-                                TypeRange inputTypes, TypeRange outputTypes,
-                                ValueRange captures = {});
+static void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
+                                                      OperationState &result,
+                                                      TypeRange inputTypes,
+                                                      TypeRange outputTypes);
 
-/// Common parsing and printing used for both named structured ops created by
-/// ods-gen and by manually defined C++ ops. Does not handle regions.
 static ParseResult
 parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
                              SmallVectorImpl<Type> &inputTypes,
                              SmallVectorImpl<Type> &outputTypes);
-template <typename NamedStructuredOpType>
-static void printCommonStructuredOpParts(OpAsmPrinter &p,
-                                         NamedStructuredOpType op);
 
-/// Specific parsing and printing for named structured ops created by ods-gen.
 template <typename NamedStructuredOpType>
 static ParseResult
 parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
-                             TypeRange inputTypes, TypeRange outputTypes,
-                             ArrayRef<OpAsmParser::OperandType> captures = {});
-
+                             TypeRange inputTypes, TypeRange outputTypes);
 static ParseResult
 parseNamedStructuredOpResults(OpAsmParser &parser,
                               SmallVectorImpl<Type> &resultTypes);
 
 template <typename NamedStructuredOpType>
-static ParseResult
-parseNamedStructuredOp(OpAsmParser &parser, OperationState &result,
-                       ArrayRef<OpAsmParser::OperandType> captures = {});
+static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
+                                          OperationState &result);
+
+template <typename NamedStructuredOpType>
+static void printCommonStructuredOpParts(OpAsmPrinter &p,
+                                         NamedStructuredOpType op);
 
 static void printNamedStructuredOpResults(OpAsmPrinter &p,
                                           TypeRange resultTypes);
@@ -123,135 +102,13 @@ static LogicalResult foldMemRefCast(Operation *op) {
 }
 
 //===----------------------------------------------------------------------===//
-// CopyOp
-//===----------------------------------------------------------------------===//
-void CopyOp::regionBuilder(Block &block, ValueRange captures) {
-  using namespace edsc::intrinsics;
-  assert(block.getNumArguments() == 2 && "CopyOp regionBuilder expects 2 args");
-  (linalg_yield(block.getArgument(0)));
-}
-
-void CopyOp::build(OpBuilder &builder, OperationState &result, Value input,
-                   Value output, AffineMap inputPermutation,
-                   AffineMap outputPermutation,
-                   ArrayRef<NamedAttribute> namedAttrs) {
-  result.addOperands({input, output});
-  result.addAttributes(namedAttrs);
-  if (inputPermutation)
-    result.addAttribute("inputPermutation",
-                        AffineMapAttr::get(inputPermutation));
-  if (outputPermutation)
-    result.addAttribute("outputPermutation",
-                        AffineMapAttr::get(outputPermutation));
-  result.addRegion();
-  fillStructuredOpRegion<CopyOp>(builder, *result.regions.front(),
-                                 TypeRange{input.getType()},
-                                 TypeRange{output.getType()});
-}
-
-ParseResult parseCopyOpRegion(OpAsmParser &parser, Region &r, Type inputType,
-                              Type outputType) {
-  OpBuilder opBuilder(parser.getBuilder().getContext());
-  fillStructuredOpRegion<CopyOp>(opBuilder, r, TypeRange{inputType},
-                                 TypeRange{outputType});
-  return success();
-}
-
-/// CopyOp region is elided when printing.
-void printCopyOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Type) {}
-
-static LogicalResult verify(CopyOp op) {
-  auto outputViewType = op.getOutputShapedType(0);
-  auto inputViewType = op.getInputShapedType(0);
-  if (inputViewType.getElementType() != outputViewType.getElementType())
-    return op.emitOpError("expects views of the same type");
-  if (inputViewType.getRank() != outputViewType.getRank())
-    return op.emitOpError("expects views of the same rank");
-  auto rank = op.getNumParallelLoops();
-  auto inputPermutationMap = op.inputPermutation();
-  if (inputPermutationMap) {
-    if (inputPermutationMap->getNumInputs() != rank)
-      return op.emitOpError("expects optional input_permutation map of rank ")
-             << rank;
-    if (!inputPermutationMap->isPermutation())
-      return op.emitOpError(
-          "expects optional input_permutation map to be a permutation");
-  }
-  auto outputPermutationMap = op.outputPermutation();
-  if (outputPermutationMap) {
-    if (outputPermutationMap->getNumInputs() != rank)
-      return op.emitOpError("expects optional output_permutation map of rank ")
-             << rank;
-    if (!outputPermutationMap->isPermutation())
-      return op.emitOpError(
-          "expects optional output_permutation map to be a permutation");
-  }
-  if (rank == 0 && inputPermutationMap)
-    return op.emitOpError("expected no input permutation when rank == 0");
-  if (rank == 0 && outputPermutationMap)
-    return op.emitOpError("expected no output permutation when rank == 0");
-  return success();
-}
-
-void CopyOp::getEffects(
-    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
-        &effects) {
-  effects.emplace_back(MemoryEffects::Read::get(), input(),
-                       SideEffects::DefaultResource::get());
-  effects.emplace_back(MemoryEffects::Write::get(), output(),
-                       SideEffects::DefaultResource::get());
-}
-
-//===----------------------------------------------------------------------===//
 // FillOp
 //===----------------------------------------------------------------------===//
-void FillOp::regionBuilder(Block &block, ValueRange captures) {
-  using namespace edsc::intrinsics;
-  assert(captures.size() == 1 && "FillOp regionBuilder expects 1 capture");
-  (linalg_yield(captures));
-}
 
 void FillOp::build(OpBuilder &builder, OperationState &result, Value output,
                    Value value) {
   build(builder, result, output.getType().dyn_cast<RankedTensorType>(), output,
         value);
-  fillStructuredOpRegion<FillOp>(builder, *result.regions.front(), TypeRange{},
-                                 TypeRange{output.getType()}, value);
-}
-
-ParseResult parseFillOpRegion(OpAsmParser &parser, Region &r, Type outputType,
-                              OpAsmParser::OperandType valueRef) {
-  OpBuilder opBuilder(parser.getBuilder().getContext());
-  // Resolve `valueRef` into `value` at parse time so we can build the region
-  // with captures.
-  SmallVector<Value> value;
-  parser.resolveOperand(valueRef, getElementTypeOrSelf(outputType), value);
-  fillStructuredOpRegion<FillOp>(opBuilder, r, TypeRange{},
-                                 TypeRange{outputType}, value);
-  return success();
-}
-
-/// FillOp region is elided when printing.
-void printFillOpRegion(OpAsmPrinter &, Operation *, Region &, Type, Value) {}
-
-static LogicalResult verify(FillOp op) {
-  auto viewType = op.getOutputShapedType(0);
-  auto fillType = op.value().getType();
-  if (viewType.getElementType() != fillType)
-    return op.emitOpError("expects fill type to match view elemental type");
-  if (!op.getNumResults() && !viewType.isa<MemRefType>()) {
-    return op.emitOpError(
-        "expected fill op with no result value to use memref type");
-  }
-  return success();
-}
-
-void FillOp::getEffects(
-    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
-        &effects) {
-  if (output().getType().isa<MemRefType>())
-    effects.emplace_back(MemoryEffects::Write::get(), output(),
-                         SideEffects::DefaultResource::get());
 }
 
 //===----------------------------------------------------------------------===//
@@ -576,6 +433,7 @@ void InitTensorOp::build(OpBuilder &b, OperationState &result,
   result.addAttributes(attrs);
 }
 
+
 static LogicalResult verify(InitTensorOp op) {
   RankedTensorType resultType = op.getType();
   SmallVector<int64_t, 4> staticSizes = llvm::to_vector<4>(llvm::map_range(
@@ -1556,6 +1414,68 @@ static LogicalResult verify(linalg::YieldOp op) {
 
 /////// Operations corresponding to library calls defined with Tablegen ////////
 
+void FillOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  if (output().getType().isa<MemRefType>())
+    effects.emplace_back(MemoryEffects::Write::get(), output(),
+                         SideEffects::DefaultResource::get());
+}
+
+static LogicalResult verify(FillOp op) {
+  auto viewType = op.getOutputShapedType(0);
+  auto fillType = op.value().getType();
+  if (viewType.getElementType() != fillType)
+    return op.emitOpError("expects fill type to match view elemental type");
+  if (!op.getNumResults() && !viewType.isa<MemRefType>()) {
+    return op.emitOpError(
+        "expected fill op with no result value to use memref type");
+  }
+  return success();
+}
+
+void CopyOp::getEffects(
+    SmallVectorImpl<SideEffects::EffectInstance<MemoryEffects::Effect>>
+        &effects) {
+  effects.emplace_back(MemoryEffects::Read::get(), input(),
+                       SideEffects::DefaultResource::get());
+  effects.emplace_back(MemoryEffects::Write::get(), output(),
+                       SideEffects::DefaultResource::get());
+}
+
+static LogicalResult verify(CopyOp op) {
+  auto outputViewType = op.getOutputShapedType(0);
+  auto inputViewType = op.getInputShapedType(0);
+  if (inputViewType.getElementType() != outputViewType.getElementType())
+    return op.emitOpError("expects views of the same type");
+  if (inputViewType.getRank() != outputViewType.getRank())
+    return op.emitOpError("expects views of the same rank");
+  auto rank = op.getNumParallelLoops();
+  auto inputPermutationMap = op.inputPermutation();
+  if (inputPermutationMap) {
+    if (inputPermutationMap->getNumInputs() != rank)
+      return op.emitOpError("expects optional input_permutation map of rank ")
+             << rank;
+    if (!inputPermutationMap->isPermutation())
+      return op.emitOpError(
+          "expects optional input_permutation map to be a permutation");
+  }
+  auto outputPermutationMap = op.outputPermutation();
+  if (outputPermutationMap) {
+    if (outputPermutationMap->getNumInputs() != rank)
+      return op.emitOpError("expects optional output_permutation map of rank ")
+             << rank;
+    if (!outputPermutationMap->isPermutation())
+      return op.emitOpError(
+          "expects optional output_permutation map to be a permutation");
+  }
+  if (rank == 0 && inputPermutationMap)
+    return op.emitOpError("expected no input permutation when rank == 0");
+  if (rank == 0 && outputPermutationMap)
+    return op.emitOpError("expected no output permutation when rank == 0");
+  return success();
+}
+
 template <typename LinalgPoolingOp>
 static LogicalResult verifyStrideOrDilation(LinalgPoolingOp op,
                                             ArrayRef<Attribute> attrs,
@@ -1788,25 +1708,14 @@ OpFoldResult TensorReshapeOp::fold(ArrayRef<Attribute> operands) {
 }
 
 //===----------------------------------------------------------------------===//
-// Support for named Linalg ops defined in ods-gen.
+// Auto-generated Linalg named ops.
 //===----------------------------------------------------------------------===//
 
-/// Generic entry point to create the block for the region of a LinalgOp.
-/// This is used by both named structured ops created by ods-gen and by manually
-/// defined C++ ops.
-/// This is used by both builders and parsers.
-/// This function creates the block in the region with arguments corresponding
-/// to the elemental types of `inputTypes` and `outputTypes`, which are asserted
-/// to be ShapedType.
 template <typename NamedStructuredOpType>
-static void
-fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
-                       TypeRange inputTypes, TypeRange outputTypes,
-                       ValueRange captures,
-                       std::function<void(unsigned, unsigned)> errorHandler) {
-  assert(llvm::all_of(inputTypes, [](Type t) { return t.isa<ShapedType>(); }));
-  assert(llvm::all_of(outputTypes, [](Type t) { return t.isa<ShapedType>(); }));
-
+static void buildNamedStructuredOpRegionAndAttributesImpl(
+    OpBuilder &opBuilder, Region &region, TypeRange inputTypes,
+    TypeRange outputTypes,
+    std::function<void(unsigned, unsigned)> errorHandler) {
   // TODO: atm all operands go through getElementTypeOrSelf,
   // reconsider when we have evidence we need to.
   SmallVector<Type, 8> argTypes;
@@ -1816,7 +1725,7 @@ fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
 
   // RAII.
   OpBuilder::InsertionGuard guard(opBuilder);
-  Block *body = opBuilder.createBlock(&region, /*insertPt=*/{}, argTypes);
+  Block *body = opBuilder.createBlock(&region, {}, argTypes);
   unsigned actual = body->getNumArguments();
   unsigned expected = NamedStructuredOpType::getNumRegionArgs();
   if (expected != actual)
@@ -1824,30 +1733,53 @@ fillStructuredOpRegion(OpBuilder &opBuilder, Region &region,
 
   opBuilder.setInsertionPointToStart(body);
   mlir::edsc::ScopedContext scope(opBuilder, opBuilder.getUnknownLoc());
-  NamedStructuredOpType::regionBuilder(*body, captures);
+  NamedStructuredOpType::regionBuilder(*body);
 
   // indexing_maps is an auto-generated method.
 
   // iterator_types is an auto-generated method.
 }
 
-/// Generic entry point to create both the region and the block of a LinalgOp.
 template <typename NamedStructuredOpType>
-void createAndFillStructuredOpRegion(OpBuilder &opBuilder,
-                                     OperationState &result,
-                                     TypeRange inputTypes,
-                                     TypeRange outputTypes,
-                                     ValueRange captures) {
+void buildNamedStructuredOpRegionAndAttributes(OpBuilder &opBuilder,
+                                               OperationState &result,
+                                               TypeRange inputTypes,
+                                               TypeRange outputTypes) {
   Region &region = *result.addRegion();
-  fillStructuredOpRegion<NamedStructuredOpType>(
-      opBuilder, region, inputTypes, outputTypes, captures,
+  buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>(
+      opBuilder, region, inputTypes, outputTypes,
       [&](unsigned expected, unsigned actual) {
+        llvm::errs() << "region expects " << expected << " args, got "
+                     << actual;
         assert(expected != actual && "incorrect number of arguments");
       });
 }
 
-/// Common parsing used for both named structured ops created by ods-gen and by
-/// manually defined C++ ops. Does not handle regions.
+template <typename NamedStructuredOpType>
+static ParseResult
+parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
+                             TypeRange inputTypes, TypeRange outputTypes) {
+  ParseResult res = success();
+  OpBuilder opBuilder(parser.getBuilder().getContext());
+  buildNamedStructuredOpRegionAndAttributesImpl<NamedStructuredOpType>(
+      opBuilder, region, inputTypes, outputTypes,
+      [&](unsigned expected, unsigned actual) {
+        res = parser.emitError(parser.getCurrentLocation(),
+                               llvm::formatv("region expects {0} args, got {1}",
+                                             expected, actual));
+      });
+  return res;
+}
+
+static ParseResult
+parseNamedStructuredOpResults(OpAsmParser &parser,
+                              SmallVectorImpl<Type> &resultTypes) {
+  if (succeeded(parser.parseOptionalArrow()))
+    if (parser.parseTypeList(resultTypes))
+      return failure();
+  return success();
+}
+
 static ParseResult
 parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
                              SmallVectorImpl<Type> &inputTypes,
@@ -1888,56 +1820,8 @@ parseCommonStructuredOpParts(OpAsmParser &parser, OperationState &result,
 }
 
 template <typename NamedStructuredOpType>
-static void printCommonStructuredOpParts(OpAsmPrinter &p,
-                                         NamedStructuredOpType op) {
-  if (!op.inputs().empty())
-    p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
-  if (!op.outputs().empty())
-    p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")";
-}
-
-//===----------------------------------------------------------------------===//
-// Specific parsing and printing for named structured ops created by ods-gen.
-//===----------------------------------------------------------------------===//
-
-template <typename NamedStructuredOpType>
-static ParseResult
-parseNamedStructuredOpRegion(OpAsmParser &parser, Region &region,
-                             TypeRange inputTypes, TypeRange outputTypes,
-                             ArrayRef<OpAsmParser::OperandType> captures) {
-  ParseResult res = success();
-  OpBuilder opBuilder(parser.getBuilder().getContext());
-  // Resolve `captures` into `capturedValues` at parse time so we can build the
-  // region with captures.
-  SmallVector<Value> capturedValues;
-  fillStructuredOpRegion<NamedStructuredOpType>(
-      opBuilder, region, inputTypes, outputTypes, capturedValues,
-      [&](unsigned expected, unsigned actual) {
-        res = parser.emitError(
-            parser.getCurrentLocation(),
-            llvm::formatv("[parseNamedStructuredOpRegion] ods-gen generated "
-                          "region expects {0} args, got {1}",
-                          expected, actual));
-        region.front().dump();
-      });
-  return res;
-}
-
-static ParseResult
-parseNamedStructuredOpResults(OpAsmParser &parser,
-                              SmallVectorImpl<Type> &resultTypes) {
-  if (succeeded(parser.parseOptionalArrow()))
-    if (parser.parseTypeList(resultTypes))
-      return failure();
-  return success();
-}
-
-template <typename NamedStructuredOpType>
-static ParseResult
-parseNamedStructuredOp(OpAsmParser &parser, OperationState &result,
-                       ArrayRef<OpAsmParser::OperandType> captures) {
-  // TODO: Enable when ods-gen supports captures.
-  assert(captures.empty() && "unexpected captures for named structured ops");
+static ParseResult parseNamedStructuredOp(OpAsmParser &parser,
+                                          OperationState &result) {
   SmallVector<Type, 1> inputTypes, outputTypes;
   if (parseCommonStructuredOpParts(parser, result, inputTypes, outputTypes))
     return failure();
@@ -1951,7 +1835,7 @@ parseNamedStructuredOp(OpAsmParser &parser, OperationState &result,
 
   std::unique_ptr<Region> region = std::make_unique<Region>();
   if (parseNamedStructuredOpRegion<NamedStructuredOpType>(
-          parser, *region, inputTypes, outputTypes, captures))
+          parser, *region, inputTypes, outputTypes))
     return failure();
   result.addRegion(std::move(region));
 
@@ -1966,6 +1850,15 @@ static void printNamedStructuredOpResults(OpAsmPrinter &p,
 }
 
 template <typename NamedStructuredOpType>
+static void printCommonStructuredOpParts(OpAsmPrinter &p,
+                                         NamedStructuredOpType op) {
+  if (!op.inputs().empty())
+    p << " ins(" << op.inputs() << " : " << op.inputs().getTypes() << ")";
+  if (!op.outputs().empty())
+    p << " outs(" << op.outputs() << " : " << op.outputs().getTypes() << ")";
+}
+
+template <typename NamedStructuredOpType>
 static void printNamedStructuredOp(OpAsmPrinter &p, NamedStructuredOpType op) {
   p << op.getOperationName();
   p.printOptionalAttrDict(op.getAttrs(),
@@ -1986,10 +1879,6 @@ static LogicalResult verifyNamedStructuredOp(NamedStructuredOpType op) {
   return verifyGenericOp<NamedStructuredOpType>(op);
 }
 
-//===----------------------------------------------------------------------===//
-// Canonicalizers and Folders.
-//===----------------------------------------------------------------------===//
-
 namespace {
 struct EraseDeadLinalgOp : public RewritePattern {
   EraseDeadLinalgOp(PatternBenefit benefit = 1)
index 69de55c..0be1c55 100644 (file)
@@ -49,7 +49,7 @@ static linalg::GenericOp createGenericOpFromNamedOp(linalg::LinalgOp namedOp,
       indexingMaps, iterators,
       [&regionBuilder](OpBuilder &bodyBuilder, Location loc, ValueRange) {
         edsc::ScopedContext scope(bodyBuilder, loc);
-        regionBuilder(*bodyBuilder.getBlock(), /*captures=*/{});
+        regionBuilder(*bodyBuilder.getBlock());
       });
 }
 
index d09d3e0..391562b 100644 (file)
@@ -52,6 +52,14 @@ static SmallVector<Value, 8> makeCanonicalAffineApplies(OpBuilder &b,
   return res;
 }
 
+static SmallVector<Value, 4> permuteIvs(ArrayRef<Value> ivs,
+                                        Optional<AffineMap> permutation) {
+  return permutation ? applyMapToValues(ScopedContext::getBuilderRef(),
+                                        ScopedContext::getLocation(),
+                                        permutation.getValue(), ivs)
+                     : SmallVector<Value, 4>(ivs.begin(), ivs.end());
+}
+
 template <typename IndexedValueType, typename OpType>
 static void inlineRegionAndEmitStore(OpType op, ArrayRef<Value> indexedValues,
                                      ArrayRef<SmallVector<Value, 8>> indexing,
@@ -170,6 +178,40 @@ static void emitScalarImplementation(ArrayRef<Value> allIvs,
                                              outputBuffers);
 }
 
+template <typename IndexedValueType>
+static void emitScalarImplementation(ArrayRef<Value> allIvs, CopyOp copyOp) {
+  assert(copyOp.hasBufferSemantics() &&
+         "expected linalg op with buffer semantics");
+  auto nPar = copyOp.getNumParallelLoops();
+  assert(nPar == allIvs.size());
+  auto inputIvs =
+      permuteIvs(allIvs.take_front(nPar), copyOp.inputPermutation());
+  auto outputIvs =
+      permuteIvs(allIvs.take_front(nPar), copyOp.outputPermutation());
+  SmallVector<Value, 8> iivs(inputIvs.begin(), inputIvs.end());
+  SmallVector<Value, 8> oivs(outputIvs.begin(), outputIvs.end());
+  IndexedValueType O(copyOp.getOutputBuffer(0)), I(copyOp.getInput(0));
+  // Emit the proper scalar assignment, whether we are dealing with a 0-D or
+  // an n-D loop nest; with or without permutations.
+  // clang-format off
+    nPar > 0 ? O(oivs) = I(iivs) :
+               O() = I();
+  // clang-format on
+}
+
+template <typename IndexedValueType>
+static void emitScalarImplementation(ArrayRef<Value> allIvs, FillOp fillOp) {
+  assert(fillOp.hasBufferSemantics() &&
+         "expected linalg op with buffer semantics");
+  auto nPar = fillOp.getNumParallelLoops();
+  assert(nPar == allIvs.size());
+  auto ivs = SmallVector<Value, 4>(allIvs.begin(), allIvs.begin() + nPar);
+  IndexedValueType O(fillOp.getOutputBuffer(0));
+  // Emit the proper scalar assignment, whether we are dealing with a 0-D or
+  // an n-D loop nest; with or without permutations.
+  nPar > 0 ? O(ivs) = fillOp.value() : O() = fillOp.value();
+}
+
 // Create a padded view into the given `input` tensor using the 'indices'
 // to access the tensor. `skipPadding` lists the dimensions for which no padding
 // is needed e.g. the non-spatial dimensions for convolutions.
@@ -491,8 +533,8 @@ linalgOpToLoopsImpl(Operation *op, OpBuilder &builder,
         assert(iterArgs.empty() && "unexpected iterArgs");
         allIvs.append(ivs.begin(), ivs.end());
         llvm::TypeSwitch<Operation *>(op)
-            .Case<ConvOp, PoolingMaxOp, PoolingMinOp, PoolingSumOp,
-                  IndexedGenericOp, LinalgOp>([&](auto op) {
+            .Case<CopyOp, FillOp, ConvOp, PoolingMaxOp, PoolingMinOp,
+                  PoolingSumOp, IndexedGenericOp, LinalgOp>([&](auto op) {
               emitScalarImplementation<IndexedValueTy>(allIvs, op);
             })
             .Default([&](Operation *op) { assert(false && "unexpected op"); });
index bfd2884..49d323a 100644 (file)
@@ -267,7 +267,7 @@ static Optional<VectorizedLinalgOp> vectorizeAsLinalgGeneric(
         llvm::map_range(linalgOp.getShapedOperandTypes(),
                         [](ShapedType t) { return t.getElementType(); }));
     block->addArguments(elementTypes);
-    linalgOp.getRegionBuilder()(*block, /*captures=*/{});
+    linalgOp.getRegionBuilder()(*block);
   }
   Block *block = &region->front();
 
@@ -333,26 +333,24 @@ static bool hasOnlyScalarElementwiseOp(Region &r) {
 
 // Return true if the op is an element-wise linalg op.
 static bool isElementwise(Operation *op) {
-  auto linalgOp = dyn_cast<linalg::LinalgOp>(op);
-  if (!linalgOp)
+  auto genericOp = dyn_cast<linalg::GenericOp>(op);
+  if (!genericOp)
     return false;
-  if (linalgOp.getNumLoops() != linalgOp.getNumParallelLoops())
+  if (genericOp.getNumLoops() != genericOp.getNumParallelLoops())
     return false;
   // TODO: relax the restrictions on indexing map.
-  for (unsigned i = 0, e = linalgOp.getNumOutputs(); i < e; i++) {
-    if (!linalgOp.getOutputIndexingMap(i).isIdentity())
+  for (unsigned i = 0, e = genericOp.getNumOutputs(); i < e; i++) {
+    if (!genericOp.getOutputIndexingMap(i).isIdentity())
       return false;
   }
   // Currently bound the input indexing map to minor identity as other
   // permutations might require adding transpose ops to convert the vector read
   // to the right shape.
-  for (unsigned i = 0, e = linalgOp.getNumInputs(); i < e; i++) {
-    if (!linalgOp.getInputIndexingMap(i).isMinorIdentity())
+  for (unsigned i = 0, e = genericOp.getNumInputs(); i < e; i++) {
+    if (!genericOp.getInputIndexingMap(i).isMinorIdentity())
       return false;
   }
-  if (linalgOp->getNumRegions() != 1)
-    return false;
-  return hasOnlyScalarElementwiseOp(linalgOp->getRegion(0));
+  return hasOnlyScalarElementwiseOp(genericOp.getRegion());
 }
 
 static Optional<VectorizedLinalgOp> vectorizeContraction(OpBuilder &builder,
@@ -395,6 +393,9 @@ LogicalResult mlir::linalg::vectorizeLinalgOpPrecondition(Operation *op) {
   for (Type outputTensorType : linalgOp.getOutputTensorTypes())
     if (!outputTensorType.cast<ShapedType>().hasStaticShape())
       return failure();
+
+  if (isa<linalg::FillOp, linalg::CopyOp>(op))
+    return success();
   if (isElementwise(op))
     return success();
   return success(isaContractionOpInterface(linalgOp));
@@ -406,12 +407,43 @@ Optional<VectorizedLinalgOp> mlir::linalg::vectorizeLinalgOp(OpBuilder &builder,
     return llvm::None;
 
   edsc::ScopedContext scope(builder, op->getLoc());
+  // In the case of 0-D memrefs, return null and special case to scalar load or
+  // store later.
+  if (auto fillOp = dyn_cast<linalg::FillOp>(op)) {
+    // Vectorize fill as a vector.broadcast.
+    LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
+                      << "Rewrite linalg.fill as vector.broadcast: " << *op);
+    VectorizedLinalgOp res;
+    if (Value v = buildVectorWrite(builder, fillOp.value(), fillOp.output()))
+      res.tensorResults.push_back(v);
+    return res;
+  }
+  if (auto copyOp = dyn_cast<linalg::CopyOp>(op)) {
+    // Vectorize copy as a vector.transfer_read+vector.transfer_write.
+    LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
+                      << "Rewrite linalg.copy as vector.transfer_read + "
+                         "vector.transfer_write: "
+                      << *op);
+    Value vector = buildVectorRead(builder, copyOp.input());
+    VectorizedLinalgOp res;
+    if (Value v = buildVectorWrite(builder, vector, copyOp.output()))
+      res.tensorResults.push_back(v);
+    return res;
+  }
   if (isElementwise(op)) {
     LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
                       << "Vectorize linalg op as a generic: " << *op);
     return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
   }
 
+  // TODO: as soon as Copy and FillOp. get a region builder, replace all the
+  // above by:
+  // if (isa<FillOp, CopyOp>(op) || isElementwise(op)) {
+  //   LLVM_DEBUG(dbgs() << "\n[" DEBUG_TYPE "]: "
+  //                     << "Vectorize linalg op as a generic: " << *op);
+  //   return vectorizeAsLinalgGeneric(builder, cast<LinalgOp>(op));
+  // }
+
   return vectorizeContraction(builder, cast<LinalgOp>(op));
 }
 
index 1432037..a66006f 100644 (file)
@@ -1,4 +1,5 @@
-// RUN: mlir-opt -copy-removal -split-input-file %s | FileCheck %s
+// RUN: mlir-opt -copy-removal -split-input-file %s
+//| FileCheck %s
 
 // All linalg copies except the linalg.copy(%1, %9) must be removed since the
 // defining operation of %1 and its DeallocOp have been defined in another block.
@@ -255,7 +256,7 @@ func @test_ReuseCopyTargetAsSource(%arg0: memref<2xf32>, %result: memref<2xf32>)
     %tmp2 = math.exp %gen2_arg0 : f32
     linalg.yield %tmp2 : f32
   }
-  linalg.copy(%temp, %result) : memref<2xf32>, memref<2xf32>
+  "linalg.copy"(%temp, %result) : (memref<2xf32>, memref<2xf32>) -> ()
   dealloc %temp : memref<2xf32>
   // CHECK: return
   return
@@ -291,7 +292,7 @@ func @test_ReuseCopyTargetAsSource(%arg0: memref<2xf32>){
     linalg.yield %tmp2 : f32
   }
   // CHECK: linalg.copy
-  linalg.copy(%temp, %to) : memref<2xf32>, memref<2xf32>
+  "linalg.copy"(%temp, %to) : (memref<2xf32>, memref<2xf32>) -> ()
   dealloc %temp : memref<2xf32>
   return
 }
@@ -354,7 +355,7 @@ func @check_with_affine_dialect(%arg0: memref<4xf32>, %arg1: memref<4xf32>, %arg
   }
   // CHECK-NOT: linalg.copy
   // CHECK-NOT: dealloc
-  linalg.copy(%0, %arg2) : memref<4xf32>, memref<4xf32>
+  "linalg.copy"(%0, %arg2) : (memref<4xf32>, memref<4xf32>) -> ()
   dealloc %0 : memref<4xf32>
   //CHECK: return
   return
index b197ba3..a16a2b8 100644 (file)
@@ -23,7 +23,7 @@
 //  IMPL-NEXT: map2 = simplifyAffineMap(map2);
 //  IMPL-NEXT: return {{.+}}.getAffineMapArrayAttr({ map0, map1, map2 });
 //
-//       IMPL:  void Test1Op::regionBuilder(Block &block, ValueRange captures) {
+//       IMPL:  void Test1Op::regionBuilder(Block &block) {
 //       IMPL:  Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
 //       IMPL:  Value [[d:.*]] = std_mulf([[a]], [[b]]);
 //       IMPL:  Value [[e:.*]] = std_addf([[c]], [[d]]);
@@ -47,7 +47,7 @@ def test1(A: f32(M, K), B: f32(K)) -> (C: f32(M)) {
 //       IMPL:  AffineMap::get(3, 3, {d2, d1}, context)
 //       IMPL:  AffineMap::get(3, 3, {d0, d1}, context)
 //
-//       IMPL:  Test2Op::regionBuilder(Block &block, ValueRange captures) {
+//       IMPL:  Test2Op::regionBuilder(Block &block) {
 //       IMPL:  Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
 //       IMPL:  Value [[d:.*]] = std_mulf([[a]], [[b]]);
 //       IMPL:  Value [[e:.*]] = std_addf([[c]], [[d]]);
@@ -71,7 +71,7 @@ def test2(A: f32(M, K), B: f32(K, N)) -> (C: f32(M, N)) {
 //       IMPL:  AffineMap::get(4, 4, {d3, d2}, context)
 //       IMPL:  AffineMap::get(4, 4, {d0, d1, d2}, context)
 //
-//       IMPL:  Test3Op::regionBuilder(Block &block, ValueRange captures) {
+//       IMPL:  Test3Op::regionBuilder(Block &block) {
 //       IMPL:  Value [[a:.*]](args[0]), [[b:.*]](args[1]), [[c:.*]](args[2]);
 //       IMPL:  Value [[d:.*]] = std_mulf([[a]], [[b]]);
 //       IMPL:  Value [[e:.*]] = std_addf([[c]], [[d]]);
index 4f57322..0934967 100644 (file)
@@ -1871,11 +1871,11 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
             $_builder.getI32VectorAttr({{
               static_cast<int32_t>(inputs.size()),
               static_cast<int32_t>(outputs.size())}));
-          createAndFillStructuredOpRegion<{0}>(
+          buildNamedStructuredOpRegionAndAttributes<{0}>(
             $_builder,
             $_state,
             TypeRange(inputs),
-            TypeRange(outputs)/*, TODO: support captures*/);
+            TypeRange(outputs));
         }]>,
         OpBuilderDAG<
         (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
@@ -1889,11 +1889,11 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
             $_builder.getI32VectorAttr({{
               static_cast<int32_t>(inputs.size()),
               static_cast<int32_t>(outputs.size())}));
-          createAndFillStructuredOpRegion<{0}>(
+          buildNamedStructuredOpRegionAndAttributes<{0}>(
             $_builder,
             $_state,
             TypeRange(inputs),
-            TypeRange(outputs)/*, TODO: support captures*/);
+            TypeRange(outputs));
         }]>,
         OpBuilderDAG<
         (ins "TypeRange":$resultTensorTypes, "ValueRange":$operands,
@@ -1907,9 +1907,7 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
         {6}
       ];
       let printer = [{{ return ::printNamedStructuredOp(p, *this); }];
-      let parser = [{{
-        return ::parseNamedStructuredOp<{0}>(parser, result/*TODO:, captures*/);
-      }];
+      let parser = [{{ return ::parseNamedStructuredOp<{0}>(parser, result); }];
       let hasFolder = 1;
       let hasCanonicalizer = 1;
 
@@ -1917,8 +1915,8 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
         // Auto-generated.
         ArrayAttr iterator_types();
         ArrayAttr indexing_maps();
-        static void regionBuilder(Block &block, ValueRange captures);
-        static std::function<void(Block &, ValueRange)> getRegionBuilder() {{
+        static void regionBuilder(Block &block);
+        static std::function<void(Block &)> getRegionBuilder() {{
           return regionBuilder;
         }
 
@@ -1982,11 +1980,11 @@ void TCParser::printODS(llvm::raw_ostream &os, StringRef cppOpName,
           $_builder.getI32VectorAttr({{
             static_cast<int32_t>(inputs.size()),
             static_cast<int32_t>(outputs.size())}));
-        createAndFillStructuredOpRegion<{0}>(
+        buildNamedStructuredOpRegionAndAttributes<{0}>(
           $_builder,
           $_state,
           TypeRange(inputs),
-          TypeRange(outputs)/*, TODO: support captures*/);
+          TypeRange(outputs));
         {2}
       }]>
     )FMT";
@@ -2313,7 +2311,7 @@ void TCParser::printRegionBuilder(llvm::raw_ostream &os, StringRef cppOpName,
   };
 
   const char *regionBuilderFmt = R"FMT(
-  void {0}::regionBuilder(Block &block, ValueRange captures) {
+  void {0}::regionBuilder(Block &block) {
     using namespace edsc;
     using namespace intrinsics;
     auto args = block.getArguments();