class NanCheckOp final : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
- NanCheckOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<Context>(operator_def, ws) {}
+ template <class... Args>
+ explicit NanCheckOp(Args&&... args)
+ : Operator<Context>(std::forward<Args>(args)...) {}
bool RunOnDevice() override;
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
- WallClockTimeOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<Context>(operator_def, ws) {}
+ template <class... Args>
+ explicit WallClockTimeOp(Args&&... args)
+ : Operator<Context>(std::forward<Args>(args)...) {}
bool RunOnDevice() override {
int64_t nanoseconds = static_cast<long int>(
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
USE_DISPATCH_HELPER;
- PrintOp(const OperatorDef& operator_def, Workspace* ws)
+ explicit PrintOp(const OperatorDef& operator_def, Workspace* ws)
: Operator<Context>(operator_def, ws),
tensor_printer_(
operator_def.input(0),
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
- WeightedSumGradientOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<Context>(operator_def, ws),
+ template <class... Args>
+ explicit WeightedSumGradientOp(Args&&... args)
+ : Operator<Context>(std::forward<Args>(args)...),
grad_on_w_(this->template GetSingleArgument<bool>("grad_on_w", false)) {
}
USE_OPERATOR_CONTEXT_FUNCTIONS;
virtual ~ScatterAssignOp() {}
- ScatterAssignOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<Context>(operator_def, ws),
+ template <class... Args>
+ explicit ScatterAssignOp(Args&&... args)
+ : Operator<Context>(std::forward<Args>(args)...),
runners_({{{TensorProto_DataType_INT32, TensorProto_DataType_FLOAT},
&ScatterAssignOp::DoRun<int32_t, float>},
{{TensorProto_DataType_INT32, TensorProto_DataType_FLOAT16},
class LengthsToWeightsOp : public Operator<Context> {
public:
USE_OPERATOR_CONTEXT_FUNCTIONS;
- LengthsToWeightsOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<Context>(operator_def, ws),
+ template <class... Args>
+ explicit LengthsToWeightsOp(Args&&... args)
+ : Operator<Context>(std::forward<Args>(args)...),
power_(this->template GetSingleArgument<float>("power", 0.5)) {}
bool RunOnDevice() override {
template <typename T, class Context>
class AccumulateHistogramOp : public Operator<Context> {
public:
- AccumulateHistogramOp(const OperatorDef& def, Workspace* ws)
- : Operator<Context>(def, ws),
+ template <class... Args>
+ explicit AccumulateHistogramOp(Args&&... args)
+ : Operator<Context>(std::forward<Args>(args)...),
lower_bound_(
this->template GetSingleArgument<float>("lower_bound", 0.0)),
upper_bound_(
class ThrowExceptionOp : public Operator<CPUContext> {
public:
- ThrowExceptionOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<CPUContext>(operator_def, ws),
+ template <class... Args>
+ explicit ThrowExceptionOp(Args&&... args)
+ : Operator<CPUContext>(std::forward<Args>(args)...),
message_(GetSingleArgument<std::string>(
"message",
"Exception from ThrowExceptionOp")) {}
class ThrowChildThreadExceptionOp : public Operator<CPUContext> {
public:
- ThrowChildThreadExceptionOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<CPUContext>(operator_def, ws),
+ template <class... Args>
+ explicit ThrowChildThreadExceptionOp(Args&&... args)
+ : Operator<CPUContext>(std::forward<Args>(args)...),
message_(GetSingleArgument<std::string>(
"message",
"Exception from ThrowChildThreadExceptionOp")) {}
class LogFatalOp : public Operator<CPUContext> {
public:
- LogFatalOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<CPUContext>(operator_def, ws),
+ template <class... Args>
+ explicit LogFatalOp(Args&&... args)
+ : Operator<CPUContext>(std::forward<Args>(args)...),
message_(GetSingleArgument<std::string>(
"message",
"Logging from LogFatalOp")) {}
class FailOp : public Operator<CPUContext> {
public:
- FailOp(const OperatorDef& operator_def, Workspace* ws)
- : Operator<CPUContext>(operator_def, ws) {}
+ template <class... Args>
+ explicit FailOp(Args&&... args)
+ : Operator<CPUContext>(std::forward<Args>(args)...) {}
bool RunOnDevice() override {
return false;