const auto& U = Input(1);
const auto& V = Input(2);
const auto& b = Input(3);
- auto* Y = Output(0);
+
//auto* buffer_ptr = Output(1);
// Size M * middle;
//auto& multi_buffer_ = *buffer_ptr;
int middle = U.dim32(0);
CAFFE_ENFORCE_EQ(K, V.dim32(0));
CAFFE_ENFORCE_EQ(N, b.dim32(0));
+ std::vector<int64_t> dims;
if (X.dim() > 1) {
- Y->Resize(M, N);
+ dims = {M, N};
multi_buffer_.Resize(M, middle);
} else {
- Y->Resize(N);
+ dims = {N};
multi_buffer_.Resize(middle);
}
- // The col buffer is stored in CHW order as well - kernel_dim, and the height
- // and width.
+ auto* Y = Output(0, dims, at::dtype<T>());
+ // The col buffer is stored in CHW order as well - kernel_dim, and the
+ // height and width.
// multi_buffer_.Resize(M, middle);
T* multi_buffer_data = multi_buffer_.template mutable_data<T>();
// X * V * tans(U)
}
auto* dU = Output(0);
auto* dV = Output(1);
- auto* db = Output(2);
+
dU->ResizeLike(U);
dV->ResizeLike(V);
- db->Resize(N);
+ auto* db = Output(2, {N}, at::dtype<T>());
// Compute dU
// first compute X * V
const auto& W = Input(1);
const auto& Mask = Input(2);
const auto& b = Input(3);
- auto* Y = Output(0);
+
CAFFE_ENFORCE_GE(X.dim(), 1);
CAFFE_ENFORCE_GE(W.dim(), 2);
if (X.dim() > 2 || W.dim() > 2) {
int N = W.dim32(0);
CAFFE_ENFORCE_EQ(K, W.numel() / W.dim32(0));
CAFFE_ENFORCE_EQ(N, b.dim32(0));
+ std::vector<int64_t> dims;
if (X.dim() > 1) {
- Y->Resize(M, N);
+ dims = {M, N};
} else {
- Y->Resize(N);
+ dims = {N};
}
+ auto* Y = Output(0, dims, at::dtype<T>());
// W * x
math::Gemm<T, Context, Engine>(
CblasNoTrans, CblasTrans, M, N, K, 1, X.template data<T>(),
bias_multiplier_.template data<T>(), b.template data<T>(), 1,
Y->template mutable_data<T>(), &context_);
if (OutputSize() == 2){
- auto* Comp_rate = Output(1);
- Comp_rate->Resize(vector<int64_t>());
+ auto* Comp_rate = Output(1, vector<int64_t>(), at::dtype<T>());
T* comp_data = Comp_rate->template mutable_data<T>();
math::Sum<T, Context>(
Mask.numel(), Mask.template data<T>(), comp_data, &context_);
DCHECK_EQ(N, dY.numel());
}
auto* dW = Output(0);
- auto* db = Output(1);
+
dW->ResizeLike(W);
- db->Resize(N);
+ auto* db = Output(1, {N}, at::dtype<T>());
// Compute dW
math::Gemm<T, Context, Engine>(
const auto& jw = Input(3);
// Notice that we do not need to transpose b
const auto& b = Input(4);
- auto* Yt = Output(0); // transposed Y
+ // transposed Y
// here we assume X is k-by-m
CAFFE_ENFORCE_EQ(Xt.dim(), 2);
CAFFE_ENFORCE_EQ(b.dim(), 1);
// number of outputs.
int N = iw.dim32(0)-1;
CAFFE_ENFORCE_EQ(N, b.dim32(0));
- Yt->Resize(shape(N, M));
+ auto* Yt = Output(0, shape(N, M), at::dtype<T>());
// Y' = W * X';
Sparse_mm<T, Context>(
++n_segments;
}
- auto* output = Output(0);
- output->Resize(n_segments, num_outputs_);
+ auto* output = Output(0, {n_segments, num_outputs_}, at::dtype<T>());
T* output_data = output->template mutable_data<T>();
++n_segments;
}
- auto* output = Output(0);
- output->Resize(n_segments, num_outputs_);
+ auto* output = Output(0, {n_segments, num_outputs_}, at::dtype<T>());
T* output_data = output->template mutable_data<T>();
int64_t num_nz_ent = seg.size(0);
int64_t grad_weight_size = num_nz_ent * num_outputs_ * num_alpha;
- auto* grad_weight_val = Output(0);
- grad_weight_val->Resize(grad_weight_size);
+
+ auto* grad_weight_val = Output(0, {grad_weight_size}, at::dtype<T>());
T* grad_weight_val_data = grad_weight_val->template mutable_data<T>();
- auto* grad_weight_ind = Output(1);
- grad_weight_ind->Resize(grad_weight_size);
+ auto* grad_weight_ind = Output(1, {grad_weight_size}, at::dtype<int64_t>());
auto* grad_weight_ind_data =
grad_weight_ind->template mutable_data<int64_t>();
CAFFE_ENFORCE(
old_row.numel() == nnz,
"Column and row tensors must have the same size.");
- auto* new_col = Output(0);
- auto* new_row = Output(1);
- new_col->Resize(nnz);
- new_row->Resize(nnz);
+
+ auto* new_col = Output(0, {nnz}, at::dtype<int64_t>());
+ auto* new_row = Output(1, {nnz}, at::dtype<int>());
const auto* old_col_data = old_col.template data<int64_t>();
const auto* old_row_data = old_row.template data<int>();
bool RunOnDevice() override {
const auto& A = Input(0);
const auto& B = Input(1);
- auto* C = Output(0);
CAFFE_ENFORCE(A.dim() == 2, A.dim());
int64_t D_ = B_size / (K_ * N_);
int64_t C_size = D_ * M_ * N_;
- C->Resize(vector<int64_t>{C_size});
+ auto* C = Output(0, vector<int64_t>{C_size}, at::dtype<T>());
int64_t B_stride = K_ * N_;
int64_t C_stride = M_ * N_;
const auto& G = Input(0);
const auto& A = Input(1);
const auto& B = Input(2);
- auto* dA = Output(0);
- auto* dB = Output(1);
int64_t G_size = G.numel();
int64_t D_ = G_size / (M_ * N_);
int64_t dB_size = D_ * K_ * N_;
- dA->Resize(A.sizes());
- dB->Resize(B.sizes());
+ auto* dA = Output(0, A.sizes(), at::dtype<T>());
+ auto* dB = Output(1, B.sizes(), at::dtype<T>());
int64_t B_stride = K_ * N_;
int64_t G_stride = M_ * N_;
auto X_dim0 = X.size(0);
auto X_dim1 = X.size(1);
- auto* X_orig_dim0 = Output(1);
- X_orig_dim0->Resize(1);
+ auto* X_orig_dim0 = Output(1, {1}, at::dtype<int64_t>());
*X_orig_dim0->template mutable_data<int64_t>() = X_dim0;
if (X_dim0 % scale_ != 0) {