From f218dd4ebe03b8a116239a7c785e13235cde6793 Mon Sep 17 00:00:00 2001 From: Allen Porter Date: Thu, 2 Jul 2026 14:17:00 -0700 Subject: [PATCH] refactor: decouple B01 (Q7/Q10) protocol layer from transport layer --- roborock/devices/device_manager.py | 15 +- roborock/devices/rpc/b01_q10_channel.py | 54 +++- roborock/devices/rpc/b01_q7_channel.py | 92 +++++- roborock/devices/traits/b01/q10/__init__.py | 9 +- roborock/devices/traits/b01/q10/command.py | 11 +- roborock/devices/traits/b01/q7/__init__.py | 39 ++- .../devices/traits/b01/q7/clean_summary.py | 14 +- roborock/devices/traits/b01/q7/map.py | 12 +- roborock/devices/traits/b01/q7/map_content.py | 15 +- tests/devices/rpc/test_b01_q10_channel.py | 102 ++++++ tests/devices/rpc/test_b01_q7_channel.py | 295 ++++++++++++++++++ tests/devices/traits/b01/q10/conftest.py | 54 ++++ tests/devices/traits/b01/q10/test_map.py | 39 ++- tests/devices/traits/b01/q10/test_remote.py | 27 +- tests/devices/traits/b01/q10/test_settings.py | 48 +-- tests/devices/traits/b01/q10/test_status.py | 73 ++--- tests/devices/traits/b01/q10/test_vacuum.py | 35 +-- tests/devices/traits/b01/q7/__init__.py | 52 +-- tests/devices/traits/b01/q7/conftest.py | 64 ++-- .../traits/b01/q7/test_clean_summary.py | 21 +- tests/devices/traits/b01/q7/test_init.py | 250 +++++---------- tests/devices/traits/b01/q7/test_map.py | 41 +-- .../devices/traits/b01/q7/test_map_content.py | 72 ++--- 23 files changed, 906 insertions(+), 528 deletions(-) create mode 100644 tests/devices/rpc/test_b01_q10_channel.py create mode 100644 tests/devices/rpc/test_b01_q7_channel.py create mode 100644 tests/devices/traits/b01/q10/conftest.py diff --git a/roborock/devices/device_manager.py b/roborock/devices/device_manager.py index 5f82bfca..fe497b71 100644 --- a/roborock/devices/device_manager.py +++ b/roborock/devices/device_manager.py @@ -26,6 +26,8 @@ from roborock.web_api import RoborockApiClient, UserWebApiClient from .cache import Cache, DeviceCache, NoCache +from .rpc.b01_q7_channel import create_b01_q7_channel +from .rpc.b01_q10_channel import create_b01_q10_channel from .rpc.v1_channel import create_v1_channel from .traits import Trait, a01, b01, v1 from .transport.channel import Channel @@ -254,13 +256,22 @@ def device_creator(home_data: HomeData, device: HomeDataDevice, product: HomeDat channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device) trait = a01.create(product, channel) case DeviceVersion.B01: - channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device) + mqtt_channel = create_mqtt_channel(user_data, mqtt_params, mqtt_session, device) model_part = product.model.split(".")[-1] if "ss" in model_part: + b01_q10_channel = create_b01_q10_channel(mqtt_channel) + channel = b01_q10_channel trait = b01.q10.create(channel) elif "sc" in model_part: # Q7 devices start with 'sc' in their model naming. - trait = b01.q7.create(product, device, channel) + b01_q7_channel = create_b01_q7_channel(device, product, mqtt_channel) + channel = b01_q7_channel + trait = b01.q7.create( + product, + device, + rpc_channel=b01_q7_channel, + map_rpc_channel=b01_q7_channel, + ) else: raise UnsupportedDeviceError(f"Device {device.name} has unsupported B01 model: {product.model}") case _: diff --git a/roborock/devices/rpc/b01_q10_channel.py b/roborock/devices/rpc/b01_q10_channel.py index 49758326..a8062da9 100644 --- a/roborock/devices/rpc/b01_q10_channel.py +++ b/roborock/devices/rpc/b01_q10_channel.py @@ -1,9 +1,13 @@ """Thin wrapper around the MQTT channel for Roborock B01 Q10 devices.""" +from __future__ import annotations + import logging -from collections.abc import AsyncGenerator +from collections.abc import AsyncGenerator, Callable +from typing import Protocol from roborock.data.b01_q10.b01_q10_code_mappings import B01_Q10_DP +from roborock.devices.transport.channel import Channel from roborock.devices.transport.mqtt_channel import MqttChannel from roborock.exceptions import RoborockException from roborock.protocols.b01_q10_protocol import ( @@ -12,10 +16,53 @@ decode_message, encode_mqtt_payload, ) +from roborock.roborock_message import RoborockMessage _LOGGER = logging.getLogger(__name__) +class Q10RpcChannel(Protocol): + """Protocol for Q10 RPC channels.""" + + async def send_command( + self, + command: B01_Q10_DP, + params: ParamsType = None, + ) -> None: + """Send a command on the MQTT channel, without waiting for a response.""" + ... + + +class B01Q10Channel(Channel, Q10RpcChannel): + """Unified B01 Q10 channel wrapping MQTT transport.""" + + def __init__(self, mqtt_channel: MqttChannel) -> None: + self._mqtt_channel = mqtt_channel + + @property + def is_connected(self) -> bool: + return self._mqtt_channel.is_connected + + @property + def is_local_connected(self) -> bool: + return False + + async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]: + return await self._mqtt_channel.subscribe(callback) + + async def subscribe_stream(self) -> AsyncGenerator[Q10Message, None]: + """Stream decoded Q10 messages received via MQTT.""" + async for msg in stream_decoded_messages(self._mqtt_channel): + yield msg + + async def send_command( + self, + command: B01_Q10_DP, + params: ParamsType = None, + ) -> None: + await send_command(self._mqtt_channel, command, params) + + async def stream_decoded_messages( mqtt_channel: MqttChannel, ) -> AsyncGenerator[Q10Message, None]: @@ -59,3 +106,8 @@ async def send_command( ex, ) raise + + +def create_b01_q10_channel(mqtt_channel: MqttChannel) -> B01Q10Channel: + """Create a B01Q10Channel wrapping MQTT transport.""" + return B01Q10Channel(mqtt_channel) diff --git a/roborock/devices/rpc/b01_q7_channel.py b/roborock/devices/rpc/b01_q7_channel.py index b9926e3e..b53f706e 100644 --- a/roborock/devices/rpc/b01_q7_channel.py +++ b/roborock/devices/rpc/b01_q7_channel.py @@ -6,14 +6,20 @@ import json import logging from collections.abc import Callable -from typing import TypeAlias, TypeVar +from typing import Protocol, TypeAlias, TypeVar +from roborock.data import HomeDataDevice, HomeDataProduct +from roborock.devices.transport.channel import Channel from roborock.devices.transport.mqtt_channel import MqttChannel from roborock.exceptions import RoborockException from roborock.protocols.b01_q7_protocol import ( + B01_Q7_DPS, B01_VERSION, + CommandType, MapKey, + ParamsType, Q7RequestMessage, + create_map_key, decode_map_payload, decode_rpc_response, encode_mqtt_payload, @@ -26,6 +32,30 @@ DecodedB01Response: TypeAlias = dict[str, object] | str +class Q7RpcChannel(Protocol): + """Protocol for Q7 RPC channels.""" + + async def send_command( + self, + command: CommandType, + params: ParamsType = None, + ) -> DecodedB01Response: + """Send a command and get a decoded response.""" + ... + + +class Q7MapRpcChannel(Protocol): + """Protocol for Q7 map RPC channels.""" + + async def send_map_command( + self, + command: CommandType, + params: ParamsType = None, + ) -> bytes: + """Send a map command and get decoded bytes.""" + ... + + def _matches_map_response(response_message: RoborockMessage, *, version: bytes | None) -> bytes | None: """Return raw map payload bytes for matching MAP_RESPONSE messages.""" if ( @@ -120,39 +150,55 @@ def find_response(response_message: RoborockMessage) -> DecodedB01Response | Non raise RoborockException(f"B01 command timed out after {_TIMEOUT}s ({request_message})") from ex except RoborockException as ex: _LOGGER.warning( - "Error sending B01 decoded command (%ss): %s", + "Error sending B01 decoded command (%s): %s", request_message, ex, ) raise except Exception as ex: _LOGGER.exception( - "Error sending B01 decoded command (%ss): %s", + "Error sending B01 decoded command (%s): %s", request_message, ex, ) raise -class MapRpcChannel: - """RPC channel for map-related commands on B01/Q7 devices.""" +class B01Q7Channel(Channel, Q7RpcChannel, Q7MapRpcChannel): + """Unified B01 Q7 channel wrapping MQTT transport.""" def __init__(self, mqtt_channel: MqttChannel, map_key: MapKey) -> None: self._mqtt_channel = mqtt_channel self._map_key = map_key - async def send_map_command(self, request_message: Q7RequestMessage) -> bytes: - """Send a map upload command and return decoded SCMap bytes. - - This publishes the request and waits for a matching ``MAP_RESPONSE`` message - with the correct protocol version. The raw ``MAP_RESPONSE`` payload bytes are - then decoded/inflated via :func:`decode_map_payload` using this channel's - ``map_key``, and the resulting SCMap bytes are returned. - - The returned value is the decoded map data bytes suitable for passing to the - map parser library, not the raw MQTT ``MAP_RESPONSE`` payload bytes. - """ + @property + def is_connected(self) -> bool: + return self._mqtt_channel.is_connected + + @property + def is_local_connected(self) -> bool: + return False + + async def subscribe(self, callback: Callable[[RoborockMessage], None]) -> Callable[[], None]: + return await self._mqtt_channel.subscribe(callback) + + async def send_command( + self, + command: CommandType, + params: ParamsType = None, + ) -> DecodedB01Response: + return await send_decoded_command( + self._mqtt_channel, + Q7RequestMessage(dps=B01_Q7_DPS, command=command, params=params), + ) + async def send_map_command( + self, + command: CommandType, + params: ParamsType = None, + ) -> bytes: + """Send a map upload command and return decoded SCMap bytes.""" + request_message = Q7RequestMessage(dps=B01_Q7_DPS, command=command, params=params) try: raw_payload = await _send_command( self._mqtt_channel, @@ -163,3 +209,17 @@ async def send_map_command(self, request_message: Q7RequestMessage) -> bytes: raise RoborockException(f"B01 map command timed out after {_TIMEOUT}s ({request_message})") from ex return decode_map_payload(raw_payload, map_key=self._map_key) + + +def create_b01_q7_channel( + device: HomeDataDevice, + product: HomeDataProduct, + mqtt_channel: MqttChannel, +) -> B01Q7Channel: + """Create a B01Q7Channel for the given device.""" + if device.sn is None or product.model is None: + raise RoborockException( + f"Device serial number and product model are required (sn: {device.sn}, model: {product.model})" + ) + map_key = create_map_key(serial=device.sn, model=product.model) + return B01Q7Channel(mqtt_channel, map_key) diff --git a/roborock/devices/traits/b01/q10/__init__.py b/roborock/devices/traits/b01/q10/__init__.py index 8f4a5202..452dd551 100644 --- a/roborock/devices/traits/b01/q10/__init__.py +++ b/roborock/devices/traits/b01/q10/__init__.py @@ -4,9 +4,8 @@ import logging from roborock.data.b01_q10.b01_q10_code_mappings import B01_Q10_DP -from roborock.devices.rpc.b01_q10_channel import stream_decoded_messages +from roborock.devices.rpc.b01_q10_channel import B01Q10Channel from roborock.devices.traits import Trait -from roborock.devices.transport.mqtt_channel import MqttChannel from roborock.map.b01_q10_map_parser import Q10MapPacket, Q10TracePacket from roborock.protocols.b01_q10_protocol import Q10DpsUpdate, Q10Message @@ -78,7 +77,7 @@ class Q10PropertiesApi(Trait): map: MapContentTrait """Trait for fetching the current parsed map (image + rooms).""" - def __init__(self, channel: MqttChannel) -> None: + def __init__(self, channel: B01Q10Channel) -> None: """Initialize the B01Props API.""" self._channel = channel self.command = CommandTrait(channel) @@ -127,7 +126,7 @@ async def refresh(self) -> None: async def _subscribe_loop(self) -> None: """Persistent loop dispatching decoded messages to the read-model traits.""" - async for message in stream_decoded_messages(self._channel): + async for message in self._channel.subscribe_stream(): self._handle_message(message) def _handle_message(self, message: Q10Message) -> None: @@ -150,6 +149,6 @@ def _handle_message(self, message: Q10Message) -> None: trait.update_from_dps(message.dps) -def create(channel: MqttChannel) -> Q10PropertiesApi: +def create(channel: B01Q10Channel) -> Q10PropertiesApi: """Create traits for B01 devices.""" return Q10PropertiesApi(channel) diff --git a/roborock/devices/traits/b01/q10/command.py b/roborock/devices/traits/b01/q10/command.py index 4aa3a593..b8f06445 100644 --- a/roborock/devices/traits/b01/q10/command.py +++ b/roborock/devices/traits/b01/q10/command.py @@ -1,8 +1,7 @@ from typing import Any from roborock.data.b01_q10.b01_q10_code_mappings import B01_Q10_DP -from roborock.devices.rpc.b01_q10_channel import send_command -from roborock.devices.transport.mqtt_channel import MqttChannel +from roborock.devices.rpc.b01_q10_channel import Q10RpcChannel from roborock.protocols.b01_q10_protocol import ParamsType @@ -15,9 +14,9 @@ class CommandTrait: available. """ - def __init__(self, channel: MqttChannel) -> None: + def __init__(self, rpc_channel: Q10RpcChannel) -> None: """Initialize the CommandTrait.""" - self._channel = channel + self._rpc_channel = rpc_channel async def send(self, command: B01_Q10_DP, params: ParamsType = None) -> Any: """Send a command to the device. @@ -27,6 +26,6 @@ async def send(self, command: B01_Q10_DP, params: ParamsType = None) -> Any: caller to ensure that any traits affected by the command are refreshed as needed. """ - if not self._channel: + if not self._rpc_channel: raise ValueError("Device trait in invalid state") - return await send_command(self._channel, command, params=params) + return await self._rpc_channel.send_command(command, params=params) diff --git a/roborock/devices/traits/b01/q7/__init__.py b/roborock/devices/traits/b01/q7/__init__.py index a48dc344..325c84e1 100644 --- a/roborock/devices/traits/b01/q7/__init__.py +++ b/roborock/devices/traits/b01/q7/__init__.py @@ -18,11 +18,10 @@ SCWindMapping, WaterLevelMapping, ) -from roborock.devices.rpc.b01_q7_channel import MapRpcChannel, send_decoded_command +from roborock.devices.rpc.b01_q7_channel import Q7MapRpcChannel, Q7RpcChannel from roborock.devices.traits import Trait -from roborock.devices.transport.mqtt_channel import MqttChannel from roborock.exceptions import RoborockException -from roborock.protocols.b01_q7_protocol import B01_Q7_DPS, CommandType, ParamsType, Q7RequestMessage, create_map_key +from roborock.protocols.b01_q7_protocol import CommandType, ParamsType from roborock.roborock_message import RoborockB01Props from roborock.roborock_typing import RoborockB01Q7Methods @@ -53,10 +52,14 @@ class Q7PropertiesApi(Trait): """Trait for fetching parsed current map content.""" def __init__( - self, channel: MqttChannel, map_rpc_channel: MapRpcChannel, device: HomeDataDevice, product: HomeDataProduct + self, + rpc_channel: Q7RpcChannel, + map_rpc_channel: Q7MapRpcChannel, + device: HomeDataDevice, + product: HomeDataProduct, ) -> None: """Initialize the Q7 API.""" - self._channel = channel + self._rpc_channel = rpc_channel self._map_rpc_channel = map_rpc_channel self._device = device self._product = product @@ -64,8 +67,8 @@ def __init__( if not device.sn or not product.model: raise ValueError("B01 Q7 map content requires device serial number and product model metadata") - self.clean_summary = CleanSummaryTrait(channel) - self.map = MapTrait(channel) + self.clean_summary = CleanSummaryTrait(rpc_channel) + self.map = MapTrait(rpc_channel) self.map_content = MapContentTrait( self._map_rpc_channel, self.map, @@ -199,17 +202,23 @@ async def find_me(self) -> None: async def send(self, command: CommandType, params: ParamsType) -> Any: """Send a command to the device.""" - return await send_decoded_command( - self._channel, - Q7RequestMessage(dps=B01_Q7_DPS, command=command, params=params), - ) + return await self._rpc_channel.send_command(command, params) -def create(product: HomeDataProduct, device: HomeDataDevice, channel: MqttChannel) -> Q7PropertiesApi: +def create( + product: HomeDataProduct, + device: HomeDataDevice, + rpc_channel: Q7RpcChannel, + map_rpc_channel: Q7MapRpcChannel, +) -> Q7PropertiesApi: """Create traits for B01 Q7 devices.""" if device.sn is None or product.model is None: raise RoborockException( - f"Device serial number and product model are required (sn:: {device.sn}, model: {product.model})" + f"Device serial number and product model are required (sn: {device.sn}, model: {product.model})" ) - map_rpc_channel = MapRpcChannel(channel, map_key=create_map_key(serial=device.sn, model=product.model)) - return Q7PropertiesApi(channel, device=device, product=product, map_rpc_channel=map_rpc_channel) + return Q7PropertiesApi( + rpc_channel=rpc_channel, + map_rpc_channel=map_rpc_channel, + device=device, + product=product, + ) diff --git a/roborock/devices/traits/b01/q7/clean_summary.py b/roborock/devices/traits/b01/q7/clean_summary.py index 65fea0e8..c2b62cfa 100644 --- a/roborock/devices/traits/b01/q7/clean_summary.py +++ b/roborock/devices/traits/b01/q7/clean_summary.py @@ -7,11 +7,9 @@ import logging from roborock import CleanRecordDetail, CleanRecordList, CleanRecordSummary -from roborock.devices.rpc.b01_q7_channel import send_decoded_command +from roborock.devices.rpc.b01_q7_channel import Q7RpcChannel from roborock.devices.traits import Trait -from roborock.devices.transport.mqtt_channel import MqttChannel from roborock.exceptions import RoborockException -from roborock.protocols.b01_q7_protocol import B01_Q7_DPS, Q7RequestMessage from roborock.roborock_typing import RoborockB01Q7Methods __all__ = [ @@ -24,11 +22,11 @@ class CleanSummaryTrait(CleanRecordSummary, Trait): """B01/Q7 clean summary + clean record access (via record list service).""" - def __init__(self, channel: MqttChannel) -> None: + def __init__(self, channel: Q7RpcChannel) -> None: """Initialize the clean summary trait. Args: - channel: MQTT channel used to communicate with the device. + channel: RPC channel used to communicate with the device. """ super().__init__() self._channel = channel @@ -46,9 +44,9 @@ async def refresh(self) -> None: async def _get_record_list(self) -> CleanRecordList: """Fetch the raw device clean record list (`service.get_record_list`).""" - result = await send_decoded_command( - self._channel, - Q7RequestMessage(dps=B01_Q7_DPS, command=RoborockB01Q7Methods.GET_RECORD_LIST, params={}), + result = await self._channel.send_command( + command=RoborockB01Q7Methods.GET_RECORD_LIST, + params={}, ) if not isinstance(result, dict): diff --git a/roborock/devices/traits/b01/q7/map.py b/roborock/devices/traits/b01/q7/map.py index f367e407..50bea18f 100644 --- a/roborock/devices/traits/b01/q7/map.py +++ b/roborock/devices/traits/b01/q7/map.py @@ -1,11 +1,9 @@ """Map trait for B01 Q7 devices.""" from roborock.data import Q7MapList -from roborock.devices.rpc.b01_q7_channel import send_decoded_command +from roborock.devices.rpc.b01_q7_channel import Q7RpcChannel from roborock.devices.traits import Trait -from roborock.devices.transport.mqtt_channel import MqttChannel from roborock.exceptions import RoborockException -from roborock.protocols.b01_q7_protocol import B01_Q7_DPS, Q7RequestMessage from roborock.roborock_typing import RoborockB01Q7Methods @@ -16,15 +14,15 @@ class MapTrait(Q7MapList, Trait): current map ID to fetch. """ - def __init__(self, channel: MqttChannel) -> None: + def __init__(self, channel: Q7RpcChannel) -> None: super().__init__() self._channel = channel async def refresh(self) -> None: """Refresh cached map list metadata from the device.""" - response = await send_decoded_command( - self._channel, - Q7RequestMessage(dps=B01_Q7_DPS, command=RoborockB01Q7Methods.GET_MAP_LIST, params={}), + response = await self._channel.send_command( + command=RoborockB01Q7Methods.GET_MAP_LIST, + params={}, ) if not isinstance(response, dict): raise RoborockException( diff --git a/roborock/devices/traits/b01/q7/map_content.py b/roborock/devices/traits/b01/q7/map_content.py index afc36d02..0becf91a 100644 --- a/roborock/devices/traits/b01/q7/map_content.py +++ b/roborock/devices/traits/b01/q7/map_content.py @@ -14,11 +14,10 @@ from vacuum_map_parser_base.map_data import MapData from roborock.data import RoborockBase -from roborock.devices.rpc.b01_q7_channel import MapRpcChannel +from roborock.devices.rpc.b01_q7_channel import Q7MapRpcChannel from roborock.devices.traits import Trait from roborock.exceptions import RoborockException from roborock.map.b01_map_parser import B01MapParser, B01MapParserConfig -from roborock.protocols.b01_q7_protocol import B01_Q7_DPS, Q7RequestMessage from roborock.roborock_typing import RoborockB01Q7Methods from .map import MapTrait @@ -55,7 +54,7 @@ class MapContentTrait(MapContent, Trait): def __init__( self, - map_rpc_channel: MapRpcChannel, + map_rpc_channel: Q7MapRpcChannel, map_trait: MapTrait, *, map_parser_config: B01MapParserConfig | None = None, @@ -77,13 +76,11 @@ async def refresh(self) -> None: if (map_id := self._map_trait.current_map_id) is None: raise RoborockException("Unable to determine current map ID") - request = Q7RequestMessage( - dps=B01_Q7_DPS, - command=RoborockB01Q7Methods.UPLOAD_BY_MAPID, - params={"map_id": map_id}, - ) async with self._map_command_lock: - raw_payload = await self._map_rpc_channel.send_map_command(request) + raw_payload = await self._map_rpc_channel.send_map_command( + RoborockB01Q7Methods.UPLOAD_BY_MAPID, + {"map_id": map_id}, + ) try: parsed_data = self._map_parser.parse(raw_payload) diff --git a/tests/devices/rpc/test_b01_q10_channel.py b/tests/devices/rpc/test_b01_q10_channel.py new file mode 100644 index 00000000..0ffcc71f --- /dev/null +++ b/tests/devices/rpc/test_b01_q10_channel.py @@ -0,0 +1,102 @@ +import json +from collections.abc import AsyncGenerator + +import pytest + +from roborock.data.b01_q10.b01_q10_code_mappings import B01_Q10_DP +from roborock.devices.rpc.b01_q10_channel import create_b01_q10_channel +from roborock.exceptions import RoborockException +from roborock.protocols.b01_q10_protocol import Q10DpsUpdate +from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol +from tests.fixtures.channel_fixtures import FakeChannel + + +@pytest.fixture(name="fake_channel") +def fake_channel_fixture() -> FakeChannel: + return FakeChannel() + + +async def test_create_b01_q10_channel(fake_channel: FakeChannel) -> None: + channel = create_b01_q10_channel(fake_channel) # type: ignore[arg-type] + assert channel is not None + assert channel.is_connected is False + assert channel.is_local_connected is False + + +async def test_q10_channel_send_command(fake_channel: FakeChannel) -> None: + channel = create_b01_q10_channel(fake_channel) # type: ignore[arg-type] + await channel.send_command(B01_Q10_DP.VOLUME, 50) + + assert len(fake_channel.published_messages) == 1 + message = fake_channel.published_messages[0] + assert message.protocol == RoborockMessageProtocol.RPC_REQUEST + assert message.payload is not None + payload_data = json.loads(message.payload.decode()) + assert payload_data == {"dps": {"26": 50}} + + +async def test_q10_channel_send_command_error(fake_channel: FakeChannel) -> None: + channel = create_b01_q10_channel(fake_channel) # type: ignore[arg-type] + fake_channel.publish_side_effect = RoborockException("Publish error") + + with pytest.raises(RoborockException, match="Publish error"): + await channel.send_command(B01_Q10_DP.VOLUME, 50) + + +async def test_q10_channel_subscribe_stream(fake_channel: FakeChannel) -> None: + channel = create_b01_q10_channel(fake_channel) # type: ignore[arg-type] + + async def simulate_messages() -> AsyncGenerator[RoborockMessage, None]: + dps_payload = {"dps": {"26": 30}} + # Valid message + yield RoborockMessage( + protocol=RoborockMessageProtocol.RPC_RESPONSE, + payload=json.dumps(dps_payload).encode(), + version=b"B01", + ) + # Invalid message (causes decode failure/RoborockException) + yield RoborockMessage( + protocol=RoborockMessageProtocol.RPC_RESPONSE, + payload=b"invalid-json{", + version=b"B01", + ) + # JSON but no dps key (returns None from decoder) + yield RoborockMessage( + protocol=RoborockMessageProtocol.RPC_RESPONSE, + payload=b'{"not_dps": 1}', + version=b"B01", + ) + # Another valid message + dps_payload_2 = {"dps": {"26": 40}} + yield RoborockMessage( + protocol=RoborockMessageProtocol.RPC_RESPONSE, + payload=json.dumps(dps_payload_2).encode(), + version=b"B01", + ) + + # Patch fake_channel.subscribe_stream + setattr(fake_channel, "subscribe_stream", MockStream(simulate_messages())) + + messages = [] + async for msg in channel.subscribe_stream(): + messages.append(msg) + + assert len(messages) == 2 + assert isinstance(messages[0], Q10DpsUpdate) + assert messages[0].dps[B01_Q10_DP.VOLUME] == 30 + assert isinstance(messages[1], Q10DpsUpdate) + assert messages[1].dps[B01_Q10_DP.VOLUME] == 40 + + +class MockStream: + def __init__(self, generator: AsyncGenerator[RoborockMessage, None]) -> None: + self._generator = generator + + def __call__(self, *args, **kwargs): + return self + + def __aiter__(self): + return self + + async def __anext__(self): + return await self._generator.__anext__() diff --git a/tests/devices/rpc/test_b01_q7_channel.py b/tests/devices/rpc/test_b01_q7_channel.py new file mode 100644 index 00000000..21c6d31a --- /dev/null +++ b/tests/devices/rpc/test_b01_q7_channel.py @@ -0,0 +1,295 @@ +import asyncio +import json +import math +import time +from collections.abc import Generator +from typing import Any, cast +from unittest.mock import patch + +import pytest +from Crypto.Cipher import AES +from Crypto.Util.Padding import pad, unpad + +from roborock.data import HomeDataDevice, HomeDataProduct, RoborockCategory +from roborock.devices.rpc.b01_q7_channel import ( + create_b01_q7_channel, + send_decoded_command, +) +from roborock.exceptions import RoborockException +from roborock.protocols.b01_q7_protocol import B01_VERSION, Q7RequestMessage +from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol +from tests.fixtures.channel_fixtures import FakeChannel + + +class B01MessageBuilder: + """Helper class to build B01 RPC response messages for tests.""" + + def __init__(self) -> None: + self.msg_id = 123456789 + self.seq = 2020 + + def build(self, data: dict[str, Any] | str, code: int | None = None) -> RoborockMessage: + """Build an encoded B01 RPC response message.""" + message: dict[str, Any] = { + "msgId": str(self.msg_id), + "data": data, + } + if code is not None: + message["code"] = code + return self._build_dps(message) + + def _build_dps(self, message: dict[str, Any] | str) -> RoborockMessage: + """Build an encoded B01 RPC response message.""" + dps_payload = {"dps": {"10000": json.dumps(message)}} + self.seq += 1 + return RoborockMessage( + protocol=RoborockMessageProtocol.RPC_RESPONSE, + payload=pad( + json.dumps(dps_payload).encode(), + AES.block_size, + ), + version=b"B01", + seq=self.seq, + ) + + def build_map_response(self, payload: bytes) -> RoborockMessage: + """Build a dummy MAP_RESPONSE message.""" + self.seq += 1 + return RoborockMessage( + protocol=RoborockMessageProtocol.MAP_RESPONSE, + payload=payload, + version=b"B01", + seq=self.seq, + ) + + +@pytest.fixture(name="fake_channel") +def fake_channel_fixture() -> FakeChannel: + return FakeChannel() + + +@pytest.fixture(name="product") +def product_fixture() -> HomeDataProduct: + return HomeDataProduct( + id="product-id-q7", + name="Roborock Q7", + model="roborock.vacuum.sc05", + category=RoborockCategory.VACUUM, + ) + + +@pytest.fixture(name="device") +def device_fixture() -> HomeDataDevice: + return HomeDataDevice( + duid="abc123", + name="Q7", + local_key="key123key123key1", + product_id="product-id-q7", + sn="testsn012345", + ) + + +@pytest.fixture(name="expected_msg_id", autouse=True) +def next_message_id_fixture() -> Generator[int, None, None]: + expected_msg_id = math.floor(time.time()) + with patch("roborock.protocols.b01_q7_protocol.get_next_int", return_value=expected_msg_id): + yield expected_msg_id + + +@pytest.fixture(name="message_builder") +def message_builder_fixture(expected_msg_id: int) -> B01MessageBuilder: + builder = B01MessageBuilder() + builder.msg_id = expected_msg_id + return builder + + +async def test_create_b01_q7_channel( + device: HomeDataDevice, product: HomeDataProduct, fake_channel: FakeChannel +) -> None: + channel = create_b01_q7_channel(device, product, fake_channel) # type: ignore[arg-type] + assert channel is not None + assert channel.is_connected is False + assert channel.is_local_connected is False + + +async def test_create_b01_q7_channel_missing_metadata( + device: HomeDataDevice, + product: HomeDataProduct, + fake_channel: FakeChannel, +) -> None: + """Test creating Q7 channel without required metadata raises RoborockException.""" + device.sn = None + with pytest.raises(RoborockException, match="Device serial number and product model are required"): + create_b01_q7_channel(device, product, fake_channel) # type: ignore[arg-type] + + +async def test_q7_channel_send_command( + device: HomeDataDevice, + product: HomeDataProduct, + fake_channel: FakeChannel, + message_builder: B01MessageBuilder, +) -> None: + channel = create_b01_q7_channel(device, product, fake_channel) # type: ignore[arg-type] + fake_channel.response_queue.append(message_builder.build({"status": 1})) + + result = await channel.send_command("prop.get", {"property": ["status"]}) + assert result == {"status": 1} + + assert len(fake_channel.published_messages) == 1 + message = fake_channel.published_messages[0] + assert message.protocol == RoborockMessageProtocol.RPC_REQUEST + assert message.version == B01_VERSION + + assert message.payload is not None + payload_data = json.loads(unpad(message.payload, AES.block_size)) + assert payload_data["dps"]["10000"]["method"] == "prop.get" + + +async def test_q7_channel_send_map_command( + device: HomeDataDevice, + product: HomeDataProduct, + fake_channel: FakeChannel, + message_builder: B01MessageBuilder, +) -> None: + channel = create_b01_q7_channel(device, product, fake_channel) # type: ignore[arg-type] + unrelated_msg = message_builder.build({"status": 1}) + fake_channel.response_queue.append(unrelated_msg) + + with patch( + "roborock.devices.rpc.b01_q7_channel.decode_map_payload", + return_value=b"inflated-payload", + ) as mock_decode: + task = asyncio.create_task(channel.send_map_command("service.upload_by_mapid", {"map_id": 123})) + await asyncio.sleep(0) + + fake_channel.notify_subscribers(message_builder.build_map_response(b"raw-map-payload")) + result = await task + assert result == b"inflated-payload" + mock_decode.assert_called_once() + + +async def test_send_decoded_command_non_dict_response(fake_channel: FakeChannel, message_builder: B01MessageBuilder): + """Test validity of handling non-dict responses.""" + message = message_builder.build("some_string_error") + fake_channel.response_queue.append(message) + + with pytest.raises(RoborockException, match="Unexpected data type for response"): + await send_decoded_command(fake_channel, Q7RequestMessage(dps=10000, command="prop.get", params=[])) # type: ignore[arg-type] + + +async def test_send_decoded_command_error_code(fake_channel: FakeChannel, message_builder: B01MessageBuilder): + """Test that non-zero error codes from device are properly handled.""" + message = message_builder.build({}, code=5001) + fake_channel.response_queue.append(message) + + with pytest.raises(RoborockException, match="B01 command failed with code 5001"): + await send_decoded_command(fake_channel, Q7RequestMessage(dps=10000, command="prop.get", params=[])) # type: ignore[arg-type] + + +async def test_send_decoded_command_allows_ok_string_ack(fake_channel: FakeChannel, message_builder: B01MessageBuilder): + """Command ACKs may return plain string payloads like ``ok``.""" + message = message_builder.build("ok") + fake_channel.response_queue.append(message) + + result = await send_decoded_command( + cast(Any, fake_channel), + Q7RequestMessage(dps=10000, command="service.set_room_clean", params=[]), # type: ignore[arg-type] + ) + + assert result == "ok" + + +async def test_send_command_timeout( + device: HomeDataDevice, + product: HomeDataProduct, + fake_channel: FakeChannel, +) -> None: + """Test timeout behavior on regular send_command.""" + channel = create_b01_q7_channel(device, product, fake_channel) # type: ignore[arg-type] + + with patch("roborock.devices.rpc.b01_q7_channel._TIMEOUT", 0.01): + with pytest.raises(RoborockException, match="B01 command timed out after"): + await channel.send_command("prop.get", {"property": ["status"]}) + + +async def test_send_map_command_timeout( + device: HomeDataDevice, + product: HomeDataProduct, + fake_channel: FakeChannel, +) -> None: + """Test timeout behavior on send_map_command.""" + channel = create_b01_q7_channel(device, product, fake_channel) # type: ignore[arg-type] + + with patch("roborock.devices.rpc.b01_q7_channel._TIMEOUT", 0.01): + with pytest.raises(RoborockException, match="B01 map command timed out after"): + await channel.send_map_command("service.upload_by_mapid", {"map_id": 123}) + + +@pytest.mark.parametrize( + "bad_payload", + [ + # Undecryptable/unpad-failing message + b"short", + # Non-string dps value + pad(b'{"dps": {"10000": 123}}', AES.block_size), + # Invalid JSON string in dps value + pad(b'{"dps": {"10000": "invalid-json{"}}', AES.block_size), + # Message with incorrect msgId + pad(b'{"dps": {"10000": "{\\"msgId\\": \\"999\\"}"}}', AES.block_size), + ], +) +async def test_send_command_ignores_invalid_response( + device: HomeDataDevice, + product: HomeDataProduct, + fake_channel: FakeChannel, + message_builder: B01MessageBuilder, + bad_payload: bytes, +) -> None: + """Test that invalid, unparsable, or wrong-msgId messages are skipped gracefully.""" + channel = create_b01_q7_channel(device, product, fake_channel) # type: ignore[arg-type] + + bad_message = RoborockMessage( + protocol=RoborockMessageProtocol.RPC_RESPONSE, + payload=bad_payload, + version=b"B01", + ) + fake_channel.response_queue.append(bad_message) + + task = asyncio.create_task(channel.send_command("prop.get", {"property": ["status"]})) + await asyncio.sleep(0) + + fake_channel.notify_subscribers(message_builder.build({"status": 1})) + result = await task + assert result == {"status": 1} + + +async def test_send_command_ignores_messages_after_resolve( + device: HomeDataDevice, + product: HomeDataProduct, + fake_channel: FakeChannel, + message_builder: B01MessageBuilder, +) -> None: + """Test that messages arriving after the command has resolved are ignored.""" + channel = create_b01_q7_channel(device, product, fake_channel) # type: ignore[arg-type] + + fake_channel.response_queue.append(message_builder.build({"status": 1})) + + task = asyncio.create_task(channel.send_command("prop.get", {"property": ["status"]})) + await asyncio.sleep(0) + + fake_channel.notify_subscribers(message_builder.build({"status": 1})) + result = await task + assert result == {"status": 1} + + +async def test_send_command_general_exception( + device: HomeDataDevice, + product: HomeDataProduct, + fake_channel: FakeChannel, +) -> None: + """Test that non-RoborockException errors are propagated and logged.""" + channel = create_b01_q7_channel(device, product, fake_channel) # type: ignore[arg-type] + fake_channel.publish_side_effect = RuntimeError("Generic publish crash") + + with pytest.raises(RuntimeError, match="Generic publish crash"): + await channel.send_command("prop.get", {"property": ["status"]}) diff --git a/tests/devices/traits/b01/q10/conftest.py b/tests/devices/traits/b01/q10/conftest.py new file mode 100644 index 00000000..85cf0a02 --- /dev/null +++ b/tests/devices/traits/b01/q10/conftest.py @@ -0,0 +1,54 @@ +from collections.abc import AsyncGenerator +from typing import Any + +import pytest + +from roborock.devices.rpc.b01_q10_channel import B01Q10Channel, Q10RpcChannel +from roborock.devices.traits.b01.q10 import Q10PropertiesApi, create +from roborock.protocols.b01_q10_protocol import Q10Message + + +class FakeQ10RpcChannel(Q10RpcChannel): + """Plaintext mock RPC channel for Q10.""" + + def __init__(self) -> None: + self.published_commands: list[tuple[Any, Any]] = [] + + async def send_command(self, command: Any, params: Any = None) -> None: + self.published_commands.append((command, params)) + + +class FakeB01Q10Channel(B01Q10Channel): + """Plaintext mock transport channel for Q10.""" + + def __init__(self) -> None: + self.published_commands: list[tuple[Any, Any]] = [] + self.messages_to_stream: list[Q10Message] = [] + + @property + def is_connected(self) -> bool: + return True + + @property + def is_local_connected(self) -> bool: + return False + + async def subscribe(self, callback: Any) -> Any: + return lambda: None + + async def subscribe_stream(self) -> AsyncGenerator[Q10Message, None]: + for msg in self.messages_to_stream: + yield msg + + async def send_command(self, command: Any, params: Any = None) -> None: + self.published_commands.append((command, params)) + + +@pytest.fixture(name="fake_channel") +def fake_channel_fixture() -> FakeB01Q10Channel: + return FakeB01Q10Channel() + + +@pytest.fixture(name="q10_api") +def q10_api_fixture(fake_channel: FakeB01Q10Channel) -> Q10PropertiesApi: + return create(fake_channel) diff --git a/tests/devices/traits/b01/q10/test_map.py b/tests/devices/traits/b01/q10/test_map.py index 470f3f99..fb07e585 100644 --- a/tests/devices/traits/b01/q10/test_map.py +++ b/tests/devices/traits/b01/q10/test_map.py @@ -9,7 +9,7 @@ from collections.abc import AsyncGenerator from pathlib import Path from typing import cast -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock import pytest @@ -17,18 +17,14 @@ from roborock.devices.traits.b01.q10 import Q10PropertiesApi, create from roborock.devices.traits.b01.q10.map import MapContentTrait from roborock.map.b01_q10_map_parser import Q10Point, parse_map_packet, parse_trace_packet -from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol +from roborock.protocols.b01_q10_protocol import Q10Message + +from .conftest import FakeB01Q10Channel FIXTURE = Path("tests/map/testdata/b01_q10_map.bin") TRACE_FIXTURE = Path("tests/map/testdata/b01_q10_trace.bin") -def _map_message( - payload: bytes, protocol: RoborockMessageProtocol = RoborockMessageProtocol.MAP_RESPONSE -) -> RoborockMessage: - return RoborockMessage(protocol=protocol, payload=payload, version=b"B01") - - def test_update_from_map_packet_populates_image_and_rooms() -> None: """A parsed 01 01 map packet populates the image, rooms and map data.""" packet = parse_map_packet(FIXTURE.read_bytes()) @@ -124,24 +120,25 @@ async def test_await_q10_map_push_can_fall_back_to_cached_map_on_timeout() -> No # --- Integration through the Q10PropertiesApi subscribe loop ----------------- -@pytest.fixture -def message_queue() -> asyncio.Queue[RoborockMessage]: +@pytest.fixture(name="message_queue") +def message_queue_fixture() -> asyncio.Queue[Q10Message]: return asyncio.Queue() -@pytest.fixture -def mock_channel(message_queue: asyncio.Queue[RoborockMessage]) -> AsyncMock: - async def mock_stream() -> AsyncGenerator[RoborockMessage, None]: +@pytest.fixture(name="mock_channel") +def mock_channel_fixture(message_queue: asyncio.Queue[Q10Message]) -> FakeB01Q10Channel: + channel = FakeB01Q10Channel() + + async def mock_stream() -> AsyncGenerator[Q10Message, None]: while True: yield await message_queue.get() - channel = AsyncMock() - channel.subscribe_stream = Mock(return_value=mock_stream()) + setattr(channel, "subscribe_stream", Mock(side_effect=mock_stream)) return channel -@pytest.fixture -async def q10_api(mock_channel: AsyncMock) -> AsyncGenerator[Q10PropertiesApi, None]: +@pytest.fixture(name="q10_api") +async def q10_api_fixture(mock_channel: FakeB01Q10Channel) -> AsyncGenerator[Q10PropertiesApi, None]: api = create(mock_channel) await api.start() yield api @@ -156,12 +153,12 @@ async def _wait_for(predicate, timeout: float = 2.0) -> None: async def test_subscribe_loop_routes_map_push( q10_api: Q10PropertiesApi, - message_queue: asyncio.Queue[RoborockMessage], + message_queue: asyncio.Queue[Q10Message], ) -> None: """A map pushed onto the stream is routed to the map trait by the loop.""" assert q10_api.map.image_content is None - message_queue.put_nowait(_map_message(FIXTURE.read_bytes())) + message_queue.put_nowait(parse_map_packet(FIXTURE.read_bytes())) await _wait_for(lambda: q10_api.map.image_content is not None) assert {room.id: room.name for room in q10_api.map.rooms} == {2: "Living Room", 3: "Bedroom"} @@ -169,12 +166,12 @@ async def test_subscribe_loop_routes_map_push( async def test_subscribe_loop_routes_trace_push( q10_api: Q10PropertiesApi, - message_queue: asyncio.Queue[RoborockMessage], + message_queue: asyncio.Queue[Q10Message], ) -> None: """A trace pushed onto the stream is routed to the map trait by the loop.""" assert not q10_api.map.path - message_queue.put_nowait(_map_message(TRACE_FIXTURE.read_bytes())) + message_queue.put_nowait(parse_trace_packet(TRACE_FIXTURE.read_bytes())) await _wait_for(lambda: bool(q10_api.map.path)) assert q10_api.map.robot_position is not None diff --git a/tests/devices/traits/b01/q10/test_remote.py b/tests/devices/traits/b01/q10/test_remote.py index 4de8c986..2eac1403 100644 --- a/tests/devices/traits/b01/q10/test_remote.py +++ b/tests/devices/traits/b01/q10/test_remote.py @@ -1,4 +1,3 @@ -import json from collections.abc import Awaitable, Callable from typing import Any @@ -6,17 +5,8 @@ from roborock.devices.traits.b01.q10 import Q10PropertiesApi from roborock.devices.traits.b01.q10.remote import RemoteTrait -from tests.fixtures.channel_fixtures import FakeChannel - -@pytest.fixture(name="fake_channel") -def fake_channel_fixture() -> FakeChannel: - return FakeChannel() - - -@pytest.fixture(name="q10_api") -def q10_api_fixture(fake_channel: FakeChannel) -> Q10PropertiesApi: - return Q10PropertiesApi(fake_channel) # type: ignore[arg-type] +from .conftest import FakeB01Q10Channel @pytest.fixture(name="remote") @@ -36,15 +26,18 @@ def remote_fixture(q10_api: Q10PropertiesApi) -> RemoteTrait: ) async def test_remote_commands( remote: RemoteTrait, - fake_channel: FakeChannel, + fake_channel: FakeB01Q10Channel, command_fn: Callable[[RemoteTrait], Awaitable[None]], expected_payload: dict[str, Any], ) -> None: """Test sending a remote start command.""" await command_fn(remote) - assert len(fake_channel.published_messages) == 1 - message = fake_channel.published_messages[0] - assert message.payload - payload_data = json.loads(message.payload.decode()) - assert payload_data == {"dps": expected_payload} + assert len(fake_channel.published_commands) == 1 + command, params = fake_channel.published_commands[0] + + dp_code = int(list(expected_payload.keys())[0]) + expected_params = list(expected_payload.values())[0] + + assert command.code == dp_code + assert params == expected_params diff --git a/tests/devices/traits/b01/q10/test_settings.py b/tests/devices/traits/b01/q10/test_settings.py index 7b3474c5..7d29347b 100644 --- a/tests/devices/traits/b01/q10/test_settings.py +++ b/tests/devices/traits/b01/q10/test_settings.py @@ -1,8 +1,5 @@ """Tests for the Q10 B01 setting writer traits.""" -import json -from typing import cast - import pytest from roborock.data.b01_q10.b01_q10_code_mappings import YXDeviceDustCollectionFrequency @@ -12,31 +9,30 @@ from roborock.devices.traits.b01.q10.do_not_disturb import DoNotDisturbTrait from roborock.devices.traits.b01.q10.dust_collection import DustCollectionTrait from roborock.devices.traits.b01.q10.volume import SoundVolumeTrait -from roborock.devices.transport.mqtt_channel import MqttChannel -from tests.fixtures.channel_fixtures import FakeChannel + +from .conftest import FakeQ10RpcChannel -@pytest.fixture -def fake_channel() -> FakeChannel: - return FakeChannel() +@pytest.fixture(name="fake_rpc_channel") +def fake_rpc_channel_fixture() -> FakeQ10RpcChannel: + return FakeQ10RpcChannel() -@pytest.fixture -def command(fake_channel: FakeChannel) -> CommandTrait: - return CommandTrait(cast(MqttChannel, fake_channel)) +@pytest.fixture(name="command") +def command_fixture(fake_rpc_channel: FakeQ10RpcChannel) -> CommandTrait: + return CommandTrait(fake_rpc_channel) -def _sent_dps(fake_channel: FakeChannel) -> dict: - assert len(fake_channel.published_messages) == 1 - payload = fake_channel.published_messages[0].payload - assert payload is not None - return json.loads(payload)["dps"] +def _sent_dps(fake_rpc_channel: FakeQ10RpcChannel) -> dict: + assert len(fake_rpc_channel.published_commands) == 1 + command, params = fake_rpc_channel.published_commands[0] + return {str(command.code): params} -async def test_set_volume_uses_common_wrapper(fake_channel: FakeChannel, command: CommandTrait) -> None: +async def test_set_volume_uses_common_wrapper(fake_rpc_channel: FakeQ10RpcChannel, command: CommandTrait) -> None: """Volume writes are wrapped in dpCommon (101) -> {"26": value}.""" await SoundVolumeTrait(command).set_volume(55) - assert _sent_dps(fake_channel) == {"101": {"26": 55}} + assert _sent_dps(fake_rpc_channel) == {"101": {"26": 55}} @pytest.mark.parametrize("volume", [-1, 101, 1000]) @@ -55,16 +51,20 @@ async def test_set_volume_rejects_out_of_range(command: CommandTrait, volume: in ], ) async def test_switch_enable_writes_common_wrapped_dp( - fake_channel: FakeChannel, command: CommandTrait, trait_cls: type, method: str, code: str + fake_rpc_channel: FakeQ10RpcChannel, + command: CommandTrait, + trait_cls: type, + method: str, + code: str, ) -> None: """Each switch trait's enable() writes its data point as int 1 under dpCommon.""" await getattr(trait_cls(command), method)() - assert _sent_dps(fake_channel) == {"101": {code: 1}} + assert _sent_dps(fake_rpc_channel) == {"101": {code: 1}} -async def test_switch_disable_sends_zero(fake_channel: FakeChannel, command: CommandTrait) -> None: +async def test_switch_disable_sends_zero(fake_rpc_channel: FakeQ10RpcChannel, command: CommandTrait) -> None: await ChildLockTrait(command).disable() - assert _sent_dps(fake_channel) == {"101": {"47": 0}} + assert _sent_dps(fake_rpc_channel) == {"101": {"47": 0}} @pytest.mark.parametrize( @@ -77,11 +77,11 @@ async def test_switch_disable_sends_zero(fake_channel: FakeChannel, command: Com ], ) async def test_set_dust_frequency_writes_interval_code( - fake_channel: FakeChannel, + fake_rpc_channel: FakeQ10RpcChannel, command: CommandTrait, frequency: YXDeviceDustCollectionFrequency, code: int, ) -> None: """Frequency enum writes its interval code under dpDustSetting (50).""" await DustCollectionTrait(command).set_frequency(frequency) - assert _sent_dps(fake_channel) == {"101": {"50": code}} + assert _sent_dps(fake_rpc_channel) == {"101": {"50": code}} diff --git a/tests/devices/traits/b01/q10/test_status.py b/tests/devices/traits/b01/q10/test_status.py index 98732ce6..8105142a 100644 --- a/tests/devices/traits/b01/q10/test_status.py +++ b/tests/devices/traits/b01/q10/test_status.py @@ -1,11 +1,10 @@ """Tests for the Q10 B01 status trait.""" import asyncio -import json import pathlib from collections.abc import AsyncGenerator from typing import Any -from unittest.mock import AsyncMock, Mock +from unittest.mock import Mock import pytest @@ -23,42 +22,38 @@ ) from roborock.data.b01_q10.b01_q10_containers import dpNetInfo, dpNotDisturbExpand, dpTimeZone from roborock.devices.traits.b01.q10 import Q10PropertiesApi, create +from roborock.protocols.b01_q10_protocol import Q10Message, decode_message from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol +from .conftest import FakeB01Q10Channel + TEST_DATA_DIR = pathlib.Path("tests/protocols/testdata/b01_q10_protocol") TESTDATA_DP_STATUS_DP_CLEAN_TASK_TYPE = (TEST_DATA_DIR / "dpStatus-dpCleanTaskType.json").read_bytes() TESTDATA_DP_REQUEST_DPS = (TEST_DATA_DIR / "dpRequetdps.json").read_bytes() -@pytest.fixture -def mock_channel(): - """Fixture for a mocked MQTT channel.""" - mock = AsyncMock() - return mock - - -@pytest.fixture -def message_queue() -> asyncio.Queue[RoborockMessage]: +@pytest.fixture(name="message_queue") +def message_queue_fixture() -> asyncio.Queue[Q10Message]: """Fixture for a message queue used by the mock stream.""" return asyncio.Queue() -@pytest.fixture -def mock_subscribe_stream(mock_channel: AsyncMock, message_queue: asyncio.Queue[RoborockMessage]) -> Mock: +@pytest.fixture(name="mock_channel") +def mock_channel_fixture(message_queue: asyncio.Queue[Q10Message]) -> FakeB01Q10Channel: """Fixture to mock the subscribe_stream method to yield from a queue.""" + channel = FakeB01Q10Channel() - async def mock_stream() -> AsyncGenerator[RoborockMessage, None]: + async def mock_stream() -> AsyncGenerator[Q10Message, None]: while True: yield await message_queue.get() - mock = Mock(return_value=mock_stream()) - mock_channel.subscribe_stream = mock - return mock + setattr(channel, "subscribe_stream", Mock(side_effect=mock_stream)) + return channel -@pytest.fixture -async def q10_api(mock_channel: AsyncMock, mock_subscribe_stream: Mock) -> AsyncGenerator[Q10PropertiesApi, None]: +@pytest.fixture(name="q10_api") +async def q10_api_fixture(mock_channel: FakeB01Q10Channel) -> AsyncGenerator[Q10PropertiesApi, None]: """Fixture to create and manage the Q10PropertiesApi.""" api = create(mock_channel) await api.start() @@ -66,20 +61,20 @@ async def q10_api(mock_channel: AsyncMock, mock_subscribe_stream: Mock) -> Async await api.close() -def build_message(payload: bytes) -> RoborockMessage: - """Helper to build a RoborockMessage for testing.""" - return RoborockMessage( +def build_q10_message(payload: bytes) -> Q10Message: + """Helper to build a Q10Message for testing.""" + msg = RoborockMessage( protocol=RoborockMessageProtocol.RPC_RESPONSE, payload=payload, version=b"B01", ) + result = decode_message(msg) + assert result is not None + return result async def wait_for_attribute_value(obj: Any, attribute: str, value: Any, timeout: float = 2.0) -> None: - """Wait for an attribute on an object to reach a specific value. - - This is a temporary polling solution until listeners are implemented. - """ + """Wait for an attribute on an object to reach a specific value.""" for _ in range(int(timeout / 0.1)): if getattr(obj, attribute) == value: return @@ -89,12 +84,12 @@ async def wait_for_attribute_value(obj: Any, attribute: str, value: Any, timeout async def test_status_trait_streaming( q10_api: Q10PropertiesApi, - message_queue: asyncio.Queue[RoborockMessage], + message_queue: asyncio.Queue[Q10Message], ) -> None: """Test that the StatusTrait updates its state from streaming messages.""" # status (121) = 8 (CHARGING_STATE) # clean_task_type (138) = 0 (IDLE) - message = build_message(TESTDATA_DP_STATUS_DP_CLEAN_TASK_TYPE) + message = build_q10_message(TESTDATA_DP_STATUS_DP_CLEAN_TASK_TYPE) assert q10_api.status.status is None assert q10_api.status.clean_task_type is None @@ -112,8 +107,8 @@ async def test_status_trait_streaming( async def test_status_trait_refresh( q10_api: Q10PropertiesApi, - mock_channel: AsyncMock, - message_queue: asyncio.Queue[RoborockMessage], + mock_channel: FakeB01Q10Channel, + message_queue: asyncio.Queue[Q10Message], ) -> None: """Test that the StatusTrait sends a refresh command and updates state.""" assert q10_api.status.battery is None @@ -128,18 +123,14 @@ async def test_status_trait_refresh( # battery (122) = 100 # status (121) = 8 (CHARGING_STATE) # fan_level (123) = 2 (BALANCED) - message = build_message(TESTDATA_DP_REQUEST_DPS) + message = build_q10_message(TESTDATA_DP_REQUEST_DPS) # Send a refresh command await q10_api.refresh() - mock_channel.publish.assert_called_once() - sent_message = mock_channel.publish.call_args[0][0] - assert sent_message.protocol == RoborockMessageProtocol.RPC_REQUEST - # Verify refresh payload - data = json.loads(sent_message.payload) - assert data - assert data.get("dps") - assert data.get("dps").get("102") == {} # REQUEST_DPS code is 102 + assert len(mock_channel.published_commands) == 1 + command, params = mock_channel.published_commands[0] + assert command == B01_Q10_DP.REQUEST_DPS + assert params == {} # Push the response message into the queue message_queue.put_nowait(message) @@ -199,11 +190,11 @@ async def test_status_trait_refresh( async def test_status_trait_vacuum_only_refresh( q10_api: Q10PropertiesApi, - message_queue: asyncio.Queue[RoborockMessage], + message_queue: asyncio.Queue[Q10Message], ) -> None: """Test decoding a full status dump from a vacuum-only (no mop) Q10.""" payload = (TEST_DATA_DIR / "dpRequestDps_vacuum_only.json").read_bytes() - message_queue.put_nowait(build_message(payload)) + message_queue.put_nowait(build_q10_message(payload)) await wait_for_attribute_value(q10_api.status, "battery", 75) diff --git a/tests/devices/traits/b01/q10/test_vacuum.py b/tests/devices/traits/b01/q10/test_vacuum.py index e736d28c..e20d36d1 100644 --- a/tests/devices/traits/b01/q10/test_vacuum.py +++ b/tests/devices/traits/b01/q10/test_vacuum.py @@ -1,4 +1,3 @@ -import json from collections.abc import Awaitable, Callable from typing import Any @@ -7,21 +6,12 @@ from roborock.data.b01_q10.b01_q10_code_mappings import YXCleanType, YXFanLevel from roborock.devices.traits.b01.q10 import Q10PropertiesApi from roborock.devices.traits.b01.q10.vacuum import VacuumTrait -from tests.fixtures.channel_fixtures import FakeChannel +from .conftest import FakeB01Q10Channel -@pytest.fixture(name="fake_channel") -def fake_channel_fixture() -> FakeChannel: - return FakeChannel() - -@pytest.fixture(name="q10_api") -def q10_api_fixture(fake_channel: FakeChannel) -> Q10PropertiesApi: - return Q10PropertiesApi(fake_channel) # type: ignore[arg-type] - - -@pytest.fixture(name="vacuumm") -def vacuumm_fixture(q10_api: Q10PropertiesApi) -> VacuumTrait: +@pytest.fixture(name="vacuum") +def vacuum_fixture(q10_api: Q10PropertiesApi) -> VacuumTrait: return q10_api.vacuum @@ -41,16 +31,19 @@ def vacuumm_fixture(q10_api: Q10PropertiesApi) -> VacuumTrait: ], ) async def test_vacuum_commands( - vacuumm: VacuumTrait, - fake_channel: FakeChannel, + vacuum: VacuumTrait, + fake_channel: FakeB01Q10Channel, command_fn: Callable[[VacuumTrait], Awaitable[None]], expected_payload: dict[str, Any], ) -> None: """Test sending a vacuum start command.""" - await command_fn(vacuumm) + await command_fn(vacuum) + + assert len(fake_channel.published_commands) == 1 + command, params = fake_channel.published_commands[0] + + dp_code = int(list(expected_payload.keys())[0]) + expected_params = list(expected_payload.values())[0] - assert len(fake_channel.published_messages) == 1 - message = fake_channel.published_messages[0] - assert message.payload - payload_data = json.loads(message.payload.decode()) - assert payload_data == {"dps": expected_payload} + assert command.code == dp_code + assert params == expected_params diff --git a/tests/devices/traits/b01/q7/__init__.py b/tests/devices/traits/b01/q7/__init__.py index 128a0924..1fcbe4aa 100644 --- a/tests/devices/traits/b01/q7/__init__.py +++ b/tests/devices/traits/b01/q7/__init__.py @@ -1,51 +1 @@ -import json -from typing import Any - -from Crypto.Cipher import AES -from Crypto.Util.Padding import pad - -from roborock.devices.traits.b01.q7 import Q7PropertiesApi -from roborock.roborock_message import RoborockMessage, RoborockMessageProtocol -from tests.fixtures.channel_fixtures import FakeChannel - - -class B01MessageBuilder: - """Helper class to build B01 RPC response messages for tests.""" - - def __init__(self) -> None: - self.msg_id = 123456789 - self.seq = 2020 - - def build(self, data: dict[str, Any] | str, code: int | None = None) -> RoborockMessage: - """Build an encoded B01 RPC response message.""" - message: dict[str, Any] = { - "msgId": str(self.msg_id), - "data": data, - } - if code is not None: - message["code"] = code - return self._build_dps(message) - - def _build_dps(self, message: dict[str, Any] | str) -> RoborockMessage: - """Build an encoded B01 RPC response message.""" - dps_payload = {"dps": {"10000": json.dumps(message)}} - self.seq += 1 - return RoborockMessage( - protocol=RoborockMessageProtocol.RPC_RESPONSE, - payload=pad( - json.dumps(dps_payload).encode(), - AES.block_size, - ), - version=b"B01", - seq=self.seq, - ) - - def build_map_response(self, payload: bytes) -> RoborockMessage: - """Build a dummy MAP_RESPONSE message.""" - self.seq += 1 - return RoborockMessage( - protocol=RoborockMessageProtocol.MAP_RESPONSE, - payload=payload, - version=b"B01", - seq=self.seq, - ) +# Q7 traits test package diff --git a/tests/devices/traits/b01/q7/conftest.py b/tests/devices/traits/b01/q7/conftest.py index caf2097d..efc9d166 100644 --- a/tests/devices/traits/b01/q7/conftest.py +++ b/tests/devices/traits/b01/q7/conftest.py @@ -1,20 +1,38 @@ -import math -import time -from collections.abc import Generator -from unittest.mock import patch +from typing import Any import pytest from roborock.data import HomeDataDevice, HomeDataProduct, RoborockCategory +from roborock.devices.rpc.b01_q7_channel import Q7MapRpcChannel, Q7RpcChannel from roborock.devices.traits.b01.q7 import Q7PropertiesApi, create -from tests.fixtures.channel_fixtures import FakeChannel -from . import B01MessageBuilder + +class FakeQ7Channel(Q7RpcChannel, Q7MapRpcChannel): + """A plaintext mock for Q7 Rpc and Map Channel.""" + + def __init__(self) -> None: + self.published_commands: list[tuple[Any, Any]] = [] + self.response_queue: list[Any] = [] + self.side_effect: Exception | None = None + + async def send_command(self, command: Any, params: Any = None) -> Any: + if self.side_effect: + raise self.side_effect + self.published_commands.append((command, params)) + if self.response_queue: + return self.response_queue.pop(0) + return {} + + async def send_map_command(self, command: Any, params: Any = None) -> bytes: + self.published_commands.append((command, params)) + if self.response_queue: + return self.response_queue.pop(0) + return b"" @pytest.fixture(name="fake_channel") -def fake_channel_fixture() -> FakeChannel: - return FakeChannel() +def fake_channel_fixture() -> FakeQ7Channel: + return FakeQ7Channel() @pytest.fixture(name="product") @@ -39,27 +57,9 @@ def device_fixture() -> HomeDataDevice: @pytest.fixture(name="q7_api") -def q7_api_fixture(fake_channel: FakeChannel, device: HomeDataDevice, product: HomeDataProduct) -> Q7PropertiesApi: - return create(product, device, fake_channel) # type: ignore[arg-type] - - -@pytest.fixture(name="expected_msg_id", autouse=True) -def next_message_id_fixture() -> Generator[int, None, None]: - """Fixture to patch get_next_int to return the expected message ID. - - We pick an arbitrary number, but just need it to ensure we can craft a fake - response with the message id matched to the outgoing RPC. - """ - - expected_msg_id = math.floor(time.time()) - - # Patch get_next_int to return our expected msg_id so the channel waits for it - with patch("roborock.protocols.b01_q7_protocol.get_next_int", return_value=expected_msg_id): - yield expected_msg_id - - -@pytest.fixture(name="message_builder") -def message_builder_fixture(expected_msg_id: int) -> B01MessageBuilder: - builder = B01MessageBuilder() - builder.msg_id = expected_msg_id - return builder +def q7_api_fixture( + fake_channel: FakeQ7Channel, + device: HomeDataDevice, + product: HomeDataProduct, +) -> Q7PropertiesApi: + return create(product, device, fake_channel, fake_channel) diff --git a/tests/devices/traits/b01/q7/test_clean_summary.py b/tests/devices/traits/b01/q7/test_clean_summary.py index 0a2bccbb..ae77208e 100644 --- a/tests/devices/traits/b01/q7/test_clean_summary.py +++ b/tests/devices/traits/b01/q7/test_clean_summary.py @@ -8,9 +8,8 @@ from roborock.data.b01_q7 import CleanRecordList from roborock.devices.traits.b01.q7.clean_summary import CleanSummaryTrait from roborock.exceptions import RoborockException -from tests.fixtures.channel_fixtures import FakeChannel -from . import B01MessageBuilder +from .conftest import FakeQ7Channel CLEAN_RECORD_LIST_DATA = { "total_time": 34980, @@ -60,17 +59,16 @@ @pytest.fixture(name="clean_summary_trait") -def clean_summary_trait_fixture(fake_channel: FakeChannel) -> CleanSummaryTrait: - return CleanSummaryTrait(fake_channel) # type: ignore[arg-type] +def clean_summary_trait_fixture(fake_channel: FakeQ7Channel) -> CleanSummaryTrait: + return CleanSummaryTrait(fake_channel) async def test_refresh_success( clean_summary_trait: CleanSummaryTrait, - fake_channel: FakeChannel, - message_builder: B01MessageBuilder, + fake_channel: FakeQ7Channel, ) -> None: """Test successfully refreshing clean summary.""" - fake_channel.response_queue.append(message_builder.build(CLEAN_RECORD_LIST_DATA)) + fake_channel.response_queue.append(CLEAN_RECORD_LIST_DATA) await clean_summary_trait.refresh() assert clean_summary_trait.total_time == 34980 @@ -82,8 +80,7 @@ async def test_refresh_success( async def test_refresh_with_no_records( clean_summary_trait: CleanSummaryTrait, - fake_channel: FakeChannel, - message_builder: B01MessageBuilder, + fake_channel: FakeQ7Channel, ) -> None: """Test refreshing with no records.""" empty_response = { @@ -92,7 +89,7 @@ async def test_refresh_with_no_records( "total_count": 0, "record_list": [], } - fake_channel.response_queue.append(message_builder.build(empty_response)) + fake_channel.response_queue.append(empty_response) await clean_summary_trait.refresh() assert clean_summary_trait.total_time == 0 @@ -103,10 +100,10 @@ async def test_refresh_with_no_records( async def test_refresh_propagates_exceptions( clean_summary_trait: CleanSummaryTrait, - fake_channel: FakeChannel, + fake_channel: FakeQ7Channel, ) -> None: """Test that exceptions from channel are propagated during refresh.""" - fake_channel.publish_side_effect = RoborockException("Communication error") + fake_channel.side_effect = RoborockException("Communication error") with pytest.raises(RoborockException, match="Communication error"): await clean_summary_trait.refresh() diff --git a/tests/devices/traits/b01/q7/test_init.py b/tests/devices/traits/b01/q7/test_init.py index 0c04a261..c273d2d4 100644 --- a/tests/devices/traits/b01/q7/test_init.py +++ b/tests/devices/traits/b01/q7/test_init.py @@ -1,9 +1,6 @@ -import json -from typing import Any, cast +from typing import Any import pytest -from Crypto.Cipher import AES -from Crypto.Util.Padding import unpad from roborock.data.b01_q7 import ( CleanTaskTypeMapping, @@ -13,19 +10,14 @@ WaterLevelMapping, WorkStatusMapping, ) -from roborock.devices.rpc.b01_q7_channel import send_decoded_command from roborock.devices.traits.b01.q7 import Q7PropertiesApi -from roborock.exceptions import RoborockException -from roborock.protocols.b01_q7_protocol import B01_VERSION, Q7RequestMessage -from roborock.roborock_message import RoborockB01Props, RoborockMessageProtocol -from tests.fixtures.channel_fixtures import FakeChannel +from roborock.roborock_message import RoborockB01Props +from roborock.roborock_typing import RoborockB01Q7Methods -from . import B01MessageBuilder +from .conftest import FakeQ7Channel -async def test_q7_api_query_values( - q7_api: Q7PropertiesApi, fake_channel: FakeChannel, message_builder: B01MessageBuilder -): +async def test_q7_api_query_values(q7_api: Q7PropertiesApi, fake_channel: FakeQ7Channel): """Test that Q7PropertiesApi correctly converts raw values.""" response_data = { "status": 1, @@ -33,7 +25,7 @@ async def test_q7_api_query_values( "battery": 100, } - fake_channel.response_queue.append(message_builder.build(response_data)) + fake_channel.response_queue.append(response_data) result = await q7_api.query_values( [ @@ -46,19 +38,10 @@ async def test_q7_api_query_values( assert result.status == WorkStatusMapping.WAITING_FOR_ORDERS assert result.wind == SCWindMapping.STANDARD - assert len(fake_channel.published_messages) == 1 - message = fake_channel.published_messages[0] - assert message.protocol == RoborockMessageProtocol.RPC_REQUEST - assert message.version == B01_VERSION - - assert message.payload is not None - payload_data = json.loads(unpad(message.payload, AES.block_size)) - assert "dps" in payload_data - assert "10000" in payload_data["dps"] - inner = payload_data["dps"]["10000"] - assert inner["method"] == "prop.get" - assert inner["msgId"] == str(message_builder.msg_id) - assert inner["params"] == {"property": [RoborockB01Props.STATUS, RoborockB01Props.WIND]} + assert len(fake_channel.published_commands) == 1 + command, params = fake_channel.published_commands[0] + assert command == RoborockB01Q7Methods.GET_PROP + assert params == {"property": [RoborockB01Props.STATUS, RoborockB01Props.WIND]} @pytest.mark.parametrize( @@ -81,11 +64,10 @@ async def test_q7_response_value_mapping( response_data: dict[str, Any], expected_status: WorkStatusMapping, q7_api: Q7PropertiesApi, - fake_channel: FakeChannel, - message_builder: B01MessageBuilder, + fake_channel: FakeQ7Channel, ): """Test Q7PropertiesApi value mapping for different statuses.""" - fake_channel.response_queue.append(message_builder.build(response_data)) + fake_channel.response_queue.append(response_data) result = await q7_api.query_values(query) @@ -93,81 +75,42 @@ async def test_q7_response_value_mapping( assert result.status == expected_status -async def test_send_decoded_command_non_dict_response(fake_channel: FakeChannel, message_builder: B01MessageBuilder): - """Test validity of handling non-dict responses (should not timeout).""" - message = message_builder.build("some_string_error") - fake_channel.response_queue.append(message) - - with pytest.raises(RoborockException, match="Unexpected data type for response"): - await send_decoded_command(fake_channel, Q7RequestMessage(dps=10000, command="prop.get", params=[])) # type: ignore[arg-type] - - -async def test_send_decoded_command_error_code(fake_channel: FakeChannel, message_builder: B01MessageBuilder): - """Test that non-zero error codes from device are properly handled.""" - message = message_builder.build({}, code=5001) - fake_channel.response_queue.append(message) - - with pytest.raises(RoborockException, match="B01 command failed with code 5001"): - await send_decoded_command(fake_channel, Q7RequestMessage(dps=10000, command="prop.get", params=[])) # type: ignore[arg-type] - - -async def test_send_decoded_command_allows_ok_string_ack(fake_channel: FakeChannel, message_builder: B01MessageBuilder): - """Command ACKs may return plain string payloads like ``ok``.""" - message = message_builder.build("ok") - fake_channel.response_queue.append(message) - - result = await send_decoded_command( - cast(Any, fake_channel), - Q7RequestMessage(dps=10000, command="service.set_room_clean", params=[]), # type: ignore[arg-type] - ) - - assert result == "ok" - - -async def test_q7_api_set_fan_speed( - q7_api: Q7PropertiesApi, fake_channel: FakeChannel, message_builder: B01MessageBuilder -): +async def test_q7_api_set_fan_speed(q7_api: Q7PropertiesApi, fake_channel: FakeQ7Channel): """Test setting fan speed.""" - fake_channel.response_queue.append(message_builder.build({"result": "ok"})) + fake_channel.response_queue.append({"result": "ok"}) await q7_api.set_fan_speed(SCWindMapping.STRONG) - assert len(fake_channel.published_messages) == 1 - message = fake_channel.published_messages[0] - payload_data = json.loads(unpad(message.payload, AES.block_size)) - assert payload_data["dps"]["10000"]["method"] == "prop.set" - assert payload_data["dps"]["10000"]["params"] == {RoborockB01Props.WIND: SCWindMapping.STRONG.code} + assert len(fake_channel.published_commands) == 1 + command, params = fake_channel.published_commands[0] + assert command == RoborockB01Q7Methods.SET_PROP + assert params == {RoborockB01Props.WIND: SCWindMapping.STRONG.code} -async def test_q7_api_set_water_level( - q7_api: Q7PropertiesApi, fake_channel: FakeChannel, message_builder: B01MessageBuilder -): +async def test_q7_api_set_water_level(q7_api: Q7PropertiesApi, fake_channel: FakeQ7Channel): """Test setting water level.""" - fake_channel.response_queue.append(message_builder.build({"result": "ok"})) + fake_channel.response_queue.append({"result": "ok"}) await q7_api.set_water_level(WaterLevelMapping.HIGH) - assert len(fake_channel.published_messages) == 1 - message = fake_channel.published_messages[0] - payload_data = json.loads(unpad(message.payload, AES.block_size)) - assert payload_data["dps"]["10000"]["method"] == "prop.set" - assert payload_data["dps"]["10000"]["params"] == {RoborockB01Props.WATER: WaterLevelMapping.HIGH.code} + assert len(fake_channel.published_commands) == 1 + command, params = fake_channel.published_commands[0] + assert command == RoborockB01Q7Methods.SET_PROP + assert params == {RoborockB01Props.WATER: WaterLevelMapping.HIGH.code} @pytest.mark.parametrize("volume", [0, 50, 100]) async def test_q7_api_set_volume( volume: int, q7_api: Q7PropertiesApi, - fake_channel: FakeChannel, - message_builder: B01MessageBuilder, + fake_channel: FakeQ7Channel, ): """Test setting the robot voice volume.""" - fake_channel.response_queue.append(message_builder.build({"result": "ok"})) + fake_channel.response_queue.append({"result": "ok"}) await q7_api.set_volume(volume) - assert len(fake_channel.published_messages) == 1 - message = fake_channel.published_messages[0] - payload_data = json.loads(unpad(message.payload, AES.block_size)) - assert payload_data["dps"]["10000"]["method"] == "prop.set" - assert payload_data["dps"]["10000"]["params"] == {RoborockB01Props.VOLUME: volume} + assert len(fake_channel.published_commands) == 1 + command, params = fake_channel.published_commands[0] + assert command == RoborockB01Q7Methods.SET_PROP + assert params == {RoborockB01Props.VOLUME: volume} @pytest.mark.parametrize( @@ -178,18 +121,16 @@ async def test_q7_api_set_child_lock( enabled: bool, expected_code: int, q7_api: Q7PropertiesApi, - fake_channel: FakeChannel, - message_builder: B01MessageBuilder, + fake_channel: FakeQ7Channel, ): """Test toggling the child lock.""" - fake_channel.response_queue.append(message_builder.build({"result": "ok"})) + fake_channel.response_queue.append({"result": "ok"}) await q7_api.set_child_lock(enabled) - assert len(fake_channel.published_messages) == 1 - message = fake_channel.published_messages[0] - payload_data = json.loads(unpad(message.payload, AES.block_size)) - assert payload_data["dps"]["10000"]["method"] == "prop.set" - assert payload_data["dps"]["10000"]["params"] == {RoborockB01Props.CHILD_LOCK: expected_code} + assert len(fake_channel.published_commands) == 1 + command, params = fake_channel.published_commands[0] + assert command == RoborockB01Q7Methods.SET_PROP + assert params == {RoborockB01Props.CHILD_LOCK: expected_code} @pytest.mark.parametrize("enabled, expected_is_open", [(True, 1), (False, 0)]) @@ -197,17 +138,16 @@ async def test_q7_api_set_do_not_disturb( enabled: bool, expected_is_open: int, q7_api: Q7PropertiesApi, - fake_channel: FakeChannel, - message_builder: B01MessageBuilder, + fake_channel: FakeQ7Channel, ): """Test do-not-disturb is set as a whole via service.set_quiet_time.""" - fake_channel.response_queue.append(message_builder.build({"result": "ok"})) + fake_channel.response_queue.append({"result": "ok"}) await q7_api.set_do_not_disturb(enabled, 1200, 420) - message = fake_channel.published_messages[0] - payload_data = json.loads(unpad(message.payload, AES.block_size)) - assert payload_data["dps"]["10000"]["method"] == "service.set_quiet_time" - assert payload_data["dps"]["10000"]["params"] == { + assert len(fake_channel.published_commands) == 1 + command, params = fake_channel.published_commands[0] + assert command == RoborockB01Q7Methods.SET_QUIET_TIME + assert params == { "is_open": expected_is_open, "quiet_begin_time": 1200, "quiet_end_time": 420, @@ -222,13 +162,13 @@ async def test_q7_api_set_do_not_disturb_invalid_time( begin_time: int, end_time: int, q7_api: Q7PropertiesApi, - fake_channel: FakeChannel, + fake_channel: FakeQ7Channel, ): """Test out-of-range times raise ValueError and nothing is sent.""" with pytest.raises(ValueError, match="minutes since midnight"): await q7_api.set_do_not_disturb(True, begin_time, end_time) - assert len(fake_channel.published_messages) == 0 + assert len(fake_channel.published_commands) == 0 @pytest.mark.parametrize( @@ -243,112 +183,94 @@ async def test_q7_api_set_mode( mode: CleanTypeMapping, expected_code: int, q7_api: Q7PropertiesApi, - fake_channel: FakeChannel, - message_builder: B01MessageBuilder, + fake_channel: FakeQ7Channel, ): """Test setting cleaning mode (vacuum, mop, or both).""" - fake_channel.response_queue.append(message_builder.build({"result": "ok"})) + fake_channel.response_queue.append({"result": "ok"}) await q7_api.set_mode(mode) - assert len(fake_channel.published_messages) == 1 - message = fake_channel.published_messages[0] - payload_data = json.loads(unpad(message.payload, AES.block_size)) - assert payload_data["dps"]["10000"]["method"] == "prop.set" - assert payload_data["dps"]["10000"]["params"] == {RoborockB01Props.MODE: expected_code} + assert len(fake_channel.published_commands) == 1 + command, params = fake_channel.published_commands[0] + assert command == RoborockB01Q7Methods.SET_PROP + assert params == {RoborockB01Props.MODE: expected_code} -async def test_q7_api_start_clean( - q7_api: Q7PropertiesApi, fake_channel: FakeChannel, message_builder: B01MessageBuilder -): +async def test_q7_api_start_clean(q7_api: Q7PropertiesApi, fake_channel: FakeQ7Channel): """Test starting cleaning.""" - fake_channel.response_queue.append(message_builder.build({"result": "ok"})) + fake_channel.response_queue.append({"result": "ok"}) await q7_api.start_clean() - assert len(fake_channel.published_messages) == 1 - message = fake_channel.published_messages[0] - payload_data = json.loads(unpad(message.payload, AES.block_size)) - assert payload_data["dps"]["10000"]["method"] == "service.set_room_clean" - assert payload_data["dps"]["10000"]["params"] == { + assert len(fake_channel.published_commands) == 1 + command, params = fake_channel.published_commands[0] + assert command == RoborockB01Q7Methods.SET_ROOM_CLEAN + assert params == { "clean_type": CleanTaskTypeMapping.ALL.code, "ctrl_value": SCDeviceCleanParam.START.code, "room_ids": [], } -async def test_q7_api_pause_clean( - q7_api: Q7PropertiesApi, fake_channel: FakeChannel, message_builder: B01MessageBuilder -): +async def test_q7_api_pause_clean(q7_api: Q7PropertiesApi, fake_channel: FakeQ7Channel): """Test pausing cleaning.""" - fake_channel.response_queue.append(message_builder.build({"result": "ok"})) + fake_channel.response_queue.append({"result": "ok"}) await q7_api.pause_clean() - assert len(fake_channel.published_messages) == 1 - message = fake_channel.published_messages[0] - payload_data = json.loads(unpad(message.payload, AES.block_size)) - assert payload_data["dps"]["10000"]["method"] == "service.set_room_clean" - assert payload_data["dps"]["10000"]["params"] == { + assert len(fake_channel.published_commands) == 1 + command, params = fake_channel.published_commands[0] + assert command == RoborockB01Q7Methods.SET_ROOM_CLEAN + assert params == { "clean_type": CleanTaskTypeMapping.ALL.code, "ctrl_value": SCDeviceCleanParam.PAUSE.code, "room_ids": [], } -async def test_q7_api_stop_clean( - q7_api: Q7PropertiesApi, fake_channel: FakeChannel, message_builder: B01MessageBuilder -): +async def test_q7_api_stop_clean(q7_api: Q7PropertiesApi, fake_channel: FakeQ7Channel): """Test stopping cleaning.""" - fake_channel.response_queue.append(message_builder.build({"result": "ok"})) + fake_channel.response_queue.append({"result": "ok"}) await q7_api.stop_clean() - assert len(fake_channel.published_messages) == 1 - message = fake_channel.published_messages[0] - payload_data = json.loads(unpad(message.payload, AES.block_size)) - assert payload_data["dps"]["10000"]["method"] == "service.set_room_clean" - assert payload_data["dps"]["10000"]["params"] == { + assert len(fake_channel.published_commands) == 1 + command, params = fake_channel.published_commands[0] + assert command == RoborockB01Q7Methods.SET_ROOM_CLEAN + assert params == { "clean_type": CleanTaskTypeMapping.ALL.code, "ctrl_value": SCDeviceCleanParam.STOP.code, "room_ids": [], } -async def test_q7_api_return_to_dock( - q7_api: Q7PropertiesApi, fake_channel: FakeChannel, message_builder: B01MessageBuilder -): +async def test_q7_api_return_to_dock(q7_api: Q7PropertiesApi, fake_channel: FakeQ7Channel): """Test returning to dock.""" - fake_channel.response_queue.append(message_builder.build({"result": "ok"})) + fake_channel.response_queue.append({"result": "ok"}) await q7_api.return_to_dock() - assert len(fake_channel.published_messages) == 1 - message = fake_channel.published_messages[0] - payload_data = json.loads(unpad(message.payload, AES.block_size)) - assert payload_data["dps"]["10000"]["method"] == "service.start_recharge" - assert payload_data["dps"]["10000"]["params"] == {} + assert len(fake_channel.published_commands) == 1 + command, params = fake_channel.published_commands[0] + assert command == RoborockB01Q7Methods.START_RECHARGE + assert params == {} -async def test_q7_api_find_me(q7_api: Q7PropertiesApi, fake_channel: FakeChannel, message_builder: B01MessageBuilder): +async def test_q7_api_find_me(q7_api: Q7PropertiesApi, fake_channel: FakeQ7Channel): """Test locating the device.""" - fake_channel.response_queue.append(message_builder.build({"result": "ok"})) + fake_channel.response_queue.append({"result": "ok"}) await q7_api.find_me() - assert len(fake_channel.published_messages) == 1 - message = fake_channel.published_messages[0] - payload_data = json.loads(unpad(message.payload, AES.block_size)) - assert payload_data["dps"]["10000"]["method"] == "service.find_device" - assert payload_data["dps"]["10000"]["params"] == {} + assert len(fake_channel.published_commands) == 1 + command, params = fake_channel.published_commands[0] + assert command == RoborockB01Q7Methods.FIND_DEVICE + assert params == {} -async def test_q7_api_clean_segments( - q7_api: Q7PropertiesApi, fake_channel: FakeChannel, message_builder: B01MessageBuilder -): +async def test_q7_api_clean_segments(q7_api: Q7PropertiesApi, fake_channel: FakeQ7Channel): """Test room/segment cleaning helper for Q7.""" - fake_channel.response_queue.append(message_builder.build({"result": "ok"})) + fake_channel.response_queue.append({"result": "ok"}) await q7_api.clean_segments([10, 11]) - assert len(fake_channel.published_messages) == 1 - message = fake_channel.published_messages[0] - payload_data = json.loads(unpad(message.payload, AES.block_size)) - assert payload_data["dps"]["10000"]["method"] == "service.set_room_clean" - assert payload_data["dps"]["10000"]["params"] == { + assert len(fake_channel.published_commands) == 1 + command, params = fake_channel.published_commands[0] + assert command == RoborockB01Q7Methods.SET_ROOM_CLEAN + assert params == { "clean_type": CleanTaskTypeMapping.ROOM.code, "ctrl_value": SCDeviceCleanParam.START.code, "room_ids": [10, 11], diff --git a/tests/devices/traits/b01/q7/test_map.py b/tests/devices/traits/b01/q7/test_map.py index 04b6a7df..e2011501 100644 --- a/tests/devices/traits/b01/q7/test_map.py +++ b/tests/devices/traits/b01/q7/test_map.py @@ -1,36 +1,17 @@ -from roborock.data import Q7MapList, Q7MapListEntry +from roborock.data.b01_q7.b01_q7_containers import Q7MapListEntry from roborock.devices.traits.b01.q7 import Q7PropertiesApi -from tests.fixtures.channel_fixtures import FakeChannel +from roborock.roborock_typing import RoborockB01Q7Methods -from . import B01MessageBuilder +from .conftest import FakeQ7Channel -async def test_q7_api_map_trait_refresh_populates_cached_values( - q7_api: Q7PropertiesApi, - fake_channel: FakeChannel, - message_builder: B01MessageBuilder, -): - """Map trait follows refresh + cached-value access pattern.""" - fake_channel.response_queue.append(message_builder.build({"map_list": [{"id": 101, "cur": True}]})) - - assert q7_api.map.map_list == [] - assert q7_api.map.current_map_id is None - +async def test_q7_map_refresh(q7_api: Q7PropertiesApi, fake_channel: FakeQ7Channel): + """Test retrieving lists of saved maps.""" + fake_channel.response_queue.append({"map_list": [{"id": 111, "name": "Map 1"}]}) await q7_api.map.refresh() - assert len(fake_channel.published_messages) == 1 - assert q7_api.map.map_list[0].id == 101 - assert q7_api.map.map_list[0].cur is True - assert q7_api.map.current_map_id == 101 - - -def test_q7_map_list_current_map_id_prefers_marked_current(): - """Current-map resolution prefers the entry marked current.""" - map_list = Q7MapList( - map_list=[ - Q7MapListEntry(id=111, cur=False), - Q7MapListEntry(id=222, cur=True), - ] - ) - - assert map_list.current_map_id == 222 + assert len(fake_channel.published_commands) == 1 + command, params = fake_channel.published_commands[0] + assert command == RoborockB01Q7Methods.GET_MAP_LIST + assert params == {} + assert q7_api.map.map_list == [Q7MapListEntry(id=111)] diff --git a/tests/devices/traits/b01/q7/test_map_content.py b/tests/devices/traits/b01/q7/test_map_content.py index 7ec11112..b8753d77 100644 --- a/tests/devices/traits/b01/q7/test_map_content.py +++ b/tests/devices/traits/b01/q7/test_map_content.py @@ -1,28 +1,24 @@ -import json from unittest.mock import patch import pytest -from Crypto.Cipher import AES -from Crypto.Util.Padding import unpad from vacuum_map_parser_base.map_data import MapData from roborock.devices.traits.b01.q7 import Q7PropertiesApi from roborock.exceptions import RoborockException from roborock.map.b01_map_parser import ParsedMapData -from tests.fixtures.channel_fixtures import FakeChannel +from roborock.roborock_typing import RoborockB01Q7Methods -from . import B01MessageBuilder +from .conftest import FakeQ7Channel async def test_q7_map_content_refresh_populates_cached_values( q7_api: Q7PropertiesApi, - fake_channel: FakeChannel, - message_builder: B01MessageBuilder, + fake_channel: FakeQ7Channel, ): fake_channel.response_queue.extend( [ - message_builder.build({"map_list": [{"id": 1772093512, "cur": True}]}), - message_builder.build_map_response(b"raw-map-payload"), + {"map_list": [{"id": 1772093512, "cur": True}]}, + b"inflated-payload", ] ) @@ -34,16 +30,10 @@ async def test_q7_map_content_refresh_populates_cached_values( image_content=b"pngbytes", map_data=dummy_map_data, ) - with ( - patch( - "roborock.devices.rpc.b01_q7_channel.decode_map_payload", - return_value=b"inflated-payload", - ), - patch( - "roborock.devices.traits.b01.q7.map_content.B01MapParser.parse", - return_value=parsed_map_data, - ) as parse, - ): + with patch( + "roborock.devices.traits.b01.q7.map_content.B01MapParser.parse", + return_value=parsed_map_data, + ) as parse: await q7_api.map_content.refresh() assert q7_api.map_content.image_content == b"pngbytes" @@ -52,27 +42,24 @@ async def test_q7_map_content_refresh_populates_cached_values( parse.assert_called_once_with(b"inflated-payload") - assert len(fake_channel.published_messages) == 2 - first = fake_channel.published_messages[0] - first_payload = json.loads(unpad(first.payload, AES.block_size)) - assert first_payload["dps"]["10000"]["method"] == "service.get_map_list" + assert len(fake_channel.published_commands) == 2 + cmd, params = fake_channel.published_commands[0] + assert cmd == RoborockB01Q7Methods.GET_MAP_LIST - second = fake_channel.published_messages[1] - second_payload = json.loads(unpad(second.payload, AES.block_size)) - assert second_payload["dps"]["10000"]["method"] == "service.upload_by_mapid" - assert second_payload["dps"]["10000"]["params"] == {"map_id": 1772093512} + map_cmd, map_params = fake_channel.published_commands[1] + assert map_cmd == RoborockB01Q7Methods.UPLOAD_BY_MAPID + assert map_params == {"map_id": 1772093512} async def test_q7_map_content_refresh_falls_back_to_first_map( q7_api: Q7PropertiesApi, - fake_channel: FakeChannel, - message_builder: B01MessageBuilder, + fake_channel: FakeQ7Channel, ): """If no current map marker exists, first map in list is used.""" fake_channel.response_queue.extend( [ - message_builder.build({"map_list": [{"id": 111}, {"id": 222, "cur": False}]}), - message_builder.build_map_response(b"raw-map-payload"), + {"map_list": [{"id": 111}, {"id": 222, "cur": False}]}, + b"inflated-payload", ] ) @@ -80,30 +67,23 @@ async def test_q7_map_content_refresh_falls_back_to_first_map( await q7_api.map.refresh() dummy_map_data = MapData() - with ( - patch( - "roborock.devices.rpc.b01_q7_channel.decode_map_payload", - return_value=b"inflated-payload", - ), - patch( - "roborock.devices.traits.b01.q7.map_content.B01MapParser.parse", - return_value=type("X", (), {"image_content": b"pngbytes", "map_data": dummy_map_data})(), - ), + with patch( + "roborock.devices.traits.b01.q7.map_content.B01MapParser.parse", + return_value=type("X", (), {"image_content": b"pngbytes", "map_data": dummy_map_data})(), ): await q7_api.map_content.refresh() - second = fake_channel.published_messages[1] - second_payload = json.loads(unpad(second.payload, AES.block_size)) - assert second_payload["dps"]["10000"]["params"] == {"map_id": 111} + assert len(fake_channel.published_commands) == 2 + map_cmd, map_params = fake_channel.published_commands[1] + assert map_params == {"map_id": 111} async def test_q7_map_content_refresh_errors_without_map_list( q7_api: Q7PropertiesApi, - fake_channel: FakeChannel, - message_builder: B01MessageBuilder, + fake_channel: FakeQ7Channel, ): """Refresh should fail clearly when map list is unusable.""" - fake_channel.response_queue.append(message_builder.build({"map_list": []})) + fake_channel.response_queue.extend([{"map_list": []}]) with pytest.raises(RoborockException, match="Unable to determine current map ID"): await q7_api.map_content.refresh()