Implement environment singleton to be accessed throughout the code
Load and parse environment file from working dir Signed-off-by: Joffrey F <joffrey@docker.com>
This commit is contained in:
parent
d1d8df7f72
commit
c69d8a3bd2
13 changed files with 151 additions and 64 deletions
|
|
@ -21,7 +21,7 @@ log = logging.getLogger(__name__)
|
|||
def project_from_options(project_dir, options):
|
||||
return get_project(
|
||||
project_dir,
|
||||
get_config_path_from_options(options),
|
||||
get_config_path_from_options(project_dir, options),
|
||||
project_name=options.get('--project-name'),
|
||||
verbose=options.get('--verbose'),
|
||||
host=options.get('--host'),
|
||||
|
|
@ -29,12 +29,13 @@ def project_from_options(project_dir, options):
|
|||
)
|
||||
|
||||
|
||||
def get_config_path_from_options(options):
|
||||
def get_config_path_from_options(base_dir, options):
|
||||
file_option = options.get('--file')
|
||||
if file_option:
|
||||
return file_option
|
||||
|
||||
config_files = os.environ.get('COMPOSE_FILE')
|
||||
environment = config.environment.get_instance(base_dir)
|
||||
config_files = environment.get('COMPOSE_FILE')
|
||||
if config_files:
|
||||
return config_files.split(os.pathsep)
|
||||
return None
|
||||
|
|
@ -57,8 +58,9 @@ def get_project(project_dir, config_path=None, project_name=None, verbose=False,
|
|||
config_details = config.find(project_dir, config_path)
|
||||
project_name = get_project_name(config_details.working_dir, project_name)
|
||||
config_data = config.load(config_details)
|
||||
environment = config.environment.get_instance(project_dir)
|
||||
|
||||
api_version = os.environ.get(
|
||||
api_version = environment.get(
|
||||
'COMPOSE_API_VERSION',
|
||||
API_VERSIONS[config_data.version])
|
||||
client = get_client(
|
||||
|
|
@ -73,7 +75,8 @@ def get_project_name(working_dir, project_name=None):
|
|||
def normalize_name(name):
|
||||
return re.sub(r'[^a-z0-9]', '', name.lower())
|
||||
|
||||
project_name = project_name or os.environ.get('COMPOSE_PROJECT_NAME')
|
||||
environment = config.environment.get_instance(working_dir)
|
||||
project_name = project_name or environment.get('COMPOSE_PROJECT_NAME')
|
||||
if project_name:
|
||||
return normalize_name(project_name)
|
||||
|
||||
|
|
|
|||
|
|
@ -222,7 +222,7 @@ class TopLevelCommand(object):
|
|||
--services Print the service names, one per line.
|
||||
|
||||
"""
|
||||
config_path = get_config_path_from_options(config_options)
|
||||
config_path = get_config_path_from_options(self.project_dir, config_options)
|
||||
compose_config = config.load(config.find(self.project_dir, config_path))
|
||||
|
||||
if options['--quiet']:
|
||||
|
|
|
|||
|
|
@ -2,6 +2,7 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import unicode_literals
|
||||
|
||||
from . import environment
|
||||
from .config import ConfigurationError
|
||||
from .config import DOCKER_CONFIG_KEYS
|
||||
from .config import find
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@ from cached_property import cached_property
|
|||
from ..const import COMPOSEFILE_V1 as V1
|
||||
from ..const import COMPOSEFILE_V2_0 as V2_0
|
||||
from ..utils import build_string_dict
|
||||
from .environment import Environment
|
||||
from .errors import CircularReference
|
||||
from .errors import ComposeFileNotFound
|
||||
from .errors import ConfigurationError
|
||||
|
|
@ -211,7 +212,8 @@ def find(base_dir, filenames):
|
|||
if filenames == ['-']:
|
||||
return ConfigDetails(
|
||||
os.getcwd(),
|
||||
[ConfigFile(None, yaml.safe_load(sys.stdin))])
|
||||
[ConfigFile(None, yaml.safe_load(sys.stdin))],
|
||||
)
|
||||
|
||||
if filenames:
|
||||
filenames = [os.path.join(base_dir, f) for f in filenames]
|
||||
|
|
@ -221,7 +223,8 @@ def find(base_dir, filenames):
|
|||
log.debug("Using configuration files: {}".format(",".join(filenames)))
|
||||
return ConfigDetails(
|
||||
os.path.dirname(filenames[0]),
|
||||
[ConfigFile.from_filename(f) for f in filenames])
|
||||
[ConfigFile.from_filename(f) for f in filenames],
|
||||
)
|
||||
|
||||
|
||||
def validate_config_version(config_files):
|
||||
|
|
@ -288,6 +291,10 @@ def load(config_details):
|
|||
"""
|
||||
validate_config_version(config_details.config_files)
|
||||
|
||||
# load environment in working dir for later use in interpolation
|
||||
# it is done here to avoid having to pass down working_dir
|
||||
Environment.get_instance(config_details.working_dir)
|
||||
|
||||
processed_files = [
|
||||
process_config_file(config_file)
|
||||
for config_file in config_details.config_files
|
||||
|
|
@ -302,9 +309,8 @@ def load(config_details):
|
|||
config_details.config_files, 'get_networks', 'Network'
|
||||
)
|
||||
service_dicts = load_services(
|
||||
config_details.working_dir,
|
||||
main_file,
|
||||
[file.get_service_dicts() for file in config_details.config_files])
|
||||
config_details, main_file,
|
||||
)
|
||||
|
||||
if main_file.version != V1:
|
||||
for service_dict in service_dicts:
|
||||
|
|
@ -348,14 +354,16 @@ def load_mapping(config_files, get_func, entity_type):
|
|||
return mapping
|
||||
|
||||
|
||||
def load_services(working_dir, config_file, service_configs):
|
||||
def load_services(config_details, config_file):
|
||||
def build_service(service_name, service_dict, service_names):
|
||||
service_config = ServiceConfig.with_abs_paths(
|
||||
working_dir,
|
||||
config_details.working_dir,
|
||||
config_file.filename,
|
||||
service_name,
|
||||
service_dict)
|
||||
resolver = ServiceExtendsResolver(service_config, config_file)
|
||||
resolver = ServiceExtendsResolver(
|
||||
service_config, config_file
|
||||
)
|
||||
service_dict = process_service(resolver.run())
|
||||
|
||||
service_config = service_config._replace(config=service_dict)
|
||||
|
|
@ -383,6 +391,10 @@ def load_services(working_dir, config_file, service_configs):
|
|||
for name in all_service_names
|
||||
}
|
||||
|
||||
service_configs = [
|
||||
file.get_service_dicts() for file in config_details.config_files
|
||||
]
|
||||
|
||||
service_config = service_configs[0]
|
||||
for next_config in service_configs[1:]:
|
||||
service_config = merge_services(service_config, next_config)
|
||||
|
|
@ -462,8 +474,8 @@ class ServiceExtendsResolver(object):
|
|||
extends_file = ConfigFile.from_filename(config_path)
|
||||
validate_config_version([self.config_file, extends_file])
|
||||
extended_file = process_config_file(
|
||||
extends_file,
|
||||
service_name=service_name)
|
||||
extends_file, service_name=service_name
|
||||
)
|
||||
service_config = extended_file.get_service(service_name)
|
||||
|
||||
return config_path, service_config, service_name
|
||||
|
|
@ -476,7 +488,8 @@ class ServiceExtendsResolver(object):
|
|||
service_name,
|
||||
service_dict),
|
||||
self.config_file,
|
||||
already_seen=self.already_seen + [self.signature])
|
||||
already_seen=self.already_seen + [self.signature],
|
||||
)
|
||||
|
||||
service_config = resolver.run()
|
||||
other_service_dict = process_service(service_config)
|
||||
|
|
@ -824,10 +837,11 @@ def parse_ulimits(ulimits):
|
|||
|
||||
|
||||
def resolve_env_var(key, val):
|
||||
environment = Environment.get_instance()
|
||||
if val is not None:
|
||||
return key, val
|
||||
elif key in os.environ:
|
||||
return key, os.environ[key]
|
||||
elif key in environment:
|
||||
return key, environment[key]
|
||||
else:
|
||||
return key, None
|
||||
|
||||
|
|
|
|||
69
compose/config/environment.py
Normal file
69
compose/config/environment.py
Normal file
|
|
@ -0,0 +1,69 @@
|
|||
from __future__ import absolute_import
|
||||
from __future__ import unicode_literals
|
||||
|
||||
import logging
|
||||
import os
|
||||
|
||||
from .errors import ConfigurationError
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BlankDefaultDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(BlankDefaultDict, self).__init__(*args, **kwargs)
|
||||
self.missing_keys = []
|
||||
|
||||
def __getitem__(self, key):
|
||||
try:
|
||||
return super(BlankDefaultDict, self).__getitem__(key)
|
||||
except KeyError:
|
||||
if key not in self.missing_keys:
|
||||
log.warn(
|
||||
"The {} variable is not set. Defaulting to a blank string."
|
||||
.format(key)
|
||||
)
|
||||
self.missing_keys.append(key)
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
class Environment(BlankDefaultDict):
|
||||
__instance = None
|
||||
|
||||
@classmethod
|
||||
def get_instance(cls, base_dir='.'):
|
||||
if cls.__instance:
|
||||
return cls.__instance
|
||||
|
||||
instance = cls(base_dir)
|
||||
cls.__instance = instance
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def reset(cls):
|
||||
cls.__instance = None
|
||||
|
||||
def __init__(self, base_dir):
|
||||
super(Environment, self).__init__()
|
||||
self.load_environment_file(os.path.join(base_dir, '.env'))
|
||||
self.update(os.environ)
|
||||
|
||||
def load_environment_file(self, path):
|
||||
if not os.path.exists(path):
|
||||
return
|
||||
mapping = {}
|
||||
with open(path, 'r') as f:
|
||||
for line in f.readlines():
|
||||
line = line.strip()
|
||||
if '=' not in line:
|
||||
raise ConfigurationError(
|
||||
'Invalid environment variable mapping in env file. '
|
||||
'Missing "=" in "{0}"'.format(line)
|
||||
)
|
||||
mapping.__setitem__(*line.split('=', 1))
|
||||
self.update(mapping)
|
||||
|
||||
|
||||
def get_instance(base_dir=None):
|
||||
return Environment.get_instance(base_dir)
|
||||
|
|
@ -2,17 +2,17 @@ from __future__ import absolute_import
|
|||
from __future__ import unicode_literals
|
||||
|
||||
import logging
|
||||
import os
|
||||
from string import Template
|
||||
|
||||
import six
|
||||
|
||||
from .environment import Environment
|
||||
from .errors import ConfigurationError
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def interpolate_environment_variables(config, section):
|
||||
mapping = BlankDefaultDict(os.environ)
|
||||
mapping = Environment.get_instance()
|
||||
|
||||
def process_item(name, config_dict):
|
||||
return dict(
|
||||
|
|
@ -60,25 +60,6 @@ def interpolate(string, mapping):
|
|||
raise InvalidInterpolation(string)
|
||||
|
||||
|
||||
class BlankDefaultDict(dict):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super(BlankDefaultDict, self).__init__(*args, **kwargs)
|
||||
self.missing_keys = []
|
||||
|
||||
def __getitem__(self, key):
|
||||
try:
|
||||
return super(BlankDefaultDict, self).__getitem__(key)
|
||||
except KeyError:
|
||||
if key not in self.missing_keys:
|
||||
log.warn(
|
||||
"The {} variable is not set. Defaulting to a blank string."
|
||||
.format(key)
|
||||
)
|
||||
self.missing_keys.append(key)
|
||||
|
||||
return ""
|
||||
|
||||
|
||||
class InvalidInterpolation(Exception):
|
||||
def __init__(self, string):
|
||||
self.string = string
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue