diff --git a/plugins/module_utils/openvpn_client.py b/plugins/module_utils/openvpn_client.py index 0ea1fd12..855c6211 100644 --- a/plugins/module_utils/openvpn_client.py +++ b/plugins/module_utils/openvpn_client.py @@ -94,7 +94,7 @@ class PFSenseOpenVPNClientModule(PFSenseModuleBase): def __init__(self, module, pfsense=None): super(PFSenseOpenVPNClientModule, self).__init__(module, pfsense) self.name = "pfsense_openvpn" - self.root_elt = self.pfsense.get_element('openvpn') + self.root_elt = self.pfsense.get_element('openvpn', create_node=True) self.obj = dict() ############################## @@ -180,13 +180,19 @@ def _validate_params(self): self.module.fail_json(msg='Cannot find authentication client {0}.'.format(authsrv)) # validate key - if params['shared_key'] is not None: - key = params['shared_key'] - lines = key.splitlines() - if lines[0] == '-----BEGIN OpenVPN Static key V1-----' and lines[-1] == '-----END OpenVPN Static key V1-----': - params['shared_key'] = base64.b64encode(key.encode()).decode() - elif not re.match('LS0tLS1CRUdJTiBPcGVuVlBOIFN0YXRpYyBrZXkgVjEtLS0tLQ', key): - self.module.fail_json(msg='Could not recognize key format: %s' % (key)) + for param in ['shared_key', 'tls']: + if params[param] is not None: + key = params[param] + if key == 'generate': + # generate during params_to_obj + pass + elif re.search('^-----BEGIN OpenVPN Static key V1-----.*-----END OpenVPN Static key V1-----$', key, flags=re.MULTILINE | re.DOTALL): + params[param] = base64.b64encode(key.encode()).decode() + else: + key_decoded = base64.b64decode(key.encode()).decode() + if not re.search('^-----BEGIN OpenVPN Static key V1-----.*-----END OpenVPN Static key V1-----$', + key_decoded, flags=re.MULTILINE | re.DOTALL): + self.module.fail_json(msg='Could not recognize {0} key format: {1}'.format(param, key_decoded)) def _nextvpnid(self): """ find next available vpnid """ @@ -216,16 +222,9 @@ def _find_last_openvpn_idx(self): def _copy_and_update_target(self): """ update the XML target_elt """ - before = self.pfsense.element_to_dict(self.target_elt) - changed = self.pfsense.copy_dict_to_element(self.obj, self.target_elt) - if self._remove_deleted_params(): - changed = True - - self.diff['before'] = before - if changed: - self.diff['after'] = self.pfsense.element_to_dict(self.target_elt) - self.result['changed'] = True - else: + (before, changed) = super(PFSenseOpenVPNClientModule, self)._copy_and_update_target() + + if not changed: self.diff['after'] = self.obj return (before, changed) @@ -243,6 +242,16 @@ def _create_target(self): def _find_target(self): """ find the XML target_elt """ (target_elt, self.idx) = self._find_openvpn_client(self.obj['description']) + for param in ['shared_key', 'tls']: + current_elt = self.pfsense.get_element(param, target_elt) + if self.params[param] == 'generate': + if current_elt is None: + (dummy, key, stderr) = self.module.run_command('/usr/local/sbin/openvpn --genkey secret /dev/stdout') + if stderr != "": + self.module.fail_json(msg='generate for "{0}" secret key: {1}'.format(param, stderr)) + self.obj[param] = base64.b64encode(key.encode()).decode() + else: + self.obj[param] = current_elt.text return target_elt def _remove_target_elt(self): diff --git a/plugins/module_utils/openvpn_server.py b/plugins/module_utils/openvpn_server.py index a5c06d35..00b29478 100644 --- a/plugins/module_utils/openvpn_server.py +++ b/plugins/module_utils/openvpn_server.py @@ -106,7 +106,7 @@ def get_argument_spec(): def __init__(self, module, pfsense=None): super(PFSenseOpenVPNServerModule, self).__init__(module, pfsense) self.name = "pfsense_openvpn_server" - self.root_elt = self.pfsense.get_element('openvpn') + self.root_elt = self.pfsense.get_element('openvpn', create_node=True) self.obj = dict() ############################## @@ -213,10 +213,13 @@ def _validate_params(self): for param in ['shared_key', 'tls']: if params[param] is not None: key = params[param] - if re.search('^-----BEGIN OpenVPN Static key V1-----.*-----END OpenVPN Static key V1-----$', key, flags=re.MULTILINE | re.DOTALL): + if key == 'generate': + # generate during params_to_obj + pass + elif re.search('^-----BEGIN OpenVPN Static key V1-----.*-----END OpenVPN Static key V1-----$', key, flags=re.MULTILINE | re.DOTALL): params[param] = base64.b64encode(key.encode()).decode() else: - key_decoded = base64.b64decode(params[param].encode()).decode() + key_decoded = base64.b64decode(key.encode()).decode() if not re.search('^-----BEGIN OpenVPN Static key V1-----.*-----END OpenVPN Static key V1-----$', key_decoded, flags=re.MULTILINE | re.DOTALL): self.module.fail_json(msg='Could not recognize {0} key format: {1}'.format(param, key_decoded)) @@ -283,21 +286,15 @@ def _get_params_to_remove(self): def _copy_and_update_target(self): """ update the XML target_elt """ - before = self.pfsense.element_to_dict(self.target_elt) + (before, changed) = super(PFSenseOpenVPNServerModule, self)._copy_and_update_target() + # Check if local port is used self._openvpn_port_used(self.params['protocol'], self.params['interface'], self.params['local_port'], before['vpnid']) - changed = self.pfsense.copy_dict_to_element(self.obj, self.target_elt) - if self._remove_deleted_params(): - changed = True - - self.diff['before'] = before - if changed: - self.diff['after'] = self.pfsense.element_to_dict(self.target_elt) - self.result['changed'] = True - else: + + if not changed: self.diff['after'] = self.obj - self.result['vpnid'] = int(self.diff['before']['vpnid']) + self.result['vpnid'] = int(before['vpnid']) return (before, changed) def _create_target(self): @@ -316,6 +313,16 @@ def _create_target(self): def _find_target(self): """ find the XML target_elt """ (target_elt, self.idx) = self._find_openvpn_server(self.obj['description']) + for param in ['shared_key', 'tls']: + current_elt = self.pfsense.get_element(param, target_elt) + if self.params[param] == 'generate': + if current_elt is None: + (dummy, key, stderr) = self.module.run_command('/usr/local/sbin/openvpn --genkey secret /dev/stdout') + if stderr != "": + self.module.fail_json(msg='generate for "{0}" secret key: {1}'.format(param, stderr)) + self.obj[param] = base64.b64encode(key.encode()).decode() + else: + self.obj[param] = current_elt.text return target_elt ############################## diff --git a/plugins/modules/pfsense_openvpn_client.py b/plugins/modules/pfsense_openvpn_client.py index 9b0fa12f..790d967c 100644 --- a/plugins/modules/pfsense_openvpn_client.py +++ b/plugins/modules/pfsense_openvpn_client.py @@ -87,7 +87,7 @@ default: false type: bool shared_key: - description: Pre-shared key for shared key modes. + description: Pre-shared key for shared key modes. If set to 'generate' it will create a key if one does not already exist. type: str dh_length: description: DH parameter length. diff --git a/plugins/modules/pfsense_openvpn_server.py b/plugins/modules/pfsense_openvpn_server.py index 384126d3..de14159a 100644 --- a/plugins/modules/pfsense_openvpn_server.py +++ b/plugins/modules/pfsense_openvpn_server.py @@ -90,7 +90,7 @@ default: false type: bool shared_key: - description: Pre-shared key for shared key modes. + description: Pre-shared key for shared key modes. If set to 'generate' it will create a key if one does not already exist. type: str dh_length: description: DH parameter length. diff --git a/tests/unit/plugins/modules/test_pfsense_openvpn_server.py b/tests/unit/plugins/modules/test_pfsense_openvpn_server.py index 7f9431b9..4c92060e 100644 --- a/tests/unit/plugins/modules/test_pfsense_openvpn_server.py +++ b/tests/unit/plugins/modules/test_pfsense_openvpn_server.py @@ -4,6 +4,7 @@ from __future__ import (absolute_import, division, print_function) __metaclass__ = type +import base64 import pytest import sys @@ -12,6 +13,7 @@ from ansible_collections.pfsensible.core.plugins.modules import pfsense_openvpn_server from .pfsense_module import TestPFSenseModule +from ansible_collections.community.internal_test_tools.tests.unit.compat.mock import patch CERTIFICATE = ( "LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tDQpNSUlFQ0RDQ0F2Q2dBd0lCQWdJSUZqRk9oczFuTXpRd0RRWUpLb1pJaHZjTkFRRUxCUUF3WERFVE1CRUdBMVVFDQpBeE1LYjNCbGJuWndiaTFqWVRF" @@ -47,6 +49,21 @@ def __init__(self, *args, **kwargs): self.config_file = 'pfsense_openvpn_config.xml' self.pfmodule = pfsense_openvpn_server.PFSenseOpenVPNServerModule + def setUp(self): + """ mocking up """ + + super(TestPFSenseOpenVPNServerModule, self).setUp() + + self.mock_run_command = patch('ansible.module_utils.basic.AnsibleModule.run_command') + self.run_command = self.mock_run_command.start() + self.run_command.return_value = (0, base64.b64decode(TLSKEY.encode()).decode(), '') + + def tearDown(self): + """ mocking down """ + super(TestPFSenseOpenVPNServerModule, self).tearDown() + + self.run_command.stop() + @staticmethod def runTest(): """ dummy function needed to instantiate this test module from another in python 2.7 """ @@ -87,6 +104,12 @@ def certref(descr): def check_target_elt(self, obj, target_elt): """ check XML definition of target elt """ + # Use "generated" key + if 'shared_key' in obj and obj['shared_key'] == 'generate': + obj['shared_key'] = TLSKEY + if 'tls' in obj and obj['tls'] == 'generate': + obj['tls'] = TLSKEY + self.check_param_equal(obj, target_elt, 'name', xml_field='description') self.check_param_equal(obj, target_elt, 'custom_options') self.check_param_equal(obj, target_elt, 'mode', default='ptp_tls') @@ -100,6 +123,7 @@ def check_target_elt(self, obj, target_elt): self.check_param_equal(obj, target_elt, 'local_port', default=1194) self.check_param_equal(obj, target_elt, 'protocol', default='UDP4') if 'tls' in obj['mode']: + self.check_param_equal(obj, target_elt, 'tls') self.check_param_equal(obj, target_elt, 'tls') self.check_param_equal(obj, target_elt, 'tls_type') self.assert_xml_elt_equal(target_elt, 'caref', self.caref(obj['ca'])) @@ -141,6 +165,11 @@ def test_openvpn_server_create(self): obj = dict(name='ovpns3', mode='p2p_tls', ca='OpenVPN CA', local_port=1196) self.do_module_test(obj, command="create openvpn_server 'ovpns3', description='ovpns3'") + def test_openvpn_server_create_generate(self): + """ test creation of a new OpenVPN server """ + obj = dict(name='ovpns3', mode='p2p_tls', ca='OpenVPN CA', local_port=1196, tls='generate') + self.do_module_test(obj, command="create openvpn_server 'ovpns3', description='ovpns3'") + def test_openvpn_server_delete(self): """ test deletion of a OpenVPN server """ obj = dict(name='ovpns2')