const BufferAllocation::Slice& scratch_buffer, const Shape& input_shape,
const Shape& filter_shape, const Shape& output_shape, const Window& window,
const ConvolutionDimensionNumbers& dim_nums, int64 algorithm,
- const HloInstruction* hlo)
+ bool tensor_ops_enabled, const HloInstruction* hlo)
: Thunk(Kind::kConvolution, hlo),
convolution_kind_(convolution_kind),
input_buffer_(input_buffer),
output_shape_(output_shape),
window_(window),
dim_nums_(dim_nums),
- algorithm_(algorithm) {}
+ algorithm_(algorithm),
+ tensor_ops_enabled_(tensor_ops_enabled) {}
Status ConvolutionThunk::ExecuteOnStream(
const BufferAllocations& buffer_allocations, se::Stream* stream) {
buffer_allocations.GetDeviceAddress(scratch_buffer_);
se::dnn::AlgorithmConfig algorithm_config(
- se::dnn::AlgorithmDesc(algorithm_, /*use_tensor_ops=*/false));
+ se::dnn::AlgorithmDesc(algorithm_, tensor_ops_enabled_));
TF_RETURN_IF_ERROR(RunCudnnConvolution(
convolution_kind_, input_shape_, filter_shape_, output_shape_, input_data,
const Shape& input_shape, const Shape& filter_shape,
const Shape& output_shape, const Window& window,
const ConvolutionDimensionNumbers& dim_nums, int64 algorithm,
- const HloInstruction* hlo);
+ bool tensor_ops_enabled, const HloInstruction* hlo);
ConvolutionThunk(const ConvolutionThunk&) = delete;
ConvolutionThunk& operator=(const ConvolutionThunk&) = delete;
const Window window_;
const ConvolutionDimensionNumbers dim_nums_;
int64 algorithm_;
+ bool tensor_ops_enabled_;
};
} // namespace gpu
// cache misses and doing extra work. Overall, caching doesn't seem worth the
// trouble, but we may want to revisit this if we ever find a model where
// caching would speed up compilation a lot.
-optional<std::pair<int64, int64>>
+optional<std::tuple<int64, bool, int64>>
CudnnConvolutionAlgorithmPicker::PickBestAlgorithm(
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
const Shape& output_shape, const Window& window,
<< AlgorithmToString(best_result.algorithm()) << ", takes "
<< best_result.elapsed_time_in_ms() << "ms, and uses "
<< best_result_bytes_used << "B of scratch memory.";
- return std::make_pair(best_result.algorithm().algo_id(),
- best_result_bytes_used);
+ return std::make_tuple(best_result.algorithm().algo_id(),
+ best_result.algorithm().tensor_ops_enabled(),
+ best_result_bytes_used);
}
LOG(WARNING) << "All algorithms tried for convolution " << instr->ToString()
const auto& lhs_shape = instr->operand(0)->shape();
const auto& rhs_shape = instr->operand(1)->shape();
const auto& conv_result_shape = instr->shape().tuple_shapes(0);
- optional<std::pair<int64, int64>> alg_and_scratch_bytes;
+ optional<std::tuple<int64, bool, int64>> alg_scratch_and_tc;
if (call_target == kCudnnConvForwardCallTarget) {
- alg_and_scratch_bytes = PickBestAlgorithm(
+ alg_scratch_and_tc = PickBestAlgorithm(
CudnnConvKind::kForward, /*input_shape=*/lhs_shape,
/*filter_shape=*/rhs_shape, /*output_shape=*/conv_result_shape,
instr->window(), instr->convolution_dimension_numbers(), instr);
} else if (call_target == kCudnnConvBackwardInputCallTarget) {
- alg_and_scratch_bytes = PickBestAlgorithm(
+ alg_scratch_and_tc = PickBestAlgorithm(
CudnnConvKind::kBackwardInput, /*input_shape=*/conv_result_shape,
/*filter_shape=*/rhs_shape, /*output_shape=*/lhs_shape, instr->window(),
instr->convolution_dimension_numbers(), instr);
} else if (call_target == kCudnnConvBackwardFilterCallTarget) {
- alg_and_scratch_bytes = PickBestAlgorithm(
+ alg_scratch_and_tc = PickBestAlgorithm(
CudnnConvKind::kBackwardFilter, /*input_shape=*/lhs_shape,
/*filter_shape=*/conv_result_shape, /*output_shape=*/rhs_shape,
instr->window(), instr->convolution_dimension_numbers(), instr);
<< instr->ToString();
}
- if (!alg_and_scratch_bytes.has_value()) {
+ if (!alg_scratch_and_tc.has_value()) {
return false;
}
int64 algorithm;
+ bool tensor_ops_enabled;
int64 scratch_bytes;
- std::tie(algorithm, scratch_bytes) = *alg_and_scratch_bytes;
+
+ std::tie(algorithm, tensor_ops_enabled, scratch_bytes) = *alg_scratch_and_tc;
VLOG(1) << "Setting cudnn conv to use algorithm " << algorithm << " and "
<< NumBytesToString(scratch_bytes)
- << " of scratch memory: " << instr->ToString();
+ << " of scratch memory: " << instr->ToString()
+ << " tensor_ops_enabled: " << tensor_ops_enabled;
// Replace instr with a new CustomCall which has the correct algorithm, and
// whose output shape has the appropriate amount of scratch memory.
ShapeUtil::MakeShape(U8, {scratch_bytes})});
HloInstruction* algorithm_hlo = computation->AddInstruction(
HloInstruction::CreateConstant(Literal::CreateR0<int64>(algorithm)));
+ HloInstruction* tensor_ops_enabled_hlo =
+ computation->AddInstruction(HloInstruction::CreateConstant(
+ Literal::CreateR0<bool>(tensor_ops_enabled)));
+
HloInstruction* new_call =
computation->AddInstruction(HloInstruction::CreateCustomCall(
new_call_shape,
- {instr->mutable_operand(0), instr->mutable_operand(1), algorithm_hlo},
+ {instr->mutable_operand(0), instr->mutable_operand(1), algorithm_hlo,
+ tensor_ops_enabled_hlo},
instr->custom_call_target()));
new_call->set_window(instr->window());
new_call->set_convolution_dimension_numbers(
private:
StatusOr<bool> RunOnComputation(HloComputation* computation);
StatusOr<bool> RunOnInstruction(HloInstruction* instr);
- tensorflow::gtl::optional<std::pair<int64, int64>> PickBestAlgorithm(
+ tensorflow::gtl::optional<std::tuple<int64, bool, int64>> PickBestAlgorithm(
CudnnConvKind kind, const Shape& input_shape, const Shape& filter_shape,
const Shape& output_shape, const Window& window,
const ConvolutionDimensionNumbers& dnums, HloInstruction* instr);
se::ScratchAllocator* scratch_allocator, const Window& window,
const ConvolutionDimensionNumbers& dnums, AlgorithmConfig algorithm,
Stream* stream, ProfileResult* profile_result /*= nullptr*/) {
+ VLOG(3) << "Convolution Algorithm: " << algorithm.algorithm().algo_id();
+ VLOG(3) << "tensor_ops_enabled: "
+ << algorithm.algorithm().tensor_ops_enabled();
VLOG(3) << "Convolution kind: " << CudnnConvKindToString(kind);
VLOG(3) << "input shape: { " << ShapeUtil::HumanString(input_shape) << " }";
VLOG(3) << "filter shape: { " << ShapeUtil::HumanString(filter_shape) << " }";
TF_RETURN_IF_ERROR(copy_operand_if_constant(i));
}
} else if (IsCustomCallToDnnConvolution(*hlo)) {
- // The last argument to a CUDNN convolution is its algorithm, which must
- // be an HLO constant -- it shouldn't be copied.
- for (int64 i = 0; i < hlo->operand_count() - 1; ++i) {
+ // The last two arguments to a CUDNN convolution are two HLO constants for
+ // cudnn algorithm and tensor_ops_enabled flag, which shouldn't be copied.
+ for (int64 i = 0; i < hlo->operand_count() - 2; ++i) {
TF_RETURN_IF_ERROR(copy_operand_if_constant(i));
}
} else if (ImplementedAsLibraryCall(*hlo)) {
// strings.
//
// These CustomCalls have window() and convolution_dimension_numbers() set like
-// regular convolution ops. They have the same LHS and RHS operands, plus one
-// additional int64 operand, representing which cudnn algorithm to run. This
-// operand must be an HLO constant. A value of -1 means that the implementation
-// is free to choose the best algorithm it can.
+// regular convolution ops. They have the same LHS and RHS operands, plus two
+// additional constant operands: an int64 operand for the cudnn algorithm and
+// a bool operand for whether tensor_ops is enabled. A value of -1 for the cudnn
+// algorithm means that the implementation is free to choose the best algorithm
+// it can.
//
// These calls output a tuple (conv_result, scratch_memory), where conv_result
// is the actual result of the convolution, and scratch_memory is temporary
CHECK(algorithm_inst->IsConstant()) << algorithm_inst->ToString();
int64 algorithm = algorithm_inst->literal().Get<int64>({});
+ const HloInstruction* tensor_ops_enabled_inst = custom_call->operand(3);
+ CHECK(tensor_ops_enabled_inst->IsConstant())
+ << tensor_ops_enabled_inst->ToString();
+ bool tensor_ops_enabled = tensor_ops_enabled_inst->literal().Get<bool>({});
+
const auto& target = custom_call->custom_call_target();
std::unique_ptr<ConvolutionThunk> thunk;
if (target == kCudnnConvForwardCallTarget) {
/*filter_shape=*/rhs_shape,
/*output_shape=*/conv_result_shape, //
custom_call->window(), custom_call->convolution_dimension_numbers(),
- algorithm, custom_call);
+ algorithm, tensor_ops_enabled, custom_call);
} else if (target == kCudnnConvBackwardInputCallTarget) {
thunk = MakeUnique<ConvolutionThunk>(
CudnnConvKind::kBackwardInput,
/*filter_shape=*/rhs_shape,
/*output_shape=*/lhs_shape, //
custom_call->window(), custom_call->convolution_dimension_numbers(),
- algorithm, custom_call);
+ algorithm, tensor_ops_enabled, custom_call);
} else if (target == kCudnnConvBackwardFilterCallTarget) {
thunk = MakeUnique<ConvolutionThunk>(
CudnnConvKind::kBackwardFilter,
/*filter_shape=*/conv_result_shape,
/*output_shape=*/rhs_shape, //
custom_call->window(), custom_call->convolution_dimension_numbers(),
- algorithm, custom_call);
+ algorithm, tensor_ops_enabled, custom_call);
} else {
LOG(FATAL) << "Unexpected custom call target: "
<< custom_call->custom_call_target();