virtual void addShapeConstr(const ::internal::tflite::operand::Index &ind,
const ::arm_compute::TensorInfo &info) = 0;
- virtual void addShapeConstr(const ::internal::tflite::operand::Index &ind,
- const ::arm_compute::TensorInfo &info,
- const nnfw::util::tensor::Shape &shape) = 0;
+ virtual void addShapeConstr(const ::internal::tflite::operand::Index &lhs_ind,
+ const ::internal::tflite::operand::Object &lhs_obj,
+ const nnfw::util::tensor::Shape &lhs_shape,
+ const ::internal::tflite::operand::Index &rhs_ind,
+ const ::internal::tflite::operand::Object &rhs_obj,
+ const nnfw::util::tensor::Shape &rhs_shape) = 0;
virtual void addSubsumptionConstr(const ::internal::tflite::operand::Index &ind,
const ::internal::tflite::operand::Index &base,
const ::arm_compute::Coordinates &offset,
void Planner::visit(const ::internal::tflite::op::Add::Node &node)
{
const ::internal::tflite::operand::Index ofm_index{node.param().ofm_index};
-
const ::internal::tflite::operand::Index lhs_index{node.param().lhs_index};
const ::internal::tflite::operand::Index rhs_index{node.param().rhs_index};
-
const ::internal::tflite::operand::Index activation_index{node.param().activation_index};
const auto ofm_shape = _ctx.at(ofm_index).shape().asTensor();
// TODO Should move to the place where the operand is handled, if it is possible.
// Set Shape Constraints and TensorInfo
- _builder.addShapeConstr(ofm_index,
- asTensorInfo(ofm_shape, _ctx.at(ofm_index).type(),
- _ctx.at(ofm_index).scale(), _ctx.at(ofm_index).zeroPoint()));
-
- if (lhs_shape.rank() == 4 && rhs_shape.rank() < 4)
- {
- _builder.addShapeConstr(lhs_index, asTensorInfo(lhs_shape, _ctx.at(lhs_index).type(),
- _ctx.at(lhs_index).scale(),
- _ctx.at(lhs_index).zeroPoint()));
- _builder.addShapeConstr(rhs_index,
- asTensorInfoForBroadcast(rhs_shape, _ctx.at(rhs_index).type(),
- _ctx.at(ofm_index).scale(),
- _ctx.at(ofm_index).zeroPoint()),
- rhs_shape);
- }
- else if (rhs_shape.rank() == 4 && lhs_shape.rank() < 4)
- {
- _builder.addShapeConstr(lhs_index,
- asTensorInfoForBroadcast(lhs_shape, _ctx.at(lhs_index).type(),
- _ctx.at(lhs_index).scale(),
- _ctx.at(lhs_index).zeroPoint()),
- lhs_shape);
- _builder.addShapeConstr(rhs_index, asTensorInfo(rhs_shape, _ctx.at(rhs_index).type(),
- _ctx.at(rhs_index).scale(),
- _ctx.at(rhs_index).zeroPoint()));
- }
- else
- {
- _builder.addShapeConstr(lhs_index, asTensorInfo(lhs_shape, _ctx.at(lhs_index).type(),
- _ctx.at(lhs_index).scale(),
- _ctx.at(lhs_index).zeroPoint()));
- _builder.addShapeConstr(rhs_index, asTensorInfo(rhs_shape, _ctx.at(rhs_index).type(),
- _ctx.at(rhs_index).scale(),
- _ctx.at(rhs_index).zeroPoint()));
- }
+ _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape, _ctx.at(ofm_index).type()));
+ _builder.addShapeConstr(lhs_index, _ctx.at(lhs_index), lhs_shape, rhs_index, _ctx.at(rhs_index),
+ rhs_shape);
// Construct operation parameters
struct Param
void Planner::visit(const ::internal::tflite::op::Sub::Node &node)
{
const ::internal::tflite::operand::Index ofm_index{node.param().ofm_index};
-
const ::internal::tflite::operand::Index lhs_index{node.param().lhs_index};
const ::internal::tflite::operand::Index rhs_index{node.param().rhs_index};
-
const ::internal::tflite::operand::Index activation_index{node.param().activation_index};
const auto ofm_shape = _ctx.at(ofm_index).shape().asTensor();
const auto rhs_shape = _ctx.at(rhs_index).shape().asTensor();
_builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape, _ctx.at(ofm_index).type()));
- _builder.addShapeConstr(lhs_index, asTensorInfo(lhs_shape, _ctx.at(lhs_index).type()));
- _builder.addShapeConstr(rhs_index, asTensorInfo(rhs_shape, _ctx.at(rhs_index).type()));
+ _builder.addShapeConstr(lhs_index, _ctx.at(lhs_index), lhs_shape, rhs_index, _ctx.at(rhs_index),
+ rhs_shape);
// Construct operation parameters
struct Param
const ::internal::tflite::operand::Index rhs_index{node.param().rhs_index};
const ::internal::tflite::operand::Index activation_index{node.param().activation_index};
- int32_t ofm_rank = _ctx.at(ofm_index).shape().rank();
- int32_t lhs_rank = _ctx.at(lhs_index).shape().rank();
- int32_t rhs_rank = _ctx.at(rhs_index).shape().rank();
const auto ofm_shape = _ctx.at(ofm_index).shape().asTensor();
const auto lhs_shape = _ctx.at(lhs_index).shape().asTensor();
const auto rhs_shape = _ctx.at(rhs_index).shape().asTensor();
- // not tested cases below
- assert(!(ofm_rank == 0 && lhs_rank == 0 && rhs_rank == 0));
- assert(ofm_rank < 4 && lhs_rank < 4 && rhs_rank < 4);
-
- if (ofm_rank > 3)
- {
- throw std::runtime_error("Not supported, yet");
- }
-
- _builder.addShapeConstr(ofm_index,
- asTensorInfo(ofm_shape, _ctx.at(ofm_index).type(),
- _ctx.at(ofm_index).scale(), _ctx.at(ofm_index).zeroPoint()));
-
- if (lhs_rank > 3)
- {
- throw std::runtime_error("Not supported, yet");
- }
-
- _builder.addShapeConstr(lhs_index,
- asTensorInfo(lhs_shape, _ctx.at(lhs_index).type(),
- _ctx.at(lhs_index).scale(), _ctx.at(lhs_index).zeroPoint()));
-
- if (rhs_rank > 3)
- {
- throw std::runtime_error("Not supported, yet");
- }
-
- _builder.addShapeConstr(rhs_index,
- asTensorInfo(rhs_shape, _ctx.at(rhs_index).type(),
- _ctx.at(rhs_index).scale(), _ctx.at(rhs_index).zeroPoint()));
+ _builder.addShapeConstr(ofm_index, asTensorInfo(ofm_shape, _ctx.at(ofm_index).type()));
+ _builder.addShapeConstr(lhs_index, _ctx.at(lhs_index), lhs_shape, rhs_index, _ctx.at(rhs_index),
+ rhs_shape);
struct Param
{
public:
void addShapeConstr(const ::internal::tflite::operand::Index &ind,
const ::arm_compute::TensorInfo &info) override;
- void addShapeConstr(const ::internal::tflite::operand::Index &ind,
- const ::arm_compute::TensorInfo &info,
- const nnfw::util::tensor::Shape &shape) override;
+ void addShapeConstr(const ::internal::tflite::operand::Index &lhs_ind,
+ const ::internal::tflite::operand::Object &lhs_obj,
+ const nnfw::util::tensor::Shape &lhs_shape,
+ const ::internal::tflite::operand::Index &rhs_ind,
+ const ::internal::tflite::operand::Object &rhs_obj,
+ const nnfw::util::tensor::Shape &rhs_shape) override;
public:
void addSubsumptionConstr(const ::internal::tflite::operand::Index &ind,
{
_tensor_info_ctx[ind.asInt()] = info;
}
-void PlanBuilder::addShapeConstr(const ::internal::tflite::operand::Index &ind,
- const ::arm_compute::TensorInfo &info,
- const nnfw::util::tensor::Shape &shape)
-{
- // ACL tensor info
- _tensor_info_ctx[ind.asInt()] = info;
- // broadcasting tensor shape
- internal::tflite::operand::Shape broadcastShape(4);
- if (shape.rank() == 1)
- {
- broadcastShape.dim(0) = 1;
- broadcastShape.dim(1) = 1;
- broadcastShape.dim(2) = 1;
- broadcastShape.dim(3) = shape.dim(0);
- }
- else if (shape.rank() == 2)
- {
- broadcastShape.dim(0) = 1;
- broadcastShape.dim(1) = 1;
- broadcastShape.dim(2) = shape.dim(0);
- broadcastShape.dim(3) = shape.dim(1);
+// Add tensor shape constraints considering broadcasting
+void PlanBuilder::addShapeConstr(const ::internal::tflite::operand::Index &lhs_ind,
+ const ::internal::tflite::operand::Object &lhs_obj,
+ const nnfw::util::tensor::Shape &lhs_shape,
+ const ::internal::tflite::operand::Index &rhs_ind,
+ const ::internal::tflite::operand::Object &rhs_obj,
+ const nnfw::util::tensor::Shape &rhs_shape)
+{
+ // right-side broadcasting
+ if (lhs_shape.rank() > rhs_shape.rank())
+ {
+ // ACL tensor info
+ _tensor_info_ctx[lhs_ind.asInt()] = asTensorInfo(lhs_shape, lhs_obj.type());
+ _tensor_info_ctx[rhs_ind.asInt()] =
+ asTensorInfoForBroadcast(rhs_shape, rhs_obj.type(), lhs_shape.rank());
+
+ // TFlite broadcasting tensor shape
+ if (lhs_shape.rank() == 4)
+ _broadcasting_tensor_shape.emplace(rhs_ind.asInt(),
+ asTensorShapeForTFLiteBroadcast(rhs_shape));
}
- else if (shape.rank() == 3)
+ // left-side broadcasting
+ else if (lhs_shape.rank() < rhs_shape.rank())
{
- broadcastShape.dim(0) = 1;
- broadcastShape.dim(1) = shape.dim(0);
- broadcastShape.dim(2) = shape.dim(1);
- broadcastShape.dim(3) = shape.dim(2);
+ _tensor_info_ctx[lhs_ind.asInt()] =
+ asTensorInfoForBroadcast(lhs_shape, lhs_obj.type(), rhs_shape.rank());
+ _tensor_info_ctx[rhs_ind.asInt()] = asTensorInfo(rhs_shape, rhs_obj.type());
+
+ if (rhs_shape.rank() == 4)
+ _broadcasting_tensor_shape.emplace(lhs_ind.asInt(),
+ asTensorShapeForTFLiteBroadcast(lhs_shape));
}
+ // no broadcasting
else
{
- throw std::runtime_error("Not supported, yet");
+ _tensor_info_ctx[lhs_ind.asInt()] = asTensorInfo(lhs_shape, lhs_obj.type());
+ _tensor_info_ctx[rhs_ind.asInt()] = asTensorInfo(rhs_shape, rhs_obj.type());
}
- _broadcasting_tensor_shape.emplace(ind.asInt(), broadcastShape);
}
void PlanBuilder::addSubsumptionConstr(const ::internal::tflite::operand::Index &ind,
}
}
-// in case of NHWC
+// ACL Broadcasting style in case of NHWC
// TODO HCHW
-::arm_compute::TensorShape asTensorShapeForBroadcast(const nnfw::util::tensor::Shape &shape)
+::arm_compute::TensorShape asTensorShapeForBroadcast(const nnfw::util::tensor::Shape &shape,
+ const size_t baseRank)
{
- if (shape.rank() == 1)
+ // The cases that large rank(baseRank) is 4 and small rank is less than 4 need to transform to
+ // broadcasting TensorInfo because order is different.
+ if (baseRank == 4)
{
- return ::arm_compute::TensorShape(1, 1, shape.dim(0), 1);
+ if (shape.rank() == 0)
+ {
+ return ::arm_compute::TensorShape(1);
+ }
+ else if (shape.rank() == 1)
+ {
+ return ::arm_compute::TensorShape(1, 1, shape.dim(0), 1);
+ }
+ else if (shape.rank() == 2)
+ {
+ return ::arm_compute::TensorShape(shape.dim(0), 1, shape.dim(1), 1); // w c -> w h c n
+ }
+ else if (shape.rank() == 3)
+ {
+ return ::arm_compute::TensorShape(shape.dim(1), shape.dim(0), shape.dim(2),
+ 1); // h w c -> w h c n
+ }
+ else if (shape.rank() == 4)
+ {
+ assert(shape.dim(0) ==
+ 1); // In case of ADD, SUB, MUL and DIV at ACL OpenCL, 3D inputs are supported.
+ return ::arm_compute::TensorShape(shape.dim(2), shape.dim(1), shape.dim(3),
+ shape.dim(0)); // n h w c -> W H C N
+ }
+ else
+ {
+ throw std::runtime_error("Not supported, yet");
+ }
}
- else if (shape.rank() == 2)
+ // Other cases that larger rank <= 3 don't need to transform because broadcast shape is the same
+ // as orignal. For example, ::arm_compute::TensorShape(shape.dim(0), 1, 1) ==
+ // ::arm_compute::TensorShape(shape.dim(0).
+ else
{
- return ::arm_compute::TensorShape(shape.dim(0), 1, shape.dim(1), 1); // w c -> w h c n
+ return asTensorShape(shape);
}
- else if (shape.rank() == 3)
+}
+
+// TFLite broadcasting style: used for reading an input as broadcasting shape
+internal::tflite::operand::Shape
+asTensorShapeForTFLiteBroadcast(const nnfw::util::tensor::Shape &shape)
+{
+ internal::tflite::operand::Shape broadcastShape(4);
+ if (shape.rank() == 0)
{
- return ::arm_compute::TensorShape(shape.dim(1), shape.dim(0), shape.dim(2),
- 1); // h w c -> w h c n
+ broadcastShape.dim(0) = 1;
+ broadcastShape.dim(1) = 1;
+ broadcastShape.dim(2) = 1;
+ broadcastShape.dim(3) = 1;
}
- else if (shape.rank() == 4)
+ else if (shape.rank() == 1)
{
- assert(shape.dim(0) == 1); // In case of ADD, SUB, 3D inputs are supported.
- return ::arm_compute::TensorShape(shape.dim(2), shape.dim(1), shape.dim(3),
- shape.dim(0)); // n h w c -> W H C N
+ broadcastShape.dim(0) = 1;
+ broadcastShape.dim(1) = 1;
+ broadcastShape.dim(2) = 1;
+ broadcastShape.dim(3) = shape.dim(0);
+ }
+ else if (shape.rank() == 2)
+ {
+ broadcastShape.dim(0) = 1;
+ broadcastShape.dim(1) = 1;
+ broadcastShape.dim(2) = shape.dim(0);
+ broadcastShape.dim(3) = shape.dim(1);
+ }
+ else if (shape.rank() == 3)
+ {
+ broadcastShape.dim(0) = 1;
+ broadcastShape.dim(1) = shape.dim(0);
+ broadcastShape.dim(2) = shape.dim(1);
+ broadcastShape.dim(3) = shape.dim(2);
}
else
{
throw std::runtime_error("Not supported, yet");
}
+ return broadcastShape;
}
inline ::arm_compute::TensorShape asTensorShape(const internal::tflite::operand::Shape &shape)
}
::arm_compute::TensorInfo asTensorInfoForBroadcast(const nnfw::util::tensor::Shape &shape,
- const int32_t type, const float scale = 0.0f,
+ const int32_t type, const size_t baseRank,
+ const float scale = 0.0f,
const int32_t zeroPoint = 0)
{
- return ::arm_compute::TensorInfo(asTensorShapeForBroadcast(shape), 1, asDataType(type),
+ return ::arm_compute::TensorInfo(asTensorShapeForBroadcast(shape, baseRank), 1, asDataType(type),
asQuantizationInfo(scale, zeroPoint));
}