[Arith] ExtendedEuclidean merge impl to int_operator (#5625)
authorANSHUMAN TRIPATHY <anshuman.t@huawei.com>
Mon, 1 Jun 2020 15:38:26 +0000 (21:08 +0530)
committerGitHub <noreply@github.com>
Mon, 1 Jun 2020 15:38:26 +0000 (08:38 -0700)
include/tvm/arith/util.h [deleted file]
src/arith/int_operator.h
src/arith/modular_set.cc
src/arith/solve_linear_equation.cc
src/arith/util.cc [deleted file]

diff --git a/include/tvm/arith/util.h b/include/tvm/arith/util.h
deleted file mode 100644 (file)
index adfcefc..0000000
+++ /dev/null
@@ -1,45 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tvm/arith/util.h
- * \brief Utils for arithmetic analysis.
- */
-#ifndef TVM_ARITH_UTIL_H_
-#define TVM_ARITH_UTIL_H_
-
-#include <cstdint>
-#include <tuple>
-
-namespace tvm {
-/*! \brief namespace of arithmetic analysis. */
-namespace arith {
-
-/*!
- * \brief Calculate the extended greatest common divisor for two values.
- *        See https://en.wikipedia.org/wiki/Extended_Euclidean_algorithm.
- * \param a an integer number
- * \param b an integer number
- * \return 3 integers (div, m, n) where div = gcd(a, b) and a*m + b*n = div
- */
-std::tuple<int64_t, int64_t, int64_t> xgcd(int64_t a, int64_t b);
-
-}  // namespace arith
-}  // namespace tvm
-#endif  // TVM_ARITH_UTIL_H_
index 8e4dda0..b69ce4f 100644 (file)
@@ -25,6 +25,7 @@
 #define TVM_ARITH_INT_OPERATOR_H_
 
 #include <limits>
+#include <utility>
 
 namespace tvm {
 namespace arith {
@@ -117,6 +118,70 @@ inline int64_t floormod(int64_t x, int64_t y) {
   return is_floor_div ? rmod : rmod + y;
 }
 
+/*!
+ * \brief Use Extended Euclidean algorithm to solve ax + by = gcd(a, b)
+ * \param a The first coefficient.
+ * \param b The second coefficient.
+ * \param x The solution of x.
+ * \param y The solution of y.
+ * \return The GCD of a and b.
+ */
+inline int64_t ExtendedEuclidean(int64_t a, int64_t b, int64_t* x, int64_t* y) {
+  // Extended Euclidean algorithm
+  // if a < 0, the problem can be convert into
+  // |a|* (-x) + b * y = gcd(|a|, b)
+  //
+  // initial condition:
+  // a * 0 + b * 1 = b
+  // a * 1 + b * 0 = a
+  int64_t s = 0, old_s = 1;
+  int64_t r = b, old_r = a >= 0 ? a : -a;
+  // Iteration (r2 < r1):
+  // a * x1 + b * y1 = r1
+  // a * x2 + b * y2 = r2
+  // The above two eqs can derive the following eq (q = r1 / r2)
+  // a * (x1 - x2 * q) + b * (y1 - y2 * q) = r1 - r2 * q = r3
+  // Because r3 < r2, the iteration can eventually terminate
+  while (r != 0) {
+    int64_t q = old_r / r;
+    int64_t tmp = old_r;
+    old_r = r;
+    r = tmp - q * r;
+    tmp = old_s;
+    old_s = s;
+    s = tmp - q * s;
+  }
+
+  *x = a >= 0 ? old_s : -old_s;
+  if (b != 0) {
+    *y = (old_r - (*x) * a) / b;
+  } else {
+    *y = 1;
+  }
+
+  return old_r;
+}
+
+/*!
+ * \brief Take GCD of a and b.
+ * \param a The first operand.
+ * \param b The second operand.
+ * \return The result.
+ */
+inline int64_t ZeroAwareGCD(int64_t a, int64_t b) {
+  if (a < 0) a = -a;
+  if (b < 0) b = -b;
+  if (a < b) std::swap(a, b);
+  if (b == 0) return a;
+  // perform GCD (greatest common divisor)
+  // ax + by = gcd(a, b) z if a != 0, b != 0
+  while (a % b != 0) {
+    a = a % b;
+    std::swap(a, b);
+  }
+  return b;
+}
+
 }  // namespace arith
 }  // namespace tvm
 #endif  // TVM_ARITH_INT_OPERATOR_H_
index 7ddb8f5..2645fe9 100644 (file)
@@ -270,49 +270,7 @@ class ModularSetAnalyzer::Impl : public ExprFunctor<ModularSetAnalyzer::Entry(co
       return Entry(ZeroAwareGCD(ZeroAwareGCD(base0, base1), coeff), base0);
     }
   }
-  /*!
-   * \brief Use Extended Euclidean algorithm to solve ax + by = gcd(a, b)
-   * \param a The first coefficient.
-   * \param b The second coefficient.
-   * \param x The solution of x.
-   * \param y The solution of y.
-   * \return The GCD of a and b.
-   */
-  static int64_t ExtendedEuclidean(int64_t a, int64_t b, int64_t* x, int64_t* y) {
-    // Extended Euclidean algorithm
-    // if a < 0, the problem can be convert into
-    // |a|* (-x) + b * y = gcd(|a|, b)
-    //
-    // initial condition:
-    // a * 0 + b * 1 = b
-    // a * 1 + b * 0 = a
-    int64_t s = 0, old_s = 1;
-    int64_t r = b, old_r = a >= 0 ? a : -a;
-    // Iteration (r2 < r1):
-    // a * x1 + b * y1 = r1
-    // a * x2 + b * y2 = r2
-    // The above two eqs can derive the following eq (q = r1 / r2)
-    // a * (x1 - x2 * q) + b * (y1 - y2 * q) = r1 - r2 * q = r3
-    // Because r3 < r2, the iteration can eventually terminate
-    while (r != 0) {
-      int64_t q = old_r / r;
-      int64_t tmp = old_r;
-      old_r = r;
-      r = tmp - q * r;
-      tmp = old_s;
-      old_s = s;
-      s = tmp - q * s;
-    }
-
-    *x = a >= 0 ? old_s : -old_s;
-    if (b != 0) {
-      *y = (old_r - (*x) * a) / b;
-    } else {
-      *y = 1;
-    }
 
-    return old_r;
-  }
   /*!
    * \brief Create interect of two sets.
    * \param a The left operand.
@@ -340,25 +298,6 @@ class ModularSetAnalyzer::Impl : public ExprFunctor<ModularSetAnalyzer::Entry(co
     }
   }
   /*!
-   * \brief Take GCD of a and b.
-   * \param a The first operand.
-   * \param b The second operand.
-   * \return The result.
-   */
-  static int64_t ZeroAwareGCD(int64_t a, int64_t b) {
-    if (a < 0) a = -a;
-    if (b < 0) b = -b;
-    if (a < b) std::swap(a, b);
-    if (b == 0) return a;
-    // perform GCD (greatest common divisor)
-    // ax + by = gcd(a, b) z if a != 0, b != 0
-    while (a % b != 0) {
-      a = a % b;
-      std::swap(a, b);
-    }
-    return b;
-  }
-  /*!
    * \brief return everything dtype can represent.
    * \return Bound that represent everything dtype can represent.
    */
index 50a3243..5bf0e0e 100644 (file)
 #include <tvm/arith/analyzer.h>
 #include <tvm/arith/int_solver.h>
 #include <tvm/arith/pattern.h>
-#include <tvm/arith/util.h>
 #include <tvm/runtime/data_type.h>
 #include <tvm/runtime/registry.h>
 #include <tvm/tir/expr.h>
 #include <tvm/tir/op.h>
 #include <tvm/tir/stmt_functor.h>
 
+#include "int_operator.h"
+
 namespace tvm {
 namespace arith {
 
@@ -96,7 +97,7 @@ void SmithNormalFormDiag(std::vector<std::vector<int64_t>>* S, std::vector<std::
         int64_t g, a, b;
         // g = a*matrix[index][index] + b*matrix[i][index]
         if ((*S)[i][index] % (*S)[index][index] != 0) {
-          std::tie(g, a, b) = xgcd((*S)[index][index], (*S)[i][index]);
+          g = ExtendedEuclidean((*S)[index][index], (*S)[i][index], &a, &b);
         } else {
           // Explicitly avoid changing the index-th row. This is important to avoid infinite loop.
           g = (*S)[index][index];
@@ -149,7 +150,7 @@ void SmithNormalFormDiag(std::vector<std::vector<int64_t>>* S, std::vector<std::
         int64_t g, a, b;
         // g = a*matrix[index][index] + b*matrix[index][j]
         if ((*S)[index][j] % (*S)[index][index] != 0) {
-          std::tie(g, a, b) = xgcd((*S)[index][index], (*S)[index][j]);
+          g = ExtendedEuclidean((*S)[index][index], (*S)[index][j], &a, &b);
           // During this phase we may disrupt the zeroness of the index-th column, so we will
           // have to take some action if this might have happened.
           changed = true;
diff --git a/src/arith/util.cc b/src/arith/util.cc
deleted file mode 100644 (file)
index 7b71892..0000000
+++ /dev/null
@@ -1,53 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file util.cc
- * \brief The utils for arithmetic analysis.
- */
-#include <dmlc/logging.h>
-#include <tvm/arith/util.h>
-
-namespace tvm {
-namespace arith {
-
-std::tuple<int64_t, int64_t, int64_t> xgcd(int64_t a, int64_t b) {
-  int64_t s = 0, old_s = 1;
-  int64_t t = 1, old_t = 0;
-  int64_t r = b, old_r = a;
-
-  while (r != 0) {
-    int64_t q = old_r / r;
-    std::swap(r, old_r);
-    r -= q * old_r;
-    std::swap(s, old_s);
-    s -= q * old_s;
-    std::swap(t, old_t);
-    t -= q * old_t;
-  }
-
-  CHECK_EQ(a % old_r, 0);
-  CHECK_EQ(b % old_r, 0);
-  CHECK(old_r == old_s * a + old_t * b);
-
-  return std::make_tuple(old_r, old_s, old_t);
-}
-
-}  // namespace arith
-}  // namespace tvm