Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: openvpn_server - generate #89

Merged
merged 3 commits into from
Jan 4, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 27 additions & 18 deletions plugins/module_utils/openvpn_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ class PFSenseOpenVPNClientModule(PFSenseModuleBase):
def __init__(self, module, pfsense=None):
super(PFSenseOpenVPNClientModule, self).__init__(module, pfsense)
self.name = "pfsense_openvpn"
self.root_elt = self.pfsense.get_element('openvpn')
self.root_elt = self.pfsense.get_element('openvpn', create_node=True)
self.obj = dict()

##############################
Expand Down Expand Up @@ -180,13 +180,19 @@ def _validate_params(self):
self.module.fail_json(msg='Cannot find authentication client {0}.'.format(authsrv))

# validate key
if params['shared_key'] is not None:
key = params['shared_key']
lines = key.splitlines()
if lines[0] == '-----BEGIN OpenVPN Static key V1-----' and lines[-1] == '-----END OpenVPN Static key V1-----':
params['shared_key'] = base64.b64encode(key.encode()).decode()
elif not re.match('LS0tLS1CRUdJTiBPcGVuVlBOIFN0YXRpYyBrZXkgVjEtLS0tLQ', key):
self.module.fail_json(msg='Could not recognize key format: %s' % (key))
for param in ['shared_key', 'tls']:
if params[param] is not None:
key = params[param]
if key == 'generate':
# generate during params_to_obj
pass
elif re.search('^-----BEGIN OpenVPN Static key V1-----.*-----END OpenVPN Static key V1-----$', key, flags=re.MULTILINE | re.DOTALL):
params[param] = base64.b64encode(key.encode()).decode()
else:
key_decoded = base64.b64decode(key.encode()).decode()
if not re.search('^-----BEGIN OpenVPN Static key V1-----.*-----END OpenVPN Static key V1-----$',
key_decoded, flags=re.MULTILINE | re.DOTALL):
self.module.fail_json(msg='Could not recognize {0} key format: {1}'.format(param, key_decoded))

def _nextvpnid(self):
""" find next available vpnid """
Expand Down Expand Up @@ -216,16 +222,9 @@ def _find_last_openvpn_idx(self):

def _copy_and_update_target(self):
""" update the XML target_elt """
before = self.pfsense.element_to_dict(self.target_elt)
changed = self.pfsense.copy_dict_to_element(self.obj, self.target_elt)
if self._remove_deleted_params():
changed = True

self.diff['before'] = before
if changed:
self.diff['after'] = self.pfsense.element_to_dict(self.target_elt)
self.result['changed'] = True
else:
(before, changed) = super(PFSenseOpenVPNClientModule, self)._copy_and_update_target()

if not changed:
self.diff['after'] = self.obj

return (before, changed)
Expand All @@ -243,6 +242,16 @@ def _create_target(self):
def _find_target(self):
""" find the XML target_elt """
(target_elt, self.idx) = self._find_openvpn_client(self.obj['description'])
for param in ['shared_key', 'tls']:
current_elt = self.pfsense.get_element(param, target_elt)
if self.params[param] == 'generate':
if current_elt is None:
(dummy, key, stderr) = self.module.run_command('/usr/local/sbin/openvpn --genkey secret /dev/stdout')
if stderr != "":
self.module.fail_json(msg='generate for "{0}" secret key: {1}'.format(param, stderr))
self.obj[param] = base64.b64encode(key.encode()).decode()
else:
self.obj[param] = current_elt.text
return target_elt

def _remove_target_elt(self):
Expand Down
35 changes: 21 additions & 14 deletions plugins/module_utils/openvpn_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ def get_argument_spec():
def __init__(self, module, pfsense=None):
super(PFSenseOpenVPNServerModule, self).__init__(module, pfsense)
self.name = "pfsense_openvpn_server"
self.root_elt = self.pfsense.get_element('openvpn')
self.root_elt = self.pfsense.get_element('openvpn', create_node=True)
self.obj = dict()

##############################
Expand Down Expand Up @@ -213,10 +213,13 @@ def _validate_params(self):
for param in ['shared_key', 'tls']:
if params[param] is not None:
key = params[param]
if re.search('^-----BEGIN OpenVPN Static key V1-----.*-----END OpenVPN Static key V1-----$', key, flags=re.MULTILINE | re.DOTALL):
if key == 'generate':
# generate during params_to_obj
pass
elif re.search('^-----BEGIN OpenVPN Static key V1-----.*-----END OpenVPN Static key V1-----$', key, flags=re.MULTILINE | re.DOTALL):
params[param] = base64.b64encode(key.encode()).decode()
else:
key_decoded = base64.b64decode(params[param].encode()).decode()
key_decoded = base64.b64decode(key.encode()).decode()
if not re.search('^-----BEGIN OpenVPN Static key V1-----.*-----END OpenVPN Static key V1-----$',
key_decoded, flags=re.MULTILINE | re.DOTALL):
self.module.fail_json(msg='Could not recognize {0} key format: {1}'.format(param, key_decoded))
Expand Down Expand Up @@ -283,21 +286,15 @@ def _get_params_to_remove(self):

def _copy_and_update_target(self):
""" update the XML target_elt """
before = self.pfsense.element_to_dict(self.target_elt)
(before, changed) = super(PFSenseOpenVPNServerModule, self)._copy_and_update_target()

# Check if local port is used
self._openvpn_port_used(self.params['protocol'], self.params['interface'], self.params['local_port'], before['vpnid'])
changed = self.pfsense.copy_dict_to_element(self.obj, self.target_elt)
if self._remove_deleted_params():
changed = True

self.diff['before'] = before
if changed:
self.diff['after'] = self.pfsense.element_to_dict(self.target_elt)
self.result['changed'] = True
else:

if not changed:
self.diff['after'] = self.obj

self.result['vpnid'] = int(self.diff['before']['vpnid'])
self.result['vpnid'] = int(before['vpnid'])
return (before, changed)

def _create_target(self):
Expand All @@ -316,6 +313,16 @@ def _create_target(self):
def _find_target(self):
""" find the XML target_elt """
(target_elt, self.idx) = self._find_openvpn_server(self.obj['description'])
for param in ['shared_key', 'tls']:
current_elt = self.pfsense.get_element(param, target_elt)
if self.params[param] == 'generate':
if current_elt is None:
(dummy, key, stderr) = self.module.run_command('/usr/local/sbin/openvpn --genkey secret /dev/stdout')
if stderr != "":
self.module.fail_json(msg='generate for "{0}" secret key: {1}'.format(param, stderr))
self.obj[param] = base64.b64encode(key.encode()).decode()
else:
self.obj[param] = current_elt.text
return target_elt

##############################
Expand Down
2 changes: 1 addition & 1 deletion plugins/modules/pfsense_openvpn_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,7 +87,7 @@
default: false
type: bool
shared_key:
description: Pre-shared key for shared key modes.
description: Pre-shared key for shared key modes. If set to 'generate' it will create a key if one does not already exist.
type: str
dh_length:
description: DH parameter length.
Expand Down
2 changes: 1 addition & 1 deletion plugins/modules/pfsense_openvpn_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,7 +90,7 @@
default: false
type: bool
shared_key:
description: Pre-shared key for shared key modes.
description: Pre-shared key for shared key modes. If set to 'generate' it will create a key if one does not already exist.
type: str
dh_length:
description: DH parameter length.
Expand Down
29 changes: 29 additions & 0 deletions tests/unit/plugins/modules/test_pfsense_openvpn_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from __future__ import (absolute_import, division, print_function)
__metaclass__ = type

import base64
import pytest
import sys

Expand All @@ -12,6 +13,7 @@

from ansible_collections.pfsensible.core.plugins.modules import pfsense_openvpn_server
from .pfsense_module import TestPFSenseModule
from ansible_collections.community.internal_test_tools.tests.unit.compat.mock import patch

CERTIFICATE = (
"LS0tLS1CRUdJTiBDRVJUSUZJQ0FURS0tLS0tDQpNSUlFQ0RDQ0F2Q2dBd0lCQWdJSUZqRk9oczFuTXpRd0RRWUpLb1pJaHZjTkFRRUxCUUF3WERFVE1CRUdBMVVFDQpBeE1LYjNCbGJuWndiaTFqWVRF"
Expand Down Expand Up @@ -47,6 +49,21 @@ def __init__(self, *args, **kwargs):
self.config_file = 'pfsense_openvpn_config.xml'
self.pfmodule = pfsense_openvpn_server.PFSenseOpenVPNServerModule

def setUp(self):
""" mocking up """

super(TestPFSenseOpenVPNServerModule, self).setUp()

self.mock_run_command = patch('ansible.module_utils.basic.AnsibleModule.run_command')
self.run_command = self.mock_run_command.start()
self.run_command.return_value = (0, base64.b64decode(TLSKEY.encode()).decode(), '')

def tearDown(self):
""" mocking down """
super(TestPFSenseOpenVPNServerModule, self).tearDown()

self.run_command.stop()

@staticmethod
def runTest():
""" dummy function needed to instantiate this test module from another in python 2.7 """
Expand Down Expand Up @@ -87,6 +104,12 @@ def certref(descr):
def check_target_elt(self, obj, target_elt):
""" check XML definition of target elt """

# Use "generated" key
if 'shared_key' in obj and obj['shared_key'] == 'generate':
obj['shared_key'] = TLSKEY
if 'tls' in obj and obj['tls'] == 'generate':
obj['tls'] = TLSKEY

self.check_param_equal(obj, target_elt, 'name', xml_field='description')
self.check_param_equal(obj, target_elt, 'custom_options')
self.check_param_equal(obj, target_elt, 'mode', default='ptp_tls')
Expand All @@ -100,6 +123,7 @@ def check_target_elt(self, obj, target_elt):
self.check_param_equal(obj, target_elt, 'local_port', default=1194)
self.check_param_equal(obj, target_elt, 'protocol', default='UDP4')
if 'tls' in obj['mode']:
self.check_param_equal(obj, target_elt, 'tls')
self.check_param_equal(obj, target_elt, 'tls')
self.check_param_equal(obj, target_elt, 'tls_type')
self.assert_xml_elt_equal(target_elt, 'caref', self.caref(obj['ca']))
Expand Down Expand Up @@ -141,6 +165,11 @@ def test_openvpn_server_create(self):
obj = dict(name='ovpns3', mode='p2p_tls', ca='OpenVPN CA', local_port=1196)
self.do_module_test(obj, command="create openvpn_server 'ovpns3', description='ovpns3'")

def test_openvpn_server_create_generate(self):
""" test creation of a new OpenVPN server """
obj = dict(name='ovpns3', mode='p2p_tls', ca='OpenVPN CA', local_port=1196, tls='generate')
self.do_module_test(obj, command="create openvpn_server 'ovpns3', description='ovpns3'")

def test_openvpn_server_delete(self):
""" test deletion of a OpenVPN server """
obj = dict(name='ovpns2')
Expand Down