#include <cassert>
#include <stdexcept>
+namespace
+{
+
+void read_value_float32(loco::ConstGen *const_node, int num_elements,
+ const tensorflow::TensorProto &input_tensor)
+{
+ const_node->size<loco::DataType::FLOAT32>(num_elements);
+
+ int32_t input_elements = input_tensor.float_val_size();
+
+ if (input_elements == num_elements)
+ {
+ for (int32_t i = 0; i < input_elements; i++)
+ {
+ const_node->at<loco::DataType::FLOAT32>(i) = input_tensor.float_val(i);
+ }
+ }
+ else
+ {
+ throw std::runtime_error("Error: Invalid Const values");
+ }
+}
+
+} // namespace
+
namespace moco
{
namespace tf
// Create a "ConstGen" node for Const
auto const_node = graph->nodes()->create<loco::ConstGen>();
- // TODO fill attributes
+ // set dtype
+ auto dtype = as_loco_datatype(get_datatype_attr(node, "dtype"));
+ const_node->dtype(dtype);
+
+ // import shape and value
+ const auto &input_tensor = get_tensor_attr(node, "value");
+ const auto &input_shape = input_tensor.tensor_shape();
+ const auto &input_dims = input_shape.dim();
+ assert(input_shape.dim_size() <= 6);
+ const_node->rank(input_shape.dim_size());
+ int index = 0;
+ bool zero_sized_shape = false;
+ for (auto &d : input_dims)
+ {
+ if (d.size() > std::numeric_limits<int>::max())
+ throw std::runtime_error("Shape element overflows");
+ if (d.size() == 0)
+ zero_sized_shape = true;
+
+ const_node->dim(index++) = loco::make_dimension(d.size());
+ }
+
+ int num_elements = 1;
+ if (zero_sized_shape)
+ {
+ const_node->rank(0);
+ num_elements = 0;
+ }
+ else
+ {
+ for (int d = 0; d < const_node->rank(); d++)
+ {
+ num_elements *= const_node->dim(d).value();
+ }
+ }
+
+ switch (dtype)
+ {
+ case loco::DataType::FLOAT32:
+ read_value_float32(const_node, num_elements, input_tensor);
+ break;
+
+ // TODO support other types
+
+ default:
+ throw std::runtime_error{"Error: Unsupported data type for " + node.name()};
+ }
nodes->enroll(node.name(), const_node);
-
- throw std::runtime_error{"NYI"};
}
} // namespace tf
--- /dev/null
+/*
+ * Copyright (c) 2019 Samsung Electronics Co., Ltd. All Rights Reserved
+ *
+ * Licensed under the Apache License, Version 2.0 (the "License");
+ * you may not use this file except in compliance with the License.
+ * You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+#include "TestHelper.h"
+
+#include <moco/tf/Frontend.h>
+
+#include <loco.h>
+
+#include <gtest/gtest.h>
+
+#include <tensorflow/core/framework/graph.pb.h>
+
+#include <cstring>
+#include <memory>
+
+using namespace moco::tf::test;
+
+namespace
+{
+// Test case for "input_tensor.float_val_size() == num_elements"
+
+// clang-format off
+const char *const_float_01_pbtxtdata = STRING_CONTENT(
+node {
+ name: "const/float"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ dim {
+ size: 3
+ }
+ }
+ float_val: 1.1
+ float_val: 2.2
+ float_val: 3.3
+ float_val: 4.4
+ float_val: 5.5
+ float_val: 6.6
+ }
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowFrontend, const_float_01)
+{
+ moco::tf::Frontend frontend;
+ moco::tf::ModelSignature signature;
+
+ imemstream mempb(const_float_01_pbtxtdata, std::strlen(const_float_01_pbtxtdata));
+
+ signature.add_output("const/float");
+
+ std::unique_ptr<loco::Graph> graph =
+ frontend.load(signature, &mempb, moco::tf::Frontend::FileType::Text);
+
+ loco::Graph::OutputContext *outputs = graph->outputs();
+ ASSERT_EQ(outputs->size(), 1);
+ loco::GraphOutput *output = outputs->at(0);
+ loco::Push *push = output->node();
+
+ loco::Graph::NodeContext *nodes = graph->nodes();
+ ASSERT_EQ(nodes->size(), 2);
+ loco::ConstGen *node0 = dynamic_cast<loco::ConstGen *>(nodes->at(0));
+ ASSERT_NE(node0, nullptr);
+ loco::Push *node1 = dynamic_cast<loco::Push *>(nodes->at(1));
+ ASSERT_EQ(node1, push);
+
+ ASSERT_EQ(node0->size<loco::DataType::FLOAT32>(), 6);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(0), 1.1f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(1), 2.2f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(2), 3.3f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(3), 4.4f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(4), 5.5f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(5), 6.6f);
+}