int32_t input_elements = input_tensor.float_val_size();
- if (input_elements == num_elements)
+ if (input_elements == 1)
+ {
+ for (int32_t i = 0; i < num_elements; i++)
+ {
+ const_node->at<loco::DataType::FLOAT32>(i) = input_tensor.float_val(0);
+ }
+ }
+ else if (input_elements == num_elements)
{
for (int32_t i = 0; i < input_elements; i++)
{
ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(4), 5.5f);
ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(5), 6.6f);
}
+
+namespace
+{
+// Test case for "input_tensor.float_val_size() == 1"
+
+// clang-format off
+const char *const_float_02_pbtxtdata = STRING_CONTENT(
+node {
+ name: "const/float"
+ op: "Const"
+ attr {
+ key: "dtype"
+ value {
+ type: DT_FLOAT
+ }
+ }
+ attr {
+ key: "value"
+ value {
+ tensor {
+ dtype: DT_FLOAT
+ tensor_shape {
+ dim {
+ size: 2
+ }
+ dim {
+ size: 3
+ }
+ }
+ float_val: 1.1
+ }
+ }
+ }
+}
+);
+// clang-format on
+
+} // namespace
+
+TEST(TensorFlowFrontend, const_float_02)
+{
+ moco::tf::Frontend frontend;
+ moco::tf::ModelSignature signature;
+
+ imemstream mempb(const_float_02_pbtxtdata, std::strlen(const_float_02_pbtxtdata));
+
+ signature.add_output("const/float");
+
+ std::unique_ptr<loco::Graph> graph =
+ frontend.load(signature, &mempb, moco::tf::Frontend::FileType::Text);
+
+ loco::Graph::OutputContext *outputs = graph->outputs();
+ ASSERT_EQ(outputs->size(), 1);
+ loco::GraphOutput *output = outputs->at(0);
+ loco::Push *push = output->node();
+
+ loco::Graph::NodeContext *nodes = graph->nodes();
+ ASSERT_EQ(nodes->size(), 2);
+ loco::ConstGen *node0 = dynamic_cast<loco::ConstGen *>(nodes->at(0));
+ ASSERT_NE(node0, nullptr);
+ loco::Push *node1 = dynamic_cast<loco::Push *>(nodes->at(1));
+ ASSERT_EQ(node1, push);
+
+ ASSERT_EQ(node0->size<loco::DataType::FLOAT32>(), 6);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(0), 1.1f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(1), 1.1f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(2), 1.1f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(3), 1.1f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(4), 1.1f);
+ ASSERT_EQ(node0->at<loco::DataType::FLOAT32>(5), 1.1f);
+}