Fix Linalg3 lowering to use the floating point element type matching the view
authorMehdi Amini <aminim@google.com>
Sun, 7 Apr 2019 20:12:48 +0000 (13:12 -0700)
committerMehdi Amini <joker.eph@gmail.com>
Mon, 8 Apr 2019 01:22:34 +0000 (18:22 -0700)
    It used to be hardcoded to f32, but Toy tutorial is using f64.

--

PiperOrigin-RevId: 242370172

mlir/examples/Linalg/Linalg3/lib/TensorOps.cpp

index c6f402f..a5b094c 100644 (file)
@@ -64,9 +64,14 @@ void linalg::DotOp::emitScalarImplementation(
   using edsc::intrinsics::select;
   ScopedContext scope( // account for affine.terminator in loop.
       FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
-  auto f32 = ScopedContext::getBuilder()->getF32Type();
+  FloatType fTy = getOperand(0)
+                      ->getType()
+                      .cast<ViewType>()
+                      .getElementType()
+                      .cast<FloatType>();
   IndexHandle zero(constant_index(0));
-  ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32);
+  ValueHandle zerof =
+      constant_float(llvm::APFloat::getZero(fTy.getFloatSemantics()), fTy);
   IndexHandle r_i(reductionIvs[0]);
   IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2));
   ValueHandle cond = (r_i == zero);
@@ -129,11 +134,16 @@ void linalg::MatvecOp::emitScalarImplementation(
   using edsc::intrinsics::select;
   ScopedContext scope( // account for affine.terminator in loop.
       FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
-  auto f32 = ScopedContext::getBuilder()->getF32Type();
+  FloatType fTy = getOperand(0)
+                      ->getType()
+                      .cast<ViewType>()
+                      .getElementType()
+                      .cast<FloatType>();
   IndexHandle i(parallelIvs[0]), r_j(reductionIvs[0]);
   IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2));
   IndexHandle zero(constant_index(0));
-  ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32);
+  ValueHandle zerof =
+      constant_float(llvm::APFloat::getZero(fTy.getFloatSemantics()), fTy);
   ValueHandle cond = (r_j == zero);
   ValueHandle scalarC = select(cond, zerof, *C(i));
   C(i) = scalarC + A(i, r_j) * B(r_j);
@@ -198,11 +208,16 @@ void linalg::MatmulOp::emitScalarImplementation(
   using edsc::intrinsics::select;
   ScopedContext scope( // account for affine.terminator in loop.
       FuncBuilder(body, std::prev(body->end(), 1)), innermostLoop.getLoc());
-  auto f32 = ScopedContext::getBuilder()->getF32Type();
+  FloatType fTy = getOperand(0)
+                      ->getType()
+                      .cast<ViewType>()
+                      .getElementType()
+                      .cast<FloatType>();
   IndexHandle i(parallelIvs[0]), j(parallelIvs[1]), r_k(reductionIvs[0]);
   IndexedValue A(getOperand(0)), B(getOperand(1)), C(getOperand(2));
   IndexHandle zero(constant_index(0));
-  ValueHandle zerof = constant_float(llvm::APFloat(0.0f), f32);
+  ValueHandle zerof =
+      constant_float(llvm::APFloat::getZero(fTy.getFloatSemantics()), fTy);
   ValueHandle cond = r_k == zero;
   ValueHandle scalarC = select(cond, zerof, *C(i, j));
   C(i, j) = scalarC + A(i, r_k) * B(r_k, j);