- Implement a first constant fold for shape.shape_of (more ops coming in subsequent patches)
- Implement the right builder interfaces for ShapeType and other types
- Splits shape.constant into shape.const_size and shape.const_shape which plays better with dyn_cast and building vs one polymorphic op.
Also, fix the RUN line in ops.mlir to properly verify round-tripping.
}];
let cppNamespace = "shape";
+
+ let hasConstantMaterializer = 1;
}
def Shape_ComponentType : DialectType<ShapeDialect,
- CPred<"$_self.isa<::mlir::shape::ComponentType>()">, "component type"> {
+ CPred<"$_self.isa<::mlir::shape::ComponentType>()">, "component type">,
+ BuildableType<"$_builder.getType<::mlir::shape::ComponentType>()"> {
let typeDescription = [{
`shape.element_type` represents the element type of the ShapedType. It may
be unknown, error or regular element type supported by ShapedType.
}
def Shape_ElementType : DialectType<ShapeDialect,
- CPred<"$_self.isa<::mlir::shape::ElementType>()">, "element type"> {
+ CPred<"$_self.isa<::mlir::shape::ElementType>()">, "element type">,
+ BuildableType<"$_builder.getType<::mlir::shape::ElementType>()"> {
let typeDescription = [{
`shape.element_type` represents the element type of the ShapedType. It may
be unknown, error or regular element type supported by ShapedType.
}
def Shape_ShapeType : DialectType<ShapeDialect,
- CPred<"$_self.isa<::mlir::shape::ShapeType>()">, "shape"> {
+ CPred<"$_self.isa<::mlir::shape::ShapeType>()">, "shape">,
+ BuildableType<"$_builder.getType<::mlir::shape::ShapeType>()"> {
let typeDescription = [{
`shape.type` represents either an unranked shape, a ranked shape with
possibly unknown dimensions or an invalid shape. The rank is of type
}
def Shape_SizeType : DialectType<ShapeDialect,
- CPred<"$_self.isa<::mlir::shape::SizeType>()">, "size"> {
+ CPred<"$_self.isa<::mlir::shape::SizeType>()">, "size">,
+ BuildableType<"$_builder.getType<::mlir::shape::SizeType>()"> {
let typeDescription = [{
`shape.size` represents a non-negative integer with support for being
unknown and invalid.
}
def Shape_ValueShapeType : DialectType<ShapeDialect,
- CPred<"$_self.isa<::mlir::shape::ValueShapeType>()">, "value shape"> {
+ CPred<"$_self.isa<::mlir::shape::ValueShapeType>()">, "value shape">,
+ BuildableType<"::mlir::shape::ValueShapeType::get($_builder.getContext())">
+{
let typeDescription = [{
`shape.value_shape` represents the value produced by an operation (this
corresponds to `Value` in the compiler) and a shape. Conceptually this is a
let results = (outs Shape_ShapeType:$result);
}
-def Shape_ConstantOp : Shape_Op<"constant", []> {
- let summary = "Creates a shape constant";
+def Shape_ConstShapeOp : Shape_Op<"const_shape",
+ [ConstantLike,
+ NoSideEffect,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ let summary = "Creates a constant of !shape.shape type.";
let description = [{
- An operation that builds a size or shape from integer or array attribute.
- It allows for creating dynamically valued shapes by using `?` for unknown
- values. A constant shape specified with `*` will return an unranked shape.
+ Creates a !shape.shape with rank given by the length of `shape` and with
+ dimension sizes given by the values of `shape`.
```mlir
- %x = shape.constant 10 : !shape.size
+ %0 = shape.const_shape []
+ %1 = shape.const_shape [1, 2, 3]
```
}];
-
- // TODO(jpienaar): Change to a more specialized attribute that would
- // encapsulate the unknown parsing while using denser packing.
- let arguments = (ins AnyAttr:$value);
- let results = (outs Shape_ShapeOrSizeType:$result);
+ let arguments = (ins I64ElementsAttr:$shape);
+ let results = (outs Shape_ShapeType:$result);
// TODO: Move this to main so that all shape ops implement these.
let printer = [{ return ::print(p, *this); }];
- let verifier = [{ return ::verify(*this); }];
let parser = [{ return ::parse$cppClass(parser, result); }];
+ let hasFolder = 1;
+}
+
+def Shape_ConstSizeOp : Shape_Op<"const_size",
+ [ConstantLike,
+ NoSideEffect,
+ DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ let summary = "Creates a constant of !shape.size type.";
+ let description = [{
+ Creates a !shape.size type representing the constant size given by `value`.
+
+ ```mlir
+ %x = shape.const_size 10
+ ```
+ }];
+
+ let arguments = (ins IndexAttr:$value);
+ let results = (outs Shape_SizeType:$result);
+
+ let assemblyFormat = "attr-dict $value";
}
def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", []> {
let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$arg);
let results = (outs Shape_ShapeType:$result);
+
+ let hasFolder = 1;
}
def Shape_YieldOp : Shape_Op<"yield", [NoSideEffect, Terminator]> {
#include "mlir/IR/Builders.h"
#include "mlir/IR/DialectImplementation.h"
+#include "mlir/IR/PatternMatch.h"
#include "mlir/IR/StandardTypes.h"
#include "llvm/Support/raw_ostream.h"
allowUnknownOperations();
}
+Operation *ShapeDialect::materializeConstant(OpBuilder &builder,
+ Attribute value, Type type,
+ Location loc) {
+ if (auto shapeType = type.dyn_cast<ShapeType>()) {
+ return builder.create<ConstShapeOp>(loc, type,
+ value.cast<DenseIntElementsAttr>());
+ }
+ if (auto sizeType = type.dyn_cast<SizeType>()) {
+ return builder.create<ConstSizeOp>(loc, type, value.cast<IntegerAttr>());
+ }
+ return nullptr;
+}
+
/// Parse a type registered to this dialect.
Type ShapeDialect::parseType(DialectAsmParser &parser) const {
StringRef keyword;
}
//===----------------------------------------------------------------------===//
-// Constant*Op
+// ConstShapeOp
//===----------------------------------------------------------------------===//
-static void print(OpAsmPrinter &p, ConstantOp &op) {
- p << "shape.constant ";
- p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"value"});
-
- if (op.getAttrs().size() > 1)
- p << ' ';
- p.printAttributeWithoutType(op.value());
- p << " : " << op.getType();
+static void print(OpAsmPrinter &p, ConstShapeOp &op) {
+ p << "shape.const_shape ";
+ p.printOptionalAttrDict(op.getAttrs(), /*elidedAttrs=*/{"shape"});
+ p << "[";
+ interleaveComma(op.shape().getValues<int64_t>(), p,
+ [&](int64_t i) { p << i; });
+ p << "]";
}
-static ParseResult parseConstantOp(OpAsmParser &parser,
- OperationState &result) {
- Attribute valueAttr;
+static ParseResult parseConstShapeOp(OpAsmParser &parser,
+ OperationState &result) {
if (parser.parseOptionalAttrDict(result.attributes))
return failure();
- Type i64Type = parser.getBuilder().getIntegerType(64);
- if (parser.parseAttribute(valueAttr, i64Type, "value", result.attributes))
+ // We piggy-back on ArrayAttr parsing, though we don't internally store the
+ // shape as an ArrayAttr.
+ // TODO: Implement custom parser and maybe make syntax a bit more concise.
+ Attribute extentsRaw;
+ SmallVector<NamedAttribute, 6> dummy;
+ if (parser.parseAttribute(extentsRaw, "dummy", dummy))
return failure();
-
- Type type;
- if (parser.parseColonType(type))
+ auto extentsArray = extentsRaw.dyn_cast<ArrayAttr>();
+ if (!extentsArray)
return failure();
+ SmallVector<int64_t, 6> ints;
+ for (Attribute extent : extentsArray) {
+ IntegerAttr attr = extent.dyn_cast<IntegerAttr>();
+ if (!attr)
+ return failure();
+ ints.push_back(attr.getInt());
+ }
+ Builder &builder = parser.getBuilder();
+ result.addAttribute("shape", builder.getI64TensorAttr(ints));
+
+ result.types.push_back(ShapeType::get(builder.getContext()));
+ return success();
+}
+
+OpFoldResult ConstShapeOp::fold(ArrayRef<Attribute>) { return shape(); }
+
+LogicalResult ConstShapeOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ ArrayRef<NamedAttribute> attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ inferredReturnTypes.push_back(ShapeType::get(context));
+ return success();
+}
- // Add the attribute type to the list.
- return parser.addTypeToList(type, result.types);
+//===----------------------------------------------------------------------===//
+// ConstSizeOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ConstSizeOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ ArrayRef<NamedAttribute> attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ inferredReturnTypes.push_back(SizeType::get(context));
+ return success();
}
-static LogicalResult verify(ConstantOp &op) { return success(); }
+//===----------------------------------------------------------------------===//
+// ShapeOfOp
+//===----------------------------------------------------------------------===//
+
+OpFoldResult ShapeOfOp::fold(ArrayRef<Attribute>) {
+ auto type = getOperand().getType().dyn_cast<ShapedType>();
+ if (!type || !type.hasStaticShape())
+ return nullptr;
+ Builder builder(getContext());
+ return builder.getI64TensorAttr(type.getShape());
+}
//===----------------------------------------------------------------------===//
// SplitAtOp
--- /dev/null
+// RUN: mlir-opt -canonicalize <%s | FileCheck %s --dump-input=fail
+
+// -----
+// CHECK-LABEL: func @f
+func @f(%arg0: tensor<2x3x4xf32>) -> !shape.shape {
+ // CHECK: shape.const_shape [2, 3, 4]
+ %0 = "shape.shape_of"(%arg0) : (tensor<2x3x4xf32>) -> !shape.shape
+ return %0 : !shape.shape
+}
-// RUN: mlir-opt -split-input-file %s | FileCheck %s --dump-input-on-failure
+// RUN: mlir-opt -split-input-file %s | mlir-opt | FileCheck %s --dump-input-on-failure
// CHECK-LABEL: shape_num_elements
func @shape_num_elements(%shape : !shape.shape) -> !shape.size {
- %0 = shape.constant 0 : !shape.size
+ %0 = shape.const_size 0
%1 = "shape.reduce"(%shape, %0) ( {
^bb0(%index: i32, %dim: !shape.size, %lci: !shape.size):
%acc = "shape.add"(%lci, %dim) : (!shape.size, !shape.size) -> !shape.size
}
func @test_shape_num_elements_fixed() {
- %0 = "shape.constant"() { value = [1, 57, 92] }: () -> !shape.shape
+ %0 = shape.const_shape [1, 57, 92]
%1 = call @shape_num_elements(%0) : (!shape.shape) -> (!shape.size)
%3 = "shape.print"(%1) : (!shape.size) -> !shape.size
return
}
func @test_broadcastable_fixed() {
- %0 = "shape.constant"() { value = [10, 1, 57, 92] }: () -> !shape.shape
- %1 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape
+ %0 = shape.const_shape [10, 1, 57, 92]
+ %1 = shape.const_shape [4, 57, 92]
%2 = "shape.broadcastable"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
func @test_shape_any_fixed() {
- %0 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape
- %1 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape
+ %0 = shape.const_shape [4, 57, 92]
+ %1 = shape.const_shape [4, 57, 92]
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
func @test_shape_any_unknown() {
- %0 = "shape.constant"() { value = [4, -1, 92] }: () -> !shape.shape
- %1 = "shape.constant"() { value = [-1, 57, 92] }: () -> !shape.shape
+ %0 = shape.const_shape [4, -1, 92]
+ %1 = shape.const_shape [-1, 57, 92]
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
func @test_shape_any_fixed_mismatch() {
- %0 = "shape.constant"() { value = [4, 57, 92] }: () -> !shape.shape
- %1 = "shape.constant"() { value = [2, 57, 92] }: () -> !shape.shape
+ %0 = shape.const_shape [4, 57, 92]
+ %1 = shape.const_shape [2, 57, 92]
%2 = "shape.join"(%0, %1) : (!shape.shape, !shape.shape) -> !shape.shape
%3 = "shape.print"(%2) : (!shape.shape) -> !shape.shape
return
}
+
+func @test_parse_const_shape() {
+ %0 = shape.const_shape []
+ %1 = shape.const_shape [1, 2, 3]
+ return
+}