def SPV_OC_OpULessThanEqual : I32EnumAttrCase<"OpULessThanEqual", 178>;
def SPV_OC_OpSLessThanEqual : I32EnumAttrCase<"OpSLessThanEqual", 179>;
def SPV_OC_OpLabel : I32EnumAttrCase<"OpLabel", 248>;
+def SPV_OC_OpBranch : I32EnumAttrCase<"OpBranch", 249>;
+def SPV_OC_OpBranchConditional : I32EnumAttrCase<"OpBranchConditional", 250>;
def SPV_OC_OpReturn : I32EnumAttrCase<"OpReturn", 253>;
def SPV_OC_OpReturnValue : I32EnumAttrCase<"OpReturnValue", 254>;
SPV_OC_OpINotEqual, SPV_OC_OpUGreaterThan, SPV_OC_OpSGreaterThan,
SPV_OC_OpUGreaterThanEqual, SPV_OC_OpSGreaterThanEqual, SPV_OC_OpULessThan,
SPV_OC_OpSLessThan, SPV_OC_OpULessThanEqual, SPV_OC_OpSLessThanEqual,
- SPV_OC_OpLabel, SPV_OC_OpReturn, SPV_OC_OpReturnValue
+ SPV_OC_OpLabel, SPV_OC_OpBranch, SPV_OC_OpBranchConditional, SPV_OC_OpReturn,
+ SPV_OC_OpReturnValue
]> {
let returnType = "::mlir::spirv::Opcode";
let convertFromStorage = "static_cast<::mlir::spirv::Opcode>($_self.getInt())";
// -----
+def SPV_BranchOp : SPV_Op<"Branch", [Terminator]> {
+ let summary = "Unconditional branch to target block.";
+
+ let description = [{
+ This instruction must be the last instruction in a block.
+
+ ### Custom assembly form
+
+ ``` {.ebnf}
+ branch-op ::= `spv.Branch` successor
+ ```
+
+ For example:
+
+ ```
+ spv.Branch ^target
+ ```
+ }];
+
+ let arguments = (ins);
+
+ let results = (outs);
+
+ let builders = [
+ OpBuilder<
+ "Builder *, OperationState *state, Block *successor", [{
+ state->addSuccessor(successor, {});
+ }]
+ >
+ ];
+
+ let skipDefaultBuilders = 1;
+
+ let autogenSerialization = 0;
+}
+
+// -----
+
+def SPV_BranchConditionalOp : SPV_Op<"BranchConditional", [Terminator]> {
+ let summary = [{
+ If Condition is true, branch to true block, otherwise branch to false
+ block.
+ }];
+
+ let description = [{
+ Condition must be a Boolean type scalar.
+
+ Branch weights are unsigned 32-bit integer literals. There must be
+ either no Branch Weights or exactly two branch weights. If present, the
+ first is the weight for branching to True Label, and the second is the
+ weight for branching to False Label. The implied probability that a
+ branch is taken is its weight divided by the sum of the two Branch
+ weights. At least one weight must be non-zero. A weight of zero does not
+ imply a branch is dead or permit its removal; branch weights are only
+ hints. The two weights must not overflow a 32-bit unsigned integer when
+ added together.
+
+ This instruction must be the last instruction in a block.
+
+ ### Custom assembly form
+
+ ``` {.ebnf}
+ branch-conditional-op ::= `spv.BranchConditional` ssa-use
+ (`[` integer-literal, integer-literal `]`)?
+ `,` successor `,` successor
+ ```
+
+ For example:
+
+ ```
+ spv.BranchConditional %condition, ^true_branch, ^false_branch
+ ```
+ }];
+
+ let arguments = (ins
+ SPV_Bool:$condition,
+ OptionalAttr<I32ArrayAttr>:$branch_weights
+ );
+
+ let results = (outs);
+
+ let builders = [
+ OpBuilder<
+ "Builder *, OperationState *state, Value *condition, "
+ "Block *trueBranch, Block *falseBranch, /*optional*/ArrayAttr weights",
+ [{
+ state->addOperands(condition);
+ state->addSuccessor(trueBranch, {});
+ state->addSuccessor(falseBranch, {});
+ state->addAttribute("branch_weights", weights);
+ }]
+ >
+ ];
+
+ let skipDefaultBuilders = 1;
+
+ let autogenSerialization = 0;
+
+ let extraClassDeclaration = [{
+ // Branch indices into the successor list.
+ enum { kTrueIndex = 0, kFalseIndex = 1 };
+ }];
+}
+
+// -----
+
def SPV_ReturnOp : SPV_Op<"Return", [InFunctionScope, Terminator]> {
let summary = "Return with no value from a function with void return type.";
// TODO(antiagainst): generate these strings using ODS.
static constexpr const char kAlignmentAttrName[] = "alignment";
+static constexpr const char kBranchWeightAttrName[] = "branch_weights";
static constexpr const char kDefaultValueAttrName[] = "default_value";
static constexpr const char kFnNameAttrName[] = "fn";
static constexpr const char kIndicesAttrName[] = "indices";
}
//===----------------------------------------------------------------------===//
+// spv.BranchOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseBranchOp(OpAsmParser *parser, OperationState *state) {
+ Block *dest;
+ SmallVector<Value *, 4> destOperands;
+ if (parser->parseSuccessorAndUseList(dest, destOperands))
+ return failure();
+ state->addSuccessor(dest, destOperands);
+ return success();
+}
+
+static void print(spirv::BranchOp branchOp, OpAsmPrinter *printer) {
+ *printer << spirv::BranchOp::getOperationName() << ' ';
+ printer->printSuccessorAndUseList(branchOp.getOperation(), /*index=*/0);
+}
+
+static LogicalResult verify(spirv::BranchOp branchOp) {
+ auto *op = branchOp.getOperation();
+ if (op->getNumSuccessors() != 1)
+ branchOp.emitOpError("must have exactly one successor");
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
+// spv.BranchConditionalOp
+//===----------------------------------------------------------------------===//
+
+static ParseResult parseBranchConditionalOp(OpAsmParser *parser,
+ OperationState *state) {
+ auto &builder = parser->getBuilder();
+ OpAsmParser::OperandType condInfo;
+ Block *dest;
+ SmallVector<Value *, 4> destOperands;
+
+ // Parse the condition.
+ Type boolTy = builder.getI1Type();
+ if (parser->parseOperand(condInfo) ||
+ parser->resolveOperand(condInfo, boolTy, state->operands))
+ return failure();
+
+ // Parse the optional branch weights.
+ if (succeeded(parser->parseOptionalLSquare())) {
+ IntegerAttr trueWeight, falseWeight;
+ SmallVector<NamedAttribute, 2> weights;
+
+ auto i32Type = builder.getIntegerType(32);
+ if (parser->parseAttribute(trueWeight, i32Type, "weight", weights) ||
+ parser->parseComma() ||
+ parser->parseAttribute(falseWeight, i32Type, "weight", weights) ||
+ parser->parseRSquare())
+ return failure();
+
+ state->addAttribute(kBranchWeightAttrName,
+ builder.getArrayAttr({trueWeight, falseWeight}));
+ }
+
+ // Parse the true branch.
+ if (parser->parseComma() ||
+ parser->parseSuccessorAndUseList(dest, destOperands))
+ return failure();
+ state->addSuccessor(dest, destOperands);
+
+ // Parse the false branch.
+ destOperands.clear();
+ if (parser->parseComma() ||
+ parser->parseSuccessorAndUseList(dest, destOperands))
+ return failure();
+ state->addSuccessor(dest, destOperands);
+
+ return success();
+}
+
+static void print(spirv::BranchConditionalOp branchOp, OpAsmPrinter *printer) {
+ *printer << spirv::BranchConditionalOp::getOperationName() << ' ';
+ printer->printOperand(branchOp.condition());
+
+ if (auto weights = branchOp.branch_weights()) {
+ *printer << " [";
+ mlir::interleaveComma(
+ weights->getValue(), printer->getStream(),
+ [&](Attribute a) { *printer << a.cast<IntegerAttr>().getInt(); });
+ *printer << "]";
+ }
+
+ *printer << ", ";
+ printer->printSuccessorAndUseList(branchOp.getOperation(),
+ spirv::BranchConditionalOp::kTrueIndex);
+ *printer << ", ";
+ printer->printSuccessorAndUseList(branchOp.getOperation(),
+ spirv::BranchConditionalOp::kFalseIndex);
+}
+
+static LogicalResult verify(spirv::BranchConditionalOp branchOp) {
+ auto *op = branchOp.getOperation();
+ if (op->getNumSuccessors() != 2)
+ return branchOp.emitOpError("must have exactly two successors");
+
+ if (auto weights = branchOp.branch_weights()) {
+ if (weights->getValue().size() != 2) {
+ return branchOp.emitOpError("must have exactly two branch weights");
+ }
+ if (llvm::all_of(*weights, [](Attribute attr) {
+ return attr.cast<IntegerAttr>().getValue().isNullValue();
+ }))
+ return branchOp.emitOpError("branch weights cannot both be zero");
+ }
+
+ return success();
+}
+
+//===----------------------------------------------------------------------===//
// spv.CompositeExtractOp
//===----------------------------------------------------------------------===//
}
return success();
}
+
//===----------------------------------------------------------------------===//
// spv.Return
//===----------------------------------------------------------------------===//
// RUN: mlir-opt -split-input-file -verify-diagnostics %s | FileCheck %s
//===----------------------------------------------------------------------===//
+// spv.Branch
+//===----------------------------------------------------------------------===//
+
+func @branch() -> () {
+ // CHECK: spv.Branch ^bb1
+ spv.Branch ^next
+^next:
+ spv.Return
+}
+
+// -----
+
+func @missing_accessor() -> () {
+ spv.Branch
+ // expected-error @+1 {{expected block name}}
+}
+
+// -----
+
+func @wrong_accessor_count() -> () {
+ %true = spv.constant true
+ // expected-error @+1 {{must have exactly one successor}}
+ "spv.Branch"()[^one, ^two] : () -> ()
+^one:
+ spv.Return
+^two:
+ spv.Return
+}
+
+// -----
+
+func @accessor_argument_disallowed() -> () {
+ %zero = spv.constant 0 : i32
+ // expected-error @+1 {{requires zero operands}}
+ "spv.Branch"()[^next(%zero : i32)] : () -> ()
+^next(%arg: i32):
+ spv.Return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
+// spv.BranchConditional
+//===----------------------------------------------------------------------===//
+
+func @cond_branch() -> () {
+ %true = spv.constant true
+ // CHECK: spv.BranchConditional %{{.*}}, ^bb1, ^bb2
+ spv.BranchConditional %true, ^one, ^two
+// CHECK: ^bb1
+^one:
+ spv.Return
+// CHECK: ^bb2
+^two:
+ spv.Return
+}
+
+// -----
+
+func @cond_branch_with_weights() -> () {
+ %true = spv.constant true
+ // CHECK: spv.BranchConditional %{{.*}} [5, 10]
+ spv.BranchConditional %true [5, 10], ^one, ^two
+^one:
+ spv.Return
+^two:
+ spv.Return
+}
+
+// -----
+
+func @missing_condition() -> () {
+ // expected-error @+1 {{expected SSA operand}}
+ spv.BranchConditional ^one, ^two
+^one:
+ spv.Return
+^two:
+ spv.Return
+}
+
+// -----
+
+func @wrong_condition_type() -> () {
+ // expected-note @+1 {{prior use here}}
+ %zero = spv.constant 0 : i32
+ // expected-error @+1 {{use of value '%zero' expects different type than prior uses: 'i1' vs 'i32'}}
+ spv.BranchConditional %zero, ^one, ^two
+^one:
+ spv.Return
+^two:
+ spv.Return
+}
+
+// -----
+
+func @wrong_accessor_count() -> () {
+ %true = spv.constant true
+ // expected-error @+1 {{must have exactly two successors}}
+ "spv.BranchConditional"(%true)[^one] : (i1) -> ()
+^one:
+ spv.Return
+^two:
+ spv.Return
+}
+
+// -----
+
+func @accessor_argment_disallowed() -> () {
+ %true = spv.constant true
+ // expected-error @+1 {{requires a single operand}}
+ "spv.BranchConditional"(%true)[^one(%true : i1), ^two] : (i1) -> ()
+^one(%arg : i1):
+ spv.Return
+^two:
+ spv.Return
+}
+
+// -----
+
+func @wrong_number_of_weights() -> () {
+ %true = spv.constant true
+ // expected-error @+1 {{must have exactly two branch weights}}
+ "spv.BranchConditional"(%true)[^one, ^two] {branch_weights = [1 : i32, 2 : i32, 3 : i32]} : (i1) -> ()
+^one:
+ spv.Return
+^two:
+ spv.Return
+}
+
+// -----
+
+func @weights_cannot_both_be_zero() -> () {
+ %true = spv.constant true
+ // expected-error @+1 {{branch weights cannot both be zero}}
+ spv.BranchConditional %true [0, 0], ^one, ^two
+^one:
+ spv.Return
+^two:
+ spv.Return
+}
+
+// -----
+
+//===----------------------------------------------------------------------===//
// spv.Return
//===----------------------------------------------------------------------===//
arguments = existing_info.get('arguments', None)
if arguments is None:
arguments = [map_spec_operand_to_ods_argument(o) for o in operands]
- arguments = '\n '.join(arguments)
+ arguments = ',\n '.join(arguments)
if arguments:
# Prepend and append whitespace for formatting
arguments = '\n {}\n '.format(arguments)
assembly = existing_info.get('assembly', None)
if assembly is None:
- assembly = ' ``` {.ebnf}\n'\
+ assembly = '\n ``` {.ebnf}\n'\
' [TODO]\n'\
' ```\n\n'\
' For example:\n\n'\