unsigned int size = predicted.size();
float deriv_val = 1.0f / (float)size;
- deriv.apply_i([deriv_val](float x) {
+ deriv.apply_i<float>([deriv_val](float x) {
if (fabs(x) < EPSILON_) {
return 0.0f;
}
try {
float answer =
NN.inference({MAKE_SHARED_TENSOR(nntrainer::Tensor({o}, nntrainer::TensorDim::TensorType()))})[0]
- ->apply(stepFunction)
+ ->apply<float>(stepFunction)
.getValue(0, 0, 0, 0);
std::cout << answer << " : " << l[0] << std::endl;
return InPlace::RESTRICTING;
}
+ /**
+ * if the layer's input and output type is not FP32, then it cannot be
+ * inplace. We assume that the input is always FP32.
+ */
+ if (lnode->getInputConnections().empty()) {
+ if (!istrequal(getTensorType()[3], "FP32"))
+ return InPlace::NONE;
+
+ }
+
return InPlace::NONE;
}
}
// take exp
- output.apply(exp_util, output);
+ output.apply<float>(exp_util, output);
// take sum over the last dimension
Tensor sum = output.sum(3);
}
Tensor &ActiFunc::swish(Tensor const &t_in, Tensor &t_out) {
- t_in.apply([&](float x) { return sigmoid(x); }, t_out);
+ t_in.apply<float>([&](float x) { return sigmoid(x); }, t_out);
t_out.multiply_i(t_in);
return t_out;
outgoing_derivative = Tensor(t_out.getDim());
Tensor tmp = Tensor(t_out.getDim());
- t_in.apply([&](float x) { return sigmoid(x); }, outgoing_derivative);
- t_out.apply([&](float x) { return 1 - x; }, tmp);
+ t_in.apply<float>([&](float x) { return sigmoid(x); }, outgoing_derivative);
+ t_out.apply<float>([&](float x) { return 1 - x; }, tmp);
outgoing_derivative.multiply_i(tmp);
outgoing_derivative.add_i(t_out);
Tensor &ActiFunc::gelu(Tensor const &t_in, Tensor &t_out) {
float tmp = 1 / sqrt(2);
- t_in.apply([&](float x) { return 0.5 * x * (1 + erf(x * tmp)); }, t_out);
+ t_in.apply<float>([&](float x) { return 0.5 * x * (1 + erf(x * tmp)); }, t_out);
return t_out;
}
outgoing_derivative = Tensor(t_out.getDim());
float tmp = 1 / sqrt(2);
- t_in.apply(
+ t_in.apply<float>(
[&](float x) {
return 0.5 * (1 + erf(x * tmp) +
x * ((2 / sqrt(M_PI)) * exp(-pow(x * tmp, 2))) * tmp);
Tensor &y = context.getInput(SINGLE_INOUT_IDX);
// fill the output
- hidden_ = y.apply(ActiFunc::sigmoid, hidden_);
+ hidden_ = y.apply<float>(ActiFunc::sigmoid, hidden_);
if (context.isLabelAvailable(SINGLE_INOUT_IDX)) {
Tensor &y2 = context.getLabel(SINGLE_INOUT_IDX);
// @todo: change this to apply_i
// @note: the output should be logit before applying sigmoid
// log(1 + exp(-abs(y))) + max(y, 0)
- Tensor mid_term = y.apply(static_cast<float (*)(float)>(&std::fabs))
+ Tensor mid_term = y.apply<float>(static_cast<float (*)(float)>(&std::fabs))
.multiply(-1.0)
- .apply(static_cast<float (*)(float)>(&std::exp))
+ .apply<float>(static_cast<float (*)(float)>(&std::exp))
.add(1.0)
- .apply(logFloat);
- mid_term = mid_term.add(y.apply(ActiFunc::relu));
+ .apply<float>(logFloat);
+ mid_term = mid_term.add(y.apply<float>(ActiFunc::relu));
// y * y2
Tensor end_term = y2.chain().multiply_i(y).run();
const Tensor &y2 = context.getIncomingDerivative(SINGLE_INOUT_IDX);
Tensor &y = context.getInput(SINGLE_INOUT_IDX);
- y.apply(ActiFunc::sigmoid, ret_derivative);
+ y.apply<float>(ActiFunc::sigmoid, ret_derivative);
ret_derivative.subtract_i(y2);
if (ret_derivative.divide_i(ret_derivative.size()) != ML_ERROR_NONE) {
throw std::runtime_error("[CrossEntropySigmoidLossLayer::calcDerivative] "
if (context.isLabelAvailable(SINGLE_INOUT_IDX)) {
Tensor &y2 = context.getLabel(SINGLE_INOUT_IDX);
- l = y2.multiply(hidden_.apply(logFloat)).sum_by_batch().multiply(-1);
+ l = y2.multiply(hidden_.apply<float>(logFloat)).sum_by_batch().multiply(-1);
// update the loss value
LossLayer::updateLoss(context, l);
alpha_src.copy_with_stride(
fc_proj_out.getSharedDataTensor({batch, 1, 1, mol_k}, mol_k * 2, false));
- kappa_src.apply_i(&expf);
- beta_src.apply_i(&expf);
+ kappa_src.apply_i<float>(&expf);
+ beta_src.apply_i<float>(&expf);
Tensor kappa = kappa_src;
Tensor beta = beta_src;
wv.add_i(x_grad.multiply(x_grad), 1.0f - beta2);
if (torch_ref) {
- Tensor denom = wv.apply(sqrtFloat);
+ Tensor denom = wv.apply<float>(sqrtFloat);
denom.divide_i(sqrtFloat(biasCorrection2));
denom.add_i(epsilon);
wm.divide(denom, x_grad);
return 1 / (sqrtDouble(f) + epsilon);
};
- x_grad = wv.apply(sqrtEps, x_grad);
+ x_grad = wv.apply<float>(sqrtEps, x_grad);
x_grad.multiply_i(wm);
context.applyGradient(getUpdatedLearningRate(context.getIteration(),
context.getLearningRate()));
setDist<float, std::bernoulli_distribution>(
std::bernoulli_distribution(probability));
} else if (this->getDataType() == ml::train::TensorDim::DataType::FP16) {
-#ifdef ENABLE_FP16
setDist<_FP16, std::bernoulli_distribution>(
- std::bernoulli_distribution((_FP16)probability));
-#else
- throw std::invalid_argument("Error: enable-fp16 is not enabled");
-#endif
+ std::bernoulli_distribution(probability));
}
}
Tensor &Tensor::multiply(float const &value, Tensor &out) const {
/// @todo add unittest
- // if (dim.getDataType() == ml::train::TensorDim::DataType::FP32) {
+ if (dim.getDataType() == ml::train::TensorDim::DataType::FP32) {
auto f = std::bind(std::multiplies<float>(), std::placeholders::_1, value);
- return apply(f, out);
-// } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
-// #ifdef ENABLE_FP16
-// auto f = std::bind(std::multiplies<_FP16>(), std::placeholders::_1,
-// static_cast<_FP16>(value));
-// return apply(f, out);
-// #else
-// throw std::invalid_argument("Error: enable-fp16 is not enabled");
-// #endif
- // }
+ return apply<float>(f, out);
+ } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+ auto f = std::bind(std::multiplies<_FP16>(), std::placeholders::_1,
+ static_cast<_FP16>(value));
+ return apply<_FP16>(f, out);
+#else
+ throw std::invalid_argument("Error: enable-fp16 is not enabled");
+#endif
+ }
return out;
}
throw std::invalid_argument(ss.str().c_str());
}
- // if (dim.getDataType() == ml::train::TensorDim::DataType::FP32) {
+ if (dim.getDataType() == ml::train::TensorDim::DataType::FP32) {
auto f = std::bind(std::divides<float>(), std::placeholders::_1, value);
- return apply(f, out);
-// } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
-// #ifdef ENABLE_FP16
-// auto f = std::bind(std::divides<_FP16>(), std::placeholders::_1, static_cast<_FP16>(value));
-// return apply(f, out);
-// #else
-// throw std::invalid_argument("Error: enable-fp16 is not enabled");
-// #endif
-// }
+ return apply<float>(f, out);
+ } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+ auto f = std::bind(std::divides<_FP16>(), std::placeholders::_1, static_cast<_FP16>(value));
+ return apply<_FP16>(f, out);
+#else
+ throw std::invalid_argument("Error: enable-fp16 is not enabled");
+#endif
+ }
return out;
}
Tensor &Tensor::add(float const &value, Tensor &out) const {
/// @todo add unittest
- // if (dim.getDataType() == ml::train::TensorDim::DataType::FP32) {
+ if (dim.getDataType() == ml::train::TensorDim::DataType::FP32) {
auto f = std::bind(std::plus<float>(), std::placeholders::_1, value);
- return apply(f, out);
-// } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
-// #ifdef ENABLE_FP16
-// auto f = std::bind(std::plus<_FP16>(), std::placeholders::_1,
-// static_cast<_FP16>(value));
-// return apply(f, out);
-// #else
-// throw std::invalid_argument("Error: enable-fp16 is not enabled");
-// #endif
-// }
+ return apply<float>(f, out);
+ } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+ auto f = std::bind(std::plus<_FP16>(), std::placeholders::_1,
+ static_cast<_FP16>(value));
+ return apply<_FP16>(f, out);
+#else
+ throw std::invalid_argument("Error: enable-fp16 is not enabled");
+#endif
+ }
return out;
}
Tensor &Tensor::subtract(float const &value, Tensor &out) const {
/// @todo add unittest
- // if (dim.getDataType() == ml::train::TensorDim::DataType::FP32) {
+ if (dim.getDataType() == ml::train::TensorDim::DataType::FP32) {
auto f = std::bind(std::minus<float>(), std::placeholders::_1, value);
- return apply(f, out);
-// } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
-// #ifdef ENABLE_FP16
-// auto f = std::bind(std::minus<_FP16>(), std::placeholders::_1,
-// static_cast<_FP16>(value));
-// return apply(f, out);
-// #else
-// ml_loge("%s", "Error: enable-fp16 is not enabled");
-// #endif
-// }
+ return apply<float>(f, out);
+ } else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+ auto f = std::bind(std::minus<_FP16>(), std::placeholders::_1,
+ static_cast<_FP16>(value));
+ return apply<_FP16>(f, out);
+#else
+ ml_loge("%s", "Error: enable-fp16 is not enabled");
+#endif
+ }
return out; // shouldn't reach
}
}
Tensor &Tensor::pow(float exponent, Tensor &out) const {
- // if (dim.getDataType() == ml::train::TensorDim::DataType::FP32) {
+ if (dim.getDataType() == ml::train::TensorDim::DataType::FP32) {
auto f = [exponent](float in) { return powf(in, exponent); };
- return apply(f, out);
- // }
-// if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
-// #ifdef ENABLE_FP16
-// auto f = [exponent](_FP16 in) {
-// return static_cast<_FP16>(powf(in, exponent));
-// };
-// return apply(f, out);
-// #else
-// ml_loge("%s", "Error: enable-fp16 is not enabled");
-// #endif
-// }
- // return out;
+ return apply<float>(f, out);
+ }
+ if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
+#ifdef ENABLE_FP16
+ auto f = [exponent](_FP16 in) {
+ return static_cast<_FP16>(powf(in, exponent));
+ };
+ return apply<_FP16>(f, out);
+#else
+ ml_loge("%s", "Error: enable-fp16 is not enabled");
+#endif
+ }
+ return out;
}
Tensor Tensor::getBatchSlice(size_t offset, unsigned int size) const {
ret_dims[i].width(), ret_dims[i].channel()};
}
- ret_t.apply_i([&iter_value, &loc, &end_loc, &reset_dim_arr](float _) {
+ ret_t.apply_i<float>([&iter_value, &loc, &end_loc, &reset_dim_arr](float _) {
return iter_value(loc, end_loc, reset_dim_arr);
});
}
ret_dims[i].width(), ret_dims[i].channel()};
}
- ret_t.apply_i([&iter_value, &loc, &end_loc, &reset_dim_arr](float _) {
+ ret_t.apply_i<_FP16>([&iter_value, &loc, &end_loc, &reset_dim_arr](_FP16 _) {
return iter_value(loc, end_loc, reset_dim_arr);
});
}
if (contiguous)
sscal(size(), 0, getData<float>(), 1);
else
- apply_i([](float val) -> float { return 0; });
+ apply_i<float>([](float val) -> float { return 0; });
} else if (dim.getDataType() == ml::train::TensorDim::DataType::FP16) {
#ifdef ENABLE_FP16
if (contiguous)
sscal(size(), 0, getData<_FP16>(), 1);
else
- apply_i([](float val) -> float { return 0; });
+ apply_i<_FP16>([](_FP16 val) -> _FP16 { return 0; });
#else
throw std::invalid_argument("Error: enable-fp16 is not enabled");
#endif
* @param f function to apply
* @return int ML_ERROR_NONE if successful
*/
- int apply_i(std::function<float(float)> f) {
+ template <typename T = float>
+ int apply_i(std::function<T(T)> f) {
Tensor result = *this;
- apply(f, result);
+ apply<T>(f, result);
return ML_ERROR_NONE;
};
* @param[in] *function function pointer applied
* @retval Tensor
*/
- Tensor apply(std::function<float(float)> f) const {
+ template <typename T = float>
+ Tensor apply(std::function<T(T)> f) const {
Tensor result;
- return apply(f, result);
+ return apply<T>(f, result);
};
/**
* @param[out] output output tensor
* @retval Tensor
*/
-
- Tensor &apply(std::function<float(float)> f, Tensor &output) const {
+ template <typename T = float>
+ Tensor &apply(std::function<T(T)> f, Tensor &output) const {
CREATE_IF_EMPTY_DIMS(output, dim, nullptr);
if (dim != output.dim) {
"[Tensor::apply] output dimension does not match");
}
- if (dim.getDataType() == Tdatatype::FP32) {
- if (contiguous && output.contiguous) {
- const float *data = (getData<float>());
- float *rdata = (output.getData<float>());
-
- std::transform(data, data + size(), rdata, f);
- } else if (strides[3] == 1 && output.strides[3] == 1) {
- /** @todo optimize this with combining these loops where stride is 1 */
- for (unsigned int b = 0; b < batch(); ++b) {
- for (unsigned int c = 0; c < channel(); ++c) {
- for (unsigned int h = 0; h < height(); ++h) {
- float *out_data = output.getAddress<float>(b, c, h, 0);
- const float *in_data = getAddress<float>(b, c, h, 0);
- std::transform(in_data, in_data + width(), out_data, f);
- }
+ if (contiguous && output.contiguous) {
+ const T *data = (getData<T>());
+ T *rdata = (output.getData<T>());
+
+ std::transform(data, data + size(), rdata, f);
+ } else if (strides[3] == 1 && output.strides[3] == 1) {
+ /** @todo optimize this with combining these loops where stride is 1 */
+ for (unsigned int b = 0; b < batch(); ++b) {
+ for (unsigned int c = 0; c < channel(); ++c) {
+ for (unsigned int h = 0; h < height(); ++h) {
+ T *out_data = output.getAddress<T>(b, c, h, 0);
+ const T *in_data = getAddress<T>(b, c, h, 0);
+ std::transform(in_data, in_data + width(), out_data, f);
}
}
- } else {
- for (unsigned int b = 0; b < batch(); ++b) {
- for (unsigned int c = 0; c < channel(); ++c) {
- for (unsigned int h = 0; h < height(); ++h) {
- for (unsigned int w = 0; w < width(); ++w) {
- output.setValue(b, c, h, w, f(getValue<float>(b, c, h, w)));
- }
+ }
+ } else {
+ for (unsigned int b = 0; b < batch(); ++b) {
+ for (unsigned int c = 0; c < channel(); ++c) {
+ for (unsigned int h = 0; h < height(); ++h) {
+ for (unsigned int w = 0; w < width(); ++w) {
+ output.setValue(b, c, h, w, f(getValue<T>(b, c, h, w)));
}
}
}
}
- } else if (dim.getDataType() == Tdatatype::FP16) {
-
- auto f_16 = [f](_FP16 x) -> _FP16 {
- return static_cast<_FP16>(f(static_cast<float>(x)));
- };
+ }
- // std::function<_FP16(_FP16)> f_16 =
- // static_cast<std::function<_FP16(_FP16)>>(f);
+ // if (dim.getDataType() == Tdatatype::FP32) {
+ // if (contiguous && output.contiguous) {
+ // const float *data = (getData<float>());
+ // float *rdata = (output.getData<float>());
+
+ // std::transform(data, data + size(), rdata, f);
+ // } else if (strides[3] == 1 && output.strides[3] == 1) {
+ // /** @todo optimize this with combining these loops where stride is 1 */
+ // for (unsigned int b = 0; b < batch(); ++b) {
+ // for (unsigned int c = 0; c < channel(); ++c) {
+ // for (unsigned int h = 0; h < height(); ++h) {
+ // float *out_data = output.getAddress<float>(b, c, h, 0);
+ // const float *in_data = getAddress<float>(b, c, h, 0);
+ // std::transform(in_data, in_data + width(), out_data, f);
+ // }
+ // }
+ // }
+ // } else {
+ // for (unsigned int b = 0; b < batch(); ++b) {
+ // for (unsigned int c = 0; c < channel(); ++c) {
+ // for (unsigned int h = 0; h < height(); ++h) {
+ // for (unsigned int w = 0; w < width(); ++w) {
+ // output.setValue(b, c, h, w, f(getValue<float>(b, c, h, w)));
+ // }
+ // }
+ // }
+ // }
+ // }
+ // } else if (dim.getDataType() == Tdatatype::FP16) {
+
+ // auto f_16 = [f](_FP16 x) -> _FP16 {
+ // return static_cast<_FP16>(f(static_cast<float>(x)));
+ // };
+
+ // // std::function<_FP16(_FP16)> f_16 =
+ // // static_cast<std::function<_FP16(_FP16)>>(f);
- if (contiguous && output.contiguous) {
- const _FP16 *data = (getData<_FP16>());
- _FP16 *rdata = (output.getData<_FP16>());
-
- std::transform(data, data + size(), rdata, f_16);
- } else if (strides[3] == 1 && output.strides[3] == 1) {
- /** @todo optimize this with combining these loops where stride is 1 */
- for (unsigned int b = 0; b < batch(); ++b) {
- for (unsigned int c = 0; c < channel(); ++c) {
- for (unsigned int h = 0; h < height(); ++h) {
- _FP16 *out_data = output.getAddress<_FP16>(b, c, h, 0);
- const _FP16 *in_data = getAddress<_FP16>(b, c, h, 0);
- std::transform(in_data, in_data + width(), out_data, f_16);
- }
- }
- }
- } else {
- for (unsigned int b = 0; b < batch(); ++b) {
- for (unsigned int c = 0; c < channel(); ++c) {
- for (unsigned int h = 0; h < height(); ++h) {
- for (unsigned int w = 0; w < width(); ++w) {
- output.setValue(b, c, h, w, f_16(getValue<_FP16>(b, c, h, w)));
- }
- }
- }
- }
- }
- }
+ // if (contiguous && output.contiguous) {
+ // const _FP16 *data = (getData<_FP16>());
+ // _FP16 *rdata = (output.getData<_FP16>());
+
+ // std::transform(data, data + size(), rdata, f_16);
+ // } else if (strides[3] == 1 && output.strides[3] == 1) {
+ // /** @todo optimize this with combining these loops where stride is 1 */
+ // for (unsigned int b = 0; b < batch(); ++b) {
+ // for (unsigned int c = 0; c < channel(); ++c) {
+ // for (unsigned int h = 0; h < height(); ++h) {
+ // _FP16 *out_data = output.getAddress<_FP16>(b, c, h, 0);
+ // const _FP16 *in_data = getAddress<_FP16>(b, c, h, 0);
+ // std::transform(in_data, in_data + width(), out_data, f_16);
+ // }
+ // }
+ // }
+ // } else {
+ // for (unsigned int b = 0; b < batch(); ++b) {
+ // for (unsigned int c = 0; c < channel(); ++c) {
+ // for (unsigned int h = 0; h < height(); ++h) {
+ // for (unsigned int w = 0; w < width(); ++w) {
+ // output.setValue(b, c, h, w, f_16(getValue<_FP16>(b, c, h, w)));
+ // }
+ // }
+ // }
+ // }
+ // }
+ // }
return output;
};
/**
* @brief Apply the gradient to the weight
*/
- void applyGradient(double lr) { var->add_i(*grad.get(), -lr); }
+ void applyGradient(double lr) {
+ var->add_i(*grad.get(), -lr);
+ }
/**
* @brief Check if the gradient is supposed to be clipped by global norm with
double sqrtDouble(double x) { return sqrt(x); };
float logFloat(float x) { return log(x + 1.0e-20); }
-1103
float exp_util(float x) { return exp(x); }
nntrainer::Tensor input(batch, channel, height, width);
GEN_TEST_INPUT(input, (l - 4) * 0.1 * (i + 1));
- nntrainer::Tensor Results = input.apply(nntrainer::ActiFunc::sigmoid);
+ nntrainer::Tensor Results = input.apply<float>(nntrainer::ActiFunc::sigmoid);
float *data = Results.getData();
ASSERT_NE(nullptr, data);
GEN_TEST_INPUT(input, (l - 4) * 0.1 * (i + 1));
nntrainer::Tensor sigmoid_result =
- input.apply(nntrainer::ActiFunc::sigmoid);
+ input.apply<float>(nntrainer::ActiFunc::sigmoid);
float *data = sigmoid_result.getData();
ASSERT_NE(nullptr, data);
nntrainer::Tensor prime_result =
- sigmoid_result.apply(nntrainer::ActiFunc::sigmoidPrime);
+ sigmoid_result.apply<float>(nntrainer::ActiFunc::sigmoidPrime);
data = prime_result.getData();
ASSERT_NE(nullptr, data);
nntrainer::Tensor input(batch, channel, height, width);
GEN_TEST_INPUT(input, (l - 4) * 0.1 * (i + 1));
- nntrainer::Tensor Results = input.apply(nntrainer::ActiFunc::tanhFloat);
+ nntrainer::Tensor Results = input.apply<float>(nntrainer::ActiFunc::tanhFloat);
float *data = Results.getData();
ASSERT_NE(nullptr, data);
nntrainer::Tensor input(batch, channel, height, width);
GEN_TEST_INPUT(input, (l - 4) * 0.1 * (i + 1));
- nntrainer::Tensor tanh_result = input.apply(nntrainer::ActiFunc::tanhFloat);
+ nntrainer::Tensor tanh_result = input.apply<float>(nntrainer::ActiFunc::tanhFloat);
float *data = tanh_result.getData();
ASSERT_NE(nullptr, data);
nntrainer::Tensor prime_result =
- tanh_result.apply(nntrainer::ActiFunc::tanhPrime);
+ tanh_result.apply<float>(nntrainer::ActiFunc::tanhPrime);
data = prime_result.getData();
ASSERT_NE(nullptr, data);
nntrainer::Tensor input(batch, channel, height, width);
GEN_TEST_INPUT(input, (l - 4) * 0.1 * (i + 1));
- nntrainer::Tensor Results = input.apply(nntrainer::ActiFunc::relu);
+ nntrainer::Tensor Results = input.apply<float>(nntrainer::ActiFunc::relu);
float *data = Results.getData();
ASSERT_NE(nullptr, data);
nntrainer::Tensor input(batch, channel, height, width);
GEN_TEST_INPUT(input, (l - 4) * 0.1 * (i + 1));
- nntrainer::Tensor relu_result = input.apply(nntrainer::ActiFunc::relu);
+ nntrainer::Tensor relu_result = input.apply<float>(nntrainer::ActiFunc::relu);
float *data = relu_result.getData();
ASSERT_NE(nullptr, data);
nntrainer::Tensor prime_result =
- relu_result.apply(nntrainer::ActiFunc::reluPrime);
+ relu_result.apply<float>(nntrainer::ActiFunc::reluPrime);
data = prime_result.getData();
ASSERT_NE(nullptr, data);
*/
nntrainer::Tensor constant_(float value) {
nntrainer::Tensor t(batch, channel, height, width);
- return t.apply([value](float) { return value; });
+ return t.apply<float>([value](float) { return value; });
}
nntrainer::Tensor target;
nntrainer::Tensor input(batch, channel, height, width);
GEN_TEST_INPUT(input, i * (width) + k + 1);
- nntrainer::Tensor Results = input.apply(nntrainer::logFloat);
+ nntrainer::Tensor Results = input.apply<float>(nntrainer::logFloat);
float *data = Results.getData();
ASSERT_NE(nullptr, data);