[mlir][shape] refine shape.func and shape.with_shape
authorJacques Pienaar <jpienaar@google.com>
Mon, 22 Aug 2022 18:47:15 +0000 (11:47 -0700)
committerJacques Pienaar <jpienaar@google.com>
Mon, 22 Aug 2022 21:52:18 +0000 (14:52 -0700)
- shape.with_shape supports ExtentTensorType
- add helper to create shape.func

Reviewed By: jpienaar

Differential Revision: https://reviews.llvm.org/D131977

mlir/include/mlir/Dialect/Shape/IR/ShapeOps.td
mlir/lib/Dialect/Shape/IR/Shape.cpp
mlir/test/Dialect/Shape/ops.mlir

index 4570772..8503b9d 100644 (file)
@@ -736,7 +736,7 @@ def Shape_WithOp : Shape_Op<"with_shape", [NoSideEffect]> {
   }];
 
   let arguments = (ins AnyTypeOf<[AnyShaped, Shape_ValueShapeType]>:$operand,
-                       Shape_ShapeType:$shape);
+                       Shape_ShapeOrExtentTensorType:$shape);
   let results = (outs Shape_ValueShapeType:$result);
 
   let assemblyFormat = "operands attr-dict `:` type($operand) `,` type($shape)";
@@ -1110,7 +1110,20 @@ def Shape_FuncOp : Shape_Op<"func",
                        OptionalAttr<StrAttr>:$sym_visibility);
   let regions = (region AnyRegion:$body);
 
+  let builders = [OpBuilder<(ins
+    "StringRef":$name, "FunctionType":$type,
+    CArg<"ArrayRef<NamedAttribute>", "{}">:$attrs,
+    CArg<"ArrayRef<DictionaryAttr>", "{}">:$argAttrs)
+  >];
+
   let extraClassDeclaration = [{
+    static FuncOp create(Location location, StringRef name, FunctionType type,
+                         ArrayRef<NamedAttribute> attrs = {});
+    static FuncOp create(Location location, StringRef name, FunctionType type,
+                         Operation::dialect_attr_range attrs);
+    static FuncOp create(Location location, StringRef name, FunctionType type,
+                         ArrayRef<NamedAttribute> attrs,
+                         ArrayRef<DictionaryAttr> argAttrs);
     //===------------------------------------------------------------------===//
     // CallableOpInterface
     //===------------------------------------------------------------------===//
index 133e757..2b8cae8 100644 (file)
@@ -1267,6 +1267,43 @@ void FunctionLibraryOp::print(OpAsmPrinter &p) {
 // FuncOp
 //===----------------------------------------------------------------------===//
 
+FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
+                      ArrayRef<NamedAttribute> attrs) {
+  OpBuilder builder(location->getContext());
+  OperationState state(location, getOperationName());
+  FuncOp::build(builder, state, name, type, attrs);
+  return cast<FuncOp>(Operation::create(state));
+}
+FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
+                      Operation::dialect_attr_range attrs) {
+  SmallVector<NamedAttribute, 8> attrRef(attrs);
+  return create(location, name, type, llvm::makeArrayRef(attrRef));
+}
+FuncOp FuncOp::create(Location location, StringRef name, FunctionType type,
+                      ArrayRef<NamedAttribute> attrs,
+                      ArrayRef<DictionaryAttr> argAttrs) {
+  FuncOp func = create(location, name, type, attrs);
+  func.setAllArgAttrs(argAttrs);
+  return func;
+}
+
+void FuncOp::build(OpBuilder &builder, OperationState &state, StringRef name,
+                   FunctionType type, ArrayRef<NamedAttribute> attrs,
+                   ArrayRef<DictionaryAttr> argAttrs) {
+  state.addAttribute(FuncOp::getSymNameAttrName(state.name),
+                     builder.getStringAttr(name));
+  state.addAttribute(FuncOp::getFunctionTypeAttrName(state.name),
+                     TypeAttr::get(type));
+  state.attributes.append(attrs.begin(), attrs.end());
+  state.addRegion();
+
+  if (argAttrs.empty())
+    return;
+  assert(type.getNumInputs() == argAttrs.size());
+  function_interface_impl::addArgAndResultAttrs(builder, state, argAttrs,
+                                                /*resultAttrs=*/llvm::None);
+}
+
 ParseResult FuncOp::parse(OpAsmParser &parser, OperationState &result) {
   auto buildFuncType =
       [](Builder &builder, ArrayRef<Type> argTypes, ArrayRef<Type> results,
index 33da6b2..8a90ed8 100644 (file)
@@ -268,6 +268,12 @@ func.func @shape_with_shape(%a : !shape.value_shape, %b : !shape.value_shape) ->
   return %2 : !shape.shape
 }
 
+func.func @shape_with_shape_extent_tensor_type(%a : tensor<?x?x?xf32>, %b : !shape.value_shape) -> !shape.value_shape {
+  %0 = shape.shape_of %a : tensor<?x?x?xf32> -> tensor<3xindex>
+  %1 = shape.with_shape %b, %0 : !shape.value_shape, tensor<3xindex>
+  return %1 : !shape.value_shape
+}
+
 func.func @any_on_shape(%a : !shape.shape, %b : !shape.shape, %c : !shape.shape)
     -> !shape.shape {
   %result = shape.any %a, %b, %c