diff --git a/compose/config/schema.json b/compose/config/schema.json index 74f5edbb..24fd53d1 100644 --- a/compose/config/schema.json +++ b/compose/config/schema.json @@ -75,10 +75,22 @@ "pid": {"type": "string"}, "ports": { - "type": "array", - "items": {"type": "string"}, - "uniqueItems": true, - "format": "ports" + "oneOf": [ + { + "type": "array", + "items": {"type": "string"}, + "uniqueItems": true, + "format": "ports" + }, + { + "type": "string", + "format": "ports" + }, + { + "type": "number", + "format": "ports" + } + ] }, "privileged": {"type": "string"}, diff --git a/compose/config/validation.py b/compose/config/validation.py index 15e0754c..3f46632b 100644 --- a/compose/config/validation.py +++ b/compose/config/validation.py @@ -84,6 +84,13 @@ def process_errors(errors): required.append("Service '{}' has neither an image nor a build path specified. Exactly one must be provided.".format(service_name)) else: required.append(error.message) + elif error.validator == 'oneOf': + config_key = error.path[1] + valid_types = [context.validator_value for context in error.context] + valid_type_msg = " or ".join(valid_types) + type_errors.append("Service '{}' configuration key '{}' contains an invalid type, it should be either {}".format( + service_name, config_key, valid_type_msg) + ) elif error.validator == 'type': msg = "a" if error.validator_value == "array": diff --git a/tests/unit/config_test.py b/tests/unit/config_test.py index 4e982bb4..b4d2ce82 100644 --- a/tests/unit/config_test.py +++ b/tests/unit/config_test.py @@ -75,8 +75,9 @@ class ConfigTest(unittest.TestCase): ) def test_config_invalid_ports_format_validation(self): - with self.assertRaises(ConfigurationError): - for invalid_ports in [{"1": "8000"}, "whatport", "625", "8000:8050"]: + expected_error_msg = "Service 'web' configuration key 'ports' contains an invalid type" + with self.assertRaisesRegexp(ConfigurationError, expected_error_msg): + for invalid_ports in [{"1": "8000"}, False, 0]: config.load( config.ConfigDetails( {'web': {'image': 'busybox', 'ports': invalid_ports}}, @@ -86,7 +87,7 @@ class ConfigTest(unittest.TestCase): ) def test_config_valid_ports_format_validation(self): - valid_ports = [["8000", "9000"], ["8000/8050"], ["8000"]] + valid_ports = [["8000", "9000"], ["8000/8050"], ["8000"], "8000", 8000] for ports in valid_ports: config.load( config.ConfigDetails(