Generalize I32ElementsAttr definition and introduce I64ElementsAttr
authorSmit Hinsu <hinsu@google.com>
Thu, 5 Sep 2019 06:15:33 +0000 (23:15 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Thu, 5 Sep 2019 06:16:01 +0000 (23:16 -0700)
Also, fix constBuilderCall to return attribute of the storage class DenseIntElementsAttr

PiperOrigin-RevId: 267305813

mlir/include/mlir/IR/OpBase.td
mlir/test/lib/TestDialect/TestOps.td
mlir/test/mlir-tblgen/pattern.mlir

index d7b475a2d457a976aa0203e9155de4000c576865..e4acb038941cecd53ffb709f72968bb746a9ad3c 100644 (file)
@@ -866,9 +866,28 @@ class ElementsAttrBase<Pred condition, string description> :
   let convertFromStorage = "$_self";
 }
 
-def ElementsAttr: ElementsAttrBase<CPred<"$_self.isa<ElementsAttr>()">,
+def ElementsAttr : ElementsAttrBase<CPred<"$_self.isa<ElementsAttr>()">,
                                    "constant vector/tensor attribute">;
 
+class IntElementsAttr<int width> : ElementsAttrBase<
+  CPred<"$_self.isa<DenseIntElementsAttr>() &&"
+      "$_self.cast<DenseIntElementsAttr>().getType()."
+      "getElementType().isInteger(" # width # ")">,
+  width # "-bit integer elements attribute"> {
+
+  let storageType = [{ DenseIntElementsAttr }];
+  let returnType = [{ DenseIntElementsAttr }];
+
+  // Note that this is only constructing scalar elements attribute.
+  let constBuilderCall = "DenseElementsAttr::get("
+    "$_builder.getTensorType({}, $_builder.getIntegerType(" # width # ")), "
+    "llvm::makeArrayRef($0)).cast<DenseIntElementsAttr>()";
+  let convertFromStorage = "$_self";
+}
+
+def I32ElementsAttr : IntElementsAttr<32>;
+def I64ElementsAttr : IntElementsAttr<64>;
+
 // Base class for array attributes.
 class ArrayAttrBase<Pred condition, string description> :
     Attr<condition, description> {
@@ -916,18 +935,6 @@ def TypeArrayAttr : TypedArrayAttrBase<TypeAttr, "type array attribute"> {
   let constBuilderCall = ?;
 }
 
-def I32ElementsAttr : Attr<
-  CPred<"$_self.isa<DenseIntElementsAttr>() &&"
-      "$_self.cast<DenseIntElementsAttr>().getType()."
-      "getElementType().isInteger(32)">,
-  "32-bit integer elements attribute"> {
-  let storageType = [{ DenseIntElementsAttr }];
-  let returnType = [{ DenseIntElementsAttr }];
-  let constBuilderCall = "$_builder.getDenseElementsAttr("
-    "$_builder.getTensorType({}, $_builder.getIntegerType(32)), "
-      "{$_builder.getI32IntegerAttr($0)})";
-  let convertFromStorage = "$_self";
-}
 // Attribute information for an Attribute field within a StructAttr.
 class StructFieldAttr<string thisName, Attr thisType> {
   // Name of this field in the StructAttr.
index ee7a3965bb9984accdc5e3003465a2c1f1d732f8..0b2703c3e061041b77aefe57139e0755c1326e87 100644 (file)
@@ -263,10 +263,16 @@ def SingleBlockImplicitTerminatorOp : TEST_Op<"SingleBlockImplicitTerminator",
   let regions = (region SizedRegion<1>:$region);
 }
 
-def I32ElementsAttributesOp : TEST_Op<"i32ElementsAttr"> {
+def I32ElementsAttrOp : TEST_Op<"i32ElementsAttr"> {
   let arguments = (ins I32ElementsAttr:$attr);
 }
 
+def IsNotScalar : Constraint<CPred<"$0.getType().getRank() != 0">>;
+
+def UpdateAttr : Pat<(I32ElementsAttrOp $attr),
+                     (I32ElementsAttrOp ConstantAttr<I32ElementsAttr, "0">),
+                     [(IsNotScalar $attr)]>;
+
 //===----------------------------------------------------------------------===//
 // Test Patterns
 //===----------------------------------------------------------------------===//
index 3b9e25e6b79c63b03f55f6451c0f7eecc2cddded..ee5acf92a461bce226274886d6d0c93ca794cd85 100644 (file)
@@ -152,6 +152,17 @@ func @verifyI64EnumAttr() -> i32 {
   return %0 : i32
 }
 
+//===----------------------------------------------------------------------===//
+// Test ElelementsAttr
+//===----------------------------------------------------------------------===//
+
+// CHECK-LABEL: rewrite_i32elementsattr
+func @rewrite_i32elementsattr() -> () {
+  // CHECK: attr = dense<0> : tensor<i32>
+  "test.i32ElementsAttr"() {attr = dense<[3, 5]>:tensor<2xi32>} : () -> ()
+  return
+}
+
 //===----------------------------------------------------------------------===//
 // Test Multi-result Ops
 //===----------------------------------------------------------------------===//