[mlgo][nfc] Virtualize Logger implementation
authorMircea Trofin <mtrofin@google.com>
Thu, 1 Dec 2022 00:42:09 +0000 (16:42 -0800)
committerMircea Trofin <mtrofin@google.com>
Fri, 2 Dec 2022 00:03:08 +0000 (16:03 -0800)
This is in preparation for dropping the dependency on protobuf. This
first step allows us to subsequently introduce the non-protobuf
implementation behind a flag. After that we can update the training side
to ingest the new format, after which we can drop the protobuf
implementation and de-virtualize everything.

Differential Revision: https://reviews.llvm.org/D139062

llvm/lib/Analysis/TrainingLogger.cpp

index 81e8100..2aff026 100644 (file)
@@ -52,10 +52,29 @@ void serialize(const Message &SE, std::string *OutStr) {
 namespace llvm {
 
 class LoggerDataImpl {
+protected:
   const std::vector<TensorSpec> LoggedFeatureSpecs;
   const TensorSpec RewardSpec;
   const bool IncludeReward;
+  LoggerDataImpl(const std::vector<TensorSpec> &LoggedSpecs,
+                 const TensorSpec &RewardSpec, bool IncludeReward)
+      : LoggedFeatureSpecs(LoggedSpecs), RewardSpec(RewardSpec),
+        IncludeReward(IncludeReward) {}
+  virtual void logRewardImpl(const char *Value, size_t Size) = 0;
+
+public:
+  // flush the logged info to a stream and clear the log contents.
+  virtual void flush(std::string *Str) = 0;
+  virtual char *addNewTensor(size_t FeatureID) = 0;
+  virtual size_t getNrRecords() const = 0;
+  virtual ~LoggerDataImpl() = default;
+
+  template <typename T> void logReward(T Value) {
+    logRewardImpl(reinterpret_cast<const char *>(&Value), sizeof(T));
+  }
+};
 
+class TFSequenceExampleLoggerDataImpl : public LoggerDataImpl {
   std::vector<tensorflow::FeatureList> FeatureLists;
   tensorflow::FeatureList Reward;
 
@@ -94,13 +113,14 @@ class LoggerDataImpl {
   }
 
 public:
-  LoggerDataImpl(const std::vector<TensorSpec> &LoggedSpecs,
-                 const TensorSpec &RewardSpec, bool IncludeReward)
-      : LoggedFeatureSpecs(LoggedSpecs), RewardSpec(RewardSpec),
-        IncludeReward(IncludeReward), FeatureLists(LoggedFeatureSpecs.size()) {}
+  TFSequenceExampleLoggerDataImpl(const std::vector<TensorSpec> &LoggedSpecs,
+                                  const TensorSpec &RewardSpec,
+                                  bool IncludeReward)
+      : LoggerDataImpl(LoggedSpecs, RewardSpec, IncludeReward),
+        FeatureLists(LoggedFeatureSpecs.size()) {}
 
   // flush the logged info to a stream and clear the log contents.
-  void flush(std::string *Str) {
+  void flush(std::string *Str) override {
     size_t NrRecords = getNrRecords();
     (void)NrRecords;
     tensorflow::SequenceExample SE;
@@ -109,7 +129,7 @@ public:
     serialize(SE, Str);
   }
 
-  char *addNewTensor(size_t FeatureID) {
+  char *addNewTensor(size_t FeatureID) override {
     const auto &Spec = LoggedFeatureSpecs[FeatureID];
     if (Spec.isElementType<float>()) {
       auto *RF = FeatureLists[FeatureID]
@@ -129,18 +149,22 @@ public:
     llvm_unreachable("Unsupported tensor type.");
   }
 
-  template <typename T> void logReward(T Value) {
+  void logRewardImpl(const char *Value, size_t Size) override {
     assert(IncludeReward);
     if (RewardSpec.isElementType<float>())
-      Reward.add_feature()->mutable_float_list()->add_value(Value);
-    else if (RewardSpec.isElementType<int32_t>() ||
-             RewardSpec.isElementType<int64_t>())
-      Reward.add_feature()->mutable_int64_list()->add_value(Value);
+      Reward.add_feature()->mutable_float_list()->add_value(
+          *reinterpret_cast<const float *>(Value));
+    else if (RewardSpec.isElementType<int32_t>())
+      Reward.add_feature()->mutable_int64_list()->add_value(
+          *reinterpret_cast<const int32_t *>(Value));
+    else if (RewardSpec.isElementType<int64_t>())
+      Reward.add_feature()->mutable_int64_list()->add_value(
+          *reinterpret_cast<const int64_t *>(Value));
     else
       llvm_unreachable("Unsupported tensor type.");
   }
 
-  size_t getNrRecords() const {
+  size_t getNrRecords() const override {
     return FeatureLists.empty() ? 0 : FeatureLists[0].feature().size();
   }
 };
@@ -150,8 +174,8 @@ Logger::Logger(const std::vector<TensorSpec> &FeatureSpecs,
                const TensorSpec &RewardSpec, bool IncludeReward)
     : FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec),
       IncludeReward(IncludeReward),
-      LoggerData(std::make_unique<LoggerDataImpl>(FeatureSpecs, RewardSpec,
-                                                  IncludeReward)) {}
+      LoggerData(std::make_unique<TFSequenceExampleLoggerDataImpl>(
+          FeatureSpecs, RewardSpec, IncludeReward)) {}
 
 Logger::~Logger() {}