blob: a9333e2d7dd02d0061b49857e4c059790e1fed6f [file] [log] [blame]
# Copyright 2012 The Chromium Authors. All rights reserved.
# Use of this source code is governed by a BSD-style license that can be
# found in the LICENSE file.
import fnmatch
import importlib
import inspect
import os
import re
import sys
from py_utils import camel_case
def DiscoverModules(start_dir, top_level_dir, pattern='*'):
"""Discover all modules in |start_dir| which match |pattern|.
Args:
start_dir: The directory to recursively search.
top_level_dir: The top level of the package, for importing.
pattern: Unix shell-style pattern for filtering the filenames to import.
Returns:
list of modules.
"""
# start_dir and top_level_dir must be consistent with each other.
start_dir = os.path.realpath(start_dir)
top_level_dir = os.path.realpath(top_level_dir)
modules = []
sub_paths = list(os.walk(start_dir))
# We sort the directories & file paths to ensure a deterministic ordering when
# traversing |top_level_dir|.
sub_paths.sort(key=lambda paths_tuple: paths_tuple[0])
for dir_path, _, filenames in sub_paths:
# Sort the directories to walk recursively by the directory path.
filenames.sort()
for filename in filenames:
# Filter out unwanted filenames.
if filename.startswith('.') or filename.startswith('_'):
continue
if os.path.splitext(filename)[1] != '.py':
continue
if not fnmatch.fnmatch(filename, pattern):
continue
# Find the module.
module_rel_path = os.path.relpath(
os.path.join(dir_path, filename), top_level_dir)
module_name = re.sub(r'[/\\]', '.', os.path.splitext(module_rel_path)[0])
# Import the module.
try:
# Make sure that top_level_dir is the first path in the sys.path in case
# there are naming conflict in module parts.
original_sys_path = sys.path[:]
sys.path.insert(0, top_level_dir)
module = importlib.import_module(module_name)
modules.append(module)
finally:
sys.path = original_sys_path
return modules
def AssertNoKeyConflicts(classes_by_key_1, classes_by_key_2):
for k in classes_by_key_1:
if k in classes_by_key_2:
assert classes_by_key_1[k] is classes_by_key_2[k], (
'Found conflicting classes for the same key: '
'key=%s, class_1=%s, class_2=%s' % (
k, classes_by_key_1[k], classes_by_key_2[k]))
# TODO(dtu): Normalize all discoverable classes to have corresponding module
# and class names, then always index by class name.
def DiscoverClasses(start_dir,
top_level_dir,
base_class,
pattern='*',
index_by_class_name=True,
directly_constructable=False):
"""Discover all classes in |start_dir| which subclass |base_class|.
Base classes that contain subclasses are ignored by default.
Args:
start_dir: The directory to recursively search.
top_level_dir: The top level of the package, for importing.
base_class: The base class to search for.
pattern: Unix shell-style pattern for filtering the filenames to import.
index_by_class_name: If True, use class name converted to
lowercase_with_underscores instead of module name in return dict keys.
directly_constructable: If True, will only return classes that can be
constructed without arguments
Returns:
dict of {module_name: class} or {underscored_class_name: class}
"""
modules = DiscoverModules(start_dir, top_level_dir, pattern)
classes = {}
for module in modules:
new_classes = DiscoverClassesInModule(
module, base_class, index_by_class_name, directly_constructable)
# TODO(crbug.com/548652): we should remove index_by_class_name once
# benchmark_smoke_unittest in chromium/src/tools/perf no longer relied
# naming collisions to reduce the number of smoked benchmark tests.
if index_by_class_name:
AssertNoKeyConflicts(classes, new_classes)
classes = dict(list(classes.items()) + list(new_classes.items()))
return classes
# TODO(crbug.com/548652): we should remove index_by_class_name once
# benchmark_smoke_unittest in chromium/src/tools/perf no longer relied
# naming collisions to reduce the number of smoked benchmark tests.
def DiscoverClassesInModule(module,
base_class,
index_by_class_name=False,
directly_constructable=False):
"""Discover all classes in |module| which subclass |base_class|.
Base classes that contain subclasses are ignored by default.
Args:
module: The module to search.
base_class: The base class to search for.
index_by_class_name: If True, use class name converted to
lowercase_with_underscores instead of module name in return dict keys.
Returns:
dict of {module_name: class} or {underscored_class_name: class}
"""
classes = {}
for _, obj in inspect.getmembers(module):
# Ensure object is a class.
if not inspect.isclass(obj):
continue
# Include only subclasses of base_class.
if not issubclass(obj, base_class):
continue
# Exclude the base_class itself.
if obj is base_class:
continue
# Exclude protected or private classes.
if obj.__name__.startswith('_'):
continue
# Include only the module in which the class is defined.
# If a class is imported by another module, exclude those duplicates.
if obj.__module__ != module.__name__:
continue
if index_by_class_name:
key_name = camel_case.ToUnderscore(obj.__name__)
else:
key_name = module.__name__.split('.')[-1]
if not directly_constructable or IsDirectlyConstructable(obj):
if key_name in classes and index_by_class_name:
assert classes[key_name] is obj, (
'Duplicate key_name with different objs detected: '
'key=%s, obj1=%s, obj2=%s' % (key_name, classes[key_name], obj))
else:
classes[key_name] = obj
return classes
def IsDirectlyConstructable(cls):
"""Returns True if instance of |cls| can be construct without arguments."""
assert inspect.isclass(cls)
if not hasattr(cls, '__init__'):
# Case |class A: pass|.
return True
if cls.__init__ is object.__init__:
# Case |class A(object): pass|.
return True
# Case |class (object):| with |__init__| other than |object.__init__|.
args, _, _, defaults = inspect.getargspec(cls.__init__)
if defaults is None:
defaults = ()
# Return true if |self| is only arg without a default.
return len(args) == len(defaults) + 1
_COUNTER = [0]
def _GetUniqueModuleName():
_COUNTER[0] += 1
return "module_" + str(_COUNTER[0])