StringMap<size_t> ObservationIDs;
std::string CurrentContext;
- void writeHeader();
+ void writeHeader(std::optional<TensorSpec> AdviceSpec);
void writeTensor(const TensorSpec &Spec, const char *RawData) {
OS->write(RawData, Spec.getTotalTensorBufferSize());
}
/// corresponding to the model being trained/logged.
Logger(std::unique_ptr<raw_ostream> OS,
const std::vector<TensorSpec> &FeatureSpecs,
- const TensorSpec &RewardSpec, bool IncludeReward);
+ const TensorSpec &RewardSpec, bool IncludeReward,
+ std::optional<TensorSpec> AdviceSpec = std::nullopt);
void switchContext(StringRef Name);
void startObservation();
InputSpecs(Inputs), OutputSpec(Advice), Inbound(InboundName, InEC),
OutputBuffer(OutputSpec.getTotalTensorBufferSize()),
Log(std::make_unique<raw_fd_ostream>(OutboundName, OutEC), InputSpecs,
- Advice, /*IncludeReward=*/false) {
+ Advice, /*IncludeReward=*/false, Advice) {
if (InEC) {
Ctx.emitError("Cannot open inbound file: " + InEC.message());
return;
UseSimpleLogger("tfutils-use-simplelogger", cl::init(true), cl::Hidden,
cl::desc("Output simple (non-protobuf) log."));
-void Logger::writeHeader() {
+void Logger::writeHeader(std::optional<TensorSpec> AdviceSpec) {
json::OStream JOS(*OS);
JOS.object([&]() {
JOS.attributeArray("features", [&]() {
RewardSpec.toJSON(JOS);
JOS.attributeEnd();
}
+ if (AdviceSpec.has_value()) {
+ JOS.attributeBegin("advice");
+ AdviceSpec->toJSON(JOS);
+ JOS.attributeEnd();
+ }
});
*OS << "\n";
}
Logger::Logger(std::unique_ptr<raw_ostream> OS,
const std::vector<TensorSpec> &FeatureSpecs,
- const TensorSpec &RewardSpec, bool IncludeReward)
+ const TensorSpec &RewardSpec, bool IncludeReward,
+ std::optional<TensorSpec> AdviceSpec)
: OS(std::move(OS)), FeatureSpecs(FeatureSpecs), RewardSpec(RewardSpec),
IncludeReward(IncludeReward) {
- writeHeader();
+ writeHeader(AdviceSpec);
}
auto Header = json::parse(ReadLn());
EXPECT_FALSE(Header.takeError());
EXPECT_NE(Header->getAsObject()->getArray("features"), nullptr);
+ EXPECT_NE(Header->getAsObject()->getObject("advice"), nullptr);
// Then comes the context
EXPECT_FALSE(json::parse(ReadLn()).takeError());