#pragma once
#include "kernel_selector_common.h"
-#include "kernel_selector_params.h"
-#include "tensor_type.h"
#include <sstream>
#include <cmath>
namespace kernel_selector {
+struct base_params;
+
using JitDefinitions = std::vector<std::pair<std::string, std::string>>;
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template<typename T>
std::string toCodeString(T val) { return std::to_string(val); }
-template<>
-inline std::string toCodeString<std::string>(std::string val) { return val; }
-
-template<>
-inline std::string toCodeString<const char*>(const char* val) { return val; }
-
-template<>
-inline std::string toCodeString<char*>(char* val) { return val; }
-
-template<>
-inline std::string toCodeString<bool>(bool val)
-{
- std::stringstream ss;
- ss << static_cast<int>(val);
- return ss.str();
-}
-
-template<>
-inline std::string toCodeString<const bool>(const bool val)
-{
- std::stringstream ss;
- ss << static_cast<int>(val);
- return ss.str();
-}
-
-template<>
-inline std::string toCodeString<float>(float val) {
- if (std::isinf(val))
- return std::signbit(val) ? "-INFINITY" : "INFINITY";
- std::stringstream ss;
-#ifdef __GNUC__
- // Workaround GCC compiler/STL bug
- ss << "as_float(0x" << std::hex << *reinterpret_cast<uint32_t*>(&val) << ")";
-#else
- ss << std::hexfloat << val << "f";
-#endif
- ss << " /*" << std::scientific << val << "*/";
- return ss.str();
-}
-
-template<>
-inline std::string toCodeString<double>(double val) {
- if (std::isinf(val))
- return std::signbit(val) ? "-INFINITY" : "INFINITY";
- std::stringstream ss;
-#ifdef __GNUC__
- // Workaround GCC compiler/STL bug
- ss << "as_double(0x" << std::hex << *reinterpret_cast<uint64_t*>(&val) << ")";
-#else
- ss << std::hexfloat << val;
-#endif
- ss << " /*" << std::scientific << val << "*/";
- return ss.str();
-}
+inline std::string toCodeString(const std::string& val) { return val; }
+inline std::string toCodeString(const char* val) { return val; }
+inline std::string toCodeString(bool val) { return val ? "1" : "0"; }
+std::string toCodeString(float val);
+std::string toCodeString(double val);
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// JitConstant
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
template <typename VecT, typename ValT, typename Func>
-inline std::string toVectorString(const VecT& vec, const std::string& vertorType, size_t maxDim, ValT padFillingVal, Func fetchFunc)
+inline std::string toVectorString(const VecT& vec, const std::string& vectorType, size_t maxDim, ValT padFillingVal, Func fetchFunc)
{
std::stringstream ss;
- ss << "(" << vertorType << " []){ ";
+ ss << "(" << vectorType << " []){ ";
for (size_t i = 0; i < vec.size(); i++)
ss << toCodeString(fetchFunc(vec[i])) << ",";
for (size_t i = vec.size(); i < maxDim; i++)
return std::static_pointer_cast<JitConstant>(std::make_shared<simple_jit_constant>(name, toCodeString(value)));
}
-////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
-// TensorBaseTJitConstant
-////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
-template<typename DType, typename Layout>
-class TensorBaseTJitConstant : public JitConstant
-{
-protected:
- TensorBaseTJitConstant(const std::string& name) : JitConstant(name) {}
-
-public:
-
- JitDefinitions GetDefinitions(const Tensor::TensorBaseT<DType, Layout>& t) const
- {
- JitDefinitions definitions{
- { _name + "_TYPE", toCLType(t.GetDType()) },
- { _name + "_OFFSET", toCodeString(t.GetFirstElementOffset()) },
- { _name + "_VIEW_OFFSET", toCodeString(t.GetViewOffset()) },
- { _name + "_LENGTH", toCodeString(t.LogicalSize()) },
- { _name + "_DIMS", toCodeString(t.GetDims().size()) },
- { _name + "_SIMPLE", toCodeString(t.SimpleLayout()) },
- { "TO_" + _name + "_TYPE", "convert_" + toCLType(t.GetDType()) },
- { _name + "_LAYOUT_" + toString(t.GetLayout()), "1" },
- };
-
- definitions.push_back({ _name + "_SIZE", toCodeString(t.GetDims().size()) });
- definitions.push_back({ _name + "_SIZES", toVectorString(t.GetDims(), "size_t", KERNEL_SELECTOR_TENSOR_DIM_MAX, 1, [](const Tensor::Dim& d) { return d.v; }) });
- definitions.push_back({ _name + "_PITCHES", toVectorString(t.GetDims(), "size_t", KERNEL_SELECTOR_TENSOR_DIM_MAX, 1, [](const Tensor::Dim& d) { return d.pitch; }) });
- definitions.push_back({ _name + "_PAD_BEFORE", toVectorString(t.GetDims(), "size_t", KERNEL_SELECTOR_TENSOR_DIM_MAX, 0, [](const Tensor::Dim& d) { return d.pad.before; }) });
- definitions.push_back({ _name + "_PAD_AFTER", toVectorString(t.GetDims(), "size_t", KERNEL_SELECTOR_TENSOR_DIM_MAX, 0, [](const Tensor::Dim& d) { return d.pad.after; }) });
-
- return definitions;
- }
-};
-
-////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
-// DataTensorJitConstant
-////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
-class DataTensorJitConstant : public TensorBaseTJitConstant<Datatype, DataLayout>
-{
- const DataTensor _tensor;
-
-public:
- DataTensorJitConstant(const std::string& name, const DataTensor& t) : TensorBaseTJitConstant(name), _tensor(t) {}
-
- JitDefinitions GetDefinitions() const override;
-};
-
-inline std::shared_ptr<JitConstant> MakeJitConstant(const std::string& name, const DataTensor& value)
-{
- return std::static_pointer_cast<JitConstant>(std::make_shared<DataTensorJitConstant>(name, value));
-}
-
-////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
-// WeightTensorJitConstant
-////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
-class WeightTensorJitConstant : public TensorBaseTJitConstant<WeightsType, WeightsLayout>
-{
- const WeightsTensor _tensor;
-
-public:
- WeightTensorJitConstant(const std::string& name, const WeightsTensor& t) : TensorBaseTJitConstant(name), _tensor(t) {}
-
- JitDefinitions GetDefinitions() const override;
-};
-
-inline std::shared_ptr<JitConstant> MakeJitConstant(const std::string& name, const WeightsTensor& value)
-{
- return std::static_pointer_cast<JitConstant>(std::make_shared<WeightTensorJitConstant>(name, value));
-}
+std::shared_ptr<JitConstant> MakeJitConstant(const std::string& name, const struct Tensor::DataTensor& value);
+std::shared_ptr<JitConstant> MakeJitConstant(const std::string& name, const struct Tensor::WeightsTensor& value);
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
// VectorDataJitConstant
JitDefinitions GetDefinitions() const;
};
+JitConstants MakeActivationJitConstants(const base_activation_params& params, const std::string& suffix="");
JitConstants MakeBaseParamsJitConstants(const base_params& params);
JitConstants MakeLoopUnrollParamsJitConstants(uint32_t loopCount);
JitConstants MakeUnitTypeJitConstants(Datatype dataType);