Add a generic Linalg op
authorNicolas Vasilache <ntv@google.com>
Fri, 2 Aug 2019 16:53:08 +0000 (09:53 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 2 Aug 2019 16:53:41 +0000 (09:53 -0700)
This CL introduces a linalg.generic op to represent generic tensor contraction operations on views.

A linalg.generic operation requires a numbers of attributes that are sufficient to emit the computation in scalar form as well as compute the appropriate subviews to enable tiling and fusion.

These attributes are very similar to the attributes for existing operations such as linalg.matmul etc and existing operations can be implemented with the generic form.

In the future, most existing operations can be implemented using the generic form.

This CL starts by splitting out most of the functionality of the linalg::NInputsAndOutputs trait into a ViewTrait that queries the per-instance properties of the op. This allows using the attribute informations.

This exposes an ordering of verifiers issue where ViewTrait::verify uses attributes but the verifiers for those attributes have not been run. The desired behavior would be for the verifiers of the attributes specified in the builder to execute first but it is not the case atm. As a consequence, to emit proper error messages and avoid crashing, some of the
linalg.generic methods are defensive as such:
```
    unsigned getNumInputs() {
      // This is redundant with the `n_views` attribute verifier but ordering of verifiers
      // may exhibit cases where we crash instead of emitting an error message.
      if (!getAttr("n_views") || n_views().getValue().size() != 2)
        return 0;
```

In pretty-printed form, the specific attributes required for linalg.generic are factored out in an independent dictionary named "_". When parsing its content is flattened and the "_name" is dropped. This allows using aliasing for reducing boilerplate at each linalg.generic invocation while benefiting from the Tablegen'd verifier form for each named attribute in the dictionary.

For instance, implementing linalg.matmul in terms of linalg.generic resembles:

```
func @mac(%a: f32, %b: f32, %c: f32) -> f32 {
  %d = mulf %a, %b: f32
  %e = addf %c, %d: f32
  return %e: f32
}
#matmul_accesses = [
  (m, n, k) -> (m, k),
  (m, n, k) -> (k, n),
  (m, n, k) -> (m, n)
]
#matmul_trait = {
  doc = "C(m, n) += A(m, k) * B(k, n)",
  fun = @mac,
  indexing_maps = #matmul_accesses,
  library_call = "linalg_matmul",
  n_views = [2, 1],
  n_loop_types = [2, 1, 0]
}
```

And can be used in multiple places as:
```
  linalg.generic #matmul_trait %A, %B, %C [other-attributes] :
    !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
```

In the future it would be great to have a mechanism to alias / register a new
linalg.op as a pair of linalg.generic, #trait.

Also, note that with one could theoretically only specify the `doc` string and parse all the attributes from it.

PiperOrigin-RevId: 261338740

16 files changed:
mlir/include/mlir/AffineOps/AffineOps.td
mlir/include/mlir/AffineOps/AffineOpsBase.td [new file with mode: 0644]
mlir/include/mlir/EDSC/Intrinsics.h
mlir/include/mlir/IR/AffineMap.h
mlir/include/mlir/IR/Builders.h
mlir/include/mlir/Linalg/IR/LinalgLibraryOps.td
mlir/include/mlir/Linalg/IR/LinalgOps.h
mlir/include/mlir/Linalg/IR/LinalgTraits.h
mlir/lib/IR/AffineMap.cpp
mlir/lib/IR/Builders.cpp
mlir/lib/Linalg/IR/LinalgOps.cpp
mlir/lib/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/lib/Linalg/Transforms/Tiling.cpp
mlir/test/Linalg/invalid-generic-op.mlir [new file with mode: 0644]
mlir/test/Linalg/loops.mlir
mlir/test/Linalg/roundtrip.mlir

index 306b399..c517ed0 100644 (file)
@@ -1,4 +1,4 @@
-//===- Ops.td - Affine operation definitions ---------------*- tablegen -*-===//
+//===- AffineOps.td - Affine operation definitions ---------*- tablegen -*-===//
 //
 // Copyright 2019 The MLIR Authors.
 //
@@ -28,6 +28,8 @@
 include "mlir/IR/OpBase.td"
 #endif // OP_BASE
 
+include "mlir/AffineOps/AffineOpsBase.td"
+
 def Affine_Dialect : Dialect {
   let name = "affine";
   let cppNamespace = "";
diff --git a/mlir/include/mlir/AffineOps/AffineOpsBase.td b/mlir/include/mlir/AffineOps/AffineOpsBase.td
new file mode 100644 (file)
index 0000000..2ac1d37
--- /dev/null
@@ -0,0 +1,44 @@
+//===- AffineOpsBase.td - Affine operation definitions -----*- tablegen -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// Defines base support for MLIR affine operations.
+//
+//===----------------------------------------------------------------------===//
+
+#ifdef AFFINE_OPS_BASE
+#else
+#define AFFINE_OPS_BASE
+
+#ifdef OP_BASE
+#else
+include "mlir/IR/OpBase.td"
+#endif // OP_BASE
+
+// Attributes containing affine maps.
+def AffineMapAttr : Attr<
+    CPred<"$_self.isa<AffineMapAttr>()">, "AffineMap attribute"> {
+  let storageType = [{ AffineMapAttr }];
+  let returnType = [{ AffineMap }];
+  let constBuilderCall = "$_builder.getAffineMapAttr($0)";
+}
+
+def AffineMapArrayAttr : TypedArrayAttrBase<AffineMapAttr,
+                                      "AffineMap array attribute"> {
+  let constBuilderCall = "$_builder.getAffineMapArrayAttr($0)";
+}
+
+#endif // AFFINE_OPS_BASE
index 872c3b2..021fec2 100644 (file)
@@ -190,6 +190,7 @@ using alloc = ValueBuilder<AllocOp>;
 using affine_apply = ValueBuilder<AffineApplyOp>;
 using affine_load = ValueBuilder<AffineLoadOp>;
 using affine_store = OperationBuilder<AffineStoreOp>;
+using call = OperationBuilder<mlir::CallOp>;
 using constant_float = ValueBuilder<ConstantFloatOp>;
 using constant_index = ValueBuilder<ConstantIndexOp>;
 using constant_int = ValueBuilder<ConstantIntOp>;
index a29db18..6a60223 100644 (file)
@@ -153,11 +153,12 @@ AffineMap simplifyAffineMap(AffineMap map);
 
 /// Returns a map of codomain to domain dimensions such that the first codomain
 /// dimension for a particular domain dimension is selected.
-/// Returns an empty map if the input map is empty.
+/// Returns an empty map if the input map is empty or if `map` is not invertible
+/// (i.e. `map` does not contain a subset that is a permutation of full domain
+/// rank).
 ///
 /// Prerequisites:
-///   1. `map` must contain a subset that is a permutation of full domain rank.
-///   2. `map` has no symbols.
+///   1. `map` has no symbols.
 ///
 /// Example 1:
 ///
index 185ac2c..9f5f873 100644 (file)
@@ -130,9 +130,11 @@ public:
   FloatAttr getF16FloatAttr(float value);
   FloatAttr getF32FloatAttr(float value);
   FloatAttr getF64FloatAttr(double value);
+
   IntegerAttr getI32IntegerAttr(int32_t value);
   IntegerAttr getI64IntegerAttr(int64_t value);
 
+  ArrayAttr getAffineMapArrayAttr(ArrayRef<AffineMap> values);
   ArrayAttr getI32ArrayAttr(ArrayRef<int32_t> values);
   ArrayAttr getI64ArrayAttr(ArrayRef<int64_t> values);
   ArrayAttr getF32ArrayAttr(ArrayRef<float> values);
index a3796d2..2ea8db2 100644 (file)
 //
 //===----------------------------------------------------------------------===//
 
-include "mlir/Linalg/IR/LinalgBase.td"
-
 #ifdef LINALG_LIBRARY_OPS
 #else
 #define LINALG_LIBRARY_OPS
 
+include "mlir/AffineOps/AffineOpsBase.td"
+include "mlir/Linalg/IR/LinalgBase.td"
+
 class LinalgParametricNativeOpTrait<string prop, string parameters> :
   NativeOpTrait<"linalg::" # prop # parameters>
 {}
@@ -65,29 +66,28 @@ class ViewRanks<list<int> ranks> :
 LinalgParametricIntNativeOpTrait<"ViewRanks", ranks>
 {}
 
+def ViewTraits : NativeOpTrait<"linalg::ViewTraits">;
+
 // Base Tablegen class for Linalg ops.
 // Linalg ops that correspond to library calls operate on linalg::View as their
 // first operands. These may be optionally followed by non-view operands
 // depending on the specific Linalg op.
-class LinalgLibrary_Op<string mnemonic, list<OpTrait> props>
-  : Op<Linalg_Dialect, mnemonic, props> {
+class LinalgLibraryBase_Op<string mnemonic, list<OpTrait> props>
+  : Op<Linalg_Dialect, mnemonic, !listconcat(props, [ViewTraits])> {
   let parser = [{ return parseLinalgLibraryOp(parser, result); }];
   let printer = [{ printLinalgLibraryOp(p, *this); }];
+}
 
-  let extraClassDeclaration = [{
-    static StringRef getLibraryCallName() {
+class LinalgLibrary_Op<string mnemonic, list<OpTrait> props>
+  : LinalgLibraryBase_Op<mnemonic, props> {
+
+  code classDeclaration = [{
+    StringRef getLibraryCallName() {
       return "linalg_}] # mnemonic # [{";
     }
   }];
 }
 
-def AffineMapAttr : Attr<
-    CPred<"$_self.isa<AffineMapAttr>()">, "AffineMap attribute"> {
-  let storageType = [{ AffineMapAttr }];
-  let returnType = [{ AffineMap }];
-  let constBuilderCall = "$_builder.getAffineMapAttr($0)";
-}
-
 ////////////////////////////////////////////////////////////////////////////////
 // Concrete Linalg ops.
 ////////////////////////////////////////////////////////////////////////////////
@@ -138,7 +138,7 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputsAndOutputs<1, 1>]> {
     return build(
       builder, result, input, output, AffineMapAttr(), AffineMapAttr());
   }]>];
-  let extraClassDeclaration = [{
+  let extraClassDeclaration = classDeclaration # [{
     unsigned getNumParallelLoops() {
       auto *view = *(getOperands().begin());
       return view->getType().cast<ViewType>().getRank();
@@ -151,7 +151,7 @@ def CopyOp : LinalgLibrary_Op<"copy", [NInputsAndOutputs<1, 1>]> {
 
 def FillOp : LinalgLibrary_Op<"fill", [NInputsAndOutputs<0, 1>]> {
   let arguments = (ins View, AnyTypeOf<[AnyFloat, AnyInteger, AnyVector]>);
-  let extraClassDeclaration = [{
+  let extraClassDeclaration = classDeclaration # [{
     unsigned getNumParallelLoops() {
       auto *view = *(getOperands().begin());
       return view->getType().cast<ViewType>().getRank();
@@ -170,6 +170,7 @@ def DotOp : LinalgLibrary_Op<"dot",
                              NLoopTypes<0, 1, 0>,
                              ViewRanks<[1, 1, 0]>]> {
   let arguments = (ins View, View, View);
+  let extraClassDeclaration = classDeclaration;
 }
 
 def MatvecOp : LinalgLibrary_Op<"matvec",
@@ -177,6 +178,7 @@ def MatvecOp : LinalgLibrary_Op<"matvec",
                                    NLoopTypes<1, 1, 0>,
                                    ViewRanks<[2, 1, 1]>]> {
   let arguments = (ins View, View, View);
+  let extraClassDeclaration = classDeclaration;
 }
 
 def MatmulOp : LinalgLibrary_Op<"matmul",
@@ -184,6 +186,7 @@ def MatmulOp : LinalgLibrary_Op<"matmul",
                                    NLoopTypes<2, 1, 0>,
                                    ViewRanks<[2, 2, 2]>]> {
   let arguments = (ins View, View, View);
+  let extraClassDeclaration = classDeclaration;
 }
 
 def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> {
@@ -208,7 +211,7 @@ def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> {
   let arguments = (ins View:$filter, View:$input, View:$output,
                    OptionalAttr<I64ArrayAttr>:$strides,
                    OptionalAttr<I64ArrayAttr>:$dilations);
-  let extraClassDeclaration = [{
+  let extraClassDeclaration = classDeclaration # [{
     // TODO(ntv) extend to support more than 1 dimensions and potentially
     // grouping too.
     unsigned getNumBatchDimensions() { return 1; }
@@ -248,4 +251,163 @@ def ConvOp : LinalgLibrary_Op<"conv", [NInputsAndOutputs<2, 1>]> {
   let verifier = [{ return ::verify(*this); }];
 }
 
+def GenericOp : LinalgLibraryBase_Op<"generic", []> {
+  let description = [{
+    Generic Linalg op form where the key properties of the computation are
+    specified as attributes. In pretty form, a linalg.generic op is written as:
+
+      ```
+        linalg.generic #trait_attribute %A, %B, %C {other-attributes} :
+          !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+      ```
+
+    Where #trait_attributes is an alias of a dictionary attribute containing:
+      - doc [optional]: a documentation string
+      - fun: a SymbolRefAttr that must resolve to an existing function symbol.
+        To support inplace updates in a generic fashion, the signature of the
+        function must be:
+        ```
+          fun([input views element types], [output views element types])
+            -> ([output views element types])
+        ```
+      - indexing_maps: a list of AffineMapAttr, one AffineMapAttr per each input
+        and output view. Such AffineMapAttr specifies the mapping between the
+        loops and the indexing within each view.
+      - library_call [optional]: a StringAttr containing the name of an
+        external library function that the linalg.generic operation maps to.
+        The external library is assumed to be dynamically linked and no strong
+        compile-time guarantees are provided. In the absence of such a library
+        call, linalg.generic will always lower to loops.
+      - n_loops: a triple of I64Attr representing the number of enclosing
+        [parallel, reduction, window] loops respectively.
+      - n_views: a pair of I64Attr representing the number of input (readonly)
+        and output (readwrite) views.
+
+    Example:
+    Defining a #matmul_trait attribute in MLIR can be done as follows:
+      ```
+        func @fma(%a: f32, %b: f32, %c: f32) -> f32 {
+          %d = mulf %a, %b: f32
+          %e = addf %c, %d: f32
+          return %e: f32
+        }
+        #matmul_accesses = [
+          (m, n, k) -> (m, k),
+          (m, n, k) -> (k, n),
+          (m, n, k) -> (m, n)
+        ]
+        #matmul_trait = {
+          doc = "C(m, n) += A(m, k) * B(k, n)",
+          fun = @fma,
+          indexing_maps = #matmul_accesses,
+          library_call = "linalg_matmul",
+          n_views = [2, 1],
+          n_loop_types = [2, 1, 0]
+        }
+      ```
+
+    And can be reused in multiple places as:
+      ```
+        linalg.generic #matmul_trait %A, %B, %C [other-attributes] :
+          !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
+      ```
+
+    This may lower to either:
+      ```
+        call @linalg_matmul(%A, %B, %C) :
+          (!linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>)
+          -> ()
+      ```
+
+    or IR resembling:
+    ```
+    loop.for %m = %c0 to %M step %c1 {
+      loop.for %n = %c0 to %N step %c1 {
+        loop.for %k = %c0 to %K step %c1 {
+          %a = linalg.load %A[%m, %k] : !linalg.view<?x?xf32>
+          %b = linalg.load %B[%k, %n] : !linalg.view<?x?xf32>
+          %c = linalg.load %C[%m, %n] : !linalg.view<?x?xf32>
+          %d = call @mac(%a, %b, %c) : (f32, f32, f32) -> (f32)
+          linalg.store %d, %C[%m, %n] : !linalg.view<?x?x?xf32>
+        }
+      }
+    }
+    ```
+  }];
+  let arguments = (ins Variadic<View>:$views,
+                   SymbolRefAttr:$fun,
+                   AffineMapArrayAttr:$indexing_maps,
+                   I64ArrayAttr:$n_loop_types,
+                   I64ArrayAttr:$n_views,
+                   OptionalAttr<StrAttr>:$doc,
+                   OptionalAttr<StrAttr>:$library_call);
+  let extraClassDeclaration = [{
+    SmallVector<StringRef, 8> linalgTraitAttrNames() {
+      return SmallVector<StringRef, 8>{
+        "doc", "fun", "indexing_maps", "library_call", "n_loop_types", "n_views"
+      };
+    }
+    unsigned getNumInputs() {
+      if (!getAttr("n_views") || n_views().getValue().size() != 2)
+        return 0;
+      auto val = n_views().getValue()[0].cast<IntegerAttr>().getValue();
+      assert(val.getSExtValue() >= 0);
+      return val.getZExtValue();
+    }
+    unsigned getNumOutputs() {
+      if (!getAttr("n_views") || n_views().getValue().size() != 2)
+        return 0;
+      auto val = n_views().getValue()[1].cast<IntegerAttr>().getValue();
+      assert(val.getSExtValue() >= 0);
+      return val.getZExtValue();
+    }
+    unsigned getNumParallelLoops() {
+      if (!getAttr("n_loop_types") || n_loop_types().getValue().size() != 3)
+        return 0;
+      auto val = n_loop_types().getValue()[0].cast<IntegerAttr>().getValue();
+      assert(val.getSExtValue() >= 0);
+      return val.getZExtValue();
+    }
+    unsigned getNumReductionLoops() {
+      if (!getAttr("n_loop_types") || n_loop_types().getValue().size() != 3)
+        return 0;
+      auto val = n_loop_types().getValue()[1].cast<IntegerAttr>().getValue();
+      assert(val.getSExtValue() >= 0);
+      return val.getZExtValue();
+    }
+    unsigned getNumWindowLoops() {
+      if (!getAttr("n_loop_types") || n_loop_types().getValue().size() != 3)
+        return 0;
+      auto val = n_loop_types().getValue()[2].cast<IntegerAttr>().getValue();
+      assert(val.getSExtValue() >= 0);
+      return val.getZExtValue();
+    }
+    unsigned getNumLoops() {
+      return getNumParallelLoops() + getNumReductionLoops() +
+        getNumWindowLoops();
+    }
+    StringRef getFunName() {
+      return fun();
+    }
+    StringRef getLibraryCallName() {
+      return library_call().hasValue() ? library_call().getValue() : "";
+    }
+    AffineMap getIndexingMap(unsigned i) {
+      assert(i < getNumInputsAndOutputs());
+      return indexing_maps().getValue()[i].cast<AffineMapAttr>().getValue();
+    }
+    AffineMap getInputIndexingMap(unsigned i) {
+      assert(i < getNumInputs());
+      return indexing_maps().getValue()[i].cast<AffineMapAttr>().getValue();
+    }
+    AffineMap getOutputIndexingMap(unsigned i) {
+      assert(i < getNumOutputs());
+      return indexing_maps().getValue()[i + getNumInputs()]
+          .cast<AffineMapAttr>().getValue();
+    }
+  }];
+  let printer = [{ return ::print(p, *this); }];
+  let verifier = [{ return ::verify(*this); }];
+  let parser = [{ return ::parse$cppClass(parser, result); }];
+}
 #endif // LINALG_LIBRARY_OPS
index 9a16766..41767ad 100644 (file)
@@ -18,6 +18,7 @@
 #ifndef MLIR_LINALG_LINALGOPS_H_
 #define MLIR_LINALG_LINALGOPS_H_
 
+#include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/OpDefinition.h"
 #include "mlir/Linalg/IR/LinalgTraits.h"
index 022aef4..5c94be3 100644 (file)
@@ -41,54 +41,95 @@ public:
   public:
     static unsigned getNumInputs() { return NInputs; }
     static unsigned getNumOutputs() { return NOutputs; }
-    static unsigned getNumInputsAndOutputs() { return NInputs + NOutputs; }
-    Value *getInput(unsigned i) { return this->getOperation()->getOperand(i); }
-    llvm::Optional<unsigned> getIndexOfInput(Value *view) {
-      auto it = llvm::find(getInputs(), view);
-      if (it != getInputs().end())
-        return it - getInputs().begin();
-      return llvm::None;
-    }
-    mlir::linalg::ViewType getInputViewType(unsigned i) {
-      return this->getOperation()
-          ->getOperand(i)
-          ->getType()
-          .template cast<mlir::linalg::ViewType>();
-    }
-    Operation::operand_range getInputs() {
-      auto range = this->getOperation()->getOperands();
-      return {range.begin(), range.begin() + getNumInputs()};
-    }
-    Value *getOutput(unsigned i) {
-      return this->getOperation()->getOperand(getNumInputs() + i);
-    }
-    llvm::Optional<unsigned> getIndexOfOutput(Value *view) {
-      auto it = llvm::find(getOutputs(), view);
-      if (it != getOutputs().end())
-        return it - getOutputs().begin();
-      return llvm::None;
-    }
-    mlir::linalg::ViewType getOutputViewType(unsigned i) {
-      return this->getOperation()
-          ->getOperand(getNumInputs() + i)
-          ->getType()
-          .template cast<mlir::linalg::ViewType>();
-    }
-    Operation::operand_range getOutputs() {
-      auto range = this->getOperation()->getOperands();
-      return {range.begin() + getNumInputs(),
-              range.begin() + getNumInputsAndOutputs()};
-    }
-    Operation::operand_range getInputsAndOutputs() {
-      auto range = this->getOperation()->getOperands();
-      return {range.begin(), range.begin() + getNumInputsAndOutputs()};
-    }
     static LogicalResult verifyTrait(Operation *op) {
       return OpTrait::impl::verifyAtLeastNOperands(op, NInputs + NOutputs);
     }
   };
 };
 
+/// This class provides the API for ops that are known to operate on views. This
+/// trait must be used in conjunction with an op definition or a trait that
+/// provides the methods `getNumInputs` and `getNumOutputs`. This is used as a
+/// trait like this:
+///
+///   class DotOp : public Op<DotOp, OpTrait::ViewTrait> {
+///
+template <typename ConcreteType>
+class ViewTraits : public OpTrait::TraitBase<ConcreteType, ViewTraits> {
+private:
+  /// Return the number of input views. For internal use only.
+  unsigned nInputs() {
+    return cast<ConcreteType>(this->getOperation()).getNumInputs();
+  }
+  /// Return the number of input views. For internal use only.
+  unsigned nOutputs() {
+    return cast<ConcreteType>(this->getOperation()).getNumOutputs();
+  }
+
+public:
+  /// Return the `i`-th input view.
+  Value *getInput(unsigned i) {
+    assert(i < nInputs());
+    return this->getOperation()->getOperand(i);
+  }
+  /// Return the index of `view` in the list of input views if found, llvm::None
+  /// otherwise.
+  llvm::Optional<unsigned> getIndexOfInput(Value *view) {
+    auto it = llvm::find(getInputs(), view);
+    if (it != getInputs().end())
+      return it - getInputs().begin();
+    return llvm::None;
+  }
+  /// Return the `i`-th input view type.
+  mlir::linalg::ViewType getInputViewType(unsigned i) {
+    return getInput(i)->getType().template cast<mlir::linalg::ViewType>();
+  }
+  /// Return the range over input views.
+  Operation::operand_range getInputs() {
+    auto range = this->getOperation()->getOperands();
+    return {range.begin(), range.begin() + nInputs()};
+  }
+  /// Return the `i`-th output view.
+  Value *getOutput(unsigned i) {
+    return this->getOperation()->getOperand(nInputs() + i);
+  }
+  /// Return the index of `view` in the list of output views if found,
+  /// llvm::None otherwise.
+  llvm::Optional<unsigned> getIndexOfOutput(Value *view) {
+    auto it = llvm::find(getOutputs(), view);
+    if (it != getOutputs().end())
+      return it - getOutputs().begin();
+    return llvm::None;
+  }
+  /// Return the `i`-th output view type.
+  mlir::linalg::ViewType getOutputViewType(unsigned i) {
+    return getOutput(i)->getType().template cast<mlir::linalg::ViewType>();
+  }
+  /// Return the range over output views.
+  Operation::operand_range getOutputs() {
+    auto range = this->getOperation()->getOperands();
+    return {range.begin() + nInputs(),
+            range.begin() + getNumInputsAndOutputs()};
+  }
+  /// Return the number of input and output views.
+  unsigned getNumInputsAndOutputs() { return nInputs() + nOutputs(); }
+  /// Return the range over input and output views.
+  Operation::operand_range getInputsAndOutputs() {
+    auto range = this->getOperation()->getOperands();
+    return {range.begin(), range.begin() + getNumInputsAndOutputs()};
+  }
+  static LogicalResult verifyTrait(Operation *op) {
+    auto nViews = cast<ConcreteType>(op).getNumInputsAndOutputs();
+    if (failed(OpTrait::impl::verifyAtLeastNOperands(op, nViews)))
+      return failure();
+    for (unsigned i = 0, e = nViews; i < e; ++i) {
+      if (!op->getOperand(i)->getType().dyn_cast<mlir::linalg::ViewType>())
+        return op->emitOpError("operand ") << i << " must have view type ";
+    }
+    return success();
+  }
+};
+
 /// This class provides the API for ops that are known to have a specified
 /// number of parallel, reduction and window loops. This is used as a trait like
 /// this:
index 9adf1df..1b6bbe5 100644 (file)
@@ -296,8 +296,8 @@ AffineMap mlir::inversePermutation(AffineMap map) {
   for (auto expr : exprs)
     if (expr)
       seenExprs.push_back(expr);
-  assert(seenExprs.size() == map.getNumInputs() &&
-         "map does not include a full rank permutation");
+  if (seenExprs.size() != map.getNumInputs())
+    return AffineMap();
   return AffineMap::get(map.getNumResults(), 0, seenExprs);
 }
 
index f6d228e..f31868b 100644 (file)
@@ -236,6 +236,12 @@ ArrayAttr Builder::getStrArrayAttr(ArrayRef<StringRef> values) {
   return getArrayAttr(attrs);
 }
 
+ArrayAttr Builder::getAffineMapArrayAttr(ArrayRef<AffineMap> values) {
+  auto attrs = functional::map(
+      [this](AffineMap v) -> Attribute { return getAffineMapAttr(v); }, values);
+  return getArrayAttr(attrs);
+}
+
 Attribute Builder::getZeroAttr(Type type) {
   switch (type.getKind()) {
   case StandardTypes::F16:
index fa1a315..ad8bb48 100644 (file)
@@ -25,6 +25,8 @@
 #include "mlir/IR/AffineExpr.h"
 #include "mlir/IR/AffineMap.h"
 #include "mlir/IR/Builders.h"
+#include "mlir/IR/Function.h"
+#include "mlir/IR/Module.h"
 #include "mlir/IR/OpImplementation.h"
 #include "mlir/IR/StandardTypes.h"
 #include "mlir/Linalg/IR/LinalgTypes.h"
@@ -33,6 +35,8 @@
 #include "mlir/Support/STLExtras.h"
 #include "mlir/Transforms/FoldUtils.h"
 
+#include "llvm/ADT/StringSet.h"
+
 using namespace mlir;
 using namespace mlir::edsc;
 using namespace mlir::edsc::intrinsics;
@@ -525,6 +529,119 @@ static void print(OpAsmPrinter *p, RangeIntersectOp op) {
   *p << " : " << op.getOperand(0)->getType();
 }
 
+//===----------------------------------------------------------------------===//
+// GenericOp
+//===----------------------------------------------------------------------===//
+
+static void print(OpAsmPrinter *p, GenericOp op) {
+  auto attrNames = op.linalgTraitAttrNames();
+  llvm::StringSet<> linalgTraitAttrsSet;
+  linalgTraitAttrsSet.insert(attrNames.begin(), attrNames.end());
+  SmallVector<NamedAttribute, 8> attrs;
+  for (auto attr : op.getAttrs()) {
+    if (linalgTraitAttrsSet.count(attr.first.strref()) > 0)
+      attrs.push_back(attr);
+  }
+  auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
+  *p << op.getOperationName() << " " << dictAttr << " ";
+  p->printOperands(op.getOperands());
+  p->printOptionalAttrDict(op.getAttrs(), attrNames);
+  *p << ": ";
+  interleaveComma(op.getOperandTypes(), *p);
+}
+
+static ParseResult parseGenericOp(OpAsmParser *parser, OperationState *result) {
+  SmallVector<OpAsmParser::OperandType, 8> operandsInfo;
+  DictionaryAttr dictAttr;
+  // Parse the core linalg traits that must check into a dictAttr.
+  // The name is unimportant as we will overwrite result->attributes.
+  // The core linalg traits must contain the information necessary to pass the
+  // verifier.
+  if (parser->parseAttribute(dictAttr, "_", result->attributes) ||
+      parser->parseOperandList(operandsInfo))
+    return failure();
+  result->attributes.assign(dictAttr.getValue().begin(),
+                            dictAttr.getValue().end());
+
+  // Optional attributes may be added.
+  SmallVector<Type, 8> operandTypes;
+  if (parser->parseOptionalAttributeDict(result->attributes) ||
+      parser->parseColonTypeList(operandTypes))
+    return failure();
+  return parser->resolveOperands(operandsInfo, operandTypes,
+                                 parser->getCurrentLocation(),
+                                 result->operands);
+}
+
+static LogicalResult verify(GenericOp op) {
+  auto nInputViews = op.getNumInputs();
+  auto nViews = op.getNumInputsAndOutputs();
+  if (nViews != llvm::size(op.views()))
+    return op.emitError("op expected exactly ") << nViews << " view operands";
+
+  auto m = op.getParentOfType<ModuleOp>();
+  auto fun = m.lookupSymbol<FuncOp>(op.fun());
+  if (!fun || !fun.getType())
+    return op.emitError(
+        "op expected fun attribute to refer to a defined symbol");
+
+  auto funType = fun.getType();
+  if (funType.getNumInputs() != nViews)
+    return op.emitError("op expected fun arguments to match number of views");
+  if (funType.getNumResults() != op.getNumOutputs())
+    return op.emitError(
+        "op expected fun results to match number of output views");
+
+  auto nLoops = op.getNumLoops();
+  SmallVector<AffineMap, 4> indexingMaps;
+  indexingMaps.reserve(op.indexing_maps().size());
+  for (auto en : llvm::enumerate(op.indexing_maps())) {
+    auto idx = en.index();
+    auto m = en.value().cast<AffineMapAttr>().getValue();
+    indexingMaps.push_back(m); // Save reference to map for further checks.
+    auto view = (idx < nInputViews) ? op.getInputViewType(idx)
+                                    : op.getOutputViewType(idx - nInputViews);
+
+    if (m.getNumSymbols() != 0)
+      return op.emitError("op expected indexing_map #")
+             << idx << " to have no symbols";
+
+    if (m.getNumDims() != nLoops)
+      return op.emitError("op expected indexing_map #")
+             << idx << " to have " << nLoops
+             << " dim(s) to match the number of loops";
+
+    if (m.getNumResults() == 1 && view.getRank() == 0) {
+      auto cst = m.getResult(0).dyn_cast<AffineConstantExpr>();
+      if (!cst || cst.getValue() != 0)
+        return op.emitError("op expected indexing_map #")
+               << idx << " to be 0 to match 0-D view: " << view;
+    }
+
+    if (m.getNumResults() != view.getRank())
+      return op.emitError("op expected indexing_map #")
+             << idx << " results to match view rank: " << view;
+
+    if (funType.getInput(idx) != view.getElementType())
+      return op.emitError("op expected fun argument ")
+             << idx << " to match view element type: " << view.getElementType();
+
+    if (idx >= nInputViews)
+      if (funType.getResult(idx - nInputViews) != view.getElementType())
+        return op.emitError("op expected fun result ")
+               << idx << " to match output view element type: "
+               << view.getElementType();
+  }
+
+  auto concatMap = concatAffineMaps(indexingMaps);
+  auto aggregateMap = inversePermutation(concatMap);
+  if (!aggregateMap)
+    return op.emitError("op expected the concatenation of maps in indexing_map "
+                        "to be invertible");
+
+  return success();
+}
+
 static ParseResult parseRangeIntersectOp(OpAsmParser *parser,
                                          OperationState *result) {
   SmallVector<OpAsmParser::OperandType, 2> ops;
@@ -840,6 +957,14 @@ SmallVector<AffineMap, 4> mlir::linalg::loopToOperandRangesMaps(Operation *op) {
         AffineMap::get(idx, 0, concat(concat(bs, ws), qs)),
         // output[b, x[0], ..., x[N-1], k]
         AffineMap::get(idx, 0, concat(concat(bs, xs), ks))};
+  } else if (auto genericOp = dyn_cast<GenericOp>(op)) {
+    SmallVector<AffineMap, 4> res;
+    unsigned nViews = genericOp.getNumInputsAndOutputs();
+    res.reserve(nViews);
+    for (unsigned i = 0, e = nViews; i < e; ++i) {
+      res.push_back(genericOp.getIndexingMap(i));
+    }
+    return res;
   }
   llvm_unreachable("Missing loopToOperandRangesMaps for op");
 }
@@ -945,5 +1070,62 @@ void mlir::linalg::emitScalarImplementation(
     O(oIdx) += F(fIdx) * I(imIdx);
     return;
   }
+  if (auto genericOp = dyn_cast<GenericOp>(op)) {
+    using edsc::intrinsics::detail::ValueHandleArray;
+    unsigned nInputs = genericOp.getNumInputs();
+    unsigned nOutputs = genericOp.getNumOutputs();
+    SmallVector<Value *, 4> indexedValues(nInputs + nOutputs);
+    // Emits the MLIR for the scalar part of the generic op by:
+    //   1. Emitting linalg_load and linalg_store ops for each input and output
+    //      view in order. This is achieved by applying the appropriate input or
+    //      output map to the enclosing induction variables.
+    //   2. Emitting a call to `op.fun()` that takes as arguments the scalars
+    //      from point 1. above.
+    //   3. Emitting linalg_store to store the results of 2. to the output
+    //      views.
+    //
+    // An example output may resemble:
+    //
+    // ```
+    //    loop.for %i = %c0 to %0 step %c1 {
+    //      loop.for %j = %c0 to %1 step %c1 {
+    //        loop.for %k = %c0 to %4 step %c1 {
+    //          %11 = linalg.load %arg0[%i, %j] : !linalg.view<?x?xf32>
+    //          %12 = linalg.load %arg1[%i, %j, %k] : !linalg.view<?x?x?xf32>
+    //          %13 = linalg.load %arg2[%i, %k, %j] : !linalg.view<?x?x?xf32>
+    //          %14:2 = call @foo(%11, %12, %13) : (f32, f32, f32) -> (f32, f32)
+    //          linalg.store %14#0, %arg1[%i, %j, %k] : !linalg.view<?x?x?xf32>
+    //          linalg.store %14#1, %arg2[%i, %k, %j] : !linalg.view<?x?x?xf32>
+    //       }
+    //      }
+    //    }
+    // ```
+
+    // 1.a. Emit linalg_load from input views.
+    for (unsigned i = 0, e = nInputs; i < e; ++i) {
+      ValueHandleArray indexing(foldedAffineApplies(
+          b, loc, genericOp.getInputIndexingMap(i), allIvs, folder));
+      indexedValues[i] = linalg_load(genericOp.getInput(i), indexing);
+    }
+    // 1.b. Emit linalg_load from output views..
+    for (unsigned i = 0, e = nOutputs; i < e; ++i) {
+      ValueHandleArray indexing(foldedAffineApplies(
+          b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
+      indexedValues[nInputs + i] =
+          linalg_load(genericOp.getOutput(i), indexing);
+    }
+    // 2. Emit call.
+    auto m = genericOp.getParentOfType<ModuleOp>();
+    auto fun = m.lookupSymbol<FuncOp>(genericOp.fun());
+    Operation *callOp = call(fun, indexedValues);
+    assert(callOp->getNumResults() == genericOp.getNumOutputs());
+    // 3. Emit linalg_store.
+    for (unsigned i = 0, e = nOutputs; i < e; ++i) {
+      ValueHandleArray indexing(foldedAffineApplies(
+          b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
+      linalg_store(callOp->getResult(i), genericOp.getOutput(i), indexing);
+    }
+    return;
+  }
   llvm_unreachable("Missing emitScalarImplementation for op");
 }
index b6bfa58..6e5e270 100644 (file)
@@ -55,12 +55,12 @@ using namespace mlir::linalg::intrinsics;
 using add = ValueBuilder<mlir::LLVM::AddOp>;
 using addi = ValueBuilder<mlir::AddIOp>;
 using bitcast = ValueBuilder<mlir::LLVM::BitcastOp>;
-using call = OperationBuilder<mlir::LLVM::CallOp>;
 using cmpi = ValueBuilder<mlir::CmpIOp>;
 using constant = ValueBuilder<mlir::LLVM::ConstantOp>;
 using extractvalue = ValueBuilder<mlir::LLVM::ExtractValueOp>;
 using gep = ValueBuilder<mlir::LLVM::GEPOp>;
 using insertvalue = ValueBuilder<mlir::LLVM::InsertValueOp>;
+using llvm_call = OperationBuilder<mlir::LLVM::CallOp>;
 using llvm_icmp = ValueBuilder<LLVM::ICmpOp>;
 using llvm_load = ValueBuilder<LLVM::LoadOp>;
 using llvm_store = OperationBuilder<LLVM::StoreOp>;
@@ -206,7 +206,7 @@ public:
     Value *allocSize =
         mul(size, constant(int64Ty, IntegerAttr::get(indexType, elementSize)));
     Value *allocated =
-        call(voidPtrTy, rewriter.getSymbolRefAttr(mallocFunc), allocSize)
+        llvm_call(voidPtrTy, rewriter.getSymbolRefAttr(mallocFunc), allocSize)
             .getOperation()
             ->getResult(0);
     allocated = bitcast(elementPtrType, allocated);
@@ -251,7 +251,7 @@ public:
     edsc::ScopedContext context(rewriter, op->getLoc());
     Value *casted = bitcast(voidPtrTy, extractvalue(elementPtrTy, operands[0],
                                                     positionAttr(rewriter, 0)));
-    call(ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
+    llvm_call(ArrayRef<Type>(), rewriter.getSymbolRefAttr(freeFunc), casted);
     rewriter.replaceOp(op, llvm::None);
     return matchSuccess();
   }
@@ -611,8 +611,12 @@ template <typename LinalgOp>
 static FuncOp
 getLLVMLibraryCallDeclaration(Operation *op, LLVMTypeConverter &lowering,
                               ConversionPatternRewriter &rewriter) {
-  assert(isa<LinalgOp>(op));
-  auto fnName = LinalgOp::getLibraryCallName();
+  auto linalgOp = cast<LinalgOp>(op);
+  auto fnName = linalgOp.getLibraryCallName();
+  if (fnName.empty()) {
+    op->emitWarning("No library call defined for: ") << *op;
+    return FuncOp();
+  }
   auto module = op->getParentOfType<ModuleOp>();
   if (auto f = module.lookupSymbol<FuncOp>(fnName)) {
     return f;
@@ -661,7 +665,7 @@ static void getLLVMLibraryCallDefinition(FuncOp fn,
     implFnArgs.push_back(alloca);
     llvm_store(arg, alloca);
   }
-  call(ArrayRef<Type>(), builder.getSymbolRefAttr(implFn), implFnArgs);
+  llvm_call(ArrayRef<Type>(), builder.getSymbolRefAttr(implFn), implFnArgs);
   llvm_return{ArrayRef<Value *>()};
 }
 
@@ -704,6 +708,8 @@ public:
                   ConversionPatternRewriter &rewriter) const override {
     // Only emit library call declaration. Fill in the body later.
     auto f = getLLVMLibraryCallDeclaration<LinalgOp>(op, lowering, rewriter);
+    if (!f)
+      return matchFailure();
     static_cast<LinalgTypeConverter &>(lowering).addLibraryFnDeclaration(f);
 
     auto fAttr = rewriter.getSymbolRefAttr(f);
index b667333..25ffdeb 100644 (file)
@@ -373,6 +373,7 @@ mlir::linalg::tileLinalgOp(LinalgOp op, ArrayRef<Value *> tileSizes,
   // permutation map (asserted in the inverse calculation).
   auto viewSizesToLoopsMap =
       inversePermutation(concatAffineMaps(loopToOperandRangesMaps(op)));
+  assert(viewSizesToLoopsMap && "expected invertible map");
   auto loopRanges =
       makeTiledLoopRanges(scope.getBuilder(), scope.getLocation(),
                           viewSizesToLoopsMap, viewSizes, tileSizes, folder);
diff --git a/mlir/test/Linalg/invalid-generic-op.mlir b/mlir/test/Linalg/invalid-generic-op.mlir
new file mode 100644 (file)
index 0000000..8d9956c
--- /dev/null
@@ -0,0 +1,196 @@
+// RUN: mlir-opt %s -split-input-file -verify-diagnostics
+
+// -----
+
+// CHECK-LABEL: at_least_2_operands
+func @at_least_2_operands(%arg0: !linalg.view<f32>) {
+  // expected-error @+1 {{op expected 2 or more operands}}
+  linalg.generic {
+    fun = @foo,
+    indexing_maps =  [ () -> (0) ],
+    n_views = [1, 1],
+    n_loop_types = [0, 0, 0]
+  } %arg0: !linalg.view<f32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: exactly_2_views
+func @exactly_2_views(%arg0: !linalg.view<f32>) {
+  // expected-error @+1 {{op expected exactly 2 view operands}}
+  linalg.generic {
+    fun = @foo,
+    indexing_maps =  [ () -> (0) ],
+    n_views = [1, 1],
+    n_loop_types = [0, 0, 0]
+  } %arg0, %arg0, %arg0: !linalg.view<f32>, !linalg.view<f32>, !linalg.view<f32>
+  return
+}
+
+// -----
+
+// CHECK-LABEL: undefined_fun
+func @undefined_fun(%arg0: !linalg.view<f32>) {
+  // expected-error @+1 {{op expected fun attribute to refer to a defined symbol}}
+  linalg.generic {
+    fun = @foo,
+    indexing_maps =  [ () -> (0) ],
+    n_views = [1, 1],
+    n_loop_types = [0, 0, 0]
+  } %arg0, %arg0: !linalg.view<f32>, !linalg.view<f32>
+  return
+}
+
+// -----
+
+func @foo() { return }
+
+// CHECK-LABEL: mismatched_num_arguments
+func @mismatched_num_arguments(%arg0: !linalg.view<f32>) {
+  // expected-error @+1 {{op expected fun arguments to match number of views}}
+  linalg.generic {
+    fun = @foo,
+    indexing_maps =  [ () -> (0) ],
+    n_views = [0, 1],
+    n_loop_types = [0, 0, 0]
+  } %arg0: !linalg.view<f32>
+  return
+}
+
+// -----
+
+func @foo(%0: i32) { return }
+
+// CHECK-LABEL: mismatched_num_returns
+func @mismatched_num_returns(%arg0: !linalg.view<f32>) {
+  // expected-error @+1 {{op expected fun results to match number of output views}}
+  linalg.generic {
+    fun = @foo,
+    indexing_maps =  [ () -> (0) ],
+    n_views = [0, 1],
+    n_loop_types = [0, 0, 0]
+  } %arg0: !linalg.view<f32>
+  return
+}
+
+// -----
+
+func @foo(%0: i32) -> i32 { return %0: i32 }
+
+// CHECK-LABEL: symbol_in_map
+func @symbol_in_map(%arg0: !linalg.view<f32>) {
+  // expected-error @+1 {{op expected indexing_map #0 to have no symbols}}
+  linalg.generic {
+    fun = @foo,
+    indexing_maps =  [ ()[N] -> (0) ],
+    n_views = [0, 1],
+    n_loop_types = [1, 0, 0]
+  } %arg0: !linalg.view<f32>
+  return
+}
+
+// -----
+
+func @foo(%0: i32) -> i32 { return %0: i32 }
+
+// CHECK-LABEL: wrong_dim_in_map
+func @wrong_dim_in_map(%arg0: !linalg.view<f32>) {
+  // expected-error @+1 {{op expected indexing_map #0 to have 1 dim(s) to match the number of loops}}
+  linalg.generic {
+    fun = @foo,
+    indexing_maps =  [ () -> (0) ],
+    n_views = [0, 1],
+    n_loop_types = [1, 0, 0]
+  } %arg0: !linalg.view<f32>
+  return
+}
+
+// -----
+
+func @foo(%0: i32) -> i32 { return %0: i32 }
+
+// CHECK-LABEL: zero_d_view
+func @zero_d_view(%arg0: !linalg.view<f32>) {
+  // expected-error @+1 {{op expected indexing_map #0 to be 0 to match 0-D view: '!linalg.view<f32>'}}
+  linalg.generic {
+    fun = @foo,
+    indexing_maps =  [ () -> (1) ],
+    n_views = [0, 1],
+    n_loop_types = [0, 0, 0]
+  } %arg0: !linalg.view<f32>
+  return
+}
+
+// -----
+
+func @foo(%0: f32) -> f32 { return %0: f32 }
+
+// CHECK-LABEL: one_d_view
+func @one_d_view(%arg0: !linalg.view<?xf32>) {
+  // expected-error @+1 {{op expected indexing_map #0 results to match view rank: '!linalg.view<?xf32>'}}
+  linalg.generic {
+    fun = @foo,
+    indexing_maps =  [ () -> (0, 0) ],
+    n_views = [0, 1],
+    n_loop_types = [0, 0, 0]
+  } %arg0: !linalg.view<?xf32>
+  return
+}
+
+// -----
+
+func @foo(%0: i32) -> f32 {
+  %1 = constant 0.0: f32
+  return %1: f32
+}
+
+// CHECK-LABEL: fun_arg_0_element_type
+func @fun_arg_0_element_type(%arg0: !linalg.view<?xf32>) {
+  // expected-error @+1 {{op expected fun argument 0 to match view element type: 'f32'}}
+  linalg.generic {
+    fun = @foo,
+    indexing_maps =  [ () -> (0) ],
+    n_views = [0, 1],
+    n_loop_types = [0, 0, 0]
+  } %arg0: !linalg.view<?xf32>
+  return
+}
+
+// -----
+
+func @foo(%0: f32) -> i4 {
+  %1 = constant 1: i4
+  return %1: i4
+}
+
+// CHECK-LABEL: fun_result_0_element_type
+func @fun_result_0_element_type(%arg0: !linalg.view<?xf32>) {
+  // expected-error @+1 {{op expected fun result 0 to match output view element type: 'f32'}}
+  linalg.generic {
+    fun = @foo,
+    indexing_maps =  [ () -> (0) ],
+    n_views = [0, 1],
+    n_loop_types = [0, 0, 0]
+  } %arg0: !linalg.view<?xf32>
+  return
+}
+
+// -----
+
+func @foo(%0: f32, %1: f32) -> f32 { return %1: f32 }
+
+// CHECK-LABEL: singular_maps
+func @singular_maps(%arg0: !linalg.view<?xf32>, %arg1: !linalg.view<?xf32>) {
+  // expected-error @+1 {{op expected the concatenation of maps in indexing_map to be invertible}}
+  linalg.generic {
+    fun = @foo,
+    indexing_maps =  [
+      (i, j) -> (i + j) ,
+      (i, j) -> (i + j)
+    ],
+    n_views = [1, 1],
+    n_loop_types = [2, 0, 0]
+  } %arg0, %arg1: !linalg.view<?xf32>, !linalg.view<?xf32>
+  return
+}
index a12aa99..64e3bb1 100644 (file)
@@ -201,3 +201,37 @@ func @conv_view4(%arg0: !linalg.view<?x?x?x?xf32>, %arg1: !linalg.view<?x?x?x?xf
 //       CHECK:                 %{{.*}} = linalg.load %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : !linalg.view<?x?x?x?xf32>
 //       CHECK:                 %{{.*}} = addf %{{.*}}, %{{.*}} : f32
 //       CHECK:                 linalg.store %{{.*}}, %{{.*}}[%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}] : !linalg.view<?x?x?x?xf32>
+
+func @foo(%0: f32, %1: f32, %2: f32) -> (f32, f32) {
+  %f0 = constant 0.0 : f32
+  return %f0, %f0 : f32, f32
+}
+#accesses = [
+  (i, j, k) -> (i, j),
+  (i, j, k) -> (i, j, k),
+  (i, j, k) -> (i, k, j)
+]
+#trait = {
+  n_views = [1, 2],
+  n_loop_types = [3, 0, 0],
+  indexing_maps = #accesses,
+  fun = @foo,
+  library_call = "external_function_name",
+  doc = "B(i,j,k), C(i,k,j) = foo(A(i, j), B(i,j,k), C(i,k,j))"
+}
+func @generic(%arg0: !linalg.view<?x?xf32>, %arg1: !linalg.view<?x?x?xf32>, %arg2: !linalg.view<?x?x?xf32>) {
+  linalg.generic #trait %arg0, %arg1, %arg2:
+    !linalg.view<?x?xf32>, !linalg.view<?x?x?xf32>, !linalg.view<?x?x?xf32>
+  return
+}
+// CHECK-LABEL: @foo
+// CHECK-LABEL: @generic
+//       CHECK: loop.for %[[i:.*]] = {{.*}}
+//       CHECK:   loop.for %[[j:.*]] = {{.*}}
+//       CHECK:     loop.for %[[k:.*]] = {{.*}}
+//       CHECK:       %[[a:.*]] = linalg.load %{{.*}}[%[[i]], %[[j]]] : !linalg.view<?x?xf32>
+//       CHECK:       %[[b:.*]] = linalg.load %{{.*}}[%[[i]], %[[j]], %[[k]]] : !linalg.view<?x?x?xf32>
+//       CHECK:       %[[c:.*]] = linalg.load %{{.*}}[%[[i]], %[[k]], %[[j]]] : !linalg.view<?x?x?xf32>
+//       CHECK:       %[[res:.*]]:2 = call @foo(%[[a]], %[[b]], %[[c]]) : (f32, f32, f32) -> (f32, f32)
+//       CHECK:       linalg.store %[[res]]#0, %{{.*}}[%[[i]], %[[j]], %[[k]]] : !linalg.view<?x?x?xf32>
+//       CHECK:       linalg.store %[[res]]#1, %{{.*}}[%[[i]], %[[k]], %[[j]]] : !linalg.view<?x?x?xf32>
index 2a3a3c5..13934c9 100644 (file)
@@ -164,3 +164,26 @@ func @const_buffer_view(%arg0: index, %arg1: index, %arg2: index) {
   %c2 = linalg.view %c0[%c1] : !linalg.buffer<17xf32> -> !linalg.view<?xf32>
   return
 }
+
+#accesses = [
+  (i, j, k) -> (j, i),
+  (i, j, k) -> (i, k, i + j)
+]
+#trait = {
+  indexing_maps = #accesses,
+  n_views = [1, 1],
+  n_loop_types = [3, 0, 0],
+  fun = @foo,
+  library_call = "external_function_name"
+}
+func @foo(%0: vector<3x4xi4>, %1: f32) -> f32 {
+  %f0 = constant 0.0 : f32
+  return %f0 : f32
+}
+func @generic(%arg0: !linalg.view<?x?xvector<3x4xi4>>, %arg1: !linalg.view<?x?x?xf32>) {
+  linalg.generic #trait %arg0, %arg1 {foo = 1} : !linalg.view<?x?xvector<3x4xi4>>, !linalg.view<?x?x?xf32>
+  return
+}
+// CHECK-LABEL: func @foo
+// CHECK-LABEL: func @generic
+//       CHECK:   linalg.generic {fun = @foo, indexing_maps = [#map2, #map3], library_call = "external_function_name", n_loop_types = [3, 0, 0], n_views = [1, 1]} %{{.*}}, %{{.*}} {foo = 1 : i64}: !linalg.view<?x?xvector<3x4xi4>>, !linalg.view<?x?x?xf32>