#define SHAPE_OPS
include "mlir/IR/OpBase.td"
+include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/SideEffects.td"
// TODO(jpienaar): Move to base.
let parser = [{ return ::parse$cppClass(parser, result); }];
}
-def Shape_CreateShapeOp : Shape_Op<"create_shape", []> {
- let summary = "Creates a shape descriptor from a tensor";
+def Shape_FromExtentTensorOp : Shape_Op<"from_extent_tensor", []> {
+ let summary = "Creates a shape from a tensor of extents";
let description = [{
- Creates a shape from a 1D integral tensor. The rank equals the number of
- elements in the tensor, and extent matches the values of the elements.
+ Creates a shape from a 1D integral tensor of extents. The rank of the
+ resulting shape equals the number of elements in the tensor, and the
+ extents match the values of the elements.
}];
let arguments = (ins I32Tensor:$input);
let results = (outs Shape_ShapeType:$result);
}
+def Shape_ToExtentTensorOp : Shape_Op<"to_tensor", []> {
+ let summary = "Creates a dimension tensor from a shape";
+ // TODO: Think more about the error situation. Perhaps factor out the
+ // error detection into a separate op so downstream consumers can control
+ // their error behavior. Then this op would assume that the input has
+ // been properly checked to not be an error (and could thus be a
+ // NoSideEffect op).
+ let description = [{
+ Converts a shape to a 1D integral tensor of extents. The number of elements
+ in the tensor equals the rank of the shape, and the elements equal the
+ extents of the shape.
+
+ If the shape represents an error, then this op currently aborts the program.
+ }];
+
+ let arguments = (ins Shape_ShapeType:$input);
+ let results = (outs I32Tensor:$result);
+}
+
def Shape_JoinOp : Shape_Op<"join", []> {
let summary = "Returns the least general shape.size of its operands";
let description = [{
let results = (outs Shape_ShapeOrSizeType:$output);
}
+def Shape_SplitAtOp : Shape_Op<"split_at",
+ [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ let summary = "Splits a shape at a given index.";
+ let description = [{
+ Splits a shape at a given dimension `index`, returning two shapes.
+ If `index` is negative, it is treated as indexing from the back of the
+ shape. This negative-handling behavior is important when handling unranked
+ shapes, where the positive index is not necessarily knowable due to a
+ dynamic number of leading dimensions.
+
+ Examples:
+ - split_at([4,5,6], index=0) -> [], [4,5,6]
+ - split_at([4,5,6], index=1) -> [4], [5,6]
+ - split_at([4,5,6], index=2) -> [4,5], [6]
+ - split_at([4,5,6], index=3) -> [4,5,6], []
+ - split_at([4,5,6], index=4) -> error
+ - split_at([4,5,6], index=-1) -> [4,5], [6]
+ - split_at([4,5,6], index=-2) -> [4], [5,6]
+ - split_at([4,5,6], index=-3) -> [], [4,5,6]
+ - split_at([4,5,6], index=-4) -> error
+
+ Requires:
+ - `index` is in the range [-rank(operand),rank(operand)]
+ }];
+
+ let arguments = (ins Shape_ShapeType:$operand, I32:$index);
+ let results = (outs Shape_ShapeType:$head, Shape_ShapeType:$tail);
+}
+
+def Shape_ConcatOp : Shape_Op<"concat",
+ [DeclareOpInterfaceMethods<InferTypeOpInterface>]> {
+ let summary = "Concatenates two shapes.";
+ let description = [{
+ Creates a shape whose dimensions consist of first the dimensions from `lhs`
+ followed by the dimensions of `rhs`.
+
+ Example:
+ concat([2,3], [4,5]) -> [2,3,4,5]
+ concat([], []) -> []
+ concat([], [4,5,6]) -> [4,5,6]
+ }];
+
+ let arguments = (ins Shape_ShapeType:$lhs, Shape_ShapeType:$rhs);
+ let results = (outs Shape_ShapeType:$result);
+}
+
#endif // SHAPE_OPS
static LogicalResult verify(ConstantOp &op) { return success(); }
+//===----------------------------------------------------------------------===//
+// SplitAtOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult SplitAtOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ ArrayRef<NamedAttribute> attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ auto shapeType = ShapeType::get(context);
+ inferredReturnTypes.push_back(shapeType);
+ inferredReturnTypes.push_back(shapeType);
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// ConcatOp
+//===----------------------------------------------------------------------===//
+
+LogicalResult ConcatOp::inferReturnTypes(
+ MLIRContext *context, Optional<Location> location, ValueRange operands,
+ ArrayRef<NamedAttribute> attributes, RegionRange regions,
+ SmallVectorImpl<Type> &inferredReturnTypes) {
+ auto shapeType = ShapeType::get(context);
+ inferredReturnTypes.push_back(shapeType);
+ return success();
+}
+
namespace mlir {
namespace shape {