const auto* max_op =
op_0->type == OperatorType::kTensorFlowMaximum ? op_0 : op_1;
- CHECK_EQ(min_op->inputs.size(), 2);
- CHECK_EQ(max_op->inputs.size(), 2);
+ if (min_op->inputs.size() != 2 || max_op->inputs.size() != 2) {
+ return false;
+ }
if (min_op->outputs.size() != 1 || max_op->outputs.size() != 1) {
return false;
}
// At that point we know that none of the outputs is used, so we will
// definitely remove the node and all its outputs.
- // Remove any input array that is not used by anything else,
- // and that is not the output of some other operator.
+ // Remove any input array that not the output of another op, and only used by
+ // this op.
for (const auto& input : op->inputs) {
- if (IsDiscardableArray(*model, input) &&
- CountOpsWithInput(*model, input) == 1 &&
- !GetOpWithOutput(*model, input)) {
- model->EraseArray(input);
+ if (!GetOpWithOutput(*model, input)) {
+ DeleteArrayIfUsedOnce(input, model);
}
}
for (const auto& output : op->outputs) {
// If the output array is the model's input array, don't remove that.
// That's the case when cropping a model at a given --input_array.
- if (!IsDiscardableArray(*model, output)) {
- continue;
- }
- // Likewise, if the output array is a RNN state array, don't remove that.
- bool found_output_as_rnn_state_array = false;
- for (const auto& rnn_state : model->flags.rnn_states()) {
- if (output == rnn_state.state_array()) {
- found_output_as_rnn_state_array = true;
- break;
- }
- }
- if (found_output_as_rnn_state_array) {
- continue;
+ if (IsDiscardableArray(*model, output)) {
+ model->EraseArray(output);
}
- // Generic case: do delete this output array.
- model->EraseArray(output);
}
model->operators.erase(it);
return true;
}
}
+ int axis = op->axis;
+ if (axis < 0) {
+ // Handle negative axis
+ axis += model->GetArray(op->inputs[0]).shape().dims().size();
+ }
+ CHECK_EQ(axis, 0) << "Stacking only supported along 0th axis";
+
CHECK(!output_array.buffer);
switch (output_array.data_type) {
case ArrayDataType::kFloat:
// Erase input arrays if no longer used
for (const auto& input : op->inputs) {
- if (IsDiscardableArray(*model, input) &&
- CountOpsWithInput(*model, input) == 1) {
- model->EraseArray(input);
- }
+ toco::DeleteArrayIfUsedOnce(input, model);
}
// Erase the operator
if (HasAttr(node, "dilations")) {
const auto& dilations = GetListAttr(node, "dilations");
CHECK_EQ(dilations.i_size(), 4);
- CHECK_EQ(dilations.i(0), 1);
- CHECK_EQ(dilations.i(3), 1);
+ CHECK_EQ(dilations.i(0), 1)
+ << "Can only import Conv ops with dilation along the height (1st) or "
+ "width (2nd) axis. TensorFlow op \""
+ << node.name() << "\" had dilations:[ " << dilations.i(0) << ", "
+ << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3)
+ << "].";
+ CHECK_EQ(dilations.i(3), 1)
+ << "Can only import Conv ops with dilation along the height (1st) or "
+ "width (2nd) axis. TensorFlow op \""
+ << node.name() << "\" had dilations:[ " << dilations.i(0) << ", "
+ << dilations.i(1) << ", " << dilations.i(2) << ", " << dilations.i(3)
+ << "].";
conv->dilation_height_factor = dilations.i(1);
conv->dilation_width_factor = dilations.i(2);
} else {