From 705550357fb9f1955207b5953779e8a382744f30 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 9 May 2018 13:43:14 -0700 Subject: [PATCH] Adding constant slice op support. PiperOrigin-RevId: 196021899 --- tensorflow/contrib/lite/toco/BUILD | 1 + .../graph_transformations/graph_transformations.h | 1 + .../resolve_constant_slice.cc | 165 +++++++++++++++++++++ tensorflow/contrib/lite/toco/toco_tooling.cc | 1 + 4 files changed, 168 insertions(+) create mode 100644 tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc diff --git a/tensorflow/contrib/lite/toco/BUILD b/tensorflow/contrib/lite/toco/BUILD index 01ce0d9..b8acc9a 100644 --- a/tensorflow/contrib/lite/toco/BUILD +++ b/tensorflow/contrib/lite/toco/BUILD @@ -273,6 +273,7 @@ cc_library( "graph_transformations/resolve_constant_range.cc", "graph_transformations/resolve_constant_reshape.cc", "graph_transformations/resolve_constant_shape_or_rank.cc", + "graph_transformations/resolve_constant_slice.cc", "graph_transformations/resolve_constant_stack.cc", "graph_transformations/resolve_constant_strided_slice.cc", "graph_transformations/resolve_constant_transpose.cc", diff --git a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h index 4e3ea72..8da242a 100644 --- a/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h +++ b/tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h @@ -182,6 +182,7 @@ DECLARE_GRAPH_TRANSFORMATION(ResolveTransposeAttributes) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRandomUniform) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantRange) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantShapeOrRank) +DECLARE_GRAPH_TRANSFORMATION(ResolveConstantSlice) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStack) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantStridedSlice) DECLARE_GRAPH_TRANSFORMATION(ResolveConstantFill) diff --git a/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc new file mode 100644 index 0000000..b35c3e1 --- /dev/null +++ b/tensorflow/contrib/lite/toco/graph_transformations/resolve_constant_slice.cc @@ -0,0 +1,165 @@ +/* Copyright 2018 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include + +#include "tensorflow/contrib/lite/toco/graph_transformations/graph_transformations.h" +#include "tensorflow/contrib/lite/toco/model.h" +#include "tensorflow/contrib/lite/toco/tooling_util.h" +#include "tensorflow/core/platform/logging.h" + +namespace toco { + +namespace { + +template +bool Slice(SliceOperator const& op, Array const& input_array, + Array* output_array) { + // Implementation is taken from the tflite kernel. + + CHECK(input_array.data_type == Type); + CHECK(output_array->data_type == Type); + const auto& input_data = input_array.GetBuffer().data; + + // Create a buffer for the output array. + std::vector>& output_data = + output_array->GetMutableBuffer().data; + output_data.resize(RequiredBufferSizeForShape(output_array->shape())); + + std::vector size = op.size; + if (size.size() != op.begin.size()) { + // Broadcast the end positions. + CHECK_EQ(op.size.size(), 1); + int broadcast_size = size[0]; + while (size.size() < op.begin.size()) size.push_back(broadcast_size); + } + + // Calculate begin and end indices along each dimension. + CHECK_LE(op.begin.size(), 4); + CHECK_LE(size.size(), 4); + std::vector begin = op.begin; + std::vector end; + for (int i = 0; i < begin.size(); ++i) { + int dim_size = size[i]; + if (dim_size == -1) { + // -1 means the rest of the dimension. + dim_size = input_array.shape().dims()[i] - begin[i]; + } + CHECK_GE(dim_size, 1); + end.push_back(begin[i] + dim_size - 1); + } + + // Pad out so that we always have 4 dims, makes this loop easier. + while (begin.size() < 4) begin.insert(begin.begin(), 0); + while (end.size() < 4) end.insert(end.begin(), 0); + Shape padded_shape = input_array.shape(); + while (padded_shape.dimensions_count() < 4) { + padded_shape.mutable_dims()->insert(padded_shape.mutable_dims()->begin(), + 1); + } + + auto* out_ptr = output_data.data(); + for (int in_b = begin[0]; in_b <= end[0]; ++in_b) { + for (int in_h = begin[1]; in_h <= end[1]; ++in_h) { + for (int in_w = begin[2]; in_w <= end[2]; ++in_w) { + for (int in_d = begin[3]; in_d <= end[3]; ++in_d) { + *out_ptr++ = + input_data[Offset(padded_shape, {in_b, in_h, in_w, in_d})]; + } + } + } + } + + return true; +} + +} // namespace + +bool ResolveConstantSlice::Run(Model* model, std::size_t op_index) { + const auto it = model->operators.begin() + op_index; + const auto* base_op = it->get(); + if (base_op->type != OperatorType::kSlice) { + return false; + } + + const SliceOperator* op = static_cast(base_op); + + CHECK_EQ(op->outputs.size(), 1); + auto& output_array = model->GetArray(op->outputs[0]); + if (output_array.data_type == ArrayDataType::kNone) { + // Yield until the output type has been set by PropagateArrayDataTypes. + return false; + } + + if (!output_array.has_shape()) { + // Yield until the output shape has been set by PropagateFixedShapes. + return false; + } + + if (op->begin.empty() || op->size.empty()) { + // Attributes have not resolved yet. + return false; + } + + const auto& input_array = model->GetArray(op->inputs[0]); + if (!input_array.has_shape()) { + // Yield until the value shape has been resolved. + return false; + } + if (!IsConstantParameterArray(*model, op->inputs[0])) { + // Yield until the value is constant. + return false; + } + + CHECK(!output_array.buffer); + switch (output_array.data_type) { + case ArrayDataType::kFloat: + if (!Slice(*op, input_array, &output_array)) { + return false; + } + break; + case ArrayDataType::kUint8: + if (!Slice(*op, input_array, &output_array)) { + return false; + } + break; + case ArrayDataType::kInt32: + if (!Slice(*op, input_array, &output_array)) { + return false; + } + break; + case ArrayDataType::kInt64: + if (!Slice(*op, input_array, &output_array)) { + return false; + } + break; + default: + LOG(FATAL) << "Unsupported data type input to Slice op with output \"" + << op->outputs[0] << "\""; + break; + } + + // Erase input array if no longer used. + if (IsDiscardableArray(*model, op->inputs[0]) && + CountOpsWithInput(*model, op->inputs[0]) == 1) { + model->EraseArray(op->inputs[0]); + } + + // Erase the operator + model->operators.erase(it); + + return true; +} + +} // namespace toco diff --git a/tensorflow/contrib/lite/toco/toco_tooling.cc b/tensorflow/contrib/lite/toco/toco_tooling.cc index 58c9905..d894916 100644 --- a/tensorflow/contrib/lite/toco/toco_tooling.cc +++ b/tensorflow/contrib/lite/toco/toco_tooling.cc @@ -86,6 +86,7 @@ void MakeGeneralGraphTransformationsSet( transformations->Add(new ResolveConstantRandomUniform); transformations->Add(new ResolveConstantRange); transformations->Add(new ResolveConstantReshape); + transformations->Add(new ResolveConstantSlice); transformations->Add(new ResolveConstantStack); transformations->Add(new ResolveConstantStridedSlice); transformations->Add(new ResolveConstantTranspose); -- 2.7.4