bodyRegion->push_back(new Block);
Block &bodyBlock = bodyRegion->front();
bodyBlock.addArgument(builder.getIndexType());
- bodyBlock.addArgument(SizeType::get(builder.getContext()));
+
+ Type elementType;
+ if (auto tensorType = shape.getType().dyn_cast<TensorType>())
+ elementType = tensorType.getElementType();
+ else
+ elementType = SizeType::get(builder.getContext());
+ bodyBlock.addArgument(elementType);
for (Type initValType : initVals.getTypes()) {
bodyBlock.addArgument(initValType);
#include "PassDetail.h"
#include "mlir/Dialect/Shape/IR/Shape.h"
#include "mlir/Dialect/Shape/Transforms/Passes.h"
+#include "mlir/Dialect/StandardOps/IR/Ops.h"
#include "mlir/IR/Builders.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Pass/Pass.h"
NumElementsOpConverter::matchAndRewrite(NumElementsOp op,
PatternRewriter &rewriter) const {
auto loc = op.getLoc();
- Value init = rewriter.create<ConstSizeOp>(loc, rewriter.getIndexAttr(1));
+ Type valueType = op.getResult().getType();
+ Value init = op.getDialect()
+ ->materializeConstant(rewriter, rewriter.getIndexAttr(1),
+ valueType, loc)
+ ->getResult(0);
ReduceOp reduce = rewriter.create<ReduceOp>(loc, op.shape(), init);
// Generate reduce operator.
Block *body = reduce.getBody();
OpBuilder b = OpBuilder::atBlockEnd(body);
- Value product = b.create<MulOp>(loc, b.getType<SizeType>(),
- body->getArgument(1), body->getArgument(2));
+ Value product = b.create<MulOp>(loc, valueType, body->getArgument(1),
+ body->getArgument(2));
b.create<YieldOp>(loc, product);
rewriter.replaceOp(op, reduce.result());
populateShapeRewritePatterns(&ctx, patterns);
ConversionTarget target(getContext());
- target.addLegalDialect<ShapeDialect>();
+ target.addLegalDialect<ShapeDialect, StandardOpsDialect>();
target.addIllegalOp<NumElementsOp>();
if (failed(mlir::applyPartialConversion(getFunction(), target, patterns)))
signalPassFailure();
// CHECK: }
// CHECK: return [[NUM_ELEMENTS]] : !shape.size
+// -----
+
+// CHECK-LABEL: func @num_elements_to_reduce_on_index
+// CHECK-SAME: ([[ARG:%.*]]: tensor<?xindex>) -> index
+func @num_elements_to_reduce_on_index(%shape : tensor<?xindex>) -> index {
+ %num_elements = shape.num_elements %shape : tensor<?xindex> -> index
+ return %num_elements : index
+}
+// CHECK: [[C1:%.*]] = constant 1 : index
+// CHECK: [[NUM_ELEMENTS:%.*]] = shape.reduce([[ARG]], [[C1]]) : tensor<?xindex> -> index
+// CHECK: ^bb0({{.*}}: index, [[DIM:%.*]]: index, [[ACC:%.*]]: index
+// CHECK: [[NEW_ACC:%.*]] = shape.mul [[DIM]], [[ACC]]
+// CHECK: shape.yield [[NEW_ACC]] : index
+// CHECK: }
+// CHECK: return [[NUM_ELEMENTS]] : index