+++ /dev/null
-/*
- * Copyright (c) 2018 Samsung Electronics Co., Ltd. All Rights Reserved
- *
- * Licensed under the Apache License, Version 2.0 (the "License");
- * you may not use this file except in compliance with the License.
- * You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-#ifndef _MIR_OPS_BATCH_NORM_H_
-#define _MIR_OPS_BATCH_NORM_H_
-
-#include "mir/Operation.h"
-
-namespace mir
-{
-namespace ops
-{
-
-class BatchNormOp : public Operation
-{
-public:
- BatchNormOp(Output *arg, float moving_avg_fraction, float eps, bool spatial)
- : Operation(Type::batchNorm, {arg}), _moving_avg_fraction(moving_avg_fraction), _eps(eps),
- _spatial(spatial)
- {
- // Infer output shape.
- setOutputShape(0, getInputShape(0));
- }
-
- Operation *copyWithInputs(const std::vector<Output *> &inputs) override
- {
- return new BatchNormOp(inputs[0], _moving_avg_fraction, _eps, _spatial);
- }
-
- /**
- * @return The epsilon value to use to avoid division by zero.
- */
- float getEps() const { return _eps; }
-
- /**
- * @return Factor used in computing the running mean and variance.
- * e.g., running_mean = running_mean * movingAvgFraction + mean * (1 - movingAvgFraction).
- */
- float getMovingAvgFraction() const { return _moving_avg_fraction; }
-
- /**
- * @return If true, compute the mean and variance across all spatial elements If false, compute
- * the mean and variance per feature.
- */
- bool getSpatial() const { return _spatial; }
-
-private:
- float _moving_avg_fraction;
- float _eps;
- bool _spatial;
-};
-
-} // namespace ops
-} // namespace mir
-
-#endif //_MIR_OPS_BATCH_NORM_H_
_dot_builder.updateWithOp(&op, node_info);
}
-void IrDotDumper::visit(ops::BatchNormOp &op)
-{
- auto nodeInfo = DotIrNodeInfo()
- .withType("BatchNorm", op.getName())
- .withInShapes(getInputShapes(op))
- .withOutShapes(getOutputShapes(op))
- .withMisc("Moving Average Fraction", op.getMovingAvgFraction())
- .withMisc("Eps", op.getEps())
- .withMisc("Spatial", op.getSpatial());
- _dot_builder.updateWithOp(&op, nodeInfo);
-}
-
void IrDotDumper::visit(ops::SliceOp &op)
{
auto node_info = DotIrNodeInfo()