from argparse import ArgumentParser
from text_blocks import *
import glob
import json
import logging
import os
import posixpath
import pygccxml
import pygccxml.declarations as dec


def add_with_sep(s, to_add, separ):
    """
    Helper function to combine two strings separated by another string

    :param s: String that is added to
    :param to_add: String that is added to s
    :param separ: Separation string between s and to_add
    :return: s and to_add combined, separated by separ
    """
    return s + separ + to_add


def generate_function_string(member_function, indent, is_free_function=False):
    """
    Takes a set of PyGCCXML data and returns the PyBind11 code for all
    overloaded versions of the function

    :param member_function: PyGCCXML data dictionary for individual function
    :param is_free_function: Boolean to determine if function should be marked
    as a member of the PyBind11 module or of a class. Essentially, prepends "m" to function signature
    is_free_function = True
      m.def("ExtractDoubleOrThrow", py::overload_cast<double const &>(ExtractDoubleOrThrow),<...>
    is_free_function = False
      .def("start_time", py::overload_cast<>(start_time),, , doc.PiecewiseTrajectory.start_time.doc)
    :return: A string which contains all PyBind11 declarations for single function and it's overloads
    """
    # Check to see if we've written out this function and its overloads
    # already
    member_string = ""
    # Capture all overloaded versions at once.
    all_version = find_all_function_decls(member_function)
    overload_flag = False
    # Len > 1 means there are overloads, and we write them all now.
    # also used to ensure that the py_overload flag version is used.
    if len(all_version) > 1:
        overload_flag = True
    for version in all_version:
        arg_string = ""
        # Capture each argument and a default value, if found.
        for arg in version.arguments:
            next_arg_str = member_func_arg.format(arg.name,
                                                 " = %s" % arg.default_value if arg.default_value else "")
            arg_string = add_with_sep(arg_string, next_arg_str, ", ")
        if version.parent.name == "::":
            ref_string = "&%s" % (version.name)

        else:
            ref_string = "&%s::%s" % (version.parent.name, version.name)

        # Since function is overloaded, capture the argument types to uniquely identify the function.
        # if overload_flag:
        signature = overload_template.format(fun_ref=ref_string,
                                             decl_string=version.decl_string)

        # Check to see if function can be marked as static
        static = ""
        if "has_static" in dir(member_function):
            static = "_static" if member_function.has_static else ""
        # Append found values to growing string list of members
        member_string += indent + member_func.format(
          module="m" if is_free_function else "",
          static=static,
          fun_name=member_function.name,
          fun_ref=signature,
          args=arg_string,
          classname_doc=version.parent.name.split("<")[0],
          doc=", doc.%s.%s.doc" % (version.parent.name.split("<")[0], member_function.name) if False else "",
          ending=";" if is_free_function else "")

        member_string.strip(",")

    return member_string



def find_getter(var_data, class_data):
    name_to_find = "get_" + var_data.name
    if name_to_find in [fxn.name for fxn in class_data.member_functions()]:
        return name_to_find

    return ""

def find_setter(var_data, class_data):
    name_to_find = "set_" + var_data.name
    if name_to_find in [fxn.name for fxn in class_data.member_functions()]:
        return name_to_find

    return ""


def generate_member_var_string(var_data, getter_setter_flag, written_functions):
    '''
    Accepts a pygccxml object representing a variable.
    Checks whether the variable is writeable and static,
    then generates valid binding code for that variable as a string.

    :param var_data: a PyGCCXML object representing a variable
    :return: a string containing the PyBind11 declaration for the member variable
    '''

    var_name_str = var_data.name
    classname_str = var_data.parent.name

    is_public = var_data.access_type == dec.ACCESS_TYPES.PUBLIC

    if is_public:
        # Check if the variable is writeable
        is_const = dec.is_const(
            dec.remove_volatile(
                dec.remove_reference(var_data.decl_type)))
        writeable_str = "write" if not is_const else "only"

        # Get a string representing a reference to the variable
        ref = member_reference.format(classname=classname_str,
                                      member=var_name_str)

        # Check if it is static
        is_static = var_data.type_qualifiers.has_static
        static_str = "_static" if is_static else ""

        return public_member_var_fmt.format(write=writeable_str,
                                            static=static_str,
                                            var_name=var_name_str,
                                            var_ref=ref)
    elif getter_setter_flag:

        # First find if there is a corresponding getter/setter
        getter_fxn_name = find_getter(var_data, var_data.parent)
        setter_fxn_name = find_setter(var_data, var_data.parent)

        # If neither could be found, return empty str
        if not getter_fxn_name and not setter_fxn_name:
            return ""

        # At this point, at least one of the functions was found,
        # so we can start generating the string
        accessors_string = ""
        is_readonly = True

        # Add the getters and setters
        if getter_fxn_name:
            ref = member_reference.format(classname=classname_str, member=getter_fxn_name)
            accessors_string = add_with_sep(accessors_string, ref, "")
            written_functions.append(getter_fxn_name)

        if setter_fxn_name:
            ref = member_reference.format(classname=classname_str, member=setter_fxn_name)
            accessors_string = add_with_sep(accessors_string, ref, ", " if accessors_string else "")
            written_functions.append(setter_fxn_name)

            # Also change the status of is_readonly
            is_readonly = False

        readonly_str = "_readonly" if is_readonly else ""

        is_static = var_data.type_qualifiers.has_static
        static_str = "_static" if is_static else ""

        return private_member_var_fmt.format(readonly=readonly_str,
                                             static=static_str,
                                             var_name=var_name_str,
                                             var_accessors=accessors_string)

    # If the variable isn't public, and the option to
    # expose the variable through getters and setters is False,
    # the variable won't be directly accessible from python
    # Getters and setters may still be bound when the member fxns are written
    return ""



# Takes a data object of the function and returns a list of
# all declarations, both current and all potential overloads
def find_all_function_decls(function_data):
    """
    Accepts a pygccxml object of a function , loops through the overloads,
    and appends each to a list

    :param function_data: A PyGCCXML data object for a function
    :return: A list of known signatures of the function, including all overloads
    """
    all_functions = []
    for overload in function_data.overloads.declarations:
        all_functions.append(overload)
    all_functions.append(function_data)

    return all_functions


# Takes a list of function data, turns it into a long string of data
# writes that string to a named file.
def write_free_function_data(namespace, function_data, out_dir, found_includes):
    """
    Takes a list of function objects and writes them out to a PyBind11 moodule.
    The module is named for the first argument and the file is written to the out_dir


    :param namespace: String name of grouping to be used as PyBind11 module name
    :param function_data: List of PyGCCXML objects which describe functions
    :param out_dir:  File location to store the resultant file.
    :return: None
    """
    indent = " " * 2
    defs = ""
    written_functions = []
    for function in function_data:
        # Skip if we already wrote it, to prevent writing multiple copies of multiple overloads
        if function.name in written_functions:
            continue
        defs += generate_function_string(function, indent, is_free_function=True)
        written_functions.append(function.name)
    file_name = posixpath.join(out_dir, "%s_py.cpp" % namespace)
    print("Writing %s to %s " % (namespace, file_name))
    with open(file_name, "w") as namespacepy:
        namespacepy.write(module_cpp.format(namespace=namespace, defs=defs.strip(),
                                            includes=found_includes))


def write_class_data(class_data, out_dir, found_includes):
    """
    Takes an instance of Class data from PyGCCXML and outputs a single file with
    PyBind11 declarations for the class.  Includes constructors and functions

    :param class_data: pygccxml data dictionary for class object
    :param out_dir:  File location to store the resultant file.
    :return: None
    """
    indent = " " * 4

    # First get the arguments to the py::class_<>() call
    # The first of which is the class name
    pyclass_args = class_data.name

    # Next any super classes
    for b in class_data.bases:

        # If the relationship is public, add to pyclass_args
        if b.access_type == dec.ACCESS_TYPES.PUBLIC:
            pyclass_args = add_with_sep(pyclass_args, b.related_class.name, ", ")

    constructor_str = ""
    # List to stuff names into which will prevent re-writing
    written_functions = []
    file_part = "%s_py" % class_data.name.split("<")[0]
    for constructorObj in class_data.constructors():
        arg_string = ""
        print(constructorObj.name)
        for arg in constructorObj.argument_types:
            arg_string += arg.decl_string + ","
        arg_string = arg_string.strip(",")
        constructor_str += indent + constructor.format(arg_string, ", " if False else "")

    member_var_string = ""
    for member_var in class_data.variables():
        member_var_string += indent + generate_member_var_string(member_var, True, written_functions)
    print(member_var_string)

    member_string = ""
    for member_function in class_data.member_functions():
        if member_function.name in written_functions:
            continue
        member_string += generate_function_string(member_function, indent)
        print(member_string)
        written_functions.append(member_function.name)

    # TODO: Necessary?  Determine usefulness of listing operators
    for operator in class_data.operators():
        continue
        # print(operator)
        # print(operator.name)
        # print(len(operator.argument_types))

    # with open("%s.hpp"  % file_part, "w") as hpp_file:
    file_name = posixpath.join(out_dir, "%s.cpp" % file_part)
    print("Writing %s to %s " % (file_part, file_name))
    with open(file_name, "w") as cpp_file:
        cpp_file.write(cppbody.format(name=file_part,
                                      pyclass_args=pyclass_args,
                                      doc=", doc.%s.doc)" % class_data.name if False else "",
                                      constructor=constructor_str.strip(),
                                      funcs=member_string.strip(),
                                      # vars="",
                                      vars=member_var_string.strip(),
                                      arith="",
                                      includes=found_includes))


def write_module_data(module_name, results_dict, out_dir):
    """
    Writes out the "folder" level module for wrapping.
    THis file follows
    :param module_name:
    :param options:
    :param includes:
    :return: The name of the module file to include in the library
    """
    module_data = {"forwards": [],
                   "init_funs": []}
    for future_file in results_dict["out_names"]:
        module_data["forwards"].append(init_fun_forward.format(name=future_file.split(".")[0]))
        module_data['init_funs'].append(init_fun_signature.format(name=future_file.split(".")[0]))
    file_name = posixpath.join(out_dir, "%s.cpp" % module_name)
    with open(file_name, "w") as module_file:
        module_file.write(common_cpp_body.format(name=module_name,
                                                 forwards="".join(module_data["forwards"]),
                                                 init_funs="".join(module_data["init_funs"])
                                                 ))
    return file_name


def find_classes(src, src_dict, res_dict, free_fun_name):
    """
      Used to generate the input file to the CastXML tool.
      A recursive function to parse python dictionaries of input JSON to find objects that match
      "classes" or "functions"
      Then generates explicit instantiations for those template objects that need it.
      It also reads the file name to be "#include"-ed and adds to a growing list of
      search paths.

      Assumes:
        - "namespaces" found within an input JSON object are the same as the subdirectory name where
          additional files are found
        - Templated functions are not overloaded
        -

      :param src: file system directory to
      :param src_dict: A JSON dictionary with information about classes, functions, and namespaces.
      :param res_dict: A dictionary which contains the new objects found in each call
      :param free_fun_name: THe name of the file which will be used to capture free functions for this call
      :return: The results dictionary after summation from all recursive calls below this object.
    """
    for key, data in src_dict.items():
        if key in ["classes", "functions"]:
            if key == "classes":
                for className in data:
                    if len(data[className]["inst"]):
                        for instantiation in data[className]["inst"]:
                            res_dict["instantiate"].append(("typedef {}<{}> {};\n").format(className,
                                                                                           instantiation,
                                                                                           className.split("::")[-1]+instantiation.split("::")[-1]))
                            res_dict["typedefs"].append(("template class {}<{}>;\n").format(className, instantiation))
                            res_dict["decls"].append("%s<%s>" % (className.split("::")[-1],
                                                                 instantiation.split("::")[-1]))
                    else:
                        res_dict["instantiate"].append(("typedef {} {};\n").format(
                            className, className.split("::")[-1]))
                        res_dict["decls"].append(className.split("::")[-1])
                    res_dict["#include"].append("#include \"%s\"\n" % data[className]["file"])
                    res_dict["out_names"].append("%s_py.cpp" % className.split("::")[-1])
            else:
                res_dict["out_names"].append("%s_py.cpp" % free_fun_name)
                for function_name in data:
                    inc_file_line = "#include \"%s\"\n" % data[function_name]["file"]
                    if inc_file_line not in res_dict["#include"]:
                        res_dict["#include"].append(inc_file_line)
                    if data[function_name]["is_template"]:
                        for instantiation in data[function_name]["inst"]:
                            instantiation_val = instantiation
                            instantiation_name = instantiation
                            if isinstance(instantiation_val, list):
                                instantiation_val = ", ".join(instantiation)
                                instantiation_name = "_".join(instantiation)
                            # clean up other special characters for name
                            instantiation_name = instantiation_name.replace("<", "").replace(">", "").replace(":", "")
                            inst_name = "%s_%s" % (function_name.split("::")[-1], instantiation_name)
                            res_dict["instantiate"].append(("auto {fakenm} = &{fun_name}<{type}>;\n").format(fakenm=inst_name,
                                                                                                             fun_name=function_name,
                                                                                                             type=instantiation_val))
        else:
            res_dict["namespaces"].append(key)
            res_dict = find_classes(posixpath.join(src, key), data, res_dict, key)
    return res_dict


def parse(options):
    """
    Overall function to perform automatic generation of Pybind11 code from a C++ repository
    - Parses the JSON input
    - Finds classes and functions
    - writes instantiations and include into "wrapper.hpp", the only input to CastXML
    - Runs CastXML and used pygccxml to read results into data object
    - Loops through parsed data for previously found classes and all available functions

    :param options: Result of the argparse module found in the running of the file.
    :return: None
    """
    # init the pygccxml stuff
    # Adapted from CPPWG: https://github.com/jmsgrogan/cppwg/blob/265117455ed57eb250643a28ea6029c2bccf3ab3/cppwg/parsers/source_parser.py#L24
    results = {"includes": [],
               "namespaces": [],
               "typedefs": [],
               "#include": [],
               "instantiate": [],
               "decls": [],
               "out_names": []}
    files_to_parse = []
    # If a single json is marked, only look over that particular file
    # TODO: Glob over the json_path directory as well?
    files_to_parse.append(options.json_path)
    # Source path is the directory with the JSON in it
    options.source_dir = posixpath.dirname(options.json_path)
    # Find classes and functions in each JSON file from above
    for found_input_file in files_to_parse:
        results = find_classes(options.source_dir, json.load(open(found_input_file, "r")), results, "free_functions")
        results["includes"].append(posixpath.dirname(found_input_file))
    module_file = write_module_data(options.module_name, results, options.output_dir)
    # Short circuit: prints list of files to be generated by the run, if it were to continue.
    if options.no_generation:
        print(';'.join([module_file]+results["out_names"]))
        return

    # Configures the CastXML calls
    # TODO: Ensure external programs can append additional C++ flag arguments
    castxml_config = pygccxml.parser.xml_generator_configuration_t(xml_generator_path=options.castxml_path,
                                                                   xml_generator="castxml",
                                                                   cflags="-std=c++1z" + " " + options.c_flags,
                                                                   include_paths=results["includes"] + options.includes)
    # Creates the single file that CastXML should parse containing all instantiations and includes
    with open("wrapper.cpp", "w") as file:
        file.write(wrap_header % ("".join(results["#include"]),
                                  "".join(results["typedefs"]),
                                  "drake_wrap",
                                  "    ".join(results["instantiate"])))
    # Run CastXML and parse back the resulting XML into a Python Object.
    pygccxml.utils.loggers.cxx_parser.setLevel(logging.CRITICAL)
    pygccxml.declarations.scopedef_t.RECURSIVE_DEFAULT = False
    pygccxml.declarations.scopedef_t.ALLOW_EMPTY_MDECL_WRAPPER = True
    total = pygccxml.parser.parse(["wrapper.cpp"],
                                  castxml_config,
                                  compilation_mode=pygccxml.parser.COMPILATION_MODE.ALL_AT_ONCE)

    # Total seems to be a single item list, due to ALL_AT_ONCE mode, capture the data from
    # from the first item in the list
    name_data = total[0]
    # If user provides an overall namespace, descend into it here.
    if options.default_namespace:
        name_data = total[0].namespace(options.default_namespace)
    # Capture the highest level list of free_function data
    # TODO: Should we capture free_functions if name_data = total[0]?
    other_free_functions = [x for x in name_data.free_functions()]
    # First capture list of the namespaces that pygccxml say exists
    known_namespaces = [x.name for x in name_data.namespaces()]
    for declared_class in results["decls"]:
        # If class name  found in classes found in name_data, write out file for that class
        if declared_class in [x.name for x in name_data.classes()]:
            write_class_data(name_data.class_(declared_class), options.output_dir, "".join(results["#include"]))
        else:
            # Otherwise, loop through parsed namespaces to check for presence of classe()
            # Only if the namespace was found during data gathering do we look into it for the class
            # TODO: Is this check necssary?  How often will a function name be duplicated across namespaces?
            for namespace in results["namespaces"]:
                if namespace in known_namespaces:
                    # If found in the sub-namespace list of classes, write the data
                    if declared_class in [x.name for x in name_data.namespace(namespace).classes()]:
                        write_class_data(name_data.namespace(namespace).class_(declared_class), options.output_dir, "".join(results["#include"]))
    # Separate write out for namespace free_functions
    # TODO: See https://kwgitlab.kitware.com/joe.snyder/wrapper_generator/-/issues/2 for related issue
    for namespace in results["namespaces"]:
        if namespace in known_namespaces and name_data.namespace(namespace).free_functions():
            write_free_function_data(namespace, name_data.namespace(namespace).free_functions(), options.output_dir, " ".join(results["#include"]))
    # Finally, if overall namespace has free functions, write them out.
    if other_free_functions:
        write_free_function_data("free_functions", other_free_functions, options.output_dir, "".join(results["#include"]))


arg = ArgumentParser()
arg.add_argument("-o", "--output", action="store", dest="output_dir", required=False, default=os.getcwd())
arg.add_argument("-j", "--input_json", action="store",  dest="json_path",
                 help="Path to input JSON file of objects to process", required=True)
arg.add_argument("--module_name", action="store",  dest="module_name",
                 help="Desired name of the output PyBind11 module", required=True)
arg.add_argument("-g", "--castxml-path", action="store", dest="castxml_path",
                 help="Path to castxml",  required=False)
arg.add_argument("-d", "--default_namespace", action="store",
                 dest="default_namespace", help="value to use as the default namespace", required=False)
arg.add_argument("-i", "--includes", type=str, help="Path to the includes directory.",
                 action="append", default=[])
arg.add_argument("--no-generation", "-n", help="Only print name of files to be generated",
                 dest="no_generation", action="store_true", required=False)
arg.add_argument("-cf", "--cflags", required=False, dest='c_flags', default='')
options = arg.parse_args()
parse(options)
