From 266841751f072828c5982c972d9a482cda201659 Mon Sep 17 00:00:00 2001 From: Mahesh Ravishankar Date: Fri, 28 Jun 2019 09:59:22 -0700 Subject: [PATCH] Add buffer size information to Linalg::BufferType. If the size is constant then it is represented as . If the size is not a compile time constant, then it is represented as . PiperOrigin-RevId: 255619400 --- mlir/include/mlir/Linalg/IR/LinalgTypes.h | 5 +- mlir/lib/Linalg/IR/LinalgTypes.cpp | 62 ++++++++++++++++++---- mlir/test/Linalg/llvm.mlir | 10 ++-- mlir/test/Linalg/loops.mlir | 12 ++--- mlir/test/Linalg/roundtrip.mlir | 24 ++++----- .../mlir-cpu-runner/linalg_integration_test.mlir | 36 ++++++------- 6 files changed, 97 insertions(+), 52 deletions(-) diff --git a/mlir/include/mlir/Linalg/IR/LinalgTypes.h b/mlir/include/mlir/Linalg/IR/LinalgTypes.h index 30930f3..b1ce221 100644 --- a/mlir/include/mlir/Linalg/IR/LinalgTypes.h +++ b/mlir/include/mlir/Linalg/IR/LinalgTypes.h @@ -53,12 +53,15 @@ public: // Used for generic hooks in TypeBase. using Base::Base; /// Construction hook. - static BufferType get(MLIRContext *context, Type elementType); + static BufferType get(MLIRContext *context, Type elementType, + int64_t bufferSize = -1); /// Used to implement llvm-style cast. static bool kindof(unsigned kind) { return kind == LinalgTypes::Buffer; } // Type-specific functionality. Type getElementType(); + bool hasConstantSize(); + Optional getBufferSize(); }; /// A RangeType represents a minimal range abstraction (min, max, step). diff --git a/mlir/lib/Linalg/IR/LinalgTypes.cpp b/mlir/lib/Linalg/IR/LinalgTypes.cpp index 6b2e541..82be170 100644 --- a/mlir/lib/Linalg/IR/LinalgTypes.cpp +++ b/mlir/lib/Linalg/IR/LinalgTypes.cpp @@ -50,8 +50,10 @@ struct mlir::linalg::BufferTypeStorage : public TypeStorage { /// Underlying Key type to transport the payload needed to construct a custom /// type in a generic way. struct Key { - Key(Type elementType) : elementType(elementType) {} + Key(Type elementType, int64_t bufferSize = -1) + : elementType(elementType), bufferSize(bufferSize) {} Type elementType; + int64_t bufferSize; }; /// `KeyTy` is a necessary typename hook for MLIR's custom type unique'ing. using KeyTy = Key; @@ -64,31 +66,48 @@ struct mlir::linalg::BufferTypeStorage : public TypeStorage { /// Equality operator for hashing. bool operator==(const Key &key) const { - return elementType == key.elementType; + return elementType == key.elementType && bufferSize == key.bufferSize; } /// Hashing for unique'ing. static unsigned hashKey(const Key &key) { - return llvm::hash_combine(key.elementType); + return llvm::hash_combine(key.elementType, key.bufferSize); } - Type getElementType() { return elementType; }; + Type getElementType() { return elementType; } + bool hasConstantSize() { return bufferSize >= 0; } + Optional getBufferSize() { + if (hasConstantSize()) { + return bufferSize; + } + return llvm::None; + } private: - BufferTypeStorage(const Key &key) : elementType(key.elementType) {} + BufferTypeStorage(const Key &key) + : elementType(key.elementType), bufferSize(key.bufferSize) {} Type elementType; + int64_t bufferSize; }; -BufferType mlir::linalg::BufferType::get(MLIRContext *context, - Type elementType) { - return Base::get(context, LinalgTypes::Buffer, elementType); +BufferType mlir::linalg::BufferType::get(MLIRContext *context, Type elementType, + int64_t bufferSize) { + return Base::get(context, LinalgTypes::Buffer, elementType, bufferSize); } Type mlir::linalg::BufferType::getElementType() { return getImpl()->getElementType(); } +bool mlir::linalg::BufferType::hasConstantSize() { + return getImpl()->hasConstantSize(); +} + +Optional mlir::linalg::BufferType::getBufferSize() { + return getImpl()->getBufferSize(); +} + Type mlir::linalg::LinalgDialect::parseType(StringRef spec, Location loc) const { StringRef origSpec = spec; @@ -97,8 +116,24 @@ Type mlir::linalg::LinalgDialect::parseType(StringRef spec, return RangeType::get(getContext()); else if (spec.consume_front("buffer")) { if (spec.consume_front("<") && spec.consume_back(">")) { + // Check for '?' + int64_t bufferSize = -1; + if (!spec.consume_front("?")) { + unsigned parsedBufferSize; + if (!spec.consumeInteger(10, parsedBufferSize)) { + emitError(loc, "expected buffer size to be an unsigned integer"); + return Type(); + } + bufferSize = static_cast(parsedBufferSize); + } + if (!spec.consume_front("x")) { + emitError(loc, "missing x in buffer type descrition : ") << spec; + return Type(); + } if (auto t = mlir::parseType(spec, context)) - return BufferType::get(getContext(), t); + return (bufferSize == -1 + ? BufferType::get(getContext(), t) + : BufferType::get(getContext(), t, bufferSize)); } } else if (spec.consume_front("view")) { if (spec.consume_front("<") && spec.consume_back(">")) { @@ -173,7 +208,14 @@ unsigned mlir::linalg::ViewType::getRank() { return getImpl()->getRank(); } /// BufferType prints as "buffer". static void print(BufferType bt, raw_ostream &os) { - os << "buffer<" << bt.getElementType() << ">"; + os << "buffer<"; + auto bs = bt.getBufferSize(); + if (bs) { + os << bs.getValue(); + } else { + os << "?"; + } + os << "x" << bt.getElementType() << ">"; } /// RangeType prints as just "range". diff --git a/mlir/test/Linalg/llvm.mlir b/mlir/test/Linalg/llvm.mlir index c1abade..8362fc5 100644 --- a/mlir/test/Linalg/llvm.mlir +++ b/mlir/test/Linalg/llvm.mlir @@ -1,7 +1,7 @@ // RUN: mlir-opt %s -linalg-lower-to-llvm-dialect | FileCheck %s -func @buffer_size(%arg0: !linalg.buffer) { - %s = linalg.buffer_size %arg0 : !linalg.buffer +func @buffer_size(%arg0: !linalg.buffer) { + %s = linalg.buffer_size %arg0 : !linalg.buffer return } // CHECK-LABEL: func @buffer_size(%arg0: !llvm<"{ float*, i64 }">) { @@ -21,7 +21,7 @@ func @range(%arg0: index) { // CHECK-NEXT: %4 = llvm.insertvalue %arg0, %3[1] : !llvm<"{ i64, i64, i64 }"> // CHECK-NEXT: %5 = llvm.insertvalue %1, %4[2] : !llvm<"{ i64, i64, i64 }"> -func @view(%arg0: !linalg.buffer, %arg1: !linalg.range) { +func @view(%arg0: !linalg.buffer, %arg1: !linalg.range) { %0 = linalg.view %arg0[%arg1] : !linalg.view return } @@ -40,7 +40,7 @@ func @view(%arg0: !linalg.buffer, %arg1: !linalg.range) { // CHECK-NEXT: %11 = llvm.sub %10, %9 : !llvm.i64 // CHECK-NEXT: %12 = llvm.insertvalue %11, %8[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }"> -func @view3d(%arg0: !linalg.buffer, %arg1: !linalg.range, %arg2: !linalg.range, %arg3: !linalg.range) { +func @view3d(%arg0: !linalg.buffer, %arg1: !linalg.range, %arg2: !linalg.range, %arg3: !linalg.range) { %0 = linalg.view %arg0[%arg1, %arg2, %arg3] : !linalg.view return } @@ -55,7 +55,7 @@ func @view3d(%arg0: !linalg.buffer, %arg1: !linalg.range, %arg2: !linalg.ra // CHECK-NEXT: %15 = llvm.mul %13, %14 : !llvm.i64 // CHECK-NEXT: %16 = llvm.insertvalue %15, %12[3, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }"> -func @slice(%arg0: !linalg.buffer, %arg1: !linalg.range) { +func @slice(%arg0: !linalg.buffer, %arg1: !linalg.range) { %0 = linalg.view %arg0[%arg1] : !linalg.view %1 = linalg.slice %0[%arg1] : !linalg.view, !linalg.range, !linalg.view return diff --git a/mlir/test/Linalg/loops.mlir b/mlir/test/Linalg/loops.mlir index 35086ba..9907b4a 100644 --- a/mlir/test/Linalg/loops.mlir +++ b/mlir/test/Linalg/loops.mlir @@ -4,7 +4,7 @@ // CHECK-DAG: #[[S2D3:.*]] = (d0, d1) -> (d0 * 2 + d1 * 4) // CHECK-DAG: #[[S3D2:.*]] = (d0, d1) -> (d0 * 3 + d1 * 5) -func @matmul(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index) { +func @matmul(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index) { %c0 = constant 0 : index %c1 = constant 1 : index %I = linalg.range %c0:%arg1:%c1 : !linalg.range @@ -16,7 +16,7 @@ func @matmul(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: inde linalg.matmul(%A, %B, %C) : !linalg.view, !linalg.view, !linalg.view return } -// CHECK-LABEL: func @matmul(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index) { +// CHECK-LABEL: func @matmul(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index) { // CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view // CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view // CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view @@ -33,7 +33,7 @@ func @matmul(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: inde // CHECK-DAG: %[[res:.*]] = addf %[[c]], %[[inc]] : f32 // CHECK: linalg.store %[[res]], %[[C]][%i0, %i1] : !linalg.view -func @matvec(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index) { +func @matvec(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index) { %c0 = constant 0 : index %c1 = constant 1 : index %I = linalg.range %c0:%arg1:%c1 : !linalg.range @@ -44,7 +44,7 @@ func @matvec(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: inde linalg.matvec(%2, %3, %4) : !linalg.view, !linalg.view, !linalg.view return } -// CHECK-LABEL: func @matvec(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index) { +// CHECK-LABEL: func @matvec(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index) { // CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view // CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view // CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view @@ -59,7 +59,7 @@ func @matvec(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: inde // CHECK-DAG: %[[res:.*]] = addf %[[c]], %[[inc]] : f32 // CHECK: linalg.store %[[res]], %[[C]][%i0] : !linalg.view -func @dot(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index) { +func @dot(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index) { %c0 = constant 0 : index %c1 = constant 1 : index %I = linalg.range %c0:%arg1:%c1 : !linalg.range @@ -69,7 +69,7 @@ func @dot(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index) linalg.dot(%1, %2, %3) : !linalg.view, !linalg.view, !linalg.view return } -// CHECK-LABEL: func @dot(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index) { +// CHECK-LABEL: func @dot(%arg0: !linalg.buffer, %arg1: index, %arg2: index, %arg3: index) { // CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view // CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view // CHECK: %[[C:.*]] = linalg.view %arg0[] : !linalg.view diff --git a/mlir/test/Linalg/roundtrip.mlir b/mlir/test/Linalg/roundtrip.mlir index 29c2497..2133b24 100644 --- a/mlir/test/Linalg/roundtrip.mlir +++ b/mlir/test/Linalg/roundtrip.mlir @@ -12,14 +12,14 @@ func @range(%arg0: index, %arg1: index, %arg2: index) { func @buffer(%arg0: index, %arg1: index) { %0 = muli %arg0, %arg0 : index - %1 = linalg.buffer_alloc %0 : !linalg.buffer> - linalg.buffer_dealloc %1 : !linalg.buffer> + %1 = linalg.buffer_alloc %0 : !linalg.buffer> + linalg.buffer_dealloc %1 : !linalg.buffer> return } // CHECK-LABEL: func @buffer(%arg0: index, %arg1: index) { // CHECK-NEXT: %0 = muli %arg0, %arg0 : index -// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer> -// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer> +// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer> +// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer> func @view_fun(%arg0: !linalg.view>) { return @@ -28,26 +28,26 @@ func @view_fun(%arg0: !linalg.view>) { func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) { %0 = muli %arg0, %arg0 : index - %1 = linalg.buffer_alloc %0 : !linalg.buffer + %1 = linalg.buffer_alloc %0 : !linalg.buffer %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range %3 = linalg.view %1[%2, %2] : !linalg.view %4 = linalg.slice %3[%2, %2] : !linalg.view, !linalg.range, !linalg.range, !linalg.view %5 = linalg.slice %3[%2, %arg2] : !linalg.view, !linalg.range, index, !linalg.view %6 = linalg.slice %3[%arg2, %2] : !linalg.view, index, !linalg.range, !linalg.view %7 = linalg.slice %3[%arg2, %arg3] : !linalg.view, index, index, !linalg.view - linalg.buffer_dealloc %1 : !linalg.buffer + linalg.buffer_dealloc %1 : !linalg.buffer return } // CHECK-LABEL: func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) { // CHECK-NEXT: %0 = muli %arg0, %arg0 : index -// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer +// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer // CHECK-NEXT: %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range // CHECK-NEXT: %3 = linalg.view %1[%2, %2] : !linalg.view // CHECK-NEXT: %4 = linalg.slice %3[%2, %2] : !linalg.view, !linalg.range, !linalg.range, !linalg.view // CHECK-NEXT: %5 = linalg.slice %3[%2, %arg2] : !linalg.view, !linalg.range, index, !linalg.view // CHECK-NEXT: %6 = linalg.slice %3[%arg2, %2] : !linalg.view, index, !linalg.range, !linalg.view // CHECK-NEXT: %7 = linalg.slice %3[%arg2, %arg3] : !linalg.view, index, index, !linalg.view -// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer +// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer func @ops(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !linalg.view, %arg3: !linalg.view) { linalg.matmul(%arg0, %arg0, %arg0) : !linalg.view, !linalg.view, !linalg.view @@ -62,14 +62,14 @@ func @ops(%arg0: !linalg.view, %arg1: !linalg.view, %arg2: !lina func @dim(%arg0: !linalg.view) { %0 = linalg.dim %arg0, 1 : !linalg.view - %1 = linalg.buffer_alloc %0 : !linalg.buffer - linalg.buffer_dealloc %1 : !linalg.buffer + %1 = linalg.buffer_alloc %0 : !linalg.buffer + linalg.buffer_dealloc %1 : !linalg.buffer return } // CHECK-LABEL: func @dim(%arg0: !linalg.view) { // CHECK-NEXT: %0 = linalg.dim %arg0, 1 : !linalg.view -// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer -// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer +// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer +// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer func @range_intersect(%arg0: !linalg.range, %arg1: !linalg.range) -> !linalg.range { %0 = linalg.range_intersect %arg0, %arg1 : !linalg.range diff --git a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir index e72b49d..c5dcbb3 100644 --- a/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir +++ b/mlir/test/mlir-cpu-runner/linalg_integration_test.mlir @@ -3,10 +3,10 @@ // RUN: mlir-opt %s -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s // RUN: mlir-opt %s -linalg-lower-to-loops -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s -func @fill_f32(%arg0 : !linalg.buffer, %f : f32) { +func @fill_f32(%arg0 : !linalg.buffer, %f : f32) { %c0 = constant 0 : index %c1 = constant 1 : index - %s = linalg.buffer_size %arg0 : !linalg.buffer + %s = linalg.buffer_size %arg0 : !linalg.buffer %R = linalg.range %c0:%s:%c1 : !linalg.range %V = linalg.view %arg0[%R] : !linalg.view affine.for %i0 = 0 to %s { @@ -15,10 +15,10 @@ func @fill_f32(%arg0 : !linalg.buffer, %f : f32) { return } -func @alloc_filled_f32(%s : index, %f : f32) -> !linalg.buffer { - %A = linalg.buffer_alloc %s : !linalg.buffer - call @fill_f32(%A, %f) : (!linalg.buffer, f32) -> () - return %A : !linalg.buffer +func @alloc_filled_f32(%s : index, %f : f32) -> !linalg.buffer { + %A = linalg.buffer_alloc %s : !linalg.buffer + call @fill_f32(%A, %f) : (!linalg.buffer, f32) -> () + return %A : !linalg.buffer } func @dot() -> f32 { @@ -29,9 +29,9 @@ func @dot() -> f32 { %f1 = constant 1.00000e+00 : f32 %f2 = constant 2.00000e+00 : f32 - %bA = call @alloc_filled_f32(%c16, %f2) : (index, f32) -> (!linalg.buffer) - %bB = call @alloc_filled_f32(%c16, %f1) : (index, f32) -> (!linalg.buffer) - %bC = call @alloc_filled_f32(%c1, %f10) : (index, f32) -> (!linalg.buffer) + %bA = call @alloc_filled_f32(%c16, %f2) : (index, f32) -> (!linalg.buffer) + %bB = call @alloc_filled_f32(%c16, %f1) : (index, f32) -> (!linalg.buffer) + %bC = call @alloc_filled_f32(%c1, %f10) : (index, f32) -> (!linalg.buffer) %R = linalg.range %c0:%c16:%c1 : !linalg.range %A = linalg.view %bA[%R] : !linalg.view @@ -41,9 +41,9 @@ func @dot() -> f32 { linalg.dot(%A, %B, %C) : !linalg.view, !linalg.view, !linalg.view %res = linalg.load %C[] : !linalg.view - linalg.buffer_dealloc %bC : !linalg.buffer - linalg.buffer_dealloc %bB : !linalg.buffer - linalg.buffer_dealloc %bA : !linalg.buffer + linalg.buffer_dealloc %bC : !linalg.buffer + linalg.buffer_dealloc %bB : !linalg.buffer + linalg.buffer_dealloc %bA : !linalg.buffer return %res : f32 } @@ -61,9 +61,9 @@ func @matmul() -> f32 { %f2 = constant 2.00000e+00 : f32 %f10 = constant 10.00000e+00 : f32 - %bA = call @alloc_filled_f32(%c160, %f2) : (index, f32) -> (!linalg.buffer) - %bB = call @alloc_filled_f32(%c160, %f1) : (index, f32) -> (!linalg.buffer) - %bC = call @alloc_filled_f32(%c100, %f10) : (index, f32) -> (!linalg.buffer) + %bA = call @alloc_filled_f32(%c160, %f2) : (index, f32) -> (!linalg.buffer) + %bB = call @alloc_filled_f32(%c160, %f1) : (index, f32) -> (!linalg.buffer) + %bC = call @alloc_filled_f32(%c100, %f10) : (index, f32) -> (!linalg.buffer) %M = linalg.range %c0:%c10:%c1 : !linalg.range %N = linalg.range %c0:%c10:%c1 : !linalg.range @@ -75,9 +75,9 @@ func @matmul() -> f32 { linalg.matmul(%A, %B, %C) : !linalg.view, !linalg.view, !linalg.view %res = linalg.load %C[%c6, %c7] : !linalg.view - linalg.buffer_dealloc %bC : !linalg.buffer - linalg.buffer_dealloc %bB : !linalg.buffer - linalg.buffer_dealloc %bA : !linalg.buffer + linalg.buffer_dealloc %bC : !linalg.buffer + linalg.buffer_dealloc %bB : !linalg.buffer + linalg.buffer_dealloc %bA : !linalg.buffer return %res : f32 } -- 2.7.4