66from contextlib import suppress
77from dataclasses import replace
88import datetime as dt
9- from enum import IntEnum
9+ from enum import Enum , IntEnum , auto
1010from functools import partial
1111import logging
1212import re
1313from typing import TYPE_CHECKING , Any , Final , NamedTuple , Protocol , Unpack
1414
15+ from homeassistant .components import mqtt
1516from homeassistant .components .script import DOMAIN as SCRIPT_DOMAIN
1617from homeassistant .core import CALLBACK_TYPE , Event , HomeAssistant , callback
17- from homeassistant .helpers import entity_registry as er , event as evt
18+ from homeassistant .helpers import (
19+ device_registry as dr ,
20+ entity_registry as er ,
21+ event as evt ,
22+ )
1823from homeassistant .helpers .device import (
1924 async_device_info_to_link_from_entity ,
2025 async_entity_id_to_device_id ,
2126)
2227from homeassistant .util import dt as dt_util
28+ from homeassistant .util .json import json_loads_object
2329
2430from .const import (
2531 CONF_END_ACTION ,
3339 DeviceId ,
3440 Effect ,
3541 ExpirationInfo ,
42+ Integration ,
3643 LampieNotificationInfo ,
3744 LampieNotificationOptionsDict ,
3845 LampieSwitchInfo ,
4855 from .coordinator import LampieUpdateCoordinator
4956
5057type ZHAEventData = dict [str , Any ]
58+ type MQTTDeviceName = str
5159
5260_LOGGER = logging .getLogger (__name__ )
5361
62+ MQTT_DOMAIN : Final = "mqtt"
5463ZHA_DOMAIN : Final = "zha"
5564ALREADY_EXPIRED : Final = 0
5665
66+ SWITCH_INTEGRATIONS = {
67+ ZHA_DOMAIN : Integration .ZHA ,
68+ MQTT_DOMAIN : Integration .Z2M ,
69+ }
70+
5771FIRMWARE_SECONDS_MAX = dt .timedelta (seconds = 60 ).total_seconds ()
5872FIRMWARE_MINUTES_MAX = dt .timedelta (minutes = 60 ).total_seconds ()
5973FIRMWARE_HOURS_MAX = dt .timedelta (hours = 134 ).total_seconds ()
86100 "button_6_double" ,
87101}
88102
103+ Z2M_COMMAND_MAP = {
104+ "config_double" : "button_3_double" ,
105+ }
106+
89107
90108class _LEDMode (IntEnum ):
91109 ALL = 1
92110 INDIVIDUAL = 3
93111
94112
113+ class _SwitchKeyType (Enum ):
114+ DEVICE_ID = auto ()
115+ MQTT_NAME = auto ()
116+
117+
118+ class _SwitchKey [T ](NamedTuple ):
119+ type : _SwitchKeyType
120+ identifier : T
121+
122+
95123class _StartScriptResult (NamedTuple ):
96124 led_config : tuple [LEDConfig , ...] | None
97125 block_activation : bool
@@ -102,6 +130,10 @@ class _EndScriptResult(NamedTuple):
102130 block_next : bool
103131
104132
133+ class _UnknownIntegrationError (Exception ):
134+ pass
135+
136+
105137class _LampieUnmanagedSwitchCoordinator :
106138 def async_update_listeners (self ) -> None :
107139 pass
@@ -130,7 +162,8 @@ def __init__(self, hass: HomeAssistant) -> None:
130162 self ._coordinators : dict [Slug , LampieUpdateCoordinator ] = {}
131163 self ._notifications : dict [Slug , LampieNotificationInfo ] = {}
132164 self ._switches : dict [SwitchId , LampieSwitchInfo ] = {}
133- self ._device_switches : dict [DeviceId , SwitchId ] = {}
165+ self ._switch_ids : dict [_SwitchKey [DeviceId | MQTTDeviceName ], SwitchId ] = {}
166+ self ._cancel_mqtt_listener : CALLBACK_TYPE | None = None
134167 self ._cancel_zha_listener : CALLBACK_TYPE = hass .bus .async_listen (
135168 "zha_event" ,
136169 self ._handle_zha_event ,
@@ -145,9 +178,20 @@ def remove_coordinator(self, coordinator: LampieUpdateCoordinator) -> None:
145178 self ._coordinators .pop (coordinator .slug )
146179 self ._update_references ()
147180
181+ async def setup (self ) -> None :
182+ self ._cancel_mqtt_listener = await mqtt .async_subscribe (
183+ self ._hass ,
184+ "zigbee2mqtt/+" ,
185+ self ._handle_z2m_message ,
186+ )
187+
148188 def teardown (self ) -> bool :
149189 if len (self ._coordinators ) == 0 :
150190 self ._cancel_zha_listener ()
191+
192+ if self ._cancel_mqtt_listener :
193+ self ._cancel_mqtt_listener ()
194+
151195 for key , expiration in (
152196 * ((key , info .expiration ) for key , info in self ._notifications .items ()),
153197 * ((key , info .expiration ) for key , info in self ._switches .items ()),
@@ -159,11 +203,17 @@ def teardown(self) -> bool:
159203 def switch_info (self , switch_id : SwitchId ) -> LampieSwitchInfo :
160204 if switch_id not in self ._switches :
161205 entity_registry = er .async_get (self ._hass )
206+ device_registry = dr .async_get (self ._hass )
162207 device_id = async_entity_id_to_device_id (self ._hass , switch_id )
208+ device = device_registry .async_get (device_id )
209+ integration = self ._switch_integration (device ) if device else None
163210 entity_entries = er .async_entries_for_device (entity_registry , device_id )
164211 local_protetction_id = None
165212 disable_clear_notification_id = None
166213
214+ if not integration :
215+ raise _UnknownIntegrationError
216+
167217 for entity_entry in entity_entries :
168218 if entity_entry .unique_id .endswith ("-local_protection" ):
169219 local_protetction_id = entity_entry .entity_id
@@ -172,8 +222,20 @@ def switch_info(self, switch_id: SwitchId) -> LampieSwitchInfo:
172222 ):
173223 disable_clear_notification_id = entity_entry .entity_id
174224
175- self ._device_switches [device_id ] = switch_id
225+ if (
226+ device
227+ and (mqtt_device_name := self ._mqtt_device_name (device )) is not None
228+ ):
229+ self ._switch_ids [
230+ _SwitchKey (_SwitchKeyType .MQTT_NAME , mqtt_device_name )
231+ ] = switch_id
232+
233+ self ._switch_ids [_SwitchKey (_SwitchKeyType .DEVICE_ID , device_id )] = (
234+ switch_id
235+ )
236+
176237 self ._switches [switch_id ] = LampieSwitchInfo (
238+ integration = integration ,
177239 led_config = (),
178240 led_config_source = LEDConfigSource (None ),
179241 local_protetction_id = local_protetction_id ,
@@ -712,9 +774,8 @@ async def _issue_switch_commands(
712774 "all" if led_mode == _LEDMode .ALL else ", " .join (updated_leds ),
713775 )
714776
715- @classmethod
716- def _switch_command_led_params (cls , led : LEDConfig ) -> dict [str , Any ]:
717- firmware_duration = cls ._firmware_duration (led .duration )
777+ def _switch_command_led_params (self , led : LEDConfig ) -> dict [str , Any ]:
778+ firmware_duration = self ._firmware_duration (led .duration )
718779
719780 return {
720781 "led_color" : int (led .color ),
@@ -727,17 +788,24 @@ def _switch_command_led_params(cls, led: LEDConfig) -> dict[str, Any]:
727788 else firmware_duration ,
728789 }
729790
730- @classmethod
731- def _firmware_duration (cls , seconds : int | None ) -> int | None :
791+ def _firmware_duration (self , seconds : int | None ) -> int | None :
732792 """Convert a timeframe to a duration supported by the switch firmware.
733793
794+ Note: Any usage of MQTT switches in the system will result in the
795+ suspension of all firmware durations, and this method will always return
796+ `None`. At the moment Zigbee2MQTT does not support the 0x24 command
797+ which is used for `led_effect_complete`, so there is no way to know the
798+ notification ended.
799+
734800 Args:
735801 seconds: The duration as a number of seconds.
736802
737803 Returns:
738804 The duration parameter value (0-255) if it can be handled by the
739805 firmware or None if it cannot be.
740806 """
807+ if any (key .type == _SwitchKeyType .MQTT_NAME for key in self ._switch_ids ):
808+ return None
741809 if seconds is None or seconds == ALREADY_EXPIRED :
742810 return None
743811 if seconds <= FIRMWARE_SECONDS_MAX :
@@ -753,17 +821,40 @@ def _filter_zha_events(
753821 self ,
754822 event_data : ZHAEventData ,
755823 ) -> bool :
824+ switch_key = _SwitchKey (_SwitchKeyType .DEVICE_ID , event_data ["device_id" ])
756825 return (
757- event_data [ "device_id" ] in self ._device_switches
826+ switch_key in self ._switch_ids
758827 and event_data ["command" ] in DISMISSAL_COMMANDS
759828 )
760829
761830 @callback
762831 async def _handle_zha_event (self , event : Event [ZHAEventData ]) -> None :
832+ await self ._handle_generic_event (
833+ command = event .data ["command" ],
834+ device_id = event .data ["device_id" ],
835+ )
836+
837+ @callback
838+ async def _handle_z2m_message (self , message : mqtt .ReceiveMessage ) -> None :
839+ command_path_base , device_name = message .topic .split ("/" , 1 )
840+ switch_key = _SwitchKey (_SwitchKeyType .MQTT_NAME , device_name )
841+ device_id = self ._switch_ids .get (switch_key )
842+
843+ if command_path_base == "zigbee2mqtt" and device_id is not None :
844+ payload = json_loads_object (message .payload )
845+ action = payload .get ("action" )
846+ command = Z2M_COMMAND_MAP .get (action )
847+
848+ if command :
849+ await self ._handle_generic_event (
850+ command = command ,
851+ device_id = device_id ,
852+ )
853+
854+ async def _handle_generic_event (self , command : str , device_id : str ) -> None :
763855 hass = self ._hass
764- command = event .data ["command" ]
765- device_id = event .data ["device_id" ]
766- switch_id = self ._device_switches [device_id ]
856+ switch_key = _SwitchKey (_SwitchKeyType .DEVICE_ID , device_id )
857+ switch_id = self ._switch_ids [switch_key ]
767858 from_state = self .switch_info (switch_id )
768859 led_config_source = from_state .led_config_source
769860 led_config = [* from_state .led_config ]
@@ -1090,8 +1181,17 @@ def _update_references(self) -> None:
10901181 )
10911182
10921183 for switch_id in switch_ids :
1184+ try :
1185+ switch_info = self .switch_info (switch_id )
1186+ except _UnknownIntegrationError :
1187+ _LOGGER .exception (
1188+ "ignoring switch %s: could not to a valid integration" ,
1189+ switch_id ,
1190+ )
1191+ continue
1192+
10931193 priorities = switch_priorities .get (switch_id ) or [slug ]
1094- expected = [* self . switch_info ( switch_id ) .priorities ]
1194+ expected = [* switch_info .priorities ]
10951195
10961196 if switch_id in processed_switches and expected != priorities :
10971197 _LOGGER .warning (
@@ -1109,13 +1209,34 @@ def _update_references(self) -> None:
11091209 continue
11101210
11111211 self ._switches [switch_id ] = replace (
1112- self . switch_info ( switch_id ) ,
1212+ switch_info ,
11131213 priorities = tuple (priorities ),
11141214 )
11151215 processed_switches .add (switch_id )
11161216
11171217 processed_slugs .append (slug )
11181218
1219+ @classmethod
1220+ def _switch_integration (cls , device : dr .DeviceEnttry ) -> Integration | None :
1221+ id_tuple = next (iter (device .identifiers ))
1222+ domain = id_tuple [0 ]
1223+ return SWITCH_INTEGRATIONS .get (domain )
1224+
1225+ @classmethod
1226+ def _mqtt_device_id (cls , device : dr .DeviceEnttry ) -> str | None :
1227+ id_tuple = next (iter (device .identifiers ))
1228+
1229+ if id_tuple [0 ] != MQTT_DOMAIN :
1230+ return None
1231+
1232+ identifier : str = id_tuple [1 ]
1233+
1234+ return identifier .split ("_" , 1 )[1 ]
1235+
1236+ @classmethod
1237+ def _mqtt_device_name (cls , device : dr .DeviceEnttry ) -> str | None :
1238+ return device .name if cls ._mqtt_device_id (device ) is not None else None
1239+
11191240
11201241def _all_clear (led_config : Sequence [LEDConfig ]) -> bool :
11211242 return all (item .effect == Effect .CLEAR for item in led_config )
0 commit comments