Commit eb4e7270 authored by Tom Osika's avatar Tom Osika
Browse files

Merge branch 'dev/add-publicist' into 'master'

Dev/add publicist

See merge request autopybind11/autopybind11!24
parents a20efa78 ff931a54
Pipeline #183115 failed with stage
in 0 seconds
......@@ -40,7 +40,7 @@ class BindingsGenerator:
include_str += "#include \"%s\"\n" % f
return include_str
def generate_function_string(self, member_function, is_free_function=False, py_name=''):
def generate_function_string(self, member_function, is_free_function=False, py_name='', publicist_name=''):
"""
Takes a set of PyGCCXML data and returns the PyBind11 code for
that function
......@@ -67,11 +67,21 @@ class BindingsGenerator:
next_arg_str = self.opts.member_func_arg_fmt.format(arg.name,
" = %s" % arg.default_value if arg.default_value else "")
arg_string = arg_string + ", " + next_arg_str
if member_function.parent.name == "::":
ref_string = "&%s" % (member_function.name)
else:
ref_string = "&%s::%s" % (member_function.parent.decl_string, member_function.name)
parent = ""
if not is_free_function and self.protected_filter(member_function):
if publicist_name:
parent = publicist_name
else:
msg = "No publicist name set for protected virtual method %s" % member_function.name
raise RuntimeError(msg)
else:
parent = member_function.parent.decl_string
ref_string = "&%s::%s" % (parent, member_function.name)
signature = self.opts.overload_template_fmt.format(fun_ref=ref_string,
decl_string=member_function.decl_string)
......@@ -288,6 +298,9 @@ class BindingsGenerator:
def public_filter(self, x):
return x.access_type == "public"
def protected_filter(self, x):
return x.access_type == "protected"
def private_filter(self, x):
return x.access_type == "private"
......@@ -391,7 +404,27 @@ class BindingsGenerator:
keys["ctor_name"] = cpp_class_name
keys["virtual_overrides"] = self.get_tramp_overrides(class_inst, alias, methods)
return self.opts.trampoline_def_fmt.format(**keys)
return self.opts.trampoline_def_fmt.format(**keys).strip()
def get_publicist_using_directives(self, class_inst, methods):
directives = ""
for m in methods:
keys = dict()
keys["class_decl"] = class_inst.decl_string
keys["fxn_name"] = m.name
directives += self.opts.publicist_using_directives_fmt.format(**keys)
directives += "\n" + self.indent
return directives.strip()
def get_publicist_string(self, class_inst, publicist_name, methods):
keys = dict()
keys["publicist_name"] = publicist_name
keys["class_decl"] = class_inst.decl_string
keys["using_directives"] = self.get_publicist_using_directives(class_inst, methods)
return self.opts.publicist_def_fmt.format(**keys).strip()
def write_class_data(self, cpp_class_name, instance_list, out_dir, found_includes, desired_name):
"""
......@@ -412,6 +445,7 @@ class BindingsGenerator:
keys = {
"includes": self.get_include_str(found_includes),
"trampoline_str": "",
"publicist_str": "",
"namespace": cpp_class_name + "_py",
"defs": ""
}
......@@ -456,6 +490,19 @@ class BindingsGenerator:
# Also need to add to pyclass_args
pyclass_args += ", " + tramp_name
# Now we'll deal with any protected functions, as we want these to be visible
# to python subclasses. Especially if they are virtual and appear in the
# trampoline for overriding.
prot_methods = instance_data.member_functions(self.protected_filter)
publicist_name = ""
if prot_methods:
self.indent = self.indent[:-2]
publicist_name = pyclass_name_stem + "_publicist"
publicist_str = self.get_publicist_string(instance_data, publicist_name, prot_methods)
keys["publicist_str"] += publicist_str
keys["publicist_str"] += newlines
self.indent += " " * 2
constructor_str = ""
# List to stuff names into which will prevent re-writing
written_functions = []
......@@ -501,9 +548,9 @@ class BindingsGenerator:
member_string = ""
# import pdb; pdb.set_trace()
for member_function in instance_data.member_functions():
if member_function.name in written_functions or member_function.access_type != "public":
if member_function.name in written_functions or self.private_filter(member_function):
continue
member_string += self.generate_function_string(member_function)
member_string += self.generate_function_string(member_function, publicist_name=publicist_name)
member_string += self.indent
# TODO: Necessary? Determine usefulness of listing operators
......@@ -887,6 +934,8 @@ def main():
arg.add("--trampoline_def_fmt", required=False, default=tb.trampoline_def)
arg.add("--pybind_overload_macro_args_fmt", required=False, default=tb.pybind_overload_macro_args)
arg.add("--copy_constructor_tramp_fmt", required=False, default=tb.copy_constructor_tramp)
arg.add("--publicist_using_directives_fmt", required=False, default=tb.publicist_using_directives)
arg.add("--publicist_def_fmt", required=False, default=tb.publicist_def)
options = arg.parse_args()
......
......@@ -78,9 +78,20 @@ public:
}};
"""
publicist_using_directives = """using {class_decl}::{fxn_name};"""
publicist_def = """class {publicist_name}
: public {class_decl}
{{
public:
{using_directives}
}};
"""
class_module_cpp = """#include <pybind11/pybind11.h>
{includes}
{trampoline_str}
{publicist_str}
namespace py = pybind11;
void auto_bind_{namespace}(py::module &m)
{{
......@@ -105,7 +116,7 @@ member_func_arg = """py::arg(\"{}\"){}"""
public_member_var = """.def_read{write}{static}(\"{var_name}\", {var_ref})\n"""
private_member_var = """.def_property{readonly}{static}(\"{var_name}\", {var_accessors})\n"""
member_reference = "&{classname}::{member}"
overload_template = """({decl_string}) ({fun_ref})"""
overload_template = """static_cast<{decl_string}>({fun_ref})"""
operator_template = """.def({arg1} {symbol} {arg2})\n"""
call_template = """.def("__call__", []({arg_str}){{ {fn_call} }})\n"""
......
......@@ -37,6 +37,10 @@ add_test(NAME non_template_tramp_copy_constructor_unittest
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
# Templated test
add_test(NAME template_basic_behavior_unittest
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/tests/template/basicClassTests.py
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
add_test(NAME template_trampoline_unittest
COMMAND ${PYTHON_EXECUTABLE} ${CMAKE_CURRENT_SOURCE_DIR}/tests/template/trampolineTests.py
WORKING_DIRECTORY ${CMAKE_CURRENT_BINARY_DIR})
......
......@@ -12,6 +12,13 @@ public:
float var0;
virtual std::string whoami() const = 0;
// Use this to actually call the protected virtual function
// Can't use an external function since the method is protected
inline std::string call_prot_virt_fxn() const
{
return this -> prot_virt_fxn();
}
// Use this to actually call the private virtual function
// Can't use an external function since the method is private
std::string call_priv_virt_fxn() const
......@@ -19,6 +26,13 @@ public:
return this -> priv_virt_fxn();
}
protected:
virtual std::string prot_virt_fxn() const = 0;
inline std::string prot_fxn() const
{
return std::string("Base.prot_fxn()");
}
private:
// Make sure private pure virtual functions are included in trampoline.
// The way that trampolines work actually requires these be overriden,
......
......@@ -30,6 +30,12 @@ public:
return std::string("Derived1.virt1()");
}
protected:
virtual std::string prot_virt_fxn() const override
{
return std::string("Derived1.prot_virt_fxn()");
}
private:
// Make sure private non-pure virtual functions aren't included in trampoline
inline virtual std::string priv_virt_fxn() const override
......
......@@ -27,6 +27,12 @@ public:
return std::string("TD1.virt1()");
}
protected:
virtual std::string prot_virt_fxn() const override
{
return std::string("TD1.prot_virt_fxn()");
}
private:
// Make sure private non-pure virtual functions aren't included in trampoline
inline virtual std::string priv_virt_fxn() const override
......
......@@ -76,7 +76,7 @@ class basicClassBehavior(unittest.TestCase):
# Not testing virtual-ness here, just that it was inherited
# and can be called by the usual means. Notice that the implementation
# was provided in Derived1
def test_inherited_mthds(self):
def test_public_inherited_mthds(self):
self.assertEqual(im.Derived2_py().virt1(6.28), "Derived1.virt1()")
# Same for InheritsAll. All of the implementations should be the same as
......@@ -85,6 +85,16 @@ class basicClassBehavior(unittest.TestCase):
self.assertEqual(im.InheritsAll_py().virt1(6.28), "Derived1.virt1()")
self.assertEqual(im.InheritsAll_py().virt2(-6.28, "py_string"), "Derived2.virt2()")
# Likewise with above, but with a protected member.
# Should be "public" on the Python side. Implemented in Base class
def test_prot_inherited_mthds(self):
exp_string = "Base.prot_fxn()"
self.assertEqual(im.Base_py().prot_fxn(), exp_string)
self.assertEqual(im.AbstractDerived1_py().prot_fxn(), exp_string)
self.assertEqual(im.Derived1_py().prot_fxn(), exp_string)
self.assertEqual(im.Derived2_py().prot_fxn(), exp_string)
self.assertEqual(im.InheritsAll_py().prot_fxn(), exp_string)
# Also not testing virtual-ness here. Making sure that
# subclasses can be substitute for a base class in a function argument
# and that the correct behavior is observed
......
......@@ -20,6 +20,10 @@ class DerivedFromBase(im.Base_py):
def whoami(self):
return "DerivedFromBase"
# Override prot_virt_fxn
def prot_virt_fxn(self):
return "DerivedFromBase.prot_virt_fxn()"
# Override priv_virt_fxn
def priv_virt_fxn(self):
return "DerivedFromBase.priv_virt_fxn()"
......@@ -33,6 +37,9 @@ class DerivedFromAbstractDerived1(im.AbstractDerived1_py):
return "DerivedFromAbstractDerived1"
# Override priv_virt_fxn
def prot_virt_fxn(self):
return "DerivedFromAbstractDerived1.prot_virt_fxn()"
def priv_virt_fxn(self):
return "DerivedFromAbstractDerived1.priv_virt_fxn()"
......@@ -50,6 +57,9 @@ class DerivedFromD1(im.Derived1_py):
def virt1(self, f):
return "DerivedFromD1.virt1()"
def prot_virt_fxn(self):
return "DerivedFromD1.prot_virt_fxn()"
def priv_virt_fxn(self):
return "DerivedFromD1.priv_virt_fxn()"
......@@ -71,6 +81,9 @@ class DerivedFromD2(im.Derived2_py):
def virt2(self, f, s):
return "DerivedFromD2.virt2()"
def prot_virt_fxn(self):
return "DerivedFromD2.prot_virt_fxn()"
def priv_virt_fxn(self):
return "DerivedFromD2.priv_virt_fxn()"
......@@ -93,5 +106,8 @@ class DerivedFromIA(im.InheritsAll_py):
def virt2(self, f, s):
return "DerivedFromIA.virt2()"
def prot_virt_fxn(self):
return "DerivedFromIA.prot_virt_fxn()"
def priv_virt_fxn(self):
return "DerivedFromIA.priv_virt_fxn()"
......@@ -15,7 +15,7 @@ import inheritance_module as im
# Virtual-ness should be honored by Python for both C++ classes exposed in Python,
# and for Python classes inheriting from C++ classes
class trampolines(unittest.TestCase):
def test_virt_overrides_cpp(self):
def test_public_virt_overrides_cpp(self):
# First call whoami from base reference. Downcasting should ensure that
# this always prints the lowest implementation in the inheritance tree
# Note that InheritsAll should always match derived2's output.
......@@ -28,6 +28,16 @@ class trampolines(unittest.TestCase):
self.assertEqual(im.call_virt_whoami(im.Derived2_py()), "Derived2")
self.assertEqual(im.call_virt_whoami(im.InheritsAll_py()), "Derived2")
def test_prot_virt_overrides_cpp(self):
# Pure virtual can't be called
self.assertRaises(RuntimeError, im.Base_py().call_prot_virt_fxn)
self.assertRaises(RuntimeError, im.AbstractDerived1_py().call_prot_virt_fxn)
# Derived 1 sets the implementation, everyone else inherits
self.assertEqual(im.Derived1_py().call_prot_virt_fxn(), "Derived1.prot_virt_fxn()")
self.assertEqual(im.Derived2_py().call_prot_virt_fxn(), "Derived1.prot_virt_fxn()")
self.assertEqual(im.InheritsAll_py().call_prot_virt_fxn(), "Derived1.prot_virt_fxn()")
def test_priv_virt_overrides_cpp(self):
# Pure virtual can't be called
self.assertRaises(RuntimeError, im.Base_py().call_priv_virt_fxn)
......@@ -44,7 +54,7 @@ class trampolines(unittest.TestCase):
self.assertIn("call_virt_from_derived2", dir(im))
# Test virtual functions can be overriden in python
def test_virt_overrides_py(self):
def test_public_virt_overrides_py(self):
# First DerivedFromBase
self.assertEqual(im.call_virt_whoami(DerivedFromBase()), "DerivedFromBase")
......@@ -79,6 +89,24 @@ class trampolines(unittest.TestCase):
exp_string += "DerivedFromIA.virt2()"
self.assertEqual(im.call_virt_from_derived2(DerivedFromIA()), exp_string)
def test_prot_virt_overrides_py(self):
# First classes derived from C++ abstract classes
exp_string = "DerivedFromBase.prot_virt_fxn()"
self.assertEqual(DerivedFromBase().call_prot_virt_fxn(), exp_string)
exp_string = "DerivedFromAbstractDerived1.prot_virt_fxn()"
self.assertEqual(DerivedFromAbstractDerived1().call_prot_virt_fxn(), exp_string)
# Now classes derived from concrete C++ classes
exp_string = "DerivedFromD1.prot_virt_fxn()"
self.assertEqual(DerivedFromD1().call_prot_virt_fxn(), exp_string)
exp_string = "DerivedFromD2.prot_virt_fxn()"
self.assertEqual(DerivedFromD2().call_prot_virt_fxn(), exp_string)
exp_string = "DerivedFromIA.prot_virt_fxn()"
self.assertEqual(DerivedFromIA().call_prot_virt_fxn(), exp_string)
def test_priv_virt_overrides_py(self):
# First DerivedFromBase, which is one of 2 that should work
exp_string = "DerivedFromBase.priv_virt_fxn()"
......
# Distributed under the OSI-approved BSD 3-Clause License. See accompanying
# file Copyright.txt
import unittest
import sys
import os
# Make it so we search where we are running.
sys.path.append(os.getcwd())
import inheritance_module as im
# Not going to test all of the same functionality as in non_template/basicClassTests,
# as basic class functionality has been tested thoroughly elsewhere in the examples.
# Just test new functionality not tested elsewhere
class basicClassBehavior(unittest.TestCase):
# Should be "public" on the Python side. Implemented in Base class
def test_prot_inherited_mthds(self):
exp_string = "Base.prot_fxn()"
self.assertEqual(im.TAD1_float_py().prot_fxn(), exp_string)
self.assertEqual(im.TAD1_int_py().prot_fxn(), exp_string)
self.assertEqual(im.TD1_float_py().prot_fxn(), exp_string)
self.assertEqual(im.TD1_int_py().prot_fxn(), exp_string)
self.assertEqual(im.TD2_float_double_py().prot_fxn(), exp_string)
self.assertEqual(im.TD2_int_double_py().prot_fxn(), exp_string)
if __name__ == '__main__':
unittest.main()
......@@ -20,6 +20,10 @@ class DerivedFromTAD1Float(im.TAD1_float_py):
def whoami(self):
return "DerivedFromTAD1Float"
# Override prot_virt_fxn
def prot_virt_fxn(self):
return "DerivedFromTAD1Float.prot_virt_fxn()"
# Override priv_virt_fxn
def priv_virt_fxn(self):
return "DerivedFromTAD1Float.priv_virt_fxn()"
......@@ -32,6 +36,10 @@ class DerivedFromTAD1Int(im.TAD1_int_py):
def whoami(self):
return "DerivedFromTAD1Int"
# Override prot_virt_fxn
def prot_virt_fxn(self):
return "DerivedFromTAD1Int.prot_virt_fxn()"
# Override priv_virt_fxn
def priv_virt_fxn(self):
return "DerivedFromTAD1Int.priv_virt_fxn()"
......@@ -51,6 +59,10 @@ class DerivedFromTD1Float(im.TD1_float_py):
def virt1(self, f):
return "DerivedFromTD1Float.virt1()"
# Override prot_virt_fxn
def prot_virt_fxn(self):
return "DerivedFromTD1Float.prot_virt_fxn()"
# Override priv_virt_fxn
def priv_virt_fxn(self):
return "DerivedFromTD1Float.priv_virt_fxn()"
......@@ -70,6 +82,10 @@ class DerivedFromTD1Int(im.TD1_int_py):
def virt1(self, f):
return "DerivedFromTD1Int.virt1()"
# Override prot_virt_fxn
def prot_virt_fxn(self):
return "DerivedFromTD1Int.prot_virt_fxn()"
# Override priv_virt_fxn
def priv_virt_fxn(self):
return "DerivedFromTD1Int.priv_virt_fxn()"
......@@ -93,6 +109,10 @@ class DerivedFromTD2FloatDouble(im.TD2_float_double_py):
def virt2(self, f, s):
return "DerivedFromTD2FloatDouble.virt2()"
# Override prot_virt_fxn
def prot_virt_fxn(self):
return "DerivedFromTD2FloatDouble.prot_virt_fxn()"
# Override priv_virt_fxn
def priv_virt_fxn(self):
return "DerivedFromTD2FloatDouble.priv_virt_fxn()"
......@@ -116,7 +136,10 @@ class DerivedFromTD2IntDouble(im.TD2_int_double_py):
def virt2(self, f, s):
return "DerivedFromTD2IntDouble.virt2()"
# Override prot_virt_fxn
def prot_virt_fxn(self):
return "DerivedFromTD2IntDouble.prot_virt_fxn()"
# Override priv_virt_fxn
def priv_virt_fxn(self):
return "DerivedFromTD2IntDouble.priv_virt_fxn()"
......@@ -15,7 +15,7 @@ import inheritance_module as im
# Virtual-ness should be honored by Python for both C++ classes exposed in Python,
# and for Python classes inheriting from C++ classes
class trampolines(unittest.TestCase):
def test_virt_overrides_cpp(self):
def test_public_virt_overrides_cpp(self):
# First call whoami from base reference. Downcasting should ensure that
# this always prints the lowest implementation in the inheritance tree
# Method isn't implemented in abstract class, so this should raise an exception
......@@ -28,6 +28,17 @@ class trampolines(unittest.TestCase):
self.assertEqual(im.call_virt_whoami(im.TD2_float_double_py()), "TD2")
self.assertEqual(im.call_virt_whoami(im.TD2_int_double_py()), "TD2")
def test_prot_virt_overrides_cpp(self):
# Not implemented in abstract class
self.assertRaises(RuntimeError, im.TAD1_float_py().call_prot_virt_fxn)
self.assertRaises(RuntimeError, im.TAD1_int_py().call_prot_virt_fxn)
# TD1 sets the implementation, everyone else inherits
self.assertEqual(im.TD1_float_py().call_prot_virt_fxn(), "TD1.prot_virt_fxn()")
self.assertEqual(im.TD1_int_py().call_prot_virt_fxn(), "TD1.prot_virt_fxn()")
self.assertEqual(im.TD2_float_double_py().call_prot_virt_fxn(), "TD1.prot_virt_fxn()")
self.assertEqual(im.TD2_int_double_py().call_prot_virt_fxn(), "TD1.prot_virt_fxn()")
def test_priv_virt_overrides_cpp(self):
# Not implemented in abstract class
self.assertRaises(RuntimeError, im.TAD1_float_py().call_priv_virt_fxn)
......@@ -45,7 +56,7 @@ class trampolines(unittest.TestCase):
self.assertIn("call_virt_from_td2", dir(im))
# Test virtual functions can be overriden in python
def test_virt_overrides_py(self):
def test_public_virt_overrides_py(self):
# Check that virtual methods can be accessed
# from a reference to any base for a python class
......@@ -91,6 +102,27 @@ class trampolines(unittest.TestCase):
exp_string += "DerivedFromTD2IntDouble.virt2()"
self.assertEqual(im.call_virt_from_td2(DerivedFromTD2IntDouble()), exp_string)
def test_prot_virt_overrides_py(self):
# Check python classes deriving from C++ abstract class
exp_string = "DerivedFromTAD1Float.prot_virt_fxn()"
self.assertEqual(DerivedFromTAD1Float().call_prot_virt_fxn(), exp_string)
exp_string = "DerivedFromTAD1Int.prot_virt_fxn()"
self.assertEqual(DerivedFromTAD1Int().call_prot_virt_fxn(), exp_string)
# Now classes derived from concrete C++ classes
exp_string = "DerivedFromTD1Float.prot_virt_fxn()"
self.assertEqual(DerivedFromTD1Float().call_prot_virt_fxn(), exp_string)
exp_string = "DerivedFromTD1Int.prot_virt_fxn()"
self.assertEqual(DerivedFromTD1Int().call_prot_virt_fxn(), exp_string)
exp_string = "DerivedFromTD2FloatDouble.prot_virt_fxn()"
self.assertEqual(DerivedFromTD2FloatDouble().call_prot_virt_fxn(), exp_string)
exp_string = "DerivedFromTD2IntDouble.prot_virt_fxn()"
self.assertEqual(DerivedFromTD2IntDouble().call_prot_virt_fxn(), exp_string)
def test_priv_virt_overrides_py(self):
# Check python classes deriving from C++ abstract class
exp_string = "DerivedFromTAD1Float.priv_virt_fxn()"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment