const uint8_t *_data = getData<uint8_t>();
const uint8_t *_rdata = rhs.getData<uint8_t>();
for (size_t i = 0; i < len; ++i) {
- /** not checking sign change is intentional to avoid float calculation
- * errors around 0 */
- if ((std::isnan(_data[i]) && !std::isnan(_rdata[i])) ||
- (!std::isnan(_data[i]) && std::isnan(_rdata[i])) ||
- _data[i] != _rdata[i])
+ if (_data[i] != _rdata[i])
return false;
}
} else if (dim.getDataType() == ml::train::TensorDim::DataType::QINT4) {
data = decode_qint(_data[i / 2], (i % 2 == 0));
rdata = decode_qint(_rdata[i / 2], (i % 2 == 0));
- if ((std::isnan(data) && !std::isnan(rdata)) ||
- (!std::isnan(data) && std::isnan(rdata)) || data != rdata)
+ if (data != rdata)
return false;
}
}