[AutoTVM] fix argument type for curve feature (#3004)
authorLianmin Zheng <lianminzheng@gmail.com>
Thu, 11 Apr 2019 02:58:54 +0000 (10:58 +0800)
committerGitHub <noreply@github.com>
Thu, 11 Apr 2019 02:58:54 +0000 (10:58 +0800)
src/autotvm/touch_extractor.cc
tests/python/unittest/test_autotvm_feature.py

index e24e757..002b970 100644 (file)
@@ -514,10 +514,10 @@ TVM_REGISTER_API("autotvm.feature.GetItervarFeatureFlatten")
 TVM_REGISTER_API("autotvm.feature.GetCurveSampleFeatureFlatten")
 .set_body([](TVMArgs args, TVMRetValue *ret) {
   Stmt stmt = args[0];
-  bool take_log = args[1];
+  int sample_n = args[1];
   std::vector<float> ret_feature;
 
-  GetCurveSampleFeatureFlatten(stmt, take_log, &ret_feature);
+  GetCurveSampleFeatureFlatten(stmt, sample_n, &ret_feature);
 
   TVMByteArray arr;
   arr.size = sizeof(float) * ret_feature.size();
index 401a8d3..e0736c2 100644 (file)
@@ -61,6 +61,23 @@ def test_iter_feature_gemm():
             assert ans[pair[0]] == pair[1:], "%s: %s vs %s" % (pair[0], ans[pair[0]], pair[1:])
 
 
+def test_curve_feature_gemm():
+    N = 128
+
+    k = tvm.reduce_axis((0, N), 'k')
+    A = tvm.placeholder((N, N), name='A')
+    B = tvm.placeholder((N, N), name='B')
+    C = tvm.compute(
+        A.shape,
+        lambda y, x: tvm.sum(A[y, k] * B[k, x], axis=k),
+        name='C')
+
+    s = tvm.create_schedule(C.op)
+
+    feas = feature.get_buffer_curve_sample_flatten(s, [A, B, C], sample_n=30)
+    # sample_n * #buffers * #curves * 2 numbers per curve
+    assert len(feas) == 30 * 3 * 4 * 2
+
 def test_feature_shape():
     """test the dimensions of flatten feature are the same"""
 
@@ -112,4 +129,6 @@ def test_feature_shape():
 
 if __name__ == "__main__":
     test_iter_feature_gemm()
+    test_curve_feature_gemm()
     test_feature_shape()
+