Skip to content

Commit bddfb48

Browse files
committed
Start mocking out MQTT support
1 parent be5acd2 commit bddfb48

File tree

11 files changed

+544
-27
lines changed

11 files changed

+544
-27
lines changed

custom_components/lampie/__init__.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,13 @@ async def async_setup_entry(hass: HomeAssistant, entry: LampieConfigEntry) -> bo
4545
_LOGGER.debug("setup %s with config:%s", entry.title, entry.data)
4646

4747
if DOMAIN not in hass.data:
48-
hass.data[DOMAIN] = LampieOrchestrator(hass)
48+
orchestrator = LampieOrchestrator(hass)
49+
hass.data[DOMAIN] = orchestrator
50+
51+
await orchestrator.setup()
4952

5053
coordinator = LampieUpdateCoordinator(hass, entry)
51-
orchestrator: LampieOrchestrator = hass.data[DOMAIN]
54+
orchestrator = hass.data[DOMAIN]
5255
orchestrator.add_coordinator(coordinator)
5356
entry.runtime_data = LampieConfigEntryRuntimeData(
5457
orchestrator=orchestrator,

custom_components/lampie/orchestrator.py

Lines changed: 136 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -6,20 +6,26 @@
66
from contextlib import suppress
77
from dataclasses import replace
88
import datetime as dt
9-
from enum import IntEnum
9+
from enum import Enum, IntEnum, auto
1010
from functools import partial
1111
import logging
1212
import re
1313
from typing import TYPE_CHECKING, Any, Final, NamedTuple, Protocol, Unpack
1414

15+
from homeassistant.components import mqtt
1516
from homeassistant.components.script import DOMAIN as SCRIPT_DOMAIN
1617
from homeassistant.core import CALLBACK_TYPE, Event, HomeAssistant, callback
17-
from homeassistant.helpers import entity_registry as er, event as evt
18+
from homeassistant.helpers import (
19+
device_registry as dr,
20+
entity_registry as er,
21+
event as evt,
22+
)
1823
from homeassistant.helpers.device import (
1924
async_device_info_to_link_from_entity,
2025
async_entity_id_to_device_id,
2126
)
2227
from homeassistant.util import dt as dt_util
28+
from homeassistant.util.json import json_loads_object
2329

2430
from .const import (
2531
CONF_END_ACTION,
@@ -33,6 +39,7 @@
3339
DeviceId,
3440
Effect,
3541
ExpirationInfo,
42+
Integration,
3643
LampieNotificationInfo,
3744
LampieNotificationOptionsDict,
3845
LampieSwitchInfo,
@@ -48,12 +55,19 @@
4855
from .coordinator import LampieUpdateCoordinator
4956

5057
type ZHAEventData = dict[str, Any]
58+
type MQTTDeviceName = str
5159

5260
_LOGGER = logging.getLogger(__name__)
5361

62+
MQTT_DOMAIN: Final = "mqtt"
5463
ZHA_DOMAIN: Final = "zha"
5564
ALREADY_EXPIRED: Final = 0
5665

66+
SWITCH_INTEGRATIONS = {
67+
ZHA_DOMAIN: Integration.ZHA,
68+
MQTT_DOMAIN: Integration.Z2M,
69+
}
70+
5771
FIRMWARE_SECONDS_MAX = dt.timedelta(seconds=60).total_seconds()
5872
FIRMWARE_MINUTES_MAX = dt.timedelta(minutes=60).total_seconds()
5973
FIRMWARE_HOURS_MAX = dt.timedelta(hours=134).total_seconds()
@@ -86,12 +100,26 @@
86100
"button_6_double",
87101
}
88102

103+
Z2M_COMMAND_MAP = {
104+
"config_double": "button_3_double",
105+
}
106+
89107

90108
class _LEDMode(IntEnum):
91109
ALL = 1
92110
INDIVIDUAL = 3
93111

94112

113+
class _SwitchKeyType(Enum):
114+
DEVICE_ID = auto()
115+
MQTT_NAME = auto()
116+
117+
118+
class _SwitchKey[T](NamedTuple):
119+
type: _SwitchKeyType
120+
identifier: T
121+
122+
95123
class _StartScriptResult(NamedTuple):
96124
led_config: tuple[LEDConfig, ...] | None
97125
block_activation: bool
@@ -102,6 +130,10 @@ class _EndScriptResult(NamedTuple):
102130
block_next: bool
103131

104132

133+
class _UnknownIntegrationError(Exception):
134+
pass
135+
136+
105137
class _LampieUnmanagedSwitchCoordinator:
106138
def async_update_listeners(self) -> None:
107139
pass
@@ -130,7 +162,8 @@ def __init__(self, hass: HomeAssistant) -> None:
130162
self._coordinators: dict[Slug, LampieUpdateCoordinator] = {}
131163
self._notifications: dict[Slug, LampieNotificationInfo] = {}
132164
self._switches: dict[SwitchId, LampieSwitchInfo] = {}
133-
self._device_switches: dict[DeviceId, SwitchId] = {}
165+
self._switch_ids: dict[_SwitchKey[DeviceId | MQTTDeviceName], SwitchId] = {}
166+
self._cancel_mqtt_listener: CALLBACK_TYPE | None = None
134167
self._cancel_zha_listener: CALLBACK_TYPE = hass.bus.async_listen(
135168
"zha_event",
136169
self._handle_zha_event,
@@ -145,9 +178,20 @@ def remove_coordinator(self, coordinator: LampieUpdateCoordinator) -> None:
145178
self._coordinators.pop(coordinator.slug)
146179
self._update_references()
147180

181+
async def setup(self) -> None:
182+
self._cancel_mqtt_listener = await mqtt.async_subscribe(
183+
self._hass,
184+
"zigbee2mqtt/+",
185+
self._handle_z2m_message,
186+
)
187+
148188
def teardown(self) -> bool:
149189
if len(self._coordinators) == 0:
150190
self._cancel_zha_listener()
191+
192+
if self._cancel_mqtt_listener:
193+
self._cancel_mqtt_listener()
194+
151195
for key, expiration in (
152196
*((key, info.expiration) for key, info in self._notifications.items()),
153197
*((key, info.expiration) for key, info in self._switches.items()),
@@ -159,11 +203,17 @@ def teardown(self) -> bool:
159203
def switch_info(self, switch_id: SwitchId) -> LampieSwitchInfo:
160204
if switch_id not in self._switches:
161205
entity_registry = er.async_get(self._hass)
206+
device_registry = dr.async_get(self._hass)
162207
device_id = async_entity_id_to_device_id(self._hass, switch_id)
208+
device = device_registry.async_get(device_id)
209+
integration = self._switch_integration(device) if device else None
163210
entity_entries = er.async_entries_for_device(entity_registry, device_id)
164211
local_protetction_id = None
165212
disable_clear_notification_id = None
166213

214+
if not integration:
215+
raise _UnknownIntegrationError
216+
167217
for entity_entry in entity_entries:
168218
if entity_entry.unique_id.endswith("-local_protection"):
169219
local_protetction_id = entity_entry.entity_id
@@ -172,8 +222,20 @@ def switch_info(self, switch_id: SwitchId) -> LampieSwitchInfo:
172222
):
173223
disable_clear_notification_id = entity_entry.entity_id
174224

175-
self._device_switches[device_id] = switch_id
225+
if (
226+
device
227+
and (mqtt_device_name := self._mqtt_device_name(device)) is not None
228+
):
229+
self._switch_ids[
230+
_SwitchKey(_SwitchKeyType.MQTT_NAME, mqtt_device_name)
231+
] = switch_id
232+
233+
self._switch_ids[_SwitchKey(_SwitchKeyType.DEVICE_ID, device_id)] = (
234+
switch_id
235+
)
236+
176237
self._switches[switch_id] = LampieSwitchInfo(
238+
integration=integration,
177239
led_config=(),
178240
led_config_source=LEDConfigSource(None),
179241
local_protetction_id=local_protetction_id,
@@ -712,9 +774,8 @@ async def _issue_switch_commands(
712774
"all" if led_mode == _LEDMode.ALL else ", ".join(updated_leds),
713775
)
714776

715-
@classmethod
716-
def _switch_command_led_params(cls, led: LEDConfig) -> dict[str, Any]:
717-
firmware_duration = cls._firmware_duration(led.duration)
777+
def _switch_command_led_params(self, led: LEDConfig) -> dict[str, Any]:
778+
firmware_duration = self._firmware_duration(led.duration)
718779

719780
return {
720781
"led_color": int(led.color),
@@ -727,17 +788,24 @@ def _switch_command_led_params(cls, led: LEDConfig) -> dict[str, Any]:
727788
else firmware_duration,
728789
}
729790

730-
@classmethod
731-
def _firmware_duration(cls, seconds: int | None) -> int | None:
791+
def _firmware_duration(self, seconds: int | None) -> int | None:
732792
"""Convert a timeframe to a duration supported by the switch firmware.
733793
794+
Note: Any usage of MQTT switches in the system will result in the
795+
suspension of all firmware durations, and this method will always return
796+
`None`. At the moment Zigbee2MQTT does not support the 0x24 command
797+
which is used for `led_effect_complete`, so there is no way to know the
798+
notification ended.
799+
734800
Args:
735801
seconds: The duration as a number of seconds.
736802
737803
Returns:
738804
The duration parameter value (0-255) if it can be handled by the
739805
firmware or None if it cannot be.
740806
"""
807+
if any(key.type == _SwitchKeyType.MQTT_NAME for key in self._switch_ids):
808+
return None
741809
if seconds is None or seconds == ALREADY_EXPIRED:
742810
return None
743811
if seconds <= FIRMWARE_SECONDS_MAX:
@@ -753,17 +821,40 @@ def _filter_zha_events(
753821
self,
754822
event_data: ZHAEventData,
755823
) -> bool:
824+
switch_key = _SwitchKey(_SwitchKeyType.DEVICE_ID, event_data["device_id"])
756825
return (
757-
event_data["device_id"] in self._device_switches
826+
switch_key in self._switch_ids
758827
and event_data["command"] in DISMISSAL_COMMANDS
759828
)
760829

761830
@callback
762831
async def _handle_zha_event(self, event: Event[ZHAEventData]) -> None:
832+
await self._handle_generic_event(
833+
command=event.data["command"],
834+
device_id=event.data["device_id"],
835+
)
836+
837+
@callback
838+
async def _handle_z2m_message(self, message: mqtt.ReceiveMessage) -> None:
839+
command_path_base, device_name = message.topic.split("/", 1)
840+
switch_key = _SwitchKey(_SwitchKeyType.MQTT_NAME, device_name)
841+
device_id = self._switch_ids.get(switch_key)
842+
843+
if command_path_base == "zigbee2mqtt" and device_id is not None:
844+
payload = json_loads_object(message.payload)
845+
action = payload.get("action")
846+
command = Z2M_COMMAND_MAP.get(action)
847+
848+
if command:
849+
await self._handle_generic_event(
850+
command=command,
851+
device_id=device_id,
852+
)
853+
854+
async def _handle_generic_event(self, command: str, device_id: str) -> None:
763855
hass = self._hass
764-
command = event.data["command"]
765-
device_id = event.data["device_id"]
766-
switch_id = self._device_switches[device_id]
856+
switch_key = _SwitchKey(_SwitchKeyType.DEVICE_ID, device_id)
857+
switch_id = self._switch_ids[switch_key]
767858
from_state = self.switch_info(switch_id)
768859
led_config_source = from_state.led_config_source
769860
led_config = [*from_state.led_config]
@@ -1090,8 +1181,17 @@ def _update_references(self) -> None:
10901181
)
10911182

10921183
for switch_id in switch_ids:
1184+
try:
1185+
switch_info = self.switch_info(switch_id)
1186+
except _UnknownIntegrationError:
1187+
_LOGGER.exception(
1188+
"ignoring switch %s: could not to a valid integration",
1189+
switch_id,
1190+
)
1191+
continue
1192+
10931193
priorities = switch_priorities.get(switch_id) or [slug]
1094-
expected = [*self.switch_info(switch_id).priorities]
1194+
expected = [*switch_info.priorities]
10951195

10961196
if switch_id in processed_switches and expected != priorities:
10971197
_LOGGER.warning(
@@ -1109,13 +1209,34 @@ def _update_references(self) -> None:
11091209
continue
11101210

11111211
self._switches[switch_id] = replace(
1112-
self.switch_info(switch_id),
1212+
switch_info,
11131213
priorities=tuple(priorities),
11141214
)
11151215
processed_switches.add(switch_id)
11161216

11171217
processed_slugs.append(slug)
11181218

1219+
@classmethod
1220+
def _switch_integration(cls, device: dr.DeviceEnttry) -> Integration | None:
1221+
id_tuple = next(iter(device.identifiers))
1222+
domain = id_tuple[0]
1223+
return SWITCH_INTEGRATIONS.get(domain)
1224+
1225+
@classmethod
1226+
def _mqtt_device_id(cls, device: dr.DeviceEnttry) -> str | None:
1227+
id_tuple = next(iter(device.identifiers))
1228+
1229+
if id_tuple[0] != MQTT_DOMAIN:
1230+
return None
1231+
1232+
identifier: str = id_tuple[1]
1233+
1234+
return identifier.split("_", 1)[1]
1235+
1236+
@classmethod
1237+
def _mqtt_device_name(cls, device: dr.DeviceEnttry) -> str | None:
1238+
return device.name if cls._mqtt_device_id(device) is not None else None
1239+
11191240

11201241
def _all_clear(led_config: Sequence[LEDConfig]) -> bool:
11211242
return all(item.effect == Effect.CLEAR for item in led_config)

custom_components/lampie/services.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ override:
1616
entity:
1717
multiple: true
1818
filter:
19-
integration: zha
2019
domain:
2120
- light
2221
- fan

custom_components/lampie/types.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -219,6 +219,13 @@ class InvalidColor(Exception):
219219
index: int | None = None
220220

221221

222+
class Integration(StrEnum):
223+
"""Switch integration type."""
224+
225+
ZHA = auto()
226+
Z2M = auto()
227+
228+
222229
@dataclass(frozen=True)
223230
class ExpirationInfo:
224231
"""Storage of expiration info."""
@@ -255,6 +262,7 @@ class LampieSwitchInfo:
255262
disable_clear_notification_id: EntityId | None = None
256263
priorities: tuple[Slug, ...] = field(default_factory=tuple)
257264
expiration: ExpirationInfo = field(default_factory=ExpirationInfo)
265+
integration: Integration = Integration.ZHA
258266

259267

260268
class LampieSwitchOptionsDict(TypedDict):

tests/conftest.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
"""Fixtures for testing."""
22

3+
from collections.abc import Generator
34
import logging
5+
from unittest.mock import AsyncMock, Mock, patch
46

57
from freezegun.api import FrozenDateTimeFactory
68
from homeassistant.core import HomeAssistant
@@ -52,6 +54,19 @@ def auto_enable_custom_integrations(enable_custom_integrations):
5254
return
5355

5456

57+
@pytest.fixture(name="mqtt_subscribe", autouse=True)
58+
def auto_patch_mqtt_async_subscribe() -> Generator[AsyncMock]:
59+
"""Patch mqtt.async_subscribe."""
60+
61+
unsub = Mock()
62+
63+
with patch(
64+
"homeassistant.components.mqtt.async_subscribe", return_value=unsub
65+
) as mock_subscribe:
66+
mock_subscribe._unsub = unsub
67+
yield mock_subscribe
68+
69+
5570
@pytest.fixture
5671
def snapshot(snapshot: SnapshotAssertion) -> SnapshotAssertion:
5772
"""Return snapshot assertion fixture with the Home Assistant extension."""

tests/snapshots/test_diagnostics.ambr

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,6 +93,7 @@
9393
'cancel_listener': None,
9494
'started_at': None,
9595
}),
96+
'integration': 'zha',
9697
'led_config': list([
9798
dict({
9899
'brightness': 100.0,

0 commit comments

Comments
 (0)