#include "GraphBuilder.h"
#include "GraphBuilderContext.h"
-#include "CompilerArgs.h"
#include <moco/tf/Names.h>
#include <plier/tf/Convert.h>
pull_node->dtype(dtype);
// Setting shape info.
- auto compiler_args = CompilerArgs::get();
- auto arg_shape = compiler_args->getInputShape(node.name());
-
- if (arg_shape == nullptr) // no user-provided shape info
+ pull_node->rank(num_dims);
+ for (int64_t d = 0; d < num_dims; d++)
{
- pull_node->rank(num_dims);
- for (int64_t d = 0; d < num_dims; d++)
+ assert(shape.dim(d).size() < std::numeric_limits<uint32_t>::max());
+ int64_t dim_value = shape.dim(d).size();
+ if (dim_value >= 0ULL)
{
- assert(shape.dim(d).size() < std::numeric_limits<uint32_t>::max());
- int64_t dim_value = shape.dim(d).size();
- if (dim_value >= 0ULL)
- {
- uint32_t dim_value32 = static_cast<uint32_t>(dim_value);
- pull_node->dim(d) = dim_value32;
- }
- else
- {
- pull_node->dim(d).unset();
- // TODO Remove assert() and do implement
- // NOTE Current implementation assumes dim is all know
- assert(false);
- }
+ uint32_t dim_value32 = static_cast<uint32_t>(dim_value);
+ pull_node->dim(d) = dim_value32;
}
- }
- else // when user provided shape info
- {
- // validation: compare shape attr with user-provided-shape
- if (num_dims != arg_shape->rank())
- throw std::runtime_error(
- "Shape information in Placeholder does not matched with the shape user provided: " +
- node.name());
-
- pull_node->rank(num_dims);
-
- for (int64_t d = 0; d < num_dims; d++)
+ else
{
- assert(shape.dim(d).size() < std::numeric_limits<uint32_t>::max());
- int64_t tf_dim = shape.dim(d).size();
- uint32_t arg_dim = arg_shape->dim(d);
-
- if (tf_dim >= 0 && tf_dim != arg_dim)
- throw std::runtime_error("Shape that user entered is different from shape in Placeholder");
-
- // set shape of pull
- pull_node->dim(d) = arg_dim;
+ pull_node->dim(d).unset();
+ // TODO Remove assert() and do implement
+ // NOTE Current implementation assumes dim is all know
+ assert(false);
}
}
ASSERT_TRUE(pull_node->dim(2).known() && pull_node->dim(2).value() == 3);
ASSERT_TRUE(pull_node->dim(3).known() && pull_node->dim(3).value() == 4);
}
-
-TEST(TensorFlowImport, placeholder_wrong_user_input_0)
-{
- // user provides shape info
- nncc::core::ADT::tensor::Shape shape{1024, 2, 111, 4}; // '111' is not matched with shape in model
- moco::tf::CompilerArgs::get()->addInput("placeholder", shape);
-
- // load graph
- moco::tf::Importer importer;
- moco::tf::ModelSignature signature;
- signature.add_output(moco::tf::TensorName("output", 0));
-
- tensorflow::GraphDef graph_def;
- EXPECT_TRUE(plier::tf::parse_graphdef(known_batch_pbtxt, graph_def));
- ASSERT_ANY_THROW(importer.import(signature, graph_def));
-
- moco::tf::CompilerArgs::get()->clear();
-}
-
-namespace
-{
-// clang-format off
-const char *unknown_batch_pbtxt = STRING_CONTENT(
-node {
- name: "placeholder"
- op: "Placeholder"
- attr {
- key: "dtype" value { type: DT_FLOAT }
- }
- attr {
- key: "shape"
- value {
- shape {
- dim { size: -1 }
- dim { size: 2 }
- dim { size: 3 }
- dim { size: 4 }
- }
- }
- }
-}
-node {
- name: "output"
- op: "Identity"
- input: "placeholder"
- attr {
- key: "T" value { type: DT_FLOAT }
- }
-}
-);
-// clang-format on
-
-} // namespace
-
-TEST(TensorFlowImport, placeholder_unknwon_batch)
-{
- constexpr uint32_t USER_PROVIDED_BATCH = 1024;
-
- // user provides shape info
- nncc::core::ADT::tensor::Shape shape{USER_PROVIDED_BATCH, 2, 3, 4};
- moco::tf::CompilerArgs::get()->addInput("placeholder", shape);
-
- // load graph
- moco::tf::Importer importer;
- moco::tf::ModelSignature signature;
- signature.add_output(moco::tf::TensorName("output", 0));
-
- tensorflow::GraphDef graph_def;
- EXPECT_TRUE(plier::tf::parse_graphdef(unknown_batch_pbtxt, graph_def));
- std::unique_ptr<loco::Graph> graph = importer.import(signature, graph_def);
-
- // get loco::Pull
- loco::Graph::NodeContext *loco_nodes = graph->nodes();
- loco::Pull *pull_node = dynamic_cast<loco::Pull *>(loco_nodes->at(0));
-
- // Check dim
- ASSERT_TRUE(pull_node->dim(0).known() && pull_node->dim(0).value() == USER_PROVIDED_BATCH);
- ASSERT_TRUE(pull_node->dim(1).known() && pull_node->dim(1).value() == 2);
- ASSERT_TRUE(pull_node->dim(2).known() && pull_node->dim(2).value() == 3);
- ASSERT_TRUE(pull_node->dim(3).known() && pull_node->dim(3).value() == 4);
-
- moco::tf::CompilerArgs::get()->clear();
-}
-
-TEST(TensorFlowImport, placeholder_wrong_user_input_1)
-{
- constexpr uint32_t USER_PROVIDED_BATCH = 1024;
-
- // user provides shape info
- nncc::core::ADT::tensor::Shape shape{USER_PROVIDED_BATCH, 2, 3, 111}; // '111' is not matched
- moco::tf::CompilerArgs::get()->addInput("placeholder", shape);
-
- // load graph
- moco::tf::Importer importer;
- moco::tf::ModelSignature signature;
- signature.add_output(moco::tf::TensorName("output", 0));
-
- tensorflow::GraphDef graph_def;
- EXPECT_TRUE(plier::tf::parse_graphdef(unknown_batch_pbtxt, graph_def));
- ASSERT_ANY_THROW(importer.import(signature, graph_def));
-
- moco::tf::CompilerArgs::get()->clear();
-}