[mlir][sparse][pytaco] test cleanup
authorAart Bik <ajcbik@google.com>
Wed, 9 Feb 2022 22:23:22 +0000 (14:23 -0800)
committerAart Bik <ajcbik@google.com>
Thu, 10 Feb 2022 00:58:25 +0000 (16:58 -0800)
removed obsoleted TODO
removed strange Fp precision for coordinates
lined up meta data testing code for readability

Reviewed By: bixia

Differential Revision: https://reviews.llvm.org/D119377

mlir/test/Integration/Dialect/SparseTensor/taco/data/gold_C.tns
mlir/test/Integration/Dialect/SparseTensor/taco/test_SpMM.py
mlir/test/Integration/Dialect/SparseTensor/taco/tools/testing_utils.py

index 9f5aec5..61bec5d 100644 (file)
@@ -1,12 +1,12 @@
 # See http://frostt.io/tensors/file-formats.html for FROSTT (.tns) format
 2 9
 3 3
-1.0 1.0 100.0
-1.0 2.0 107.0
-1.0 3.0 114.0
-2.0 1.0 201.0
-2.0 2.0 216.0
-2.0 3.0 231.0
-3.0 1.0 318.0
-3.0 2.0 342.0
-3.0 3.0 366.0
+1 1 100
+1 2 107
+1 3 114
+2 1 201
+2 2 216
+2 3 231
+3 1 318
+3 2 342
+3 3 366
index 6092301..af1c6ba 100644 (file)
@@ -18,17 +18,13 @@ csr = pt.format([pt.dense, pt.compressed], [0, 1])
 # Read matrices A and B from file, infer size of output matrix C.
 A = pt.read(os.path.join(_SCRIPT_PATH, "data/A.mtx"), csr)
 B = pt.read(os.path.join(_SCRIPT_PATH, "data/B.mtx"), csr)
-C = pt.tensor((A.shape[0], B.shape[1]), csr)
+C = pt.tensor([A.shape[0], B.shape[1]], csr)
 
 # Define the kernel.
 i, j, k = pt.get_index_vars(3)
 C[i, j] = A[i, k] * B[k, j]
 
 # Force evaluation of the kernel by writing out C.
-#
-# TODO: use sparse_tensor.out for output, so that C.tns becomes
-#       a file in extended FROSTT format
-#
 with tempfile.TemporaryDirectory() as test_dir:
   golden_file = os.path.join(_SCRIPT_PATH, "data/gold_C.tns")
   out_file = os.path.join(test_dir, "C.tns")
index 4e277d6..1e13bad 100644 (file)
@@ -23,8 +23,8 @@ def compare_sparse_tns(expected: str, actual: str, rtol: float = 0.0001) -> bool
       _ = expected_f.readline()
 
       # Compare the two lines of meta data
-      if actual_f.readline() != expected_f.readline() or actual_f.readline(
-      ) != expected_f.readline():
+      if (actual_f.readline() != expected_f.readline() or
+          actual_f.readline() != expected_f.readline()):
         return FALSE
 
   actual_data = np.loadtxt(actual, np.float64, skiprows=3)