namespace caffe2 {
+// Earlier in the days Caffe sets the default miopen workspace to 8MB. We bump
+// it up to 64MB in Caffe2, as this enables the use of Winograd in many cases,
+// something very beneficial to more recent CNN models.
+static constexpr size_t kCONV_MIOPEN_WORKSPACE_LIMIT_BYTES = 64 * 1024 * 1024;
+
class MIOPENConvOpBase : public ConvPoolOpBase<HIPContext> {
public:
MIOPENConvOpBase(const OperatorDef& operator_def, Workspace* ws)
miopen_wrapper_(&context_),
miopen_state_(
OperatorBase::GetSingleArgument<size_t>("miopen_state", 0)),
+ miopen_ws_nbytes_limit_(OperatorBase::GetSingleArgument<size_t>(
+ "ws_nbytes_limit",
+ kCONV_MIOPEN_WORKSPACE_LIMIT_BYTES)),
exhaustive_search_(
OperatorBase::GetSingleArgument<bool>("exhaustive_search", false)),
alpha_(OperatorBase::GetSingleArgument<float>("alpha", 1.0)),
miopenConvolutionDescriptor_t conv_desc_;
miopenConvolutionMode_t mode_;
size_t miopen_state_;
+ const size_t miopen_ws_nbytes_limit_;
bool exhaustive_search_;
const float alpha_;
const float beta_;
OperatorBase::GetSingleArgument<int>("returnedAlgoCount_", 1)),
bestAlgoFound_(
OperatorBase::GetSingleArgument<bool>("bestAlgoFound_", false)),
+ fwdConvWs_(nullptr),
+ fwdConvWsSize_(0),
fwdAlgo_(miopenConvolutionFwdAlgoGEMM) {}
- ~MIOPENConvOp() {}
+ ~MIOPENConvOp() {
+ if (fwdConvWs_) {
+ hipFree(fwdConvWs_);
+ fwdConvWs_ = nullptr;
+ fwdConvWsSize_ = 0;
+ }
+ }
template <
typename T_X,
const int requestAlgoCount_;
int returnedAlgoCount_;
bool bestAlgoFound_;
+ char* fwdConvWs_;
size_t fwdConvWsSize_;
miopenConvFwdAlgorithm_t fwdAlgo_;
// Input: X, W, b
bestWeightAlgoFound_(
OperatorBase::GetSingleArgument<bool>("bestAlgoFound", false)),
bwdWeiAlgo_(miopenConvolutionBwdWeightsAlgoGEMM),
- bwdDataAlgo_(miopenConvolutionBwdDataAlgoGEMM) {
+ bwdDataAlgo_(miopenConvolutionBwdDataAlgoGEMM),
+ bwdWeightWsSize_(0),
+ bwdDataWsSize_(0),
+ bwdWeightWs_(nullptr),
+ bwdDataWs_(nullptr) {
CAFFE_ENFORCE(
!(no_bias_ && OutputSize() == 3),
"If bias is not present, you should not have 3 grad output.");
}
- ~MIOPENConvGradientOp() {}
+ ~MIOPENConvGradientOp() {
+ if (bwdWeightWs_) {
+ hipFree(bwdWeightWs_);
+ bwdWeightWs_ = nullptr;
+ bwdWeightWsSize_ = 0;
+ }
+ if (bwdDataWs_) {
+ hipFree(bwdDataWs_);
+ bwdDataWs_ = nullptr;
+ bwdDataWsSize_ = 0;
+ }
+ }
template <
typename T_X,
miopenConvBwdDataAlgorithm_t bwdDataAlgo_;
size_t bwdWeightWsSize_;
size_t bwdDataWsSize_;
+ char* bwdWeightWs_;
+ char* bwdDataWs_;
// input: X, W, dY
// output: dW, db, and optionally dX
INPUT_TAGS(INPUT, FILTER, OUTPUT_GRAD);
conv_desc_,
top_desc_,
&fwdConvWsSize_));
+ if ((fwdConvWsSize_ > 0) && (fwdConvWs_ == nullptr)) {
+ HIP_CHECK(hipMalloc(&fwdConvWs_, fwdConvWsSize_));
+ }
miopen_wrapper_.with_miopen_state(miopen_state_, [&](MIOPENState* state) {
MIOPEN_ENFORCE(miopenFindConvolutionForwardAlgorithm(
requestAlgoCount_,
&returnedAlgoCount_,
&perf,
- state->workspace().get(fwdConvWsSize_),
+ fwdConvWs_,
fwdConvWsSize_,
false));
});
&beta_,
top_desc_,
Y->template mutable_data<T_Y>(),
- state->workspace().get(fwdConvWsSize_),
+ fwdConvWs_,
fwdConvWsSize_));
});
conv_desc_,
bottom_desc_,
&bwdDataWsSize_));
+ if ((bwdDataWsSize_ > 0) && (bwdDataWs_ == nullptr)) {
+ HIP_CHECK(hipMalloc(&bwdDataWs_, bwdDataWsSize_));
+ }
miopen_wrapper_.with_miopen_state(miopen_state_, [&](MIOPENState* state) {
MIOPEN_ENFORCE(miopenFindConvolutionBackwardDataAlgorithm(
requestAlgoCount_,
&returnedAlgoCount_,
&perf,
- state->workspace().get(bwdDataWsSize_),
+ bwdDataWs_,
bwdDataWsSize_,
false));
});
conv_desc_,
weight_desc_,
&bwdWeightWsSize_));
+ if ((bwdWeightWsSize_ > 0) && (bwdWeightWs_ == nullptr)) {
+ HIP_CHECK(hipMalloc(&bwdWeightWs_, bwdWeightWsSize_));
+ }
miopen_wrapper_.with_miopen_state(miopen_state_, [&](MIOPENState* state) {
MIOPEN_ENFORCE(miopenFindConvolutionBackwardWeightsAlgorithm(
requestAlgoCount_,
&returnedAlgoCount_,
&perf,
- state->workspace().get(bwdWeightWsSize_),
+ bwdWeightWs_,
bwdWeightWsSize_,
false));
});
&beta_,
bottom_desc_,
dX->template mutable_data<T_DX>(),
- state->workspace().get(bwdDataWsSize_),
+ bwdDataWs_,
bwdDataWsSize_));
});
}
&beta_,
weight_desc_,
dW->template mutable_data<T_DW>(),
- state->workspace().get(bwdWeightWsSize_),
+ bwdWeightWs_,
bwdWeightWsSize_));
});
+ // Synchronize the work across groups.
+ hipDeviceSynchronize();
+
////////////////////////////////////// BIAS ///////////////////////////
if (!no_bias_) {
auto* dbias = Output(BIAS_OR_INPUT_GRAD);
namespace caffe2 {
+static constexpr size_t kCONV_MIOPEN_WORKSPACE_LIMIT_BYTES = 64 * 1024 * 1024;
+
class MIOPENConvTransposeOpBase : public ConvTransposeUnpoolBase<HIPContext> {
public:
MIOPENConvTransposeOpBase(const OperatorDef& operator_def, Workspace* ws)
miopen_wrapper_(&context_),
miopen_state_(
OperatorBase::GetSingleArgument<size_t>("miopen_state", 0)),
+ miopen_ws_nbytes_limit_(OperatorBase::GetSingleArgument<size_t>(
+ "ws_nbytes_limit",
+ kCONV_MIOPEN_WORKSPACE_LIMIT_BYTES)),
exhaustive_search_(
OperatorBase::GetSingleArgument<bool>("exhaustive_search", false)),
alpha_(OperatorBase::GetSingleArgument<float>("alpha", 1.0)),
miopenTensorDescriptor_t top_desc_for_bias_;
miopenConvolutionDescriptor_t conv_desc_;
size_t miopen_state_;
+ const size_t miopen_ws_nbytes_limit_;
bool exhaustive_search_;
const float alpha_;
const float beta_;
OperatorBase::GetSingleArgument<int>("returnedAlgoCount_", 1)),
bestAlgoFound_(
OperatorBase::GetSingleArgument<bool>("bestAlgoFound_", false)),
+ fwdConvWs_(nullptr),
+ fwdConvWsSize_(0),
fwdAlgo_(miopenConvolutionFwdAlgoGEMM) {}
- ~MIOPENConvTransposeOp() {}
+ ~MIOPENConvTransposeOp() {
+ if (fwdConvWs_) {
+ hipFree(fwdConvWs_);
+ fwdConvWs_ = nullptr;
+ fwdConvWsSize_ = 0;
+ }
+ }
bool RunOnDevice() override;
const int requestAlgoCount_;
int returnedAlgoCount_;
bool bestAlgoFound_;
+ char* fwdConvWs_;
size_t fwdConvWsSize_;
miopenConvFwdAlgorithm_t fwdAlgo_;
// Input: X, W, b
bestWeightAlgoFound_(
OperatorBase::GetSingleArgument<bool>("bestAlgoFound", false)),
bwdWeiAlgo_(miopenConvolutionBwdWeightsAlgoGEMM),
- bwdDataAlgo_(miopenConvolutionBwdDataAlgoGEMM) {
+ bwdDataAlgo_(miopenConvolutionBwdDataAlgoGEMM),
+ bwdWeightWsSize_(0),
+ bwdDataWsSize_(0),
+ bwdWeightWs_(nullptr),
+ bwdDataWs_(nullptr) {
CAFFE_ENFORCE(
!(no_bias_ && OutputSize() == 3),
"If bias is not present, you should not have 3 grad output.");
}
- ~MIOPENConvTransposeGradientOp() {}
+ ~MIOPENConvTransposeGradientOp() {
+ if (bwdWeightWs_) {
+ hipFree(bwdWeightWs_);
+ bwdWeightWs_ = nullptr;
+ bwdWeightWsSize_ = 0;
+ }
+ if (bwdDataWs_) {
+ hipFree(bwdDataWs_);
+ bwdDataWs_ = nullptr;
+ bwdDataWsSize_ = 0;
+ }
+ }
bool RunOnDevice() override;
miopenConvBwdDataAlgorithm_t bwdDataAlgo_;
size_t bwdWeightWsSize_;
size_t bwdDataWsSize_;
+ char* bwdWeightWs_;
+ char* bwdDataWs_;
// input: X, W, dY
// output: dW, db, and optionally dX
INPUT_TAGS(INPUT, FILTER, OUTPUT_GRAD);
conv_desc_,
top_desc_,
&fwdConvWsSize_));
+ if ((fwdConvWsSize_ > 0) && (fwdConvWs_ == nullptr)) {
+ HIP_CHECK(hipMalloc(&fwdConvWs_, fwdConvWsSize_));
+ }
miopen_wrapper_.with_miopen_state(miopen_state_, [&](MIOPENState* state) {
MIOPEN_ENFORCE(miopenFindConvolutionForwardAlgorithm(
requestAlgoCount_,
&returnedAlgoCount_,
&perf,
- state->workspace().get(fwdConvWsSize_),
+ fwdConvWs_,
fwdConvWsSize_,
false));
});
&beta_,
top_desc_,
Y->template mutable_data<T>(),
- state->workspace().get(fwdConvWsSize_),
+ fwdConvWs_,
fwdConvWsSize_));
});
conv_desc_,
bottom_desc_,
&bwdDataWsSize_));
+ if ((bwdDataWsSize_ > 0) && (bwdDataWs_ == nullptr)) {
+ HIP_CHECK(hipMalloc(&bwdDataWs_, bwdDataWsSize_));
+ }
miopen_wrapper_.with_miopen_state(miopen_state_, [&](MIOPENState* state) {
MIOPEN_ENFORCE(miopenFindConvolutionBackwardDataAlgorithm(
requestAlgoCount_,
&returnedAlgoCount_,
&perf,
- state->workspace().get(bwdDataWsSize_),
+ bwdDataWs_,
bwdDataWsSize_,
false));
});
conv_desc_,
weight_desc_,
&bwdWeightWsSize_));
+ if ((bwdWeightWsSize_ > 0) && (bwdWeightWs_ == nullptr)) {
+ HIP_CHECK(hipMalloc(&bwdWeightWs_, bwdWeightWsSize_));
+ }
miopen_wrapper_.with_miopen_state(miopen_state_, [&](MIOPENState* state) {
MIOPEN_ENFORCE(miopenFindConvolutionBackwardWeightsAlgorithm(
requestAlgoCount_,
&returnedAlgoCount_,
&perf,
- state->workspace().get(bwdWeightWsSize_),
+ bwdWeightWs_,
bwdWeightWsSize_,
false));
});
&beta_,
bottom_desc_,
dX->template mutable_data<T>(),
- state->workspace().get(bwdDataWsSize_),
+ bwdDataWs_,
bwdDataWsSize_));
});
}
&beta_,
weight_desc_,
dW->template mutable_data<T>(),
- state->workspace().get(bwdWeightWsSize_),
+ bwdWeightWs_,
bwdWeightWsSize_));
});