ElementIterator<T>(rawData, splat, getNumElements())};
}
+ template <typename T, typename = typename std::enable_if<
+ std::is_same<T, StringRef>::value>::type>
llvm::iterator_range<ElementIterator<StringRef>> getValues() const {
auto stringRefs = getRawStringData();
const char *ptr = reinterpret_cast<const char *>(stringRefs.data());
// If the data is already known to be a splat, the key hash value is
// directly the data buffer.
if (isKnownSplat)
- return KeyTy(ty, data, llvm::hash_value(data), isKnownSplat);
+ return KeyTy(ty, data, llvm::hash_value(data.front()), isKnownSplat);
// Handle the simple case of only one element.
assert(ty.getNumElements() != 1 &&
if (!firstElt.equals(data[i]))
return KeyTy(ty, data, llvm::hash_combine(hashVal, data.drop_front(i)));
- // Otherwise, this is a splat.
- return KeyTy(ty, data, hashVal, /*isSplat=*/true);
+ // Otherwise, this is a splat so just return the hash of the first element.
+ return KeyTy(ty, data.take_front(), hashVal, /*isSplat=*/true);
}
/// Hash the key for the storage.
//===----------------------------------------------------------------------===//
#include "mlir/IR/Attributes.h"
+#include "mlir/IR/Identifier.h"
#include "mlir/IR/StandardTypes.h"
#include "gtest/gtest.h"
template <typename EltTy>
static void testSplat(Type eltType, const EltTy &splatElt) {
- VectorType shape = VectorType::get({2, 1}, eltType);
+ RankedTensorType shape = RankedTensorType::get({2, 1}, eltType);
// Check that the generated splat is the same for 1 element and N elements.
DenseElementsAttr splat = DenseElementsAttr::get(shape, splatElt);
TEST(DenseSplatTest, BoolSplat) {
MLIRContext context;
IntegerType boolTy = IntegerType::get(1, &context);
- VectorType shape = VectorType::get({2, 2}, boolTy);
+ RankedTensorType shape = RankedTensorType::get({2, 2}, boolTy);
// Check that splat is automatically detected for boolean values.
/// True.
MLIRContext context;
IntegerType boolTy = IntegerType::get(1, &context);
- VectorType shape = VectorType::get({boolCount}, boolTy);
+ RankedTensorType shape = RankedTensorType::get({boolCount}, boolTy);
// Check that splat is automatically detected for boolean values.
/// True.
TEST(DenseSplatTest, BoolNonSplat) {
MLIRContext context;
IntegerType boolTy = IntegerType::get(1, &context);
- VectorType shape = VectorType::get({6}, boolTy);
+ RankedTensorType shape = RankedTensorType::get({6}, boolTy);
// Check that we properly handle non-splat values.
DenseElementsAttr nonSplat =
testSplat(floatTy, value);
}
+TEST(DenseSplatTest, StringSplat) {
+ MLIRContext context;
+ Type stringType =
+ OpaqueType::get(Identifier::get("test", &context), "string", &context);
+ StringRef value = "test-string";
+ testSplat(stringType, value);
+}
+
} // end namespace