[MLIR][Shape] Support transforming shape.num_elements on tensors
authorStephan Herhut <herhut@google.com>
Tue, 28 Jul 2020 11:09:45 +0000 (13:09 +0200)
committerStephan Herhut <herhut@google.com>
Tue, 28 Jul 2020 12:13:06 +0000 (14:13 +0200)
The current transformation to shape.reduce does not support tensor values.
This adds the required changes to make that work, including fixing the builder
for shape.reduce.

Differential Revision: https://reviews.llvm.org/D84744

mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp
mlir/test/Dialect/Shape/shape-to-shape.mlir

index 4887c87..3c71e34 100644 (file)
@@ -834,7 +834,13 @@ void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape,
   bodyRegion->push_back(new Block);
   Block &bodyBlock = bodyRegion->front();
   bodyBlock.addArgument(builder.getIndexType());
-  bodyBlock.addArgument(SizeType::get(builder.getContext()));
+
+  Type elementType;
+  if (auto tensorType = shape.getType().dyn_cast<TensorType>())
+    elementType = tensorType.getElementType();
+  else
+    elementType = SizeType::get(builder.getContext());
+  bodyBlock.addArgument(elementType);
 
   for (Type initValType : initVals.getTypes()) {
     bodyBlock.addArgument(initValType);
index bb2b03b..a84fad1 100644 (file)
@@ -9,6 +9,7 @@
 #include "PassDetail.h"
 #include "mlir/Dialect/Shape/IR/Shape.h"
 #include "mlir/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/PatternMatch.h"
 #include "mlir/Pass/Pass.h"
@@ -32,14 +33,18 @@ LogicalResult
 NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
                                         PatternRewriter &rewriter) const {
   auto loc = op.getLoc();
-  Value init = rewriter.create<ConstSizeOp>(loc, rewriter.getIndexAttr(1));
+  Type valueType = op.getResult().getType();
+  Value init = op.getDialect()
+                   ->materializeConstant(rewriter, rewriter.getIndexAttr(1),
+                                         valueType, loc)
+                   ->getResult(0);
   ReduceOp reduce = rewriter.create<ReduceOp>(loc, op.shape(), init);
 
   // Generate reduce operator.
   Block *body = reduce.getBody();
   OpBuilder b = OpBuilder::atBlockEnd(body);
-  Value product = b.create<MulOp>(loc, b.getType<SizeType>(),
-                                  body->getArgument(1), body->getArgument(2));
+  Value product = b.create<MulOp>(loc, valueType, body->getArgument(1),
+                                  body->getArgument(2));
   b.create<YieldOp>(loc, product);
 
   rewriter.replaceOp(op, reduce.result());
@@ -60,7 +65,7 @@ void ShapeToShapeLowering::runOnFunction() {
   populateShapeRewritePatterns(&ctx, patterns);
 
   ConversionTarget target(getContext());
-  target.addLegalDialect<ShapeDialect>();
+  target.addLegalDialect<ShapeDialect, StandardOpsDialect>();
   target.addIllegalOp<NumElementsOp>();
   if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
     signalPassFailure();
index d1b00bc..481d682 100644 (file)
@@ -14,3 +14,18 @@ func @num_elements_to_reduce(%shape : !shape.shape) -> !shape.size {
 // CHECK: }
 // CHECK: return [[NUM_ELEMENTS]] : !shape.size
 
+// -----
+
+// CHECK-LABEL: func @num_elements_to_reduce_on_index
+// CHECK-SAME:  ([[ARG:%.*]]: tensor<?xindex>) -> index
+func @num_elements_to_reduce_on_index(%shape : tensor<?xindex>) -> index {
+  %num_elements = shape.num_elements %shape : tensor<?xindex> -> index
+  return %num_elements : index
+}
+// CHECK: [[C1:%.*]] = constant 1 : index
+// CHECK: [[NUM_ELEMENTS:%.*]] = shape.reduce([[ARG]], [[C1]]) : tensor<?xindex> -> index
+// CHECK: ^bb0({{.*}}: index, [[DIM:%.*]]: index, [[ACC:%.*]]: index
+// CHECK:   [[NEW_ACC:%.*]] = shape.mul [[DIM]], [[ACC]]
+// CHECK:   shape.yield [[NEW_ACC]] : index
+// CHECK: }
+// CHECK: return [[NUM_ELEMENTS]] : index