From 783a351785c14b7c2eb9f65bd40d37be11cbf38b Mon Sep 17 00:00:00 2001 From: Frederik Gossen Date: Fri, 24 Jul 2020 13:24:23 +0000 Subject: [PATCH] [MLIR][Shape] Allow `shape.mul` to operate in indices Differential Revision: https://reviews.llvm.org/D84437 --- mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td | 23 ++++++++----- mlir/lib/Dialect/Shape/IR/Shape.cpp | 39 +++++++++++++++++----- .../Shape/Transforms/ShapeToShapeLowering.cpp | 4 +-- .../ShapeToStandard/shape-to-standard.mlir | 15 +++++++-- mlir/test/Dialect/Shape/invalid.mlir | 26 +++++++++++++++ mlir/test/Dialect/Shape/ops.mlir | 13 +++++--- 6 files changed, 94 insertions(+), 26 deletions(-) diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index 32d6eba..425cf91 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -307,18 +307,25 @@ def Shape_JoinOp : Shape_Op<"join", [Commutative]> { let results = (outs Shape_ShapeOrSizeType:$result); } -def Shape_MulOp : Shape_Op<"mul", [Commutative, SameOperandsAndResultType]> { - let summary = "Multiplication of sizes"; +def Shape_MulOp : Shape_Op<"mul", [Commutative, NoSideEffect]> { + let summary = "Multiplication of sizes and indices"; let description = [{ - Multiplies two valid sizes as follows: - - lhs * rhs = unknown if either lhs or rhs unknown; - - lhs * rhs = (int)lhs * (int)rhs if both known; + Multiplies two sizes or indices. If either operand is an error it will be + propagated to the result. The operands can be of type `size` or `index`. If + at least one of the operands can hold an error, i.e. if it is of type `size`, + then also the result must be of type `size`. If error propagation is not + possible because both operands are of type `index` then the result must also + be of type `index`. }]; - let arguments = (ins Shape_SizeType:$lhs, Shape_SizeType:$rhs); - let results = (outs Shape_SizeType:$result); + let arguments = (ins Shape_SizeOrIndexType:$lhs, Shape_SizeOrIndexType:$rhs); + let results = (outs Shape_SizeOrIndexType:$result); - let assemblyFormat = "$lhs `,` $rhs attr-dict"; + let assemblyFormat = [{ + $lhs `,` $rhs `:` type($lhs) `,` type($rhs) `->` type($result) attr-dict + }]; + + let verifier = [{ return ::verify(*this); }]; } def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index 3bdc5cc..2f64130 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -28,6 +28,13 @@ static RankedTensorType getExtentTensorType(MLIRContext *ctx) { return RankedTensorType::get({ShapedType::kDynamicSize}, IndexType::get(ctx)); } +static bool isErrorPropagationPossible(ArrayRef operandTypes) { + for (Type ty : operandTypes) + if (ty.isa() || ty.isa() || ty.isa()) + return true; + return false; +} + ShapeDialect::ShapeDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { addOperations< @@ -539,9 +546,7 @@ static LogicalResult verify(GetExtentOp op) { Type shapeTy = op.shape().getType(); Type dimTy = op.dim().getType(); Type extentTy = op.extent().getType(); - bool errorPropagationPossible = - shapeTy.isa() || dimTy.isa(); - if (errorPropagationPossible) { + if (isErrorPropagationPossible({shapeTy, dimTy})) { if (!extentTy.isa()) op.emitError() << "if at least one of the operands can hold error values then the " @@ -593,9 +598,8 @@ void GetExtentOp::build(OpBuilder &builder, OperationState &result, Value shape, //===----------------------------------------------------------------------===// static LogicalResult verify(shape::RankOp op) { - Type argTy = op.shape().getType(); - Type resultTy = op.rank().getType(); - if (argTy.isa() && !resultTy.isa()) + if (op.shape().getType().isa() && + !op.rank().getType().isa()) return op.emitOpError() << "if operand is of type `shape` then the result must be of type " "`size` to propagate potential errors"; @@ -673,6 +677,25 @@ OpFoldResult NumElementsOp::fold(ArrayRef operands) { } //===----------------------------------------------------------------------===// +// MulOp +//===----------------------------------------------------------------------===// + +static LogicalResult verify(MulOp op) { + Type resultTy = op.result().getType(); + if (isErrorPropagationPossible({op.lhs().getType(), op.rhs().getType()})) { + if (!resultTy.isa()) + return op.emitOpError() + << "if at least one of the operands can hold error values then " + "the result must be of type `size` to propagate them"; + } else { + if (resultTy.isa()) + return op.emitError() << "if none of the operands can hold error values " + "then the result must be of type `index`"; + } + return success(); +} + +//===----------------------------------------------------------------------===// // ShapeOfOp //===----------------------------------------------------------------------===// @@ -685,15 +708,13 @@ OpFoldResult ShapeOfOp::fold(ArrayRef) { } static LogicalResult verify(ShapeOfOp op) { - Type argTy = op.arg().getType(); Type resultTy = op.result().getType(); - if (argTy.isa()) { + if (isErrorPropagationPossible(op.arg().getType())) { if (!resultTy.isa()) return op.emitOpError() << "if operand is of type `value_shape` then the result must be " "of type `shape` to propagate potential error shapes"; } else { - assert(argTy.isa()); if (resultTy != getExtentTensorType(op.getContext())) return op.emitOpError() << "if operand is a shaped type then the result " "must be an extent tensor"; diff --git a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp index 467f3d3..bb2b03b 100644 --- a/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp +++ b/mlir/lib/Dialect/Shape/Transforms/ShapeToShapeLowering.cpp @@ -38,8 +38,8 @@ NumElementsOpConverter::matchAndRewrite(NumElementsOp op, // Generate reduce operator. Block *body = reduce.getBody(); OpBuilder b = OpBuilder::atBlockEnd(body); - Value product = - b.create(loc, body->getArgument(1), body->getArgument(2)); + Value product = b.create(loc, b.getType(), + body->getArgument(1), body->getArgument(2)); b.create(loc, product); rewriter.replaceOp(op, reduce.result()); diff --git a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir index 908acab..8236c6f 100644 --- a/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir +++ b/mlir/test/Conversion/ShapeToStandard/shape-to-standard.mlir @@ -24,10 +24,19 @@ func @shape_id(%shape : !shape.shape) -> !shape.shape { // CHECK-LABEL: @binary_ops // CHECK-SAME: (%[[LHS:.*]]: index, %[[RHS:.*]]: index) func @binary_ops(%lhs : !shape.size, %rhs : !shape.size) { + // CHECK: addi %[[LHS]], %[[RHS]] : index %sum = "shape.add"(%lhs, %rhs) : (!shape.size, !shape.size) -> !shape.size - // CHECK-NEXT: addi %[[LHS]], %[[RHS]] : index - %product = shape.mul %lhs, %rhs - // CHECK-NEXT: muli %[[LHS]], %[[RHS]] : index + return +} + +// ----- + +// Lower binary ops. +// CHECK-LABEL: @binary_ops +// CHECK-SAME: (%[[LHS:.*]]: index, %[[RHS:.*]]: index) +func @binary_ops(%lhs : index, %rhs : index) { + // CHECK: muli %[[LHS]], %[[RHS]] : index + %product = shape.mul %lhs, %rhs : index, index -> index return } diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir index d7e9e40..b4900e4 100644 --- a/mlir/test/Dialect/Shape/invalid.mlir +++ b/mlir/test/Dialect/Shape/invalid.mlir @@ -6,6 +6,7 @@ func @reduce_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) { ^bb0(%index: index, %dim: !shape.size): shape.yield %dim : !shape.size } + return } // ----- @@ -18,6 +19,7 @@ func @reduce_op_arg0_wrong_type(%shape : !shape.shape, %init : !shape.size) { : (!shape.size, !shape.size) -> !shape.size shape.yield %new_acc : !shape.size } + return } // ----- @@ -28,6 +30,7 @@ func @reduce_op_arg1_wrong_type(%shape : !shape.shape, %init : !shape.size) { ^bb0(%index: index, %dim: f32, %lci: !shape.size): shape.yield } + return } // ----- @@ -38,6 +41,7 @@ func @reduce_op_arg1_wrong_type(%shape : tensor, %init : index) { ^bb0(%index: index, %dim: f32, %lci: index): shape.yield } + return } // ----- @@ -48,6 +52,7 @@ func @reduce_op_init_type_mismatch(%shape : !shape.shape, %init : f32) { ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size): shape.yield } + return } // ----- @@ -58,6 +63,7 @@ func @yield_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) { ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size): shape.yield %dim, %dim : !shape.size, !shape.size } + return } // ----- @@ -69,6 +75,7 @@ func @yield_op_type_mismatch(%shape : !shape.shape, %init : !shape.size) { %c0 = constant 1 : index shape.yield %c0 : index } + return } // ----- @@ -85,6 +92,7 @@ func @shape_of(%value_arg : !shape.value_shape, %shaped_arg : tensor) { // expected-error@+1 {{if operand is of type `value_shape` then the result must be of type `shape` to propagate potential error shapes}} %0 = shape.shape_of %value_arg : !shape.value_shape -> tensor + return } // ----- @@ -93,6 +101,7 @@ func @shape_of(%value_arg : !shape.value_shape, %shaped_arg : tensor) { // expected-error@+1 {{if operand is a shaped type then the result must be an extent tensor}} %1 = shape.shape_of %shaped_arg : tensor -> !shape.shape + return } // ----- @@ -100,6 +109,7 @@ func @shape_of(%value_arg : !shape.value_shape, func @rank(%arg : !shape.shape) { // expected-error@+1 {{if operand is of type `shape` then the result must be of type `size` to propagate potential errors}} %0 = shape.rank %arg : !shape.shape -> index + return } // ----- @@ -120,3 +130,19 @@ func @get_extent_error_possible(%arg : tensor) -> index { return %result : index } +// ----- + +func @mul_error_free(%arg : index) -> !shape.size { + // expected-error@+1 {{if none of the operands can hold error values then the result must be of type `index`}} + %result = shape.mul %arg, %arg : index, index -> !shape.size + return %result : !shape.size +} + +// ----- + +func @mul_error_possible(%lhs : !shape.size, %rhs : index) -> index { + // expected-error@+1 {{if at least one of the operands can hold error values then the result must be of type `size` to propagate them}} + %result = shape.mul %lhs, %rhs : !shape.size, index -> index + return %result : index +} + diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir index b6b8392..3a0cb77 100644 --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -9,6 +9,7 @@ func @shape_num_elements(%shape : !shape.shape) -> !shape.size { %num_elements = shape.reduce(%shape, %init) : !shape.shape -> !shape.size { ^bb0(%index : index, %extent : !shape.size, %acc : !shape.size): %acc_next = shape.mul %acc, %extent + : !shape.size, !shape.size -> !shape.size shape.yield %acc_next : !shape.size } return %num_elements : !shape.size @@ -19,7 +20,7 @@ func @extent_tensor_num_elements(%shape : tensor) -> index { %init = constant 1 : index %num_elements = shape.reduce(%shape, %init) : tensor -> index { ^bb0(%index : index, %extent : index, %acc : index): - %acc_next = muli %acc, %extent : index + %acc_next = shape.mul %acc, %extent : index, index -> index shape.yield %acc_next : index } return %num_elements : index @@ -110,9 +111,13 @@ func @broadcastable_on_extent_tensors(%lhs : tensor, return } -func @test_mul(%lhs: !shape.size, %rhs: !shape.size) -> !shape.size { - %product = shape.mul %lhs, %rhs - return %product: !shape.size +func @mul(%size_arg : !shape.size, %index_arg : index) { + %size_prod = shape.mul %size_arg, %size_arg + : !shape.size, !shape.size -> !shape.size + %index_prod = shape.mul %index_arg, %index_arg : index, index -> index + %mixed_prod = shape.mul %size_arg, %index_arg + : !shape.size, index -> !shape.size + return } func @const_size() { -- 2.7.4