const float* X_data = X.data<float>();
const float* Y_data = Y.data<float>();
// Auxiliary arrays, one allocation of memory
- aux_.Resize(2 * N);
+ ReinitializeTensor(&aux_, {2 * N}, at::dtype<float>().device(CUDA));
float* aux_data = aux_.mutable_data<float>();
float* x2 = aux_data;
float* y2 = aux_data + N;
auto* dY_data = dY->template mutable_data<float>();
// one memory allocation, a few arrays
- aux_.Resize(6 * N);
+ ReinitializeTensor(&aux_, {6 * N}, at::dtype<float>().device(CUDA));
float* aux_data = aux_.mutable_data<float>();
float* xn = aux_data;
float* yn = aux_data + N;
OUTPUT_TAGS(COS_OUT);
private:
- Tensor aux_{Context::GetDeviceType()};
+ Tensor aux_;
};
template <typename T, class Context>
OUTPUT_TAGS(DER_X_OUT, DER_Y_OUT);
private:
- Tensor aux_{Context::GetDeviceType()};
+ Tensor aux_;
};
template <typename T, class Context>
math_type);
// Add bias term
Tensor bias_multiplier(cache->bias_multiplier_);
- if (bias_multiplier.numel() != M) {
- // If the helper bias multiplier is not M, reshape and fill it with one.
- bias_multiplier.Resize(M);
- caffe2::math::Set<DataType, Context>(
- M,
- caffe2::convert::To<float, DataType>(1),
- bias_multiplier.template mutable_data<DataType>(),
- static_cast<Context*>(&context));
- }
+ ReinitializeTensor(&bias_multiplier, {M}, at::dtype<DataType>().device(CPU));
+ caffe2::math::Set<DataType, Context>(
+ M,
+ caffe2::convert::To<float, DataType>(1),
+ bias_multiplier.template mutable_data<DataType>(),
+ static_cast<Context*>(&context));
caffe2::math::Gemm<DataType, Context, caffe2::DefaultEngine>(
CblasNoTrans,
CblasNoTrans,
struct Cache final {
vector<int64_t> Y_shape_cache_;
- C10Tensor bias_multiplier_ = C10Tensor(Tensor{CPU});
+ C10Tensor bias_multiplier_ = C10Tensor(Tensor());
};
using Signature = void(
#ifndef CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
#define CAFFE2_OPERATORS_FULLY_CONNECTED_OP_H_
+#include <c10/util/Optional.h>
#include "caffe2/core/context.h"
#include "caffe2/core/operator.h"
#include "caffe2/utils/conversions.h"
&context_,
math_type);
// Add bias term
- if (bias_multiplier_.numel() != M) {
- // If the helper bias multiplier is not M, reshape and fill it with one.
- bias_multiplier_.Resize(M);
+ if (!bias_multiplier_.has_value()) {
+ bias_multiplier_ = caffe2::empty({M}, at::dtype<T_B>().device(Context::GetDeviceType()));
math::Set<T_B, Context>(
M,
convert::To<float, T_B>(1),
- bias_multiplier_.template mutable_data<T_B>(),
+ bias_multiplier_->template mutable_data<T_B>(),
+ &context_);
+ } else if (bias_multiplier_->numel() != M) {
+ bias_multiplier_->Resize(M);
+ math::Set<T_B, Context>(
+ M,
+ convert::To<float, T_B>(1),
+ bias_multiplier_->template mutable_data<T_B>(),
&context_);
}
+
math::Gemm<T_B, Context, Engine>(
CblasNoTrans,
CblasNoTrans,
N,
1,
1,
- bias_multiplier_.template data<T_B>(),
+ bias_multiplier_->template data<T_B>(),
b.template data<T_B>(),
1,
Y->template mutable_data<T_Y>(),
// A local vector to cache the output shape so we don't need to recreate
// a vector object every time we run Run().
vector<int64_t> Y_shape_cache_;
- Tensor bias_multiplier_{Context::GetDeviceType()};
+ c10::optional<Tensor> bias_multiplier_;
bool float16_compute_;
};
dW->template mutable_data<T_DW>(),
&context_,
math_type);
- if (bias_multiplier_.numel() != M) {
- // If the helper bias multiplier is not M, reshape and fill it
- // with one.
- bias_multiplier_.Resize(M);
+ if (!bias_multiplier_.has_value()) {
+ bias_multiplier_ = caffe2::empty({M}, at::dtype<T_B>().device(Context::GetDeviceType()));
+ math::Set<T_B, Context>(
+ M,
+ convert::To<float, T_B>(1),
+ bias_multiplier_->template mutable_data<T_B>(),
+ &context_);
+ } else if (bias_multiplier_->numel() != M) {
+ bias_multiplier_->Resize(M);
math::Set<T_B, Context>(
M,
convert::To<float, T_B>(1),
- bias_multiplier_.template mutable_data<T_B>(),
+ bias_multiplier_->template mutable_data<T_B>(),
&context_);
}
// Compute dB
N,
1,
dY.template data<T_DY>(),
- bias_multiplier_.template data<T_B>(),
+ bias_multiplier_->template data<T_B>(),
0,
db->template mutable_data<T_DB>(),
&context_);
protected:
size_t axis_{1};
size_t axis_w_{1};
- Tensor bias_multiplier_{Context::GetDeviceType()};
+ c10::optional<Tensor> bias_multiplier_;
bool float16_compute_;
};
<< " given size: " << source_values.size();
auto str = source_values[0];
- values_.Resize(str.size());
+ ReinitializeTensor(&values_, {static_cast<int64_t>(str.size())}, at::dtype<uint8_t>().device(CPU));
uint8_t* values_data = values_.template mutable_data<uint8_t>();
for (int i = 0; i < str.size(); i++) {
values_data[i] = static_cast<uint8_t>(str[i]);
}
- }
+}
- Tensor values_{CPU};
+Tensor values_;
};
} // namespace caffe2
void ExtractValues() {
auto source_values =
this->template GetRepeatedArgument<Type>("values");
- values_.Resize(source_values.size());
+ ReinitializeTensor(&values_, {static_cast<int64_t>(source_values.size())}, at::dtype<Type>().device(CPU));
Type* values_data = values_.template mutable_data<Type>();
for (int i = 0; i < source_values.size(); i++) {
values_data[i] = static_cast<Type>(source_values[i]);
}
bool (GivenTensorFillOp::*body_)(Tensor* output);
- Tensor values_{CPU};
+ Tensor values_;
};
} // namespace caffe2
// dL/ds = Sum(dL/dY * gamma * X)
// dL/db = Sum(dL/dY * gamma)
const int C = G * D;
- ds_.Resize(N, G);
- db_.Resize(N, G);
+ ReinitializeTensor(
+ &ds_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
+ ReinitializeTensor(
+ &db_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
T* ds_data = ds_.template mutable_data<T>();
T* db_data = db_.template mutable_data<T>();
math::Set<T, Context>(N * G, T(0), ds_data, &context_);
float* dbeta_data) {
const int size = N * G * D * HxW;
const int C = G * D;
- ds_.Resize(N, G);
- db_.Resize(N, G);
+ ReinitializeTensor(
+ &ds_, {N, G}, at::dtype<float>().device(CUDA));
+ ReinitializeTensor(
+ &db_, {N, G}, at::dtype<float>().device(CUDA));
float* ds_data = ds_.mutable_data<float>();
float* db_data = db_.mutable_data<float>();
if (order_ == StorageOrder::NCHW) {
mu_data = mu->template mutable_data<T>();
rsig_data = rsig->template mutable_data<T>();
} else {
- mu_.Resize(N, G);
- rsig_.Resize(N, G);
+ ReinitializeTensor(&mu_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
+ ReinitializeTensor(&rsig_, {N, G}, at::dtype<T>().device(Context::GetDeviceType()));
mu_data = mu_.template mutable_data<T>();
rsig_data = rsig_.template mutable_data<T>();
}
T* mu,
T* rsig) {
const int C = G * D;
- scale_.Resize(N, C);
- bias_.Resize(N, C);
+ ReinitializeTensor(&scale_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
+ ReinitializeTensor(&bias_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
T* scale_data = scale_.template mutable_data<T>();
T* bias_data = bias_.template mutable_data<T>();
if (order_ == StorageOrder::NCHW) {
const StorageOrder order_;
const bool is_test_;
- Tensor mu_{Context::GetDeviceType()};
- Tensor rsig_{Context::GetDeviceType()};
- Tensor scale_{Context::GetDeviceType()};
- Tensor bias_{Context::GetDeviceType()};
+ Tensor mu_;
+ Tensor rsig_;
+ Tensor scale_;
+ Tensor bias_;
// Input: X, gamma, beta
// Output: Y, mu, inv_sig
const int group_;
const StorageOrder order_;
- Tensor ds_{Context::GetDeviceType()};
- Tensor db_{Context::GetDeviceType()};
+ Tensor ds_;
+ Tensor db_;
// Input: dY, X, gamma, beta, mu, inv_sig
// Output: dX, dgamma, dbeta
//Softmax
float* softmax_output_data = int_output + int_output_offset;
- if (scale_.numel() != 1) {
- scale_.Resize(1);
+ if (!scale_.has_value()) {
+ scale_ = caffe2::empty({1}, at::dtype<float>().device(CPU));
}
- if (sum_multiplier_.numel() != dim_out) {
- sum_multiplier_.Resize(dim_out);
+
+ if (!sum_multiplier_.has_value()) {
+ sum_multiplier_ = caffe2::empty({dim_out}, at::dtype<float>().device(CPU));
+ math::Set<float, CPUContext>(dim_out, 1.f,
+ sum_multiplier_->mutable_data<float>(), &context_);
+ } else if (sum_multiplier_->numel() != dim_out) {
+ sum_multiplier_->Resize(dim_out);
math::Set<float, CPUContext>(dim_out, 1.f,
- sum_multiplier_.mutable_data<float>(), &context_);
+ sum_multiplier_->mutable_data<float>(), &context_);
}
math::RowwiseMax<float, CPUContext>(1, dim_out, fc_output_data,
- scale_.mutable_data<float>(), &context_);
+ scale_->mutable_data<float>(), &context_);
// Put the intermediate result X - max(X) into Y
context_.template CopyFromCPU<float>(
dim_out, fc_output_data, softmax_output_data);
// Subtract the scale
math::Gemv<float, CPUContext>(CblasNoTrans, dim_out, 1, -1,
- sum_multiplier_.data<float>(), scale_.data<float>(), 1, softmax_output_data,
+ sum_multiplier_->data<float>(), scale_->data<float>(), 1, softmax_output_data,
&context_);
// Exponentiation
math::Exp<float, CPUContext>(dim_out, softmax_output_data,
softmax_output_data, &context_);
math::Gemv<float, CPUContext>(CblasNoTrans, 1, dim_out, 1,
- softmax_output_data, sum_multiplier_.data<float>(), 0,
- scale_.mutable_data<float>(), &context_);
+ softmax_output_data, sum_multiplier_->data<float>(), 0,
+ scale_->mutable_data<float>(), &context_);
// Do division
- const float scale = *scale_.data<float>();
+ const float scale = *(scale_->data<float>());
for (int j = 0; j < dim_out; ++j) {
softmax_output_data[j] /= scale;
}
float* int_output_data = intermediate_output->template mutable_data<float>();
int int_output_offset = 0;
- if (bias_multiplier_.numel() != M) {
- bias_multiplier_.Resize(M);
+ if (!bias_multiplier_.has_value()) {
+ bias_multiplier_ = caffe2::empty({M}, at::dtype<float>().device(CPU));
+ math::Set<float, CPUContext>(M, static_cast<float>(1),
+ bias_multiplier_->mutable_data<float>(), &context_);
+ } else if (bias_multiplier_->numel() != M) {
+ bias_multiplier_->Resize(M);
math::Set<float, CPUContext>(M, static_cast<float>(1),
- bias_multiplier_.mutable_data<float>(), &context_);
+ bias_multiplier_->mutable_data<float>(), &context_);
}
for (int sample = 0; sample < M; ++sample) {
//Adding log probabilities
Ydata[sample] += RunForwardSingle(X.data<float>() + sample*K,
W.data<float>() + w_offset*K, b.data<float>() + w_offset, target,
- int_output_data, bias_multiplier_.data<float>()+sample, w_length, K,
+ int_output_data, bias_multiplier_->data<float>()+sample, w_length, K,
int_output_offset);
}
}
int_output_offset -= dim_out;
//Softmax
- if (scale_.numel() != 1) {
- scale_.Resize(1);
+ if (!scale_.has_value()) {
+ scale_ = caffe2::empty({1}, at::dtype<float>().device(CPU));
}
- float* scaledata = scale_.mutable_data<float>();
+ float* scaledata = scale_->mutable_data<float>();
- if (sum_multiplier_.numel() != dim_out) {
- sum_multiplier_.Resize(dim_out);
+ if (!sum_multiplier_.has_value()) {
+ sum_multiplier_ = caffe2::empty({dim_out}, at::dtype<float>().device(CPU));
math::Set<float, CPUContext>(dim_out, 1.f,
- sum_multiplier_.mutable_data<float>(), &context_);
+ sum_multiplier_->mutable_data<float>(), &context_);
+ } else if (sum_multiplier_->numel() != dim_out) {
+ sum_multiplier_->Resize(dim_out);
+ math::Set<float, CPUContext>(dim_out, 1.f,
+ sum_multiplier_->mutable_data<float>(), &context_);
}
float* dX_softmax = dint_output + int_output_offset - dim_out;
math::Dot<float, CPUContext>(dim_out, X_entropy, dX_entropy, scaledata,
&context_);
math::Gemv<float, CPUContext>(CblasTrans, 1, dim_out, -1,
- sum_multiplier_.data<float>(), scaledata , 1, dX_softmax, &context_);
+ sum_multiplier_->data<float>(), scaledata , 1, dX_softmax, &context_);
math::Mul<float, CPUContext>(dim_out, dX_softmax, X_entropy, dX_softmax,
&context_);
int_output_offset -= dim_out;
//FC
- if (bias_multiplier_.numel() != 1) {
+ if (!bias_multiplier_.has_value()) {
// If the helper bias multiplier has not been created, reshape and fill
// it with 1
- bias_multiplier_.Resize(1);
+ bias_multiplier_ = caffe2::empty({1}, at::dtype<float>().device(CPU));
math::Set<float, CPUContext>(1, static_cast<float>(1),
- bias_multiplier_.template mutable_data<float>(), &context_);
+ bias_multiplier_->template mutable_data<float>(), &context_);
}
// Compute dW and add incrementally
// Compute dB and add incrementally
// db = db + dX_softmax*bias_multiplier_
math::Gemv<float, CPUContext>(CblasTrans, 1, dim_out, 1, dX_softmax,
- bias_multiplier_.template data<float>(), 1, db, &context_);
+ bias_multiplier_->template data<float>(), 1, db, &context_);
// Compute dX and add incrementally
// dX = dX + W'dX_softmax
b + w_offset,
-1,
int_output_data,
- bias_multiplier_.template data<float>() + sample,
+ bias_multiplier_->template data<float>() + sample,
w_length,
K,
int_output_offset);
auto* Y_names = Output(0, {M, top_n_}, at::dtype<string>());
auto* Y_scores = Output(1, {M, top_n_}, at::dtype<float>());
- if (bias_multiplier_.numel() != M) {
- bias_multiplier_.Resize(M);
- math::Set<float, CPUContext>(
- M,
- static_cast<float>(1),
- bias_multiplier_.mutable_data<float>(),
- &context_);
+ if (!bias_multiplier_.has_value()) {
+ bias_multiplier_ = caffe2::empty({M}, at::dtype<float>().device(CPU));
+ math::Set<float, CPUContext>(M, static_cast<float>(1),
+ bias_multiplier_->mutable_data<float>(), &context_);
+ } else if (bias_multiplier_->numel() != M) {
+ bias_multiplier_->Resize(M);
+ math::Set<float, CPUContext>(M, static_cast<float>(1),
+ bias_multiplier_->mutable_data<float>(), &context_);
}
for (int sample = 0; sample < M; ++sample) {
#ifndef CAFFE2_OPERATORS_H_SOFTMAX_OP_H_
#define CAFFE2_OPERATORS_H_SOFTMAX_OP_H_
+#include <c10/util/Optional.h>
#include "caffe2/core/context.h"
#include "caffe2/core/logging.h"
#include "caffe2/core/operator.h"
protected:
std::unordered_map<int, PathProto> hierarchy_all_map_;
- Tensor scale_{Context::GetDeviceType()};
- Tensor sum_multiplier_{Context::GetDeviceType()};
- Tensor bias_multiplier_{Context::GetDeviceType()};
+ c10::optional<Tensor> scale_;
+ c10::optional<Tensor> sum_multiplier_;
+ c10::optional<Tensor> bias_multiplier_;
static constexpr T kLOG_THRESHOLD() {
return 1e-20f;
}
// Resize before we get into the per-instance loop
if (InputSize() < 5) {
- mean_.Resize(N, C);
+ ReinitializeTensor(
+ &mean_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
}
if (InputSize() < 6) {
- inv_stdev_.Resize(N, C);
+ ReinitializeTensor(
+ &inv_stdev_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
}
// looping over per-instance and using Eigen blocks to extract out
// Compute mean if it wasn't passed in
if (InputSize() < 5) {
- mean_.Resize(N, C);
+ ReinitializeTensor(
+ &mean_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
EigenVectorArrayMap<T> mean_mutable_arr(
mean_.template mutable_data<T>(), N * C);
mean_mutable_arr = input_mat.colwise().mean();
// compute 1 / stdev if not passed in
if (InputSize() < 6) {
- inv_stdev_.Resize(N, C);
+ ReinitializeTensor(
+ &inv_stdev_, {N, C}, at::dtype<T>().device(Context::GetDeviceType()));
EigenVectorArrayMap<T> inv_stdev_mutable_arr(
inv_stdev_.template mutable_data<T>(), N * C);
const auto dim_stride = C;
if (InputSize() < 5) {
- mean_.Resize(N, C);
+ ReinitializeTensor(&mean_, {N, C}, at::dtype<float>().device(CUDA));
auto mean_mutable_data = mean_.mutable_data<float>();
InstanceNormMeanKernel<<<
CAFFE_GET_BLOCKS(N * C),
const auto mean_data = mean.data<float>();
if (InputSize() < 6) {
- inv_stdev_.Resize(N, C);
+ ReinitializeTensor(&inv_stdev_, {N, C}, at::dtype<float>().device(CUDA));
auto inv_stdev_mutable_data = inv_stdev_.mutable_data<float>();
InstanceNormInvStdevKernel<<<
CAFFE_GET_BLOCKS(N * C),
// temp results that could get passed through to this gradient, but if not,
// are stored here
- Tensor mean_{Context::GetDeviceType()};
- Tensor inv_stdev_{Context::GetDeviceType()};
+ Tensor mean_;
+ Tensor inv_stdev_;
INPUT_TAGS(INPUT, SCALE, BIAS, OUTPUT_GRAD, MEAN, INV_STDEV);
OUTPUT_TAGS(INPUT_GRAD, SCALE_GRAD, BIAS_GRAD);
// Col pass reduces shape to (N, C, H, W)
vector<int64_t> row_pass_shape(dY.sizes().vec());
row_pass_shape[3] -= 1;
- row_pass_buffer_.Resize(row_pass_shape);
+ ReinitializeTensor(&row_pass_buffer_, row_pass_shape, at::dtype<float>().device(CUDA));
const int chans = row_pass_buffer_.dim32(1);
const int rows_out = row_pass_buffer_.dim32(2);
const int cols_out = row_pass_buffer_.dim32(3);
bool RunOnDevice() override;
protected:
- Tensor row_pass_buffer_{Context::GetDeviceType()};
+ Tensor row_pass_buffer_;
};
} // namespace caffe2
const int N = X.size_from_dim(canonical_axis);
auto* dX = Output(0, X.sizes(), at::dtype<T>());
- ds_.Resize(M);
- db_.Resize(M);
- dY_scale_.Resize(M);
- X_scale_.Resize(M);
- bias_.Resize(M);
+ ReinitializeTensor(&ds_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
+ ReinitializeTensor(&db_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
+ ReinitializeTensor(&dY_scale_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
+ ReinitializeTensor(&X_scale_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
+ ReinitializeTensor(&bias_, {M}, at::dtype<T>().device(Context::GetDeviceType()));
const T* dY_data = dY.template data<T>();
const T* X_data = X.template data<T>();
const T* mean_data = mean.template data<T>();
const int axis_;
- Tensor ds_{Context::GetDeviceType()};
- Tensor db_{Context::GetDeviceType()};
- Tensor dY_scale_{Context::GetDeviceType()};
- Tensor X_scale_{Context::GetDeviceType()};
- Tensor bias_{Context::GetDeviceType()};
+ Tensor ds_;
+ Tensor db_;
+ Tensor dY_scale_;
+ Tensor X_scale_;
+ Tensor bias_;
};
} // namespace caffe2
bool LengthsTileOp<CUDAContext>::RunOnDevice() {
auto& data = Input(DATA);
auto& lengths = Input(LENGTHS);
-
+
CAFFE_ENFORCE_EQ(lengths.ndim(), 1, "LENGTHS must be 1-D");
CAFFE_ENFORCE_GE(data.ndim(), 1, "DATA should be at least 1-D");
auto numElements = total_length * numElementsPerRow;
auto numBlocks = CAFFE_GET_BLOCKS(numElements);
- rowMappingHost_.Resize(total_length);
- rowMappingDevice_.Resize(total_length);
+ ReinitializeTensor(&rowMappingHost_, {total_length}, at::dtype<int32_t>().device(CPU));
+ ReinitializeTensor(&rowMappingDevice_, {total_length}, at::dtype<int32_t>().device(CPU));
auto* rowOffsets = rowMappingHost_.mutable_data<int32_t>();
int32_t outputRow = 0;
for (int64_t i = 0; i < lengths_size; i++) {
private:
Tensor lengths_host_{CPU};
- Tensor rowMappingHost_{CPU};
- Tensor rowMappingDevice_{Context::GetDeviceType()};
+ Tensor rowMappingHost_;
+ Tensor rowMappingDevice_;
};
} // namespace caffe2
new_size <<= 1;
}
if (new_size != old_size) {
- inv_log_i_.Resize(new_size);
+ ReinitializeTensor(
+ &inv_log_i_,
+ {new_size},
+ at::dtype<float>().device(CPU));
auto* data = inv_log_i_.template mutable_data<float>();
EigenVectorArrayMap<float> vec(data, inv_log_i_.numel());
const float log2f_ = std::log(2.f);
template <>
void LambdaRankNdcgOp<float, CPUContext>::ComputeDiscounts(int* idx, int N) {
- discount_.Resize(N);
+ ReinitializeTensor(
+ &discount_, {N}, at::dtype<float>().device(CPU));
auto* discount_data = discount_.template mutable_data<float>();
auto* inv_log_i_data = inv_log_i_.template mutable_data<float>();
for (int i = 0; i < N; i++) {
return 0;
}
- ideal_idx_.Resize(N);
- rank_idx_.Resize(N);
+ ReinitializeTensor(
+ &ideal_idx_, {N}, at::dtype<int>().device(CPU));
+ ReinitializeTensor(
+ &rank_idx_, {N}, at::dtype<int>().device(CPU));
auto* rank_idx_data = rank_idx_.template mutable_data<int>();
auto* ideal_idx_data = ideal_idx_.template mutable_data<int>();
}
const double log2f_ = std::log(2.f);
- gain_.Resize(N);
+ ReinitializeTensor(
+ &gain_, {N}, at::dtype<float>().device(CPU));
auto* gain_data = gain_.template mutable_data<float>();
EigenVectorArrayMap<float> gain_vec(gain_data, gain_.numel());
// similar to ideal but replace with actual discounts
double dcg = (gain_vec * discount_vec).sum();
- lambda_.Resize(N * N);
+ ReinitializeTensor(
+ &lambda_, {N * N}, at::dtype<float>().device(CPU));
auto* lambda_data = lambda_.template mutable_data<float>();
EigenArrayMap<float> lambda_mat(lambda_data, N, N);
// computes lambda weight (i, j) = abs(gain_dff * discount_diff)
Tensor** dy);
bool use_ndcg_as_loss_;
bool use_exp_gain_;
- Tensor gain_{Context::GetDeviceType()};
- Tensor discount_{Context::GetDeviceType()};
- Tensor rank_idx_{Context::GetDeviceType()};
- Tensor ideal_idx_{Context::GetDeviceType()};
- Tensor lambda_{Context::GetDeviceType()};
- Tensor inv_log_i_{Context::GetDeviceType()};
+ Tensor gain_;
+ Tensor discount_;
+ Tensor rank_idx_;
+ Tensor ideal_idx_;
+ Tensor lambda_;
+ Tensor inv_log_i_;
};
template <typename T, class Context>
num_values,
"Sum of lengths should be equal to the total number of samples");
- values_tensor.Resize(num_values);
- percentiles_tensor.Resize(num_values);
+ ReinitializeTensor(
+ &values_tensor,
+ {num_values},
+ at::dtype<float>().device(CPU));
+ ReinitializeTensor(
+ &percentiles_tensor,
+ {num_values},
+ at::dtype<float>().device(CPU));
float* values_tensor_data = values_tensor.template mutable_data<float>();
float* percentiles_tensor_data =
percentiles_tensor.template mutable_data<float>();
protected:
INPUT_TAGS(X, VAL_PCT_PAIRS, LENS);
OUTPUT_TAGS(PCT);
- Tensor values_tensor{Context::GetDeviceType()};
- Tensor percentiles_tensor{Context::GetDeviceType()};
+ Tensor values_tensor;
+ Tensor percentiles_tensor;
};
} // namespace caffe2
}
Y_dims[axis_] += Xi.t.size(axis_);
}
- Y->t.Resize(Y_dims);
+ ReinitializeTensor(&Y->t, Y_dims, at::dtype<uint8_t>().device(CPU));
int before = X0.t.size_to_dim(axis_);
int after = X0.t.size_from_dim(axis_ + 1);
const auto C_total = Y_dims[axis_];
CHECK_EQ(K, W.t.size(1));
CHECK_EQ(N, B.t.numel());
const auto M = X.t.numel() / K;
- Y->t.Resize(M, N);
+ ReinitializeTensor(&Y->t, {M, N}, at::dtype<uint8_t>().device(CPU));
runWithSharedBuffer<CPUContext>(ws_, [&](Tensor* buffer) {
initQNNPACK();