// Used for generic hooks in TypeBase.
using Base::Base;
/// Construction hook.
- static BufferType get(MLIRContext *context, Type elementType);
+ static BufferType get(MLIRContext *context, Type elementType,
+ int64_t bufferSize = -1);
/// Used to implement llvm-style cast.
static bool kindof(unsigned kind) { return kind == LinalgTypes::Buffer; }
// Type-specific functionality.
Type getElementType();
+ bool hasConstantSize();
+ Optional<int64_t> getBufferSize();
};
/// A RangeType represents a minimal range abstraction (min, max, step).
/// Underlying Key type to transport the payload needed to construct a custom
/// type in a generic way.
struct Key {
- Key(Type elementType) : elementType(elementType) {}
+ Key(Type elementType, int64_t bufferSize = -1)
+ : elementType(elementType), bufferSize(bufferSize) {}
Type elementType;
+ int64_t bufferSize;
};
/// `KeyTy` is a necessary typename hook for MLIR's custom type unique'ing.
using KeyTy = Key;
/// Equality operator for hashing.
bool operator==(const Key &key) const {
- return elementType == key.elementType;
+ return elementType == key.elementType && bufferSize == key.bufferSize;
}
/// Hashing for unique'ing.
static unsigned hashKey(const Key &key) {
- return llvm::hash_combine(key.elementType);
+ return llvm::hash_combine(key.elementType, key.bufferSize);
}
- Type getElementType() { return elementType; };
+ Type getElementType() { return elementType; }
+ bool hasConstantSize() { return bufferSize >= 0; }
+ Optional<int64_t> getBufferSize() {
+ if (hasConstantSize()) {
+ return bufferSize;
+ }
+ return llvm::None;
+ }
private:
- BufferTypeStorage(const Key &key) : elementType(key.elementType) {}
+ BufferTypeStorage(const Key &key)
+ : elementType(key.elementType), bufferSize(key.bufferSize) {}
Type elementType;
+ int64_t bufferSize;
};
-BufferType mlir::linalg::BufferType::get(MLIRContext *context,
- Type elementType) {
- return Base::get(context, LinalgTypes::Buffer, elementType);
+BufferType mlir::linalg::BufferType::get(MLIRContext *context, Type elementType,
+ int64_t bufferSize) {
+ return Base::get(context, LinalgTypes::Buffer, elementType, bufferSize);
}
Type mlir::linalg::BufferType::getElementType() {
return getImpl()->getElementType();
}
+bool mlir::linalg::BufferType::hasConstantSize() {
+ return getImpl()->hasConstantSize();
+}
+
+Optional<int64_t> mlir::linalg::BufferType::getBufferSize() {
+ return getImpl()->getBufferSize();
+}
+
Type mlir::linalg::LinalgDialect::parseType(StringRef spec,
Location loc) const {
StringRef origSpec = spec;
return RangeType::get(getContext());
else if (spec.consume_front("buffer")) {
if (spec.consume_front("<") && spec.consume_back(">")) {
+ // Check for '?'
+ int64_t bufferSize = -1;
+ if (!spec.consume_front("?")) {
+ unsigned parsedBufferSize;
+ if (!spec.consumeInteger(10, parsedBufferSize)) {
+ emitError(loc, "expected buffer size to be an unsigned integer");
+ return Type();
+ }
+ bufferSize = static_cast<int64_t>(parsedBufferSize);
+ }
+ if (!spec.consume_front("x")) {
+ emitError(loc, "missing x in buffer type descrition : ") << spec;
+ return Type();
+ }
if (auto t = mlir::parseType(spec, context))
- return BufferType::get(getContext(), t);
+ return (bufferSize == -1
+ ? BufferType::get(getContext(), t)
+ : BufferType::get(getContext(), t, bufferSize));
}
} else if (spec.consume_front("view")) {
if (spec.consume_front("<") && spec.consume_back(">")) {
/// BufferType prints as "buffer<element_type>".
static void print(BufferType bt, raw_ostream &os) {
- os << "buffer<" << bt.getElementType() << ">";
+ os << "buffer<";
+ auto bs = bt.getBufferSize();
+ if (bs) {
+ os << bs.getValue();
+ } else {
+ os << "?";
+ }
+ os << "x" << bt.getElementType() << ">";
}
/// RangeType prints as just "range".
// RUN: mlir-opt %s -linalg-lower-to-llvm-dialect | FileCheck %s
-func @buffer_size(%arg0: !linalg.buffer<f32>) {
- %s = linalg.buffer_size %arg0 : !linalg.buffer<f32>
+func @buffer_size(%arg0: !linalg.buffer<?xf32>) {
+ %s = linalg.buffer_size %arg0 : !linalg.buffer<?xf32>
return
}
// CHECK-LABEL: func @buffer_size(%arg0: !llvm<"{ float*, i64 }">) {
// CHECK-NEXT: %4 = llvm.insertvalue %arg0, %3[1] : !llvm<"{ i64, i64, i64 }">
// CHECK-NEXT: %5 = llvm.insertvalue %1, %4[2] : !llvm<"{ i64, i64, i64 }">
-func @view(%arg0: !linalg.buffer<f32>, %arg1: !linalg.range) {
+func @view(%arg0: !linalg.buffer<?xf32>, %arg1: !linalg.range) {
%0 = linalg.view %arg0[%arg1] : !linalg.view<?xf32>
return
}
// CHECK-NEXT: %11 = llvm.sub %10, %9 : !llvm.i64
// CHECK-NEXT: %12 = llvm.insertvalue %11, %8[2, 0] : !llvm<"{ float*, i64, [1 x i64], [1 x i64] }">
-func @view3d(%arg0: !linalg.buffer<f32>, %arg1: !linalg.range, %arg2: !linalg.range, %arg3: !linalg.range) {
+func @view3d(%arg0: !linalg.buffer<?xf32>, %arg1: !linalg.range, %arg2: !linalg.range, %arg3: !linalg.range) {
%0 = linalg.view %arg0[%arg1, %arg2, %arg3] : !linalg.view<?x?x?xf32>
return
}
// CHECK-NEXT: %15 = llvm.mul %13, %14 : !llvm.i64
// CHECK-NEXT: %16 = llvm.insertvalue %15, %12[3, 1] : !llvm<"{ float*, i64, [3 x i64], [3 x i64] }">
-func @slice(%arg0: !linalg.buffer<f32>, %arg1: !linalg.range) {
+func @slice(%arg0: !linalg.buffer<?xf32>, %arg1: !linalg.range) {
%0 = linalg.view %arg0[%arg1] : !linalg.view<?xf32>
%1 = linalg.slice %0[%arg1] : !linalg.view<?xf32>, !linalg.range, !linalg.view<?xf32>
return
// CHECK-DAG: #[[S2D3:.*]] = (d0, d1) -> (d0 * 2 + d1 * 4)
// CHECK-DAG: #[[S3D2:.*]] = (d0, d1) -> (d0 * 3 + d1 * 5)
-func @matmul(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index) {
+func @matmul(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%I = linalg.range %c0:%arg1:%c1 : !linalg.range
linalg.matmul(%A, %B, %C) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
return
}
-// CHECK-LABEL: func @matmul(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index) {
+// CHECK-LABEL: func @matmul(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index) {
// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
// CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
// CHECK-DAG: %[[res:.*]] = addf %[[c]], %[[inc]] : f32
// CHECK: linalg.store %[[res]], %[[C]][%i0, %i1] : !linalg.view<?x?xf32>
-func @matvec(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index) {
+func @matvec(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%I = linalg.range %c0:%arg1:%c1 : !linalg.range
linalg.matvec(%2, %3, %4) : !linalg.view<?x?xf32>, !linalg.view<?xf32>, !linalg.view<?xf32>
return
}
-// CHECK-LABEL: func @matvec(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index) {
+// CHECK-LABEL: func @matvec(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index) {
// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?x?xf32>
// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
// CHECK: %[[C:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
// CHECK-DAG: %[[res:.*]] = addf %[[c]], %[[inc]] : f32
// CHECK: linalg.store %[[res]], %[[C]][%i0] : !linalg.view<?xf32>
-func @dot(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index) {
+func @dot(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index) {
%c0 = constant 0 : index
%c1 = constant 1 : index
%I = linalg.range %c0:%arg1:%c1 : !linalg.range
linalg.dot(%1, %2, %3) : !linalg.view<?xf32>, !linalg.view<?xf32>, !linalg.view<f32>
return
}
-// CHECK-LABEL: func @dot(%arg0: !linalg.buffer<f32>, %arg1: index, %arg2: index, %arg3: index) {
+// CHECK-LABEL: func @dot(%arg0: !linalg.buffer<?xf32>, %arg1: index, %arg2: index, %arg3: index) {
// CHECK: %[[A:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
// CHECK: %[[B:.*]] = linalg.view %arg0[{{.*}}] : !linalg.view<?xf32>
// CHECK: %[[C:.*]] = linalg.view %arg0[] : !linalg.view<f32>
func @buffer(%arg0: index, %arg1: index) {
%0 = muli %arg0, %arg0 : index
- %1 = linalg.buffer_alloc %0 : !linalg.buffer<vector<4xi8>>
- linalg.buffer_dealloc %1 : !linalg.buffer<vector<4xi8>>
+ %1 = linalg.buffer_alloc %0 : !linalg.buffer<?xvector<4xi8>>
+ linalg.buffer_dealloc %1 : !linalg.buffer<?xvector<4xi8>>
return
}
// CHECK-LABEL: func @buffer(%arg0: index, %arg1: index) {
// CHECK-NEXT: %0 = muli %arg0, %arg0 : index
-// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<vector<4xi8>>
-// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<vector<4xi8>>
+// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<?xvector<4xi8>>
+// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<?xvector<4xi8>>
func @view_fun(%arg0: !linalg.view<?x?xvector<3x4xi4>>) {
return
func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
%0 = muli %arg0, %arg0 : index
- %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
+ %1 = linalg.buffer_alloc %0 : !linalg.buffer<?xf32>
%2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
%3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
%4 = linalg.slice %3[%2, %2] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
%5 = linalg.slice %3[%2, %arg2] : !linalg.view<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
%6 = linalg.slice %3[%arg2, %2] : !linalg.view<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
%7 = linalg.slice %3[%arg2, %arg3] : !linalg.view<?x?xf32>, index, index, !linalg.view<f32>
- linalg.buffer_dealloc %1 : !linalg.buffer<f32>
+ linalg.buffer_dealloc %1 : !linalg.buffer<?xf32>
return
}
// CHECK-LABEL: func @views(%arg0: index, %arg1: index, %arg2: index, %arg3: index, %arg4: index) {
// CHECK-NEXT: %0 = muli %arg0, %arg0 : index
-// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
+// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<?xf32>
// CHECK-NEXT: %2 = linalg.range %arg2:%arg3:%arg4 : !linalg.range
// CHECK-NEXT: %3 = linalg.view %1[%2, %2] : !linalg.view<?x?xf32>
// CHECK-NEXT: %4 = linalg.slice %3[%2, %2] : !linalg.view<?x?xf32>, !linalg.range, !linalg.range, !linalg.view<?x?xf32>
// CHECK-NEXT: %5 = linalg.slice %3[%2, %arg2] : !linalg.view<?x?xf32>, !linalg.range, index, !linalg.view<?xf32>
// CHECK-NEXT: %6 = linalg.slice %3[%arg2, %2] : !linalg.view<?x?xf32>, index, !linalg.range, !linalg.view<?xf32>
// CHECK-NEXT: %7 = linalg.slice %3[%arg2, %arg3] : !linalg.view<?x?xf32>, index, index, !linalg.view<f32>
-// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<f32>
+// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<?xf32>
func @ops(%arg0: !linalg.view<?x?xf32>, %arg1: !linalg.view<?xf32>, %arg2: !linalg.view<?xf32>, %arg3: !linalg.view<f32>) {
linalg.matmul(%arg0, %arg0, %arg0) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
func @dim(%arg0: !linalg.view<?x?xf32>) {
%0 = linalg.dim %arg0, 1 : !linalg.view<?x?xf32>
- %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
- linalg.buffer_dealloc %1 : !linalg.buffer<f32>
+ %1 = linalg.buffer_alloc %0 : !linalg.buffer<?xf32>
+ linalg.buffer_dealloc %1 : !linalg.buffer<?xf32>
return
}
// CHECK-LABEL: func @dim(%arg0: !linalg.view<?x?xf32>) {
// CHECK-NEXT: %0 = linalg.dim %arg0, 1 : !linalg.view<?x?xf32>
-// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<f32>
-// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<f32>
+// CHECK-NEXT: %1 = linalg.buffer_alloc %0 : !linalg.buffer<?xf32>
+// CHECK-NEXT: linalg.buffer_dealloc %1 : !linalg.buffer<?xf32>
func @range_intersect(%arg0: !linalg.range, %arg1: !linalg.range) -> !linalg.range {
%0 = linalg.range_intersect %arg0, %arg1 : !linalg.range
// RUN: mlir-opt %s -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
// RUN: mlir-opt %s -linalg-lower-to-loops -linalg-lower-to-llvm-dialect | mlir-cpu-runner -e matmul -entry-point-result=f32 -shared-libs=%linalg_test_lib_dir/libcblas%shlibext,%linalg_test_lib_dir/libcblas_interface%shlibext | FileCheck %s
-func @fill_f32(%arg0 : !linalg.buffer<f32>, %f : f32) {
+func @fill_f32(%arg0 : !linalg.buffer<?xf32>, %f : f32) {
%c0 = constant 0 : index
%c1 = constant 1 : index
- %s = linalg.buffer_size %arg0 : !linalg.buffer<f32>
+ %s = linalg.buffer_size %arg0 : !linalg.buffer<?xf32>
%R = linalg.range %c0:%s:%c1 : !linalg.range
%V = linalg.view %arg0[%R] : !linalg.view<?xf32>
affine.for %i0 = 0 to %s {
return
}
-func @alloc_filled_f32(%s : index, %f : f32) -> !linalg.buffer<f32> {
- %A = linalg.buffer_alloc %s : !linalg.buffer<f32>
- call @fill_f32(%A, %f) : (!linalg.buffer<f32>, f32) -> ()
- return %A : !linalg.buffer<f32>
+func @alloc_filled_f32(%s : index, %f : f32) -> !linalg.buffer<?xf32> {
+ %A = linalg.buffer_alloc %s : !linalg.buffer<?xf32>
+ call @fill_f32(%A, %f) : (!linalg.buffer<?xf32>, f32) -> ()
+ return %A : !linalg.buffer<?xf32>
}
func @dot() -> f32 {
%f1 = constant 1.00000e+00 : f32
%f2 = constant 2.00000e+00 : f32
- %bA = call @alloc_filled_f32(%c16, %f2) : (index, f32) -> (!linalg.buffer<f32>)
- %bB = call @alloc_filled_f32(%c16, %f1) : (index, f32) -> (!linalg.buffer<f32>)
- %bC = call @alloc_filled_f32(%c1, %f10) : (index, f32) -> (!linalg.buffer<f32>)
+ %bA = call @alloc_filled_f32(%c16, %f2) : (index, f32) -> (!linalg.buffer<?xf32>)
+ %bB = call @alloc_filled_f32(%c16, %f1) : (index, f32) -> (!linalg.buffer<?xf32>)
+ %bC = call @alloc_filled_f32(%c1, %f10) : (index, f32) -> (!linalg.buffer<?xf32>)
%R = linalg.range %c0:%c16:%c1 : !linalg.range
%A = linalg.view %bA[%R] : !linalg.view<?xf32>
linalg.dot(%A, %B, %C) : !linalg.view<?xf32>, !linalg.view<?xf32>, !linalg.view<f32>
%res = linalg.load %C[] : !linalg.view<f32>
- linalg.buffer_dealloc %bC : !linalg.buffer<f32>
- linalg.buffer_dealloc %bB : !linalg.buffer<f32>
- linalg.buffer_dealloc %bA : !linalg.buffer<f32>
+ linalg.buffer_dealloc %bC : !linalg.buffer<?xf32>
+ linalg.buffer_dealloc %bB : !linalg.buffer<?xf32>
+ linalg.buffer_dealloc %bA : !linalg.buffer<?xf32>
return %res : f32
}
%f2 = constant 2.00000e+00 : f32
%f10 = constant 10.00000e+00 : f32
- %bA = call @alloc_filled_f32(%c160, %f2) : (index, f32) -> (!linalg.buffer<f32>)
- %bB = call @alloc_filled_f32(%c160, %f1) : (index, f32) -> (!linalg.buffer<f32>)
- %bC = call @alloc_filled_f32(%c100, %f10) : (index, f32) -> (!linalg.buffer<f32>)
+ %bA = call @alloc_filled_f32(%c160, %f2) : (index, f32) -> (!linalg.buffer<?xf32>)
+ %bB = call @alloc_filled_f32(%c160, %f1) : (index, f32) -> (!linalg.buffer<?xf32>)
+ %bC = call @alloc_filled_f32(%c100, %f10) : (index, f32) -> (!linalg.buffer<?xf32>)
%M = linalg.range %c0:%c10:%c1 : !linalg.range
%N = linalg.range %c0:%c10:%c1 : !linalg.range
linalg.matmul(%A, %B, %C) : !linalg.view<?x?xf32>, !linalg.view<?x?xf32>, !linalg.view<?x?xf32>
%res = linalg.load %C[%c6, %c7] : !linalg.view<?x?xf32>
- linalg.buffer_dealloc %bC : !linalg.buffer<f32>
- linalg.buffer_dealloc %bB : !linalg.buffer<f32>
- linalg.buffer_dealloc %bA : !linalg.buffer<f32>
+ linalg.buffer_dealloc %bC : !linalg.buffer<?xf32>
+ linalg.buffer_dealloc %bB : !linalg.buffer<?xf32>
+ linalg.buffer_dealloc %bA : !linalg.buffer<?xf32>
return %res : f32
}