[mlir][python] Support more types in IntegerAttr.value
authorrkayaith <rkayaith@gmail.com>
Thu, 24 Feb 2022 09:21:40 +0000 (10:21 +0100)
committerAlex Zinenko <zinenko@google.com>
Thu, 24 Feb 2022 09:26:31 +0000 (10:26 +0100)
Previously only accessing values for `index` and signless int types
would work; signed and unsigned ints would hit an assert in
`IntegerAttr::getInt`. This exposes `IntegerAttr::get{S,U}Int` to the C
API and calls the appropriate function from the python bindings.

Reviewed By: ftynse

Differential Revision: https://reviews.llvm.org/D120194

mlir/include/mlir-c/BuiltinAttributes.h
mlir/lib/Bindings/Python/IRAttributes.cpp
mlir/lib/CAPI/IR/BuiltinAttributes.cpp
mlir/test/CAPI/ir.c
mlir/test/python/ir/attributes.py

index 973b7e9..bb4431f 100644 (file)
@@ -125,9 +125,17 @@ MLIR_CAPI_EXPORTED MlirAttribute mlirIntegerAttrGet(MlirType type,
                                                     int64_t value);
 
 /// Returns the value stored in the given integer attribute, assuming the value
-/// fits into a 64-bit integer.
+/// is of signless type and fits into a signed 64-bit integer.
 MLIR_CAPI_EXPORTED int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr);
 
+/// Returns the value stored in the given integer attribute, assuming the value
+/// is of signed type and fits into a signed 64-bit integer.
+MLIR_CAPI_EXPORTED int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr);
+
+/// Returns the value stored in the given integer attribute, assuming the value
+/// is of unsigned type and fits into an unsigned 64-bit integer.
+MLIR_CAPI_EXPORTED uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr);
+
 //===----------------------------------------------------------------------===//
 // Bool attribute.
 //===----------------------------------------------------------------------===//
index 5d87641..bef3b95 100644 (file)
@@ -258,8 +258,13 @@ public:
         "Gets an uniqued integer attribute associated to a type");
     c.def_property_readonly(
         "value",
-        [](PyIntegerAttribute &self) {
-          return mlirIntegerAttrGetValueInt(self);
+        [](PyIntegerAttribute &self) -> py::int_ {
+          MlirType type = mlirAttributeGetType(self);
+          if (mlirTypeIsAIndex(type) || mlirIntegerTypeIsSignless(type))
+            return mlirIntegerAttrGetValueInt(self);
+          if (mlirIntegerTypeIsSigned(type))
+            return mlirIntegerAttrGetValueSInt(self);
+          return mlirIntegerAttrGetValueUInt(self);
         },
         "Returns the value of the integer attribute");
   }
index 7b718da..9ea277b 100644 (file)
@@ -129,6 +129,14 @@ int64_t mlirIntegerAttrGetValueInt(MlirAttribute attr) {
   return unwrap(attr).cast<IntegerAttr>().getInt();
 }
 
+int64_t mlirIntegerAttrGetValueSInt(MlirAttribute attr) {
+  return unwrap(attr).cast<IntegerAttr>().getSInt();
+}
+
+uint64_t mlirIntegerAttrGetValueUInt(MlirAttribute attr) {
+  return unwrap(attr).cast<IntegerAttr>().getUInt();
+}
+
 //===----------------------------------------------------------------------===//
 // Bool attribute.
 //===----------------------------------------------------------------------===//
index 7ac7a19..c8d2739 100644 (file)
@@ -813,11 +813,21 @@ int printBuiltinAttributes(MlirContext ctx) {
   // CHECK: f64
 
   MlirAttribute integer = mlirIntegerAttrGet(mlirIntegerTypeGet(ctx, 32), 42);
+  MlirAttribute signedInteger =
+      mlirIntegerAttrGet(mlirIntegerTypeSignedGet(ctx, 8), -1);
+  MlirAttribute unsignedInteger =
+      mlirIntegerAttrGet(mlirIntegerTypeUnsignedGet(ctx, 8), 255);
   if (!mlirAttributeIsAInteger(integer) ||
-      mlirIntegerAttrGetValueInt(integer) != 42)
+      mlirIntegerAttrGetValueInt(integer) != 42 ||
+      mlirIntegerAttrGetValueSInt(signedInteger) != -1 ||
+      mlirIntegerAttrGetValueUInt(unsignedInteger) != 255)
     return 2;
   mlirAttributeDump(integer);
+  mlirAttributeDump(signedInteger);
+  mlirAttributeDump(unsignedInteger);
   // CHECK: 42 : i32
+  // CHECK: -1 : si8
+  // CHECK: 255 : ui8
 
   MlirAttribute boolean = mlirBoolAttrGet(ctx, 1);
   if (!mlirAttributeIsABool(boolean) || !mlirBoolAttrGetValue(boolean))
index 48f2d4b..53d246b 100644 (file)
@@ -189,11 +189,20 @@ def testFloatAttr():
 @run
 def testIntegerAttr():
   with Context() as ctx:
-    iattr = IntegerAttr(Attribute.parse("42"))
-    # CHECK: iattr value: 42
-    print("iattr value:", iattr.value)
-    # CHECK: iattr type: i64
-    print("iattr type:", iattr.type)
+    i_attr = IntegerAttr(Attribute.parse("42"))
+    # CHECK: i_attr value: 42
+    print("i_attr value:", i_attr.value)
+    # CHECK: i_attr type: i64
+    print("i_attr type:", i_attr.type)
+    si_attr = IntegerAttr(Attribute.parse("-1 : si8"))
+    # CHECK: si_attr value: -1
+    print("si_attr value:", si_attr.value)
+    ui_attr = IntegerAttr(Attribute.parse("255 : ui8"))
+    # CHECK: ui_attr value: 255
+    print("ui_attr value:", ui_attr.value)
+    idx_attr = IntegerAttr(Attribute.parse("-1 : index"))
+    # CHECK: idx_attr value: -1
+    print("idx_attr value:", idx_attr.value)
 
     # Test factory methods.
     # CHECK: default_get: 42 : i32