\b[TOPI] Fix CUDA Library Tuning (#6132)
authorCody Yu <comaniac0422@gmail.com>
Sat, 25 Jul 2020 04:09:06 +0000 (21:09 -0700)
committerGitHub <noreply@github.com>
Sat, 25 Jul 2020 04:09:06 +0000 (21:09 -0700)
python/tvm/autotvm/task/space.py
topi/python/topi/cuda/conv2d.py

index fbf474f..4937661 100644 (file)
@@ -33,6 +33,7 @@ from collections import namedtuple, OrderedDict
 import numpy as np
 
 from tvm.te import schedule, thread_axis
+from tvm.tir import expr
 from tvm.autotvm.util import get_const_int
 
 Axis = namedtuple('Axis', ['space', 'index'])
@@ -733,10 +734,12 @@ class ConfigSpace(object):
 
         Parameters
         ---------
-        flop: int or float
+        flop: int or float or IntImm or FloatImm
             number of float operations
         """
-        self.flop += flop
+        if isinstance(flop, (expr.IntImm, expr.FloatImm)):
+            flop = flop.value
+        self.flop += float(flop)
 
     def raise_error(self, msg):
         """register error in config
index d98d630..973c216 100644 (file)
@@ -18,6 +18,7 @@
 """Compute definition for conv2d with cuda backend"""
 from tvm import te
 from tvm import autotvm
+from tvm.autotvm.task.space import OtherOptionEntity
 from tvm.contrib import cudnn
 
 from .. import nn, generic
@@ -99,6 +100,10 @@ def conv2d_cudnn(cfg, data, kernel, strides, padding, dilation, groups=1,
     else:
         dtype = data.dtype
 
+    cfg.define_knob('algo', range(8))
+    if cfg.is_fallback: # Let CUDNN choose the best algo
+        cfg['algo'] = OtherOptionEntity(-1)
+
     return cudnn.conv_forward(data,
                               kernel,
                               [pt, pl], # cudnn padding pt, pl on both sides of input
@@ -106,7 +111,7 @@ def conv2d_cudnn(cfg, data, kernel, strides, padding, dilation, groups=1,
                               [dilation_h, dilation_w],
                               conv_mode=1,
                               tensor_format=tensor_format,
-                              algo=-1,         # let CUDNN choose the best algo
+                              algo=cfg['algo'].val,
                               conv_dtype=dtype,
                               groups=groups)