From: Alexander Belyaev Date: Fri, 5 Jun 2020 09:16:53 +0000 (+0200) Subject: [Mlir] Implement printer, parser, verifier and builder for shape.reduce. X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=04fb2b6123ee66e09b1956ff68b5436fe43cd3b4;p=platform%2Fupstream%2Fllvm.git [Mlir] Implement printer, parser, verifier and builder for shape.reduce. Differential Revision: https://reviews.llvm.org/D81186 --- diff --git a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td index 5fc2aa4..ac5bedf 100644 --- a/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td +++ b/mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td @@ -290,7 +290,8 @@ def Shape_NumElementsOp : Shape_Op<"num_elements", [NoSideEffect]> { let hasFolder = 1; } -def Shape_ReduceOp : Shape_Op<"reduce", []> { +def Shape_ReduceOp : Shape_Op<"reduce", + [SingleBlockImplicitTerminator<"YieldOp">]> { let summary = "Returns an expression reduced over a shape"; let description = [{ An operation that takes as input a shape, number of initial values and has a @@ -310,25 +311,32 @@ def Shape_ReduceOp : Shape_Op<"reduce", []> { number of elements ```mlir - func @shape_num_elements(%shape : !shape.shape) -> !shape.size { - %0 = "shape.constant_dim"() {value = 1 : i32} : () -> !shape.size - %1 = "shape.reduce"(%shape, %0) ( { - ^bb0(%index: i32, %dim: !shape.size, %lci: !shape.size): - %acc = "shape.mul"(%lci, %dim) : + func @reduce(%shape : !shape.shape, %init : !shape.size) -> !shape.size { + %num_elements = shape.reduce(%shape, %init) -> !shape.size { + ^bb0(%index: index, %dim: !shape.size, %acc: !shape.size): + %updated_acc = "shape.mul"(%acc, %dim) : (!shape.size, !shape.size) -> !shape.size - shape.yield %acc : !shape.size - }) : (!shape.shape, !shape.size) -> (!shape.size) - return %1 : !shape.size + shape.yield %updated_acc : !shape.size + } + return %num_elements : !shape.size } ``` If the shape is unranked, then the results of the op is also unranked. }]; - let arguments = (ins Shape_ShapeType:$shape, Variadic:$args); + let arguments = (ins Shape_ShapeType:$shape, Variadic:$initVals); let results = (outs Variadic:$result); - let regions = (region SizedRegion<1>:$body); + + let builders = [ + OpBuilder<"OpBuilder &builder, OperationState &result, " + "Value shape, ValueRange initVals">, + ]; + + let verifier = [{ return ::verify(*this); }]; + let printer = [{ return ::print(p, *this); }]; + let parser = [{ return ::parse$cppClass(parser, result); }]; } def Shape_ShapeOfOp : Shape_Op<"shape_of", [NoSideEffect]> { diff --git a/mlir/lib/Dialect/Shape/IR/Shape.cpp b/mlir/lib/Dialect/Shape/IR/Shape.cpp index ed89d5b..04b1a51 100644 --- a/mlir/lib/Dialect/Shape/IR/Shape.cpp +++ b/mlir/lib/Dialect/Shape/IR/Shape.cpp @@ -481,6 +481,89 @@ OpFoldResult ToExtentTensorOp::fold(ArrayRef operands) { return DenseIntElementsAttr::get(type, shape); } +//===----------------------------------------------------------------------===// +// ReduceOp +//===----------------------------------------------------------------------===// + +void ReduceOp::build(OpBuilder &builder, OperationState &result, Value shape, + ValueRange initVals) { + result.addOperands(shape); + result.addOperands(initVals); + + Region *bodyRegion = result.addRegion(); + bodyRegion->push_back(new Block); + Block &bodyBlock = bodyRegion->front(); + bodyBlock.addArgument(builder.getIndexType()); + bodyBlock.addArgument(SizeType::get(builder.getContext())); + + for (Type initValType : initVals.getTypes()) { + bodyBlock.addArgument(initValType); + result.addTypes(initValType); + } +} + +static LogicalResult verify(ReduceOp op) { + // Verify block arg types. + Block &block = op.body().front(); + + auto blockArgsCount = op.initVals().size() + 2; + if (block.getNumArguments() != blockArgsCount) + return op.emitOpError() << "ReduceOp body is expected to have " + << blockArgsCount << " arguments"; + + if (block.getArgument(0).getType() != IndexType::get(op.getContext())) + return op.emitOpError( + "argument 0 of ReduceOp body is expected to be of IndexType"); + + if (block.getArgument(1).getType() != SizeType::get(op.getContext())) + return op.emitOpError( + "argument 1 of ReduceOp body is expected to be of SizeType"); + + for (auto type : llvm::enumerate(op.initVals())) + if (block.getArgument(type.index() + 2).getType() != type.value().getType()) + return op.emitOpError() + << "type mismatch between argument " << type.index() + 2 + << " of ReduceOp body and initial value " << type.index(); + return success(); +} + +static ParseResult parseReduceOp(OpAsmParser &parser, OperationState &result) { + auto *ctx = parser.getBuilder().getContext(); + // Parse operands. + SmallVector operands; + if (parser.parseOperandList(operands, /*requiredOperandCount=*/-1, + OpAsmParser::Delimiter::Paren) || + parser.parseOptionalArrowTypeList(result.types)) + return failure(); + + // Resolve operands. + auto initVals = llvm::makeArrayRef(operands).drop_front(); + if (parser.resolveOperand(operands.front(), ShapeType::get(ctx), + result.operands) || + parser.resolveOperands(initVals, result.types, parser.getNameLoc(), + result.operands)) + return failure(); + + // Parse the body. + Region *body = result.addRegion(); + if (parser.parseRegion(*body, /*args=*/{}, /*argTypes=*/{})) + return failure(); + + // Parse attributes. + if (parser.parseOptionalAttrDict(result.attributes)) + return failure(); + + return success(); +} + +static void print(OpAsmPrinter &p, ReduceOp op) { + p << op.getOperationName() << '(' << op.shape() << ", " << op.initVals() + << ") "; + p.printOptionalArrowTypeList(op.getResultTypes()); + p.printRegion(op.body()); + p.printOptionalAttrDict(op.getAttrs()); +} + namespace mlir { namespace shape { diff --git a/mlir/test/Dialect/Shape/invalid.mlir b/mlir/test/Dialect/Shape/invalid.mlir new file mode 100644 index 0000000..63589c8 --- /dev/null +++ b/mlir/test/Dialect/Shape/invalid.mlir @@ -0,0 +1,40 @@ +// RUN: mlir-opt %s -split-input-file -verify-diagnostics + +func @reduce_op_args_num_mismatch(%shape : !shape.shape, %init : !shape.size) { + // expected-error@+1 {{ReduceOp body is expected to have 3 arguments}} + %num_elements = shape.reduce(%shape, %init) -> !shape.size { + ^bb0(%index: index, %dim: !shape.size): + "shape.yield"(%dim) : (!shape.size) -> () + } +} + +// ----- + +func @reduce_op_arg0_wrong_type(%shape : !shape.shape, %init : !shape.size) { + // expected-error@+1 {{argument 0 of ReduceOp body is expected to be of IndexType}} + %num_elements = shape.reduce(%shape, %init) -> !shape.size { + ^bb0(%index: f32, %dim: !shape.size, %lci: !shape.size): + %acc = "shape.add"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size + "shape.yield"(%acc) : (!shape.size) -> () + } +} + +// ----- + +func @reduce_op_arg1_wrong_type(%shape : !shape.shape, %init : !shape.size) { + // expected-error@+1 {{argument 1 of ReduceOp body is expected to be of SizeType}} + %num_elements = shape.reduce(%shape, %init) -> !shape.size { + ^bb0(%index: index, %dim: f32, %lci: !shape.size): + "shape.yield"() : () -> () + } +} + +// ----- + +func @reduce_op_init_type_mismatch(%shape : !shape.shape, %init : f32) { + // expected-error@+1 {{type mismatch between argument 2 of ReduceOp body and initial value 0}} + %num_elements = shape.reduce(%shape, %init) -> f32 { + ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size): + "shape.yield"() : () -> () + } +} diff --git a/mlir/test/Dialect/Shape/ops.mlir b/mlir/test/Dialect/Shape/ops.mlir index 5f316d9..0df58ed 100644 --- a/mlir/test/Dialect/Shape/ops.mlir +++ b/mlir/test/Dialect/Shape/ops.mlir @@ -6,13 +6,13 @@ // CHECK-LABEL: shape_num_elements func @shape_num_elements(%shape : !shape.shape) -> !shape.size { - %0 = shape.const_size 0 - %1 = "shape.reduce"(%shape, %0) ( { - ^bb0(%index: i32, %dim: !shape.size, %lci: !shape.size): + %init = shape.const_size 0 + %num_elements = shape.reduce(%shape, %init) -> !shape.size { + ^bb0(%index: index, %dim: !shape.size, %lci: !shape.size): %acc = "shape.add"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size "shape.yield"(%acc) : (!shape.size) -> () - }) : (!shape.shape, !shape.size) -> (!shape.size) - return %1 : !shape.size + } + return %num_elements : !shape.size } func @test_shape_num_elements_unknown() {