[MLIR] Add VectorTransferOps
authorNicolas Vasilache <ntv@google.com>
Mon, 3 Dec 2018 23:21:27 +0000 (15:21 -0800)
committerjpienaar <jpienaar@google.com>
Fri, 29 Mar 2019 21:15:25 +0000 (14:15 -0700)
This CL implements and uses VectorTransferOps in lieu of the former custom
call op. Tests are updated accordingly.

VectorTransferOps come in 2 flavors: VectorTransferReadOp and
VectorTransferWriteOp.

VectorTransferOps can be thought of as a backend-independent
pseudo op/library call that needs to be legalized to MLIR (whiteboxed) before
it can be lowered to backend-dependent IR.

Note that the current implementation does not yet support a real permutation
map. Proper support will come in a followup CL.

VectorTransferReadOp
====================
VectorTransferReadOp performs a blocking read from a scalar memref
location into a super-vector of the same elemental type. This operation is
called 'read' by opposition to 'load' because the super-vector granularity
is generally not representable with a single hardware register. As a
consequence, memory transfers will generally be required when lowering
VectorTransferReadOp. A VectorTransferReadOp is thus a mid-level abstraction
that supports super-vectorization with non-effecting padding for full-tile
only code.

A vector transfer read has semantics similar to a vector load, with additional
support for:
  1. an optional value of the elemental type of the MemRef. This value
     supports non-effecting padding and is inserted in places where the
     vector read exceeds the MemRef bounds. If the value is not specified,
     the access is statically guaranteed to be within bounds;
  2. an attribute of type AffineMap to specify a slice of the original
     MemRef access and its transposition into the super-vector shape. The
     permutation_map is an unbounded AffineMap that must represent a
     permutation from the MemRef dim space projected onto the vector dim
     space.

Example:
```mlir
  %A = alloc(%size1, %size2, %size3, %size4) : memref<?x?x?x?xf32>
  ...
  %val = `ssa-value` : f32
  // let %i, %j, %k, %l be ssa-values of type index
  %v0 = vector_transfer_read %src, %i, %j, %k, %l
        {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
          (memref<?x?x?x?xf32>, index, index, index, index) ->
            vector<16x32x64xf32>
  %v1 = vector_transfer_read %src, %i, %j, %k, %l, %val
        {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
          (memref<?x?x?x?xf32>, index, index, index, index, f32) ->
            vector<16x32x64xf32>
```

VectorTransferWriteOp
=====================
VectorTransferWriteOp performs a blocking write from a super-vector to
a scalar memref of the same elemental type. This operation is
called 'write' by opposition to 'store' because the super-vector
granularity is generally not representable with a single hardware register. As
a consequence, memory transfers will generally be required when lowering
VectorTransferWriteOp. A VectorTransferWriteOp is thus a mid-level
abstraction that supports super-vectorization with non-effecting padding
for full-tile only code.
A vector transfer write has semantics similar to a vector store, with
additional support for handling out-of-bounds situations.

Example:
```mlir
  %A = alloc(%size1, %size2, %size3, %size4) : memref<?x?x?x?xf32>.
  %val = `ssa-value` : vector<16x32x64xf32>
  // let %i, %j, %k, %l be ssa-values of type index
  vector_transfer_write %val, %src, %i, %j, %k, %l
    {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
  (vector<16x32x64xf32>, memref<?x?x?x?xf32>, index, index, index, index)
```
PiperOrigin-RevId: 223873234

12 files changed:
mlir/include/mlir/Analysis/VectorAnalysis.h
mlir/include/mlir/StandardOps/StandardOps.h
mlir/include/mlir/Support/Functional.h
mlir/lib/Analysis/LoopAnalysis.cpp
mlir/lib/Analysis/VectorAnalysis.cpp
mlir/lib/StandardOps/StandardOps.cpp
mlir/lib/Transforms/MaterializeVectors.cpp
mlir/lib/Transforms/Vectorize.cpp
mlir/test/IR/core-ops.mlir
mlir/test/IR/invalid-ops.mlir
mlir/test/Transforms/materialize_vectors.mlir
mlir/test/Transforms/vectorize.mlir

index 82bffb8fa7d6623e786f580531c945a696c3c017..a3d31b2f964c57a079c91fa3889c8f576be15966 100644 (file)
 
 namespace mlir {
 
+class AffineMap;
+class MemRefType;
 class OperationStmt;
 class VectorType;
 
-// TODO(ntv): Drop this once we have proper Ops.
-static constexpr auto kVectorTransferReadOpName = "vector_transfer_read";
-static constexpr auto kVectorTransferWriteOpName = "vector_transfer_write";
-bool isaVectorTransferRead(const OperationStmt &stmt);
-bool isaVectorTransferWrite(const OperationStmt &stmt);
-
 /// Computes and returns the multi-dimensional ratio of `superShape` to
 /// `subShape`. This is calculated by performing a traversal from minor to major
 /// dimensions (i.e. in reverse shape order). If integral division is not
@@ -49,6 +45,16 @@ shapeRatio(ArrayRef<int> superShape, ArrayRef<int> subShape);
 llvm::Optional<llvm::SmallVector<unsigned, 4>>
 shapeRatio(VectorType superVectorType, VectorType subVectorType);
 
+/// Creates a permutation map to be used as an attribute in VectorTransfer ops.
+/// Currently only returns the minor vectorType.rank identity submatrix.
+///
+/// For example, assume memrefType is of rank 5 and vectorType is of rank 3,
+/// returns the affine map:
+///     (d0, d1, d2, d3, d4) -> (d2, d3, d4)
+///
+/// TODO(ntv): support real permutations.
+AffineMap makePermutationMap(MemRefType memrefType, VectorType vectorType);
+
 namespace matcher {
 
 /// Matches vector_transfer_read, vector_transfer_write and ops that return a
index 8301b131addf13ecd6e094e82556ab8ad5ddb2bd..ffd903aa779e916c3d0229e0958ce118c6f359b2 100644 (file)
@@ -28,6 +28,7 @@
 #include "mlir/IR/OpDefinition.h"
 
 namespace mlir {
+class AffineMap;
 class Builder;
 class MLValue;
 
@@ -795,6 +796,149 @@ private:
   explicit TensorCastOp(const Operation *state) : CastOp(state) {}
 };
 
+/// VectorTransferReadOp performs a blocking read from a scalar memref
+/// location into a super-vector of the same elemental type. This operation is
+/// called 'read' by opposition to 'load' because the super-vector granularity
+/// is generally not representable with a single hardware register. As a
+/// consequence, memory transfers will generally be required when lowering
+/// VectorTransferReadOp. A VectorTransferReadOp is thus a mid-level abstraction
+/// that supports super-vectorization with non-effecting padding for full-tile
+/// only code.
+//
+/// A vector transfer read has semantics similar to a vector load, with
+/// additional support for:
+///   1. an optional value of the elemental type of the MemRef. This value
+///      supports non-effecting padding and is inserted in places where the
+///      vector read exceeds the MemRef bounds. If the value is not specified,
+///      the access is statically guaranteed to be within bounds;
+///   2. an attribute of type AffineMap to specify a slice of the original
+///      MemRef access and its transposition into the super-vector shape. The
+///      permutation_map is an unbounded AffineMap that must represent a
+///      permutation from the MemRef dim space projected onto the vector dim
+///      space.
+//
+/// Example:
+/// ```mlir
+///   %A = alloc(%size1, %size2, %size3, %size4) : memref<?x?x?x?xf32>
+///   ...
+///   %val = `ssa-value` : f32
+///   // let %i, %j, %k, %l be ssa-values of type index
+///   %v0 = vector_transfer_read %src, %i, %j, %k, %l
+///          {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
+///        (memref<?x?x?x?xf32>, index, index, index, index) ->
+///          vector<16x32x64xf32>
+///   %v1 = vector_transfer_read %src, %i, %j, %k, %l, %val
+///          {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
+///        (memref<?x?x?x?xf32>, index, index, index, index, f32) ->
+///           vector<16x32x64xf32>
+/// ```
+class VectorTransferReadOp
+    : public Op<VectorTransferReadOp, OpTrait::VariadicOperands,
+                OpTrait::OneResult> {
+  enum Offsets : unsigned { MemRefOffset = 0, FirstIndexOffset = 1 };
+
+public:
+  static StringRef getOperationName() { return "vector_transfer_read"; }
+  static StringRef getPermutationMapAttrName() { return "permutation_map"; }
+  static void build(Builder *builder, OperationState *result,
+                    VectorType vectorType, SSAValue *srcMemRef,
+                    ArrayRef<SSAValue *> srcIndices, AffineMap permutationMap,
+                    Optional<SSAValue *> paddingValue = None);
+  VectorType getResultType() const {
+    return getResult()->getType().cast<VectorType>();
+  }
+  SSAValue *getMemRef() { return getOperand(Offsets::MemRefOffset); }
+  const SSAValue *getMemRef() const {
+    return getOperand(Offsets::MemRefOffset);
+  }
+  MemRefType getMemRefType() const {
+    return getMemRef()->getType().cast<MemRefType>();
+  }
+  llvm::iterator_range<Operation::operand_iterator> getIndices();
+  llvm::iterator_range<Operation::const_operand_iterator> getIndices() const;
+  Optional<SSAValue *> getPaddingValue();
+  Optional<const SSAValue *> getPaddingValue() const;
+  AffineMap getPermutationMap() const;
+
+  static bool parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p) const;
+  bool verify() const;
+
+private:
+  friend class Operation;
+  explicit VectorTransferReadOp(const Operation *state) : Op(state) {}
+};
+
+/// VectorTransferWriteOp performs a blocking write from a super-vector to
+/// a scalar memref of the same elemental type. This operation is
+/// called 'write' by opposition to 'store' because the super-vector granularity
+/// is generally not representable with a single hardware register. As a
+/// consequence, memory transfers will generally be required when lowering
+/// VectorTransferWriteOp. A VectorTransferWriteOp is thus a mid-level
+/// abstraction that supports super-vectorization with non-effecting padding for
+/// full-tile only code.
+///
+/// A vector transfer write has semantics similar to a vector store, with
+/// additional support for handling out-of-bounds situations. It is the
+/// responsibility of vector_transfer_write's implementation to ensure the
+/// memory writes are valid. Different implementations may be pertinent
+/// depending on the hardware support including:
+/// 1. predication;
+/// 2. explicit control-flow;
+/// 3. Read-Modify-Write;
+/// 4. writing out of bounds of the memref when the allocation allows it.
+///
+/// Example:
+/// ```mlir
+///   %A = alloc(%size1, %size2, %size3, %size4) : memref<?x?x?x?xf32>.
+///   %val = `ssa-value` : vector<16x32x64xf32>
+///   // let %i, %j, %k, %l be ssa-values of type index
+///   vector_transfer_write %val, %src, %i, %j, %k, %l
+///     {permutation_map: (d0, d1, d2, d3) -> (d3, d1, d2)} :
+///   vector<16x32x64xf32>, memref<?x?x?x?xf32>, index, index, index, index
+/// ```
+class VectorTransferWriteOp
+    : public Op<VectorTransferWriteOp, OpTrait::VariadicOperands,
+                OpTrait::ZeroResult> {
+  enum Offsets : unsigned {
+    VectorOffset = 0,
+    MemRefOffset = 1,
+    FirstIndexOffset = 2
+  };
+
+public:
+  static StringRef getOperationName() { return "vector_transfer_write"; }
+  static StringRef getPermutationMapAttrName() { return "permutation_map"; }
+  static void build(Builder *builder, OperationState *result,
+                    SSAValue *srcVector, SSAValue *dstMemRef,
+                    ArrayRef<SSAValue *> dstIndices, AffineMap permutationMap);
+  SSAValue *getVector() { return getOperand(Offsets::VectorOffset); }
+  const SSAValue *getVector() const {
+    return getOperand(Offsets::VectorOffset);
+  }
+  VectorType getVectorType() const {
+    return getVector()->getType().cast<VectorType>();
+  }
+  SSAValue *getMemRef() { return getOperand(Offsets::MemRefOffset); }
+  const SSAValue *getMemRef() const {
+    return getOperand(Offsets::MemRefOffset);
+  }
+  MemRefType getMemRefType() const {
+    return getMemRef()->getType().cast<MemRefType>();
+  }
+  llvm::iterator_range<Operation::operand_iterator> getIndices();
+  llvm::iterator_range<Operation::const_operand_iterator> getIndices() const;
+  AffineMap getPermutationMap() const;
+
+  static bool parse(OpAsmParser *parser, OperationState *result);
+  void print(OpAsmPrinter *p) const;
+  bool verify() const;
+
+private:
+  friend class Operation;
+  explicit VectorTransferWriteOp(const Operation *state) : Op(state) {}
+};
+
 } // end namespace mlir
 
 #endif
index 071a611151c96e39570d1c014b5c3a2d0282e7fc..e1b1ee5ce58a409e8c82c9051bc8affad26d4f6b 100644 (file)
@@ -19,6 +19,7 @@
 #define MLIR_SUPPORT_FUNCTIONAL_H_
 
 #include "llvm/ADT/STLExtras.h"
+#include "llvm/Support/Casting.h"
 
 /// This file provides some simple template functional-style sugar to operate
 /// on **value** types. Make sure when using that the stored type is cheap to
@@ -78,6 +79,14 @@ void zipApply(Fun fun, ContainerType1 input1, ContainerType2 input2) {
   }
 }
 
+/// Unwraps a pointer type to another type (possibly the same).
+/// Used in particular to allow easier compositions of
+///   llvm::iterator_range<ForStmt::operand_iterator> types.
+template <typename T, typename ToType = T>
+inline std::function<ToType *(T *)> makePtrDynCaster() {
+  return [](T *val) { return llvm::dyn_cast<ToType>(val); };
+}
+
 /// Simple ScopeGuard.
 struct ScopeGuard {
   explicit ScopeGuard(std::function<void(void)> destruct)
index 8406a37d7938fe8a089051d07e700b2d2e569963..de98849136c1b9c67f9842cb7b7cd591e4cf5e07 100644 (file)
@@ -194,7 +194,8 @@ static bool isVectorElement(LoadOrStoreOpPointer memoryOp) {
 // TODO(ntv): make the following into MLIR instructions, then use isa<>.
 static bool isVectorTransferReadOrWrite(const Statement &stmt) {
   const auto *opStmt = cast<OperationStmt>(&stmt);
-  return isaVectorTransferRead(*opStmt) || isaVectorTransferWrite(*opStmt);
+  return opStmt->isa<VectorTransferReadOp>() ||
+         opStmt->isa<VectorTransferWriteOp>();
 }
 
 using VectorizableStmtFun =
index 75f62299e1b2b004a27adc8c51d3336461662fb0..9c2160c1450a189d06f5afcc1e3f9a8e2b86311a 100644 (file)
@@ -18,6 +18,7 @@
 #include "mlir/Analysis/VectorAnalysis.h"
 #include "mlir/IR/BuiltinOps.h"
 #include "mlir/IR/Statements.h"
+#include "mlir/StandardOps/StandardOps.h"
 #include "mlir/Support/Functional.h"
 #include "mlir/Support/STLExtras.h"
 
 
 using namespace mlir;
 
-bool mlir::isaVectorTransferRead(const OperationStmt &stmt) {
-  return stmt.getName().getStringRef().str() == kVectorTransferReadOpName;
-}
-
-bool mlir::isaVectorTransferWrite(const OperationStmt &stmt) {
-  return stmt.getName().getStringRef().str() == kVectorTransferWriteOpName;
-}
-
 Optional<SmallVector<unsigned, 4>> mlir::shapeRatio(ArrayRef<int> superShape,
                                                     ArrayRef<int> subShape) {
   if (superShape.size() < subShape.size()) {
@@ -83,6 +76,20 @@ Optional<SmallVector<unsigned, 4>> mlir::shapeRatio(VectorType superVectorType,
   return shapeRatio(superVectorType.getShape(), subVectorType.getShape());
 }
 
+AffineMap mlir::makePermutationMap(MemRefType memrefType,
+                                   VectorType vectorType) {
+  unsigned memRefRank = memrefType.getRank();
+  unsigned vectorRank = vectorType.getRank();
+  assert(memRefRank >= vectorRank && "Broadcast not supported");
+  unsigned offset = memRefRank - vectorRank;
+  SmallVector<AffineExpr, 4> perm;
+  perm.reserve(memRefRank);
+  for (unsigned i = 0; i < vectorRank; ++i) {
+    perm.push_back(getAffineDimExpr(offset + i, memrefType.getContext()));
+  }
+  return AffineMap::get(memRefRank, 0, perm, {});
+}
+
 bool mlir::matcher::operatesOnStrictSuperVectors(const OperationStmt &opStmt,
                                                  VectorType subVectorType) {
   // First, extract the vector type and ditinguish between:
@@ -96,15 +103,11 @@ bool mlir::matcher::operatesOnStrictSuperVectors(const OperationStmt &opStmt,
   /// do not have to special case. Maybe a trait, or just a method, unclear atm.
   bool mustDivide = false;
   VectorType superVectorType;
-  if (isaVectorTransferRead(opStmt)) {
-    superVectorType = opStmt.getResult(0)->getType().cast<VectorType>();
+  if (auto read = opStmt.dyn_cast<VectorTransferReadOp>()) {
+    superVectorType = read->getResultType();
     mustDivide = true;
-  } else if (isaVectorTransferWrite(opStmt)) {
-    // TODO(ntv): if vector_transfer_write had store-like semantics we could
-    // have written something similar to:
-    //   auto store = storeOp->cast<StoreOp>();
-    //   auto *value = store->getValueToStore();
-    superVectorType = opStmt.getOperand(0)->getType().cast<VectorType>();
+  } else if (auto write = opStmt.dyn_cast<VectorTransferWriteOp>()) {
+    superVectorType = write->getVectorType();
     mustDivide = true;
   } else if (opStmt.getNumResults() == 0) {
     assert(opStmt.isa<ReturnOp>() &&
index 4de951a12f9803b98d92b8dc868b159a5af97be4..4d71ddeab1641042017652a6323cf0826a733920 100644 (file)
@@ -40,7 +40,8 @@ StandardOpsDialect::StandardOpsDialect(MLIRContext *context)
   addOperations<AddFOp, AddIOp, AllocOp, CallOp, CallIndirectOp, CmpIOp,
                 DeallocOp, DimOp, DmaStartOp, DmaWaitOp, ExtractElementOp,
                 LoadOp, MemRefCastOp, MulFOp, MulIOp, SelectOp, StoreOp, SubFOp,
-                SubIOp, TensorCastOp>();
+                SubIOp, TensorCastOp, VectorTransferReadOp,
+                VectorTransferWriteOp>();
 }
 
 //===----------------------------------------------------------------------===//
@@ -1321,3 +1322,427 @@ bool TensorCastOp::verify() const {
 
   return false;
 }
+
+//===----------------------------------------------------------------------===//
+// VectorTransferReadOp
+//===----------------------------------------------------------------------===//
+template <typename EmitFun>
+static bool verifyPermutationMap(AffineMap permutationMap,
+                                 EmitFun emitOpError) {
+  SmallVector<bool, 8> seen(permutationMap.getNumInputs(), false);
+  for (auto expr : permutationMap.getResults()) {
+    auto dim = expr.dyn_cast<AffineDimExpr>();
+    if (!dim) {
+      return emitOpError(
+          "requires a permutation_map that is an actual permutation");
+    }
+    if (seen[dim.getPosition()]) {
+      return emitOpError(
+          "requires a permutation_map that is a full column-rank "
+          "permutation (i.e. a permutation composed with an "
+          "orthogonal projection)");
+    }
+    seen[dim.getPosition()] = true;
+  }
+  return false;
+}
+
+void VectorTransferReadOp::build(Builder *builder, OperationState *result,
+                                 VectorType vectorType, SSAValue *srcMemRef,
+                                 ArrayRef<SSAValue *> srcIndices,
+                                 AffineMap permutationMap,
+                                 Optional<SSAValue *> paddingValue) {
+  result->addOperands(srcMemRef);
+  result->addOperands(srcIndices);
+  if (paddingValue) {
+    result->addOperands({*paddingValue});
+  }
+  result->addAttribute(getPermutationMapAttrName(),
+                       builder->getAffineMapAttr(permutationMap));
+  result->addTypes(vectorType);
+}
+
+llvm::iterator_range<Operation::operand_iterator>
+VectorTransferReadOp::getIndices() {
+  auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
+  auto end = begin + getMemRefType().getRank();
+  return {begin, end};
+}
+
+llvm::iterator_range<Operation::const_operand_iterator>
+VectorTransferReadOp::getIndices() const {
+  auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
+  auto end = begin + getMemRefType().getRank();
+  return {begin, end};
+}
+
+Optional<SSAValue *> VectorTransferReadOp::getPaddingValue() {
+  auto memRefRank = getMemRefType().getRank();
+  if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) {
+    return None;
+  }
+  return Optional<SSAValue *>(
+      getOperand(Offsets::FirstIndexOffset + memRefRank));
+}
+
+Optional<const SSAValue *> VectorTransferReadOp::getPaddingValue() const {
+  auto memRefRank = getMemRefType().getRank();
+  if (getNumOperands() <= Offsets::FirstIndexOffset + memRefRank) {
+    return None;
+  }
+  return Optional<const SSAValue *>(
+      getOperand(Offsets::FirstIndexOffset + memRefRank));
+}
+
+AffineMap VectorTransferReadOp::getPermutationMap() const {
+  return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue();
+}
+
+void VectorTransferReadOp::print(OpAsmPrinter *p) const {
+  *p << getOperationName() << " ";
+  p->printOperand(getMemRef());
+  *p << ", ";
+  p->printOperands(getIndices());
+  auto optionalPaddingValue = getPaddingValue();
+  if (optionalPaddingValue) {
+    *p << ", ";
+    p->printOperand(*optionalPaddingValue);
+  }
+  p->printOptionalAttrDict(getAttrs());
+  // Construct the FunctionType and print it.
+  llvm::SmallVector<Type, 8> inputs{getMemRefType()};
+  // Must have at least one actual index, see verify.
+  const SSAValue *firstIndex = *(getIndices().begin());
+  Type indexType = firstIndex->getType();
+  inputs.append(getMemRefType().getRank(), indexType);
+  if (optionalPaddingValue) {
+    inputs.push_back((*optionalPaddingValue)->getType());
+  }
+  *p << " : "
+     << FunctionType::get(inputs, {getResultType()}, indexType.getContext());
+}
+
+bool VectorTransferReadOp::parse(OpAsmParser *parser, OperationState *result) {
+  SmallVector<OpAsmParser::OperandType, 8> parsedOperands;
+  Type type;
+
+  // Parsing with support for optional paddingValue.
+  auto fail = parser->parseOperandList(parsedOperands) ||
+              parser->parseOptionalAttributeDict(result->attributes) ||
+              parser->parseColonType(type);
+  if (fail) {
+    return true;
+  }
+
+  // Resolution.
+  auto funType = type.dyn_cast<FunctionType>();
+  if (!funType) {
+    parser->emitError(parser->getNameLoc(), "Function type expected");
+    return true;
+  }
+  if (funType.getNumInputs() < 1) {
+    parser->emitError(parser->getNameLoc(),
+                      "Function type expects at least one input");
+    return true;
+  }
+  MemRefType memrefType =
+      funType.getInput(Offsets::MemRefOffset).dyn_cast<MemRefType>();
+  if (!memrefType) {
+    parser->emitError(parser->getNameLoc(),
+                      "MemRef type expected for first input");
+    return true;
+  }
+  if (funType.getNumResults() < 1) {
+    parser->emitError(parser->getNameLoc(),
+                      "Function type expects exactly one vector result");
+    return true;
+  }
+  VectorType vectorType = funType.getResult(0).dyn_cast<VectorType>();
+  if (!vectorType) {
+    parser->emitError(parser->getNameLoc(),
+                      "Vector type expected for first result");
+    return true;
+  }
+  if (parsedOperands.size() != funType.getNumInputs()) {
+    parser->emitError(parser->getNameLoc(), "requires " +
+                                                Twine(funType.getNumInputs()) +
+                                                " operands");
+    return true;
+  }
+
+  // Extract optional paddingValue.
+  OpAsmParser::OperandType memrefInfo = parsedOperands[0];
+  // At this point, indexInfo may contain the optional paddingValue, pop it out.
+  SmallVector<OpAsmParser::OperandType, 8> indexInfo{
+      parsedOperands.begin() + Offsets::FirstIndexOffset, parsedOperands.end()};
+  Type paddingType;
+  OpAsmParser::OperandType paddingValue;
+  bool hasPaddingValue = indexInfo.size() > memrefType.getRank();
+  unsigned expectedNumOperands = Offsets::FirstIndexOffset +
+                                 memrefType.getRank() +
+                                 (hasPaddingValue ? 1 : 0);
+  if (hasPaddingValue) {
+    paddingType = funType.getInputs().back();
+    paddingValue = indexInfo.pop_back_val();
+  }
+  if (funType.getNumInputs() != expectedNumOperands) {
+    parser->emitError(
+        parser->getNameLoc(),
+        "requires actual number of operands to match function type");
+    return true;
+  }
+
+  auto indexType = parser->getBuilder().getIndexType();
+  return parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
+         parser->resolveOperands(indexInfo, indexType, result->operands) ||
+         (hasPaddingValue && parser->resolveOperand(paddingValue, paddingType,
+                                                    result->operands)) ||
+         parser->addTypeToList(vectorType, result->types);
+}
+
+bool VectorTransferReadOp::verify() const {
+  // Consistency of memref type in function type.
+  if (llvm::empty(getOperands())) {
+    return emitOpError(
+        "requires at least a memref operand followed by 'rank' indices");
+  }
+  if (!getMemRef()->getType().isa<MemRefType>()) {
+    return emitOpError("requires a memref as first operand");
+  }
+  // Consistency of vector type in function type.
+  if (!getResult()->getType().isa<VectorType>()) {
+    return emitOpError("should have a vector result type in function type: "
+                       "(memref_type [, elemental_type]) -> vector_type");
+  }
+  // Consistency of elemental types in memref and vector.
+  MemRefType memrefType = getMemRefType();
+  VectorType vectorType = getResultType();
+  if (memrefType.getElementType() != vectorType.getElementType())
+    return emitOpError(
+        "requires memref and vector types of the same elemental type");
+  // Consistency of number of input types.
+  auto optionalPaddingValue = getPaddingValue();
+  unsigned expectedNumOperands = Offsets::FirstIndexOffset +
+                                 memrefType.getRank() +
+                                 (optionalPaddingValue ? 1 : 0);
+  // Checks on the actual operands and their types.
+  if (getNumOperands() != expectedNumOperands) {
+    return emitOpError("expects " + Twine(expectedNumOperands) +
+                       " operands to match the types");
+  }
+  // Consistency of padding value with vector type.
+  if (optionalPaddingValue) {
+    auto paddingValue = *optionalPaddingValue;
+    auto elementalType = paddingValue->getType();
+    if (!VectorType::isValidElementType(elementalType)) {
+      return emitOpError("requires valid padding vector elemental type");
+    }
+    if (elementalType != vectorType.getElementType()) {
+      return emitOpError(
+          "requires formal padding and vector of the same elemental type");
+    }
+  }
+  // Consistency of indices types.
+  unsigned numIndices = 0;
+  for (auto *idx : getIndices()) {
+    if (!idx->getType().isIndex()) {
+      return emitOpError(
+          "index to vector_transfer_read must have 'index' type");
+    }
+    ++numIndices;
+  }
+  if (numIndices != memrefType.getRank()) {
+    return emitOpError("requires at least a memref operand followed by " +
+                       Twine(memrefType.getRank()) + " indices");
+  }
+
+  // Consistency of AffineMap attribute.
+  if (!getAttrOfType<AffineMapAttr>(getPermutationMapAttrName())) {
+    return emitOpError("requires an AffineMapAttr named 'permutation_map'");
+  }
+  auto permutationMap = getPermutationMap();
+  if (!permutationMap.getRangeSizes().empty()) {
+    return emitOpError("requires an unbounded permutation_map");
+  }
+  if (permutationMap.getNumSymbols() != 0) {
+    return emitOpError("requires a permutation_map without symbols");
+  }
+  if (permutationMap.getNumInputs() != memrefType.getRank()) {
+    return emitOpError("requires a permutation_map with input dims of the "
+                       "same rank as the memref type");
+  }
+  if (permutationMap.getNumResults() != vectorType.getRank()) {
+    return emitOpError("requires a permutation_map with result dims of the "
+                       "same rank as the vector type");
+  }
+  return verifyPermutationMap(permutationMap,
+                              [this](Twine t) { return emitOpError(t); });
+}
+
+//===----------------------------------------------------------------------===//
+// VectorTransferWriteOp
+//===----------------------------------------------------------------------===//
+void VectorTransferWriteOp::build(Builder *builder, OperationState *result,
+                                  SSAValue *srcVector, SSAValue *dstMemRef,
+                                  ArrayRef<SSAValue *> dstIndices,
+                                  AffineMap permutationMap) {
+  result->addOperands({srcVector, dstMemRef});
+  result->addOperands(dstIndices);
+  result->addAttribute(getPermutationMapAttrName(),
+                       builder->getAffineMapAttr(permutationMap));
+}
+
+llvm::iterator_range<Operation::operand_iterator>
+VectorTransferWriteOp::getIndices() {
+  auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
+  auto end = begin + getMemRefType().getRank();
+  return {begin, end};
+}
+
+llvm::iterator_range<Operation::const_operand_iterator>
+VectorTransferWriteOp::getIndices() const {
+  auto begin = getOperation()->operand_begin() + Offsets::FirstIndexOffset;
+  auto end = begin + getMemRefType().getRank();
+  return {begin, end};
+}
+
+AffineMap VectorTransferWriteOp::getPermutationMap() const {
+  return getAttrOfType<AffineMapAttr>(getPermutationMapAttrName()).getValue();
+}
+
+void VectorTransferWriteOp::print(OpAsmPrinter *p) const {
+  *p << getOperationName();
+  *p << " " << *getVector();
+  *p << ", " << *getMemRef();
+  *p << ", ";
+  p->printOperands(getIndices());
+  p->printOptionalAttrDict(getAttrs());
+  Type indexType = (*getIndices().begin())->getType();
+  *p << " : ";
+  p->printType(getVectorType());
+  *p << ", ";
+  p->printType(getMemRefType());
+  for (unsigned r = 0, n = getMemRefType().getRank(); r < n; ++r) {
+    *p << ", ";
+    p->printType(indexType);
+  }
+}
+
+bool VectorTransferWriteOp::parse(OpAsmParser *parser, OperationState *result) {
+  SmallVector<OpAsmParser::OperandType, 8> parsedOperands;
+  SmallVector<Type, 8> types;
+
+  // Parsing with support for optional paddingValue.
+  auto fail = parser->parseOperandList(parsedOperands) ||
+              parser->parseOptionalAttributeDict(result->attributes) ||
+              parser->parseColonTypeList(types);
+  if (fail) {
+    return true;
+  }
+
+  // Resolution.
+  if (parsedOperands.size() != types.size()) {
+    parser->emitError(parser->getNameLoc(),
+                      "requires number of operands and input types to match");
+    return true;
+  }
+  if (parsedOperands.size() < Offsets::FirstIndexOffset) {
+    parser->emitError(parser->getNameLoc(),
+                      "requires at least vector and memref operands");
+    return true;
+  }
+  VectorType vectorType = types[Offsets::VectorOffset].dyn_cast<VectorType>();
+  if (!vectorType) {
+    parser->emitError(parser->getNameLoc(),
+                      "Vector type expected for first input type");
+    return true;
+  }
+  MemRefType memrefType = types[Offsets::MemRefOffset].dyn_cast<MemRefType>();
+  if (!memrefType) {
+    parser->emitError(parser->getNameLoc(),
+                      "MemRef type expected for second input type");
+    return true;
+  }
+
+  unsigned expectedNumOperands =
+      Offsets::FirstIndexOffset + memrefType.getRank();
+  if (parsedOperands.size() != expectedNumOperands) {
+    parser->emitError(parser->getNameLoc(),
+                      "requires " + Twine(expectedNumOperands) + " operands");
+    return true;
+  }
+
+  OpAsmParser::OperandType vectorInfo = parsedOperands[Offsets::VectorOffset];
+  OpAsmParser::OperandType memrefInfo = parsedOperands[Offsets::MemRefOffset];
+  SmallVector<OpAsmParser::OperandType, 8> indexInfo{
+      parsedOperands.begin() + Offsets::FirstIndexOffset, parsedOperands.end()};
+  auto indexType = parser->getBuilder().getIndexType();
+  return parser->resolveOperand(vectorInfo, vectorType, result->operands) ||
+         parser->resolveOperand(memrefInfo, memrefType, result->operands) ||
+         parser->resolveOperands(indexInfo, indexType, result->operands);
+}
+
+bool VectorTransferWriteOp::verify() const {
+  // Consistency of memref type in function type.
+  if (llvm::empty(getOperands())) {
+    return emitOpError(
+        "requires at least a memref operand followed by 'rank' indices");
+  }
+  if (!getMemRef()->getType().isa<MemRefType>()) {
+    return emitOpError("requires a memref first operand");
+  }
+  // Consistency of vector type in function type.
+  if (!getVector()->getType().isa<VectorType>()) {
+    return emitOpError("should have a vector input type in function type: "
+                       "(vector_type, memref_type [, elemental_type]) -> ()");
+  }
+  // Consistency of elemental types in memref and vector.
+  MemRefType memrefType = getMemRefType();
+  VectorType vectorType = getVectorType();
+  if (memrefType.getElementType() != vectorType.getElementType())
+    return emitOpError(
+        "requires memref and vector types of the same elemental type");
+  // Consistency of number of input types.
+  unsigned expectedNumOperands =
+      Offsets::FirstIndexOffset + memrefType.getRank();
+  // Checks on the actual operands and their types.
+  if (getNumOperands() != expectedNumOperands) {
+    return emitOpError("expects " + Twine(expectedNumOperands) +
+                       " operands to match the types");
+  }
+  // Consistency of indices types.
+  unsigned numIndices = 0;
+  for (auto *idx : getIndices()) {
+    if (!idx->getType().isIndex()) {
+      return emitOpError(
+          "index to vector_transfer_write must have 'index' type");
+    }
+    numIndices++;
+  }
+  if (numIndices != memrefType.getRank()) {
+    return emitOpError("requires at least a memref operand followed by " +
+                       Twine(memrefType.getRank()) + " indices");
+  }
+
+  // Consistency of AffineMap attribute.
+  if (!getAttrOfType<AffineMapAttr>(getPermutationMapAttrName())) {
+    return emitOpError("requires an AffineMapAttr named 'permutation_map'");
+  }
+  auto permutationMap = getPermutationMap();
+  if (!permutationMap.getRangeSizes().empty()) {
+    return emitOpError("requires an unbounded permutation_map");
+  }
+  if (permutationMap.getNumSymbols() != 0) {
+    return emitOpError("requires a permutation_map without symbols");
+  }
+  if (permutationMap.getNumInputs() != memrefType.getRank()) {
+    return emitOpError("requires a permutation_map with input dims of the "
+                       "same rank as the memref type");
+  }
+  if (permutationMap.getNumResults() != vectorType.getRank()) {
+    return emitOpError("requires a permutation_map with result dims of the "
+                       "same rank as the vector type");
+  }
+  return verifyPermutationMap(permutationMap,
+                              [this](Twine t) { return emitOpError(t); });
+}
index 60f0c06aad54cfde269d034b330c44233373200c..400b4fdf9341b35ab8d0894980fabd13dface525 100644 (file)
@@ -89,6 +89,7 @@ using llvm::SetVector;
 
 using namespace mlir;
 
+using functional::makePtrDynCaster;
 using functional::map;
 
 static llvm::cl::list<int>
@@ -243,11 +244,11 @@ substitute(SSAValue *v,
 /// TODO(ntv): support a concrete AffineMap and compose with it.
 /// TODO(ntv): these implementation details should be captured in a
 /// vectorization trait at the op level directly.
-static SmallVector<MLValue *, 8>
-reindexAffineIndices(MLFuncBuilder *b, Type hwVectorType,
+static SmallVector<SSAValue *, 8>
+reindexAffineIndices(MLFuncBuilder *b, VectorType hwVectorType,
                      ArrayRef<unsigned> hwVectorInstance,
                      ArrayRef<SSAValue *> memrefIndices) {
-  auto vectorShape = hwVectorType.cast<VectorType>().getShape();
+  auto vectorShape = hwVectorType.getShape();
   assert(hwVectorInstance.size() >= vectorShape.size());
 
   unsigned numIndices = memrefIndices.size();
@@ -287,78 +288,21 @@ reindexAffineIndices(MLFuncBuilder *b, Type hwVectorType,
   // TODO(ntv): support a concrete map and composition.
   auto app = b->create<AffineApplyOp>(b->getInsertionPoint()->getLoc(),
                                       affineMap, memrefIndices);
-  unsigned numResults = app->getNumResults();
-  SmallVector<MLValue *, 8> res;
-  for (unsigned i = 0; i < numResults; ++i) {
-    res.push_back(cast<MLValue>(app->getResult(i)));
-  }
-  return res;
+  return SmallVector<SSAValue *, 8>{app->getResults()};
 }
 
-/// Returns the cloned operands of `opStmt` for the instance of
-/// `hwVectorInstance` when lowering from a super-vector type to
-/// `hwVectorType`. `hwVectorInstance` represents one particular instance of
-/// `hwVectorType` int the covering of the super-vector type. For a more
-/// detailed description of the problem, see the description of
-/// reindexAffineIndices.
-static SmallVector<MLValue *, 8>
-cloneAndUnrollOperands(OperationStmt *opStmt, Type hwVectorType,
-                       ArrayRef<unsigned> hwVectorInstance,
-                       DenseMap<const MLValue *, MLValue *> *substitutionsMap) {
-  using functional::map;
-
-  // For Ops that are not vector_transfer_read/vector_transfer_write we can just
-  // substitute and be done.
-  if (!isaVectorTransferRead(*opStmt) && !isaVectorTransferWrite(*opStmt)) {
-    return map([substitutionsMap](
-                   SSAValue *v) { return substitute(v, *substitutionsMap); },
-               opStmt->getOperands());
-  }
-
-  // TODO(ntv): this error-prone boilerplate can be removed once we have a
-  // proper Op for vectr_transfer.
-  unsigned offset = 0;
-  unsigned numIndices = 0;
-  SmallVector<MLValue *, 8> res;
-  auto operands = opStmt->getOperands();
-  if (isaVectorTransferRead(*opStmt)) {
-    offset = 1;
-    numIndices = opStmt->getNumOperands() - 1;
-  } else if (isaVectorTransferWrite(*opStmt)) {
-    offset = 2;
-    numIndices = opStmt->getNumOperands() - 2;
-  }
-  // Copy as-is the [optional valueToStore], memref.
-  for (unsigned i = 0; i < offset; ++i) {
-    res.push_back(substitute(*(operands.begin() + i), *substitutionsMap));
-  }
-
-  MLFuncBuilder b(opStmt);
-  // TODO(ntv): indices extraction is brittle and unsafe before we have an Op.
-  SmallVector<SSAValue *, 8> indices;
-  for (auto it = operands.begin() + offset; it != operands.end(); ++it) {
-    indices.push_back(*it);
-  }
-  auto affineValues =
-      reindexAffineIndices(&b, hwVectorType, hwVectorInstance, indices);
-  res.append(affineValues.begin(), affineValues.end());
-
-  return res;
-}
-
-// Returns attributes with the following substitutions applied:
-//   - splat of `superVectorType` is replaced by splat of `hwVectorType`.
-// TODO(ntv): add more substitutions on a per-need basis.
-static SmallVector<NamedAttribute, 2>
+/// Returns attributes with the following substitutions applied:
+///   - splat of `superVectorType` is replaced by splat of `hwVectorType`.
+/// TODO(ntv): add more substitutions on a per-need basis.
+static SmallVector<NamedAttribute, 1>
 materializeAttributes(OperationStmt *opStmt, VectorType superVectorType,
                       VectorType hwVectorType) {
-  SmallVector<NamedAttribute, 2> res;
+  SmallVector<NamedAttribute, 1> res;
   for (auto a : opStmt->getAttrs()) {
     auto splat = a.second.dyn_cast<SplatElementsAttr>();
     bool splatOfSuperVectorType = splat && (splat.getType() == superVectorType);
     if (splatOfSuperVectorType) {
-      auto attr = SplatElementsAttr::get(hwVectorType.cast<VectorType>(),
-                                         splat.getValue());
+      auto attr = SplatElementsAttr::get(hwVectorType, splat.getValue());
       res.push_back(NamedAttribute(a.first, attr));
     } else {
       res.push_back(a);
@@ -367,6 +311,70 @@ materializeAttributes(OperationStmt *opStmt, VectorType superVectorType,
   return res;
 }
 
+/// Creates an instantiated version of `opStmt`.
+/// Ops other than VectorTransferReadOp/VectorTransferWriteOp require no
+/// affine reindexing. Just substitute their SSAValue* operands and be done. For
+/// this case the actual instance is irrelevant. Just use the SSA values in
+/// substitutionsMap.
+static OperationStmt *
+instantiate(MLFuncBuilder *b, OperationStmt *opStmt, VectorType superVectorType,
+            VectorType hwVectorType,
+            DenseMap<const MLValue *, MLValue *> *substitutionsMap) {
+  assert(!opStmt->isa<VectorTransferReadOp>() &&
+         "Should call the function specialized for VectorTransferReadOp");
+  assert(!opStmt->isa<VectorTransferWriteOp>() &&
+         "Should call the function specialized for VectorTransferWriteOp");
+  auto operands =
+      map([substitutionsMap](
+              SSAValue *v) { return substitute(v, *substitutionsMap); },
+          opStmt->getOperands());
+  return b->createOperation(
+      opStmt->getLoc(), opStmt->getName(), operands, {hwVectorType},
+      materializeAttributes(opStmt, superVectorType, hwVectorType));
+}
+
+/// Creates an instantiated version of `read` for the instance of
+/// `hwVectorInstance` when lowering from a super-vector type to
+/// `hwVectorType`. `hwVectorInstance` represents one particular instance of
+/// `hwVectorType` int the covering of the super-vector type. For a more
+/// detailed description of the problem, see the description of
+/// reindexAffineIndices.
+static OperationStmt *
+instantiate(MLFuncBuilder *b, VectorTransferReadOp *read,
+            VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance,
+            DenseMap<const MLValue *, MLValue *> *substitutionsMap) {
+  SmallVector<SSAValue *, 8> indices =
+      map(makePtrDynCaster<SSAValue>(), read->getIndices());
+  auto affineIndices =
+      reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices);
+  auto cloned = b->create<VectorTransferReadOp>(
+      read->getLoc(), hwVectorType, read->getMemRef(), affineIndices,
+      makePermutationMap(read->getMemRefType(), hwVectorType),
+      read->getPaddingValue());
+  return cast<OperationStmt>(cloned->getOperation());
+}
+
+/// Creates an instantiated version of `write` for the instance of
+/// `hwVectorInstance` when lowering from a super-vector type to
+/// `hwVectorType`. `hwVectorInstance` represents one particular instance of
+/// `hwVectorType` int the covering of th3e super-vector type. For a more
+/// detailed description of the problem, see the description of
+/// reindexAffineIndices.
+static OperationStmt *
+instantiate(MLFuncBuilder *b, VectorTransferWriteOp *write,
+            VectorType hwVectorType, ArrayRef<unsigned> hwVectorInstance,
+            DenseMap<const MLValue *, MLValue *> *substitutionsMap) {
+  SmallVector<SSAValue *, 8> indices =
+      map(makePtrDynCaster<SSAValue>(), write->getIndices());
+  auto affineIndices =
+      reindexAffineIndices(b, hwVectorType, hwVectorInstance, indices);
+  auto cloned = b->create<VectorTransferWriteOp>(
+      write->getLoc(), substitute(write->getVector(), *substitutionsMap),
+      write->getMemRef(), affineIndices,
+      makePermutationMap(write->getMemRefType(), hwVectorType));
+  return cast<OperationStmt>(cloned->getOperation());
+}
+
 /// Returns `true` if stmt instance is properly cloned and inserted, false
 /// otherwise.
 /// The multi-dimensional `hwVectorInstance` belongs to the shapeRatio of
@@ -386,45 +394,52 @@ materializeAttributes(OperationStmt *opStmt, VectorType superVectorType,
 ///      type, all operands are substituted according to `substitutions`. Thanks
 ///      to the topological order of a slice, the substitution is always
 ///      possible.
-static bool cloneAndInsertHardwareVectorInstance(Statement *stmt,
-                                                 MaterializationState *state) {
-  LLVM_DEBUG(dbgs() << "\nclone" << *stmt);
-  if (auto *opStmt = dyn_cast<OperationStmt>(stmt)) {
-    // TODO(ntv): Is it worth considering an OperationStmt.clone operation
-    // which changes the type so we can promote an OperationStmt with less
-    // boilerplate?
-    assert(opStmt->getNumResults() <= 1 && "NYI: opStmt has > 1 results");
-    auto operands = cloneAndUnrollOperands(opStmt, state->hwVectorType,
-                                           state->hwVectorInstance,
-                                           state->substitutionsMap);
-    MLFuncBuilder b(stmt);
-    if (opStmt->getNumResults() == 0) {
-      // vector_transfer_write
-      b.createOperation(stmt->getLoc(), opStmt->getName(), operands, {},
-                        materializeAttributes(opStmt, state->superVectorType,
-                                              state->hwVectorType));
-    } else {
-      // vector_transfer_read
-      auto *cloned = b.createOperation(
-          stmt->getLoc(), opStmt->getName(), operands, {state->hwVectorType},
-          materializeAttributes(opStmt, state->superVectorType,
-                                state->hwVectorType));
-      state->substitutionsMap->insert(std::make_pair(
-          cast<MLValue>(opStmt->getResult(0)),
-          cast<MLValue>(cast<OperationStmt>(cloned)->getResult(0))));
-    }
-    return false;
-  }
+static bool instantiateMaterialization(Statement *stmt,
+                                       MaterializationState *state) {
+  LLVM_DEBUG(dbgs() << "\ninstantiate: " << *stmt);
 
+  // Fail hard and wake up when needed.
   if (isa<ForStmt>(stmt)) {
-    // Fail hard and wake up when needed.
     stmt->emitError("NYI path ForStmt");
     return true;
   }
 
   // Fail hard and wake up when needed.
-  stmt->emitError("NYI path IfStmt");
-  return true;
+  if (isa<IfStmt>(stmt)) {
+    stmt->emitError("NYI path IfStmt");
+    return true;
+  }
+
+  // Create a builder here for unroll-and-jam effects.
+  MLFuncBuilder b(stmt);
+  auto *opStmt = cast<OperationStmt>(stmt);
+  if (auto write = opStmt->dyn_cast<VectorTransferWriteOp>()) {
+    instantiate(&b, &*write, state->hwVectorType, state->hwVectorInstance,
+                state->substitutionsMap);
+    return false;
+  } else if (auto read = opStmt->dyn_cast<VectorTransferReadOp>()) {
+    auto *clone = instantiate(&b, &*read, state->hwVectorType,
+                              state->hwVectorInstance, state->substitutionsMap);
+    state->substitutionsMap->insert(std::make_pair(
+        cast<MLValue>(read->getResult()), cast<MLValue>(clone->getResult(0))));
+    return false;
+  }
+  // The only op with 0 results reaching this point must, by construction, be
+  // VectorTransferWriteOps and have been caught above. Ops with >= 2 results
+  // are not yet supported. So just support 1 result.
+  if (opStmt->getNumResults() != 1) {
+    stmt->emitError("NYI: ops with != 1 results");
+    return true;
+  }
+  if (opStmt->getResult(0)->getType() != state->superVectorType) {
+    stmt->emitError("Op does not return a supervector.");
+    return true;
+  }
+  auto *clone = instantiate(&b, opStmt, state->superVectorType,
+                            state->hwVectorType, state->substitutionsMap);
+  state->substitutionsMap->insert(std::make_pair(
+      cast<MLValue>(opStmt->getResult(0)), cast<MLValue>(clone->getResult(0))));
+  return false;
 }
 
 /// Takes a slice and rewrites the operations in it so that occurrences
@@ -463,15 +478,22 @@ static void emitSlice(MaterializationState *state,
     scopedState.substitutionsMap = &substitutionMap;
     // slice are topologically sorted, we can just clone them in order.
     for (auto *stmt : *slice) {
-      auto fail = cloneAndInsertHardwareVectorInstance(stmt, &scopedState);
+      auto fail = instantiateMaterialization(stmt, &scopedState);
       (void)fail;
       assert(!fail && "Unhandled super-vector materialization failure");
     }
   }
+
+  LLVM_DEBUG(dbgs() << "\nMLFunction is now\n");
+  LLVM_DEBUG(
+      cast<OperationStmt>((*slice)[0])->getOperationFunction()->print(dbgs()));
+
   // slice are topologically sorted, we can just erase them in reverse
   // order. Reverse iterator does not just work simply with an operator*
   // dereference.
   for (int idx = slice->size() - 1; idx >= 0; --idx) {
+    LLVM_DEBUG(dbgs() << "\nErase: ");
+    LLVM_DEBUG((*slice)[idx]->print(dbgs()));
     (*slice)[idx]->erase();
   }
 }
@@ -497,25 +519,21 @@ static void materialize(MLFunction *f,
                         const SetVector<OperationStmt *> &terminators,
                         MaterializationState *state) {
   DenseSet<Statement *> seen;
-  for (auto terminator : terminators) {
-    LLVM_DEBUG(dbgs() << "\nFrom terminator:" << *terminator);
-
+  for (auto *term : terminators) {
     // Short-circuit test, a given terminator may have been reached by some
     // other previous transitive use-def chains.
-    if (seen.count(terminator) > 0) {
+    if (seen.count(term) > 0) {
       continue;
     }
 
-    // Terminators are vector_transfer_write with 0 results by construction atm.
-    assert(isaVectorTransferWrite(*terminator) && "");
-    assert(terminator->getNumResults() == 0 &&
-           "NYI: terminators must have 0 results");
+    auto terminator = term->cast<VectorTransferWriteOp>();
+    LLVM_DEBUG(dbgs() << "\nFrom terminator:" << *term);
 
     // Get the transitive use-defs starting from terminator, limited to the
     // current enclosing scope of the terminator. See the top of the function
     // Note for the justification of this restriction.
     // TODO(ntv): relax scoping constraints.
-    auto *enclosingScope = terminator->getParentStmt();
+    auto *enclosingScope = term->getParentStmt();
     auto keepIfInSameScope = [enclosingScope](Statement *stmt) {
       assert(stmt && "NULL stmt");
       if (!enclosingScope) {
@@ -525,7 +543,7 @@ static void materialize(MLFunction *f,
       return properlyDominates(*enclosingScope, *stmt);
     };
     SetVector<Statement *> slice =
-        getSlice(terminator, keepIfInSameScope, keepIfInSameScope);
+        getSlice(term, keepIfInSameScope, keepIfInSameScope);
     assert(!slice.empty());
 
     // Sanity checks: transitive slice must be completely disjoint from
@@ -540,10 +558,9 @@ static void materialize(MLFunction *f,
 
     // Emit the current slice.
     // Set scoped super-vector and corresponding hw vector types.
-    state->superVectorType =
-        terminator->getOperand(0)->getType().cast<VectorType>();
+    state->superVectorType = terminator->getVectorType();
     assert((state->superVectorType.getElementType() ==
-            Type::getF32(terminator->getContext())) &&
+            Type::getF32(term->getContext())) &&
            "Only f32 supported for now");
     state->hwVectorType = VectorType::get(
         state->hwVectorSize, state->superVectorType.getElementType());
@@ -568,7 +585,7 @@ PassResult MaterializeVectors::runOnMLFunction(MLFunction *f) {
   // super-vector of subVectorType.
   auto filter = [subVectorType](const Statement &stmt) {
     const auto &opStmt = cast<OperationStmt>(stmt);
-    if (!isaVectorTransferWrite(opStmt)) {
+    if (!opStmt.isa<VectorTransferWriteOp>()) {
       return false;
     }
     return matcher::operatesOnStrictSuperVectors(opStmt, subVectorType);
index 5a408b0a2d7c947bf1453f37bd1ce137344ae473..e4822c27ac9a80eb396df15d06d2e7dacfb917f7 100644 (file)
@@ -541,6 +541,7 @@ using namespace mlir;
 #define DEBUG_TYPE "early-vect"
 
 using functional::apply;
+using functional::makePtrDynCaster;
 using functional::map;
 using functional::ScopeGuard;
 using llvm::dbgs;
@@ -820,23 +821,15 @@ void VectorizationState::registerReplacement(const SSAValue *key,
 /// TODO(andydavis,bondhugula,ntv):
 ///   1. generalize to support padding semantics and offsets within vector type.
 static OperationStmt *
-createVectorTransferRead(MLFuncBuilder *b, Location loc, VectorType vectorType,
+createVectorTransferRead(OperationStmt *loadOp, VectorType vectorType,
                          SSAValue *srcMemRef, ArrayRef<SSAValue *> srcIndices) {
-  SmallVector<SSAValue *, 8> operands;
-  operands.reserve(1 + srcIndices.size());
-  operands.insert(operands.end(), srcMemRef);
-  operands.insert(operands.end(), srcIndices.begin(), srcIndices.end());
-  OperationState opState(b->getContext(), loc, kVectorTransferReadOpName,
-                         operands, vectorType);
-  return b->createOperation(opState);
-}
-
-/// Unwraps a pointer type to another type (possibly the same).
-/// Used in particular to allow easier compositions of
-///   llvm::iterator_range<ForStmt::operand_iterator> types.
-template <typename T, typename ToType = T>
-static std::function<ToType *(T *)> unwrapPtr() {
-  return [](T *val) { return dyn_cast<ToType>(val); };
+  auto memRefType = srcMemRef->getType().cast<MemRefType>();
+  MLFuncBuilder b(loadOp);
+  // TODO(ntv): neutral for noneffective padding.
+  auto transfer = b.create<VectorTransferReadOp>(
+      loadOp->getLoc(), vectorType, srcMemRef, srcIndices,
+      makePermutationMap(memRefType, vectorType));
+  return cast<OperationStmt>(transfer->getOperation());
 }
 
 /// Handles the vectorization of load and store MLIR operations.
@@ -865,15 +858,14 @@ static bool vectorizeRootOrTerminal(MLValue *iv, LoadOrStoreOpPointer memoryOp,
 
   // Materialize a MemRef with 1 vector.
   auto *opStmt = cast<OperationStmt>(memoryOp->getOperation());
-  MLFuncBuilder b(opStmt);
   // For now, vector_transfers must be aligned, operate only on indices with an
   // identity subset of AffineMap and do not change layout.
   // TODO(ntv): increase the expressiveness power of vector_transfer operations
   // as needed by various targets.
   if (opStmt->template isa<LoadOp>()) {
     auto *transfer = createVectorTransferRead(
-        &b, opStmt->getLoc(), vectorType, memoryOp->getMemRef(),
-        map(unwrapPtr<SSAValue>(), memoryOp->getIndices()));
+        opStmt, vectorType, memoryOp->getMemRef(),
+        map(makePtrDynCaster<SSAValue>(), memoryOp->getIndices()));
     state->registerReplacement(opStmt, transfer);
   } else {
     state->registerTerminator(opStmt);
@@ -1008,7 +1000,7 @@ static MLValue *vectorizeConstant(Statement *stmt, const ConstantOp &constant,
   auto *splat = cast<OperationStmt>(b.createOperation(
       loc, constantOpStmt->getName(), {}, {vectorType},
       {make_pair(Identifier::get("value", b.getContext()), attr)}));
-  return cast<MLValue>(cast<OperationStmt>(splat)->getResult(0));
+  return cast<MLValue>(splat->getResult(0));
 }
 
 /// Returns a uniqu'ed VectorType.
@@ -1106,17 +1098,17 @@ static MLValue *vectorizeOperand(SSAValue *operand, Statement *stmt,
 static OperationStmt *createVectorTransferWrite(OperationStmt *storeOp,
                                                 VectorizationState *state) {
   auto store = storeOp->cast<StoreOp>();
+  auto *memRef = store->getMemRef();
+  auto memRefType = memRef->getType().cast<MemRefType>();
   auto *value = store->getValueToStore();
-  auto indices = map(unwrapPtr<SSAValue>(), store->getIndices());
-  SmallVector<SSAValue *, 8> operands;
-  operands.reserve(1 + 1 + indices.size());
-  operands.insert(operands.end(), vectorizeOperand(value, storeOp, state));
-  operands.insert(operands.end(), store->getMemRef());
-  operands.insert(operands.end(), indices.begin(), indices.end());
+  auto *vectorValue = vectorizeOperand(value, storeOp, state);
+  auto vectorType = vectorValue->getType().cast<VectorType>();
+  auto indices = map(makePtrDynCaster<SSAValue>(), store->getIndices());
   MLFuncBuilder b(storeOp);
-  OperationState opState(b.getContext(), storeOp->getLoc(),
-                         kVectorTransferWriteOpName, operands, {});
-  return b.createOperation(opState);
+  auto transfer = b.create<VectorTransferWriteOp>(
+      storeOp->getLoc(), vectorValue, memRef, indices,
+      makePermutationMap(memRefType, vectorType));
+  return cast<OperationStmt>(transfer->getOperation());
 }
 
 /// Encodes OperationStmt-specific behavior for vectorization. In general we
@@ -1134,9 +1126,9 @@ static OperationStmt *vectorizeOneOperationStmt(MLFuncBuilder *b,
   // Sanity checks.
   assert(!stmt->isa<LoadOp>() &&
          "all loads must have already been fully vectorized independently");
-  assert(!isaVectorTransferRead(*stmt) &&
+  assert(!stmt->isa<VectorTransferReadOp>() &&
          "vector_transfer_read cannot be further vectorized");
-  assert(!isaVectorTransferWrite(*stmt) &&
+  assert(!stmt->isa<VectorTransferWriteOp>() &&
          "vector_transfer_write cannot be further vectorized");
 
   if (stmt->isa<StoreOp>()) {
index f9e0e5f9404165cc6a2f64f5940dd6086d281429..9a056994c6bb0e2e4a4cd96427b2b7b9f14c0324 100644 (file)
@@ -9,6 +9,9 @@
 
 // CHECK: #map2 = (d0, d1)[s0, s1] -> (d0 + s1, d1 + s0)
 // CHECK: #map3 = ()[s0] -> (s0 + 1)
+// CHECK-DAG: #[[map_proj_d0d1_d0:map[0-9]+]] = (d0, d1) -> (d0)
+// CHECK-DAG: #[[map_proj_d0d1_d1:map[0-9]+]] = (d0, d1) -> (d1)
+// CHECK-DAG: #[[map_proj_d0d1_d1d0:map[0-9]+]] = (d0, d1) -> (d1, d0)
 
 // CHECK-LABEL: cfgfunc @cfgfunc_with_ops(f32) {
 cfgfunc @cfgfunc_with_ops(f32) {
@@ -259,3 +262,23 @@ mlfunc @test_dimop(%arg0 : tensor<4x4x?xf32>) {
   return
 }
 
+
+// CHECK-LABEL: mlfunc @test_vector_transfer_ops(%arg0
+mlfunc @test_vector_transfer_ops(%arg0 : memref<?x?xf32>) {
+  %c3 = constant 3 : index
+  %cst = constant 3.0 : f32
+  // CHECK: %0 = vector_transfer_read %arg0, %c3, %c3 {permutation_map: #[[map_proj_d0d1_d0]]} : (memref<?x?xf32>, index, index) -> vector<128xf32>
+  %0 = vector_transfer_read %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : (memref<?x?xf32>, index, index) -> vector<128xf32>
+  // CHECK: %1 = vector_transfer_read %arg0, %c3, %c3 {permutation_map: #[[map_proj_d0d1_d1d0]]} : (memref<?x?xf32>, index, index) -> vector<3x7xf32>
+  %1 = vector_transfer_read %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d1, d0)} : (memref<?x?xf32>, index, index) -> vector<3x7xf32>
+  // CHECK: %2 = vector_transfer_read %arg0, %c3, %c3, %cst {permutation_map: #[[map_proj_d0d1_d0]]} : (memref<?x?xf32>, index, index, f32) -> vector<128xf32>
+  %2 = vector_transfer_read %arg0, %c3, %c3, %cst {permutation_map: (d0, d1)->(d0)} : (memref<?x?xf32>, index, index, f32) -> vector<128xf32>
+  // CHECK: %3 = vector_transfer_read %arg0, %c3, %c3, %cst {permutation_map: #[[map_proj_d0d1_d1]]} : (memref<?x?xf32>, index, index, f32) -> vector<128xf32>
+  %3 = vector_transfer_read %arg0, %c3, %c3, %cst {permutation_map: (d0, d1)->(d1)} : (memref<?x?xf32>, index, index, f32) -> vector<128xf32>
+  //
+  // CHECK: vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: #[[map_proj_d0d1_d0]]} : vector<128xf32>, memref<?x?xf32>, index, index
+  vector_transfer_write %0, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
+  // CHECK: vector_transfer_write %1, %arg0, %c3, %c3 {permutation_map: #[[map_proj_d0d1_d1d0]]} : vector<3x7xf32>, memref<?x?xf32>, index, index
+  vector_transfer_write %1, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d1, d0)} : vector<3x7xf32>, memref<?x?xf32>, index, index
+  return
+}
index 75c812b9cd430bd66ad0018822d5e296938ac41b..03474a8be5715171b1fc72b927a6132732d3418b 100644 (file)
@@ -286,3 +286,184 @@ bb0(%cond : tensor<?xi1>, %t : tensor<42xi32>, %f : tensor<42xi32>):
   // expected-error@+1 {{requires the condition to have the same shape as arguments}}
   %r = "select"(%cond, %t, %f) : (tensor<?xi1>, tensor<42xi32>, tensor<42xi32>) -> tensor<42xi32>
 }
+
+// -----
+
+cfgfunc @test_vector_transfer_read(memref<?x?xf32>) {
+bb0(%arg0 : memref<?x?xf32>):
+  %c3 = constant 3 : index
+  %cst = constant 3.0 : f32
+  // expected-error@+1 {{expected 4 operand types but had 3}}
+  %0 = "vector_transfer_read"(%arg0, %c3, %c3, %c3) : (memref<?x?xf32>, index, index) -> vector<128xf32>
+}
+
+// -----
+
+cfgfunc @test_vector_transfer_read(memref<?x?xf32>) {
+bb0(%arg0 : memref<?x?xf32>):
+  %c3 = constant 3 : index
+  %cst = constant 3.0 : f32
+  // expected-error@+1 {{requires 3 operands}}
+  %0 = vector_transfer_read %arg0, %c3, %c3, %c3 : (memref<?x?xf32>, index, index) -> vector<128xf32>
+}
+
+// -----
+
+cfgfunc @test_vector_transfer_read(memref<?x?xf32>) {
+bb0(%arg0 : memref<?x?xf32>):
+  %c3 = constant 3 : index
+  %cst = constant 3.0 : f32
+  // expected-error@+1 {{requires an AffineMapAttr named 'permutation_map'}}
+  %0 = vector_transfer_read %arg0, %c3, %c3 : (memref<?x?xf32>, index, index) -> vector<128xf32>
+}
+
+// -----
+
+cfgfunc @test_vector_transfer_read(memref<?x?xf32>) {
+bb0(%arg0 : memref<?x?xf32>):
+  %c3 = constant 3 : index
+  %cst = constant 3.0 : f32
+  // expected-error@+1 {{requires an AffineMapAttr named 'permutation_map'}}
+  %0 = vector_transfer_read %arg0, %c3, %c3 {perm: (d0)->(d0)} : (memref<?x?xf32>, index, index) -> vector<128xf32>
+}
+
+// -----
+
+cfgfunc @test_vector_transfer_read(memref<?x?xf32>) {
+bb0(%arg0 : memref<?x?xf32>):
+  %c3 = constant 3 : index
+  %cst = constant 3.0 : f32
+  // expected-error@+1 {{requires a permutation_map with input dims of the same rank as the memref type}}
+  %0 = vector_transfer_read %arg0, %c3, %c3 {permutation_map: (d0)->(d0)} : (memref<?x?xf32>, index, index) -> vector<128xf32>
+}
+
+// -----
+
+cfgfunc @test_vector_transfer_read(memref<?x?xf32>) {
+bb0(%arg0 : memref<?x?xf32>):
+  %c3 = constant 3 : index
+  %cst = constant 3.0 : f32
+  // expected-error@+1 {{requires a permutation_map with result dims of the same rank as the vector type}}
+  %0 = vector_transfer_read %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0, d1)} : (memref<?x?xf32>, index, index) -> vector<128xf32>
+}
+
+// -----
+
+cfgfunc @test_vector_transfer_read(memref<?x?xf32>) {
+bb0(%arg0 : memref<?x?xf32>):
+  %c3 = constant 3 : index
+  %cst = constant 3.0 : f32
+  // expected-error@+1 {{requires a permutation_map that is an actual permutation}}
+  %0 = vector_transfer_read %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0 + d1)} : (memref<?x?xf32>, index, index) -> vector<128xf32>
+}
+
+// -----
+
+cfgfunc @test_vector_transfer_read(memref<?x?xf32>) {
+bb0(%arg0 : memref<?x?xf32>):
+  %c3 = constant 3 : index
+  %cst = constant 3.0 : f32
+  // expected-error@+1 {{requires a permutation_map that is an actual permutation}}
+  %0 = vector_transfer_read %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0 + 1)} : (memref<?x?xf32>, index, index) -> vector<128xf32>
+}
+// -----
+
+cfgfunc @test_vector_transfer_read(memref<?x?x?xf32>) {
+bb0(%arg0 : memref<?x?x?xf32>):
+  %c3 = constant 3 : index
+  %cst = constant 3.0 : f32
+  // expected-error@+1 {{requires a permutation_map that is a full column-rank permutation}}
+  %0 = vector_transfer_read %arg0, %c3, %c3, %c3 {permutation_map: (d0, d1, d2)->(d0, d0)} : (memref<?x?x?xf32>, index, index, index) -> vector<3x7xf32>
+}
+
+// -----
+
+cfgfunc @test_vector_transfer_write(memref<?x?xf32>) {
+bb0(%arg0 : memref<?x?xf32>):
+  %c3 = constant 3 : index
+  %cst = constant splat<vector<128 x f32>, 3.0>  : vector<128 x f32>
+  // expected-error@+1 {{expected 5 operand types but had 4}}
+  %0 = "vector_transfer_write"(%cst, %arg0, %c3, %c3, %c3) : (vector<128xf32>, memref<?x?xf32>, index, index) -> ()
+}
+
+// -----
+
+cfgfunc @test_vector_transfer_write(memref<?x?xf32>) {
+bb0(%arg0 : memref<?x?xf32>):
+  %c3 = constant 3 : index
+  %cst = constant splat<vector<128 x f32>, 3.0>  : vector<128 x f32>
+  // expected-error@+1 {{requires number of operands and input types to match}}
+  vector_transfer_write %cst, %arg0, %c3, %c3, %c3 : vector<128xf32>, memref<?x?xf32>, index, index
+}
+
+// -----
+
+cfgfunc @test_vector_transfer_write(memref<?x?xf32>) {
+bb0(%arg0 : memref<?x?xf32>):
+  %c3 = constant 3 : index
+  %cst = constant splat<vector<128 x f32>, 3.0>  : vector<128 x f32>
+  // expected-error@+1 {{requires an AffineMapAttr named 'permutation_map'}}
+  vector_transfer_write %cst, %arg0, %c3, %c3 : vector<128xf32>, memref<?x?xf32>, index, index
+}
+
+// -----
+
+cfgfunc @test_vector_transfer_write(memref<?x?xf32>) {
+bb0(%arg0 : memref<?x?xf32>):
+  %c3 = constant 3 : index
+  %cst = constant splat<vector<128 x f32>, 3.0>  : vector<128 x f32>
+  // expected-error@+1 {{requires an AffineMapAttr named 'permutation_map'}}
+  vector_transfer_write %cst, %arg0, %c3, %c3 {perm: (d0)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
+}
+
+// -----
+
+cfgfunc @test_vector_transfer_write(memref<?x?xf32>) {
+bb0(%arg0 : memref<?x?xf32>):
+  %c3 = constant 3 : index
+  %cst = constant splat<vector<128 x f32>, 3.0>  : vector<128 x f32>
+  // expected-error@+1 {{requires a permutation_map with input dims of the same rank as the memref type}}
+  vector_transfer_write %cst, %arg0, %c3, %c3 {permutation_map: (d0)->(d0)} : vector<128xf32>, memref<?x?xf32>, index, index
+}
+
+// -----
+
+cfgfunc @test_vector_transfer_write(memref<?x?xf32>) {
+bb0(%arg0 : memref<?x?xf32>):
+  %c3 = constant 3 : index
+  %cst = constant splat<vector<128 x f32>, 3.0>  : vector<128 x f32>
+  // expected-error@+1 {{requires a permutation_map with result dims of the same rank as the vector type}}
+  vector_transfer_write %cst, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0, d1)} : vector<128xf32>, memref<?x?xf32>, index, index
+}
+
+// -----
+
+cfgfunc @test_vector_transfer_write(memref<?x?xf32>) {
+bb0(%arg0 : memref<?x?xf32>):
+  %c3 = constant 3 : index
+  %cst = constant splat<vector<128 x f32>, 3.0>  : vector<128 x f32>
+  // expected-error@+1 {{requires a permutation_map that is an actual permutation}}
+  vector_transfer_write %cst, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0 + d1)} : vector<128xf32>, memref<?x?xf32>, index, index
+}
+
+// -----
+
+cfgfunc @test_vector_transfer_write(memref<?x?xf32>) {
+bb0(%arg0 : memref<?x?xf32>):
+  %c3 = constant 3 : index
+  %cst = constant splat<vector<128 x f32>, 3.0>  : vector<128 x f32>
+  // expected-error@+1 {{requires a permutation_map that is an actual permutation}}
+  vector_transfer_write %cst, %arg0, %c3, %c3 {permutation_map: (d0, d1)->(d0 + 1)} : vector<128xf32>, memref<?x?xf32>, index, index
+}
+// -----
+
+cfgfunc @test_vector_transfer_write(memref<?x?x?xf32>) {
+bb0(%arg0 : memref<?x?x?xf32>):
+  %c3 = constant 3 : index
+  %cst = constant splat<vector<3 x 7 x f32>, 3.0>  : vector<3 x 7 x f32>
+  // expected-error@+1 {{requires a permutation_map that is a full column-rank permutation}}
+  vector_transfer_write %cst, %arg0, %c3, %c3, %c3 {permutation_map: (d0, d1, d2)->(d0, d0)} : vector<3x7xf32>, memref<?x?x?xf32>, index, index, index
+}
+
+
+
index cc38442cb1a2b2cdeb5fe9f335b9d8b58e59b1e8..93d17ea10d777bae922e3f6bd827da779e09a848 100644 (file)
@@ -2,21 +2,25 @@
 // RUN: mlir-opt %s -vectorize -virtual-vector-size 3 -virtual-vector-size 16 --test-fastest-varying=1 --test-fastest-varying=0 -materialize-vectors -vector-size=8 | FileCheck %s -check-prefix=VEC2DTO1D
 // RUN: mlir-opt %s -vectorize -virtual-vector-size 3 -virtual-vector-size 32 --test-fastest-varying=1 --test-fastest-varying=0 -materialize-vectors -vector-size=3 -vector-size=16 | FileCheck %s -check-prefix=VEC2DTO2D
 
+// Capture permutation maps used in vectorization.
+// VEC1DTO1D-DAG: #[[map_proj_d0d1_d1:map[0-9]+]] = (d0, d1) -> (d1)
+// VEC2DTO1D-DAG: #[[map_proj_d0d1_d1:map[0-9]+]] = (d0, d1) -> (d1)
+// VEC2DTO2D-DAG: #[[map_proj_d0d1_d0d1:map[0-9]+]] = (d0, d1) -> (d0, d1)
+
 // vector<32xf32> -> vector<8xf32>
-// VEC1DTO1D: [[MAP0:#.*]] = (d0, d1) -> (d0, d1)
-// VEC1DTO1D: [[MAP1:#.*]] = (d0, d1) -> (d0, d1 + 8)
-// VEC1DTO1D: [[MAP2:#.*]] = (d0, d1) -> (d0, d1 + 16)
-// VEC1DTO1D: [[MAP3:#.*]] = (d0, d1) -> (d0, d1 + 24)
+// VEC1DTO1D-DAG: [[MAP0:#.*]] = (d0, d1) -> (d0, d1)
+// VEC1DTO1D-DAG: [[MAP1:#.*]] = (d0, d1) -> (d0, d1 + 8)
+// VEC1DTO1D-DAG: [[MAP2:#.*]] = (d0, d1) -> (d0, d1 + 16)
+// VEC1DTO1D-DAG: [[MAP3:#.*]] = (d0, d1) -> (d0, d1 + 24)
 // vector<3x16xf32> -> vector<8xf32>
-// VEC2DTO1D: [[MAP0:#.*]] = (d0, d1) -> (d0, d1)
-// VEC2DTO1D: [[MAP1:#.*]] = (d0, d1) -> (d0, d1 + 8)
-// VEC2DTO1D: [[MAP2:#.*]] = (d0, d1) -> (d0 + 1, d1)
-// VEC2DTO1D: [[MAP3:#.*]] = (d0, d1) -> (d0 + 1, d1 + 8)
-// VEC2DTO1D: [[MAP4:#.*]] = (d0, d1) -> (d0 + 2, d1)
-// VEC2DTO1D: [[MAP5:#.*]] = (d0, d1) -> (d0 + 2, d1 + 8)
+// VEC2DTO1D-DAG: [[MAP0:#.*]] = (d0, d1) -> (d0, d1)
+// VEC2DTO1D-DAG: [[MAP1:#.*]] = (d0, d1) -> (d0, d1 + 8)
+// VEC2DTO1D-DAG: [[MAP2:#.*]] = (d0, d1) -> (d0 + 1, d1)
+// VEC2DTO1D-DAG: [[MAP3:#.*]] = (d0, d1) -> (d0 + 1, d1 + 8)
+// VEC2DTO1D-DAG: [[MAP4:#.*]] = (d0, d1) -> (d0 + 2, d1)
+// VEC2DTO1D-DAG: [[MAP5:#.*]] = (d0, d1) -> (d0 + 2, d1 + 8)
 // vector<3x32xf32> -> vector<3x16xf32>
-// VEC2DTO2D: [[MAP0:#.*]] = (d0, d1) -> (d0, d1)
-// VEC2DTO2D: [[MAP1:#.*]] = (d0, d1) -> (d0, d1 + 16)
+// VEC2DTO2D-DAG: [[MAP1:#.*]] = (d0, d1) -> (d0, d1 + 16)
 mlfunc @vector_add_2d(%M : index, %N : index) -> f32 {
   %A = alloc (%M, %N) : memref<?x?xf32, 0>
   %B = alloc (%M, %N) : memref<?x?xf32, 0>
@@ -32,13 +36,13 @@ mlfunc @vector_add_2d(%M : index, %N : index) -> f32 {
       // VEC1DTO1D: [[CST2:%.*]] = constant splat<vector<8xf32>, 1.000000e+00> : vector<8xf32>
       // VEC1DTO1D: [[CST3:%.*]] = constant splat<vector<8xf32>, 1.000000e+00> : vector<8xf32>
       // VEC1DTO1D: [[VAL0:%.*]] = affine_apply [[MAP0]]{{.*}}
-      // VEC1DTO1D: "vector_transfer_write"([[CST0]], {{.*}}, [[VAL0]]#0, [[VAL0]]#1) : (vector<8xf32>
+      // VEC1DTO1D: vector_transfer_write [[CST0]], {{.*}}, [[VAL0]]#0, [[VAL0]]#1 {permutation_map: #[[map_proj_d0d1_d1]]} : vector<8xf32>
       // VEC1DTO1D: [[VAL1:%.*]] = affine_apply [[MAP1]]{{.*}}
-      // VEC1DTO1D: "vector_transfer_write"([[CST1]], {{.*}}, [[VAL1]]#0, [[VAL1]]#1) : (vector<8xf32>
+      // VEC1DTO1D: vector_transfer_write [[CST1]], {{.*}}, [[VAL1]]#0, [[VAL1]]#1 {permutation_map: #[[map_proj_d0d1_d1]]} : vector<8xf32>
       // VEC1DTO1D: [[VAL2:%.*]] = affine_apply [[MAP2]]{{.*}}
-      // VEC1DTO1D:"vector_transfer_write"([[CST2]], {{.*}}, [[VAL2]]#0, [[VAL2]]#1) : (vector<8xf32>
+      // VEC1DTO1D:vector_transfer_write [[CST2]], {{.*}}, [[VAL2]]#0, [[VAL2]]#1 {permutation_map: #[[map_proj_d0d1_d1]]} : vector<8xf32>
       // VEC1DTO1D: [[VAL3:%.*]] = affine_apply [[MAP3]]{{.*}}
-      // VEC1DTO1D:"vector_transfer_write"([[CST3]], {{.*}}, [[VAL3]]#0, [[VAL3]]#1) : (vector<8xf32>
+      // VEC1DTO1D:vector_transfer_write [[CST3]], {{.*}}, [[VAL3]]#0, [[VAL3]]#1 {permutation_map: #[[map_proj_d0d1_d1]]} : vector<8xf32>
       //
       store %f1, %A[%i0, %i1] : memref<?x?xf32, 0>
     }
@@ -49,10 +53,10 @@ mlfunc @vector_add_2d(%M : index, %N : index) -> f32 {
       // VEC2DTO1D does (3x4)x unrolling.
       // VEC2DTO1D-COUNT-6: {{.*}} = constant splat<vector<8xf32>, 1.000000e+00> : vector<8xf32>
       // VEC2DTO1D: [[VAL0:%.*]] = affine_apply [[MAP0]]{{.*}}
-      // VEC2DTO1D: "vector_transfer_write"({{.*}}, [[VAL0]]#0, [[VAL0]]#1) : (vector<8xf32>
+      // VEC2DTO1D: vector_transfer_write {{.*}}, [[VAL0]]#0, [[VAL0]]#1 {permutation_map: #[[map_proj_d0d1_d1]]} : vector<8xf32>
       // ... 4 other interleaved affine_apply, vector_transfer_write
       // VEC2DTO1D: [[VAL5:%.*]] = affine_apply [[MAP5]]{{.*}}
-      // VEC2DTO1D: "vector_transfer_write"({{.*}}, [[VAL5]]#0, [[VAL5]]#1) : (vector<8xf32>
+      // VEC2DTO1D: vector_transfer_write {{.*}}, [[VAL5]]#0, [[VAL5]]#1 {permutation_map: #[[map_proj_d0d1_d1]]} : vector<8xf32>
       //
       store %f2, %B[%i2, %i3] : memref<?x?xf32, 0>
     }
@@ -60,19 +64,19 @@ mlfunc @vector_add_2d(%M : index, %N : index) -> f32 {
   for %i4 = 0 to %M {
     for %i5 = 0 to %N {
       // VEC2DTO2D: %7 = affine_apply #map0(%i4, %i5)
-      // VEC2DTO2D: %8 = "vector_transfer_read"(%0, %7#0, %7#1) : (memref<?x?xf32>, index, index) -> vector<3x16xf32>
+      // VEC2DTO2D: %8 = vector_transfer_read %0, %7#0, %7#1 {permutation_map: #[[map_proj_d0d1_d0d1]]} : (memref<?x?xf32>, index, index) -> vector<3x16xf32>
       // VEC2DTO2D: %9 = affine_apply #map1(%i4, %i5)
-      // VEC2DTO2D: %10 = "vector_transfer_read"(%0, %9#0, %9#1) : (memref<?x?xf32>, index, index) -> vector<3x16xf32>
+      // VEC2DTO2D: %10 = vector_transfer_read %0, %9#0, %9#1 {permutation_map: #[[map_proj_d0d1_d0d1]]} : (memref<?x?xf32>, index, index) -> vector<3x16xf32>
       // VEC2DTO2D: %11 = affine_apply #map0(%i4, %i5)
-      // VEC2DTO2D: %12 = "vector_transfer_read"(%1, %11#0, %11#1) : (memref<?x?xf32>, index, index) -> vector<3x16xf32>
+      // VEC2DTO2D: %12 = vector_transfer_read %1, %11#0, %11#1 {permutation_map: #[[map_proj_d0d1_d0d1]]} : (memref<?x?xf32>, index, index) -> vector<3x16xf32>
       // VEC2DTO2D: %13 = affine_apply #map1(%i4, %i5)
-      // VEC2DTO2D: %14 = "vector_transfer_read"(%1, %13#0, %13#1) : (memref<?x?xf32>, index, index) -> vector<3x16xf32>
+      // VEC2DTO2D: %14 = vector_transfer_read %1, %13#0, %13#1 {permutation_map: #[[map_proj_d0d1_d0d1]]} : (memref<?x?xf32>, index, index) -> vector<3x16xf32>
       // VEC2DTO2D: %15 = addf %8, %12 : vector<3x16xf32>
       // VEC2DTO2D: %16 = addf %10, %14 : vector<3x16xf32>
       // VEC2DTO2D: %17 = affine_apply #map0(%i4, %i5)
-      // VEC2DTO2D: "vector_transfer_write"(%15, %2, %17#0, %17#1) : (vector<3x16xf32>, memref<?x?xf32>, index, index) -> ()
+      // VEC2DTO2D: vector_transfer_write %15, %2, %17#0, %17#1 {permutation_map: #[[map_proj_d0d1_d0d1]]} : vector<3x16xf32>, memref<?x?xf32>, index, index
       // VEC2DTO2D: %18 = affine_apply #map1(%i4, %i5)
-      // VEC2DTO2D: "vector_transfer_write"(%16, %2, %18#0, %18#1) : (vector<3x16xf32>, memref<?x?xf32>, index, index) -> ()
+      // VEC2DTO2D: vector_transfer_write %16, %2, %18#0, %18#1 {permutation_map: #[[map_proj_d0d1_d0d1]]} : vector<3x16xf32>, memref<?x?xf32>, index, index
       //
       %a5 = load %A[%i4, %i5] : memref<?x?xf32, 0>
       %b5 = load %B[%i4, %i5] : memref<?x?xf32, 0>
index 824e8167b06c7585954e3bcefdc079d41db8229c..3533ef7601371f85cbd099781a2eb2b316edd5bf 100644 (file)
@@ -5,6 +5,14 @@
 // RUN: mlir-opt %s -vectorize -virtual-vector-size 32 -virtual-vector-size 256 --test-fastest-varying=0 --test-fastest-varying=2 | FileCheck %s -check-prefix=VEC2D_OT
 // RUN: mlir-opt %s -vectorize -virtual-vector-size 32 -virtual-vector-size 64 -virtual-vector-size 256 --test-fastest-varying=2 --test-fastest-varying=1 --test-fastest-varying=0 | FileCheck %s -check-prefix=VEC3D
 
+// Permutation maps used in vectorization.
+// VEC1D: #[[map_proj_d0d1_d1:map[0-9]+]] = (d0, d1) -> (d1)
+// VEC2D: #[[map_proj_d0d1_d0d1:map[0-9]+]] = (d0, d1) -> (d0, d1)
+// VEC2D_T: #[[map_proj_d0d1d2_d1d2:map[0-9]+]] = (d0, d1, d2) -> (d1, d2)
+// VEC2D_O: #[[map_proj_d0d1d2_d1d2:map[0-9]+]] = (d0, d1, d2) -> (d1, d2)
+// VEC2D_OT: #[[map_proj_d0d1d2_d1d2:map[0-9]+]] = (d0, d1, d2) -> (d1, d2)
+// VEC3D: #[[map_proj_d0d1d2_d0d1d2:map[0-9]+]] = (d0, d1, d2) -> (d0, d1, d2)
+
 #map0 = (d0) -> (d0)
 #map1 = (d0, d1) -> (d0, d1)
 #map1_t = (d0, d1) -> (d1, d0)
@@ -15,6 +23,7 @@
 #mapadd2 = (d0) -> (d0 + 2)
 #mapadd3 = (d0) -> (d0 + 3)
 #set0 = (i) : (i >= 0)
+
 // Maps introduced to vectorize fastest varying memory index.
 mlfunc @vec1d(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // VEC1D-DAG: [[C0:%[a-z0-9_]+]] = constant 0 : index
@@ -26,26 +35,26 @@ mlfunc @vec1d(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    %P = dim %B, 2 : memref<?x?x?xf32>
    %cst0 = constant 0 : index
 // VEC1D:for [[IV0:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128
-// VEC1D-NEXT: {{.*}} = "vector_transfer_read"(%arg0, [[C0]], [[C0]]) : (memref<?x?xf32>, index, index) -> vector<128xf32>
+// VEC1D-NEXT: {{.*}} = vector_transfer_read %arg0, [[C0]], [[C0]] {permutation_map: #[[map_proj_d0d1_d1]]} : (memref<?x?xf32>, index, index) -> vector<128xf32>
 // For this simple loop, the current transformation generates:
 //   for %i0 = 0 to %0 step 128 {
-//     %3 = "vector_transfer_read"(%arg0, %c0_0, %c0_0) : (memref<?x?xf32>, index, index) -> vector<128xf32>
+//     %3 = vector_transfer_read %arg0, %c0_0, %c0_0 : (memref<?x?xf32>, index, index) -> vector<128xf32>
 //   }
-   for %i0 = 0 to %M { // vectorized due to scalar -> vector 
+   for %i0 = 0 to %M { // vectorized due to scalar -> vector
      %a0 = load %A[%cst0, %cst0] : memref<?x?xf32>
    }
 // VEC1D:for {{.*}} [[ARG_M]] {
-   for %i1 = 0 to %M { // not vectorized 
+   for %i1 = 0 to %M { // not vectorized
      %a1 = load %A[%i1, %i1] : memref<?x?xf32>
    }
 // VEC1D:   for %i{{[0-9]*}} = 0 to [[ARG_M]] {
-   for %i2 = 0 to %M { // not vectorized, would vectorize with --test-fastest-varying=1 
+   for %i2 = 0 to %M { // not vectorized, would vectorize with --test-fastest-varying=1
      %r2 = affine_apply (d0) -> (d0) (%i2)
      %a2 = load %A[%r2#0, %cst0] : memref<?x?xf32>
    }
 // VEC1D:for [[IV3:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128
 // VEC1D-NEXT:   [[APP3:%[a-zA-Z0-9]+]] = affine_apply {{.*}}[[IV3]]
-// VEC1D-NEXT:   {{.*}} = "vector_transfer_read"(%arg0, [[C0]], [[APP3]]) : {{.*}} -> vector<128xf32>
+// VEC1D-NEXT:   {{.*}} = vector_transfer_read %arg0, [[C0]], [[APP3]] {permutation_map: #[[map_proj_d0d1_d1]]} : {{.*}} -> vector<128xf32>
    for %i3 = 0 to %M { // vectorized
      %r3 = affine_apply (d0) -> (d0) (%i3)
      %a3 = load %A[%cst0, %r3#0] : memref<?x?xf32>
@@ -53,8 +62,8 @@ mlfunc @vec1d(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // VEC1D:for [[IV4:%[i0-9]+]] = 0 to [[ARG_M]] step 128 {
 // VEC1D-NEXT:   for [[IV5:%[i0-9]*]] = 0 to [[ARG_N]] {
 // VEC1D-NEXT:   [[APP5:%[0-9]+]] = affine_apply {{.*}}([[IV4]], [[IV5]])
-// VEC1D-NEXT:   {{.*}} = "vector_transfer_read"(%arg0, [[APP5]]#0, [[APP5]]#1) : {{.*}} -> vector<128xf32>
-   for %i4 = 0 to %M { // vectorized 
+// VEC1D-NEXT:   {{.*}} = vector_transfer_read %arg0, [[APP5]]#0, [[APP5]]#1 {permutation_map: #[[map_proj_d0d1_d1]]} : {{.*}} -> vector<128xf32>
+   for %i4 = 0 to %M { // vectorized
      for %i5 = 0 to %N { // not vectorized, would vectorize with --test-fastest-varying=1
        %r5 = affine_apply #map1_t (%i4, %i5)
        %a5 = load %A[%r5#0, %r5#1] : memref<?x?xf32>
@@ -71,7 +80,7 @@ mlfunc @vec1d(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
 // VEC1D:for [[IV8:%[i0-9]+]] = 0 to [[ARG_M]] step 128
 // VEC1D-NEXT:   for [[IV9:%[i0-9]*]] = 0 to [[ARG_N]] {
 // VEC1D-NEXT:   [[APP9:%[0-9]+]] = affine_apply {{.*}}([[IV8]], [[IV9]])
-// VEC1D-NEXT:   {{.*}} = "vector_transfer_read"(%arg0, [[APP9]]#0, [[APP9]]#1) : {{.*}} -> vector<128xf32>
+// VEC1D-NEXT:   {{.*}} = vector_transfer_read %arg0, [[APP9]]#0, [[APP9]]#1 {permutation_map: #[[map_proj_d0d1_d1]]} : {{.*}} -> vector<128xf32>
    for %i8 = 0 to %M { // vectorized
      for %i9 = 0 to %N {
        %r9 = affine_apply #map3 (%i8, %i9)
@@ -80,8 +89,8 @@ mlfunc @vec1d(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    }
 // VEC1D: for [[IV10:%[i0-9]*]] = 0 to %{{[0-9]*}} {
 // VEC1D:   for [[IV11:%[i0-9]*]] = 0 to %{{[0-9]*}} {
-   for %i10 = 0 to %M { // not vectorized, need per load transposes 
-     for %i11 = 0 to %N { // not vectorized, need per load transposes 
+   for %i10 = 0 to %M { // not vectorized, need per load transposes
+     for %i11 = 0 to %N { // not vectorized, need per load transposes
        %r11 = affine_apply #map1 (%i10, %i11)
        %a11 = load %A[%r11#0, %r11#1] : memref<?x?xf32>
        %r12 = affine_apply #map1_t (%i10, %i11)
@@ -112,7 +121,7 @@ mlfunc @vec1d(%A : memref<?x?xf32>, %B : memref<?x?x?xf32>) {
    }
 // VEC1D: for %i{{[0-9]*}} = 0 to %{{[0-9]*}} {
 // VEC1D:   for [[IV18:%[a-zA-Z0-9]+]] = 0 to [[ARG_M]] step 128
-// VEC1D:     {{.*}} = "vector_transfer_read"(%arg0, [[C0]], [[C0]]) : {{.*}} -> vector<128xf32>
+// VEC1D:     {{.*}} = vector_transfer_read %arg0, [[C0]], [[C0]] {permutation_map: #[[map_proj_d0d1_d1]]} : {{.*}} -> vector<128xf32>
    for %i17 = 0 to %M { // not vectorized, the 1-D pattern that matched %i18 in DFS post-order prevents vectorizing %i17
      for %i18 = 0 to %M { // vectorized due to scalar -> vector
        %a18 = load %A[%cst0, %cst0] : memref<?x?xf32>
@@ -211,22 +220,22 @@ mlfunc @vec2d_imperfectly_nested(%A : memref<?x?x?xf32>) {
    // VEC2D_T: for %i0 = 0 to %0 step 32 {
    // VEC2D_T:   for %i1 = 0 to %1 step 256 {
    // VEC2D_T:     for %i2 = 0 to %2 {
-   // VEC2D_T:       %3 = "vector_transfer_read"(%arg0, %i2, %i1, %i0) : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
+   // VEC2D_T:       %3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: #[[map_proj_d0d1d2_d1d2]]} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
    // VEC2D_T:   for %i3 = 0 to %1 {
    // VEC2D_T:     for %i4 = 0 to %2 step 256 {
-   // VEC2D_T:       %4 = "vector_transfer_read"(%arg0, %i3, %i4, %i0) : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
+   // VEC2D_T:       %4 = vector_transfer_read %arg0, %i3, %i4, %i0 {permutation_map: #[[map_proj_d0d1d2_d1d2]]} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
    // VEC2D_T:     for %i5 = 0 to %2 step 256 {
-   // VEC2D_T:       %5 = "vector_transfer_read"(%arg0, %i3, %i5, %i0) : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
+   // VEC2D_T:       %5 = vector_transfer_read %arg0, %i3, %i5, %i0 {permutation_map: #[[map_proj_d0d1d2_d1d2]]} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
    //
    // VEC2D_OT: for %i0 = 0 to %0 step 32 {
    // VEC2D_OT:   for %i1 = 0 to %1 {
    // VEC2D_OT:     for %i2 = 0 to %2 step 256 {
-   // VEC2D_OT:       %3 = "vector_transfer_read"(%arg0, %i2, %i1, %i0) : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
+   // VEC2D_OT:       %3 = vector_transfer_read %arg0, %i2, %i1, %i0 {permutation_map: #[[map_proj_d0d1d2_d1d2]]} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
    // VEC2D_OT:   for %i3 = 0 to %1 step 256 {
    // VEC2D_OT:     for %i4 = 0 to %2 {
-   // VEC2D_OT:       %4 = "vector_transfer_read"(%arg0, %i3, %i4, %i0) : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
+   // VEC2D_OT:       %4 = vector_transfer_read %arg0, %i3, %i4, %i0 {permutation_map: #[[map_proj_d0d1d2_d1d2]]} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
    // VEC2D_OT:     for %i5 = 0 to %2 {
-   // VEC2D_OT:       %5 = "vector_transfer_read"(%arg0, %i3, %i5, %i0) : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
+   // VEC2D_OT:       %5 = vector_transfer_read %arg0, %i3, %i5, %i0 {permutation_map: #[[map_proj_d0d1d2_d1d2]]} : (memref<?x?x?xf32>, index, index, index) -> vector<32x256xf32>
    for %i0 = 0 to %0 {
      for %i1 = 0 to %1 {
        for %i2 = 0 to %2 {
@@ -254,7 +263,7 @@ mlfunc @vec3d(%A : memref<?x?x?xf32>) {
    // VEC3D:     for %i2 = 0 to %0 step 32 {
    // VEC3D:       for %i3 = 0 to %1 step 64 {
    // VEC3D:         for %i4 = 0 to %2 step 256 {
-   // VEC3D:           %3 = "vector_transfer_read"(%arg0, %i2, %i3, %i4) : (memref<?x?x?xf32>, index, index, index) -> vector<32x64x256xf32>
+   // VEC3D:           %3 = vector_transfer_read %arg0, %i2, %i3, %i4 {permutation_map: #[[map_proj_d0d1d2_d0d1d2]]} : (memref<?x?x?xf32>, index, index, index) -> vector<32x64x256xf32>
    for %t0 = 0 to %0 {
      for %t1 = 0 to %0 {
        for %i0 = 0 to %0 {
@@ -278,9 +287,9 @@ mlfunc @vector_add_2d(%M : index, %N : index) -> f32 {
   for %i0 = 0 to %M {
     for %i1 = 0 to %N {
       // VEC1D: [[C1:%.*]] = constant splat<vector<128xf32>, 1.000000e+00> : vector<128xf32>
-      // VEC1D: "vector_transfer_write"([[C1]], {{.*}}) : (vector<128xf32>, memref<?x?xf32>, index, index) -> ()
+      // VEC1D: vector_transfer_write [[C1]], {{.*}} {permutation_map: #[[map_proj_d0d1_d1]]} : vector<128xf32>, memref<?x?xf32>, index, index
       // VEC2D: [[C1:%.*]] = constant splat<vector<32x256xf32>, 1.000000e+00> : vector<32x256xf32>
-      // VEC2D: "vector_transfer_write"([[C1]], {{.*}}) : (vector<32x256xf32>, memref<?x?xf32>, index, index) -> ()
+      // VEC2D: vector_transfer_write [[C1]], {{.*}} {permutation_map: #[[map_proj_d0d1_d0d1]]} : vector<32x256xf32>, memref<?x?xf32>, index, index
       // non-scoped %f1
       store %f1, %A[%i0, %i1] : memref<?x?xf32, 0>
     }
@@ -288,9 +297,9 @@ mlfunc @vector_add_2d(%M : index, %N : index) -> f32 {
   for %i2 = 0 to %M {
     for %i3 = 0 to %N {
       // VEC1D: [[C3:%.*]] = constant splat<vector<128xf32>, 2.000000e+00> : vector<128xf32>
-      // VEC1D: "vector_transfer_write"([[C3]], {{.*}}) : (vector<128xf32>, memref<?x?xf32>, index, index) -> ()
+      // VEC1D: vector_transfer_write [[C3]], {{.*}} {permutation_map: #[[map_proj_d0d1_d1]]} : vector<128xf32>, memref<?x?xf32>, index, index
       // VEC2D: [[C3:%.*]] = constant splat<vector<32x256xf32>, 2.000000e+00> : vector<32x256xf32>
-      // VEC2D: "vector_transfer_write"([[C3]], {{.*}}) : (vector<32x256xf32>, memref<?x?xf32>, index, index) -> ()
+      // VEC2D: vector_transfer_write [[C3]], {{.*}} {permutation_map: #[[map_proj_d0d1_d0d1]]}  : vector<32x256xf32>, memref<?x?xf32>, index, index
       // non-scoped %f2
       store %f2, %B[%i2, %i3] : memref<?x?xf32, 0>
     }
@@ -298,25 +307,25 @@ mlfunc @vector_add_2d(%M : index, %N : index) -> f32 {
   for %i4 = 0 to %M {
     for %i5 = 0 to %N {
       //
-      // VEC1D: [[A5:%.*]] = "vector_transfer_read"(%0, {{.*}}) : (memref<?x?xf32>, index, index) -> vector<128xf32>
-      // VEC1D: [[B5:%.*]] = "vector_transfer_read"(%1, {{.*}}) : (memref<?x?xf32>, index, index) -> vector<128xf32>
+      // VEC1D: [[A5:%.*]] = vector_transfer_read %0, {{.*}} {permutation_map: #[[map_proj_d0d1_d1]]} : (memref<?x?xf32>, index, index) -> vector<128xf32>
+      // VEC1D: [[B5:%.*]] = vector_transfer_read %1, {{.*}} {permutation_map: #[[map_proj_d0d1_d1]]} : (memref<?x?xf32>, index, index) -> vector<128xf32>
       // VEC1D: [[S5:%.*]] = addf [[A5]], [[B5]] : vector<128xf32>
       // VEC1D: [[SPLAT1:%.*]] = constant splat<vector<128xf32>, 1.000000e+00> : vector<128xf32>
       // VEC1D: [[S6:%.*]] = addf [[S5]], [[SPLAT1]] : vector<128xf32>
       // VEC1D: [[SPLAT2:%.*]] = constant splat<vector<128xf32>, 2.000000e+00> : vector<128xf32>
       // VEC1D: [[S7:%.*]] = addf [[S5]], [[SPLAT2]] : vector<128xf32>
       // VEC1D: [[S8:%.*]] = addf [[S7]], [[S6]] : vector<128xf32>
-      // VEC1D: "vector_transfer_write"([[S8]], {{.*}}) : (vector<128xf32>, memref<?x?xf32>, index, index) -> ()
+      // VEC1D: vector_transfer_write [[S8]], {{.*}} {permutation_map: #[[map_proj_d0d1_d1]]} : vector<128xf32>, memref<?x?xf32>, index, index
       //
-      // VEC2D: [[A5:%.*]] = "vector_transfer_read"(%0, {{.*}}) : (memref<?x?xf32>, index, index) -> vector<32x256xf32>
-      // VEC2D: [[B5:%.*]] = "vector_transfer_read"(%1, {{.*}}) : (memref<?x?xf32>, index, index) -> vector<32x256xf32>
+      // VEC2D: [[A5:%.*]] = vector_transfer_read %0, {{.*}} {permutation_map: #[[map_proj_d0d1_d0d1]]} : (memref<?x?xf32>, index, index) -> vector<32x256xf32>
+      // VEC2D: [[B5:%.*]] = vector_transfer_read %1, {{.*}} {permutation_map: #[[map_proj_d0d1_d0d1]]} : (memref<?x?xf32>, index, index) -> vector<32x256xf32>
       // VEC2D: [[S5:%.*]] = addf [[A5]], [[B5]] : vector<32x256xf32>
       // VEC2D: [[SPLAT1:%.*]] = constant splat<vector<32x256xf32>, 1.000000e+00> : vector<32x256xf32>
       // VEC2D: [[S6:%.*]] = addf [[S5]], [[SPLAT1]] : vector<32x256xf32>
       // VEC2D: [[SPLAT2:%.*]] = constant splat<vector<32x256xf32>, 2.000000e+00> : vector<32x256xf32>
       // VEC2D: [[S7:%.*]] = addf [[S5]], [[SPLAT2]] : vector<32x256xf32>
       // VEC2D: [[S8:%.*]] = addf [[S7]], [[S6]] : vector<32x256xf32>
-      // VEC2D: "vector_transfer_write"([[S8]], {{.*}}) : (vector<32x256xf32>, memref<?x?xf32>, index, index) -> ()
+      // VEC2D: vector_transfer_write [[S8]], {{.*}} {permutation_map: #[[map_proj_d0d1_d0d1]]} : vector<32x256xf32>, memref<?x?xf32>, index, index
       //
       %a5 = load %A[%i4, %i5] : memref<?x?xf32, 0>
       %b5 = load %B[%i4, %i5] : memref<?x?xf32, 0>