)
cc_library(
+ name = "gce_env_utils",
+ srcs = ["gce_env_utils.cc"],
+ hdrs = ["gce_env_utils.h"],
+ copts = tf_copts(),
+ visibility = ["//visibility:private"],
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
+cc_library(
name = "gcs_file_system",
srcs = ["gcs_file_system.cc"],
hdrs = ["gcs_file_system.h"],
visibility = ["//tensorflow:__subpackages__"],
deps = [
":curl_http_request",
+ ":gce_env_utils",
":oauth_client",
":retrying_utils",
"//tensorflow/core:lib",
],
)
+cc_library(
+ name = "fake_env",
+ srcs = [
+ "fake_env.cc",
+ ],
+ hdrs = [
+ "fake_env.h",
+ ],
+ copts = tf_copts(),
+ deps = [
+ "//tensorflow/core:framework_headers_lib",
+ "//tensorflow/core:lib_internal",
+ ],
+)
+
tf_cc_test(
name = "expiring_lru_cache_test",
size = "small",
"testdata/service_account_credentials.json",
],
deps = [
+ ":fake_env",
":google_auth_provider",
":http_request_fake",
":oauth_client",
"//tensorflow/core:test_main",
],
)
+
+tf_cc_test(
+ name = "gce_env_utils_test",
+ size = "small",
+ srcs = ["gcp_env_utils_test.cc"],
+ deps = [
+ ":fake_env",
+ ":gce_env_utils",
+ "//tensorflow/core:test",
+ "//tensorflow/core:test_main",
+ ],
+)
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/cloud/fake_env.h"
+
+namespace tensorflow {
+namespace test {
+
+Status FakeEnv::FakeRandomAccessFile::Read(uint64 offset, size_t n,
+ StringPiece* result,
+ char* scratch) const {
+ CHECK_EQ(offset, 0);
+ CHECK_EQ(n, 256);
+ Status s;
+ string platform;
+ switch (env_type_) {
+ case kGoogle: {
+ platform = "Google\n ";
+ s = errors::OutOfRange("");
+ break;
+ }
+ case kGce: {
+ platform = " Google Compute Engine\n ";
+ s = errors::OutOfRange("");
+ break;
+ }
+ case kLocal: {
+ platform = "HP Linux Workstation";
+ s = Status::OK();
+ break;
+ }
+ case kBad: {
+ platform = "";
+ s = errors::Internal("Expected");
+ break;
+ }
+ }
+ strncpy(scratch, platform.data(), strlen(platform.data()));
+ *result = StringPiece(scratch, platform.length());
+ return s;
+}
+
+Status FakeEnv::NewRandomAccessFile(const string& fname,
+ std::unique_ptr<RandomAccessFile>* result) {
+ result->reset(new FakeRandomAccessFile(env_type_));
+ return Status::OK();
+}
+
+} // namespace test
+} // namespace tensorflow
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_FAKE_ENV_H_
+#define TENSORFLOW_CORE_PLATFORM_CLOUD_FAKE_ENV_H_
+
+#include "tensorflow/core/platform/env.h"
+
+namespace tensorflow {
+namespace test {
+
+/// Env implementation that stubs out the calls to read a file and time.
+class FakeEnv : public EnvWrapper {
+ public:
+ enum EnvType {
+ kGoogle,
+ kGce,
+ kLocal,
+ kBad,
+ };
+
+ FakeEnv(EnvType env_type) : EnvWrapper(Env::Default()), env_type_(env_type) {}
+
+ class FakeRandomAccessFile : public RandomAccessFile {
+ public:
+ FakeRandomAccessFile(EnvType env_type) : env_type_(env_type) {}
+
+ Status Read(uint64 offset, size_t n, StringPiece* result,
+ char* scratch) const override;
+
+ private:
+ EnvType env_type_;
+ };
+
+ Status NewRandomAccessFile(
+ const string& fname, std::unique_ptr<RandomAccessFile>* result) override;
+
+ uint64 NowSeconds() override { return now; }
+ uint64 now = 10000;
+
+ private:
+ EnvType env_type_;
+};
+
+} // namespace test
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_FAKE_ENV_H_
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/cloud/gce_env_utils.h"
+
+#if defined(PLATFORM_WINDOWS)
+#include <algorithm>
+#include <cctype>
+#include <iostream>
+#include <string>
+
+// The order if these includes is important, windows.h has to come first.
+// clang-format off
+#include <windows.h> // NOLINT
+#include <tchar.h> // NOLINT
+#include <shellapi.h> // NOLINT
+// clang-format on
+#else
+#include "tensorflow/core/lib/gtl/stl_util.h"
+#include "tensorflow/core/lib/strings/str_util.h"
+#endif
+
+namespace tensorflow {
+
+constexpr char kExpectedGoogleProductName[] = "Google";
+constexpr char kExpectedGceProductName[] = "Google Compute Engine";
+
+constexpr char kWinCheckCommand[] = "powershell.exe";
+constexpr char kWinCheckCommandArgs[] =
+ "(Get-WmiObject -Class Win32_BIOS).Manufacturer";
+
+constexpr char kLinuxProductNameFile[] = "/sys/class/dmi/id/product_name";
+
+const size_t kBiosDataBufferSize = 256;
+
+namespace {
+
+#if defined(PLATFORM_WINDOWS)
+
+Status IsRunningOnWinGce(bool* is_running_under_gce) {
+ *is_running_under_gce = FALSE;
+ SECURITY_ATTRIBUTES sa;
+ sa.nLength = sizeof(sa);
+ sa.lpSecurityDescriptor = NULL;
+ sa.bInheritHandle = TRUE;
+
+ // Handles to input and output of the pipe connecting us
+ // to the child process running powershell(). The output of this
+ // child process will be written to 'process_output_in' and read from
+ // 'process_output_in'.
+ HANDLE process_output_out = NULL;
+ HANDLE process_output_in = NULL;
+
+ // Create the actually pipe connecting us to the child process.
+ if (!CreatePipe(&process_output_out, &process_output_in, &sa, 0)) {
+ return errors::Internal("CreatePipe() failed");
+ }
+ if (!SetHandleInformation(process_output_out, HANDLE_FLAG_INHERIT, 0)) {
+ return errors::Internal("SetHandleInformation() failed");
+ }
+
+ PROCESS_INFORMATION pi;
+ STARTUPINFO si;
+ DWORD flags = CREATE_NO_WINDOW;
+ ZeroMemory(&pi, sizeof(pi));
+ ZeroMemory(&si, sizeof(si));
+ si.cb = sizeof(si);
+ si.dwFlags |= STARTF_USESTDHANDLES;
+ si.hStdInput = NULL;
+
+ // Connect the process to pipe's input.
+ si.hStdError = process_output_in;
+ si.hStdOutput = process_output_in;
+ // Execute (and wait for) powershell command to read the product information
+ // out of the registry.
+ TCHAR cmd[kBiosDataBufferSize];
+ snprintf(cmd, kBiosDataBufferSize, "%s %s", _T(kWinCheckCommand),
+ _T(kWinCheckCommandArgs));
+
+ if (!CreateProcess(NULL, cmd, NULL, NULL, TRUE, flags, NULL, NULL, &si,
+ &pi)) {
+ return errors::Internal("CreateProcess() failed");
+ }
+
+ WaitForSingleObject(pi.hProcess, INFINITE);
+ CloseHandle(pi.hProcess);
+ CloseHandle(pi.hThread);
+
+ // Read data from the pipe. Note that we are reading only kBiosDataBufferSize
+ // chars. There might be technically more data than that but we are looking
+ // for Google product identifiers that are much shorter than
+ // kBiosDataBufferSize.
+ DWORD dwread = 0;
+ CHAR buffer[kBiosDataBufferSize];
+ if (!ReadFile(process_output_out, buffer, kBiosDataBufferSize, &dwread,
+ NULL)) {
+ return errors::Internal("Failed reading from the pipe.");
+ }
+ std::string output(buffer, 0, dwread);
+ // Trim whitespaces
+ output.erase(output.begin(),
+ std::find_if(output.begin(), output.end(),
+ [](int ch) { return !std::isspace(ch); }));
+ output.erase(std::find_if(output.rbegin(), output.rend(),
+ [](int ch) { return !std::isspace(ch); })
+ .base(),
+ output.end());
+ *is_running_under_gce =
+ output == kExpectedGceProductName || output == kExpectedGoogleProductName;
+ return Status::OK();
+}
+
+#else
+
+Status IsRunningOnLinuxGce(Env* env, bool* is_running_under_gce) {
+ std::unique_ptr<RandomAccessFile> file;
+ TF_RETURN_IF_ERROR(env->NewRandomAccessFile(kLinuxProductNameFile, &file));
+ char buf[kBiosDataBufferSize + 1];
+ std::fill(buf, buf + kBiosDataBufferSize + 1, '\0');
+ StringPiece product_name;
+ const Status s = file->Read(0, kBiosDataBufferSize, &product_name, buf);
+ if (!s.ok() && !errors::IsOutOfRange(s)) {
+ // We expect OutOfRange error because bios file doesn't correspond to its
+ // state size,
+ return s;
+ }
+ str_util::RemoveLeadingWhitespace(&product_name);
+ str_util::RemoveTrailingWhitespace(&product_name);
+ *is_running_under_gce = (product_name == kExpectedGceProductName ||
+ product_name == kExpectedGoogleProductName);
+ return Status::OK();
+}
+
+#endif
+
+} // namespace
+
+Status IsRunningOnGce(Env* env, bool* is_running_under_gce) {
+ *is_running_under_gce = false;
+#if defined(PLATFORM_WINDOWS)
+ return IsRunningOnWinGce(is_running_under_gce);
+#else
+ return IsRunningOnLinuxGce(env, is_running_under_gce);
+#endif
+}
+
+} // namespace tensorflow
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#ifndef TENSORFLOW_CORE_PLATFORM_CLOUD_GCE_ENV_UTILS_H_
+#define TENSORFLOW_CORE_PLATFORM_CLOUD_GCE_ENV_UTILS_H_
+
+#include "tensorflow/core/lib/core/status.h"
+#include "tensorflow/core/platform/env.h"
+
+namespace tensorflow {
+
+// Check whether the current process is running under GCE.
+Status IsRunningOnGce(Env* env, bool* is_running_under_gce);
+
+} // namespace tensorflow
+
+#endif // TENSORFLOW_CORE_PLATFORM_CLOUD_GCE_ENV_UTILS_H_
--- /dev/null
+/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+==============================================================================*/
+
+#include "tensorflow/core/platform/cloud/gce_env_utils.h"
+
+#include "tensorflow/core/lib/core/status_test_util.h"
+#include "tensorflow/core/platform/cloud/fake_env.h"
+#include "tensorflow/core/platform/test.h"
+
+namespace tensorflow {
+
+namespace {
+
+TEST(GcpEnvUtils, IsRunningOnGce) {
+ {
+ test::FakeEnv env(test::FakeEnv::kGoogle);
+ bool is_running_on_gcp = false;
+ TF_EXPECT_OK(IsRunningOnGce(&env, &is_running_on_gcp));
+ EXPECT_TRUE(is_running_on_gcp);
+ }
+ {
+ test::FakeEnv env(test::FakeEnv::kGce);
+ bool is_running_on_gcp = false;
+ TF_EXPECT_OK(IsRunningOnGce(&env, &is_running_on_gcp));
+ EXPECT_TRUE(is_running_on_gcp);
+ }
+ {
+ test::FakeEnv env(test::FakeEnv::kLocal);
+ bool is_running_on_gcp = false;
+ TF_EXPECT_OK(IsRunningOnGce(&env, &is_running_on_gcp));
+ EXPECT_FALSE(is_running_on_gcp);
+ }
+ {
+ test::FakeEnv env(test::FakeEnv::kBad);
+ bool is_running_on_gcp = false;
+ EXPECT_TRUE(errors::IsInternal(IsRunningOnGce(&env, &is_running_on_gcp)));
+ }
+}
+
+} // namespace
+} // namespace tensorflow
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/base64.h"
#include "tensorflow/core/platform/cloud/curl_http_request.h"
+#include "tensorflow/core/platform/cloud/gce_env_utils.h"
#include "tensorflow/core/platform/cloud/retrying_utils.h"
#include "tensorflow/core/platform/env.h"
}
Status GoogleAuthProvider::GetTokenFromGce() {
+ if (!is_running_on_gce_.has_value()) {
+ bool is_running_on_gce = false;
+ TF_RETURN_IF_ERROR(IsRunningOnGce(env_, &is_running_on_gce));
+ is_running_on_gce_ = is_running_on_gce;
+ }
+ if (!is_running_on_gce_.value()) {
+ // Assume bucket is world-accessible. If not, the access will be rejected.
+ current_token_ = "";
+ return Status::OK();
+ }
const auto get_token_from_gce = [this]() {
std::unique_ptr<HttpRequest> request(http_request_factory_->Create());
std::vector<char> response_buffer;
#define TENSORFLOW_CORE_PLATFORM_GOOGLE_AUTH_PROVIDER_H_
#include <memory>
+#include "tensorflow/core/lib/gtl/optional.h"
#include "tensorflow/core/platform/cloud/auth_provider.h"
#include "tensorflow/core/platform/cloud/oauth_client.h"
#include "tensorflow/core/platform/mutex.h"
/// standard gcloud tool's location.
Status GetTokenFromFiles() EXCLUSIVE_LOCKS_REQUIRED(mu_);
- /// Gets the bearer token from Google Compute Engine environment.
+ /// Gets the bearer token from Google Compute Engine environment. May return
+ /// an empty token if the current process is not running under GCE. If that
+ /// happens the caller will try to use the empty token and either succeed
+ /// if the resource is publicly accessible or fail with a permissions error.
Status GetTokenFromGce() EXCLUSIVE_LOCKS_REQUIRED(mu_);
/// Gets the bearer token from the systen env variable, for testing purposes.
Env* env_;
mutex mu_;
string current_token_ GUARDED_BY(mu_);
+ tensorflow::gtl::optional<bool> is_running_on_gce_ GUARDED_BY(mu_);
uint64 expiration_timestamp_sec_ GUARDED_BY(mu_) = 0;
// The initial delay for exponential backoffs when retrying failed calls.
const int64 initial_retry_delay_usec_;
#include <stdlib.h>
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
+#include "tensorflow/core/platform/cloud/fake_env.h"
#include "tensorflow/core/platform/cloud/http_request_fake.h"
#include "tensorflow/core/platform/test.h"
constexpr char kTestData[] = "core/platform/cloud/testdata/";
-class FakeEnv : public EnvWrapper {
- public:
- FakeEnv() : EnvWrapper(Env::Default()) {}
-
- uint64 NowSeconds() override { return now; }
- uint64 now = 10000;
-};
-
class FakeOAuthClient : public OAuthClient {
public:
Status GetTokenFromServiceAccountJson(
auto oauth_client = new FakeOAuthClient;
std::vector<HttpRequest*> requests;
- FakeEnv env;
+ test::FakeEnv env(test::FakeEnv::kGoogle);
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)),
auto oauth_client = new FakeOAuthClient;
std::vector<HttpRequest*> requests;
- FakeEnv env;
+ test::FakeEnv env(test::FakeEnv::kGoogle);
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)),
"token_type":"Bearer"
})")});
- FakeEnv env;
+ test::FakeEnv env(test::FakeEnv::kGoogle);
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)),
auto oauth_client = new FakeOAuthClient;
std::vector<HttpRequest*> empty_requests;
- FakeEnv env;
+ test::FakeEnv env(test::FakeEnv::kGoogle);
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&empty_requests)),
"Header Metadata-Flavor: Google\n",
"", errors::NotFound("404"), 404)});
- FakeEnv env;
+ test::FakeEnv env(test::FakeEnv::kGoogle);
+ GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
+ std::unique_ptr<HttpRequest::Factory>(
+ new FakeHttpRequestFactory(&requests)),
+ &env, 0);
+
+ string token;
+ TF_EXPECT_OK(provider.GetToken(&token));
+ EXPECT_EQ("", token);
+}
+
+TEST_F(GoogleAuthProviderTest, AccessingPublicBucket) {
+ setenv("CLOUDSDK_CONFIG",
+ io::JoinPath(testing::TensorFlowSrcRoot(), kTestData).c_str(), 1);
+
+ auto oauth_client = new FakeOAuthClient;
+ std::vector<HttpRequest*> requests;
+
+ test::FakeEnv env(test::FakeEnv::kLocal);
GoogleAuthProvider provider(std::unique_ptr<OAuthClient>(oauth_client),
std::unique_ptr<HttpRequest::Factory>(
new FakeHttpRequestFactory(&requests)),
string token;
TF_EXPECT_OK(provider.GetToken(&token));
+ // We are assuming we are accessing a public bucket (and we are not running
+ // on GCE) so we an empty token is returned.
EXPECT_EQ("", token);
}
/// The ownership of the returned RandomAccessFile is passed to the caller
/// and the object should be deleted when is not used. The file object
/// shouldn't live longer than the Env object.
- Status NewRandomAccessFile(const string& fname,
- std::unique_ptr<RandomAccessFile>* result);
+ virtual Status NewRandomAccessFile(const string& fname,
+ std::unique_ptr<RandomAccessFile>* result);
/// \brief Creates an object that writes to a new file with the specified
/// name.