namespace toco {
+namespace {
+
+bool TransposeAffectsMemoryOrder(std::vector<int> perm,
+ std::vector<int> in_shape) {
+ CHECK_EQ(perm.size(), in_shape.size());
+ // See what the ordering of the non-unary columns are before and after
+ // transpose permutation. If the major indices stay in the same order (not
+ // just the shape) then the flat buffer representation shouldn't change.
+ std::vector<int> old_major_index_ordering;
+ std::vector<int> new_major_index_ordering;
+ for (int i = 0; i < in_shape.size(); i++) {
+ if (in_shape[i] != 1) {
+ old_major_index_ordering.push_back(i);
+ }
+
+ if (in_shape[perm[i]] != 1) {
+ new_major_index_ordering.push_back(perm[i]);
+ }
+ }
+
+ CHECK_EQ(new_major_index_ordering.size(), old_major_index_ordering.size());
+
+ return old_major_index_ordering != new_major_index_ordering;
+}
+
+} // namespace
+
bool ConvertTrivialTransposeToReshape::Run(Model* model, std::size_t op_index) {
auto transpose_it = model->operators.begin() + op_index;
if (transpose_it->get()->type != OperatorType::kTranspose) {
TransposeOperator* transpose_op =
static_cast<TransposeOperator*>(transpose_it->get());
+ const auto& input_array = model->GetArray(transpose_op->inputs[0]);
const auto& output_array = model->GetArray(transpose_op->outputs[0]);
- if (!output_array.has_shape()) {
+ if (!input_array.has_shape() || !output_array.has_shape()) {
// Yield until PropagateFixedSizes has been run on this op.
return false;
}
// Note: We can assume we have error checked inputs in PropagateFixedSizes.
- // This transpose is trivial if we only have one non-unitary dimension.
- std::vector<int> const& dims = output_array.shape().dims();
- unsigned non_unitary_axis_count = 0;
- for (int i = 0; i < dims.size(); i++) {
- if (dims[i] != 1) {
- non_unitary_axis_count++;
- }
+ // Check that the permutation has propogated.
+ std::vector<int> const& perm = transpose_op->perm;
+ if (perm.empty()) {
+ return false;
}
- if (non_unitary_axis_count > 1) {
- // Transpose is not trivial
+
+ // This transpose is trivial if non-unitary dimensions remain in the same
+ // order.
+ std::vector<int> const& input_dims = input_array.shape().dims();
+ std::vector<int> const& output_dims = output_array.shape().dims();
+
+ if (TransposeAffectsMemoryOrder(perm, input_dims)) {
return false;
}
string shape_array_name = toco::AvailableArrayName(*model, perm_array_name);
Array& shape_array = model->GetOrCreateArray(shape_array_name);
*(shape_array.mutable_shape()->mutable_dims()) = {
- 1, static_cast<int>(dims.size())};
+ 1, static_cast<int>(output_dims.size())};
reshape_op->inputs.push_back(shape_array_name);
shape_array.data_type = ArrayDataType::kInt32;
auto& shape_buffer = shape_array.GetMutableBuffer<ArrayDataType::kInt32>();
- shape_buffer.data = dims;
+ shape_buffer.data = output_dims;
// Delete perm array if unused
if (IsDiscardableArray(*model, perm_array_name) &&