From 4f6c29223ee5395dd955cefafce6f03ed99170e0 Mon Sep 17 00:00:00 2001 From: Lei Zhang Date: Fri, 30 Aug 2019 12:17:21 -0700 Subject: [PATCH] Add spv.Branch and spv.BranchConditional This CL just covers the op definition, its parsing, printing, and verification. (De)serialization is to be implemented in a subsequent CL. PiperOrigin-RevId: 266431077 --- mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td | 5 +- .../mlir/Dialect/SPIRV/SPIRVControlFlowOps.td | 106 +++++++++++++++ mlir/lib/Dialect/SPIRV/SPIRVOps.cpp | 115 ++++++++++++++++ mlir/test/Dialect/SPIRV/control-flow-ops.mlir | 144 +++++++++++++++++++++ mlir/utils/spirv/gen_spirv_dialect.py | 4 +- 5 files changed, 371 insertions(+), 3 deletions(-) diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td index 538891e..df15b92 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVBase.td @@ -132,6 +132,8 @@ def SPV_OC_OpSLessThan : I32EnumAttrCase<"OpSLessThan", 177>; def SPV_OC_OpULessThanEqual : I32EnumAttrCase<"OpULessThanEqual", 178>; def SPV_OC_OpSLessThanEqual : I32EnumAttrCase<"OpSLessThanEqual", 179>; def SPV_OC_OpLabel : I32EnumAttrCase<"OpLabel", 248>; +def SPV_OC_OpBranch : I32EnumAttrCase<"OpBranch", 249>; +def SPV_OC_OpBranchConditional : I32EnumAttrCase<"OpBranchConditional", 250>; def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>; def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>; @@ -154,7 +156,8 @@ def SPV_OpcodeAttr : SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan, SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan, SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual, - SPV_OC_OpLabel, SPV_OC_OpReturn, SPV_OC_OpReturnValue + SPV_OC_OpLabel, SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn, + SPV_OC_OpReturnValue ]> { let returnType = "::mlir::spirv::Opcode"; let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())"; diff --git a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td index b0cde8b..0927684 100644 --- a/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td +++ b/mlir/include/mlir/Dialect/SPIRV/SPIRVControlFlowOps.td @@ -31,6 +31,112 @@ include "mlir/SPIRV/SPIRVBase.td" // ----- +def SPV_BranchOp : SPV_Op<"Branch", [Terminator]> { + let summary = "Unconditional branch to target block."; + + let description = [{ + This instruction must be the last instruction in a block. + + ### Custom assembly form + + ``` {.ebnf} + branch-op ::= `spv.Branch` successor + ``` + + For example: + + ``` + spv.Branch ^target + ``` + }]; + + let arguments = (ins); + + let results = (outs); + + let builders = [ + OpBuilder< + "Builder *, OperationState *state, Block *successor", [{ + state->addSuccessor(successor, {}); + }] + > + ]; + + let skipDefaultBuilders = 1; + + let autogenSerialization = 0; +} + +// ----- + +def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> { + let summary = [{ + If Condition is true, branch to true block, otherwise branch to false + block. + }]; + + let description = [{ + Condition must be a Boolean type scalar. + + Branch weights are unsigned 32-bit integer literals. There must be + either no Branch Weights or exactly two branch weights. If present, the + first is the weight for branching to True Label, and the second is the + weight for branching to False Label. The implied probability that a + branch is taken is its weight divided by the sum of the two Branch + weights. At least one weight must be non-zero. A weight of zero does not + imply a branch is dead or permit its removal; branch weights are only + hints. The two weights must not overflow a 32-bit unsigned integer when + added together. + + This instruction must be the last instruction in a block. + + ### Custom assembly form + + ``` {.ebnf} + branch-conditional-op ::= `spv.BranchConditional` ssa-use + (`[` integer-literal, integer-literal `]`)? + `,` successor `,` successor + ``` + + For example: + + ``` + spv.BranchConditional %condition, ^true_branch, ^false_branch + ``` + }]; + + let arguments = (ins + SPV_Bool:$condition, + OptionalAttr:$branch_weights + ); + + let results = (outs); + + let builders = [ + OpBuilder< + "Builder *, OperationState *state, Value *condition, " + "Block *trueBranch, Block *falseBranch, /*optional*/ArrayAttr weights", + [{ + state->addOperands(condition); + state->addSuccessor(trueBranch, {}); + state->addSuccessor(falseBranch, {}); + state->addAttribute("branch_weights", weights); + }] + > + ]; + + let skipDefaultBuilders = 1; + + let autogenSerialization = 0; + + let extraClassDeclaration = [{ + // Branch indices into the successor list. + enum { kTrueIndex = 0, kFalseIndex = 1 }; + }]; +} + +// ----- + def SPV_ReturnOp : SPV_Op<"Return", [InFunctionScope, Terminator]> { let summary = "Return with no value from a function with void return type."; diff --git a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp index aaa7ed5..2b1248b 100644 --- a/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp +++ b/mlir/lib/Dialect/SPIRV/SPIRVOps.cpp @@ -33,6 +33,7 @@ using namespace mlir; // TODO(antiagainst): generate these strings using ODS. static constexpr const char kAlignmentAttrName[] = "alignment"; +static constexpr const char kBranchWeightAttrName[] = "branch_weights"; static constexpr const char kDefaultValueAttrName[] = "default_value"; static constexpr const char kFnNameAttrName[] = "fn"; static constexpr const char kIndicesAttrName[] = "indices"; @@ -487,6 +488,119 @@ static LogicalResult verify(spirv::AddressOfOp addressOfOp) { } //===----------------------------------------------------------------------===// +// spv.BranchOp +//===----------------------------------------------------------------------===// + +static ParseResult parseBranchOp(OpAsmParser *parser, OperationState *state) { + Block *dest; + SmallVector destOperands; + if (parser->parseSuccessorAndUseList(dest, destOperands)) + return failure(); + state->addSuccessor(dest, destOperands); + return success(); +} + +static void print(spirv::BranchOp branchOp, OpAsmPrinter *printer) { + *printer << spirv::BranchOp::getOperationName() << ' '; + printer->printSuccessorAndUseList(branchOp.getOperation(), /*index=*/0); +} + +static LogicalResult verify(spirv::BranchOp branchOp) { + auto *op = branchOp.getOperation(); + if (op->getNumSuccessors() != 1) + branchOp.emitOpError("must have exactly one successor"); + + return success(); +} + +//===----------------------------------------------------------------------===// +// spv.BranchConditionalOp +//===----------------------------------------------------------------------===// + +static ParseResult parseBranchConditionalOp(OpAsmParser *parser, + OperationState *state) { + auto &builder = parser->getBuilder(); + OpAsmParser::OperandType condInfo; + Block *dest; + SmallVector destOperands; + + // Parse the condition. + Type boolTy = builder.getI1Type(); + if (parser->parseOperand(condInfo) || + parser->resolveOperand(condInfo, boolTy, state->operands)) + return failure(); + + // Parse the optional branch weights. + if (succeeded(parser->parseOptionalLSquare())) { + IntegerAttr trueWeight, falseWeight; + SmallVector weights; + + auto i32Type = builder.getIntegerType(32); + if (parser->parseAttribute(trueWeight, i32Type, "weight", weights) || + parser->parseComma() || + parser->parseAttribute(falseWeight, i32Type, "weight", weights) || + parser->parseRSquare()) + return failure(); + + state->addAttribute(kBranchWeightAttrName, + builder.getArrayAttr({trueWeight, falseWeight})); + } + + // Parse the true branch. + if (parser->parseComma() || + parser->parseSuccessorAndUseList(dest, destOperands)) + return failure(); + state->addSuccessor(dest, destOperands); + + // Parse the false branch. + destOperands.clear(); + if (parser->parseComma() || + parser->parseSuccessorAndUseList(dest, destOperands)) + return failure(); + state->addSuccessor(dest, destOperands); + + return success(); +} + +static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter *printer) { + *printer << spirv::BranchConditionalOp::getOperationName() << ' '; + printer->printOperand(branchOp.condition()); + + if (auto weights = branchOp.branch_weights()) { + *printer << " ["; + mlir::interleaveComma( + weights->getValue(), printer->getStream(), + [&](Attribute a) { *printer << a.cast().getInt(); }); + *printer << "]"; + } + + *printer << ", "; + printer->printSuccessorAndUseList(branchOp.getOperation(), + spirv::BranchConditionalOp::kTrueIndex); + *printer << ", "; + printer->printSuccessorAndUseList(branchOp.getOperation(), + spirv::BranchConditionalOp::kFalseIndex); +} + +static LogicalResult verify(spirv::BranchConditionalOp branchOp) { + auto *op = branchOp.getOperation(); + if (op->getNumSuccessors() != 2) + return branchOp.emitOpError("must have exactly two successors"); + + if (auto weights = branchOp.branch_weights()) { + if (weights->getValue().size() != 2) { + return branchOp.emitOpError("must have exactly two branch weights"); + } + if (llvm::all_of(*weights, [](Attribute attr) { + return attr.cast().getValue().isNullValue(); + })) + return branchOp.emitOpError("branch weights cannot both be zero"); + } + + return success(); +} + +//===----------------------------------------------------------------------===// // spv.CompositeExtractOp //===----------------------------------------------------------------------===// @@ -1093,6 +1207,7 @@ static LogicalResult verify(spirv::ReferenceOfOp referenceOfOp) { } return success(); } + //===----------------------------------------------------------------------===// // spv.Return //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir index bacea1e..11b8c9f 100644 --- a/mlir/test/Dialect/SPIRV/control-flow-ops.mlir +++ b/mlir/test/Dialect/SPIRV/control-flow-ops.mlir @@ -1,6 +1,150 @@ // RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s //===----------------------------------------------------------------------===// +// spv.Branch +//===----------------------------------------------------------------------===// + +func @branch() -> () { + // CHECK: spv.Branch ^bb1 + spv.Branch ^next +^next: + spv.Return +} + +// ----- + +func @missing_accessor() -> () { + spv.Branch + // expected-error @+1 {{expected block name}} +} + +// ----- + +func @wrong_accessor_count() -> () { + %true = spv.constant true + // expected-error @+1 {{must have exactly one successor}} + "spv.Branch"()[^one, ^two] : () -> () +^one: + spv.Return +^two: + spv.Return +} + +// ----- + +func @accessor_argument_disallowed() -> () { + %zero = spv.constant 0 : i32 + // expected-error @+1 {{requires zero operands}} + "spv.Branch"()[^next(%zero : i32)] : () -> () +^next(%arg: i32): + spv.Return +} + +// ----- + +//===----------------------------------------------------------------------===// +// spv.BranchConditional +//===----------------------------------------------------------------------===// + +func @cond_branch() -> () { + %true = spv.constant true + // CHECK: spv.BranchConditional %{{.*}}, ^bb1, ^bb2 + spv.BranchConditional %true, ^one, ^two +// CHECK: ^bb1 +^one: + spv.Return +// CHECK: ^bb2 +^two: + spv.Return +} + +// ----- + +func @cond_branch_with_weights() -> () { + %true = spv.constant true + // CHECK: spv.BranchConditional %{{.*}} [5, 10] + spv.BranchConditional %true [5, 10], ^one, ^two +^one: + spv.Return +^two: + spv.Return +} + +// ----- + +func @missing_condition() -> () { + // expected-error @+1 {{expected SSA operand}} + spv.BranchConditional ^one, ^two +^one: + spv.Return +^two: + spv.Return +} + +// ----- + +func @wrong_condition_type() -> () { + // expected-note @+1 {{prior use here}} + %zero = spv.constant 0 : i32 + // expected-error @+1 {{use of value '%zero' expects different type than prior uses: 'i1' vs 'i32'}} + spv.BranchConditional %zero, ^one, ^two +^one: + spv.Return +^two: + spv.Return +} + +// ----- + +func @wrong_accessor_count() -> () { + %true = spv.constant true + // expected-error @+1 {{must have exactly two successors}} + "spv.BranchConditional"(%true)[^one] : (i1) -> () +^one: + spv.Return +^two: + spv.Return +} + +// ----- + +func @accessor_argment_disallowed() -> () { + %true = spv.constant true + // expected-error @+1 {{requires a single operand}} + "spv.BranchConditional"(%true)[^one(%true : i1), ^two] : (i1) -> () +^one(%arg : i1): + spv.Return +^two: + spv.Return +} + +// ----- + +func @wrong_number_of_weights() -> () { + %true = spv.constant true + // expected-error @+1 {{must have exactly two branch weights}} + "spv.BranchConditional"(%true)[^one, ^two] {branch_weights = [1 : i32, 2 : i32, 3 : i32]} : (i1) -> () +^one: + spv.Return +^two: + spv.Return +} + +// ----- + +func @weights_cannot_both_be_zero() -> () { + %true = spv.constant true + // expected-error @+1 {{branch weights cannot both be zero}} + spv.BranchConditional %true [0, 0], ^one, ^two +^one: + spv.Return +^two: + spv.Return +} + +// ----- + +//===----------------------------------------------------------------------===// // spv.Return //===----------------------------------------------------------------------===// diff --git a/mlir/utils/spirv/gen_spirv_dialect.py b/mlir/utils/spirv/gen_spirv_dialect.py index 2017e22..e34945d 100755 --- a/mlir/utils/spirv/gen_spirv_dialect.py +++ b/mlir/utils/spirv/gen_spirv_dialect.py @@ -421,14 +421,14 @@ def get_op_definition(instruction, doc, existing_info): arguments = existing_info.get('arguments', None) if arguments is None: arguments = [map_spec_operand_to_ods_argument(o) for o in operands] - arguments = '\n '.join(arguments) + arguments = ',\n '.join(arguments) if arguments: # Prepend and append whitespace for formatting arguments = '\n {}\n '.format(arguments) assembly = existing_info.get('assembly', None) if assembly is None: - assembly = ' ``` {.ebnf}\n'\ + assembly = '\n ``` {.ebnf}\n'\ ' [TODO]\n'\ ' ```\n\n'\ ' For example:\n\n'\ -- 2.7.4