Lower linalg.indexed_generic to loops.
authorAlexander Belyaev <pifon@google.com>
Mon, 18 Nov 2019 23:39:56 +0000 (15:39 -0800)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Tue, 19 Nov 2019 00:55:15 +0000 (16:55 -0800)
PiperOrigin-RevId: 281169885

mlir/lib/Dialect/Linalg/IR/LinalgOps.cpp
mlir/lib/Dialect/Linalg/Transforms/LowerToLLVMDialect.cpp
mlir/lib/Dialect/Linalg/Transforms/LowerToLoops.cpp
mlir/test/Dialect/Linalg/invalid.mlir
mlir/test/Dialect/Linalg/loops.mlir

index 867b8ab..b137764 100644 (file)
@@ -108,12 +108,11 @@ static ParseResult parseGenericOp(OpAsmParser &parser, OperationState &result) {
 }
 
 template <typename GenericOpType>
-LogicalResult verifyBlockArgs(GenericOpType op, Block &block, unsigned nViews,
-                              unsigned nLoops, unsigned nInputViews);
+LogicalResult verifyBlockArgs(GenericOpType op, Block &block);
 
-template <>
-LogicalResult verifyBlockArgs(GenericOp op, Block &block, unsigned nViews,
-                              unsigned nLoops, unsigned nInputViews) {
+template <> LogicalResult verifyBlockArgs(GenericOp op, Block &block) {
+  auto nViews = op.getNumInputsAndOutputs();
+  auto nInputViews = op.getNumInputs();
   if (block.getNumArguments() != nViews)
     return op.emitError(
         "op expected number of block arguments to match number of views");
@@ -129,10 +128,10 @@ LogicalResult verifyBlockArgs(GenericOp op, Block &block, unsigned nViews,
   return success();
 }
 
-template <>
-LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block,
-                              unsigned nViews, unsigned nLoops,
-                              unsigned nInputViews) {
+template <> LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block) {
+  auto nInputViews = op.getNumInputs();
+  auto nLoops = op.getNumLoops();
+  auto nViews = op.getNumInputsAndOutputs();
   if (block.getNumArguments() != nViews + nLoops)
     return op.emitError(
         "op expected number of block arguments to match number of views + "
@@ -158,6 +157,76 @@ LogicalResult verifyBlockArgs(IndexedGenericOp op, Block &block,
 }
 
 template <typename GenericOpType>
+LogicalResult verifyFuncArgs(GenericOpType op, FunctionType funType);
+
+template <> LogicalResult verifyFuncArgs(GenericOp op, FunctionType funType) {
+  auto nViews = op.getNumInputsAndOutputs();
+  auto nInputViews = op.getNumInputs();
+  if (funType.getNumInputs() != nViews)
+    return op.emitError("op expected fun arguments to match number of views");
+  if (funType.getNumResults() != op.getNumOutputs())
+    return op.emitError(
+        "op expected fun results to match number of output views");
+
+  for (auto en : llvm::enumerate(op.indexing_maps())) {
+    auto idx = en.index();
+    auto view = (idx < nInputViews) ? op.getInputViewType(idx)
+                                    : op.getOutputViewType(idx - nInputViews);
+    if (funType.getInput(idx) != view.getElementType())
+      return op.emitError("op expected fun argument ")
+             << idx << " of the same type as elemental type "
+             << view.getElementType() << " of view " << idx;
+
+    if (idx >= nInputViews) {
+      auto resultIdx = idx - nInputViews;
+      if (funType.getResult(resultIdx) != view.getElementType())
+        return op.emitError("op expected fun result ")
+               << resultIdx << " of the same type as elemental type "
+               << view.getElementType() << " of view " << idx;
+    }
+  }
+  return success();
+}
+
+template <>
+LogicalResult verifyFuncArgs(IndexedGenericOp op, FunctionType funType) {
+  auto nLoops = op.getNumLoops();
+  auto nInputViews = op.getNumInputs();
+  auto nOutputs = op.getNumOutputs();
+  auto nViews = op.getNumInputsAndOutputs();
+  if (funType.getNumInputs() != nViews + nLoops)
+    return op.emitError(
+        "op expected fun arguments to match number of views + number of loops");
+  if (funType.getNumResults() != nOutputs)
+    return op.emitError(
+        "op expected fun results to match number of output views");
+  for (unsigned i = 0; i < nLoops; ++i) {
+    if (!funType.getInput(i).isIndex())
+      return op.emitError("op expected fun argument ")
+             << i << " to be of IndexType";
+  }
+  for (auto en : llvm::enumerate(op.indexing_maps())) {
+    auto idx = en.index();
+    auto funIdx = nLoops + idx;
+    auto view = (idx < nInputViews) ? op.getInputViewType(idx)
+                                    : op.getOutputViewType(idx - nInputViews);
+    if (funType.getInput(funIdx) != view.getElementType())
+      return op.emitError("op expected fun argument ")
+             << funIdx << " of the same type as elemental type "
+             << view.getElementType() << " of view " << idx;
+
+    if (idx >= nInputViews) {
+      auto resultIdx = idx - nInputViews;
+      if (funType.getResult(resultIdx) != view.getElementType())
+        return op.emitError("op expected fun result ")
+               << resultIdx << " of the same type as elemental type "
+               << view.getElementType() << " of view " << idx;
+    }
+  }
+  return success();
+}
+
+template <typename GenericOpType>
 LogicalResult verifyGenericOp(GenericOpType op) {
   auto nInputViews = op.getNumInputs();
   auto nLoops = op.getNumLoops();
@@ -171,20 +240,14 @@ LogicalResult verifyGenericOp(GenericOpType op) {
   if (!region.empty()) {
     if (region.getBlocks().size() != 1)
       return op.emitError("op expected region with 1 block");
-
-    auto &block = region.getBlocks().front();
-    if (failed(verifyBlockArgs(op, block, nViews, nLoops, nInputViews))) {
+    if (failed(verifyBlockArgs(op, region.getBlocks().front())))
       return failure();
-    }
   } else {
     if (!funOp || !funOp.getType())
       return op.emitError(
           "op expected fun attribute to refer to a defined symbol");
-    if (funType.getNumInputs() != nViews)
-      return op.emitError("op expected fun arguments to match number of views");
-    if (funType.getNumResults() != op.getNumOutputs())
-      return op.emitError(
-          "op expected fun results to match number of output views");
+    if (failed(verifyFuncArgs(op, funType)))
+      return failure();
   }
 
   SmallVector<AffineMap, 4> indexingMaps;
@@ -215,19 +278,6 @@ LogicalResult verifyGenericOp(GenericOpType op) {
     if (m.getNumResults() != view.getRank())
       return op.emitError("op expected indexing_map #")
              << idx << " results to match view rank: " << view;
-
-    if (funType) {
-      if (funType.getInput(idx) != view.getElementType())
-        return op.emitError("op expected fun argument ")
-               << idx
-               << " to match view element type: " << view.getElementType();
-
-      if (idx >= nInputViews)
-        if (funType.getResult(idx - nInputViews) != view.getElementType())
-          return op.emitError("op expected fun result ")
-                 << idx << " to match output view element type: "
-                 << view.getElementType();
-    }
   }
 
   auto concatMap = concatAffineMaps(indexingMaps);
@@ -718,6 +768,13 @@ SmallVector<AffineMap, 4> mlir::linalg::loopToOperandRangesMaps(Operation *op) {
       res.push_back(genericOp.getIndexingMap(i));
     }
     return res;
+  } else if (auto indexedGenericOp = dyn_cast<IndexedGenericOp>(op)) {
+    SmallVector<AffineMap, 4> res;
+    unsigned nViews = indexedGenericOp.getNumInputsAndOutputs();
+    res.reserve(nViews);
+    for (unsigned i = 0, e = nViews; i < e; ++i)
+      res.push_back(indexedGenericOp.getIndexingMap(i));
+    return res;
   }
   llvm_unreachable("Missing loopToOperandRangesMaps for op");
 }
index 6aace80..6e97a7a 100644 (file)
@@ -474,10 +474,11 @@ populateLinalgToStandardConversionPatterns(OwningRewritePatternList &patterns,
                                            MLIRContext *ctx) {
   // TODO(ntv) ConvOp conversion needs to export a descriptor with relevant
   // attribute values such as kernel striding and dilation.
-  patterns.insert<CopyTransposeConversion, LinalgOpConversion<CopyOp>,
-                  LinalgOpConversion<DotOp>, LinalgOpConversion<FillOp>,
-                  LinalgOpConversion<MatvecOp>, LinalgOpConversion<MatmulOp>,
-                  LinalgOpConversion<ConvOp>, LinalgOpConversion<GenericOp>>(
+  patterns.insert<CopyTransposeConversion, LinalgOpConversion<ConvOp>,
+                  LinalgOpConversion<CopyOp>, LinalgOpConversion<DotOp>,
+                  LinalgOpConversion<FillOp>, LinalgOpConversion<GenericOp>,
+                  LinalgOpConversion<IndexedGenericOp>,
+                  LinalgOpConversion<MatmulOp>, LinalgOpConversion<MatvecOp>>(
       ctx);
 }
 
index 058dc07..0bf4cea 100644 (file)
@@ -244,14 +244,14 @@ public:
     SmallVector<Value *, 4> indexedValues(nInputs + nOutputs);
 
     // 1.a. Emit std_load from input views.
-    for (unsigned i = 0, e = nInputs; i < e; ++i) {
+    for (unsigned i = 0; i < nInputs; ++i) {
       ValueHandleArray indexing(foldedAffineApplies(
           b, loc, genericOp.getInputIndexingMap(i), allIvs, folder));
       indexedValues[i] = std_load(genericOp.getInput(i), indexing);
     }
 
     // 1.b. Emit std_load from output views.
-    for (unsigned i = 0, e = nOutputs; i < e; ++i) {
+    for (unsigned i = 0; i < nOutputs; ++i) {
       ValueHandleArray indexing(foldedAffineApplies(
           b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
       indexedValues[nInputs + i] = std_load(genericOp.getOutput(i), indexing);
@@ -264,49 +264,138 @@ public:
       assert(callOp->getNumResults() == genericOp.getNumOutputs());
 
       // 3. Emit std_store.
-      for (unsigned i = 0, e = nOutputs; i < e; ++i) {
+      for (unsigned i = 0; i < nOutputs; ++i) {
         ValueHandleArray indexing(foldedAffineApplies(
             b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
         std_store(callOp->getResult(i), genericOp.getOutput(i), indexing);
       }
-    } else {
-      // TODO(ntv): When a region inliner exists, use it.
-      // 2. Inline region, currently only works for a single basic block.
-      BlockAndValueMapping map;
-      auto &block = genericOp.region().front();
-      for (auto it : llvm::zip(block.getArguments(), indexedValues))
+      return;
+    }
+    // TODO(ntv): When a region inliner exists, use it.
+    // 2. Inline region, currently only works for a single basic block.
+    BlockAndValueMapping map;
+    auto &block = genericOp.region().front();
+    for (auto it : llvm::zip(block.getArguments(), indexedValues))
+      map.map(std::get<0>(it), std::get<1>(it));
+    for (auto &op : block.without_terminator()) {
+      assert(op.getNumRegions() == 0);
+      auto *newOp = b.clone(op, map);
+      for (auto it : llvm::zip(op.getResults(), newOp->getResults()))
         map.map(std::get<0>(it), std::get<1>(it));
-      for (auto &op : block) {
-        // Skip terminator.
-        if (&op == &block.back())
-          continue;
-        assert(op.getNumRegions() == 0);
-        auto *newOp = b.clone(op, map);
-        for (auto it : llvm::zip(op.getResults(), newOp->getResults()))
-          map.map(std::get<0>(it), std::get<1>(it));
-      }
+    }
 
-      // 3. Emit std_store.
-      auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
-      assert(yieldOp->getNumOperands() == nOutputs);
-      for (unsigned i = 0, e = nOutputs; i < e; ++i) {
-        ValueHandleArray indexing(foldedAffineApplies(
-            b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
-        std_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i),
-                  indexing);
-      }
+    // 3. Emit std_store.
+    auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
+    assert(yieldOp->getNumOperands() == nOutputs);
+    for (unsigned i = 0; i < nOutputs; ++i) {
+      ValueHandleArray indexing(foldedAffineApplies(
+          b, loc, genericOp.getOutputIndexingMap(i), allIvs, folder));
+      std_store(map.lookup(yieldOp->getOperand(i)), genericOp.getOutput(i),
+                indexing);
     }
   }
 };
 
+// Emits the MLIR for the scalar part of the indexed generic op by:
+//   1. Emitting std_load and std_store ops for each input and output view in
+//      order. This is achieved by applying the appropriate input or output map
+//      to the enclosing induction variables.
+//   2. Emitting a call to `op.fun()` that takes as arguments the induction
+//      variables and the scalars from point 1. above.
+//   3. Emitting std_store to store the results of 2. to the output views.
+//
+// An example output may resemble:
+//
+// ```
+//    loop.for %i = %c0 to %0 step %c1 {
+//      loop.for %j = %c0 to %1 step %c1 {
+//        loop.for %k = %c0 to %4 step %c1 {
+//          %11 = load %arg0[%i, %j] :
+//            memref<?x?xf32, stride_specification>
+//          %12 = load %arg1[%i, %j, %k] :
+//            memref<?x?x?xf32, stride_specification>
+//          %13 = load %arg2[%i, %k, %j] :
+//            memref<?x?x?xf32, stride_specification>
+//          %14:2 = call @foo(%i, %j, %k, %11, %12, %13) :
+//            (index, index, index, f32, f32, f32) -> (f32, f32)
+//          store %14#0, %arg1[%i, %j, %k] :
+//            memref<?x?x?Xf32, stride_specification>
+//          store %14#1, %arg2[%i, %k, %j] :
+//            memref<?x?x?Xf32, stride_specification>
+//       }
+//      }
+//    }
+// ```
 template <typename IndexedValueType>
 class LinalgScopedEmitter<IndexedValueType, IndexedGenericOp> {
 public:
   static void emitScalarImplementation(ArrayRef<Value *> allIvs,
-                                       IndexedGenericOp genericOp,
+                                       IndexedGenericOp indexedGenericOp,
                                        OperationFolder *folder) {
-    // This is just a shim to make Linalg compile.
-    // TODO(pifon): Implement lowering after IndexedGenericOp def is submitted.
+    auto b = ScopedContext::getBuilder();
+    auto loc = ScopedContext::getLocation();
+    using edsc::intrinsics::detail::ValueHandleArray;
+    unsigned nInputs = indexedGenericOp.getNumInputs();
+    unsigned nOutputs = indexedGenericOp.getNumOutputs();
+    unsigned nLoops = allIvs.size();
+    SmallVector<Value *, 4> indexedValues(nLoops + nInputs + nOutputs);
+
+    for (unsigned i = 0; i < nLoops; ++i) {
+      indexedValues[i] = allIvs[i];
+    }
+
+    // 1.a. Emit std_load from input views.
+    for (unsigned i = 0; i < nInputs; ++i) {
+      ValueHandleArray indexing(foldedAffineApplies(
+          b, loc, indexedGenericOp.getInputIndexingMap(i), allIvs, folder));
+      indexedValues[nLoops + i] =
+          std_load(indexedGenericOp.getInput(i), indexing);
+    }
+
+    // 1.b. Emit std_load from output views.
+    for (unsigned i = 0; i < nOutputs; ++i) {
+      ValueHandleArray indexing(foldedAffineApplies(
+          b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs, folder));
+      indexedValues[nLoops + nInputs + i] =
+          std_load(indexedGenericOp.getOutput(i), indexing);
+    }
+
+    if (auto funcOp = indexedGenericOp.getFunction()) {
+      // 2. Emit call.
+      Operation *callOp = call(funcOp, indexedValues);
+      assert(callOp->getNumResults() == indexedGenericOp.getNumOutputs());
+
+      // 3. Emit std_store.
+      for (unsigned i = 0; i < nOutputs; ++i) {
+        ValueHandleArray indexing(foldedAffineApplies(
+            b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs, folder));
+        std_store(callOp->getResult(i), indexedGenericOp.getOutput(i),
+                  indexing);
+      }
+      return;
+    }
+    // TODO(ntv): When a region inliner exists, use it.
+    // 2. Inline region, currently only works for a single basic block.
+    BlockAndValueMapping map;
+    auto &block = indexedGenericOp.region().front();
+    for (auto it : llvm::zip(block.getArguments(), indexedValues))
+      map.map(std::get<0>(it), std::get<1>(it));
+    for (auto &op : block.without_terminator()) {
+      assert(op.getNumRegions() == 0);
+      auto *newOp = b.clone(op, map);
+      for (auto it : llvm::zip(op.getResults(), newOp->getResults()))
+        map.map(std::get<0>(it), std::get<1>(it));
+    }
+
+    // 3. Emit std_store.
+    auto *yieldOp = cast<YieldOp>(block.back()).getOperation();
+    assert(yieldOp->getNumOperands() == nOutputs);
+    for (unsigned i = 0; i < nOutputs; ++i) {
+      ValueHandleArray indexing(foldedAffineApplies(
+          b, loc, indexedGenericOp.getOutputIndexingMap(i), allIvs, folder));
+      std_store(map.lookup(yieldOp->getOperand(i)),
+                indexedGenericOp.getOutput(i), indexing);
+    }
   }
 };
 
index 57b9f09..603c67e 100644 (file)
@@ -120,42 +120,42 @@ func @generic_mismatched_num_returns(%arg0: memref<f32>) {
 
 func @foo(%0: i32) -> i32 { return %0: i32 }
 
-func @generic_symbol_in_map(%arg0: memref<f32>) {
+func @generic_symbol_in_map(%arg0: memref<i32>) {
   // expected-error @+1 {{op expected indexing_map #0 to have no symbols}}
   linalg.generic {
     fun = @foo,
     indexing_maps =  [ ()[N] -> (0) ],
     n_views = [0, 1],
     n_loop_types = [1, 0, 0]
-  } %arg0: memref<f32>
+  } %arg0: memref<i32>
 }
 
 // -----
 
 func @foo(%0: i32) -> i32 { return %0: i32 }
 
-func @generic_wrong_dim_in_map(%arg0: memref<f32>) {
+func @generic_wrong_dim_in_map(%arg0: memref<i32>) {
   // expected-error @+1 {{op expected indexing_map #0 to have 1 dim(s) to match the number of loops}}
   linalg.generic {
     fun = @foo,
     indexing_maps =  [ () -> (0) ],
     n_views = [0, 1],
     n_loop_types = [1, 0, 0]
-  } %arg0: memref<f32>
+  } %arg0: memref<i32>
 }
 
 // -----
 
 func @foo(%0: i32) -> i32 { return %0: i32 }
 
-func @generic_zero_d_view(%arg0: memref<f32>) {
-  // expected-error @+1 {{op expected indexing_map #0 to be 0 to match 0-D view: 'memref<f32>'}}
+func @generic_zero_d_view(%arg0: memref<i32>) {
+  // expected-error @+1 {{op expected indexing_map #0 to be 0 to match 0-D view: 'memref<i32>'}}
   linalg.generic {
     fun = @foo,
     indexing_maps =  [ () -> (1) ],
     n_views = [0, 1],
     n_loop_types = [0, 0, 0]
-  } %arg0: memref<f32>
+  } %arg0: memref<i32>
 }
 
 // -----
@@ -180,7 +180,7 @@ func @foo(%0: i32) -> f32 {
 }
 
 func @generic_fun_arg_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)>) {
-  // expected-error @+1 {{op expected fun argument 0 to match view element type: 'f32'}}
+  // expected-error @+1 {{op expected fun argument 0 of the same type as elemental type 'f32' of view 0}}
   linalg.generic {
     fun = @foo,
     indexing_maps =  [ () -> (0) ],
@@ -197,7 +197,7 @@ func @foo(%0: f32) -> i4 {
 }
 
 func @generic_fun_result_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)>) {
-  // expected-error @+1 {{op expected fun result 0 to match output view element type: 'f32'}}
+  // expected-error @+1 {{op expected fun result 0 of the same type as elemental type 'f32' of view 0}}
   linalg.generic {
     fun = @foo,
     indexing_maps =  [ () -> (0) ],
@@ -308,6 +308,82 @@ func @indexed_generic_block_arg_type(%arg0: memref<f32>) {
 
 // -----
 
+func @foo(%f: f32) -> (f32) {
+  return %f : f32
+}
+func @indexed_generic_fun_arg_count(%arg0: memref<f32>) {
+  // expected-error @+1 {{op expected fun arguments to match number of views + number of loops}}
+  linalg.indexed_generic {
+    indexing_maps =  [ (d0) -> (d0) ],
+    n_views = [0, 1],
+    n_loop_types = [1, 0, 0],
+    fun = @foo
+  } %arg0:  memref<f32>
+}
+
+// -----
+
+func @foo(%i: i32, %val: f32) -> (f32) {
+  return %val : f32
+}
+func @indexed_generic_fun_induction_var_arg_type(%arg0: memref<f32>) {
+  // expected-error @+1 {{op expected fun argument 0 to be of IndexType}}
+  linalg.indexed_generic {
+    n_views = [0, 1],
+    n_loop_types = [1, 0, 0],
+    indexing_maps = [ (i) -> (i) ],
+    fun = @foo
+  } %arg0 : memref<f32>
+}
+
+// -----
+
+func @foo(%i: index, %val: i1) -> (i1) {
+  return %val : i1
+}
+func @indexed_generic_fun_arg_type(%arg0: memref<f32>) {
+  // expected-error @+1 {{op expected fun argument 1 of the same type as elemental type 'f32' of view 0}}
+  linalg.indexed_generic {
+    indexing_maps =  [ (d0) -> (d0) ],
+    n_views = [0, 1],
+    n_loop_types = [1, 0, 0],
+    fun = @foo
+  } %arg0: memref<f32>
+}
+
+// -----
+
+func @foo(%i: index, %val: i1) -> (i1, i1) {
+  return %val, %val : i1, i1
+}
+func @indexed_generic_fun_result_count(%arg0: memref<f32>) {
+  // expected-error @+1 {{op expected fun results to match number of output views}}
+  linalg.indexed_generic {
+    indexing_maps =  [ (d0) -> (d0) ],
+    n_views = [0, 1],
+    n_loop_types = [1, 0, 0],
+    fun = @foo
+  } %arg0: memref<f32>
+}
+
+// -----
+
+func @foo(%i: index, %val: i32) -> (f32) {
+  %val_float = sitofp %val : i32 to f32
+  return %val_float : f32
+}
+func @indexed_generic_fun_result_count(%arg0: memref<i32>) {
+  // expected-error @+1 {{op expected fun result 0 of the same type as elemental type 'i32' of view 0}}
+  linalg.indexed_generic {
+    indexing_maps =  [ (d0) -> (d0) ],
+    n_views = [0, 1],
+    n_loop_types = [1, 0, 0],
+    fun = @foo
+  } %arg0: memref<i32>
+}
+
+// -----
+
 func @generic_fun_result_0_element_type(%arg0: memref<?xf32, (i)[off]->(off + i)>) {
   // expected-error @+8 {{type of return operand 0 ('i1') doesn't match view element type ('f32')}}
   linalg.generic {
index d62a288..93cf69f 100644 (file)
@@ -273,3 +273,82 @@ func @generic_region(%arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>, %arg1:
 //       CHECK:       %[[e:.*]] = addf %[[c]], %[[d]] : f32
 //       CHECK:       store %[[d]], %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref<?x?x?xf32, #[[strided3D]]>
 //       CHECK:       store %[[e]], %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
+
+func @indexed_foo(%i: index, %j: index, %k: index, %0: f32, %1: f32, %2: f32) -> (f32, f32) {
+  %i_int = index_cast %i: index to i32
+  %i_float = sitofp %i_int : i32 to f32
+  return %i_float, %i_float : f32, f32
+}
+#trait3 = {
+  n_views = [1, 2],
+  n_loop_types = [3, 0, 0],
+  indexing_maps = #accesses,
+  fun = @indexed_foo,
+  library_call = "some_external_function_name_1",
+  doc = "b(i,j,k), c(i,k,j) = foo(a(i, j), b(i,j,k), c(i,k,j))"
+}
+func @indexed_generic_function(
+         %arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+         %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
+         %arg2: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
+  linalg.indexed_generic #trait3 %arg0, %arg1, %arg2:
+    memref<?x?xf32, offset: ?, strides: [?, 1]>,
+    memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
+    memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
+  return
+}
+// CHECK-LABEL: @indexed_foo
+// CHECK-LABEL: @indexed_generic_function
+// CHECK: loop.for %[[i:.*]] = {{.*}}
+// CHECK:   loop.for %[[j:.*]] = {{.*}}
+// CHECK:     loop.for %[[k:.*]] = {{.*}}
+// CHECK:       %[[a:.*]] = load %{{.*}}[%[[i]], %[[j]]] : memref<?x?xf32, #[[strided2D]]>
+// CHECK:       %[[b:.*]] = load %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref<?x?x?xf32, #[[strided3D]]>
+// CHECK:       %[[c:.*]] = load %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
+// CHECK:       %[[res:.*]]:2 = call @indexed_foo(%[[i]], %[[j]], %[[k]], %[[a]], %[[b]], %[[c]]) : (index, index, index, f32, f32, f32) -> (f32, f32)
+// CHECK:       store %[[res]]#0, %{{.*}}[%[[i]], %[[j]], %[[k]]] : memref<?x?x?xf32, #[[strided3D]]>
+// CHECK:       store %[[res]]#1, %{{.*}}[%[[i]], %[[k]], %[[j]]] : memref<?x?x?xf32, #[[strided3D]]>
+
+#trait4 = {
+  n_views = [1, 2],
+  n_loop_types = [3, 0, 0],
+  indexing_maps = #accesses,
+  library_call = "some_external_function_name_2",
+  doc = "B(i,j,k), C(i,k,j) = foo(A(i, j) * B(i,j,k), i * j * k + C(i,k,j))"
+}
+func @indexed_generic_region(
+        %arg0: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+        %arg1: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
+        %arg2: memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>) {
+  linalg.indexed_generic #trait4 %arg0, %arg1, %arg2 {
+    ^bb0(%i: index, %j: index, %k: index, %a: f32, %b: f32, %c: f32):
+      %result_1 = mulf %a, %b : f32
+
+      %ij = addi %i, %j : index
+      %ijk = addi %ij, %k : index
+      %ijk_int = index_cast %ijk : index to i32
+      %ijk_float = sitofp %ijk_int : i32 to f32
+
+      %result_2 = addf %c, %ijk_float : f32
+      linalg.yield %result_1, %result_2 : f32, f32
+  }: memref<?x?xf32, offset: ?, strides: [?, 1]>,
+     memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>,
+     memref<?x?x?xf32, offset: ?, strides: [?, ?, 1]>
+  return
+}
+
+// CHECK-LABEL: @indexed_generic_region
+// CHECK: loop.for %[[i:.*]] = {{.*}}
+// CHECK:   loop.for %[[j:.*]] = {{.*}}
+// CHECK:     loop.for %[[k:.*]] = {{.*}}
+// CHECK:       %[[a:.*]] = load %{{.*}}[%[[i]], %[[j]]]
+// CHECK:       %[[b:.*]] = load %{{.*}}[%[[i]], %[[j]], %[[k]]]
+// CHECK:       %[[c:.*]] = load %{{.*}}[%[[i]], %[[k]], %[[j]]]
+// CHECK:       %[[result_1:.*]] = mulf %[[a]], %[[b]] : f32
+// CHECK:       %[[ij:.*]] = addi %[[i]], %[[j]] : index
+// CHECK:       %[[ijk:.*]] = addi %[[ij]], %[[k]] : index
+// CHECK:       %[[ijk_int:.*]] = index_cast %[[ijk]] : index to i32
+// CHECK:       %[[ijk_float:.*]] = sitofp %[[ijk_int]] : i32 to f32
+// CHECK:       %[[result_2:.*]] = addf %[[c]], %[[ijk_float]] : f32
+// CHECK:       store %[[result_1]], %{{.*}}[%[[i]], %[[j]], %[[k]]]
+// CHECK:       store %[[result_2]], %{{.*}}[%[[i]], %[[k]], %[[j]]]