Commit 2470d8a3 authored by tao558's avatar tao558
Browse files

More work on trampolines

parent aa977426
......@@ -33,7 +33,6 @@ class BindingsGenerator:
include_str += "#include \"%s\"\n" % f
return include_str
def generate_function_string(self, member_function, is_free_function=False, py_name=''):
"""
Takes a set of PyGCCXML data and returns the PyBind11 code for
......@@ -251,16 +250,16 @@ class BindingsGenerator:
if declaration.decl_string == "::":
continue
keys["defs"] += self.opts.enum_header_fmt.format(class_name=declaration.decl_string,
name=declaration.name,
type=enum_type if enum_type else "py::arithmetic()",
doc="")
name=declaration.name,
type=enum_type if enum_type else "py::arithmetic()",
doc="")
for enum_obj in declaration.get_name2value_dict().keys():
scope_name = enum_obj
if declaration.decl_string != "::":
scope_name = "%s::%s" % (declaration.decl_string, enum_obj)
keys["defs"] += self.opts.enum_val_fmt.format(short_name=enum_obj,
scoped_name=scope_name,
doc="")
scoped_name=scope_name,
doc="")
keys["defs"] += ".export_values();\n"
......@@ -269,6 +268,71 @@ class BindingsGenerator:
file_name = posixpath.join(out_dir, "%s_py.cpp" % module_name)
self.write_data_to_file(file_name, self.opts.non_class_module_cpp_fmt.format(**keys))
def format_tramp_override_sig(self, fun, tramp_name):
fun_str = str(fun)
# First, we have to remove the [member_function] part
# return_type nmspc1::Base::foo(args) [member_function] ->
# return_type nmspc1::Base::foo(args)
fun_str = fun_str.replace(" [member function]", "")
# Next we have to remove the current class name
# return_type nmspc1::Base::foo() ->
# return_type foo(args)
to_be_replaced = fun.parent.decl_string.strip("::") + "::" + fun.name
return fun_str.replace(to_be_replaced, fun.name)
def get_tramp_overload_macro_args(self, fun):
keys = {
"return_type": fun.return_type,
"parent_class": fun.parent.decl_string,
"cpp_fxn_name": fun.name,
"arg_str": ""
}
# Go through and construct argument string
# IMMEDIATE TODO: start here. The args have extra newlines between them,
# no newlines between function definitions
for i, arg in enumerate(fun.arguments):
if i:
keys["arg_str"] += ",\n" + self.indent
keys["arg_str"] += arg.name
return self.opts.pybind_overload_macro_args_fmt.format(**keys).strip()
def get_tramp_overrides(self, class_inst, tramp_name):
overrides_acc = ""
only_virt = lambda f: f.virtuality != "not virtual"
virtual_member_funs = class_inst.member_functions(only_virt)
for fun in virtual_member_funs:
keys = dict()
keys["fxn_sig"] = self.format_tramp_override_sig(fun, tramp_name)
is_pure_virt = fun.virtuality == "pure virtual"
keys["pure"] = "_PURE" if is_pure_virt else ""
self.indent += " " * 2
keys["macro_args"] = self.get_tramp_overload_macro_args(fun)
self.indent = self.indent[:-2]
overrides_acc += self.opts.tramp_override_fmt.format(**keys)
overrides_acc += "\n" * 2
return overrides_acc.strip()
def get_trampoline_string(self, class_inst, tramp_name):
"""
Assumes that the instance has at least 1 virtual method
"""
keys = dict()
keys["tramp_name"] = tramp_name
keys["class_decl"] = class_inst.decl_string
keys["ctor_name"] = class_inst.name
keys["virtual_overrides"] = self.get_tramp_overrides(class_inst, tramp_name)
return self.opts.trampoline_def_fmt.format(**keys)
def write_class_data(self, cpp_class_name, instance_list, out_dir, found_includes):
"""
Takes an instance of Class data from PyGCCXML and outputs a single file with
......@@ -286,10 +350,9 @@ class BindingsGenerator:
newlines_between_classes = "\n" * 3
keys = {
"includes": self.get_include_str(found_includes),
"trampoline_defs": "",
"trampoline_str": "",
"namespace": cpp_class_name + "_py",
"defs": "",
"trampoline_impl": ""
"defs": ""
}
file_name = posixpath.join(out_dir, cpp_class_name + "_py.cpp")
......@@ -310,6 +373,12 @@ class BindingsGenerator:
if b.access_type == dec.ACCESS_TYPES.PUBLIC:
pyclass_args += ", " + b.related_class.decl_string
num_virt_funs = len(instance_data.member_functions(lambda f: f.virtuality != "not virtual"))
# If there exist any virtual functions, we'll need to write out a trampoline implementation
if num_virt_funs > 0:
tramp_name = cpp_class_name + "_trampoline"
keys["trampoline_str"] = self.get_trampoline_string(instance_data, tramp_name)
constructor_str = ""
# List to stuff names into which will prevent re-writing
written_functions = []
......@@ -332,6 +401,7 @@ class BindingsGenerator:
member_string = ""
for member_function in instance_data.member_functions():
# import pdb; pdb.set_trace()
if member_function.name in written_functions or member_function.access_type != "public":
continue
member_string += self.generate_function_string(member_function)
......@@ -608,7 +678,6 @@ class BindingsGenerator:
rsp_includes, rsp_defs = self.clean_flags(self.opts.rsp_path)
self.generate_wrapper_cpp(module_info)
name_data = self.compile_and_parse_wrapper(rsp_includes, rsp_defs)
print(self.opts)
classes_to_find = set(module_info["non_template_classes"] + module_info["class_insts"])
classes = name_data.classes(lambda c: c.name in classes_to_find, recursive=True)
......@@ -659,7 +728,9 @@ arg.add("--operator_fmt", required=False, default=tb.operator_template)
arg.add("--call_operator_fmt", required=False, default=tb.call_template)
arg.add("--enum_header_fmt", required=False, default=tb.enum_header)
arg.add("--enum_val_fmt", required=False, default=tb.enum_val)
arg.add("--trampoline_fmt", required=False, default=tb.trampoline)
arg.add("--tramp_override_fmt", required=False, default=tb.tramp_override)
arg.add("--trampoline_def_fmt", required=False, default=tb.trampoline_def)
arg.add("--pybind_overload_macro_args_fmt", required=False, default=tb.pybind_overload_macro_args)
options = arg.parse_args()
......
......@@ -7,8 +7,8 @@ class Base
{
public:
virtual std::string virt1() const = 0;
virtual std::string virt2() const = 0;
virtual std::string virt3() const = 0;
virtual std::string virt2(float f) const = 0;
virtual std::string virt3(float f, std::string s) const = 0;
};
#endif
\ No newline at end of file
......@@ -14,7 +14,7 @@ public:
return std::string("Derived1.virt1()");
}
virtual std::string virt2() const override
virtual std::string virt2(float f) const override
{
return std::string("Derived1.virt2()");
}
......
......@@ -15,12 +15,12 @@ public:
return std::string("Derived2.virt1()");
}
virtual std::string virt2() const override
virtual std::string virt2(float f) const override
{
return std::string("Derived2.virt2()");
}
virtual std::string virt3() const override
virtual std::string virt3(float f, std::string s) const override
{
return std::string("Derived2.virt3()");
}
......
......@@ -42,8 +42,6 @@ void auto_bind_{name}(py::module &m)
// }} // namespace pydrake
// }} // namespace drake
"""
trampoline = "class "
class_info_body = """
py::class_<{pyclass_args}>(m, "{name}"{doc})
{constructor}
......@@ -52,20 +50,40 @@ class_info_body = """
{opers}
;
"""
class_module_cpp = """
#include <pybind11/pybind11.h>
pybind_overload_macro_args = """{return_type},
{parent_class},
{cpp_fxn_name},
{arg_str}
"""
tramp_override = """{fxn_sig} override
{{
PYBIND11_OVERLOAD{pure}
(
{macro_args}
);
}}"""
trampoline_def = """class {tramp_name}
: public {class_decl}
{{
using {class_decl}::{ctor_name};
{virtual_overrides}
}};
"""
class_module_cpp = """#include <pybind11/pybind11.h>
{includes}
{trampoline_defs}
{trampoline_str}
namespace py = pybind11;
void auto_bind_{namespace}(py::module &m)
{{
{defs}
}}
{trampoline_impl}
"""
non_class_module_cpp = """
#include <pybind11/pybind11.h>
non_class_module_cpp = """#include <pybind11/pybind11.h>
{includes}
namespace py = pybind11;
void auto_bind_{namespace}(py::module &m)
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment