From: Cody Yu Date: Sat, 25 Jul 2020 04:09:06 +0000 (-0700) Subject: [TOPI] Fix CUDA Library Tuning (#6132) X-Git-Tag: upstream/0.7.0~357 X-Git-Url: http://review.tizen.org/git/?a=commitdiff_plain;h=00578b77530d5272397f76c2966c125888c4ed94;p=platform%2Fupstream%2Ftvm.git [TOPI] Fix CUDA Library Tuning (#6132) --- diff --git a/python/tvm/autotvm/task/space.py b/python/tvm/autotvm/task/space.py index fbf474f..4937661 100644 --- a/python/tvm/autotvm/task/space.py +++ b/python/tvm/autotvm/task/space.py @@ -33,6 +33,7 @@ from collections import namedtuple, OrderedDict import numpy as np from tvm.te import schedule, thread_axis +from tvm.tir import expr from tvm.autotvm.util import get_const_int Axis = namedtuple('Axis', ['space', 'index']) @@ -733,10 +734,12 @@ class ConfigSpace(object): Parameters --------- - flop: int or float + flop: int or float or IntImm or FloatImm number of float operations """ - self.flop += flop + if isinstance(flop, (expr.IntImm, expr.FloatImm)): + flop = flop.value + self.flop += float(flop) def raise_error(self, msg): """register error in config diff --git a/topi/python/topi/cuda/conv2d.py b/topi/python/topi/cuda/conv2d.py index d98d630..973c216 100644 --- a/topi/python/topi/cuda/conv2d.py +++ b/topi/python/topi/cuda/conv2d.py @@ -18,6 +18,7 @@ """Compute definition for conv2d with cuda backend""" from tvm import te from tvm import autotvm +from tvm.autotvm.task.space import OtherOptionEntity from tvm.contrib import cudnn from .. import nn, generic @@ -99,6 +100,10 @@ def conv2d_cudnn(cfg, data, kernel, strides, padding, dilation, groups=1, else: dtype = data.dtype + cfg.define_knob('algo', range(8)) + if cfg.is_fallback: # Let CUDNN choose the best algo + cfg['algo'] = OtherOptionEntity(-1) + return cudnn.conv_forward(data, kernel, [pt, pl], # cudnn padding pt, pl on both sides of input @@ -106,7 +111,7 @@ def conv2d_cudnn(cfg, data, kernel, strides, padding, dilation, groups=1, [dilation_h, dilation_w], conv_mode=1, tensor_format=tensor_format, - algo=-1, # let CUDNN choose the best algo + algo=cfg['algo'].val, conv_dtype=dtype, groups=groups)