#include "Convert.h"
#include "GraphBuilder.h"
#include "GraphBuilderContext.h"
+#include "Knob.h"
+
+#include "IR/TFConst.h"
#include <moco/tf/Names.h>
#include <loco.h>
} // namespace
+namespace
+{
+
+void read_value_int32(moco::tf::TFConst *const_node, int num_elements,
+ const tensorflow::TensorProto &input_tensor)
+{
+ const_node->size<loco::DataType::S32>(num_elements);
+
+ int32_t input_elements = input_tensor.int_val_size();
+
+ if (input_tensor.tensor_content().size() == num_elements * sizeof(int32_t))
+ {
+ const std::string &str_content = input_tensor.tensor_content();
+ const int32_t *s32_ptr = reinterpret_cast<const int32_t *>(str_content.c_str());
+ for (int32_t i = 0; i < num_elements; i++)
+ {
+ const_node->at<loco::DataType::S32>(i) = *(s32_ptr + i);
+ }
+ }
+ else if (0 < input_elements && input_elements <= num_elements)
+ {
+ for (int32_t i = 0; i < input_elements; i++)
+ {
+ const_node->at<loco::DataType::S32>(i) = input_tensor.int_val(i);
+ }
+
+ for (int32_t i = input_elements; i < num_elements; i++)
+ {
+ const_node->at<loco::DataType::S32>(i) = input_tensor.int_val(input_elements - 1);
+ }
+ }
+ else
+ {
+ throw std::runtime_error("Error: Invalid Const values");
+ }
+}
+
+void read_value_float32(moco::tf::TFConst *const_node, int num_elements,
+ const tensorflow::TensorProto &input_tensor)
+{
+ const_node->size<loco::DataType::FLOAT32>(num_elements);
+
+ int32_t input_elements = input_tensor.float_val_size();
+
+ if (input_tensor.tensor_content().size() == num_elements * sizeof(float))
+ {
+ const std::string &str_content = input_tensor.tensor_content();
+ const float *float_ptr = reinterpret_cast<const float *>(str_content.c_str());
+ for (int32_t i = 0; i < num_elements; i++)
+ {
+ const_node->at<loco::DataType::FLOAT32>(i) = *(float_ptr + i);
+ }
+ }
+ else if (0 < input_elements && input_elements <= num_elements)
+ {
+ for (int32_t i = 0; i < input_elements; i++)
+ {
+ const_node->at<loco::DataType::FLOAT32>(i) = input_tensor.float_val(i);
+ }
+
+ for (int32_t i = input_elements; i < num_elements; i++)
+ {
+ const_node->at<loco::DataType::FLOAT32>(i) = input_tensor.float_val(input_elements - 1);
+ }
+ }
+ else
+ {
+ throw std::runtime_error("Error: Invalid Const values");
+ }
+}
+
+} // namespace
+
namespace moco
{
namespace tf
public:
bool validate(const tensorflow::NodeDef &) const override;
void build(const tensorflow::NodeDef &, GraphBuilderContext *) const override;
+
+private:
+ void buildCanonical(const tensorflow::NodeDef &, GraphBuilderContext *) const;
+ void buildTF(const tensorflow::NodeDef &, GraphBuilderContext *) const;
};
bool ConstGraphBuilder::validate(const tensorflow::NodeDef &node) const
{
assert(context != nullptr);
+ if (moco::tf::get<moco::tf::Knob::ImportAsTFConst>())
+ buildTF(node, context);
+ else
+ buildCanonical(node, context);
+}
+
+void ConstGraphBuilder::buildCanonical(const tensorflow::NodeDef &node,
+ GraphBuilderContext *context) const
+{
loco::Graph *graph = context->graph();
SymbolTable *tensor_names = context->tensor_names();
tensor_names->enroll(output_name, const_node);
}
+void ConstGraphBuilder::buildTF(const tensorflow::NodeDef &node, GraphBuilderContext *context) const
+{
+ loco::Graph *graph = context->graph();
+ SymbolTable *tensor_names = context->tensor_names();
+
+ // Create a "TFConstant" node for Const
+ auto const_node = graph->nodes()->create<moco::tf::TFConst>();
+
+ // set dtype
+ auto dtype = as_loco_datatype(get_datatype_attr(node, "dtype"));
+ const_node->dtype(dtype);
+
+ // import shape and value
+ const auto &input_tensor = get_tensor_attr(node, "value");
+ const auto &input_shape = input_tensor.tensor_shape();
+ const auto &input_dims = input_shape.dim();
+ assert(input_shape.dim_size() <= 6);
+ const_node->rank(input_shape.dim_size());
+ int index = 0;
+ bool zero_sized_shape = false;
+ for (auto &d : input_dims)
+ {
+ if (d.size() > std::numeric_limits<int>::max())
+ throw std::runtime_error("Shape element overflows");
+ if (d.size() == 0)
+ zero_sized_shape = true;
+
+ if (d.size() >= 0)
+ const_node->dim(index++) = d.size();
+ else
+ throw std::runtime_error{"Error: Unknown dim size for " + node.name()};
+ }
+
+ int num_elements = 1;
+ if (zero_sized_shape)
+ {
+ const_node->rank(0);
+ num_elements = 0;
+ }
+ else
+ {
+ for (int d = 0; d < const_node->rank(); d++)
+ {
+ num_elements *= const_node->dim(d).value();
+ }
+ }
+
+ switch (dtype)
+ {
+ case loco::DataType::S32:
+ read_value_int32(const_node, num_elements, input_tensor);
+ break;
+
+ case loco::DataType::FLOAT32:
+ read_value_float32(const_node, num_elements, input_tensor);
+ break;
+
+ // TODO support other types
+
+ default:
+ throw std::runtime_error{"Error: Unsupported data type for " + node.name()};
+ }
+
+ // register string-name to node
+ TensorName output_name(node.name(), 0);
+ tensor_names->enroll(output_name, const_node);
+}
+
} // namespace tf
} // namespace moco