CHECK(reporter->AssertEQ(x_shape[i], y_shape[i]))
<< "x and y must have the same shape: " << x_shape << " vs " << y_shape;
- CHECK(reporter->AssertEQ(cond_shape[i], x_shape[i]))
- << "Shape of condition " << condition->shape
- << " must be either equal to x or has dimension of 1.";
+ if (i < cond_shape.size()) {
+ CHECK(reporter->AssertEQ(cond_shape[i], x_shape[i]))
+ << "condition and x must have the same shape: " << cond_shape << " vs " << x_shape;
+ }
}
reporter->Assign(types[3], TensorTypeNode::make(x_shape, x->dtype));
return true;