class AvgPool2D : public Op
{
public:
+ enum class Divisor
+ {
+ Unknown,
+ // Use the number of elements in each receptive field as a divisor
+ Static,
+ // Use the number of valid (non-padding) elements in each receptive field as a divisor
+ PaddingExcluded
+ };
+
+public:
explicit AvgPool2D(const PtrLink<Op, Instr> *);
public:
const AvgPool2D *asAvgPool2D(void) const override { return this; }
public:
+ Divisor divisor(void) const { return _divisor; }
+ void divisor(const Divisor &divisor) { _divisor = divisor; }
+
+public:
Window2D *window(void) { return &_window; }
const Window2D *window(void) const { return &_window; }
const PtrLink<Op, Instr> *const _op_link;
private:
+ Divisor _divisor = Divisor::Unknown;
+
Window2D _window;
Stride2D _stride;
Padding2D _pad;
// parent() should be nullptr on construction
ASSERT_EQ(op->parent(), nullptr);
+ // divisor() SHOULD be unknow on construction
+ ASSERT_EQ(immutable_ptr->divisor(), coco::AvgPool2D::Divisor::Unknown);
+
// window() SHOULD return a valid pointer
ASSERT_NE(mutable_ptr->window(), nullptr);
ASSERT_EQ(mutable_ptr->window(), immutable_ptr->window());
ASSERT_TRUE(mutable_ptr->accept(IsAvgPool2D{}));
ASSERT_TRUE(immutable_ptr->accept(IsAvgPool2D{}));
}
+
+TEST_F(AvgPool2DTest, disivor)
+{
+ auto op = allocate();
+
+ op->divisor(coco::AvgPool2D::Divisor::Static);
+
+ ASSERT_EQ(op->divisor(), coco::AvgPool2D::Divisor::Static);
+}