class LoggerDataImpl {
const std::vector<LoggedFeatureSpec> LoggedFeatureSpecs;
const TensorSpec RewardSpec;
+ const bool IncludeReward;
+
+ std::vector<tensorflow::FeatureList> FeatureLists;
+ tensorflow::FeatureList Reward;
+
+ bool isSelfConsistent(const tensorflow::SequenceExample &SE,
+ size_t NrRecords) const {
+ bool Ret = true;
+ for (const auto &TSpecs : LoggedFeatureSpecs) {
+ const auto &Name = TSpecs.getLoggingName();
+ const auto &FL = SE.feature_lists().feature_list().at(Name).feature();
+ if (NrRecords != static_cast<size_t>(FL.size())) {
+ dbgs() << "[TF-UTILS]: " << Name << " has missing records. Expected "
+ << NrRecords << " got " << FL.size() << "\n";
+ Ret = false;
+ }
+ }
+ if (IncludeReward && static_cast<size_t>(SE.feature_lists()
+ .feature_list()
+ .at(RewardSpec.name())
+ .feature()
+ .size()) != NrRecords) {
+ dbgs() << "[TF-UTILS]: reward is missing records.\n";
+ Ret = false;
+ }
+ return Ret;
+ }
- tensorflow::SequenceExample SE;
- std::vector<tensorflow::FeatureList *> FeatureLists;
- tensorflow::FeatureList *Reward = nullptr;
-
-public:
- LoggerDataImpl(const std::vector<LoggedFeatureSpec> &LoggedSpecs,
- const TensorSpec &RewardSpec, bool IncludeReward)
- : LoggedFeatureSpecs(LoggedSpecs), RewardSpec(RewardSpec) {
+ void transferLog(tensorflow::SequenceExample &SE) {
auto *FL = SE.mutable_feature_lists()->mutable_feature_list();
if (IncludeReward)
- Reward = &(*FL)[RewardSpec.name()];
- // Allocate first the map entries, then capture their address. We will not
- // mutate the set of features after this (i.e. the pointers won't dangle).
- for (const auto &LFS : LoggedSpecs) {
- (*FL)[LFS.LoggingName ? *LFS.LoggingName : LFS.Spec.name()] = {};
+ (*FL)[RewardSpec.name()].Swap(&Reward);
+ assert(FeatureLists.size() == LoggedFeatureSpecs.size());
+ for (size_t I = 0; I < FeatureLists.size(); ++I) {
+ const auto &LFS = LoggedFeatureSpecs[I];
+ (*FL)[LFS.getLoggingName()].Swap(&FeatureLists[I]);
}
- for (const auto &LFS : LoggedSpecs)
- FeatureLists.push_back(
- &(*FL)[LFS.LoggingName ? *LFS.LoggingName : LFS.Spec.name()]);
}
- void print(raw_ostream &OS) {
+public:
+ LoggerDataImpl(const std::vector<LoggedFeatureSpec> &LoggedSpecs,
+ const TensorSpec &RewardSpec, bool IncludeReward)
+ : LoggedFeatureSpecs(LoggedSpecs), RewardSpec(RewardSpec),
+ IncludeReward(IncludeReward), FeatureLists(LoggedFeatureSpecs.size()) {}
+
+ // flush the logged info to a stream and clear the log contents.
+ void flush(raw_ostream &OS) {
+ size_t NrRecords = getNrRecords();
+ tensorflow::SequenceExample SE;
+ transferLog(SE);
+ assert(isSelfConsistent(SE, NrRecords));
std::string OutStr;
if (ProtobufTextMode)
google::protobuf::TextFormat::PrintToString(SE, &OutStr);
const auto &Spec = LoggedFeatureSpecs[FeatureID].Spec;
if (Spec.isElementType<float>()) {
auto *RF = FeatureLists[FeatureID]
- ->add_feature()
+ .add_feature()
->mutable_float_list()
->mutable_value();
RF->Resize(Spec.getElementCount(), 0.0);
return reinterpret_cast<char *>(RF->mutable_data());
} else if (Spec.isElementType<int32_t>() || Spec.isElementType<int64_t>()) {
auto *RF = FeatureLists[FeatureID]
- ->add_feature()
+ .add_feature()
->mutable_int64_list()
->mutable_value();
RF->Resize(Spec.getElementCount(), 0);
}
template <typename T> void logReward(T Value) {
+ assert(IncludeReward);
if (RewardSpec.isElementType<float>())
- Reward->add_feature()->mutable_float_list()->add_value(Value);
+ Reward.add_feature()->mutable_float_list()->add_value(Value);
else if (RewardSpec.isElementType<int32_t>() ||
RewardSpec.isElementType<int64_t>())
- Reward->add_feature()->mutable_int64_list()->add_value(Value);
+ Reward.add_feature()->mutable_int64_list()->add_value(Value);
else
llvm_unreachable("Unsupported tensor type.");
}
size_t getNrRecords() const {
- return FeatureLists.empty() ? 0 : FeatureLists[0]->feature().size();
+ return FeatureLists.empty() ? 0 : FeatureLists[0].feature().size();
}
};
} // namespace llvm
return reinterpret_cast<char *>(LoggerData->addNewTensor(FeatureID));
}
-void Logger::print(raw_ostream &OS) { LoggerData->print(OS); }
+void Logger::flush(raw_ostream &OS) { LoggerData->flush(OS); }
#endif // defined(LLVM_HAVE_TF_API)