144 lines
6 KiB
Python
144 lines
6 KiB
Python
import os
|
||
import tempfile
|
||
from unittest import mock
|
||
from unittest.mock import patch
|
||
|
||
import pytest
|
||
import requests
|
||
from rest_framework import status
|
||
|
||
from devices.models.firewall import ArmaIndustrialFirewall
|
||
from devices.services.firewall import FirewallService, ConnectionException, IncompatibilityVersionException, \
|
||
InvalidCredentialException, InvalidResponseException, FailedUploadException, InvalidFileException
|
||
from devices.tasks.firewall import download_files_from_firewall
|
||
from storage.models import DataStorage
|
||
|
||
TEST_REQUEST_RESPONSE_LIST = [
|
||
requests.exceptions.ConnectTimeout,
|
||
requests.exceptions.ConnectionError,
|
||
requests.exceptions.Timeout
|
||
]
|
||
|
||
check_connection_exceptions = [InvalidCredentialException, IncompatibilityVersionException, ConnectionException]
|
||
upload_file_exception = [
|
||
[InvalidFileException, {'status': 'invalid'}],
|
||
[FailedUploadException, {'status': 'failed'}],
|
||
[InvalidResponseException, {'status': 'blah-blah'}]
|
||
]
|
||
|
||
TEST_ARMAIF_VERSIONS = [
|
||
['3.6', True],
|
||
['3.5.1', False],
|
||
['3.6-rc1', True],
|
||
['3.6-rc2', True],
|
||
['3.6-rc3', True],
|
||
['3.6-rc0', True],
|
||
['3.6-rc1234', True],
|
||
['3.6-rс-41', True],
|
||
['3.6-rс13232', True],
|
||
['3.6-rc3123123', True],
|
||
['3.8', True],
|
||
['1238.13-123.fda213', True],
|
||
['3.9-rc1', True],
|
||
['111111111111-22222222', True],
|
||
['3.6-kek', True],
|
||
['2.6', False],
|
||
['3.9req', False],
|
||
['1234123.rqr123-e12', False],
|
||
['f89y48fqyhiuyhf8o71y-82y8f82y73f8y', False],
|
||
['1-2-3-4-5', False],
|
||
['3.5', False],
|
||
['3.5-dev18723', False],
|
||
]
|
||
|
||
BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
||
TEST_FILES = os.path.join(BASE_DIR, "tests", "test_files")
|
||
|
||
get_addr = [
|
||
['1.1.1.1', 'api/test'],
|
||
['5.1.1.1', '/543'],
|
||
['1.56.90.255', ''],
|
||
]
|
||
|
||
|
||
@pytest.mark.unit
|
||
@pytest.mark.django_db
|
||
class TestFirewallService:
|
||
|
||
@pytest.fixture(autouse=True)
|
||
def setup_tests(self, api_client, django_user_model, add_user_with_permissions):
|
||
username = 'foo'
|
||
password = 'bar'
|
||
self.user = add_user_with_permissions(username=username, password=password, is_superuser=True)
|
||
self.firewall = ArmaIndustrialFirewall(
|
||
name='IF', ip='1.1.1.1', key='key', secret='secret', port=1500, type='firewall')
|
||
|
||
@pytest.mark.parametrize('req_res', TEST_REQUEST_RESPONSE_LIST)
|
||
def test_check_connection_to_firewall_exception_responses(self, req_res):
|
||
data_request = {"ip": "1.2.3.4", "key": "123", "secret": "321"}
|
||
|
||
with mock.patch('requests.Session.get', side_effect=req_res):
|
||
with pytest.raises(ConnectionException) as e:
|
||
FirewallService().check_connection(data_request)
|
||
assert e.value.detail['detail'] == 'There was a problem connecting to the firewall'
|
||
|
||
def test_check_connection_to_firewall_invalid_credentials(self):
|
||
data_request = {"ip": "1.2.3.4", "key": "123", "secret": "321"}
|
||
|
||
with mock.patch('requests.Session.get') as mock_get:
|
||
mock_get.return_value.status_code = status.HTTP_401_UNAUTHORIZED
|
||
with pytest.raises(InvalidCredentialException) as e:
|
||
FirewallService().check_connection(data_request)
|
||
assert e.value.detail['detail'] == 'Invalid credentials provided to connect to firewall'
|
||
|
||
def test_check_connection_to_firewall_incompatible_version(self):
|
||
data_request = {"ip": "1.2.3.4", "key": "123", "secret": "321"}
|
||
with mock.patch('requests.Session.get') as mock_get:
|
||
mock_get.return_value.status_code = 200
|
||
mock_get.return_value.json.return_value = {'status': 'ok', 'items': {'product_version': '3.1'}}
|
||
with pytest.raises(IncompatibilityVersionException) as e:
|
||
FirewallService().check_connection(data_request)
|
||
assert e.value.detail['detail'] == 'The firewall version is incompatible with the current console version'
|
||
|
||
@pytest.mark.parametrize('exc', upload_file_exception)
|
||
def test_upload_firewall_config_with_error(self, exc, api_client):
|
||
api_client.force_authenticate(self.user)
|
||
file_path = os.path.join(TEST_FILES, 'config.xml')
|
||
file = open(file_path, 'r')
|
||
with mock.patch('requests.Session.post') as mock_post:
|
||
mock_post.return_value.status_code = 200
|
||
mock_post.return_value.json.return_value = exc[1]
|
||
with pytest.raises(exc[0]) as e:
|
||
FirewallService(self.firewall).upload_file(file, 'config')
|
||
assert e.value.detail['detail'] == exc[0].default_detail['detail']
|
||
|
||
@pytest.mark.parametrize('armaif_version', TEST_ARMAIF_VERSIONS)
|
||
def test_check_armaif_version_validator(self, armaif_version):
|
||
assert FirewallService.firewall_version_validator(armaif_version[0]) == armaif_version[1]
|
||
|
||
@pytest.mark.parametrize('addr', get_addr)
|
||
def test_get_addr_firewall(self, addr):
|
||
firewall = ArmaIndustrialFirewall(
|
||
name='IF', ip=addr[0], key='key', secret='secret', port=1500, type='firewall')
|
||
assert FirewallService(firewall).get_addr(addr[1]) == 'https://{}/{}'.format(firewall.ip, addr[1].lstrip('/'))
|
||
|
||
def test_firewall_download_file_task(self):
|
||
"""Test download file from firewall(mocked) and add to storage."""
|
||
file_name = 'test_abc.tar'
|
||
tmp_mediaroot = tempfile.TemporaryDirectory().name
|
||
storage_file_name = os.path.join(tmp_mediaroot, file_name)
|
||
with patch('devices.services.firewall.firewall.FirewallService.download_file', lambda *args: (b'firewall__data', file_name)):
|
||
with patch('devices.tasks.firewall.MEDIA_ROOT', tmp_mediaroot):
|
||
with patch('devices.tasks.firewall.get_storage_path', lambda *args: file_name):
|
||
pk = download_files_from_firewall(self.firewall, self.user, 'config')
|
||
|
||
assert isinstance(pk, int)
|
||
|
||
storage = DataStorage.objects.get(pk=pk)
|
||
assert storage.file == file_name
|
||
|
||
assert os.path.exists(storage_file_name)
|
||
with open(storage_file_name, 'br') as file:
|
||
data = file.read()
|
||
assert data == b'firewall__data'
|
||
assert storage.format == DataStorage.Format.XML
|