template <typename NativeT, typename FnType>
Status Populate(const FnType& generator);
+ // A parallel version of Populate(). This can be used if the generator is
+ // thread-safe and the values for the shape's different elements are
+ // independent.
+ template <typename NativeT, typename FnType>
+ Status PopulateParallel(const FnType& generator);
+
// Fills this literal with the given value.
template <typename NativeT>
void PopulateWithValue(NativeT value);
// buffer).
void DeallocateBuffers();
+ // Implementation details shared between Populate() and PopulateParallel()
+ template <typename NativeT, typename FnType>
+ Status PopulateInternal(const FnType& generator, bool parallel);
+
Shape shape_;
ShapeTree<Piece> pieces_;
}
template <typename NativeT, typename FnType>
-Status Literal::Populate(const FnType& generator) {
+Status Literal::PopulateInternal(const FnType& generator, bool parallel) {
const Shape& this_shape = shape();
const int64 rank = ShapeUtil::Rank(this_shape);
TF_RET_CHECK(LayoutUtil::IsDenseArray(this_shape));
if (rank > 0) {
StrideConfig stride_config(this_shape, this_shape,
AsInt64Slice(this_shape.dimensions()));
- DimensionVector minor_scan_indexes(rank, 0);
int64 minor_dimension_size =
ShapeUtil::GetDimension(this_shape, stride_config.minor_dimension);
auto init_function = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
+ DimensionVector minor_scan_indexes(rank, 0);
const int64 index =
IndexUtil::MultidimensionalIndexToLinearIndex(shape(), indexes);
std::copy(indexes.begin(), indexes.end(), minor_scan_indexes.begin());
minor_scan_indexes[stride_config.minor_dimension] = i;
literal_data.at(index + i) = generator(minor_scan_indexes);
}
- return true;
};
- ShapeUtil::ForEachIndex(this_shape, stride_config.base,
- stride_config.dimensions, stride_config.step,
- init_function);
+ if (parallel) {
+ ShapeUtil::ForEachIndexParallel(this_shape, stride_config.base,
+ stride_config.dimensions,
+ stride_config.step, init_function);
+ } else {
+ ShapeUtil::ForEachIndex(
+ this_shape, stride_config.base, stride_config.dimensions,
+ stride_config.step,
+ [&init_function](tensorflow::gtl::ArraySlice<int64> indexes) {
+ init_function(indexes);
+ return true;
+ });
+ }
} else {
// For scalars.
literal_data.at(0) = generator({});
}
return Status::OK();
}
+template <typename NativeT, typename FnType>
+Status Literal::Populate(const FnType& generator) {
+ return PopulateInternal<NativeT>(generator, /*parallel=*/false);
+}
+
+template <typename NativeT, typename FnType>
+Status Literal::PopulateParallel(const FnType& generator) {
+ return PopulateInternal<NativeT>(generator, /*parallel=*/true);
+}
template <typename NativeT>
void Literal::PopulateWithValue(NativeT value) {
}
}
+TEST_F(LiteralUtilTest, PopulateParallel) {
+ struct PopulateData {
+ std::vector<int64> dimensions;
+ std::vector<int64> layout;
+ } populate_data[] = {
+ {{}, {}},
+ {{0}, {0}},
+ {{16}, {0}},
+ {{2, 0}, {1, 0}},
+ {{4, 16}, {1, 0}},
+ {{21, 12}, {0, 1}},
+ {{6, 11, 17}, {2, 0, 1}},
+ {{6, 11, 5, 17}, {3, 2, 0, 1}},
+ };
+ for (const auto& data : populate_data) {
+ Shape shape = ShapeUtil::MakeShapeWithLayout(
+ primitive_util::NativeToPrimitiveType<uint32>(), data.dimensions,
+ data.layout);
+ auto literal = Literal::CreateFromShape(shape);
+ auto generator = [&](ArraySlice<int64> indexes) -> uint32 {
+ // Offsets from linear index just to avoid R0 literals to be initialized
+ // with zero.
+ return IndexUtil::MultidimensionalIndexToLinearIndex(literal->shape(),
+ indexes) +
+ 17;
+ };
+ TF_EXPECT_OK(literal->PopulateParallel<uint32>(generator));
+
+ std::vector<int64> zero_base(data.dimensions.size(), 0);
+ std::vector<int64> step(data.dimensions.size(), 1);
+ bool matched = true;
+ auto check_function = [&](ArraySlice<int64> indexes) {
+ auto value = literal->Get<uint32>(indexes);
+ matched = matched && (value == generator(indexes));
+ return matched;
+ };
+ ShapeUtil::ForEachIndex(literal->shape(), zero_base, data.dimensions, step,
+ check_function);
+ EXPECT_TRUE(matched);
+ }
+}
+
TEST_F(LiteralUtilTest, ConvertR4) {
// clang-format off
auto original = Literal::CreateR4WithLayout<int8>({{
const Literal& lhs_literal = parent_->GetEvaluatedLiteralFor(lhs);
const Literal& rhs_literal = parent_->GetEvaluatedLiteralFor(rhs);
- // Dimension number applicable for input (lhs).
- const int64 input_batch_dim = dnums.input_batch_dimension();
- const int64 input_z_dim = dnums.input_feature_dimension();
- // Dimension number applicable for kernel (rhs).
- const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension();
- const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension();
- // Dimension number applicable for output.
- const int64 output_batch_dim = dnums.output_batch_dimension();
- const int64 output_z_dim = dnums.output_feature_dimension();
-
- const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim);
-
std::vector<int64> window_dimension_sizes;
for (auto i : dnums.kernel_spatial_dimensions()) {
window_dimension_sizes.push_back(ShapeUtil::GetDimension(rhs_shape, i));
DimensionVector lhs_dim_multipliers = MakeDimMultipliers(lhs_shape);
DimensionVector rhs_dim_multipliers = MakeDimMultipliers(rhs_shape);
- DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size());
-
auto lhs_literal_data = lhs_literal.data<ReturnT>();
auto rhs_literal_data = rhs_literal.data<ReturnT>();
- auto func = [&](ArraySlice<int64> out_index) {
+ auto func = [&window_shape, &dnums, &lhs_shape, &rhs_shape, &window,
+ &lhs_dim_multipliers, &rhs_dim_multipliers, lhs_literal_data,
+ rhs_literal_data](ArraySlice<int64> out_index) {
+ // Dimension number applicable for input (lhs).
+ const int64 input_batch_dim = dnums.input_batch_dimension();
+ const int64 input_z_dim = dnums.input_feature_dimension();
+ // Dimension number applicable for kernel (rhs).
+ const int64 kernel_input_z_dim = dnums.kernel_input_feature_dimension();
+ const int64 kernel_output_z_dim = dnums.kernel_output_feature_dimension();
+ // Dimension number applicable for output.
+ const int64 output_batch_dim = dnums.output_batch_dimension();
+ const int64 output_z_dim = dnums.output_feature_dimension();
+
+ const int64 z_size = ShapeUtil::GetDimension(lhs_shape, input_z_dim);
+
ElementwiseT result_val = static_cast<ElementwiseT>(0);
- std::fill(rhs_spatial_index.begin(), rhs_spatial_index.end(), 0);
+ DimensionVector rhs_spatial_index(dnums.kernel_spatial_dimensions_size(),
+ 0);
// Convolve input feature with kernel.
do {
};
auto result = Literal::CreateFromShape(result_shape);
- TF_RETURN_IF_ERROR(result->Populate<ReturnT>(func));
+ TF_RETURN_IF_ERROR(result->PopulateParallel<ReturnT>(func));
parent_->evaluated_[conv] = std::move(result);
return Status::OK();
#include "tensorflow/compiler/xla/statusor.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
+#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/optional.h"
+#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/types.h"
tensorflow::gtl::ArraySlice<int64> count,
tensorflow::gtl::ArraySlice<int64> incr,
const FnType& visitor_function) {
- if (ShapeUtil::HasZeroElements(shape)) {
- return Status::OK();
- }
- CHECK_EQ(Rank(shape), base.size());
- CHECK_EQ(incr.size(), base.size());
- CHECK_EQ(count.size(), base.size());
- const int64 rank = LayoutUtil::MinorToMajor(shape).size();
- // Allows handling R0 arrays, such that the visitor function will be called
- // once with the proper empty indexes.
- int64 n = -1;
- std::vector<int64> indexes(base.begin(), base.end());
- while (n < rank) {
- TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes));
- if (!should_continue) {
- break;
- }
- // Increments dimensions in minor to major order.
- for (n = 0; n < rank; ++n) {
- int64 dim = LayoutUtil::Minor(shape.layout(), n);
- indexes[dim] += incr[dim];
- if (indexes[dim] < base[dim] + count[dim]) {
- break;
- }
- indexes[dim] = base[dim];
- }
- }
-
- return Status::OK();
+ return ForEachIndexInternal(shape, base, count, incr, visitor_function);
}
// Simple ergonomic wrapper around ShapeUtil::ForEachIndexWithStatus.
.IgnoreError();
}
+ // A parallel version of ForEachIndex(WithStatus). This can only be used if
+ // the visitor_function is thread-safe and the order of iteration does not
+ // matter.
+ //
+ // visitor_function must be a callable of type
+ // void(ArraySlice<int64>) or compatible.
+ template <typename FnType>
+ static void ForEachIndexParallel(const Shape& shape,
+ tensorflow::gtl::ArraySlice<int64> base,
+ tensorflow::gtl::ArraySlice<int64> count,
+ tensorflow::gtl::ArraySlice<int64> incr,
+ const FnType& visitor_function) {
+ const int kNumThreads = tensorflow::port::NumSchedulableCPUs();
+ tensorflow::thread::ThreadPool pool(tensorflow::Env::Default(), "test",
+ kNumThreads);
+ // If a pool is provided, ForEachIndexInternal can never fail.
+ CHECK(ForEachIndexInternal(
+ shape, base, count, incr,
+ [&visitor_function](tensorflow::gtl::ArraySlice<int64> indexes)
+ -> StatusOr<bool> {
+ visitor_function(indexes);
+ return true;
+ },
+ &pool)
+ .ok());
+ }
+
private:
// Validates all of the non-layout properties of the shape -- this is a helper
// used by both the layout-optional and layout-required public method.
static Status ValidateShapeWithOptionalLayoutInternal(const Shape& shape);
+ template <typename FnType>
+ static Status ForEachIndexInternal(
+ const Shape& shape, tensorflow::gtl::ArraySlice<int64> base,
+ tensorflow::gtl::ArraySlice<int64> count,
+ tensorflow::gtl::ArraySlice<int64> incr, const FnType& visitor_function,
+ tensorflow::thread::ThreadPool* pool = nullptr) {
+ if (ShapeUtil::HasZeroElements(shape)) {
+ return Status::OK();
+ }
+ CHECK_EQ(Rank(shape), base.size());
+ CHECK_EQ(incr.size(), base.size());
+ CHECK_EQ(count.size(), base.size());
+ const int64 rank = LayoutUtil::MinorToMajor(shape).size();
+ // Allows handling R0 arrays, such that the visitor function will be called
+ // once with the proper empty indexes.
+ int64 n = -1;
+ std::vector<int64> indexes(base.begin(), base.end());
+ while (n < rank) {
+ if (pool != nullptr) {
+ pool->Schedule(
+ [indexes, visitor_function] { visitor_function(indexes); });
+ } else {
+ TF_ASSIGN_OR_RETURN(bool should_continue, visitor_function(indexes));
+ if (!should_continue) {
+ break;
+ }
+ }
+ // Increments dimensions in minor to major order.
+ for (n = 0; n < rank; ++n) {
+ int64 dim = LayoutUtil::Minor(shape.layout(), n);
+ indexes[dim] += incr[dim];
+ if (indexes[dim] < base[dim] + count[dim]) {
+ break;
+ }
+ indexes[dim] = base[dim];
+ }
+ }
+
+ return Status::OK();
+ }
+
TF_DISALLOW_COPY_AND_ASSIGN(ShapeUtil);
};
EXPECT_EQ(invocations, 5);
}
+TEST(ShapeUtilTest, ForEachIndexParallel) {
+ Shape shape = ShapeUtil::MakeShape(F32, {10, 10});
+ int64 output[10][10];
+ int init = 5;
+ auto set_func = [&](tensorflow::gtl::ArraySlice<int64> indexes) {
+ output[indexes[0]][indexes[1]] = init + indexes[0] + indexes[1];
+ };
+
+ ShapeUtil::ForEachIndexParallel(shape, /*base=*/{0, 0}, /*count=*/{10, 10},
+ /*incr=*/{1, 1}, set_func);
+
+ for (int i = 0; i < 10; ++i) {
+ for (int j = 0; j < 10; ++j) {
+ EXPECT_EQ(output[i][j], init + i + j);
+ }
+ }
+}
+
TEST(ShapeUtilTest, DimensionsUnmodifiedByReshape_1x1x1x1_to_1x1x1) {
// All output dimensions should be unmodified. One of the input dimensions is
// modified because the input rank is larger by one.