gemm util
authorYangqing Jia <jiayq84@gmail.com>
Wed, 18 Sep 2013 20:51:30 +0000 (13:51 -0700)
committerYangqing Jia <jiayq84@gmail.com>
Wed, 18 Sep 2013 20:51:30 +0000 (13:51 -0700)
src/caffeine/layer.hpp
src/caffeine/proto/layer_param.proto
src/caffeine/test/test_util_gemm.cpp [new file with mode: 0644]
src/caffeine/util/gemm.cpp [new file with mode: 0644]
src/caffeine/util/gemm.hpp [new file with mode: 0644]

index 6bdacb1..4a9247d 100644 (file)
@@ -32,6 +32,7 @@ class Layer {
       const bool propagate_down,
       vector<Blob<Dtype>*>* bottom);
 
+  // Returns the vector of parameters.
   vector<Blob<Dtype> >& params() { return blobs_; };
 
  protected:
index 806c405..27246ff 100644 (file)
@@ -31,14 +31,13 @@ message LayerParameter {
   optional FillerParameter bias_filler = 6; // The filler for the bias
 
   optional uint32 pad = 7 [default = 0]; // The padding size
-  optional uint32 kernelsize = 8; // The kernel size
-  optional uint32 group = 9 [default = 1]; // The group size for group conv
-  optional uint32 stride = 10 [default = 1]; // The stride
-  optional string pool = 11 [default = 'max']; // The pooling method
-  optional float dropout_ratio = 12 [default = 0.5]; // dropout ratio
+  optional float pad_value = 8 [default = 0]; // The padding value
+  optional uint32 kernelsize = 9; // The kernel size
+  optional uint32 group = 10 [default = 1]; // The group size for group conv
+  optional uint32 stride = 11 [default = 1]; // The stride
+  optional string pool = 12 [default = 'max']; // The pooling method
+  optional float dropout_ratio = 13 [default = 0.5]; // dropout ratio
 
-  optional float alpha = 13 [default = 1.]; // for local response norm
-  optional float beta = 14 [default = 0.75]; // for local response norm
-
-  repeated BlobProto blobs = 50; // for possible data.
+  optional float alpha = 14 [default = 1.]; // for local response norm
+  optional float beta = 15 [default = 0.75]; // for local response norm
 }
diff --git a/src/caffeine/test/test_util_gemm.cpp b/src/caffeine/test/test_util_gemm.cpp
new file mode 100644 (file)
index 0000000..9ea7160
--- /dev/null
@@ -0,0 +1,92 @@
+#include <cstring>
+#include <cuda_runtime.h>
+#include <mkl.h>
+#include <cublas_v2.h>
+
+#include "gtest/gtest.h"
+#include "caffeine/blob.hpp"
+#include "caffeine/util/gemm.hpp"
+
+namespace caffeine {
+
+extern cudaDeviceProp CAFFEINE_TEST_CUDA_PROP;
+
+typedef ::testing::Types<float, double> Dtypes;
+
+template <typename Dtype>
+class GemmTest : public ::testing::Test {};
+
+TYPED_TEST_CASE(GemmTest, Dtypes);
+
+TYPED_TEST(GemmTest, TestGemm) {
+  Blob<TypeParam> A(1,1,2,3);
+  Blob<TypeParam> B(1,1,3,4);
+  Blob<TypeParam> C(1,1,2,4);
+  TypeParam data[12] = {1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12};
+  TypeParam A_reshape_data[6] = {1, 4, 2, 5, 3, 6};
+  TypeParam B_reshape_data[12] = {1,5,9,2,6,10,3,7,11,4,8,12};
+  TypeParam result[8] = {38,44,50,56,83,98,113,128};
+  memcpy(A.mutable_cpu_data(), data, 6 * sizeof(TypeParam));
+  memcpy(B.mutable_cpu_data(), data, 12 * sizeof(TypeParam));
+
+  if (sizeof(TypeParam) == 4 || CAFFEINE_TEST_CUDA_PROP.major >= 2) {
+    //[1,2,3; 4 5 6] * [1,2,3,4; 5,6,7,8; 9,10,11,12];
+    decaf_cpu_gemm<TypeParam>(CblasNoTrans, CblasNoTrans, 2, 4, 3, 1.,
+        A.cpu_data(), B.cpu_data(), 0., C.mutable_cpu_data());
+    for (int i = 0; i < 8; ++i) {
+      EXPECT_EQ(C.cpu_data()[i], result[i]);
+    }
+    decaf_gpu_gemm<TypeParam>(CblasNoTrans, CblasNoTrans, 2, 4, 3, 1.,
+        A.gpu_data(), B.gpu_data(), 0., C.mutable_gpu_data());
+    for (int i = 0; i < 8; ++i) {
+      EXPECT_EQ(C.cpu_data()[i], result[i]);
+    }
+
+    // Test when we have a transposed A
+    A.Reshape(1,1,3,2);
+    memcpy(A.mutable_cpu_data(), A_reshape_data, 6 * sizeof(TypeParam));
+    decaf_cpu_gemm<TypeParam>(CblasTrans, CblasNoTrans, 2, 4, 3, 1.,
+        A.cpu_data(), B.cpu_data(), 0., C.mutable_cpu_data());
+    for (int i = 0; i < 8; ++i) {
+      EXPECT_EQ(C.cpu_data()[i], result[i]);
+    }
+    decaf_gpu_gemm<TypeParam>(CblasTrans, CblasNoTrans, 2, 4, 3, 1.,
+        A.gpu_data(), B.gpu_data(), 0., C.mutable_gpu_data());
+    for (int i = 0; i < 8; ++i) {
+      EXPECT_EQ(C.cpu_data()[i], result[i]);
+    }
+
+    // Test when we have a transposed A and a transposed B too
+    B.Reshape(1,1,4,3);
+    memcpy(B.mutable_cpu_data(), B_reshape_data, 12 * sizeof(TypeParam));
+    decaf_cpu_gemm<TypeParam>(CblasTrans, CblasTrans, 2, 4, 3, 1.,
+        A.cpu_data(), B.cpu_data(), 0., C.mutable_cpu_data());
+    for (int i = 0; i < 8; ++i) {
+      EXPECT_EQ(C.cpu_data()[i], result[i]);
+    }
+    decaf_gpu_gemm<TypeParam>(CblasTrans, CblasTrans, 2, 4, 3, 1.,
+        A.gpu_data(), B.gpu_data(), 0., C.mutable_gpu_data());
+    for (int i = 0; i < 8; ++i) {
+      EXPECT_EQ(C.cpu_data()[i], result[i]);
+    }
+
+    // Test when we have a transposed B
+    A.Reshape(1,1,2,3);
+    memcpy(A.mutable_cpu_data(), data, 6 * sizeof(TypeParam));
+    decaf_cpu_gemm<TypeParam>(CblasNoTrans, CblasTrans, 2, 4, 3, 1.,
+        A.cpu_data(), B.cpu_data(), 0., C.mutable_cpu_data());
+    for (int i = 0; i < 8; ++i) {
+      EXPECT_EQ(C.cpu_data()[i], result[i]);
+    }
+    decaf_gpu_gemm<TypeParam>(CblasNoTrans, CblasTrans, 2, 4, 3, 1.,
+        A.gpu_data(), B.gpu_data(), 0., C.mutable_gpu_data());
+    for (int i = 0; i < 8; ++i) {
+      EXPECT_EQ(C.cpu_data()[i], result[i]);
+    }
+  } else {
+    LOG(ERROR) << "Skipping test due to old architecture.";
+  }
+}
+
+
+}
diff --git a/src/caffeine/util/gemm.cpp b/src/caffeine/util/gemm.cpp
new file mode 100644 (file)
index 0000000..f6b9a1d
--- /dev/null
@@ -0,0 +1,63 @@
+#include <mkl.h>
+#include <cublas_v2.h>
+#include "caffeine/common.hpp"
+#include "caffeine/util/gemm.hpp"
+
+namespace caffeine {
+
+template<>
+void decaf_cpu_gemm<float>(const CBLAS_TRANSPOSE TransA,
+    const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
+    const float alpha, const float* A, const float* B, const float beta,
+    float* C) {
+  int lda = (TransA == CblasNoTrans) ? K : M;
+  int ldb = (TransB == CblasNoTrans) ? N : K;
+  cblas_sgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B,
+      ldb, beta, C, N);
+}
+
+template<>
+void decaf_cpu_gemm<double>(const CBLAS_TRANSPOSE TransA,
+    const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
+    const double alpha, const double* A, const double* B, const double beta,
+    double* C) {
+  int lda = (TransA == CblasNoTrans) ? K : M;
+  int ldb = (TransB == CblasNoTrans) ? N : K;
+  cblas_dgemm(CblasRowMajor, TransA, TransB, M, N, K, alpha, A, lda, B,
+      ldb, beta, C, N);
+}
+
+template <>
+void decaf_gpu_gemm<float>(const CBLAS_TRANSPOSE TransA,
+    const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
+    const float alpha, const float* A, const float* B, const float beta,
+    float* C) {
+  // Note that cublas follows fortran order.
+  int lda = (TransA == CblasNoTrans) ? K : M;
+  int ldb = (TransB == CblasNoTrans) ? N : K;
+  cublasOperation_t cuTransA = 
+      (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
+  cublasOperation_t cuTransB =
+      (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
+  CUBLAS_CHECK(cublasSgemm(Caffeine::cublas_handle(), cuTransB, cuTransA,
+      N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));  
+}
+
+template <>
+void decaf_gpu_gemm<double>(const CBLAS_TRANSPOSE TransA,
+    const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
+    const double alpha, const double* A, const double* B, const double beta,
+    double* C) {
+  // Note that cublas follows fortran order.
+  int lda = (TransA == CblasNoTrans) ? K : M;
+  int ldb = (TransB == CblasNoTrans) ? N : K;
+  cublasOperation_t cuTransA = 
+      (TransA == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
+  cublasOperation_t cuTransB =
+      (TransB == CblasNoTrans) ? CUBLAS_OP_N : CUBLAS_OP_T;
+  CUBLAS_CHECK(cublasDgemm(Caffeine::cublas_handle(), cuTransA, cuTransB,
+      N, M, K, &alpha, B, ldb, A, lda, &beta, C, N));  
+}
+
+
+}  // namespace caffeine
\ No newline at end of file
diff --git a/src/caffeine/util/gemm.hpp b/src/caffeine/util/gemm.hpp
new file mode 100644 (file)
index 0000000..f6af9c3
--- /dev/null
@@ -0,0 +1,28 @@
+#ifndef CAFFEINE_UTIL_GEMM_H_
+#define CAFFEINE_UTIL_GEMM_H_
+
+#include <mkl.h>
+#include <cublas_v2.h>
+
+namespace caffeine {
+
+// Decaf gemm provides a simpler interface to the gemm functions, with the
+// limitation that the data has to be contiguous in memory.
+template <typename Dtype>
+inline void decaf_cpu_gemm(const CBLAS_TRANSPOSE TransA,
+    const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
+    const Dtype alpha, const Dtype* A, const Dtype* B, const Dtype beta,
+    Dtype* C);
+
+// Decaf gpu gemm provides an interface that is almost the same as the cpu
+// gemm function - following the c convention and calling the fortran-order
+// gpu code under the hood.
+template <typename Dtype>
+void decaf_gpu_gemm(const CBLAS_TRANSPOSE TransA,
+    const CBLAS_TRANSPOSE TransB, const int M, const int N, const int K,
+    const Dtype alpha, const Dtype* A, const Dtype* B, const Dtype beta,
+    Dtype* C);
+
+}  // namespace caffeine
+
+#endif  // CAFFEINE_UTIL_GEMM_H_