} // namespace
namespace {
+/// Import the Shape Ops to Std Patterns.
+#include "ShapeToStandard.cpp.inc"
+} // namespace
+
+namespace {
/// Conversion pass.
class ConvertShapeToStandardPass
: public ConvertShapeToStandardBase<ConvertShapeToStandardPass> {
MLIRContext &ctx = getContext();
ConversionTarget target(ctx);
target.addLegalDialect<StandardOpsDialect, SCFDialect>();
- target.addLegalOp<FuncOp, ModuleOp, ModuleTerminatorOp>();
+ target.addLegalOp<CstrRequireOp, FuncOp, ModuleOp, ModuleTerminatorOp>();
// Setup conversion patterns.
OwningRewritePatternList patterns;
void mlir::populateShapeToStandardConversionPatterns(
OwningRewritePatternList &patterns, MLIRContext *ctx) {
// clang-format off
+ populateWithGenerated(ctx, patterns);
patterns.insert<
AnyOpConversion,
BinaryOpConversion<AddOp, AddIOp>,
--- /dev/null
+//==-- ShapeToStandard.td - Shape to Standard Patterns -------*- tablegen -*==//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+//
+// Defines Patterns to lower Shape ops to Std.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_CONVERSION_SHAPETOSTANDARD_TD
+#define MLIR_CONVERSION_SHAPETOSTANDARD_TD
+
+include "mlir/Dialect/Shape/IR/ShapeOps.td"
+
+def BroadcastableStringAttr : NativeCodeCall<[{
+ $_builder.getStringAttr("required broadcastable shapes")
+}]>;
+
+def : Pat<(Shape_CstrBroadcastableOp $LHS, $RHS),
+ (Shape_CstrRequireOp
+ (Shape_IsBroadcastableOp $LHS, $RHS),
+ (BroadcastableStringAttr))>;
+
+#endif // MLIR_CONVERSION_SHAPETOSTANDARD_TD
// CHECK: }
// CHECK: return %[[ALL_RESULT]] : i1
// CHECK: }
+
+// -----
+
+func @broadcast(%a : tensor<?xindex>, %b : tensor<?xindex>) -> !shape.witness {
+ %0 = shape.cstr_broadcastable %a, %b : tensor<?xindex>, tensor<?xindex>
+ return %0 : !shape.witness
+}
+
+// CHECK-LABEL: func @broadcast(
+// CHECK-SAME: %[[LHS:.*]]: tensor<?xindex>,
+// CHECK-SAME: %[[RHS:.*]]: tensor<?xindex>) -> !shape.witness {
+// CHECK: %[[C0:.*]] = constant 0 : index
+// CHECK: %[[C1:.*]] = constant 1 : index
+// CHECK: %[[LHS_RANK:.*]] = dim %[[LHS]], %[[C0]] : tensor<?xindex>
+// CHECK: %[[RHS_RANK:.*]] = dim %[[RHS]], %[[C0]] : tensor<?xindex>
+// CHECK: %[[LHS_SMALLER:.*]] = cmpi "ule", %[[LHS_RANK]], %[[RHS_RANK]] : index
+// CHECK: %[[SMALLER_RANK:.*]] = select %[[LHS_SMALLER]], %[[LHS_RANK]], %[[RHS_RANK]] : index
+// CHECK: %[[LARGER_RANK:.*]] = select %[[LHS_SMALLER]], %[[RHS_RANK]], %[[LHS_RANK]] : index
+// CHECK: %[[RANK_ERASED_LHS:.*]] = tensor_cast %[[LHS]] : tensor<?xindex> to tensor<?xindex>
+// CHECK: %[[RANK_ERASED_RHS:.*]] = tensor_cast %[[RHS]] : tensor<?xindex> to tensor<?xindex>
+// CHECK: %[[SMALLER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_LHS]], %[[RANK_ERASED_RHS]] : tensor<?xindex>
+// CHECK: %[[LARGER_SHAPE:.*]] = select %[[LHS_SMALLER]], %[[RANK_ERASED_RHS]], %[[RANK_ERASED_LHS]] : tensor<?xindex>
+// CHECK: %[[RANK_DIFF:.*]] = subi %[[LARGER_RANK]], %[[SMALLER_RANK]] : index
+// CHECK: %[[TRUE:.*]] = constant true
+// CHECK: %[[ALL_RESULT:.*]] = scf.for %[[VAL_16:.*]] = %[[RANK_DIFF]] to %[[LARGER_RANK]] step %[[C1]] iter_args(%[[ALL_SO_FAR:.*]] = %[[TRUE]]) -> (i1) {
+// CHECK: %[[LARGER_EXTENT:.*]] = extract_element %[[LARGER_SHAPE]]{{\[}}%[[VAL_16]]] : tensor<?xindex>
+// CHECK: %[[LARGER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[C1]] : index
+// CHECK: %[[LHS_EXTENT_INDEX:.*]] = subi %[[VAL_16]], %[[RANK_DIFF]] : index
+// CHECK: %[[SMALLER_EXTENT:.*]] = extract_element %[[SMALLER_SHAPE]]{{\[}}%[[LHS_EXTENT_INDEX]]] : tensor<?xindex>
+// CHECK: %[[SMALLER_EXTENT_IS_ONE:.*]] = cmpi "eq", %[[SMALLER_EXTENT]], %[[C1]] : index
+// CHECK: %[[EXTENTS_ARE_EQUAL:.*]] = cmpi "eq", %[[LARGER_EXTENT]], %[[SMALLER_EXTENT]] : index
+// CHECK: %[[EITHER_EXTENT_IS_ONE:.*]] = or %[[LARGER_EXTENT_IS_ONE]], %[[SMALLER_EXTENT_IS_ONE]] : i1
+// CHECK: %[[OR_EXTENTS_ARE_EQUAL:.*]] = or %[[EITHER_EXTENT_IS_ONE]], %[[EXTENTS_ARE_EQUAL]] : i1
+// CHECK: %[[NEW_ALL_SO_FAR:.*]] = and %[[ALL_SO_FAR]], %[[OR_EXTENTS_ARE_EQUAL]] : i1
+// CHECK: scf.yield %[[NEW_ALL_SO_FAR]] : i1
+// CHECK: }
+// CHECK: %[[RESULT:.*]] = shape.cstr_require %[[ALL_RESULT]], "required broadcastable shapes"
+// CHECK: return %[[RESULT]] : !shape.witness
+// CHECK: }