import os import tempfile from unittest import mock from unittest.mock import patch import pytest import requests from rest_framework import status from devices.models.firewall import ArmaIndustrialFirewall from devices.services.firewall import FirewallService, ConnectionException, IncompatibilityVersionException, \ InvalidCredentialException, InvalidResponseException, FailedUploadException, InvalidFileException from devices.tasks.firewall import download_files_from_firewall from storage.models import DataStorage TEST_REQUEST_RESPONSE_LIST = [ requests.exceptions.ConnectTimeout, requests.exceptions.ConnectionError, requests.exceptions.Timeout ] check_connection_exceptions = [InvalidCredentialException, IncompatibilityVersionException, ConnectionException] upload_file_exception = [ [InvalidFileException, {'status': 'invalid'}], [FailedUploadException, {'status': 'failed'}], [InvalidResponseException, {'status': 'blah-blah'}] ] TEST_ARMAIF_VERSIONS = [ ['3.6', True], ['3.5.1', False], ['3.6-rc1', True], ['3.6-rc2', True], ['3.6-rc3', True], ['3.6-rc0', True], ['3.6-rc1234', True], ['3.6-rс-41', True], ['3.6-rс13232', True], ['3.6-rc3123123', True], ['3.8', True], ['1238.13-123.fda213', True], ['3.9-rc1', True], ['111111111111-22222222', True], ['3.6-kek', True], ['2.6', False], ['3.9req', False], ['1234123.rqr123-e12', False], ['f89y48fqyhiuyhf8o71y-82y8f82y73f8y', False], ['1-2-3-4-5', False], ['3.5', False], ['3.5-dev18723', False], ] BASE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) TEST_FILES = os.path.join(BASE_DIR, "tests", "test_files") get_addr = [ ['1.1.1.1', 'api/test'], ['5.1.1.1', '/543'], ['1.56.90.255', ''], ] @pytest.mark.unit @pytest.mark.django_db class TestFirewallService: @pytest.fixture(autouse=True) def setup_tests(self, api_client, django_user_model, add_user_with_permissions): username = 'foo' password = 'bar' self.user = add_user_with_permissions(username=username, password=password, is_superuser=True) self.firewall = ArmaIndustrialFirewall( name='IF', ip='1.1.1.1', key='key', secret='secret', port=1500, type='firewall') @pytest.mark.parametrize('req_res', TEST_REQUEST_RESPONSE_LIST) def test_check_connection_to_firewall_exception_responses(self, req_res): data_request = {"ip": "1.2.3.4", "key": "123", "secret": "321"} with mock.patch('requests.Session.get', side_effect=req_res): with pytest.raises(ConnectionException) as e: FirewallService().check_connection(data_request) assert e.value.detail['detail'] == 'There was a problem connecting to the firewall' def test_check_connection_to_firewall_invalid_credentials(self): data_request = {"ip": "1.2.3.4", "key": "123", "secret": "321"} with mock.patch('requests.Session.get') as mock_get: mock_get.return_value.status_code = status.HTTP_401_UNAUTHORIZED with pytest.raises(InvalidCredentialException) as e: FirewallService().check_connection(data_request) assert e.value.detail['detail'] == 'Invalid credentials provided to connect to firewall' def test_check_connection_to_firewall_incompatible_version(self): data_request = {"ip": "1.2.3.4", "key": "123", "secret": "321"} with mock.patch('requests.Session.get') as mock_get: mock_get.return_value.status_code = 200 mock_get.return_value.json.return_value = {'status': 'ok', 'items': {'product_version': '3.1'}} with pytest.raises(IncompatibilityVersionException) as e: FirewallService().check_connection(data_request) assert e.value.detail['detail'] == 'The firewall version is incompatible with the current console version' @pytest.mark.parametrize('exc', upload_file_exception) def test_upload_firewall_config_with_error(self, exc, api_client): api_client.force_authenticate(self.user) file_path = os.path.join(TEST_FILES, 'config.xml') file = open(file_path, 'r') with mock.patch('requests.Session.post') as mock_post: mock_post.return_value.status_code = 200 mock_post.return_value.json.return_value = exc[1] with pytest.raises(exc[0]) as e: FirewallService(self.firewall).upload_file(file, 'config') assert e.value.detail['detail'] == exc[0].default_detail['detail'] @pytest.mark.parametrize('armaif_version', TEST_ARMAIF_VERSIONS) def test_check_armaif_version_validator(self, armaif_version): assert FirewallService.firewall_version_validator(armaif_version[0]) == armaif_version[1] @pytest.mark.parametrize('addr', get_addr) def test_get_addr_firewall(self, addr): firewall = ArmaIndustrialFirewall( name='IF', ip=addr[0], key='key', secret='secret', port=1500, type='firewall') assert FirewallService(firewall).get_addr(addr[1]) == 'https://{}/{}'.format(firewall.ip, addr[1].lstrip('/')) def test_firewall_download_file_task(self): """Test download file from firewall(mocked) and add to storage.""" file_name = 'test_abc.tar' tmp_mediaroot = tempfile.TemporaryDirectory().name storage_file_name = os.path.join(tmp_mediaroot, file_name) with patch('devices.services.firewall.firewall.FirewallService.download_file', lambda *args: (b'firewall__data', file_name)): with patch('devices.tasks.firewall.MEDIA_ROOT', tmp_mediaroot): with patch('devices.tasks.firewall.get_storage_path', lambda *args: file_name): pk = download_files_from_firewall(self.firewall, self.user, 'config') assert isinstance(pk, int) storage = DataStorage.objects.get(pk=pk) assert storage.file == file_name assert os.path.exists(storage_file_name) with open(storage_file_name, 'br') as file: data = file.read() assert data == b'firewall__data' assert storage.format == DataStorage.Format.XML