Add StdIndexedValue to EDSC helpers
authorDiego Caballero <diego.caballero@intel.com>
Fri, 2 Aug 2019 15:23:48 +0000 (08:23 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 2 Aug 2019 15:24:17 +0000 (08:24 -0700)
Add StdIndexedValue to EDSC helper so that we can use it
to generated std.load and std.store in EDSC.

Closes tensorflow/mlir#59

PiperOrigin-RevId: 261324965

mlir/include/mlir/EDSC/Helpers.h
mlir/test/EDSC/builder-api-test.cpp

index 4b0b277..69b7290 100644 (file)
@@ -33,9 +33,12 @@ namespace edsc {
 template <typename Load, typename Store> class TemplatedIndexedValue;
 
 // By default, edsc::IndexedValue provides an index notation around the affine
-// load and stores.
+// load and stores. edsc::StdIndexedValue provides the standard load/store
+// counterpart.
 using IndexedValue =
     TemplatedIndexedValue<intrinsics::affine_load, intrinsics::affine_store>;
+using StdIndexedValue =
+    TemplatedIndexedValue<intrinsics::std_load, intrinsics::std_store>;
 
 // Base class for MemRefView and VectorView.
 class View {
index 81792ac..ef87e3e 100644 (file)
@@ -603,6 +603,42 @@ memref<?x?x?xf32>, index, index, index) -> ()
   f.erase();
 }
 */
+
+// Exercise StdIndexedValue for loads and stores.
+TEST_FUNC(indirect_access) {
+  using namespace edsc;
+  using namespace edsc::intrinsics;
+  using namespace edsc::op;
+  auto memrefType =
+      MemRefType::get({-1}, FloatType::getF32(&globalContext()), {}, 0);
+  auto f = makeFunction("indirect_access", {},
+                        {memrefType, memrefType, memrefType, memrefType});
+
+  OpBuilder builder(f.getBody());
+  ScopedContext scope(builder, f.getLoc());
+  ValueHandle zero = constant_index(0);
+  MemRefView vC(f.getArgument(2));
+  IndexedValue B(f.getArgument(1)), D(f.getArgument(3));
+  StdIndexedValue A(f.getArgument(0)), C(f.getArgument(2));
+  IndexHandle i, N(vC.ub(0));
+
+  // clang-format off
+  LoopBuilder(&i, zero, N, 1)([&]{
+      C((ValueHandle)D(i)) = A((ValueHandle)B(i));
+  });
+  // clang-format on
+
+  // clang-format off
+  // CHECK-LABEL: func @indirect_access(
+  // CHECK:  [[B:%.*]] = affine.load
+  // CHECK:  [[D:%.*]] = affine.load
+  // CHECK:  load %{{.*}}{{\[}}[[B]]{{\]}}
+  // CHECK:  store %{{.*}}, %{{.*}}{{\[}}[[D]]{{\]}}
+  // clang-format on
+  f.print(llvm::outs());
+  f.erase();
+}
+
 int main() {
   RUN_TESTS();
   return 0;