[MLIR] Matrix: support matrix-vector multiplication
authorArjun P <arjunpitchanathan@gmail.com>
Wed, 2 Feb 2022 12:09:32 +0000 (17:39 +0530)
committerArjun P <arjunpitchanathan@gmail.com>
Wed, 2 Feb 2022 12:11:14 +0000 (17:41 +0530)
This just moves in the implementation from LinearTransform.

Reviewed By: Groverkss, bondhugula

Differential Revision: https://reviews.llvm.org/D118479

mlir/include/mlir/Analysis/Presburger/LinearTransform.h
mlir/include/mlir/Analysis/Presburger/Matrix.h
mlir/lib/Analysis/Presburger/LinearTransform.cpp
mlir/lib/Analysis/Presburger/Matrix.cpp

index 54d1217..107e57f 100644 (file)
@@ -39,12 +39,16 @@ public:
 
   // The given vector is interpreted as a row vector v. Post-multiply v with
   // this transform, say T, and return vT.
-  SmallVector<int64_t, 8> preMultiplyWithRow(ArrayRef<int64_t> rowVec) const;
+  SmallVector<int64_t, 8> preMultiplyWithRow(ArrayRef<int64_t> rowVec) const {
+    return matrix.preMultiplyWithRow(rowVec);
+  }
 
   // The given vector is interpreted as a column vector v. Pre-multiply v with
   // this transform, say T, and return Tv.
   SmallVector<int64_t, 8>
-  postMultiplyWithColumn(ArrayRef<int64_t> colVec) const;
+  postMultiplyWithColumn(ArrayRef<int64_t> colVec) const {
+    return matrix.postMultiplyWithColumn(colVec);
+  }
 
 private:
   Matrix matrix;
index 2f25eed..95d663c 100644 (file)
@@ -117,6 +117,15 @@ public:
   /// Negate the specified column.
   void negateColumn(unsigned column);
 
+  /// The given vector is interpreted as a row vector v. Post-multiply v with
+  /// this matrix, say M, and return vM.
+  SmallVector<int64_t, 8> preMultiplyWithRow(ArrayRef<int64_t> rowVec) const;
+
+  /// The given vector is interpreted as a column vector v. Pre-multiply v with
+  /// this matrix, say M, and return Mv.
+  SmallVector<int64_t, 8>
+  postMultiplyWithColumn(ArrayRef<int64_t> colVec) const;
+
   /// Resize the matrix to the specified dimensions. If a dimension is smaller,
   /// the values are truncated; if it is bigger, the new values are initialized
   /// to zero.
index 073b64c..ab34f53 100644 (file)
@@ -111,30 +111,6 @@ LinearTransform::makeTransformToColumnEchelon(Matrix m) {
   return {echelonCol, LinearTransform(std::move(resultMatrix))};
 }
 
-SmallVector<int64_t, 8>
-LinearTransform::preMultiplyWithRow(ArrayRef<int64_t> rowVec) const {
-  assert(rowVec.size() == matrix.getNumRows() &&
-         "row vector dimension should match transform output dimension");
-
-  SmallVector<int64_t, 8> result(matrix.getNumColumns(), 0);
-  for (unsigned col = 0, e = matrix.getNumColumns(); col < e; ++col)
-    for (unsigned i = 0, e = matrix.getNumRows(); i < e; ++i)
-      result[col] += rowVec[i] * matrix(i, col);
-  return result;
-}
-
-SmallVector<int64_t, 8>
-LinearTransform::postMultiplyWithColumn(ArrayRef<int64_t> colVec) const {
-  assert(matrix.getNumColumns() == colVec.size() &&
-         "column vector dimension should match transform input dimension");
-
-  SmallVector<int64_t, 8> result(matrix.getNumRows(), 0);
-  for (unsigned row = 0, e = matrix.getNumRows(); row < e; row++)
-    for (unsigned i = 0, e = matrix.getNumColumns(); i < e; i++)
-      result[row] += matrix(row, i) * colVec[i];
-  return result;
-}
-
 IntegerPolyhedron
 LinearTransform::applyTo(const IntegerPolyhedron &poly) const {
   IntegerPolyhedron result(poly.getNumIds());
index e167e91..cdad8ab 100644 (file)
@@ -203,6 +203,29 @@ void Matrix::negateColumn(unsigned column) {
     at(row, column) = -at(row, column);
 }
 
+SmallVector<int64_t, 8>
+Matrix::preMultiplyWithRow(ArrayRef<int64_t> rowVec) const {
+  assert(rowVec.size() == getNumRows() && "Invalid row vector dimension!");
+
+  SmallVector<int64_t, 8> result(getNumColumns(), 0);
+  for (unsigned col = 0, e = getNumColumns(); col < e; ++col)
+    for (unsigned i = 0, e = getNumRows(); i < e; ++i)
+      result[col] += rowVec[i] * at(i, col);
+  return result;
+}
+
+SmallVector<int64_t, 8>
+Matrix::postMultiplyWithColumn(ArrayRef<int64_t> colVec) const {
+  assert(getNumColumns() == colVec.size() &&
+         "Invalid column vector dimension!");
+
+  SmallVector<int64_t, 8> result(getNumRows(), 0);
+  for (unsigned row = 0, e = getNumRows(); row < e; row++)
+    for (unsigned i = 0, e = getNumColumns(); i < e; i++)
+      result[row] += at(row, i) * colVec[i];
+  return result;
+}
+
 void Matrix::print(raw_ostream &os) const {
   for (unsigned row = 0; row < nRows; ++row) {
     for (unsigned column = 0; column < nColumns; ++column)