Add unary ops and ExpOp to Standard Dialect.
authorAlexander Belyaev <pifon@google.com>
Fri, 11 Oct 2019 12:13:18 +0000 (05:13 -0700)
committerA. Unique TensorFlower <gardener@tensorflow.org>
Fri, 11 Oct 2019 12:13:55 +0000 (05:13 -0700)
PiperOrigin-RevId: 274152154

mlir/g3doc/Dialects/Standard.md
mlir/include/mlir/Dialect/StandardOps/Ops.td
mlir/include/mlir/IR/OpDefinition.h
mlir/lib/Conversion/StandardToLLVM/ConvertStandardToLLVM.cpp
mlir/lib/Dialect/StandardOps/Ops.cpp
mlir/test/Conversion/StandardToLLVM/convert-to-llvmir.mlir
mlir/test/IR/core-ops.mlir

index 70a6d4c..9b1648e 100644 (file)
@@ -452,15 +452,38 @@ Example:
 tensor_store %8, %10 : memref<4x?xf32, #layout, memspace0>
 ```
 
+## Unary Operations
+
+### 'exp' operation
+
+Syntax:
+
+``` {.ebnf}
+operation ::= ssa-id `=` `exp` ssa-use `:` type
+```
+
+Examples:
+
+```mlir {.mlir}
+// Scalar natural exponential.
+%a = exp %b : f64
+
+// SIMD vector element-wise natural exponential.
+%f = exp %g : vector<4xf32>
+
+// Tensor element-wise natural exponential.
+%x = exp %y : tensor<4x?xf8>
+```
+
+The `exp` operation takes one operand and returns one result of the same type.
+This type may be a float scalar type, a vector whose element type is float, or a
+tensor of floats. It has no standard attributes.
+
 ## Arithmetic Operations
 
 Basic arithmetic in MLIR is specified by standard operations described in this
 section.
 
-TODO: "sub" etc. Let's not get excited about filling this out yet, we can define
-these on demand. We should be highly informed by and learn from the operations
-supported by HLO and LLVM.
-
 ### 'addi' operation
 
 Syntax:
@@ -478,7 +501,7 @@ Examples:
 // SIMD vector element-wise addition, e.g. for Intel SSE.
 %f = addi %g, %h : vector<4xi32>
 
-// Tensor element-wise addition, analogous to HLO's add operation.
+// Tensor element-wise addition.
 %x = addi %y, %z : tensor<4x?xi8>
 ```
 
@@ -504,7 +527,7 @@ Examples:
 // SIMD vector addition, e.g. for Intel SSE.
 %f = addf %g, %h : vector<4xf32>
 
-// Tensor addition, analogous to HLO's add operation.
+// Tensor addition.
 %x = addf %y, %z : tensor<4x?xbf16>
 ```
 
@@ -757,7 +780,7 @@ Examples:
 // SIMD pointwise vector multiplication, e.g. for Intel SSE.
 %f = mulf %g, %h : vector<4xf32>
 
-// Tensor pointwise multiplication, analogous to HLO's pointwise multiply operation.
+// Tensor pointwise multiplication.
 %x = mulf %y, %z : tensor<4x?xbf16>
 ```
 
index 0e1f1a9..dd02f75 100644 (file)
@@ -72,6 +72,27 @@ class CastOp<string mnemonic, list<OpTrait> traits = []> :
   let hasFolder = 1;
 }
 
+// Base class for unary ops. Requires single operand and result. Individual
+// classes will have `operand` accessor.
+class UnaryOp<string mnemonic, list<OpTrait> traits = []> :
+    Op<Std_Dialect, mnemonic, !listconcat(traits, [NoSideEffect])> {
+  let results = (outs AnyType);
+  let printer = [{
+    return printStandardUnaryOp(this->getOperation(), p);
+  }];
+}
+
+class UnaryOpSameOperandAndResultType<string mnemonic, list<OpTrait> traits = []> :
+    UnaryOp<mnemonic, !listconcat(traits, [SameOperandsAndResultType])> {
+  let parser = [{
+    return impl::parseOneResultSameOperandTypeOp(parser, result);
+  }];
+}
+
+class FloatUnaryOp<string mnemonic, list<OpTrait> traits = []> :
+    UnaryOpSameOperandAndResultType<mnemonic, traits>,
+    Arguments<(ins FloatLike:$operand)>;
+
 // Base class for standard arithmetic operations.  Requires operands and
 // results to be of the same type, but does not constrain them to specific
 // types.  Individual classes will have `lhs` and `rhs` accessor to operands.
@@ -597,6 +618,10 @@ def DivIUOp : IntArithmeticOp<"diviu"> {
   let hasFolder = 1;
 }
 
+def ExpOp : FloatUnaryOp<"exp"> {
+  let summary = "base-e exponential of the specified value";
+}
+
 def ExtractElementOp : Std_Op<"extract_element", [NoSideEffect]> {
   let summary = "element extract operation";
   let description = [{
index c500e73..65033b6 100644 (file)
@@ -1141,13 +1141,17 @@ private:
   Concept *impl;
 };
 
-// These functions are out-of-line implementations of the methods in BinaryOp,
-// which avoids them being template instantiated/duplicated.
+// These functions are out-of-line implementations of the methods in UnaryOp and
+// BinaryOp, which avoids them being template instantiated/duplicated.
 namespace impl {
+ParseResult parseOneResultOneOperandTypeOp(OpAsmParser &parser,
+                                           OperationState &result);
+
 void buildBinaryOp(Builder *builder, OperationState &result, Value *lhs,
                    Value *rhs);
 ParseResult parseOneResultSameOperandTypeOp(OpAsmParser &parser,
                                             OperationState &result);
+
 // Prints the given binary `op` in custom assembly form if both the two operands
 // and the result have the same time. Otherwise, prints the generic assembly
 // form.
index 76de499..206dde7 100644 (file)
@@ -443,28 +443,43 @@ static SmallVector<int64_t, 4> getCoordinates(ArrayRef<int64_t> basis,
   return res;
 }
 
+template <typename SourceOp, unsigned OpCount> struct OpCountValidator {
+  static_assert(
+      std::is_base_of<
+          typename OpTrait::NOperands<OpCount>::template Impl<SourceOp>,
+          SourceOp>::value,
+      "wrong operand count");
+};
+
+template <typename SourceOp> struct OpCountValidator<SourceOp, 1> {
+  static_assert(std::is_base_of<OpTrait::OneOperand<SourceOp>, SourceOp>::value,
+                "expected a single operand");
+};
+
+template <typename SourceOp, unsigned OpCount> void ValidateOpCount() {
+  OpCountValidator<SourceOp, OpCount>();
+}
+
 // Basic lowering implementation for rewriting from Standard Ops to LLVM Dialect
-// Ops for binary ops with one result. This supports higher-dimensional vector
+// Ops for N-ary ops with one result. This supports higher-dimensional vector
 // types.
-template <typename SourceOp, typename TargetOp>
-struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
+template <typename SourceOp, typename TargetOp, unsigned OpCount>
+struct NaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
   using LLVMLegalizationPattern<SourceOp>::LLVMLegalizationPattern;
-  using Super = BinaryOpLLVMOpLowering<SourceOp, TargetOp>;
+  using Super = NaryOpLLVMOpLowering<SourceOp, TargetOp, OpCount>;
 
   // Convert the type of the result to an LLVM type, pass operands as is,
   // preserve attributes.
   PatternMatchResult
   matchAndRewrite(Operation *op, ArrayRef<Value *> operands,
                   ConversionPatternRewriter &rewriter) const override {
-    static_assert(
-        std::is_base_of<OpTrait::NOperands<2>::Impl<SourceOp>, SourceOp>::value,
-        "expected binary op");
+    ValidateOpCount<SourceOp, OpCount>();
     static_assert(
         std::is_base_of<OpTrait::OneResult<SourceOp>, SourceOp>::value,
         "expected single result op");
     static_assert(std::is_base_of<OpTrait::SameOperandsAndResultType<SourceOp>,
                                   SourceOp>::value,
-                  "expected single result op");
+                  "expected same operands and result type");
 
     // Cannot convert ops if their operands are not of LLVM type.
     for (Value *operand : operands) {
@@ -489,7 +504,7 @@ struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
       arraySizes.push_back(llvmTy.getArrayNumElements());
       llvmTy = llvmTy.getArrayElementType();
     }
-    assert(llvmTy.isVectorTy() && "unexpected binary op over non-vector type");
+    assert(llvmTy.isVectorTy() && "unexpected n-ary op over non-vector type");
     auto llvmVectorTy = llvmTy;
 
     // Iteratively extract a position coordinates with basis `arraySize` from a
@@ -511,13 +526,13 @@ struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
 
       // For this unrolled `position` corresponding to the `linearIndex`^th
       // element, extract operand vectors
-      Value *extractedLHS = rewriter.create<LLVM::ExtractValueOp>(
-          loc, llvmVectorTy, operands[0], position);
-      Value *extractedRHS = rewriter.create<LLVM::ExtractValueOp>(
-          loc, llvmVectorTy, operands[1], position);
+      SmallVector<Value *, OpCount> extractedOperands;
+      for (unsigned i = 0; i < OpCount; ++i) {
+        extractedOperands.push_back(rewriter.create<LLVM::ExtractValueOp>(
+            loc, llvmVectorTy, operands[i], position));
+      }
       Value *newVal = rewriter.create<TargetOp>(
-          loc, llvmVectorTy, ArrayRef<Value *>{extractedLHS, extractedRHS},
-          op->getAttrs());
+          loc, llvmVectorTy, extractedOperands, op->getAttrs());
       desc = rewriter.create<LLVM::InsertValueOp>(loc, llvmArrayTy, desc,
                                                   newVal, position);
     }
@@ -526,8 +541,16 @@ struct BinaryOpLLVMOpLowering : public LLVMLegalizationPattern<SourceOp> {
   }
 };
 
+template <typename SourceOp, typename TargetOp>
+using UnaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 1>;
+template <typename SourceOp, typename TargetOp>
+using BinaryOpLLVMOpLowering = NaryOpLLVMOpLowering<SourceOp, TargetOp, 2>;
+
 // Specific lowerings.
 // FIXME: this should be tablegen'ed.
+struct ExpOpLowering : public UnaryOpLLVMOpLowering<ExpOp, LLVM::exp> {
+  using Super::Super;
+};
 struct AddIOpLowering : public BinaryOpLLVMOpLowering<AddIOp, LLVM::AddOp> {
   using Super::Super;
 };
@@ -1301,18 +1324,49 @@ void mlir::LLVM::ensureDistinctSuccessors(ModuleOp m) {
 void mlir::populateStdToLLVMConversionPatterns(
     LLVMTypeConverter &converter, OwningRewritePatternList &patterns) {
   // FIXME: this should be tablegen'ed
+  // clang-format off
   patterns.insert<
-      AddFOpLowering, AddIOpLowering, AndOpLowering, AllocOpLowering,
-      BranchOpLowering, CallIndirectOpLowering, CallOpLowering, CmpIOpLowering,
-      CmpFOpLowering, CondBranchOpLowering, ConstLLVMOpLowering,
-      DeallocOpLowering, DimOpLowering, DivISOpLowering, DivIUOpLowering,
-      DivFOpLowering, FuncOpConversion, IndexCastOpLowering, LoadOpLowering,
-      MemRefCastOpLowering, MulFOpLowering, MulIOpLowering, OrOpLowering,
-      RemISOpLowering, RemIUOpLowering, RemFOpLowering, ReturnOpLowering,
-      SelectOpLowering, SIToFPLowering, FPExtLowering, FPTruncLowering,
-      SignExtendIOpLowering, SplatOpLowering, StoreOpLowering, SubFOpLowering,
-      SubIOpLowering, TruncateIOpLowering, XOrOpLowering,
+      AddFOpLowering,
+      AddIOpLowering,
+      AllocOpLowering,
+      AndOpLowering,
+      BranchOpLowering,
+      CallIndirectOpLowering,
+      CallOpLowering,
+      CmpFOpLowering,
+      CmpIOpLowering,
+      CondBranchOpLowering,
+      ConstLLVMOpLowering,
+      DeallocOpLowering,
+      DimOpLowering,
+      DivFOpLowering,
+      DivISOpLowering,
+      DivIUOpLowering,
+      ExpOpLowering,
+      FPExtLowering,
+      FPTruncLowering,
+      FuncOpConversion,
+      IndexCastOpLowering,
+      LoadOpLowering,
+      MemRefCastOpLowering,
+      MulFOpLowering,
+      MulIOpLowering,
+      OrOpLowering,
+      RemFOpLowering,
+      RemISOpLowering,
+      RemIUOpLowering,
+      ReturnOpLowering,
+      SIToFPLowering,
+      SelectOpLowering,
+      SignExtendIOpLowering,
+      SplatOpLowering,
+      StoreOpLowering,
+      SubFOpLowering,
+      SubIOpLowering,
+      TruncateIOpLowering,
+      XOrOpLowering,
       ZeroExtendIOpLowering>(*converter.getDialect(), converter);
+  // clang-format on
 }
 
 // Convert types using the stored LLVM IR module.
index 5cbdb67..443aa64 100644 (file)
@@ -124,6 +124,19 @@ struct StdInlinerInterface : public DialectInlinerInterface {
 // StandardOpsDialect
 //===----------------------------------------------------------------------===//
 
+/// A custom unary operation printer that omits the "std." prefix from the
+/// operation names.
+static void printStandardUnaryOp(Operation *op, OpAsmPrinter &p) {
+  assert(op->getNumOperands() == 1 && "unary op should have one operand");
+  assert(op->getNumResults() == 1 && "unary op should have one result");
+
+  const int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
+  p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
+    << *op->getOperand(0);
+  p.printOptionalAttrDict(op->getAttrs());
+  p << " : " << op->getOperand(0)->getType();
+}
+
 /// A custom binary operation printer that omits the "std." prefix from the
 /// operation names.
 static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
@@ -139,7 +152,8 @@ static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
     return;
   }
 
-  p << op->getName().getStringRef().drop_front(strlen("std.")) << ' '
+  const int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
+  p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
     << *op->getOperand(0) << ", " << *op->getOperand(1);
   p.printOptionalAttrDict(op->getAttrs());
 
@@ -150,7 +164,8 @@ static void printStandardBinaryOp(Operation *op, OpAsmPrinter &p) {
 /// A custom cast operation printer that omits the "std." prefix from the
 /// operation names.
 static void printStandardCastOp(Operation *op, OpAsmPrinter &p) {
-  p << op->getName().getStringRef().drop_front(strlen("std.")) << ' '
+  const int stdDotLen = StandardOpsDialect::getDialectNamespace().size() + 1;
+  p << op->getName().getStringRef().drop_front(stdDotLen) << ' '
     << *op->getOperand(0) << " : " << op->getOperand(0)->getType() << " to "
     << op->getResult(0)->getType();
 }
index 5fd51bd..fb23a76 100644 (file)
@@ -421,6 +421,8 @@ func @ops(f32, f32, i32, i32) -> (f32, i32) {
   %12 = or %arg2, %arg3 : i32
 // CHECK-NEXT: %12 = llvm.xor %arg2, %arg3 : !llvm.i32
   %13 = xor %arg2, %arg3 : i32
+// CHECK-NEXT: %13 = "llvm.intr.exp"(%arg0) : (!llvm.float) -> !llvm.float
+  %14 = std.exp %arg0 : f32
 
   return %0, %4 : f32, i32
 }
index abb731d..417068a 100644 (file)
@@ -351,6 +351,14 @@ func @standard_instrs(tensor<4x4x?xf32>, f32, i32, index, i64, f16) {
   // CHECK: = fptrunc {{.*}} : f32 to f16
   %95 = fptrunc %f : f32 to f16
 
+  // CHECK: %{{[0-9]+}} = exp %arg1 : f32
+  %96 = "std.exp"(%f) : (f32) -> f32
+
+  // CHECK: %{{[0-9]+}} = exp %arg1 : f32
+  %97 = exp %f : f32
+
+  // CHECK: %{{[0-9]+}} = exp %arg0 : tensor<4x4x?xf32>
+  %98 = exp %t : tensor<4x4x?xf32>
   return
 }