void NodeExecution::execute(loco::EltwiseAdd *eltwise_add)
{
+#if 0
auto lhs_data = annot_data(eltwise_add->lhs());
auto rhs_data = annot_data(eltwise_add->rhs());
erase_annot_data(eltwise_add);
annot_data(eltwise_add, std::move(eltwise_add_data));
annot_domain(eltwise_add, annot_domain(eltwise_add->lhs()));
+#endif
+
+ struct Func final : public BinaryFunc
+ {
+ float apply(float lhs, float rhs) const { return lhs + rhs; }
+ };
+
+ Func f;
+
+ eltwise_binary(eltwise_add, f);
}
} // namespace locomotiv
void NodeExecution::execute(loco::EltwiseDiv *eltwise_div)
{
+#if 0
auto lhs_data = annot_data(eltwise_div->lhs());
auto rhs_data = annot_data(eltwise_div->rhs());
erase_annot_data(eltwise_div);
annot_data(eltwise_div, std::move(eltwise_div_data));
annot_domain(eltwise_div, annot_domain(eltwise_div->lhs()));
+#endif
+
+ struct Func final : public BinaryFunc
+ {
+ float apply(float lhs, float rhs) const { return lhs / rhs; }
+ };
+
+ Func f;
+
+ eltwise_binary(eltwise_div, f);
}
} // namespace locomotiv
void NodeExecution::execute(loco::EltwiseMul *eltwise_mul)
{
+#if 0
auto lhs_data = annot_data(eltwise_mul->lhs());
auto rhs_data = annot_data(eltwise_mul->rhs());
erase_annot_data(eltwise_mul);
annot_data(eltwise_mul, std::move(eltwise_mul_data));
annot_domain(eltwise_mul, annot_domain(eltwise_mul->lhs()));
+#endif
+
+ struct Func final : public BinaryFunc
+ {
+ float apply(float lhs, float rhs) const { return lhs * rhs; }
+ };
+
+ Func f;
+
+ eltwise_binary(eltwise_mul, f);
}
} // namespace locomotiv
void NodeExecution::execute(loco::EltwiseSub *eltwise_sub)
{
+#if 0
auto lhs_data = annot_data(eltwise_sub->lhs());
auto rhs_data = annot_data(eltwise_sub->rhs());
erase_annot_data(eltwise_sub);
annot_data(eltwise_sub, std::move(eltwise_sub_data));
annot_domain(eltwise_sub, annot_domain(eltwise_sub->lhs()));
+#endif
+
+ struct Func final : public BinaryFunc
+ {
+ float apply(float lhs, float rhs) const { return lhs - rhs; }
+ };
+
+ Func f;
+
+ eltwise_binary(eltwise_sub, f);
}
} // namespace locomotiv
float UnaryFunc::apply(float) const { throw std::runtime_error{"F32 is not supported yet"}; }
int32_t UnaryFunc::apply(int32_t) const { throw std::runtime_error{"S32 is not supported yet"}; }
+float BinaryFunc::apply(float, float) const
+{
+ throw std::runtime_error{"F32 is not supported yet"};
+}
+
+int32_t BinaryFunc::apply(int32_t, int32_t) const
+{
+ throw std::runtime_error{"S32 is not supported yet"};
+}
+
// TODO Use visitor pattern of loco when available
void NodeExecution::run(loco::Node *node)
{
annot_domain(output_node, output_domain);
}
+void NodeExecution::eltwise_binary(loco::Node *node, const BinaryFunc &f)
+{
+ auto lhs_node = node->arg(0);
+ auto rhs_node = node->arg(1);
+ auto lhs_data = annot_data(lhs_node);
+ auto rhs_data = annot_data(rhs_node);
+
+ validate(lhs_data && rhs_data, "Input not ready");
+ validate(annot_domain(lhs_node) == annot_domain(rhs_node), "Wrong input domain");
+ validate(lhs_data->dtype() == rhs_data->dtype(), "Wrong input type");
+ validate(*lhs_data->shape() == *rhs_data->shape(), "Wrong input shape");
+
+ auto out_node = node;
+ std::unique_ptr<NodeData> out_data = nullptr;
+
+ switch (lhs_data->dtype())
+ {
+ case loco::DataType::FLOAT32:
+ {
+ auto lhs_bufptr = lhs_data->as_f32_bufptr();
+ auto rhs_bufptr = rhs_data->as_f32_bufptr();
+ auto out_bufptr = make_buffer<float, LexicalLayout>(*lhs_data->shape());
+
+ auto *shape = lhs_data->shape();
+
+ for (IndexEnumerator e{*shape}; e.valid(); e.advance())
+ {
+ const auto &index = e.current();
+ out_bufptr.at(index) = f.apply(lhs_bufptr->at(index), rhs_bufptr->at(index));
+ }
+
+ out_data = make_data(out_bufptr);
+ break;
+ }
+ default:
+ throw std::runtime_error("NYI for this DataType");
+ }
+
+ assert(out_data != nullptr);
+ erase_annot_data(out_node);
+ annot_data(out_node, std::move(out_data));
+ annot_domain(out_node, annot_domain(lhs_node));
+}
+
} // namespace locomotiv
virtual int32_t apply(int32_t) const;
};
+// Q. How to support mixed precision binary operators?
+struct BinaryFunc
+{
+ virtual ~BinaryFunc() = default;
+
+ virtual float apply(float, float) const;
+ virtual int32_t apply(int32_t, int32_t) const;
+};
+
/**
* @brief Helper class for Session, responsible to process one node calculation.
*/
#undef NODE
void eltwise_unary(loco::Node *node, const UnaryFunc &f);
+ void eltwise_binary(loco::Node *node, const BinaryFunc &f);
};
} // namespace locomotiv