namespace caffe {
-template <typename Dtype>
-class StochasticPoolingLayerTest : public ::testing::Test {
+template <typename TypeParam>
+class StochasticPoolingLayerTest : public MultiDeviceTest<TypeParam> {
+ typedef typename TypeParam::Dtype Dtype;
+
protected:
StochasticPoolingLayerTest()
: blob_bottom_(new Blob<Dtype>()),
vector<Blob<Dtype>*> blob_top_vec_;
};
-TYPED_TEST_CASE(StochasticPoolingLayerTest, TestDtypes);
+template <typename Dtype>
+class CPUStochasticPoolingLayerTest
+ : public StochasticPoolingLayerTest<CPUDevice<Dtype> > {
+};
+
+TYPED_TEST_CASE(CPUStochasticPoolingLayerTest, TestDtypes);
-TYPED_TEST(StochasticPoolingLayerTest, TestSetup) {
+TYPED_TEST(CPUStochasticPoolingLayerTest, TestSetup) {
LayerParameter layer_param;
PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
pooling_param->set_kernel_size(3);
EXPECT_EQ(this->blob_top_->width(), 2);
}
-TYPED_TEST(StochasticPoolingLayerTest, TestStochasticGPU) {
- Caffe::set_mode(Caffe::GPU);
+#ifndef CPU_ONLY
+
+template <typename Dtype>
+class GPUStochasticPoolingLayerTest
+ : public StochasticPoolingLayerTest<GPUDevice<Dtype> > {
+};
+
+TYPED_TEST_CASE(GPUStochasticPoolingLayerTest, TestDtypes);
+
+TYPED_TEST(GPUStochasticPoolingLayerTest, TestStochastic) {
LayerParameter layer_param;
layer_param.set_phase(TRAIN);
PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
EXPECT_GE(total / this->blob_top_->count(), 0.55);
}
-TYPED_TEST(StochasticPoolingLayerTest, TestStochasticGPUTestPhase) {
- Caffe::set_mode(Caffe::GPU);
+TYPED_TEST(GPUStochasticPoolingLayerTest, TestStochasticTestPhase) {
LayerParameter layer_param;
layer_param.set_phase(TEST);
PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
}
}
-TYPED_TEST(StochasticPoolingLayerTest, TestGradientGPU) {
- Caffe::set_mode(Caffe::GPU);
+TYPED_TEST(GPUStochasticPoolingLayerTest, TestGradient) {
LayerParameter layer_param;
layer_param.set_phase(TRAIN);
PoolingParameter* pooling_param = layer_param.mutable_pooling_param();
this->blob_top_vec_);
}
-
+#endif
} // namespace caffe