{
nnfw_tensorinfo ti;
NNPR_ENSURE_STATUS(nnfw_input_tensorinfo(session, i, &ti));
+ // allocate memory for data
+ auto bufsz = bufsize_for(&ti);
+ inputs[i].alloc(bufsz);
H5::DataSet data_set = value_group.openDataSet(std::to_string(i));
-
- // check type
H5::DataType type = data_set.getDataType();
- if (!(type == H5::PredType::IEEE_F32BE || type == H5::PredType::IEEE_F32LE))
+ if (type == H5::PredType::IEEE_F32BE || type == H5::PredType::IEEE_F32LE)
{
- throw std::runtime_error("h5 input has non-float32 type. nnpkg_run supports float32 only.");
+ data_set.read(inputs[i].data(), H5::PredType::NATIVE_FLOAT);
+ NNPR_ENSURE_STATUS(
+ nnfw_set_input(session, i, NNFW_TYPE_TENSOR_FLOAT32, inputs[i].data(), bufsz));
+ }
+ else if (type == H5::PredType::STD_U8BE || type == H5::PredType::STD_U8LE)
+ {
+ data_set.read(inputs[i].data(), H5::PredType::NATIVE_UINT8);
+ NNPR_ENSURE_STATUS(
+ nnfw_set_input(session, i, NNFW_TYPE_TENSOR_QUANT8_ASYMM, inputs[i].data(), bufsz));
+ }
+ else
+ {
+ throw std::runtime_error("nnpkg_run can load float32 and uint8 only.");
}
-
- // allocate memory for data
- auto sz = num_elems(&ti);
- inputs[i].alloc(sz * sizeof(float));
- // read data
- data_set.read(inputs[i].data(), H5::PredType::NATIVE_FLOAT);
-
- NNPR_ENSURE_STATUS(nnfw_set_input(session, i, NNFW_TYPE_TENSOR_FLOAT32, inputs[i].data(),
- sizeof(float) * num_elems(&ti)));
NNPR_ENSURE_STATUS(nnfw_set_input_layout(session, i, NNFW_LAYOUT_CHANNELS_LAST));
}
}
}
}
H5::DataSpace data_space(ti.rank, dims.data());
- H5::DataSet data_set =
- value_group.createDataSet(std::to_string(i), H5::PredType::IEEE_F32BE, data_space);
- data_set.write(outputs[i].data(), H5::PredType::NATIVE_FLOAT);
+ if (ti.dtype == NNFW_TYPE_TENSOR_FLOAT32)
+ {
+ H5::DataSet data_set =
+ value_group.createDataSet(std::to_string(i), H5::PredType::IEEE_F32BE, data_space);
+ data_set.write(outputs[i].data(), H5::PredType::NATIVE_FLOAT);
+ }
+ else if (ti.dtype == NNFW_TYPE_TENSOR_QUANT8_ASYMM)
+ {
+ H5::DataSet data_set =
+ value_group.createDataSet(std::to_string(i), H5::PredType::STD_U8BE, data_space);
+ data_set.write(outputs[i].data(), H5::PredType::NATIVE_UINT8);
+ }
+ else
+ {
+ throw std::runtime_error("nnpkg_run can dump float32 and uint8 only.");
+ }
}
}
catch (const H5::Exception &e)