[mlir][VectorOps] Generalized vector.print to i32/i64
authoraartbik <ajcbik@google.com>
Fri, 7 Feb 2020 17:09:05 +0000 (09:09 -0800)
committeraartbik <ajcbik@google.com>
Fri, 7 Feb 2020 17:25:30 +0000 (09:25 -0800)
Summary:
Lowering to LLVM IR was restricted to float/double.
This CL also adds the integral values.

Reviewers: andydavis1, nicolasvasilache, ftynse

Reviewed By: nicolasvasilache, ftynse

Subscribers: mehdi_amini, rriddle, jpienaar, burmako, shauheen, antiagainst, nicolasvasilache, arpith-jacob, mgester, lucyrfox, liufengdb, Joonsoo, llvm-commits

Tags: #llvm

Differential Revision: https://reviews.llvm.org/D74179

mlir/lib/Conversion/VectorToLLVM/ConvertVectorToLLVM.cpp
mlir/test/Conversion/VectorToLLVM/vector-to-llvm.mlir
mlir/test/mlir-cpu-runner/mlir_runner_utils.cpp

index 32a6964..7df164b 100644 (file)
@@ -809,7 +809,11 @@ public:
     Type eltType = vectorType ? vectorType.getElementType() : printType;
     int64_t rank = vectorType ? vectorType.getRank() : 0;
     Operation *printer;
-    if (eltType.isF32())
+    if (eltType.isInteger(32))
+      printer = getPrintI32(op);
+    else if (eltType.isInteger(64))
+      printer = getPrintI64(op);
+    else if (eltType.isF32())
       printer = getPrintFloat(op);
     else if (eltType.isF64())
       printer = getPrintDouble(op);
@@ -872,6 +876,16 @@ private:
   }
 
   // Helpers for method names.
+  Operation *getPrintI32(Operation *op) const {
+    LLVM::LLVMDialect *dialect = lowering.getDialect();
+    return getPrint(op, dialect, "print_i32",
+                    LLVM::LLVMType::getInt32Ty(dialect));
+  }
+  Operation *getPrintI64(Operation *op) const {
+    LLVM::LLVMDialect *dialect = lowering.getDialect();
+    return getPrint(op, dialect, "print_i64",
+                    LLVM::LLVMType::getInt64Ty(dialect));
+  }
   Operation *getPrintFloat(Operation *op) const {
     LLVM::LLVMDialect *dialect = lowering.getDialect();
     return getPrint(op, dialect, "print_f32",
index fce479a..1852512 100644 (file)
@@ -235,8 +235,8 @@ func @shuffle_1D_direct(%arg0: vector<2xf32>, %arg1: vector<2xf32>) -> vector<2x
   return %1 : vector<2xf32>
 }
 // CHECK-LABEL: llvm.func @shuffle_1D_direct
-// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm<"<2 x float>">
-// CHECK-SAME: %[[B:arg[0-9]+]]: !llvm<"<2 x float>">
+// CHECK-SAME: %[[A:.*]]: !llvm<"<2 x float>">,
+// CHECK-SAME: %[[B:.*]]: !llvm<"<2 x float>">
 //       CHECK:   %[[s:.*]] = llvm.shufflevector %[[A]], %[[B]] [0, 1] : !llvm<"<2 x float>">, !llvm<"<2 x float>">
 //       CHECK:   llvm.return %[[s]] : !llvm<"<2 x float>">
 
@@ -245,8 +245,8 @@ func @shuffle_1D(%arg0: vector<2xf32>, %arg1: vector<3xf32>) -> vector<5xf32> {
   return %1 : vector<5xf32>
 }
 // CHECK-LABEL: llvm.func @shuffle_1D
-// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm<"<2 x float>">
-// CHECK-SAME: %[[B:arg[0-9]+]]: !llvm<"<3 x float>">
+// CHECK-SAME: %[[A:.*]]: !llvm<"<2 x float>">,
+// CHECK-SAME: %[[B:.*]]: !llvm<"<3 x float>">
 //       CHECK:   %[[u0:.*]] = llvm.mlir.undef : !llvm<"<5 x float>">
 //       CHECK:   %[[c2:.*]] = llvm.mlir.constant(2 : index) : !llvm.i64
 //       CHECK:   %[[e1:.*]] = llvm.extractelement %[[B]][%[[c2]] : !llvm.i64] : !llvm<"<3 x float>">
@@ -275,8 +275,8 @@ func @shuffle_2D(%a: vector<1x4xf32>, %b: vector<2x4xf32>) -> vector<3x4xf32> {
   return %1 : vector<3x4xf32>
 }
 // CHECK-LABEL: llvm.func @shuffle_2D
-// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm<"[1 x <4 x float>]">
-// CHECK-SAME: %[[B:arg[0-9]+]]: !llvm<"[2 x <4 x float>]">
+// CHECK-SAME: %[[A:.*]]: !llvm<"[1 x <4 x float>]">,
+// CHECK-SAME: %[[B:.*]]: !llvm<"[2 x <4 x float>]">
 //       CHECK:   %[[u0:.*]] = llvm.mlir.undef : !llvm<"[3 x <4 x float>]">
 //       CHECK:   %[[e1:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[2 x <4 x float>]">
 //       CHECK:   %[[i1:.*]] = llvm.insertvalue %[[e1]], %[[u0]][0] : !llvm<"[3 x <4 x float>]">
@@ -292,7 +292,7 @@ func @extract_element(%arg0: vector<16xf32>) -> f32 {
   return %1 : f32
 }
 // CHECK-LABEL: llvm.func @extract_element
-// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm<"<16 x float>">
+// CHECK-SAME: %[[A:.*]]: !llvm<"<16 x float>">
 //       CHECK:   %[[c:.*]] = llvm.mlir.constant(15 : i32) : !llvm.i32
 //       CHECK:   %[[x:.*]] = llvm.extractelement %[[A]][%[[c]] : !llvm.i32] : !llvm<"<16 x float>">
 //       CHECK:   llvm.return %[[x]] : !llvm.float
@@ -338,8 +338,8 @@ func @insert_element(%arg0: f32, %arg1: vector<4xf32>) -> vector<4xf32> {
   return %1 : vector<4xf32>
 }
 // CHECK-LABEL: llvm.func @insert_element
-// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm.float
-// CHECK-SAME: %[[B:arg[0-9]+]]: !llvm<"<4 x float>">
+// CHECK-SAME: %[[A:.*]]: !llvm.float,
+// CHECK-SAME: %[[B:.*]]: !llvm<"<4 x float>">
 //       CHECK:   %[[c:.*]] = llvm.mlir.constant(3 : i32) : !llvm.i32
 //       CHECK:   %[[x:.*]] = llvm.insertelement %[[A]], %[[B]][%[[c]] : !llvm.i32] : !llvm<"<4 x float>">
 //       CHECK:   llvm.return %[[x]] : !llvm<"<4 x float>">
@@ -395,21 +395,48 @@ func @vector_type_cast(%arg0: memref<8x8x8xf32>) -> memref<vector<8x8x8xf32>> {
 //       CHECK:   llvm.mlir.constant(0 : index
 //       CHECK:   llvm.insertvalue {{.*}}[2] : !llvm<"{ [8 x [8 x <8 x float>]]*, [8 x [8 x <8 x float>]]*, i64 }">
 
-func @vector_print_scalar(%arg0: f32) {
+func @vector_print_scalar_i32(%arg0: i32) {
+  vector.print %arg0 : i32
+  return
+}
+// CHECK-LABEL: llvm.func @vector_print_scalar_i32
+// CHECK-SAME: %[[A:.*]]: !llvm.i32
+//       CHECK:    llvm.call @print_i32(%[[A]]) : (!llvm.i32) -> ()
+//       CHECK:    llvm.call @print_newline() : () -> ()
+
+func @vector_print_scalar_i64(%arg0: i64) {
+  vector.print %arg0 : i64
+  return
+}
+// CHECK-LABEL: llvm.func @vector_print_scalar_i64
+// CHECK-SAME: %[[A:.*]]: !llvm.i64
+//       CHECK:    llvm.call @print_i64(%[[A]]) : (!llvm.i64) -> ()
+//       CHECK:    llvm.call @print_newline() : () -> ()
+
+func @vector_print_scalar_f32(%arg0: f32) {
   vector.print %arg0 : f32
   return
 }
-// CHECK-LABEL: llvm.func @vector_print_scalar
-// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm.float
+// CHECK-LABEL: llvm.func @vector_print_scalar_f32
+// CHECK-SAME: %[[A:.*]]: !llvm.float
 //       CHECK:    llvm.call @print_f32(%[[A]]) : (!llvm.float) -> ()
 //       CHECK:    llvm.call @print_newline() : () -> ()
 
+func @vector_print_scalar_f64(%arg0: f64) {
+  vector.print %arg0 : f64
+  return
+}
+// CHECK-LABEL: llvm.func @vector_print_scalar_f64
+// CHECK-SAME: %[[A:.*]]: !llvm.double
+//       CHECK:    llvm.call @print_f64(%[[A]]) : (!llvm.double) -> ()
+//       CHECK:    llvm.call @print_newline() : () -> ()
+
 func @vector_print_vector(%arg0: vector<2x2xf32>) {
   vector.print %arg0 : vector<2x2xf32>
   return
 }
 // CHECK-LABEL: llvm.func @vector_print_vector
-// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm<"[2 x <2 x float>]">
+// CHECK-SAME: %[[A:.*]]: !llvm<"[2 x <2 x float>]">
 //       CHECK:    llvm.call @print_open() : () -> ()
 //       CHECK:    %[[x0:.*]] = llvm.extractvalue %[[A]][0] : !llvm<"[2 x <2 x float>]">
 //       CHECK:    llvm.call @print_open() : () -> ()
@@ -549,8 +576,8 @@ func @insert_strided_slice3(%arg0: vector<2x4xf32>, %arg1: vector<16x4x8xf32>) -
   return %0 : vector<16x4x8xf32>
 }
 // CHECK-LABEL: llvm.func @insert_strided_slice3
-// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm<"[2 x <4 x float>]">
-// CHECK-SAME: %[[B:arg[0-9]+]]: !llvm<"[16 x [4 x <8 x float>]]">
+// CHECK-SAME: %[[A:.*]]: !llvm<"[2 x <4 x float>]">,
+// CHECK-SAME: %[[B:.*]]: !llvm<"[16 x [4 x <8 x float>]]">
 //      CHECK: %[[s0:.*]] = llvm.extractvalue %[[B]][0] : !llvm<"[16 x [4 x <8 x float>]]">
 //      CHECK: %[[s1:.*]] = llvm.extractvalue %[[A]][0] : !llvm<"[2 x <4 x float>]">
 //      CHECK: %[[s2:.*]] = llvm.extractvalue %[[s0]][0] : !llvm<"[4 x <8 x float>]">
@@ -600,7 +627,7 @@ func @extract_strides(%arg0: vector<3x3xf32>) -> vector<1x1xf32> {
   return %1 : vector<1x1xf32>
 }
 // CHECK-LABEL: llvm.func @extract_strides
-// CHECK-SAME: %[[A:arg[0-9]+]]: !llvm<"[3 x <3 x float>]">
+// CHECK-SAME: %[[A:.*]]: !llvm<"[3 x <3 x float>]">
 //      CHECK: %[[s0:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1x1xf32>) : !llvm<"[1 x <1 x float>]">
 //      CHECK: %[[s1:.*]] = llvm.extractvalue %[[A]][2] : !llvm<"[3 x <3 x float>]">
 //      CHECK: %[[s3:.*]] = llvm.mlir.constant(dense<0.000000e+00> : vector<1xf32>) : !llvm<"<1 x float>">
index 77a5b0a..9225eab 100644 (file)
@@ -13,6 +13,7 @@
 
 #include "include/mlir_runner_utils.h"
 
+#include <cinttypes>
 #include <cstdio>
 
 extern "C" void
@@ -76,7 +77,9 @@ extern "C" void print_memref_4d_f32(StridedMemRefType<float, 4> *M) {
 // Small runtime support "lib" for vector.print lowering.
 // By providing elementary printing methods only, this
 // library can remain fully unaware of low-level implementation
-// details of our vectors.
+// details of our vectors. Also useful for direct LLVM IR output.
+extern "C" void print_i32(int32_t i) { fprintf(stdout, "%" PRId32, i); }
+extern "C" void print_i64(int64_t l) { fprintf(stdout, "%" PRId64, l); }
 extern "C" void print_f32(float f) { fprintf(stdout, "%g", f); }
 extern "C" void print_f64(double d) { fprintf(stdout, "%lg", d); }
 extern "C" void print_open() { fputs("( ", stdout); }