From 95f125129b47f80d08100f682e5885b134dbc575 Mon Sep 17 00:00:00 2001 From: "jijoong.moon" Date: Fri, 20 Mar 2020 08:17:24 +0900 Subject: [PATCH] add TensorDim Class Add TensorDim class to handle tensor demension. **Self evaluation:** 1. Build test: [X]Passed [ ]Failed [ ]Skipped 2. Run test: [X]Passed [ ]Failed [ ]Skipped Signed-off-by: jijoong.moon --- include/tensor.h | 34 ++++++++++++++++++++++++++++------ src/tensor.cpp | 18 +++++++++++++++++- 2 files changed, 45 insertions(+), 7 deletions(-) diff --git a/include/tensor.h b/include/tensor.h index c6cf050..f476e49 100644 --- a/include/tensor.h +++ b/include/tensor.h @@ -35,6 +35,9 @@ extern "C" { #include #include #include +#include + +#define MAXDIM 4 /** * @Namespace Namespace of Tensor @@ -42,12 +45,31 @@ extern "C" { */ namespace Tensors { -typedef struct { - unsigned int batch; - unsigned int channel; - unsigned int height; - unsigned int width; -} TensorDim; +class TensorDim { + public: + TensorDim() { + for (int i = 0; i < MAXDIM; ++i) { + Dim[i] = 1; + } + } + ~TensorDim(){}; + unsigned int batch() { return Dim[0]; }; + unsigned int channel() { return Dim[1]; }; + unsigned int height() { return Dim[2]; }; + unsigned int width() { return Dim[3]; }; + + void batch(unsigned int b) { Dim[0] = b; }; + void channel(unsigned int c) { Dim[1] = c; }; + void height(unsigned int h) { Dim[2] = h; }; + void width(unsigned int w) { Dim[3] = w; }; + + unsigned int *getDim() { return Dim; } + + void setTensorDim(std::string input_shape); + + private: + unsigned int Dim[4]; +}; /** * @class Tensor Class for Calculation diff --git a/src/tensor.cpp b/src/tensor.cpp index 56362d9..f87e27f 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -32,7 +32,23 @@ #include #endif -namespace Tensors{ +namespace Tensors { + +void TensorDim::setTensorDim(std::string input_shape) { + std::regex words_regex("[^\\s.,:;!?]+"); + auto words_begin = std::sregex_iterator(input_shape.begin(), input_shape.end(), words_regex); + auto words_end = std::sregex_iterator(); + int cur_dim = std::distance(words_begin, words_end); + if (cur_dim > 4) { + std::cout << "Tensor Dimension should be less than 4" << std::endl; + exit(0); + } + int cn =0; + for (std::sregex_iterator i = words_begin; i != words_end; ++i) { + Dim[MAXDIM - cur_dim + cn] = std::stoi((*i).str()); + cn++; + } +} Tensor::Tensor(int height, int width) { this->height = height; -- 2.7.4