Commit 4d346c86 authored by John Parent's avatar John Parent
Browse files

rough support for STL binding and opaque types

parent 3ea77a35
......@@ -78,5 +78,6 @@ add_autopybind11_test(enum_deps)
add_autopybind11_test(holder_types)
add_autopybind11_test(ref_support)
add_autopybind11_test(tramp_duplication)
add_autopybind11_test(stl_containers)
add_subdirectory(code_generation_regression)
cmake_minimum_required(VERSION 3.15)
project(stl_containers CXX)
add_library(stl_cpp INTERFACE)
target_include_directories(stl_cpp INTERFACE ${CMAKE_CURRENT_SOURCE_DIR})
target_sources(stl_cpp INTERFACE
stl_container.h)
find_package(AutoPyBind11)
autopybind11_fetch_build_pybind11(PYBIND11_DIR ${PYBIND11_SRC_DIR})
autopybind11_add_module("stl_container"
YAML_INPUT ${CMAKE_CURRENT_SOURCE_DIR}/input_wrap.yml
DESTINATION ${CMAKE_CURRENT_BINARY_DIR}
LINK_LIBRARIES stl_cpp)
stl:
vector:
"VectorInt" :
type: int
buffer_prot:
"VectorDouble" :
type: double
buffer_prot:
stl_bind: True
map:
"Int2String" :
module_local: False
type: ["int","std::string"]
stl_bind: True
"Int2Double" :
type: ["int", "double"]
files:
stl_container.h:
classes:
simple:
#ifndef STL_CONTAINER
#define STL_CONTAINER
#include <vector>
#include <map>
class simple
{
public:
simple() {}
simple(std::map<int,std::vector<int>> &i_map) :i_map_(i_map) {}
void append(int i, int val) {this->i_map_.at(i).push_back(val);}
void add_pair(int i, std::vector<int> v) {this->i_map_[i]=v;}
private:
std::map<int,std::vector<int>> i_map_;
};
#endif // STL_CONTAINER
......@@ -43,6 +43,8 @@ class BindingsGenerator:
self.cns_flag = False
self.validation_filters = []
self.load_validation_filters()
self.stl_binders = []
self.universal_includes = ""
def load_mdx(self, file_name):
with open(file_name, "r") as mdx_file:
......@@ -69,14 +71,14 @@ class BindingsGenerator:
if type(d_in[key]) is dict:
self.flatten(d_in[key], flat)
def write_data_to_file(self, file_name, string):
def write_data_to_file(self, file_name, string, ext):
"""
Writes incoming strings to an output file, globally keep track of objects written?
:param file_name: name of file to write to
:param string: string to write out
:return: None
"""
assert file_name.endswith(".cpp"), cpp_file
assert file_name.endswith(ext), cpp_file
with open(file_name, "w") as cpp_file:
cpp_file.write(string)
if not self.opts.skip_formatting:
......@@ -89,6 +91,8 @@ class BindingsGenerator:
include_str += f
else:
include_str += '#include "%s"\n' % f
if self.universal_includes:
include_str += self.universal_includes + "\n"
return include_str
def generate_keep_alive_string(self, decl_data, keep_alive_data):
......@@ -700,7 +704,7 @@ class BindingsGenerator:
if module_name != free_fun_mod_name:
keys["defs"] = keys["defs"].replace("::%s::" % module_name, "")
self.write_data_to_file(file_name, cpp_body.format(**keys))
self.write_data_to_file(file_name, cpp_body.format(**keys), ".cpp")
self.indent = self.indent[:-2]
def return_policy_reference_internal(self):
......@@ -823,6 +827,151 @@ class BindingsGenerator:
except AttributeError:
return decl
def derive_stl_deps(self, decl_string, incs=set()):
pass
def aquire_stl_pygccxml_data(self, pygccxml_data):
"""
Search for and return the pygccxml representations of
the requested STL container types
:param pygccxml_data: pygcxcxml dervied data representing top level
namespace to be searched for STL data
:return: STL pygccxml data
"""
# We know that any type we are looking for is within the STD namespace
# so we can avoid a recursive search and limit ourselves to the STD namespace
std = pygccxml_data.namespace(name="std")
# further avoid RT increase by leveraging decl_type looking for class_t
stl_data = []
for container in self.stl_binders:
if container.stl_bind:
try:
stl_data.append(
std.decl(
lambda dec: dec.partial_decl_string
== container.get_partial_decl(),
decl_type=dec.class_declaration.class_t,
)
)
except pygccxml.declarations.runtime_errors.declaration_not_found_t:
continue
return stl_data
def generate_stl_class_bindings(self, container, pygccxml_data):
"""
Generate the pybind11 py class definition of an STL container as a string given
data from pygccxml
:param container: An stl contianer binder object
:param pygccxml_data: Pygccxml data defining said STL object
:return: String representation of stl_binder class definiton
"""
return container.generate_class_def(pygccxml_data, self)
def write_stl_binding_data(self, containers, pygccxml_data):
"""
Write out to file bind calls and custom definitions for declared STL
opaque types, file is named "stl_bind.cxx"
:param containers: List of STL container objects to be bound
:param pygccxml_data: pygccxml_data representing compiler info associated with
aforementioned STL container objects
"""
opaque_stl = []
bind_stl = []
for container in containers:
binding_agent = bind_stl
if container.stl_bind:
binding_agent = opaque_stl
binding_agent.append(container)
body_str = ""
for container in bind_stl:
mod_local = (
", py::module_local(false)" if container.mod_loc else ""
)
buf_prot = (
", " + self.opts.py_buff_prot_fmt
if container.buffer_protocol
else ""
)
body_str += (
self.opts.bind_stl_type_fmt.format(
stl_name=container.type,
type=container.cpp_name,
mod="m",
name=container.name,
mod_local=mod_local,
buf_prot=buf_prot,
)
+ "\n"
)
for container, data in zip(opaque_stl, pygccxml_data):
body_str += (
self.generate_stl_class_bindings(container, data) + "\n"
)
class_mod_dec = self.opts.class_module_cpp_fmt.format(
includes=self.get_include_str([self.opts.stl_bind_fmt + "\n"]),
module=self.opts.module_name,
namespace="stl_binders",
publicist_str="",
trampoline_str="",
defs=body_str,
)
self.write_data_to_file("stl_bind.cxx", class_mod_dec, ".cxx")
def generate_stl_includes(self, containers):
"""
Generate string of includes for STL components
:param containers: List of stl container objects
:return: string of stl includes
"""
include_str = []
for container in containers:
fmt_str = self.opts.generic_include.format(
"<{}>".format(container.type)
)
if fmt_str not in include_str:
include_str.append(fmt_str)
return "\n".join(include_str)
def create_stl_binder_header(self, containers):
"""
Generate header supporting Pybind11 Opaque types
Header to be included by each compilation unit in module
so specified stl types will be made opaque for PBR and
avoiding excessive copy operations with large containers
Opaque types in Pybind11 are those for which template type
conversion/inspection is disabled, and the type is passed
to python directly, enabling copying and PBR
:param containers: a list of STL containers to be exposed
as an opaque type
:return: the header file name and a string representing the opaque
declarations
"""
header_str = "#ifndef STL_OPAQUE\n#define STL_OPAQUE\n#include <pybind11/pybind11.h>\n{}\n{}\n#endif\n"
binder_str = ""
file_ = "stl_opaque_dec.h"
self.universal_includes += '#include "{}"\n'.format(file_)
stl_incs = self.generate_stl_includes(containers)
for container in containers:
binder_str += (
self.opts.make_opaque_fmt.format(container.cpp_name) + "\n"
)
return header_str.format(stl_incs, binder_str), file_
def write_stl_opaque_header(self, containers, outdir):
"""
Write out the common header declaring STL opaque types
:param containers: list of contianers to be made opaque
:param outdir: directory header is to be written to
:return: Location of headerfile
"""
opaque_header, file_name = self.create_stl_binder_header(containers)
file_name = posixpath.join(outdir, file_name)
self.write_data_to_file(file_name, opaque_header, ".h")
return file_name
def include_eigen(self, x, includes):
if not self.filter_component_deps(x, self.depends_on_eigen, True):
eigen_str = "#include <pybind11/eigen.h>\n"
......@@ -2171,6 +2320,23 @@ class BindingsGenerator:
module_data = {"forwards": [], "init_funs": []}
no_forward_list = set()
nmspcs = set()
mod_includes = ""
if self.stl_binders:
header_name = self.write_stl_opaque_header(
self.stl_binders, out_dir
)
mod_includes += '#include "{}"\n'.format(header_name)
module_data["forwards"].append(
self.opts.init_fun_forward_fmt.format(
name="stl_binders", module=self.opts.module_name
)
)
module_data["init_funs"].append(
self.opts.init_fun_signature_fmt.format(
name="stl_binders", module=self.opts.module_name
)
)
# Note that this assumes self.name_tree will only be populated
# if we are not ignoring the namespace structure
if self.name_tree:
......@@ -2192,6 +2358,7 @@ class BindingsGenerator:
),
)
)
for future_file in results_dict["out_names"]:
name = future_file.split(".")[0]
# we define the namespace structure in above call to build_namespc_str
......@@ -2217,6 +2384,7 @@ class BindingsGenerator:
)
module_cpp_file = posixpath.join(out_dir, "%s.cpp" % module_name)
module_cpp_text = self.opts.common_cpp_body_fmt.format(
includes=mod_includes,
name=module_name,
forwards="".join(
module_data["forwards"]
......@@ -2225,7 +2393,7 @@ class BindingsGenerator:
init_funs="".join(module_data["init_funs"]),
autobind_calls="".join(results_dict["all_auto_bind"]),
)
self.write_data_to_file(module_cpp_file, module_cpp_text)
self.write_data_to_file(module_cpp_file, module_cpp_text, ".cpp")
return module_cpp_file
def format_member_insts(
......@@ -2359,7 +2527,7 @@ class BindingsGenerator:
self.add_namespace(curr_nmspc, name)
)
# Add the dependent file
res_dict["to_include"].add(curr_file)
res_dict["to_include"].add('"' + curr_file + '"')
# Then write the future file
future_file = self.find_future_file_name(
is_class, name, free_fun_name, curr_nmspc
......@@ -2482,6 +2650,14 @@ class BindingsGenerator:
curr_file,
mod_tree,
)
elif key == "stl":
stl_binders_ = self.customizer.construct_std_bind(inner_dict)
for x in stl_binders_:
if x.stl_bind:
res_dict["to_include"].add("<{}>".format(x.type))
res_dict["class_insts"].append(x.cpp_name)
self.stl_binders.extend(stl_binders_)
return mod_tree
def clean_flags(self, rsp_path):
......@@ -2854,6 +3030,7 @@ class BindingsGenerator:
"functions",
"enums",
"customization",
"stl",
}
for key in keys_left_to_check:
if key in [x.name for x in pygccxml_data.namespaces()]:
......@@ -2883,6 +3060,21 @@ class BindingsGenerator:
free_fun_name=free_fun_name,
)
if not curr_nmspc:
# We have moved fully through the input yaml, all recursive calls have
# returned, now before writing out all non class data, we make a check
# for any STL binding data, if found in our yaml, we compose all STL binding
# into a single file and write out.
if "stl" in yaml_dict:
import pdb
pdb.set_trace()
# First we aquire pygccxml data for types just being made opaque
# Direct STL bindings via Pybind11 do not require a custom class definition
# as they are interpreted and derived into pythonic types via Pybind11
# Opaque types require the 'user' (here APB) to perform this step.
stl_pygcc = self.aquire_stl_pygccxml_data(pygccxml_data)
self.write_stl_binding_data(self.stl_binders, stl_pygcc)
# write out all the non class_data for each namespace
# need to do it all at once here rather than incrementally
# as previous to handle complex mixed custom/c++ namespaces->modules
......@@ -2911,7 +3103,7 @@ class BindingsGenerator:
return name if len(name) > 1 else name[0]
def generate_wrapper_cpp(self, module_info):
includes_fmt = '#include "%s"\n'
includes_fmt = "#include %s\n"
class_decs_fmt = "template class %s;\n"
func_ptr_assign_fmt = "auto %s = &%s;\n"
includes_str = "".join(
......@@ -2945,7 +3137,7 @@ class BindingsGenerator:
class_decs=class_decs,
func_ptr_assigns=func_ptr_assigns,
)
self.write_data_to_file(wrapper_cpp_file, wrapper_cpp_text)
self.write_data_to_file(wrapper_cpp_file, wrapper_cpp_text, ".cpp")
def compile_and_parse_wrapper(self, module_info, rsp_includes, rsp_defs):
# Need the castxml path at this point
......@@ -3027,8 +3219,15 @@ class BindingsGenerator:
file_name = posixpath.join(
self.opts.output_dir, "%s.cpp" % self.opts.module_name
)
stl_bind = None
if self.stl_binders:
stl_bind = ["stl_bind.cxx"]
print(
"%".join([file_name] + list(module_info["out_names"].keys()))
"%".join(
[file_name]
+ list(module_info["out_names"].keys())
+ stl_bind
)
+ ";"
+ "%".join([module_index_file])
)
......@@ -3423,6 +3622,11 @@ def main(argv=None):
arg.add(
"--class_init_call_fmt", required=False, default=tb.class_init_call
)
arg.add("--make_opaque_fmt", required=False, default=tb.make_opaque)
arg.add("--stl_bind_fmt", required=False, default=tb.include_stl_bind)
arg.add("--py_buffer_prot_fmt", required=False, default=tb.buffer_prot)
arg.add("--bind_stl_type_fmt", required=False, default=tb.bind_stl_type)
arg.add("--generic_include", required=False, default=tb.generic_include)
options = arg.parse_args(argv)
# CastXML Check for a relatively recent version
rtn = subprocess.check_output([options.castxml_path, "--version"])
......
# Class that helps with customization of C++ functions, enums, classes, etc.
import autopybind11.text_blocks as tb
class Customizer:
def __init__(self, mod_tree):
self.module_structure = mod_tree
def fetch_customization_field(self, yaml_dict, field):
"""
Class helper method - takes a dictionary representation of
input yaml and extracts given key defined by field from
customization field, performs validation and returns the
values mapped by field
:param yaml_dict: A dictionary to be parsed
:param field: A key to be checked against the dictionary yaml_dict
:returns: The values associated with key field is field is
present and defined in the given dictionary, else None
"""
if (
yaml_dict
and "customization" in yaml_dict
......@@ -117,6 +130,48 @@ class Customizer:
)
return keep_alive_list if keep_alive_list else []
def fetch_field(self, yaml_dict, field):
"""
Return all values associated with key field in the provided
dictionary
:param yaml_dict: The dictionary from which the values will be
extracted
:param field: the key with which the values to be extracted
are associated
:return: the value mapped to the key field
"""
if field in yaml_dict and yaml_dict[field]:
return yaml_dict[field]
return None
def construct_std_bind(self, yaml_dict):
"""
From stl_bind field in yaml_dict
compose a set of stl_bind objects defining
STL members to be bound by APB
:param yaml_dict: dictionary representing
the STL members to be bound
:return: list of stl_bind objects
"""
stl_yaml = yaml_dict
stl_list = []
for type_ in stl_yaml:
for name in stl_yaml[type_]:
insts = stl_yaml[type_][name]["type"]
stl_bind = self.fetch_field(stl_yaml[type_][name], "stl_bind")
buff_prot = self.fetch_field(
stl_yaml[type_][name], "buffer_prot"
)
module_local = self.fetch_field(
stl_yaml[type_][name], "module_local"
)
stl_list.append(
stl_binder(
type_, name, insts, stl_bind, module_local, buff_prot
)
)
return stl_list
# If c++ object is given a custom namespace location
# extract absolute namespace path, return namespace
# and location in yaml representation of hierarchy
......@@ -176,3 +231,125 @@ class Customizer:
if not self.module_structure:
return None
return self.module_structure
class stl_binder(object):
def __init__(
self, type_, name_, insts, stl=False, module_local=True, buff=False
):
self.name = name_
self.type = type_
self.stl_bind = stl
self.mod_loc = module_local
self.inst_type = insts
self.buffer_protocol = buff
self.cpp_name = self.__create_cpp_name(self.type, self.inst_type)
self.type_str = ",".join(insts) if type(insts) is list else insts
def __eq__(self, other):
return (
other.inst_type == self.inst_type
and other.type == self.type
and other.name == self.name
)
def __create_cpp_name(self, type_, inst):
inst_str = ""
if type(self.inst_type) is list:
inst_str = ", ".join(inst)
else:
inst_str = inst
return "std::" + type_ + "< " + inst_str + " >"
def get_partial_decl(self):
return "::" + self.cpp_name
def generate_class_def(self, pygccxml_data, indent, binding_gen):
for constructorObj in pygccxml_data.constructors(
binding_gen.public_filter
):
arg_string = ""
arg_name_string = ","
for arg in constructorObj.arguments:
arg_string += arg.decl_type.decl_string + ","
default_val = ""
if arg.default_value:
fmt_string = binding_gen.opts.arg_val_cast_fmt
# startswith is used to prevent a (<type>){val, val, val} casting string
# which is a non-standard explicit type conversion syntax error.
if (
" " in arg.decl_type.decl_string
and not arg.default_value.startswith("{")
):
fmt_string = binding_gen.opts.nullptr_arg_val_fmt
default_val = fmt_string.format(
type=arg.decl_type.decl_string.replace(
"const &", ""
).strip(),
val=arg.default_value,
)
arg_name_string += binding_gen.opts.member_func_arg_fmt.format(
arg.name, default_val
)
arg_name_string += ","
arg_string = arg_string.rstrip(",")
arg_name_string = arg_name_string.rstrip(",")
constructor_str += binding_gen.opts.constructor_fmt.format(
arg_string, arg_name_string, ", " if False else "", "",
)
constructor_str += binding_gen.indent
member_var_string = ""
for member_var in sorted(
[x for x in pygccxml_data.variables()], key=lambda c: c.name
):
member_var_string += binding_gen.generate_member_var_string(
member_var, []
)
member_var_string += binding_gen.indent
member_string = ""
for member_function in sorted(
[x for x in pygccxml_data.member_functions()], key=lambda c: c.name
):
member_string += binding_gen.generate_function_string(
member_function,
publicist_name="",
keep_alive="",
return_value_policy="",
pass_by_ref="",
)
member_string += binding_gen.indent
# TODO: Necessary? Determine usefulness of listing operators
operator_string = ""
total_operators = set(pygccxml_data.operators())
for operator in sorted(total_operators, key=lambda c: c.name):
# If ostream is first argument, we assume this is a string representation
# function.
operator_string += binding_gen.generate_operator_string(
pygccxml_data, operator, is_stream=""
)
operator_string += binding_gen.indent
return tb.class_info_body.format(
using=self.name,
name=self.name,
pyclass_args=self.cpp_name,
attrs="",
doc=', \nR"""(%s)"""' % " \n".join(pygccxml_data.comment.text)
if pygccxml_data.comment.text
else "",
mod_loc=self.mod_loc,
keepalive="",
dependency_calls="",
no_delete="",
constructor=constructor_str.strip(),
funcs=member_string.strip(),
vars=member_var_string.strip(),
opers=operator_string.strip(),
enums="",
)
......@@ -14,7 +14,7 @@ common_cpp_body = """
#include <pybind11/pybind11.h>
#include <pybind11/iostream.h>
#include <pybind11/stl.h>
{includes}
namespace py = pybind11;