#include "caffe2/core/flags.h"
#include "caffe2/core/tensor_int8.h"
+#include "caffe2/operators/fc_inference.h"
#include "caffe2/utils/cpuid.h"
#include "fbgemm_pack_matrix_cache.h"
#include "fbgemm_pack_op.h"
using namespace std;
-template <typename T>
-FullyConnectedDNNLowPOp<T>::FullyConnectedDNNLowPOp(
+template <typename T, bool ReluFused>
+FullyConnectedDNNLowPOp<T, ReluFused>::FullyConnectedDNNLowPOp(
const OperatorDef& operator_def,
Workspace* ws)
: BaseType(operator_def, ws),
VLOG(2) << "DNNLOWP FC with output " << operator_def.output(0);
}
-template <typename T>
-bool FullyConnectedDNNLowPOp<T>::RunOnDevice() {
+template <typename T, bool ReluFused>
+bool FullyConnectedDNNLowPOp<T, ReluFused>::RunOnDevice() {
using namespace std;
using namespace dnnlowp;
+ bool first_invocation = !this->arguments_parsed_;
this->ParseDNNLowPOperatorArguments_();
+ if (first_invocation && ReluFused) {
+ followed_by_ = "Relu";
+ AdjustOutputTensorQuantizationParamsWithFollowedBy(this, followed_by_);
+ }
if ((!GetCpuId().avx2() || FLAGS_caffe2_dnnlowp_enforce_default_operators) &&
dequantize_output_) {
row_offsets_.data());
if (quantize_channelwise_) {
- ReQuantizeOutput<
- false /* FUSE_RELU */,
- QuantizationGranularity::OUT_CHANNEL>
+ ReQuantizeOutput<ReluFused, QuantizationGranularity::OUT_CHANNEL>
outputProcObj(
doNothingObj,
requantization_multipliers_.data(),
0, // thread_id
1); // num_threads
} else {
- ReQuantizeOutput<false /* FUSE_RELU */> outputProcObj(
+ ReQuantizeOutput<ReluFused> outputProcObj(
doNothingObj,
requantization_multipliers_.data(),
out_qparams_.zero_point,
X_pack_buf_.data(), // buffer for packed matrix
1); // group
- ReQuantizeOutput<false /* FUSE_RELU */> outputProcObj(
+ ReQuantizeOutput<ReluFused> outputProcObj(
doNothingObj,
requantization_multipliers_.data(),
out_qparams_.zero_point,
DoNothing<float, float> doNothingObj{};
if (quantize_channelwise_) {
- ReQuantizeForFloat<
- false /* FUSE_RELU*/,
- QuantizationGranularity::OUT_CHANNEL>
+ ReQuantizeForFloat<ReluFused, QuantizationGranularity::OUT_CHANNEL>
outputProcObj(
doNothingObj,
in_qparams_[0].scale,
0, // thread_id
1); // num_threads
} else {
- ReQuantizeForFloat<false /* FUSE_RELU*/> outputProcObj(
+ ReQuantizeForFloat<ReluFused> outputProcObj(
doNothingObj,
in_qparams_[0].scale,
filter_scales_.data(),
DoNothing<float, float> doNothingObj{};
if (quantize_channelwise_) {
- ReQuantizeForFloat<
- false /* FUSE_RELU*/,
- QuantizationGranularity::OUT_CHANNEL>
+ ReQuantizeForFloat<ReluFused, QuantizationGranularity::OUT_CHANNEL>
outputProcObj(
doNothingObj,
in_qparams_[0].scale,
0, // thread_id
1); // num_threads
} else {
- ReQuantizeForFloat<false /* FUSE_RELU*/> outputProcObj(
+ ReQuantizeForFloat<ReluFused> outputProcObj(
doNothingObj,
in_qparams_[0].scale,
filter_scales_.data(),
Ydata[i * N + j] = Y_int32_[i * N + j] * in_qparams_[0].scale *
filter_qparams_[quant_group].scale +
b_dequantized_data_[j];
+ if (ReluFused) {
+ Ydata[i * N + j] = std::max(Ydata[i * N + j], 0.0f);
+ }
}
}
}
Ydata[i * N + j] = fbgemm::Requantize<T>(
Y_int32_[i * N + j], requantization_params_[quant_group]);
+ if (ReluFused) {
+ Ydata[i * N + j] =
+ std::max<T>(out_qparams_.zero_point, Ydata[i * N + j]);
+ }
}
}
}
return true;
}
-template <typename T>
-bool FullyConnectedDNNLowPOp<T>::GetQuantizationParameters_() {
+template <typename T, bool ReluFused>
+bool FullyConnectedDNNLowPOp<T, ReluFused>::GetQuantizationParameters_() {
using namespace dnnlowp;
#ifdef DNNLOWP_MEASURE_TIME_BREAKDOWN
DNNLOWP_ROWWISE,
FullyConnectedDNNLowPOp<uint8_t>);
+REGISTER_CPU_OPERATOR_WITH_ENGINE(
+ Int8FCRelu,
+ DNNLOWP,
+ FullyConnectedDNNLowPOp<uint8_t, true>);
+REGISTER_CPU_OPERATOR_WITH_ENGINE(
+ Int8FCRelu,
+ DNNLOWP_ROWWISE,
+ FullyConnectedDNNLowPOp<uint8_t, true>);
+
+using namespace std::placeholders;
+OPERATOR_SCHEMA(Int8FCRelu)
+ .NumInputs(3)
+ .NumOutputs(1)
+ .TensorInferenceFunction(std::bind(FCShapeInference, _1, _2, false))
+ .CostInferenceFunction(std::bind(CostInferenceForFC, _1, _2, false));
+
} // namespace caffe2
prepack_weight=st.booleans(),
preserve_activation_sparsity=st.booleans(),
preserve_weight_sparsity=st.booleans(),
+ fuse_relu=st.booleans(),
**hu.gcs_cpu_only
)
def test_dnnlowp_fully_connected_int(
prepack_weight,
preserve_activation_sparsity,
preserve_weight_sparsity,
+ fuse_relu,
gc,
dc,
):
op_engine_list = [
("FC", ""),
- ("FC", "DNNLOWP"),
- ("FC", "DNNLOWP_16"),
- ("Int8FC", "DNNLOWP"),
]
+ if fuse_relu:
+ op_engine_list += [
+ ("Int8FCRelu", "DNNLOWP"),
+ ]
+ else:
+ op_engine_list += [
+ ("FC", "DNNLOWP"),
+ ("FC", "DNNLOWP_16"),
+ ("Int8FC", "DNNLOWP"),
+ ]
for op_type, engine in op_engine_list:
init_net = core.Net("test_init_net")
fc, outputs[0][0], preserve_activation_sparsity
)
net.Proto().op.extend([fc])
+ if fuse_relu and "DNNLOWP" not in engine:
+ net.Relu(["Y"], "Y")
if do_dequantize:
dequantize = core.CreateOperator(