2 * Copyright (c) 2022 Samsung Electronics Co., Ltd. All Rights Reserved
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
8 * http://www.apache.org/licenses/LICENSE-2.0
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
17 #include "dio_hdf5/HDF5Importer.h"
25 #include <gtest/gtest.h>
27 using HDF5Importer = dio::hdf5::HDF5Importer;
28 using Shape = std::vector<loco::Dimension>;
29 using DataType = loco::DataType;
34 const std::string file_name("dio_hdf5_test.h5");
38 // File already exists. Remove it.
39 if (auto f = fopen(file_name.c_str(), "r"))
42 if (remove(file_name.c_str()) != 0)
43 throw std::runtime_error("Error deleting file.");
47 hsize_t dim[3] = {1, 2, 3};
48 H5::DataSpace space(rank, dim);
50 float data[] = {0, 1, 2, 3, 4, 5};
52 // Create test file in the current directory
53 H5::H5File file(file_name, H5F_ACC_TRUNC);
55 file.createGroup("/value");
56 file.createGroup("/value/0");
57 H5::DataSet dataset(file.createDataSet("/value/0/0", H5::PredType::IEEE_F32BE, space));
58 dataset.write(data, H5::PredType::IEEE_F32LE);
64 TEST(dio_hdf5_test, read_with_type_shape)
68 HDF5Importer h5(::file_name);
70 h5.importGroup("value");
72 std::vector<float> buffer(6);
76 h5.readTensor(0, 0, &dtype, &shape, buffer.data());
78 for (uint32_t i = 0; i < 6; i++)
79 EXPECT_EQ(i, buffer[i]);
81 EXPECT_EQ(DataType::FLOAT32, dtype);
82 EXPECT_EQ(3, shape.size());
83 EXPECT_EQ(1, shape[0]);
84 EXPECT_EQ(2, shape[1]);
85 EXPECT_EQ(3, shape[2]);
88 TEST(dio_hdf5_test, wrong_path_NEG)
90 const std::string wrong_path = "not_existing_file_for_dio_hdf5_test";
92 EXPECT_ANY_THROW(HDF5Importer h5(wrong_path));
95 TEST(dio_hdf5_test, wrong_group_name_NEG)
99 HDF5Importer h5(::file_name);
101 EXPECT_ANY_THROW(h5.importGroup("wrong"));
104 TEST(dio_hdf5_test, data_out_of_index_NEG)
108 HDF5Importer h5(::file_name);
110 h5.importGroup("value");
112 std::vector<float> buffer(6);
116 // Read non-existing data (data_idx = 1)
117 EXPECT_ANY_THROW(h5.readTensor(1, 0, &dtype, &shape, buffer.data()));
120 TEST(dio_hdf5_test, input_out_of_index_NEG)
124 HDF5Importer h5(::file_name);
126 h5.importGroup("value");
128 std::vector<float> buffer(6);
132 // Read non-existing input (input_idx = 1)
133 EXPECT_ANY_THROW(h5.readTensor(0, 1, &dtype, &shape, buffer.data()));