from configargparse import ArgParser, YAMLConfigFileParser
from op_names import names_dict, arg_dependent_ops
import text_blocks as tb
import fileinput
import glob
import json
import logging
import os
import posixpath
import pygccxml
import pygccxml.declarations as dec


class BindingsGenerator:
    def __init__(self, opts, starting_indent=''):
        self.opts = opts
        self.indent = starting_indent
        self.written_files = []

    def write_data_to_file(self, format_obj, file_name, file_part, data, found_includes):
        """
        Writes incoming strings to an output file, globally keep track of objects written?
        :param format_obj: string body from templates to format over if a new file
        :param file_name: name of file to write to
        :param file_part: name of module to add if file is new
        :param data: string data to put into file.
        :param found_includes: list of files to include in resultant C++
        :return: None
        """
        # Checks to see if we have written the file already this run
        if file_name not in self.written_files:
            # If not, write out all of the structure that an individual object needs
            with open(file_name, "w") as cpp_file:
                cpp_file.write(format_obj.format(namespace=file_part,
                                                 defs=data,
                                                 includes=found_includes))
                # Keep track that we've written it.
                self.written_files.append(file_name)
        else:
            # Already written once,  Now we're going to stuff the definitions in
            # a flag to not stuff the data in more than once
            content_added = False
            #  fileinput.input with inplace capabilities "reads" the file from a
            #  temporary one and redirects stdout to write to the file given
            for line in fileinput.input(file_name, inplace=True):
                # Rstrip to remove our new lines and use the newlines from print
                print(line.rstrip())
                # "{" means we've found the end of the definition, and can safely print
                # our additional information to the file
                if "{" == line.strip() and not content_added:
                    print(data)
                    # Don't stuff the data in more than once.
                    content_added = True
            fileinput.close()

    def generate_function_string(self, member_function, is_free_function=False, py_name=''):
        """
        Takes a set of PyGCCXML data and returns the PyBind11 code for
        that function

        :param member_function: PyGCCXML data dictionary for individual function
        :param is_free_function: Boolean to determine if function should be marked
        as a member of the PyBind11 module or of a class. Essentially, prepends "m" to function signature
        is_free_function = True
          m.def("ExtractDoubleOrThrow", py::overload_cast<double const &>(ExtractDoubleOrThrow),<...>
        is_free_function = False
          .def("start_time", py::overload_cast<>(start_time),, , doc.PiecewiseTrajectory.start_time.doc)
        :param py_name: Name of the function on the python side. Defaults to the same
        name as the C++ function
        :return: A string which contains all PyBind11 declarations for single function
        """

        # If a different name is requested on the Python side, set it here
        fun_name = py_name if py_name else member_function.name

        # Capture each argument and a default value, if found.
        arg_string = ""

        for arg in member_function.arguments:
            next_arg_str = self.opts.member_func_arg_fmt.format(arg.name,
                                                                " = %s" % arg.default_value if arg.default_value else "")
            arg_string = arg_string + ", " + next_arg_str
        if member_function.parent.name == "::":
            ref_string = "&%s" % (member_function.name)

        else:
            ref_string = "&%s::%s" % (member_function.parent.decl_string, member_function.name)

        signature = self.opts.overload_template_fmt.format(fun_ref=ref_string,
                                                           decl_string=member_function.decl_string)

        # Check to see if function can be marked as static
        static = ""
        if "has_static" in dir(member_function):
            static = "_static" if member_function.has_static else ""

        # Return formatted function string
        member_string = self.opts.member_func_fmt.format(
          module="m" if is_free_function else "",
          static=static,
          fun_name=fun_name,
          fun_ref=signature,
          args=arg_string,
          classname_doc=member_function.parent.decl_string,
          doc=", doc.%s.%s.doc" % (member_function.parent.name.split("<")[0], member_function.name) if False else "",
          ending=";" if is_free_function else "")
        member_string.strip(",")

        return member_string

    def find_getter(self, var_data, class_data):
        name_to_find = "get_" + var_data.name
        if name_to_find in [fxn.name for fxn in class_data.member_functions()]:
            return name_to_find

        return ""

    def find_setter(self, var_data, class_data):
        name_to_find = "set_" + var_data.name
        if name_to_find in [fxn.name for fxn in class_data.member_functions()]:
            return name_to_find

        return ""

    def generate_operator_string(self, oper_data, is_member_fxn=True):
        """
        Accepts a pygccxml data object which represents a marked operator for a
        class.  Check for overloads and generate a valid binding code as a string to
        be returned
        :param oper_data: A pygccxml object representing a variable
        :return: a string containing the PyBind11 declaration for the operator
        """
        if oper_data.name == "operator=":
            return ""

        num_args = len(oper_data.arguments)
        symbol = oper_data.symbol
        py_name = ""

        # Some C++ operators have multiple python names depending on
        # the number of arguments. We'll check which py_name to use here
        if symbol in arg_dependent_ops:
            is_unary = is_member_fxn and num_args == 0 or not is_member_fxn and num_args == 1

            if is_unary:
                py_name = names_dict[symbol][0]

            else:
                py_name = names_dict[symbol][1]

        else:
            py_name = names_dict[symbol]

        if not py_name:
            raise RuntimeError("py_name not set")

        return self.generate_function_string(oper_data, py_name=py_name)

    def generate_member_var_string(self, var_data, written_functions):
        '''
        Accepts a pygccxml object representing a variable.
        Checks whether the variable is writeable and static,
        then generates valid binding code for that variable as a string.

        :param var_data: a PyGCCXML object representing a variable
        :param written_functions: Growing list of bound functions generated
        This is necessary since, if pm_flag is true, we will start binding getters and setters here.
        We don't want to duplicate them when we bind the member functions, so we'll mark them here
        :return: a string containing the PyBind11 declaration for the member variable
        '''

        var_name_str = var_data.name
        classname_str = var_data.parent.decl_string

        is_public = var_data.access_type == dec.ACCESS_TYPES.PUBLIC

        if is_public:
            # Check if the variable is writeable
            is_const = dec.is_const(
                dec.remove_volatile(
                    dec.remove_reference(var_data.decl_type)))
            writeable_str = "write" if not is_const else "only"

            # Get a string representing a reference to the variable
            ref = self.opts.member_reference_fmt.format(classname=classname_str,
                                                        member=var_name_str)

            # Check if it is static
            is_static = var_data.type_qualifiers.has_static
            static_str = "_static" if is_static else ""

            return self.opts.public_member_var_fmt.format(write=writeable_str,
                                                          static=static_str,
                                                          var_name=var_name_str,
                                                          var_ref=ref)
        elif self.opts.pm_flag:
            # First find if there is a corresponding getter/setter
            getter_fxn_name = self.find_getter(var_data, var_data.parent)
            setter_fxn_name = self.find_setter(var_data, var_data.parent)

            # If neither could be found, return empty str
            if not getter_fxn_name and not setter_fxn_name:
                return ""

            # At this point, at least one of the functions was found,
            # so we can start generating the string
            accessors_string = ""
            is_readonly = True

            # Add the getters and setters
            if getter_fxn_name:
                ref = self.opts.member_reference_fmt.format(classname=classname_str, member=getter_fxn_name)
                accessors_string += ref
                written_functions.append(getter_fxn_name)

            if setter_fxn_name:
                ref = self.opts.member_reference_fmt.format(classname=classname_str, member=setter_fxn_name)
                separator = ", " if accessors_string else ""
                accessors_string += separator + ref
                written_functions.append(setter_fxn_name)

                # Also change the status of is_readonly
                is_readonly = False

            readonly_str = "_readonly" if is_readonly else ""

            is_static = var_data.type_qualifiers.has_static
            static_str = "_static" if is_static else ""

            return self.opts.private_member_var_fmt.format(readonly=readonly_str,
                                                           static=static_str,
                                                           var_name=var_name_str,
                                                           var_accessors=accessors_string)

        # If the variable isn't public, and the option to
        # expose the variable through getters and setters is False,
        # the variable won't be directly accessible from python
        # Getters and setters may still be bound when the member fxns are written
        return ""

    # Takes a list of function data, turns it into a long string of data
    # writes that string to a named file.
    def write_free_function_data(self, namespace, function_data, out_dir, found_includes):
        """
        Takes a list of function objects and writes them out to a PyBind11 moodule.
        The module is named for the first argument and the file is written to the out_dir


        :param namespace: String name of grouping to be used as PyBind11 module name
        :param function_data: List of PyGCCXML objects which describe functions
        :param out_dir: File location to store the resultant file.
        :param found_includes:
        :return: None
        """
        self.indent += " " * 2

        defs = ""
        for function in function_data:
            defs += self.generate_function_string(function, is_free_function=True)
            defs += self.indent
        file_name = posixpath.join(out_dir, "%s_py.cpp" % namespace)
        print("Writing %s to %s " % (namespace, file_name))

        self.write_data_to_file(self.opts.module_cpp_fmt, file_name, namespace+"_py", defs.strip(), found_includes)
        self.indent = self.indent[:-2]

    def write_class_data(self, class_data, out_dir, found_includes, written_files):
        """
        Takes an instance of Class data from PyGCCXML and outputs a single file with
        PyBind11 declarations for the class.  Includes constructors and functions

        :param class_data: pygccxml data dictionary for class object
        :param out_dir:  File location to store the resultant file.
        :param found_includes:
        :param written_files: a list of files that have been written so far, used to append
                              content into classes with more than one instance
        :return: None
        """

        # Increase the indent
        self.indent += " " * 4

        # First get the arguments to the py::class_<>() call
        # The first of which is the class name
        pyclass_args = class_data.decl_string

        # Next any super classes
        for b in class_data.bases:

            # If the relationship is public, add to pyclass_args
            if b.access_type == dec.ACCESS_TYPES.PUBLIC:
                pyclass_args = pyclass_args + ", " + b.related_class.decl_string

        constructor_str = ""
        # List to stuff names into which will prevent re-writing
        written_functions = []
        # First split over < to capture name of object without any potential
        # template arguments.
        file_parts = class_data.name.split("<")
        file_part = "%s_py" % file_parts[0]
        class_name = file_part
        # if template arguments exist, expand class name used in py::class_ definition
        # file_parts[1][:-1] <- drops the ending > from the template value.
        if len(file_parts) > 1:
            class_name = "%s_%s_py" % (file_parts[0], file_parts[1][:-1])
        for constructorObj in class_data.constructors():
            arg_string = ""
            print(constructorObj.name)
            if constructorObj.access_type != "public":
                continue
            for arg in constructorObj.argument_types:
                arg_string += arg.decl_string + ","
            arg_string = arg_string.strip(",")
            constructor_str += self.opts.constructor_fmt.format(arg_string, ", " if False else "")
            constructor_str += self.indent

        member_var_string = ""
        for member_var in class_data.variables():
            member_var_string += self.generate_member_var_string(member_var, written_functions)
            member_var_string += self.indent
        print(member_var_string)

        member_string = ""
        for member_function in class_data.member_functions():
            if member_function.name in written_functions or member_function.access_type != "public":
                continue
            member_string += self.generate_function_string(member_function)
            member_string += self.indent
            print(member_string)

        # TODO: Necessary?  Determine usefulness of listing operators
        operator_string = ""
        for operator in class_data.operators():
            operator_string += self.generate_operator_string(operator)
            operator_string += self.indent
        # with open("%s.hpp"  % file_part, "w") as hpp_file:

        file_name = posixpath.join(out_dir, "%s.cpp" % file_part)
        print("Writing %s to %s " % (file_part, file_name))
        wrapper_content = self.opts.class_info_body_fmt.format(name=class_name,
                                                               pyclass_args=pyclass_args,
                                                               doc=", doc.%s.doc)" % class_data.name if False else "",
                                                               constructor=constructor_str.strip(),
                                                               funcs=member_string.strip(),
                                                               vars=member_var_string.strip(),
                                                               opers=operator_string.strip()
                                                               )
        self.write_data_to_file(self.opts.module_cpp_fmt, file_name, file_part, wrapper_content.strip(), found_includes)

        self.indent = self.indent[:-4]

    def write_module_data(self, module_name, results_dict, out_dir):
        """
        Writes out the "folder" level module for wrapping.
        THis file follows
        :param module_name:
        :param results_dict:
        :param out_dir:
        :return: The name of the module file to include in the library
        """
        module_data = {"forwards": [],
                       "init_funs": []}
        for future_file in results_dict["out_names"]:
            module_data["forwards"].append(self.opts.init_fun_forward_fmt.format(name=future_file.split(".")[0]))
            module_data['init_funs'].append(self.opts.init_fun_signature_fmt.format(name=future_file.split(".")[0]))
        file_name = posixpath.join(out_dir, "%s.cpp" % module_name)
        additional_name = posixpath.join(out_dir, "%s_py.cpp" % module_name)
        self.write_data_to_file(self.opts.module_cpp_fmt,
                                additional_name,
                                module_name + "_py",
                                "",
                                "".join(results_dict["#include"])
                                )
        with open(file_name, "w") as module_file:
            module_file.write(self.opts.common_cpp_body_fmt.format(name=module_name,
                                                                   forwards="".join(module_data["forwards"]),
                                                                   init_funs="".join(module_data["init_funs"])
                                                                   ))
        return file_name

    def find_classes(self, src, src_dict, res_dict, free_fun_name):
        """
          Used to generate the input file to the CastXML tool.
          A recursive function to parse python dictionaries of input JSON to find objects that match
          "classes" or "functions"
          Then generates explicit instantiations for those template objects that need it.
          It also reads the file name to be "#include"-ed and adds to a growing list of
          search paths.

          Assumes:
            - "namespaces" found within an input JSON object are the same as the subdirectory name where
              additional files are found
            - Templated functions are not overloaded
            -

          :param src: file system directory to
          :param src_dict: A JSON dictionary with information about classes, functions, and namespaces.
          :param res_dict: A dictionary which contains the new objects found in each call
          :param free_fun_name: THe name of the file which will be used to capture free functions for this call
          :return: The results dictionary after summation from all recursive calls below this object.
        """
        for key, data in src_dict.items():
            if key in ["classes", "functions"]:
                if key == "classes":
                    for className in data:
                        if len(data[className]["inst"]):
                            for instantiation in data[className]["inst"]:
                                res_dict["instantiate"].append(("typedef {}<{}> {};\n").format(className,
                                                                                               instantiation,
                                                                                               className.split("::")[-1]+instantiation.split("::")[-1]))
                                res_dict["typedefs"].append(("template class {}<{}>;\n").format(className, instantiation))
                                res_dict["decls"].append("%s<%s>" % (className.split("::")[-1],
                                                                     instantiation.split("::")[-1]))
                        else:
                            res_dict["instantiate"].append(("typedef {} {};\n").format(
                                className, className.split("::")[-1]))
                            res_dict["decls"].append(className.split("::")[-1])
                        res_dict["#include"].append("#include \"%s\"\n" % data[className]["file"])
                        res_dict["out_names"].append("%s_py.cpp" % className.split("::")[-1])
                else:
                    res_dict["out_names"].append("%s_py.cpp" % free_fun_name)
                    for function_name in data:
                        inc_file_line = "#include \"%s\"\n" % data[function_name]["file"]
                        if inc_file_line not in res_dict["#include"]:
                            res_dict["#include"].append(inc_file_line)
                        if data[function_name]["is_template"]:
                            for instantiation in data[function_name]["inst"]:
                                instantiation_val = instantiation
                                instantiation_name = instantiation
                                if isinstance(instantiation_val, list):
                                    instantiation_val = ", ".join(instantiation)
                                    instantiation_name = "_".join(instantiation)
                                # clean up other special characters for name
                                instantiation_name = instantiation_name.replace("<", "").replace(">", "").replace(":", "")
                                inst_name = "%s_%s" % (function_name.split("::")[-1], instantiation_name)
                                res_dict["instantiate"].append(("auto {fakenm} = &{fun_name}<{type}>;\n").format(fakenm=inst_name,
                                                                                                                 fun_name=function_name,
                                                                                                                 type=instantiation_val))
            else:
                res_dict["namespaces"].append(key)
                res_dict = self.find_classes(posixpath.join(src, key), data, res_dict, key)
        return res_dict

    def clean_flags(self, rsp_path):
        rsp_includes = []
        rsp_defs = ""
        c_std_flag = ''
        with open(rsp_path, 'r') as fp:
            for line in fp.readlines():
                line = line.strip().replace(';', ' ').split(' ')
                if line[0] == "includes:":
                    rsp_includes = line[1:]

                elif line[0] == "defines:":
                    rsp_defs = " ".join(["-D" + def_ for def_ in line[1:]])

                elif line[0] == "c_std:":
                    c_std_flag = line[1]

                else:
                    print("ERROR: invalid first token in response file: %s" % line[0])

        rsp_defs = c_std_flag + ' ' + rsp_defs
        return rsp_includes, rsp_defs.strip()

    def parse(self):
        """
        Overall function to perform automatic generation of Pybind11 code from a C++ repository
        - Parses the JSON input
        - Finds classes and functions
        - writes instantiations and include into "wrapper.hpp", the only input to CastXML
        - Runs CastXML and used pygccxml to read results into data object
        - Loops through parsed data for previously found classes and all available functions

        :return: None
        """
        # init the pygccxml stuff
        # Adapted from CPPWG: https://github.com/jmsgrogan/cppwg/blob/265117455ed57eb250643a28ea6029c2bccf3ab3/cppwg/parsers/source_parser.py#L24
        results = {"includes": [],
                   "namespaces": [],
                   "typedefs": [],
                   "#include": [],
                   "instantiate": [],
                   "decls": [],
                   "out_names": []}
        files_to_parse = []
        # If a single json is marked, only look over that particular file
        # TODO: Glob over the json_path directory as well?
        files_to_parse.append(self.opts.json_path)
        # Source path is the directory with the JSON in it
        self.opts.source_dir = posixpath.dirname(self.opts.json_path)
        # Find classes and functions in each JSON file from above
        for found_input_file in files_to_parse:
            results = self.find_classes(self.opts.source_dir, json.load(open(found_input_file, "r")), results, "free_functions")
            results["includes"].append(posixpath.dirname(found_input_file))
        module_file = self.write_module_data(self.opts.module_name, results, self.opts.output_dir)
        # Short circuit: prints list of files to be generated by the run, if it were to continue.
        if self.opts.no_generation:
            print(';'.join([module_file]+results["out_names"]))
            return

        rsp_includes, rsp_defs = self.clean_flags(self.opts.rsp_path)
        print(self.opts)
        # Configures the CastXML calls
        # TODO: Ensure external programs can append additional C++ flag arguments
        castxml_config = pygccxml.parser.xml_generator_configuration_t(xml_generator_path=self.opts.castxml_path,
                                                                       xml_generator="castxml",
                                                                       cflags=rsp_defs,
                                                                       start_with_declarations=[self.opts.default_namespace],
                                                                       include_paths=results["includes"] + rsp_includes)

        # Creates the single file that CastXML should parse containing all instantiations and includes
        with open(os.path.join(self.opts.output_dir, "wrapper.cpp"), "w") as file:
            file.write(self.opts.wrap_header_fmt % ("".join(results["#include"]),
                                                    "".join(results["typedefs"]),
                                                    "drake_wrap",
                                                    "    ".join(results["instantiate"])))
        # Run CastXML and parse back the resulting XML into a Python Object.
        pygccxml.utils.loggers.cxx_parser.setLevel(logging.CRITICAL)
        pygccxml.declarations.scopedef_t.RECURSIVE_DEFAULT = False
        pygccxml.declarations.scopedef_t.ALLOW_EMPTY_MDECL_WRAPPER = True
        # TODO: If a file #included in wrapper.cpp #includes a file that has functions/classes
        # that are not to be wrapped, they will still be bound, since they are present in wrapper.cpp
        total = pygccxml.parser.parse(["wrapper.cpp"],
                                      castxml_config,
                                      compilation_mode=pygccxml.parser.COMPILATION_MODE.ALL_AT_ONCE)

        # Total seems to be a single item list, due to ALL_AT_ONCE mode, capture the data from
        # from the first item in the list
        written_files = []
        name_data = total[0]
        # If user provides an overall namespace, descend into it here.
        if self.opts.default_namespace:
            name_data = total[0].namespace(self.opts.default_namespace)
        # Capture the highest level list of free_function data
        # TODO: Should we capture free_functions if name_data = total[0]?
        other_free_functions = [x for x in name_data.free_functions()]
        # First capture list of the namespaces that pygccxml say exists
        known_namespaces = [x.name for x in name_data.namespaces()]
        for declared_class in results["decls"]:
            # If class name  found in classes found in name_data, write out file for that class
            if declared_class in [x.name for x in name_data.classes()]:
                self.write_class_data(name_data.class_(declared_class), self.opts.output_dir,
                                      "".join(results["#include"]), written_files)
            else:
                # Otherwise, loop through parsed namespaces to check for presence of classe()
                # Only if the namespace was found during data gathering do we look into it for the class
                # TODO: Is this check necssary?  How often will a function name be duplicated across namespaces?
                for namespace in results["namespaces"]:
                    if namespace in known_namespaces:
                        # If found in the sub-namespace list of classes, write the data
                        if declared_class in [x.name for x in name_data.namespace(namespace).classes()]:
                            self.write_class_data(name_data.namespace(namespace).class_(declared_class),
                                                  self.opts.output_dir, "".join(results["#include"]), written_files)
        # Separate write out for namespace free_functions
        # TODO: See https://kwgitlab.kitware.com/joe.snyder/wrapper_generator/-/issues/2 for related issue
        for namespace in results["namespaces"]:
            if namespace in known_namespaces and name_data.namespace(namespace).free_functions():
                self.write_free_function_data(namespace, name_data.namespace(namespace).free_functions(), self.opts.output_dir,
                                              "".join(results["#include"]))
        # Finally, if overall namespace has free functions, write them out.
        if other_free_functions:
            self.write_free_function_data(self.opts.module_name, other_free_functions, self.opts.output_dir,
                                          "".join(results["#include"]))


arg = ArgParser(config_file_parser_class=YAMLConfigFileParser)

arg.add("-o", "--output", action="store", dest="output_dir", required=False, default=os.getcwd())
arg.add("-j", "--input_json", action="store",  dest="json_path",
        help="Path to input JSON file of objects to process", required=True)
arg.add("--module_name", action="store",  dest="module_name",
        help="Desired name of the output PyBind11 module", required=True)
arg.add("-g", "--castxml-path", action="store", dest="castxml_path",
        help="Path to castxml",  required=False)
arg.add("-cg", "--config-path", dest="config_dir", required=False, is_config_file=True, help="config file path")
arg.add("-d", "--default_namespace", action="store", default="",
        dest="default_namespace", help="value to use as the default namespace", required=False)

arg.add("--no-generation", "-n", help="Only print name of files to be generated",
        dest="no_generation", action="store_true", required=False)
arg.add("-rs", "--input_response", required=False, dest='rsp_path', default='')
arg.add("-pm", "--private_members", required=False, action='store_true', dest='pm_flag', default=False)

# The formatted strings that will write the pybind code are also configurable
arg.add("--common_cpp_body_fmt", required=False, default=tb.common_cpp_body)
arg.add("--class_info_body_fmt", required=False, default=tb.class_info_body)
arg.add("--init_fun_signature_fmt", required=False, default=tb.init_fun_signature)
arg.add("--init_fun_forward_fmt", required=False, default=tb.init_fun_forward)
arg.add("--cppbody_fmt", type=str, required=False, default=tb.cppbody)
arg.add("--module_cpp_fmt", required=False, default=tb.module_cpp)
arg.add("--member_func_fmt", required=False, default=tb.member_func)
arg.add("--constructor_fmt", required=False, default=tb.constructor)
arg.add("--member_func_arg_fmt", required=False, default=tb.member_func_arg)
arg.add("--public_member_var_fmt", required=False, default=tb.public_member_var)
arg.add("--private_member_var_fmt", required=False, default=tb.private_member_var)
arg.add("--member_reference_fmt", required=False, default=tb.member_reference)
arg.add("--overload_template_fmt", required=False, default=tb.overload_template)
arg.add("--wrap_header_fmt", required=False, default=tb.wrap_header)
arg.add("--operator_fmt", required=False, default=tb.operator_template)
arg.add("--call_operator_fmt", required=False, default=tb.call_template)

options = arg.parse_args()

BindingsGenerator(options).parse()
