1 # Copyright 2020 The Pigweed Authors
3 # Licensed under the Apache License, Version 2.0 (the "License"); you may not
4 # use this file except in compliance with the License. You may obtain a copy of
7 # https://www.apache.org/licenses/LICENSE-2.0
9 # Unless required by applicable law or agreed to in writing, software
10 # distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
11 # WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
12 # License for the specific language governing permissions and limitations under
14 """Tools for compiling and importing Python protos on the fly."""
19 from pathlib import Path
23 from types import ModuleType
24 from typing import (Dict, Generic, Iterable, Iterator, List, NamedTuple, Set,
25 Tuple, TypeVar, Union)
27 _LOG = logging.getLogger(__name__)
29 PathOrStr = Union[Path, str]
33 output_dir: PathOrStr,
34 proto_files: Iterable[PathOrStr],
35 includes: Iterable[PathOrStr] = ()) -> None:
36 """Compiles proto files for Python by invoking the protobuf compiler.
38 Proto files not covered by one of the provided include paths will have their
39 directory added as an include path.
41 proto_paths: List[Path] = [Path(f).resolve() for f in proto_files]
42 include_paths: Set[Path] = set(Path(d).resolve() for d in includes)
44 for path in proto_paths:
45 if not any(include in path.parents for include in include_paths):
46 include_paths.add(path.parent)
48 cmd: Tuple[PathOrStr, ...] = (
51 os.path.abspath(output_dir),
52 *(f'-I{d}' for d in include_paths),
56 _LOG.debug('%s', ' '.join(shlex.quote(str(c)) for c in cmd))
57 process = subprocess.run(cmd, capture_output=True)
59 if process.returncode:
60 _LOG.error('protoc invocation failed!\n%s\n%s',
61 ' '.join(shlex.quote(str(c)) for c in cmd),
62 process.stderr.decode())
63 process.check_returncode()
66 def _import_module(name: str, path: str) -> ModuleType:
67 spec = importlib.util.spec_from_file_location(name, path)
68 module = importlib.util.module_from_spec(spec)
69 spec.loader.exec_module(module) # type: ignore[union-attr]
73 def import_modules(directory: PathOrStr) -> Iterator:
74 """Imports modules in a directory and yields them."""
75 parent = os.path.dirname(directory)
77 for dirpath, _, files in os.walk(directory):
78 path_parts = os.path.relpath(dirpath, parent).split(os.sep)
81 name, ext = os.path.splitext(file)
84 yield _import_module(f'{".".join(path_parts)}.{name}',
85 os.path.join(dirpath, file))
88 def compile_and_import(proto_files: Iterable[PathOrStr],
89 includes: Iterable[PathOrStr] = (),
90 output_dir: PathOrStr = None) -> Iterator:
91 """Compiles protos and imports their modules; yields the proto modules.
94 proto_files: paths to .proto files to compile
95 includes: include paths to use for .proto compilation
96 output_dir: where to place the generated modules; a temporary directory is
100 the generated protobuf Python modules
104 compile_protos(output_dir, proto_files, includes)
105 yield from import_modules(output_dir)
107 with tempfile.TemporaryDirectory(prefix='compiled_protos_') as tempdir:
108 compile_protos(tempdir, proto_files, includes)
109 yield from import_modules(tempdir)
112 def compile_and_import_file(proto_file: PathOrStr,
113 includes: Iterable[PathOrStr] = (),
114 output_dir: PathOrStr = None):
115 """Compiles and imports the module for a single .proto file."""
116 return next(iter(compile_and_import([proto_file], includes, output_dir)))
119 def compile_and_import_strings(contents: Iterable[str],
120 includes: Iterable[PathOrStr] = (),
121 output_dir: PathOrStr = None) -> Iterator:
122 """Compiles protos in one or more strings."""
124 if isinstance(contents, str):
125 contents = [contents]
127 with tempfile.TemporaryDirectory(prefix='proto_sources_') as path:
130 for proto in contents:
131 # Use a hash of the proto so the same contents map to the same file
132 # name. The protobuf package complains if it seems the same contents
133 # in files with different names.
134 protos.append(Path(path, f'protobuf_{hash(proto):x}.proto'))
135 protos[-1].write_text(proto)
137 yield from compile_and_import(protos, includes, output_dir)
143 class _NestedPackage(Generic[T]):
144 """Facilitates navigating protobuf packages as attributes."""
145 def __init__(self, package: str):
146 self._packages: Dict[str, _NestedPackage[T]] = {}
147 self._items: List[T] = []
148 self._package = package
150 def _add_package(self, subpackage: str, package: '_NestedPackage') -> None:
151 self._packages[subpackage] = package
153 def _add_item(self, item) -> None:
154 if item not in self._items: # Don't store the same item multiple times.
155 self._items.append(item)
157 def __getattr__(self, attr: str):
158 """Look up subpackages or package members."""
159 if attr in self._packages:
160 return self._packages[attr]
162 for item in self._items:
163 if hasattr(item, attr):
164 return getattr(item, attr)
166 raise AttributeError(
167 f'Proto package "{self._package}" does not contain "{attr}"')
169 def __getitem__(self, subpackage: str) -> '_NestedPackage[T]':
170 """Support accessing nested packages by name."""
173 for package in subpackage.split('.'):
174 result = result._packages[package]
178 def __dir__(self) -> List[str]:
179 """List subpackages and members of modules as attributes."""
180 attributes = list(self._packages)
182 for item in self._items:
183 for attr, value in vars(item).items():
184 # Exclude private variables and modules from dir().
185 if not attr.startswith('_') and not isinstance(
187 attributes.append(attr)
191 def __iter__(self) -> Iterator['_NestedPackage[T]']:
192 """Iterate over nested packages."""
193 return iter(self._packages.values())
195 def __repr__(self) -> str:
196 msg = [f'ProtoPackage({self._package!r}']
199 i for i in vars(self)
200 if i not in self._packages and not i.startswith('_')
203 msg.append(f'members={str(public_members)}')
206 msg.append(f'subpackages={str(list(self._packages))}')
208 return ', '.join(msg) + ')'
210 def __str__(self) -> str:
214 class Packages(NamedTuple):
215 """Items in a protobuf package structure; returned from as_package."""
216 items_by_package: Dict[str, List]
217 packages: _NestedPackage
220 def as_packages(items: Iterable[Tuple[str, T]],
221 packages: Packages = None) -> Packages:
222 """Places items in a proto-style package structure navigable by attributes.
225 items: (package, item) tuples to insert into the package structure
226 packages: if provided, update this Packages instead of creating a new one
229 packages = Packages({}, _NestedPackage(''))
231 for package, item in items:
232 packages.items_by_package.setdefault(package, []).append(item)
234 entry = packages.packages
235 subpackages = package.split('.')
237 # pylint: disable=protected-access
238 for i, subpackage in enumerate(subpackages, 1):
239 if subpackage not in entry._packages:
240 entry._add_package(subpackage,
241 _NestedPackage('.'.join(subpackages[:i])))
243 entry = entry._packages[subpackage]
245 entry._add_item(item)
246 # pylint: enable=protected-access
251 PathOrModule = Union[str, Path, ModuleType]
255 """A collection of protocol buffer modules sorted by package.
257 In Python, each .proto file is compiled into a Python module. The Library
258 class makes it simple to navigate a collection of Python modules
259 corresponding to .proto files, without relying on the location of these
262 Proto messages and other types can be directly accessed by their protocol
263 buffer package name. For example, the foo.bar.Baz message can be accessed
264 in a Library called `protos` as:
266 protos.packages.foo.bar.Baz
268 A Library also provides the modules_by_package dictionary, for looking up
269 the list of modules in a particular package, and the modules() generator
270 for iterating over all modules.
273 def from_paths(cls, protos: Iterable[PathOrModule]) -> 'Library':
274 """Creates a Library from paths to proto files or proto modules."""
275 paths: List[PathOrStr] = []
276 modules: List[ModuleType] = []
279 if isinstance(proto, (Path, str)):
282 modules.append(proto)
285 modules += compile_and_import(paths)
286 return Library(modules)
289 def from_strings(cls,
290 contents: Iterable[str],
291 includes: Iterable[PathOrStr] = (),
292 output_dir: PathOrStr = None) -> 'Library':
293 """Creates a proto library from protos in the provided strings."""
294 return cls(compile_and_import_strings(contents, includes, output_dir))
296 def __init__(self, modules: Iterable[ModuleType]):
297 """Constructs a Library from an iterable of modules.
299 A Library can be constructed with modules dynamically compiled by
300 compile_and_import. For example:
302 protos = Library(compile_and_import(list_of_proto_files))
304 self.modules_by_package, self.packages = as_packages(
305 (m.DESCRIPTOR.package, m) # type: ignore[attr-defined]
308 def modules(self) -> Iterable:
309 """Allows iterating over all protobuf modules in this library."""
310 for module_list in self.modules_by_package.values():
311 yield from module_list
314 def proto_repr(message) -> str:
315 """Creates a repr-like string for a protobuf."""
318 for field in message.DESCRIPTOR.fields:
319 value = getattr(message, field.name)
321 # Include fields if has_<field>() is true or the value is non-default.
322 if hasattr(message, 'has_' + field.name):
323 if not getattr(message, 'has_' + field.name)():
325 elif value == field.default_value:
328 fields.append(f'{field.name}={value!r}')
330 return f'{message.DESCRIPTOR.full_name}({", ".join(fields)})'