typedef typename TypeParam::Dtype Dtype;
protected:
SoftmaxLayerTest()
- : blob_bottom_(new Blob<Dtype>(2, 10, 1, 1)),
+ : blob_bottom_(new Blob<Dtype>(2, 10, 2, 3)),
blob_top_(new Blob<Dtype>()) {
// fill the values
FillerParameter filler_param;
layer.Forward(this->blob_bottom_vec_, &(this->blob_top_vec_));
// Test sum
for (int i = 0; i < this->blob_bottom_->num(); ++i) {
- Dtype sum = 0;
- for (int j = 0; j < this->blob_top_->channels(); ++j) {
- sum += this->blob_top_->data_at(i, j, 0, 0);
- }
- EXPECT_GE(sum, 0.999);
- EXPECT_LE(sum, 1.001);
- }
- // Test exact values
- for (int i = 0; i < this->blob_bottom_->num(); ++i) {
- Dtype scale = 0;
- for (int j = 0; j < this->blob_bottom_->channels(); ++j) {
- scale += exp(this->blob_bottom_->data_at(i, j, 0, 0));
- }
- for (int j = 0; j < this->blob_bottom_->channels(); ++j) {
- EXPECT_GE(this->blob_top_->data_at(i, j, 0, 0) + 1e-4,
- exp(this->blob_bottom_->data_at(i, j, 0, 0)) / scale)
- << "debug: " << i << " " << j;
- EXPECT_LE(this->blob_top_->data_at(i, j, 0, 0) - 1e-4,
- exp(this->blob_bottom_->data_at(i, j, 0, 0)) / scale)
- << "debug: " << i << " " << j;
+ for (int k = 0; k < this->blob_bottom_->height(); ++k) {
+ for (int l = 0; l < this->blob_bottom_->width(); ++l) {
+ Dtype sum = 0;
+ for (int j = 0; j < this->blob_top_->channels(); ++j) {
+ sum += this->blob_top_->data_at(i, j, k, l);
+ }
+ EXPECT_GE(sum, 0.999);
+ EXPECT_LE(sum, 1.001);
+ // Test exact values
+ Dtype scale = 0;
+ for (int j = 0; j < this->blob_bottom_->channels(); ++j) {
+ scale += exp(this->blob_bottom_->data_at(i, j, k, l));
+ }
+ for (int j = 0; j < this->blob_bottom_->channels(); ++j) {
+ EXPECT_GE(this->blob_top_->data_at(i, j, k, l) + 1e-4,
+ exp(this->blob_bottom_->data_at(i, j, k, l)) / scale)
+ << "debug: " << i << " " << j;
+ EXPECT_LE(this->blob_top_->data_at(i, j, k, l) - 1e-4,
+ exp(this->blob_bottom_->data_at(i, j, k, l)) / scale)
+ << "debug: " << i << " " << j;
+ }
+ }
}
}
}
protected:
SoftmaxWithLossLayerTest()
- : blob_bottom_data_(new Blob<Dtype>(10, 5, 1, 1)),
- blob_bottom_label_(new Blob<Dtype>(10, 1, 1, 1)),
+ : blob_bottom_data_(new Blob<Dtype>(10, 5, 2, 3)),
+ blob_bottom_label_(new Blob<Dtype>(10, 1, 2, 3)),
blob_top_loss_(new Blob<Dtype>()) {
// fill the values
FillerParameter filler_param;