import logging from asgiref.sync import sync_to_async from channels.db import database_sync_to_async from channels.generic.websocket import AsyncJsonWebsocketConsumer from devices.services.firewall.firewall import get_all_aif_with_status_map from incident.services.ws_incidents import get_incident_count _log = logging.getLogger(__name__) class WSNotification(AsyncJsonWebsocketConsumer): """Consumer for websocket notification""" _group_name = 'notification' async def connect(self) -> None: if not self.scope['user'].is_authenticated: _log.warning(f'For user[{self.scope["user"]}] access denied') await self.close() _log.info(f'connect to WS, {self.channel_name}, {self.scope["user"]}') await self.channel_layer.group_add(self._group_name, self.channel_name) await self.accept() await self.send_init_data() await self.set_new_channel_name_to_user() async def disconnect(self, close_code) -> None: await self.channel_layer.group_discard(self._group_name, self.channel_name) await self.set_new_channel_name_to_user(set_default=True) async def notification(self, event): """Send notification with received data""" _log.info('Send notification') data = event.get('data', {}) await self.send_json(data) async def send_init_data(self): """Send notification with initial data. Called only one time after connection""" count = await sync_to_async(get_incident_count)() fw_statuses = await sync_to_async(get_all_aif_with_status_map)() data = { 'firewalls_status': fw_statuses, 'incident_count': count } await self.send_json(data) @database_sync_to_async def set_new_channel_name_to_user(self, set_default: bool = False): """Save the new channel_name to UserInfo model""" self.scope['user'].userinfo.channel_name = '' if set_default else self.channel_name self.scope['user'].userinfo.save()