namespace llvm {
class LoggerDataImpl {
+protected:
const std::vector<TensorSpec> LoggedFeatureSpecs;
const TensorSpec RewardSpec;
const bool IncludeReward;
+ LoggerDataImpl(const std::vector<TensorSpec> &LoggedSpecs,
+ const TensorSpec &RewardSpec, bool IncludeReward)
+ : LoggedFeatureSpecs(LoggedSpecs), RewardSpec(RewardSpec),
+ IncludeReward(IncludeReward) {}
+ virtual void logRewardImpl(const char *Value, size_t Size) = 0;
+
+public:
+ // flush the logged info to a stream and clear the log contents.
+ virtual void flush(std::string *Str) = 0;
+ virtual char *addNewTensor(size_t FeatureID) = 0;
+ virtual size_t getNrRecords() const = 0;
+ virtual ~LoggerDataImpl() = default;
+
+ template <typename T> void logReward(T Value) {
+ logRewardImpl(reinterpret_cast<const char *>(&Value), sizeof(T));
+ }
+};
+class TFSequenceExampleLoggerDataImpl : public LoggerDataImpl {
std::vector<tensorflow::FeatureList> FeatureLists;
tensorflow::FeatureList Reward;
}
public:
- LoggerDataImpl(const std::vector<TensorSpec> &LoggedSpecs,
- const TensorSpec &RewardSpec, bool IncludeReward)
- : LoggedFeatureSpecs(LoggedSpecs), RewardSpec(RewardSpec),
- IncludeReward(IncludeReward), FeatureLists(LoggedFeatureSpecs.size()) {}
+ TFSequenceExampleLoggerDataImpl(const std::vector<TensorSpec> &LoggedSpecs,
+ const TensorSpec &RewardSpec,
+ bool IncludeReward)
+ : LoggerDataImpl(LoggedSpecs, RewardSpec, IncludeReward),
+ FeatureLists(LoggedFeatureSpecs.size()) {}
// flush the logged info to a stream and clear the log contents.
- void flush(std::string *Str) {
+ void flush(std::string *Str) override {
size_t NrRecords = getNrRecords();
(void)NrRecords;
tensorflow::SequenceExample SE;
serialize(SE, Str);
}
- char *addNewTensor(size_t FeatureID) {
+ char *addNewTensor(size_t FeatureID) override {
const auto &Spec = LoggedFeatureSpecs[FeatureID];
if (Spec.isElementType<float>()) {
auto *RF = FeatureLists[FeatureID]
llvm_unreachable("Unsupported tensor type.");
}
- template <typename T> void logReward(T Value) {
+ void logRewardImpl(const char *Value, size_t Size) override {
assert(IncludeReward);
if (RewardSpec.isElementType<float>())
- Reward.add_feature()->mutable_float_list()->add_value(Value);
- else if (RewardSpec.isElementType<int32_t>() ||
- RewardSpec.isElementType<int64_t>())
- Reward.add_feature()->mutable_int64_list()->add_value(Value);
+ Reward.add_feature()->mutable_float_list()->add_value(
+ *reinterpret_cast<const float *>(Value));
+ else if (RewardSpec.isElementType<int32_t>())
+ Reward.add_feature()->mutable_int64_list()->add_value(
+ *reinterpret_cast<const int32_t *>(Value));
+ else if (RewardSpec.isElementType<int64_t>())
+ Reward.add_feature()->mutable_int64_list()->add_value(
+ *reinterpret_cast<const int64_t *>(Value));
else
llvm_unreachable("Unsupported tensor type.");
}
- size_t getNrRecords() const {
+ size_t getNrRecords() const override {
return FeatureLists.empty() ? 0 : FeatureLists[0].feature().size();
}
};
const TensorSpec &RewardSpec, bool IncludeReward)
: FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec),
IncludeReward(IncludeReward),
- LoggerData(std::make_unique<LoggerDataImpl>(FeatureSpecs, RewardSpec,
- IncludeReward)) {}
+ LoggerData(std::make_unique<TFSequenceExampleLoggerDataImpl>(
+ FeatureSpecs, RewardSpec, IncludeReward)) {}
Logger::~Logger() {}