From b4c06416df0e58e71c9c36a85f18699f0a98a8e7 Mon Sep 17 00:00:00 2001 From: Nicolas Vasilache Date: Thu, 9 May 2019 12:54:43 -0700 Subject: [PATCH] Move edsc python tests to Filecheck -- PiperOrigin-RevId: 247479507 --- mlir/bindings/python/test/test_py2and3.py | 511 ++++++++++++++++-------------- 1 file changed, 271 insertions(+), 240 deletions(-) diff --git a/mlir/bindings/python/test/test_py2and3.py b/mlir/bindings/python/test/test_py2and3.py index a004640..bc6582f 100644 --- a/mlir/bindings/python/test/test_py2and3.py +++ b/mlir/bindings/python/test/test_py2and3.py @@ -12,17 +12,21 @@ # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== -"""Python2 and 3 test for the MLIR EDSC Python bindings""" - -from __future__ import absolute_import -from __future__ import division -from __future__ import print_function -import unittest +# RUN: $(dirname %s)/test_edsc %s | FileCheck %s +"""Python2 and 3 test for the MLIR EDSC Python bindings""" import google_mlir.bindings.python.pybind as E +import inspect + +# Prints `str` prefixed by the current test function name so we can use it in +# Filecheck label directives. +# This is achieved by inspecting the stack and getting the parent name. +def printWithCurrentFunctionName(str): + print(inspect.stack()[1][3]) + print(str) -class EdscTest(unittest.TestCase): +class EdscTest: def setUp(self): self.module = E.MLIRModule() @@ -32,34 +36,36 @@ class EdscTest(unittest.TestCase): self.indexType = self.module.make_index_type() def testFunctionContext(self): + self.setUp() with self.module.function_context("foo", [], []): pass - self.assertIsNotNone(self.module.get_function("foo")) + printWithCurrentFunctionName(self.module.get_function("foo")) + # CHECK-LABEL: testFunctionContext + # CHECK: func @foo() { def testMultipleFunctions(self): + self.setUp() + with self.module.function_context("foo", [], []): + pass with self.module.function_context("foo", [], []): E.constant_index(0) - code = str(self.module) - self.assertIn("func @foo()", code) - self.assertIn(" %c0 = constant 0 : index", code) - - with self.module.function_context("bar", [], []): - E.constant_index(42) - code = str(self.module) - barPos = code.find("func @bar()") - c42Pos = code.find("%c42 = constant 42 : index") - self.assertNotEqual(barPos, -1) - self.assertNotEqual(c42Pos, -1) - self.assertGreater(c42Pos, barPos) + printWithCurrentFunctionName(str(self.module)) + # CHECK-LABEL: testMultipleFunctions + # CHECK: func @foo() + # CHECK: func @foo_0() + # CHECK: %c0 = constant 0 : index def testFunctionArgs(self): + self.setUp() with self.module.function_context("foo", [self.f32Type, self.f32Type], [self.indexType]) as fun: pass - code = str(fun) - self.assertIn("func @foo(%arg0: f32, %arg1: f32) -> index", code) + printWithCurrentFunctionName(str(fun)) + # CHECK-LABEL: testFunctionArgs + # CHECK: func @foo(%arg0: f32, %arg1: f32) -> index def testLoopContext(self): + self.setUp() with self.module.function_context("foo", [], []) as fun: lhs = E.constant_index(0) rhs = E.constant_index(42) @@ -67,61 +73,49 @@ class EdscTest(unittest.TestCase): lhs + rhs + i with E.LoopContext(rhs, rhs + rhs, 2) as j: x = i + j - code = str(fun) - # TODO(zinenko,ntv): use FileCheck for these tests - self.assertIn(' "affine.for"() ( {\n', code) - self.assertIn( - "{lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (42)}", - code) - self.assertIn(" ^bb1(%i0: index):", code) - self.assertIn(' "affine.for"(%c42, %2) ( {\n', code) - self.assertIn( - "{lower_bound: (d0) -> (d0), step: 2 : index, upper_bound: (d0) -> (d0)} : (index, index) -> ()", - code) - self.assertIn(" ^bb2(%i1: index):", code) - self.assertIn( - ' %3 = "affine.apply"(%i0, %i1) {map: (d0, d1) -> (d0 + d1)} : (index, index) -> index', - code) + printWithCurrentFunctionName(str(fun)) + # CHECK-LABEL: testLoopContext + # CHECK: "affine.for"() ( + # CHECK: ^bb1(%i0: index): + # CHECK: "affine.for"(%c42, %2) ( + # CHECK: ^bb2(%i1: index): + # CHECK: "affine.apply"(%i0, %i1) {map: (d0, d1) -> (d0 + d1)} : (index, index) -> index + # CHECK: {lower_bound: (d0) -> (d0), step: 2 : index, upper_bound: (d0) -> (d0)} : (index, index) -> () + # CHECK: {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (42)} def testLoopNestContext(self): + self.setUp() with self.module.function_context("foo", [], []) as fun: lbs = [E.constant_index(i) for i in range(4)] ubs = [E.constant_index(10 * i + 5) for i in range(4)] with E.LoopNestContext(lbs, ubs, [1, 3, 5, 7]) as (i, j, k, l): i + j + k + l - - code = str(fun) - self.assertIn(' "affine.for"() ( {\n', code) - self.assertIn(" ^bb1(%i0: index):", code) - self.assertIn(' "affine.for"() ( {\n', code) - self.assertIn(" ^bb2(%i1: index):", code) - self.assertIn(' "affine.for"() ( {\n', code) - self.assertIn(" ^bb3(%i2: index):", code) - self.assertIn(' "affine.for"() ( {\n', code) - self.assertIn(" ^bb4(%i3: index):", code) - self.assertIn( - ' %2 = "affine.apply"(%i0, %i1, %i2, %i3) {map: (d0, d1, d2, d3) -> (d0 + d1 + d2 + d3)} : (index, index, index, index) -> index', - code) + printWithCurrentFunctionName(str(fun)) + # CHECK-LABEL: testLoopNestContext + # CHECK: "affine.for"() ( + # CHECK: ^bb1(%i0: index): + # CHECK: "affine.for"() ( + # CHECK: ^bb2(%i1: index): + # CHECK: "affine.for"() ( + # CHECK: ^bb3(%i2: index): + # CHECK: "affine.for"() ( + # CHECK: ^bb4(%i3: index): + # CHECK: %2 = "affine.apply"(%i0, %i1, %i2, %i3) {map: (d0, d1, d2, d3) -> (d0 + d1 + d2 + d3)} : (index, index, index, index) -> index def testBlockContext(self): + self.setUp() with self.module.function_context("foo", [], []) as fun: cst = E.constant_index(42) with E.BlockContext(): cst + cst - code = str(fun) - # Find positions of instructions and make sure they are in the block we - # put them by comparing those positions. - # TODO(zinenko,ntv): this (and tests below) should use FileCheck instead. - c42pos = code.find("%c42 = constant 42 : index") - bb1pos = code.find("^bb1:") - c84pos = code.find('%0 = "affine.apply"() {map: () -> (84)} : () -> index') - self.assertNotEqual(c42pos, -1) - self.assertNotEqual(bb1pos, -1) - self.assertNotEqual(c84pos, -1) - self.assertGreater(bb1pos, c42pos) - self.assertLess(bb1pos, c84pos) + printWithCurrentFunctionName(str(fun)) + # CHECK-LABEL: testBlockContext + # CHECK: %c42 = constant 42 : index + # CHECK: ^bb1: + # CHECK: %0 = "affine.apply"() {map: () -> (84)} : () -> index def testBlockContextAppend(self): + self.setUp() with self.module.function_context("foo", [], []) as fun: E.constant_index(41) with E.BlockContext() as b: @@ -130,25 +124,16 @@ class EdscTest(unittest.TestCase): E.constant_index(42) with E.BlockContext(E.appendTo(blk)): E.constant_index(1) - code = str(fun) - # Find positions of instructions and make sure they are in the block we put - # them by comparing those positions. - c41pos = code.find("%c41 = constant 41 : index") - c42pos = code.find("%c42 = constant 42 : index") - bb1pos = code.find("^bb1:") - c0pos = code.find("%c0 = constant 0 : index") - c1pos = code.find("%c1 = constant 1 : index") - self.assertNotEqual(c41pos, -1) - self.assertNotEqual(c42pos, -1) - self.assertNotEqual(bb1pos, -1) - self.assertNotEqual(c0pos, -1) - self.assertNotEqual(c1pos, -1) - self.assertGreater(bb1pos, c41pos) - self.assertGreater(bb1pos, c42pos) - self.assertLess(bb1pos, c0pos) - self.assertLess(bb1pos, c1pos) + printWithCurrentFunctionName(str(fun)) + # CHECK-LABEL: testBlockContextAppend + # CHECK: %c41 = constant 41 : index + # CHECK: %c42 = constant 42 : index + # CHECK: ^bb1: + # CHECK: %c0 = constant 0 : index + # CHECK: %c1 = constant 1 : index def testBlockContextStandalone(self): + self.setUp() with self.module.function_context("foo", [], []) as fun: blk1 = E.BlockContext() blk2 = E.BlockContext() @@ -161,81 +146,72 @@ class EdscTest(unittest.TestCase): with blk1: E.constant_index(1) E.constant_index(42) - code = str(fun) - # Find positions of instructions and make sure they are in the block we put - # them by comparing those positions. - c41pos = code.find(" %c41 = constant 41 : index") - c42pos = code.find(" %c42 = constant 42 : index") - bb1pos = code.find("^bb1:") - c0pos = code.find(" %c0 = constant 0 : index") - c1pos = code.find(" %c1 = constant 1 : index") - bb2pos = code.find("^bb2:") - c56pos = code.find(" %c56 = constant 56 : index") - c57pos = code.find(" %c57 = constant 57 : index") - self.assertNotEqual(c41pos, -1) - self.assertNotEqual(c42pos, -1) - self.assertNotEqual(bb1pos, -1) - self.assertNotEqual(c0pos, -1) - self.assertNotEqual(c1pos, -1) - self.assertNotEqual(bb2pos, -1) - self.assertNotEqual(c56pos, -1) - self.assertNotEqual(c57pos, -1) - self.assertGreater(bb1pos, c41pos) - self.assertGreater(bb1pos, c42pos) - self.assertLess(bb1pos, c0pos) - self.assertLess(bb1pos, c1pos) - self.assertGreater(bb2pos, c0pos) - self.assertGreater(bb2pos, c1pos) - self.assertGreater(bb2pos, bb1pos) - self.assertLess(bb2pos, c56pos) - self.assertLess(bb2pos, c57pos) + printWithCurrentFunctionName(str(fun)) + # CHECK-LABEL: testBlockContextStandalone + # CHECK: %c41 = constant 41 : index + # CHECK: %c42 = constant 42 : index + # CHECK: ^bb1: + # CHECK: %c0 = constant 0 : index + # CHECK: %c1 = constant 1 : index + # CHECK: ^bb2: + # CHECK: %c56 = constant 56 : index + # CHECK: %c57 = constant 57 : index def testBlockArguments(self): + self.setUp() with self.module.function_context("foo", [], []) as fun: E.constant_index(42) with E.BlockContext([self.f32Type, self.f32Type]) as b: b.arg(0) + b.arg(1) - code = str(fun) - self.assertIn("%c42 = constant 42 : index", code) - self.assertIn("^bb1(%0: f32, %1: f32):", code) - self.assertIn(" %2 = addf %0, %1 : f32", code) + printWithCurrentFunctionName(str(fun)) + # CHECK-LABEL: testBlockArguments + # CHECK: %c42 = constant 42 : index + # CHECK: ^bb1(%0: f32, %1: f32): + # CHECK: %2 = addf %0, %1 : f32 def testBr(self): + self.setUp() with self.module.function_context("foo", [], []) as fun: with E.BlockContext() as b: blk = b E.ret() E.br(blk) - code = str(fun) - self.assertIn(" br ^bb1", code) - self.assertIn("^bb1:", code) - self.assertIn(" return", code) + printWithCurrentFunctionName(str(fun)) + # CHECK-LABEL: testBr + # CHECK: br ^bb1 + # CHECK: ^bb1: + # CHECK: return def testBrDeclaration(self): + self.setUp() with self.module.function_context("foo", [], []) as fun: blk = E.BlockContext() E.br(blk.handle()) with blk: E.ret() - code = str(fun) - self.assertIn(" br ^bb1", code) - self.assertIn("^bb1:", code) - self.assertIn(" return", code) + printWithCurrentFunctionName(str(fun)) + # CHECK-LABEL: testBrDeclaration + # CHECK: br ^bb1 + # CHECK: ^bb1: + # CHECK: return def testBrArgs(self): + self.setUp() with self.module.function_context("foo", [], []) as fun: # Create an infinite loop. with E.BlockContext([self.indexType, self.indexType]) as b: E.br(b, [b.arg(1), b.arg(0)]) E.br(b, [E.constant_index(0), E.constant_index(1)]) - code = str(fun) - self.assertIn(" %c0 = constant 0 : index", code) - self.assertIn(" %c1 = constant 1 : index", code) - self.assertIn(" br ^bb1(%c0, %c1 : index, index)", code) - self.assertIn("^bb1(%0: index, %1: index):", code) - self.assertIn(" br ^bb1(%1, %0 : index, index)", code) + printWithCurrentFunctionName(str(fun)) + # CHECK-LABEL: testBrArgs + # CHECK: %c0 = constant 0 : index + # CHECK: %c1 = constant 1 : index + # CHECK: br ^bb1(%c0, %c1 : index, index) + # CHECK: ^bb1(%0: index, %1: index): + # CHECK: br ^bb1(%1, %0 : index, index) def testCondBr(self): + self.setUp() with self.module.function_context("foo", [self.boolType], []) as fun: with E.BlockContext() as blk1: E.ret([]) @@ -243,86 +219,95 @@ class EdscTest(unittest.TestCase): E.ret([]) cst = E.constant_index(0) E.cond_br(fun.arg(0), blk1, [], blk2, [cst]) - - code = str(fun) - self.assertIn("cond_br %arg0, ^bb1, ^bb2(%c0 : index)", code) + printWithCurrentFunctionName(str(fun)) + # CHECK-LABEL: testCondBr + # CHECK: cond_br %arg0, ^bb1, ^bb2(%c0 : index) def testRet(self): + self.setUp() with self.module.function_context("foo", [], [self.indexType, self.indexType]) as fun: c42 = E.constant_index(42) c0 = E.constant_index(0) E.ret([c42, c0]) - code = str(fun) - self.assertIn(" %c42 = constant 42 : index", code) - self.assertIn(" %c0 = constant 0 : index", code) - self.assertIn(" return %c42, %c0 : index, index", code) + printWithCurrentFunctionName(str(fun)) + # CHECK-LABEL: testRet + # CHECK: %c42 = constant 42 : index + # CHECK: %c0 = constant 0 : index + # CHECK: return %c42, %c0 : index, index def testSelectOp(self): + self.setUp() with self.module.function_context("foo", [self.boolType], [self.i32Type]) as fun: a = E.constant_int(42, 32) b = E.constant_int(0, 32) E.ret([E.select(fun.arg(0), a, b)]) - - code = str(fun) - self.assertIn("%0 = select %arg0, %c42_i32, %c0_i32 : i32", code) + printWithCurrentFunctionName(str(fun)) + # CHECK-LABEL: testSelectOp + # CHECK: %0 = select %arg0, %c42_i32, %c0_i32 : i32 def testCallOp(self): + self.setUp() callee = self.module.declare_function("sqrtf", [self.f32Type], [self.f32Type]) with self.module.function_context("call", [self.f32Type], []) as fun: funCst = E.constant_function(callee) funCst([fun.arg(0)]) + E.constant_float(42., self.f32Type) - - code = str(self.module) - self.assertIn("func @sqrtf(f32) -> f32", code) - self.assertIn("%f = constant @sqrtf : (f32) -> f32", code) - self.assertIn("%0 = call_indirect %f(%arg0) : (f32) -> f32", code) + printWithCurrentFunctionName(str(self.module)) + # CHECK-LABEL: testCallOp + # CHECK: func @sqrtf(f32) -> f32 + # CHECK: %f = constant @sqrtf : (f32) -> f32 + # CHECK: %0 = call_indirect %f(%arg0) : (f32) -> f32 def testBooleanOps(self): + self.setUp() with self.module.function_context( "booleans", [self.boolType for _ in range(4)], []) as fun: i, j, k, l = (fun.arg(x) for x in range(4)) stmt1 = (i < j) & (j >= k) stmt2 = ~(stmt1 | (k == l)) - - code = str(fun) - self.assertIn('%0 = cmpi "slt", %arg0, %arg1 : i1', code) - self.assertIn('%1 = cmpi "sge", %arg1, %arg2 : i1', code) - self.assertIn("%2 = muli %0, %1 : i1", code) - self.assertIn('%3 = cmpi "eq", %arg2, %arg3 : i1', code) - self.assertIn("%true = constant 1 : i1", code) - self.assertIn("%4 = subi %true, %2 : i1", code) - self.assertIn("%true_0 = constant 1 : i1", code) - self.assertIn("%5 = subi %true_0, %3 : i1", code) - self.assertIn("%6 = muli %4, %5 : i1", code) - self.assertIn("%true_1 = constant 1 : i1", code) - self.assertIn("%7 = subi %true_1, %6 : i1", code) - self.assertIn("%true_2 = constant 1 : i1", code) - self.assertIn("%8 = subi %true_2, %7 : i1", code) + printWithCurrentFunctionName(str(fun)) + # CHECK-LABEL: testBooleanOps + # CHECK: %0 = cmpi "slt", %arg0, %arg1 : i1 + # CHECK: %1 = cmpi "sge", %arg1, %arg2 : i1 + # CHECK: %2 = muli %0, %1 : i1 + # CHECK: %3 = cmpi "eq", %arg2, %arg3 : i1 + # CHECK: %true = constant 1 : i1 + # CHECK: %4 = subi %true, %2 : i1 + # CHECK: %true_0 = constant 1 : i1 + # CHECK: %5 = subi %true_0, %3 : i1 + # CHECK: %6 = muli %4, %5 : i1 + # CHECK: %true_1 = constant 1 : i1 + # CHECK: %7 = subi %true_1, %6 : i1 + # CHECK: %true_2 = constant 1 : i1 + # CHECK: %8 = subi %true_2, %7 : i1 def testDivisions(self): + self.setUp() with self.module.function_context( "division", [self.indexType, self.i32Type, self.i32Type], []) as fun: # indices only support floor division fun.arg(0) // E.constant_index(42) # regular values only support regular division fun.arg(1) / fun.arg(2) - - code = str(self.module) - self.assertIn("floordiv 42", code) - self.assertIn("divis %arg1, %arg2 : i32", code) + printWithCurrentFunctionName(str(self.module)) + # CHECK-LABEL: testDivisions + # CHECK: floordiv 42 + # CHECK: divis %arg1, %arg2 : i32 def testCustom(self): + self.setUp() with self.module.function_context("custom", [self.indexType, self.f32Type], []) as fun: E.op("foo", [fun.arg(0)], [self.f32Type]) + fun.arg(1) - code = str(fun) - self.assertIn('%0 = "foo"(%arg0) : (index) -> f32', code) - self.assertIn("%1 = addf %0, %arg1 : f32", code) + printWithCurrentFunctionName(str(fun)) + # CHECK-LABEL: testCustom + # CHECK: %0 = "foo"(%arg0) : (index) -> f32 + # CHECK: %1 = addf %0, %arg1 : f32 def testConstants(self): + self.setUp() with self.module.function_context("constants", [], []) as fun: E.constant_float(1.23, self.module.make_scalar_type("bf16")) E.constant_float(1.23, self.module.make_scalar_type("f16")) @@ -335,20 +320,21 @@ class EdscTest(unittest.TestCase): E.constant_int(123, 64) E.constant_index(123) E.constant_function(fun) - - code = str(fun) - self.assertIn("constant 1.230000e+00 : bf16", code) - self.assertIn("constant 1.230470e+00 : f16", code) - self.assertIn("constant 1.230000e+00 : f32", code) - self.assertIn("constant 1.230000e+00 : f64", code) - self.assertIn("constant 1 : i1", code) - self.assertIn("constant 123 : i8", code) - self.assertIn("constant 123 : i16", code) - self.assertIn("constant 123 : i32", code) - self.assertIn("constant 123 : index", code) - self.assertIn("constant @constants : () -> ()", code) + printWithCurrentFunctionName(str(fun)) + # CHECK-LABEL: testConstants + # CHECK: constant 1.230000e+00 : bf16 + # CHECK: constant 1.230470e+00 : f16 + # CHECK: constant 1.230000e+00 : f32 + # CHECK: constant 1.230000e+00 : f64 + # CHECK: constant 1 : i1 + # CHECK: constant 123 : i8 + # CHECK: constant 123 : i16 + # CHECK: constant 123 : i32 + # CHECK: constant 123 : index + # CHECK: constant @constants : () -> () def testIndexedValue(self): + self.setUp() memrefType = self.module.make_memref_type(self.f32Type, [10, 42]) with self.module.function_context("indexed", [memrefType], [memrefType]) as fun: @@ -359,21 +345,18 @@ class EdscTest(unittest.TestCase): [E.constant_index(10), E.constant_index(42)], [1, 1]) as (i, j): A.store([i, j], A.load([i, j]) + cst) E.ret([fun.arg(0)]) - - code = str(fun) - self.assertIn('"affine.for"()', code) - self.assertIn( - "{lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (10)}", - code) - self.assertIn('"affine.for"()', code) - self.assertIn( - "{lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (42)}", - code) - self.assertIn("%0 = load %arg0[%i0, %i1] : memref<10x42xf32>", code) - self.assertIn("%1 = addf %0, %cst : f32", code) - self.assertIn("store %1, %arg0[%i0, %i1] : memref<10x42xf32>", code) + printWithCurrentFunctionName(str(fun)) + # CHECK-LABEL: testIndexedValue + # CHECK: "affine.for"() + # CHECK: "affine.for"() + # CHECK: %0 = load %arg0[%i0, %i1] : memref<10x42xf32> + # CHECK: %1 = addf %0, %cst : f32 + # CHECK: store %1, %arg0[%i0, %i1] : memref<10x42xf32> + # CHECK: {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (42)} + # CHECK: {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (10)} def testMatrixMultiply(self): + self.setUp() memrefType = self.module.make_memref_type(self.f32Type, [32, 32]) with self.module.function_context( "matmul", [memrefType, memrefType, memrefType], []) as fun: @@ -386,65 +369,70 @@ class EdscTest(unittest.TestCase): k): C.store([i, j], A.load([i, k]) * B.load([k, j])) E.ret([]) - - code = str(fun) - self.assertIn('"affine.for"()', code) - self.assertIn( - "{lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (32)} : () -> ()", - code) - self.assertIn("%0 = load %arg0[%i0, %i2] : memref<32x32xf32>", code) - self.assertIn("%1 = load %arg1[%i2, %i1] : memref<32x32xf32>", code) - self.assertIn("%2 = mulf %0, %1 : f32", code) - self.assertIn("store %2, %arg2[%i0, %i1] : memref<32x32xf32>", code) + printWithCurrentFunctionName(str(fun)) + # CHECK-LABEL: testMatrixMultiply + # CHECK: "affine.for"() + # CHECK: "affine.for"() + # CHECK: "affine.for"() + # CHECK-DAG: %0 = load %arg0[%i0, %i2] : memref<32x32xf32> + # CHECK-DAG: %1 = load %arg1[%i2, %i1] : memref<32x32xf32> + # CHECK: %2 = mulf %0, %1 : f32 + # CHECK: store %2, %arg2[%i0, %i1] : memref<32x32xf32> + # CHECK: {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (32)} : () -> () + # CHECK: {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (32)} : () -> () + # CHECK: {lower_bound: () -> (0), step: 1 : index, upper_bound: () -> (32)} : () -> () def testMLIRScalarTypes(self): + self.setUp() module = E.MLIRModule() - t = module.make_scalar_type("bf16") - self.assertIn("bf16", t.__str__()) - t = module.make_scalar_type("f16") - self.assertIn("f16", t.__str__()) - t = module.make_scalar_type("f32") - self.assertIn("f32", t.__str__()) - t = module.make_scalar_type("f64") - self.assertIn("f64", t.__str__()) - t = module.make_scalar_type("i", 1) - self.assertIn("i1", t.__str__()) - t = module.make_scalar_type("i", 8) - self.assertIn("i8", t.__str__()) - t = module.make_scalar_type("i", 32) - self.assertIn("i32", t.__str__()) - t = module.make_scalar_type("i", 123) - self.assertIn("i123", t.__str__()) - t = module.make_scalar_type("index") - self.assertIn("index", t.__str__()) + printWithCurrentFunctionName(str(module.make_scalar_type("bf16"))) + print(str(module.make_scalar_type("f16"))) + print(str(module.make_scalar_type("f32"))) + print(str(module.make_scalar_type("f64"))) + print(str(module.make_scalar_type("i", 1))) + print(str(module.make_scalar_type("i", 8))) + print(str(module.make_scalar_type("i", 32))) + print(str(module.make_scalar_type("i", 123))) + print(str(module.make_scalar_type("index"))) + # CHECK-LABEL: testMLIRScalarTypes + # CHECK: bf16 + # CHECK: f16 + # CHECK: f32 + # CHECK: f64 + # CHECK: i1 + # CHECK: i8 + # CHECK: i32 + # CHECK: i123 + # CHECK: index def testMLIRFunctionCreation(self): + self.setUp() module = E.MLIRModule() t = module.make_scalar_type("f32") - self.assertIn("f32", t.__str__()) m = module.make_memref_type(t, [3, 4, -1, 5]) - self.assertIn("memref<3x4x?x5xf32>", m.__str__()) - f = module.make_function("copy", [m, m], []) - self.assertIn( - "func @copy(%arg0: memref<3x4x?x5xf32>, %arg1: memref<3x4x?x5xf32>) {", - f.__str__()) - - f = module.make_function("sqrtf", [t], [t]) - self.assertIn("func @sqrtf(%arg0: f32) -> f32", f.__str__()) + printWithCurrentFunctionName(str(t)) + print(str(m)) + print(str(module.make_function("copy", [m, m], []))) + print(str(module.make_function("sqrtf", [t], [t]))) + # CHECK-LABEL: testMLIRFunctionCreation + # CHECK: f32 + # CHECK: memref<3x4x?x5xf32> + # CHECK: func @copy(%arg0: memref<3x4x?x5xf32>, %arg1: memref<3x4x?x5xf32>) { + # CHECK: func @sqrtf(%arg0: f32) -> f32 def testFunctionDeclaration(self): - module = E.MLIRModule() + self.setUp() boolAttr = self.module.boolAttr(True) - t = module.make_memref_type(self.f32Type, [10]) + t = self.module.make_memref_type(self.f32Type, [10]) t_llvm_noalias = t({"llvm.noalias": boolAttr}) t_readonly = t({"readonly": boolAttr}) - f = module.declare_function("foo", [t, t_llvm_noalias, t_readonly], []) - str = module.__str__() - self.assertIn( - "func @foo(memref<10xf32>, memref<10xf32> {llvm.noalias: true}, memref<10xf32> {readonly: true})", - str) + f = self.module.declare_function("foo", [t, t_llvm_noalias, t_readonly], []) + printWithCurrentFunctionName(str(self.module)) + # CHECK-LABEL: testFunctionDeclaration + # CHECK: func @foo(memref<10xf32>, memref<10xf32> {llvm.noalias: true}, memref<10xf32> {readonly: true}) def testMLIRBooleanCompilation(self): + self.setUp() m = self.module.make_memref_type(self.boolType, [10]) # i1 tensor with self.module.function_context("mkbooltensor", [m, m], []) as f: input = E.IndexedValue(f.arg(0)) @@ -457,23 +445,66 @@ class EdscTest(unittest.TestCase): b3 = b2 | (k < j) output.store([i], input.load([i]) & b3) E.ret([]) - - self.module.compile() - self.assertNotEqual(self.module.get_engine_address(), 0) + self.module.compile() + printWithCurrentFunctionName(str(self.module.get_engine_address() == 0)) + # CHECK-LABEL: testMLIRBooleanCompilation + # CHECK: False # Create 'addi' using the generic Op interface. We need an operation known # to the execution engine so that the engine can compile it. def testCustomOpCompilation(self): + self.setUp() with self.module.function_context("adder", [self.i32Type], []) as f: c1 = E.op( "std.constant", [], [self.i32Type], value=self.module.integerAttr(self.i32Type, 42)) E.op("std.addi", [c1, f.arg(0)], [self.i32Type]) E.ret([]) - self.module.compile() - self.assertNotEqual(self.module.get_engine_address(), 0) - - -if __name__ == "__main__": - unittest.main() + printWithCurrentFunctionName(str(self.module.get_engine_address() == 0)) + # CHECK-LABEL: testCustomOpCompilation + # CHECK: False + +# Until python 3.6 this cannot be used because the order in the dict is not the +# order of method declaration. +def runTests(edscTest): + def isTest(attr): + return inspect.ismethod(attr) and "__init" not in str(attr) + + tests = filter(isTest, (getattr(edscTest, attr) for attr in dir(edscTest))) + for test in tests: + test() + +# So instead one must list the functions in order of their Filecheck appearance. +def main(): + edscTest = EdscTest() + edscTest.testFunctionContext() + edscTest.testMultipleFunctions() + edscTest.testFunctionArgs() + edscTest.testLoopContext() + edscTest.testLoopNestContext() + edscTest.testBlockContext() + edscTest.testBlockContextAppend() + edscTest.testBlockContextStandalone() + edscTest.testBlockArguments() + edscTest.testBr() + edscTest.testBrDeclaration() + edscTest.testBrArgs() + edscTest.testCondBr() + edscTest.testRet() + edscTest.testSelectOp() + edscTest.testCallOp() + edscTest.testBooleanOps() + edscTest.testDivisions() + edscTest.testCustom() + edscTest.testConstants() + edscTest.testIndexedValue() + edscTest.testMatrixMultiply() + edscTest.testMLIRScalarTypes() + edscTest.testMLIRFunctionCreation() + edscTest.testFunctionDeclaration() + edscTest.testMLIRBooleanCompilation() + edscTest.testCustomOpCompilation() + +if __name__ == '__main__': + main() -- 2.7.4