[mlir][python] Adapt to `segment_sizes` attribute type change.
authorStella Laurenzo <stellaraccident@gmail.com>
Sat, 20 Mar 2021 01:44:51 +0000 (18:44 -0700)
committerStella Laurenzo <stellaraccident@gmail.com>
Sat, 20 Mar 2021 01:47:00 +0000 (18:47 -0700)
* Broken by https://reviews.llvm.org/rG1a75be0023cd80fd8560d689999a63d4368c90e6

mlir/lib/Bindings/Python/IRCore.cpp
mlir/test/Bindings/Python/ods_helpers.py

index 9d87aa5..0a4c5fc 100644 (file)
@@ -1034,8 +1034,8 @@ PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
   py::object operandSegmentSpecObj = cls.attr("_ODS_OPERAND_SEGMENTS");
   py::object resultSegmentSpecObj = cls.attr("_ODS_RESULT_SEGMENTS");
 
-  std::vector<uint64_t> operandSegmentLengths;
-  std::vector<uint64_t> resultSegmentLengths;
+  std::vector<uint32_t> operandSegmentLengths;
+  std::vector<uint32_t> resultSegmentLengths;
 
   // Validate/determine region count.
   auto opRegionSpec = py::cast<std::tuple<int, bool>>(cls.attr("_ODS_REGIONS"));
@@ -1247,8 +1247,8 @@ PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
     // Add result_segment_sizes attribute.
     if (!resultSegmentLengths.empty()) {
       int64_t size = resultSegmentLengths.size();
-      MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt64Get(
-          mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 64)),
+      MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
+          mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
           resultSegmentLengths.size(), resultSegmentLengths.data());
       (*attributes)["result_segment_sizes"] =
           PyAttribute(context, segmentLengthAttr);
@@ -1257,8 +1257,8 @@ PyOpView::buildGeneric(py::object cls, py::list resultTypeList,
     // Add operand_segment_sizes attribute.
     if (!operandSegmentLengths.empty()) {
       int64_t size = operandSegmentLengths.size();
-      MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt64Get(
-          mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 64)),
+      MlirAttribute segmentLengthAttr = mlirDenseElementsAttrUInt32Get(
+          mlirVectorTypeGet(1, &size, mlirIntegerTypeGet(context->get(), 32)),
           operandSegmentLengths.size(), operandSegmentLengths.data());
       (*attributes)["operand_segment_sizes"] =
           PyAttribute(context, segmentLengthAttr);
index 54f68a8..badeac3 100644 (file)
@@ -125,8 +125,8 @@ def testOdsBuildDefaultSizedVariadic():
       # CHECK: %[[V2:.+]] = "custom.value"
       # CHECK: %[[V3:.+]] = "custom.value"
       # CHECK: "custom.test_op"(%[[V0]], %[[V1]], %[[V2]], %[[V3]])
-      # CHECK-SAME: operand_segment_sizes = dense<[1, 2, 1]> : vector<3xi64>
-      # CHECK-SAME: result_segment_sizes = dense<[2, 1, 1]> : vector<3xi64>
+      # CHECK-SAME: operand_segment_sizes = dense<[1, 2, 1]> : vector<3xi32>
+      # CHECK-SAME: result_segment_sizes = dense<[2, 1, 1]> : vector<3xi32>
       # CHECK-SAME: : (i32, i32, i32, i32) -> (i8, i16, i32, i64)
       op = TestOp.build_generic(
           results=[[t0, t1], t2, t3],