#include "Session.h"
#include <queue>
-
+#include <cmath>
#include <cassert>
namespace
q.push(dst_bag);
}
+template <typename Callable>
+void fold_constant_op(std::queue<coco::Bag *> &q, coco::UnaryOp *op, Callable evaluate)
+{
+ auto m = op->module();
+ auto d = enco::data(m);
+
+ auto ins = op->parent();
+ auto eval = ins->asEval();
+
+ // UnaryOp has only one arg
+ auto src_obj = *(op->uses().begin());
+ auto src_bag = src_obj->bag();
+
+ auto dst_obj = eval->out();
+ auto dst_bag = dst_obj->bag();
+
+ assert(d->allocated(src_bag));
+ assert(!d->allocated(dst_bag));
+
+ // TODO Support other data type
+ auto src_span = d->f32()->weight(src_bag);
+ assert(src_span.data() != nullptr);
+
+ auto src_feature = src_obj->asFeature();
+ auto dst_feature = dst_obj->asFeature();
+
+ // TODO Support other object type
+ if (src_feature == nullptr || dst_feature == nullptr)
+ {
+ return;
+ }
+
+ assert(src_feature != nullptr);
+ assert(dst_feature != nullptr);
+
+ // Allocate weight for destination
+ d->f32()->allocate(dst_bag);
+ auto dst_span = d->f32()->weight(dst_bag);
+
+ assert(src_feature->layout()->batch() == dst_feature->layout()->batch());
+ assert(src_feature->layout()->depth() == dst_feature->layout()->depth());
+ assert(src_feature->layout()->height() == dst_feature->layout()->height());
+ assert(src_feature->layout()->width() == dst_feature->layout()->width());
+
+ uint32_t const B = src_feature->layout()->batch();
+ uint32_t const C = src_feature->layout()->depth();
+ uint32_t const H = src_feature->layout()->height();
+ uint32_t const W = src_feature->layout()->width();
+
+ for (uint32_t b = 0; b < B; ++b)
+ {
+ for (uint32_t ch = 0; ch < C; ++ch)
+ {
+ for (uint32_t row = 0; row < H; ++row)
+ {
+ for (uint32_t col = 0; col < W; ++col)
+ {
+ auto src_ind = src_feature->layout()->at(b, ch, row, col);
+ auto dst_ind = dst_feature->layout()->at(b, ch, row, col);
+
+ evaluate(&dst_span[dst_ind.value()], src_span[src_ind.value()]);
+ }
+ }
+ }
+ }
+
+ // Let's detach eval
+ eval->out(nullptr);
+ eval->detach();
+
+ // Let's visit destination bag!
+ q.push(dst_bag);
+}
+
+template <typename Callable>
+void fold_constant_op(std::queue<coco::Bag *> &q, coco::BinaryOp *op, Callable evaluate)
+{
+ auto m = op->module();
+ auto d = enco::data(m);
+
+ auto ins = op->parent();
+ auto eval = ins->asEval();
+
+ // Already folded by the other bag
+ if (!eval->out())
+ {
+ return;
+ }
+
+ auto lhs_load = op->left()->asLoad();
+ auto lhs_obj = lhs_load->object();
+ auto lhs_bag = lhs_obj->bag();
+
+ auto rhs_load = op->right()->asLoad();
+ auto rhs_obj = rhs_load->object();
+ auto rhs_bag = rhs_obj->bag();
+
+ auto dst_obj = eval->out();
+ auto dst_bag = dst_obj->bag();
+
+ // The other bag is non-constant
+ if (!d->allocated(lhs_bag) || !d->allocated(rhs_bag))
+ {
+ return;
+ }
+
+ assert(d->allocated(lhs_bag));
+ assert(d->allocated(rhs_bag));
+ assert(!d->allocated(dst_bag));
+
+ // TODO Support other data type
+ auto lhs_span = d->f32()->weight(lhs_bag);
+ auto rhs_span = d->f32()->weight(rhs_bag);
+ assert(lhs_span.data() != nullptr);
+ assert(rhs_span.data() != nullptr);
+
+ auto lhs_feature = lhs_obj->asFeature();
+ auto rhs_feature = rhs_obj->asFeature();
+ auto dst_feature = dst_obj->asFeature();
+
+ // TODO Support other object type
+ if (lhs_feature == nullptr || rhs_feature == nullptr || dst_feature == nullptr)
+ {
+ return;
+ }
+
+ assert(lhs_feature != nullptr);
+ assert(rhs_feature != nullptr);
+ assert(dst_feature != nullptr);
+
+ // Allocate weight for destination
+ d->f32()->allocate(dst_bag);
+ auto dst_span = d->f32()->weight(dst_bag);
+
+ assert(lhs_feature->layout()->batch() == rhs_feature->layout()->batch());
+ assert(lhs_feature->layout()->depth() == rhs_feature->layout()->depth());
+ assert(lhs_feature->layout()->height() == rhs_feature->layout()->height());
+ assert(lhs_feature->layout()->width() == rhs_feature->layout()->width());
+
+ assert(lhs_feature->layout()->batch() == dst_feature->layout()->batch());
+ assert(lhs_feature->layout()->depth() == dst_feature->layout()->depth());
+ assert(lhs_feature->layout()->height() == dst_feature->layout()->height());
+ assert(lhs_feature->layout()->width() == dst_feature->layout()->width());
+
+ uint32_t const B = lhs_feature->layout()->batch();
+ uint32_t const C = lhs_feature->layout()->depth();
+ uint32_t const H = lhs_feature->layout()->height();
+ uint32_t const W = lhs_feature->layout()->width();
+
+ for (uint32_t b = 0; b < B; ++b)
+ {
+ for (uint32_t ch = 0; ch < C; ++ch)
+ {
+ for (uint32_t row = 0; row < H; ++row)
+ {
+ for (uint32_t col = 0; col < W; ++col)
+ {
+ auto lhs_ind = lhs_feature->layout()->at(b, ch, row, col);
+ auto rhs_ind = rhs_feature->layout()->at(b, ch, row, col);
+ auto dst_ind = dst_feature->layout()->at(b, ch, row, col);
+
+ evaluate(&dst_span[dst_ind.value()], lhs_span[lhs_ind.value()],
+ rhs_span[rhs_ind.value()]);
+ }
+ }
+ }
+ }
+
+ // Let's detach eval
+ eval->out(nullptr);
+ eval->detach();
+
+ // Let's visit destination bag!
+ q.push(dst_bag);
+}
+
+void fold_constant(std::queue<coco::Bag *> &q, coco::Eval *eval)
+{
+ // TODO Support other data types
+ if (auto op = eval->op()->asSqrt())
+ {
+ fold_constant_op(q, op, [](float *dst, float value) { *dst = std::sqrt(value); });
+ }
+ else if (auto op = eval->op()->asAdd())
+ {
+ fold_constant_op(q, op, [](float *dst, float lhs, float rhs) { *dst = lhs + rhs; });
+ }
+ else if (auto op = eval->op()->asSub())
+ {
+ fold_constant_op(q, op, [](float *dst, float lhs, float rhs) { *dst = lhs - rhs; });
+ }
+ else if (auto op = eval->op()->asMul())
+ {
+ fold_constant_op(q, op, [](float *dst, float lhs, float rhs) { *dst = lhs * rhs; });
+ }
+ else if (auto op = eval->op()->asDiv())
+ {
+ fold_constant_op(q, op, [](float *dst, float lhs, float rhs) { *dst = lhs / rhs; });
+ }
+ else
+ {
+ // Not supported opteration, do nothing
+ // TODO Support other operations
+ }
+}
+
void fold_constant(std::queue<coco::Bag *> &q, coco::Instr *ins)
{
if (auto copy = coco::safe_cast<coco::Copy>(ins))
fold_constant(q, copy);
return;
}
+ if (auto eval = coco::safe_cast<coco::Eval>(ins))
+ {
+ fold_constant(q, eval);
+ return;
+ }
+
// TODO Add more cases for constant folding
}