lambda: self._build_dataset_bytes_stats(num_outputs),
lambda: self._build_dataset_bytes_stats(num_outputs // 10), num_outputs)
+ def _build_dataset_latency_stats(self, num_elements, tag="record_latency"):
+ return dataset_ops.Dataset.range(num_elements).apply(
+ stats_ops.latency_stats(tag))
+
+ def _build_dataset_multiple_tags(self,
+ num_elements,
+ tag1="record_latency",
+ tag2="record_latency_2"):
+ return dataset_ops.Dataset.range(num_elements).apply(
+ stats_ops.latency_stats(tag1)).apply(stats_ops.latency_stats(tag2))
+
+ def testLatencyStatsDatasetSaveableCore(self):
+ num_outputs = 100
+
+ self.run_core_tests(
+ lambda: self._build_dataset_latency_stats(num_outputs),
+ lambda: self._build_dataset_latency_stats(num_outputs // 10),
+ num_outputs)
+
+ self.run_core_tests(lambda: self._build_dataset_multiple_tags(num_outputs),
+ None, num_outputs)
+
+ tag1 = "record_latency"
+ tag2 = "record_latency"
+ self.run_core_tests(
+ lambda: self._build_dataset_multiple_tags(num_outputs, tag1, tag2),
+ None, num_outputs)
+
if __name__ == "__main__":
test.main()
DatasetBase** output) override {
string tag;
OP_REQUIRES_OK(ctx, ParseScalarArgument(ctx, "tag", &tag));
- *output = new Dataset(input, std::move(tag));
+ *output = new Dataset(ctx, input, std::move(tag));
}
private:
- class Dataset : public DatasetBase {
+ class Dataset : public GraphDatasetBase {
public:
- explicit Dataset(const DatasetBase* input, string tag)
- : input_(input), tag_(std::move(tag)) {
+ explicit Dataset(OpKernelContext* ctx, const DatasetBase* input, string tag)
+ : GraphDatasetBase(ctx), input_(input), tag_(std::move(tag)) {
input_->Ref();
}
string DebugString() override { return "LatencyStatsDatasetOp::Dataset"; }
+ protected:
+ Status AsGraphDefInternal(OpKernelContext* ctx, DatasetGraphDefBuilder* b,
+ Node** output) const override {
+ Node* input_node;
+ TF_RETURN_IF_ERROR(b->AddParentDataset(ctx, input_, &input_node));
+ Node* tag_node;
+ TF_RETURN_IF_ERROR(b->AddScalar(tag_, &tag_node));
+ TF_RETURN_IF_ERROR(b->AddDataset(this, {input_node, tag_node}, output));
+ return Status::OK();
+ }
+
private:
class Iterator : public DatasetIterator<Dataset> {
public:
Status GetNextInternal(IteratorContext* ctx,
std::vector<Tensor>* out_tensors,
bool* end_of_sequence) override {
+ tf_shared_lock l(mu_);
uint64 start = ctx->env()->NowMicros();
Status s = input_impl_->GetNext(ctx, out_tensors, end_of_sequence);
uint64 end = ctx->env()->NowMicros();
return s;
}
+ protected:
+ Status SaveInternal(IteratorStateWriter* writer) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(SaveParent(writer, input_impl_));
+ return Status::OK();
+ }
+
+ Status RestoreInternal(OpKernelContext* ctx,
+ IteratorStateReader* reader) override {
+ mutex_lock l(mu_);
+ TF_RETURN_IF_ERROR(RestoreParent(ctx, reader, input_impl_));
+ return Status::OK();
+ }
+
private:
- const std::unique_ptr<IteratorBase> input_impl_;
+ mutex mu_;
+ std::unique_ptr<IteratorBase> input_impl_ GUARDED_BY(mu_);
};
const DatasetBase* const input_;