Support enum arguments for Java binding
authorHamdi Sahloul <hamdisahloul@hotmail.com>
Mon, 27 Aug 2018 19:39:24 +0000 (04:39 +0900)
committerHamdi Sahloul <hamdisahloul@hotmail.com>
Sat, 1 Sep 2018 06:04:55 +0000 (15:04 +0900)
modules/java/generator/gen_java.py

index c17e1b3..c2f6e85 100755 (executable)
@@ -140,17 +140,18 @@ class GeneralInfo():
 
     def fullName(self, isCPP=False):
         result = ".".join([self.fullClass(), self.name])
-        return result if not isCPP else result.replace(".", "::")
+        return result if not isCPP else get_cname(result)
 
     def fullClass(self, isCPP=False):
         result = ".".join([f for f in [self.namespace] + self.classpath.split(".") if len(f)>0])
-        return result if not isCPP else result.replace(".", "::")
+        return result if not isCPP else get_cname(result)
 
 class ConstInfo(GeneralInfo):
-    def __init__(self, decl, addedManually=False, namespaces=[]):
+    def __init__(self, decl, addedManually=False, namespaces=[], enumType=None):
         GeneralInfo.__init__(self, "const", decl, namespaces)
-        self.cname = self.name.replace(".", "::")
+        self.cname = get_cname(self.name)
         self.value = decl[1]
+        self.enumType = enumType
         self.addedManually = addedManually
         if self.namespace in namespaces_dict:
             self.name = '%s_%s' % (namespaces_dict[self.namespace], self.name)
@@ -166,6 +167,25 @@ class ConstInfo(GeneralInfo):
                 return True
         return False
 
+def normalize_field_name(name):
+    return name.replace(".","_").replace("[","").replace("]","").replace("_getNativeObjAddr()","_nativeObj")
+
+def normalize_class_name(name):
+    return re.sub(r"^cv\.", "", name).replace(".", "_")
+
+def get_cname(name):
+    return name.replace(".", "::")
+
+def cast_from(t):
+    if t in type_dict and "cast_from" in type_dict[t]:
+        return type_dict[t]["cast_from"]
+    return t
+
+def cast_to(t):
+    if t in type_dict and "cast_to" in type_dict[t]:
+        return type_dict[t]["cast_to"]
+    return t
+
 class ClassPropInfo():
     def __init__(self, decl): # [f_ctype, f_name, '', '/RW']
         self.ctype = decl[0]
@@ -178,7 +198,7 @@ class ClassPropInfo():
 class ClassInfo(GeneralInfo):
     def __init__(self, decl, namespaces=[]): # [ 'class/struct cname', ': base', [modlist] ]
         GeneralInfo.__init__(self, "class", decl, namespaces)
-        self.cname = self.name.replace(".", "::")
+        self.cname = get_cname(self.name)
         self.methods = []
         self.methods_suffixes = {}
         self.consts = [] # using a list to save the occurrence order
@@ -303,7 +323,7 @@ class ArgInfo():
 class FuncInfo(GeneralInfo):
     def __init__(self, decl, namespaces=[]): # [ funcname, return_ctype, [modifiers], [args] ]
         GeneralInfo.__init__(self, "func", decl, namespaces)
-        self.cname = decl[0].replace(".", "::")
+        self.cname = get_cname(decl[0])
         self.jname = self.name
         self.isconstructor = self.name == self.classname
         if "[" in self.name:
@@ -341,7 +361,6 @@ class JavaWrapperGenerator(object):
         self.classes = { "Mat" : ClassInfo([ 'class Mat', '', [], [] ], self.namespaces) }
         self.module = ""
         self.Module = ""
-        self.enum_types = []
         self.ported_func_list = []
         self.skipped_func_list = []
         self.def_args_hist = {} # { def_args_cnt : funcs_cnt }
@@ -404,8 +423,8 @@ class JavaWrapperGenerator(object):
         )
         logging.info('ok: class %s, name: %s, base: %s', classinfo, name, classinfo.base)
 
-    def add_const(self, decl): # [ "const cname", val, [], [] ]
-        constinfo = ConstInfo(decl, namespaces=self.namespaces)
+    def add_const(self, decl, enumType=None): # [ "const cname", val, [], [] ]
+        constinfo = ConstInfo(decl, namespaces=self.namespaces, enumType=enumType)
         if constinfo.isIgnored():
             logging.info('ignored: %s', constinfo)
         elif not self.isWrapped(constinfo.classname):
@@ -423,12 +442,16 @@ class JavaWrapperGenerator(object):
                 logging.info('ok: %s', constinfo)
 
     def add_enum(self, decl): # [ "enum cname", "", [], [] ]
-        enumname = decl[0].replace("enum ", "").strip()
-        self.enum_types.append(enumname)
+        enumType = decl[0].rsplit(" ", 1)[1]
+        if enumType.endswith("<unnamed>"):
+            enumType = None
+        else:
+            ctype = normalize_class_name(enumType)
+            type_dict[ctype] = { "cast_from" : "int", "cast_to" : get_cname(enumType), "j_type" : "int", "jn_type" : "int", "jni_type" : "jint", "suffix" : "I" }
         const_decls = decl[3]
 
         for decl in const_decls:
-            self.add_const(decl)
+            self.add_const(decl, enumType)
 
     def add_func(self, decl):
         fi = FuncInfo(decl, namespaces=self.namespaces)
@@ -530,7 +553,7 @@ class JavaWrapperGenerator(object):
         if self.isWrapped(t):
             return self.getClass(t).fullName(isCPP=True)
         else:
-            return t
+            return cast_from(t)
 
     def gen_func(self, ci, fi, prop_name=''):
         logging.info("%s", fi)
@@ -563,7 +586,7 @@ class JavaWrapperGenerator(object):
             msg = "// Return type '%s' is not supported, skipping the function\n\n" % fi.ctype
             self.skipped_func_list.append(c_decl + "\n" + msg)
             j_code.write( " "*4 + msg )
-            logging.warning("SKIP:" + c_decl.strip() + "\t due to RET type" + fi.ctype)
+            logging.warning("SKIP:" + c_decl.strip() + "\t due to RET type " + fi.ctype)
             return
         for a in fi.args:
             if a.ctype not in type_dict:
@@ -575,7 +598,7 @@ class JavaWrapperGenerator(object):
                 msg = "// Unknown type '%s' (%s), skipping the function\n\n" % (a.ctype, a.out or "I")
                 self.skipped_func_list.append(c_decl + "\n" + msg)
                 j_code.write( " "*4 + msg )
-                logging.warning("SKIP:" + c_decl.strip() + "\t due to ARG type" + a.ctype + "/" + (a.out or "I"))
+                logging.warning("SKIP:" + c_decl.strip() + "\t due to ARG type " + a.ctype + "/" + (a.out or "I"))
                 return
 
         self.ported_func_list.append(c_decl)
@@ -654,7 +677,7 @@ class JavaWrapperGenerator(object):
                     if "I" in a.out or not a.out or self.isWrapped(a.ctype): # input arg, pass by primitive fields
                         for f in fields:
                             jn_args.append ( ArgInfo([ f[0], a.name + f[1], "", [], "" ]) )
-                            jni_args.append( ArgInfo([ f[0], a.name + f[1].replace(".","_").replace("[","").replace("]","").replace("_getNativeObjAddr()","_nativeObj"), "", [], "" ]) )
+                            jni_args.append( ArgInfo([ f[0], a.name + normalize_field_name(f[1]), "", [], "" ]) )
                     if "O" in a.out and not self.isWrapped(a.ctype): # out arg, pass as double[]
                         jn_args.append ( ArgInfo([ "double[]", "%s_out" % a.name, "", [], "" ]) )
                         jni_args.append ( ArgInfo([ "double[]", "%s_out" % a.name, "", [], "" ]) )
@@ -702,7 +725,7 @@ class JavaWrapperGenerator(object):
                 "    private static native $type $name($args);\n").substitute(\
                 type = type_dict[fi.ctype].get("jn_type", "double[]"), \
                 name = fi.jname + '_' + str(suffix_counter), \
-                args = ", ".join(["%s %s" % (type_dict[a.ctype]["jn_type"], a.name.replace(".","_").replace("[","").replace("]","").replace("_getNativeObjAddr()","_nativeObj")) for a in jn_args])
+                args = ", ".join(["%s %s" % (type_dict[a.ctype]["jn_type"], normalize_field_name(a.name)) for a in jn_args])
             ) );
 
             # java part:
@@ -860,7 +883,7 @@ class JavaWrapperGenerator(object):
                     if not a.out and not "jni_var" in type_dict[a.ctype]:
                         # explicit cast to C type to avoid ambiguous call error on platforms (mingw)
                         # where jni types are different from native types (e.g. jint is not the same as int)
-                        jni_name  = "(%s)%s" % (a.ctype, jni_name)
+                        jni_name  = "(%s)%s" % (cast_to(a.ctype), jni_name)
                 if not a.ctype: # hidden
                     jni_name = a.defval
                 cvargs.append( type_dict[a.ctype].get("jni_name", jni_name) % {"n" : a.name})
@@ -934,11 +957,35 @@ JNIEXPORT $rtype JNICALL Java_org_opencv_${module}_${clazz}_$fname
             %s;\n\n""" % (",\n"+" "*12).join(["%s = %s" % (c.name, c.value) for c in ci.private_consts])
             )
         if ci.consts:
-            logging.info("%s", ci.consts)
-            ci.j_code.write("""
+            enumTypes = set(map(lambda c: c.enumType, ci.consts))
+            grouped_consts = {enumType: [c for c in ci.consts if c.enumType == enumType] for enumType in enumTypes}
+            for typeName, consts in grouped_consts.items():
+                logging.info("%s", consts)
+                if typeName:
+                    typeName = typeName.rsplit(".", 1)[-1]
+###################### Utilize Java enums ######################
+#                    ci.j_code.write("""
+#    public enum {1} {{
+#        {0};
+#
+#        private final int id;
+#        {1}(int id) {{ this.id = id; }}
+#        {1}({1} _this) {{ this.id = _this.id; }}
+#        public int getValue() {{ return id; }}
+#    }}\n\n""".format((",\n"+" "*8).join(["%s(%s)" % (c.name, c.value) for c in consts]), typeName)
+#                    )
+################################################################
+                    ci.j_code.write("""
+    // C++: enum {1}
     public static final int
-            %s;\n\n""" % (",\n"+" "*12).join(["%s = %s" % (c.name, c.value) for c in ci.consts])
-            )
+            {0};\n\n""".format((",\n"+" "*12).join(["%s = %s" % (c.name, c.value) for c in consts]), typeName)
+                    )
+                else:
+                    ci.j_code.write("""
+    // C++: enum <unnamed>
+    public static final int
+            {0};\n\n""".format((",\n"+" "*12).join(["%s = %s" % (c.name, c.value) for c in consts]))
+                    )
         # methods
         for fi in ci.getAllMethods():
             self.gen_func(ci, fi)