diff --git a/custom_components/localtuya/__init__.py b/custom_components/localtuya/__init__.py index 17ea4bb..623c2de 100644 --- a/custom_components/localtuya/__init__.py +++ b/custom_components/localtuya/__init__.py @@ -52,12 +52,10 @@ localtuya: """ import asyncio import logging -from datetime import timedelta, datetime from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry from homeassistant.core import HomeAssistant, callback from homeassistant.const import ( - CONF_DEVICE_ID, CONF_PLATFORM, CONF_ENTITIES, SERVICE_RELOAD, @@ -73,9 +71,6 @@ from .common import TuyaDevice _LOGGER = logging.getLogger(__name__) UNSUB_LISTENER = "unsub_listener" -UNSUB_TRACK = "unsub_track" - -POLL_INTERVAL = 30 CONFIG_SCHEMA = config_schema() @@ -136,24 +131,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry): device = TuyaDevice(hass, entry.data) - async def update_state(now): - """Read device status and update platforms.""" - status = None - try: - status = await hass.async_add_executor_job(device.status) - except Exception: - _LOGGER.debug("update failed") - - signal = f"localtuya_{entry.data[CONF_DEVICE_ID]}" - async_dispatcher_send(hass, signal, status) - - unsub_track = async_track_time_interval( - hass, update_state, timedelta(seconds=POLL_INTERVAL) - ) - hass.data[DOMAIN][entry.entry_id] = { UNSUB_LISTENER: unsub_listener, - UNSUB_TRACK: unsub_track, TUYA_DEVICE: device, } @@ -166,8 +145,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry): for entity in entry.data[CONF_ENTITIES] ] ) - - await update_state(datetime.now()) + device.connect() hass.async_create_task(setup_entities()) @@ -188,7 +166,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry): ) hass.data[DOMAIN][entry.entry_id][UNSUB_LISTENER]() - hass.data[DOMAIN][entry.entry_id][UNSUB_TRACK]() + hass.data[DOMAIN][entry.entry_id][TUYA_DEVICE].close() if unload_ok: hass.data[DOMAIN].pop(entry.entry_id) diff --git a/custom_components/localtuya/common.py b/custom_components/localtuya/common.py index 8e85fae..aed142d 100644 --- a/custom_components/localtuya/common.py +++ b/custom_components/localtuya/common.py @@ -1,8 +1,9 @@ """Code shared between all platforms.""" +import asyncio import logging -from time import time, sleep -from threading import Lock +from random import randrange +from homeassistant.core import callback from homeassistant.helpers.entity import Entity from homeassistant.helpers.dispatcher import ( async_dispatcher_connect, @@ -23,6 +24,8 @@ from .const import CONF_LOCAL_KEY, CONF_PROTOCOL_VERSION, DOMAIN, TUYA_DEVICE _LOGGER = logging.getLogger(__name__) +BACKOFF_TIME_UPPER_LIMIT = 300 # Five minutes + def prepare_setup_entities(hass, config_entry, platform): """Prepare ro setup entities for a platform.""" @@ -60,7 +63,7 @@ async def async_setup_entry( # Add DPS used by this platform to the request list for dp_conf in dps_config_fields: if dp_conf in device_config: - tuyainterface._interface.add_dps_to_request(device_config[dp_conf]) + tuyainterface._dps_to_request[device_config[dp_conf]] = None entities.append( entity_class( @@ -88,92 +91,98 @@ def get_entity_config(config_entry, dps_id): raise Exception(f"missing entity config for id {dps_id}") -class TuyaDevice: +class TuyaDevice(pytuya.TuyaListener): """Cache wrapper for pytuya.TuyaInterface.""" def __init__(self, hass, config_entry): """Initialize the cache.""" - self._cached_status = "" - self._cached_status_time = 0 - self._interface = pytuya.TuyaInterface( - config_entry[CONF_DEVICE_ID], - config_entry[CONF_HOST], - config_entry[CONF_LOCAL_KEY], - float(config_entry[CONF_PROTOCOL_VERSION]), - ) - for entity in config_entry[CONF_ENTITIES]: - # this has to be done in case the device type is type_0d - self._interface.add_dps_to_request(entity[CONF_ID]) - self._friendly_name = config_entry[CONF_FRIENDLY_NAME] self._hass = hass - self._lock = Lock() + self._config_entry = config_entry + self._interface = None + self._status = {} + self._dps_to_request = {} + self._connect_task = None + self._connection_attempts = 0 - @property - def unique_id(self): - """Return unique device identifier.""" - return self._interface.id + # This has to be done in case the device type is type_0d + for entity in config_entry[CONF_ENTITIES]: + self._dps_to_request[entity[CONF_ID]] = None - def __get_status(self): - _LOGGER.debug("running def __get_status from TuyaDevice") - for i in range(5): - try: - status = self._interface.status() - return status - except Exception as e: - print( - "Failed to update status of device [{}]: [{}]".format( - self._interface.address, e - ) - ) - sleep(1.0) - if i + 1 == 3: - _LOGGER.error( - "Failed to update status of device %s", self._interface.address - ) - # return None - raise ConnectionError("Failed to update status .") + def connect(self, delay=None): + """Connet to device if not already connected.""" + if self._connect_task is None: + self._connect_task = asyncio.ensure_future(self._make_connection()) - def set_dps(self, state, dps_index): - """Change value of a DP of the Tuya device and update the cached status.""" - # _LOGGER.info("running def set_dps from TuyaDevice") - # No need to clear the cache here: let's just update the status of the - # changed dps as returned by the interface (see 5 lines below) - # self._cached_status = "" - # self._cached_status_time = 0 - for i in range(5): - try: - result = self._interface.set_dps(state, dps_index) - self._cached_status["dps"].update(result["dps"]) - signal = f"localtuya_{self._interface.id}" - async_dispatcher_send(self._hass, signal, self._cached_status) - return - except Exception as e: - print( - "Failed to set status of device [{}]: [{}]".format( - self._interface.address, e - ) - ) - if i + 1 == 3: - _LOGGER.error( - "Failed to set status of device %s", self._interface.address - ) - return + async def _make_connection(self): + # Do nothing if already connected + if self._interface: + return - # raise ConnectionError("Failed to set status.") + # The sleep gives another task the possibility to sweep in and + # connect, so we block that here + self._interface = True + + backoff = min( + randrange(2 ** self._connection_attempts), BACKOFF_TIME_UPPER_LIMIT + ) + + _LOGGER.debug("Waiting %d seconds before connecting", backoff) + await asyncio.sleep(backoff) - def status(self): - """Get the state of the Tuya device and cache the results.""" - _LOGGER.debug("running def status(self) from TuyaDevice") - self._lock.acquire() try: - now = time() - if not self._cached_status or now - self._cached_status_time > 10: - sleep(0.5) - self._cached_status = self.__get_status() - self._cached_status_time = time() - return self._cached_status - finally: - self._lock.release() + _LOGGER.debug("Connecting to %s", self._config_entry[CONF_HOST]) + self._interface = await pytuya.connect( + self._config_entry[CONF_HOST], + self._config_entry[CONF_DEVICE_ID], + self._config_entry[CONF_LOCAL_KEY], + float(self._config_entry[CONF_PROTOCOL_VERSION]), + self, + ) + self._interface.add_dps_to_request(self._dps_to_request) + + _LOGGER.debug("Retrieving initial state") + status = await self._interface.status() + if status is None: + raise Exception("failed to retrieve status") + + self.status_updated(status) + self._connection_attempts = 0 + except Exception: + _LOGGER.exception("connect failed") + self._connection_attempts += 1 + self._interface.close() + self._interface = None + self._hass.loop.call_soon(self.connect) + self._connect_task = None + + async def set_dps(self, state, dps_index): + """Change value of a DP of the Tuya device and update the cached status.""" + if self._interface is not None: + try: + await self._interface.set_dps(state, dps_index) + except Exception: + _LOGGER.exception("Failed to set DP {dps_index} to state") + else: + _LOGGER.error( + "Not connected to device %s", self._config_entry[CONF_FRIENDLY_NAME] + ) + + @callback + def status_updated(self, status): + """Device updated status.""" + self._status.update(status["dps"]) + + signal = f"localtuya_{self._config_entry[CONF_DEVICE_ID]}" + async_dispatcher_send(self._hass, signal, self._status) + + @callback + def disconnected(self, exc): + """Device disconnected.""" + signal = f"localtuya_{self._config_entry[CONF_DEVICE_ID]}" + async_dispatcher_send(self._hass, signal, None) + + self._interface = None + self.connect() class LocalTuyaEntity(Entity): @@ -212,7 +221,7 @@ class LocalTuyaEntity(Entity): return { "identifiers": { # Serial numbers are unique identifiers within a specific domain - (DOMAIN, f"local_{self._device.unique_id}") + (DOMAIN, f"local_{self._config_entry.data[CONF_DEVICE_ID]}") }, "name": self._config_entry.data[CONF_FRIENDLY_NAME], "manufacturer": "Unknown", @@ -233,7 +242,7 @@ class LocalTuyaEntity(Entity): @property def unique_id(self): """Return unique device identifier.""" - return f"local_{self._device.unique_id}_{self._dps_id}" + return f"local_{self._config_entry.data[CONF_DEVICE_ID]}_{self._dps_id}" def has_config(self, attr): """Return if a config parameter has a valid value.""" @@ -243,14 +252,11 @@ class LocalTuyaEntity(Entity): @property def available(self): """Return if device is available or not.""" - return bool(self._status) + return str(self._dps_id) in self._status def dps(self, dps_index): """Return cached value for DPS index.""" - if "dps" not in self._status: - return None - - value = self._status["dps"].get(str(dps_index)) + value = self._status.get(str(dps_index)) if value is None: _LOGGER.warning( "Entity %s is requesting unknown DPS index %s", diff --git a/custom_components/localtuya/config_flow.py b/custom_components/localtuya/config_flow.py index a6816b9..02587f8 100644 --- a/custom_components/localtuya/config_flow.py +++ b/custom_components/localtuya/config_flow.py @@ -159,22 +159,23 @@ def config_schema(): async def validate_input(hass: core.HomeAssistant, data): """Validate the user input allows us to connect.""" - tuyainterface = pytuya.TuyaInterface( - data[CONF_DEVICE_ID], - data[CONF_HOST], - data[CONF_LOCAL_KEY], - float(data[CONF_PROTOCOL_VERSION]), - ) detected_dps = {} try: - detected_dps = await hass.async_add_executor_job( - tuyainterface.detect_available_dps + interface = await pytuya.connect( + data[CONF_HOST], + data[CONF_DEVICE_ID], + data[CONF_LOCAL_KEY], + float(data[CONF_PROTOCOL_VERSION]), ) + + detected_dps = await interface.detect_available_dps() except (ConnectionRefusedError, ConnectionResetError): raise CannotConnect except ValueError: raise InvalidAuth + finally: + interface.close() return dps_string_list(detected_dps) diff --git a/custom_components/localtuya/cover.py b/custom_components/localtuya/cover.py index 18a19f8..1a1afc5 100644 --- a/custom_components/localtuya/cover.py +++ b/custom_components/localtuya/cover.py @@ -1,7 +1,7 @@ """Platform to locally control Tuya-based cover devices.""" +import asyncio import logging from functools import partial -from time import sleep import voluptuous as vol @@ -112,7 +112,7 @@ class LocaltuyaCover(LocalTuyaEntity, CoverEntity): return None return self._current_cover_position == 0 - def set_cover_position(self, **kwargs): + async def async_set_cover_position(self, **kwargs): """Move the cover to a specific position.""" _LOGGER.debug("Setting cover position: %r", kwargs[ATTR_POSITION]) if self._config[CONF_POSITIONING_MODE] == COVER_MODE_FAKE: @@ -123,36 +123,36 @@ class LocaltuyaCover(LocalTuyaEntity, CoverEntity): mydelay = posdiff / 50.0 * self._config[CONF_SPAN_TIME] if newpos > currpos: _LOGGER.debug("Opening to %f: delay %f", newpos, mydelay) - self.open_cover() + await self.async_open_cover() else: _LOGGER.debug("Closing to %f: delay %f", newpos, mydelay) - self.close_cover() - sleep(mydelay) - self.stop_cover() + await self.async_close_cover() + await asyncio.sleep(mydelay) + await self.async_stop_cover() self._current_cover_position = 50 _LOGGER.debug("Done") elif self._config[CONF_POSITIONING_MODE] == COVER_MODE_POSITION: converted_position = int(kwargs[ATTR_POSITION]) if 0 <= converted_position <= 100 and self.has_config(CONF_SET_POSITION_DP): - self._device.set_dps( + await self._device.set_dps( converted_position, self._config[CONF_SET_POSITION_DP] ) - def open_cover(self, **kwargs): + async def async_open_cover(self, **kwargs): """Open the cover.""" _LOGGER.debug("Launching command %s to cover ", self._open_cmd) - self._device.set_dps(self._open_cmd, self._dps_id) + await self._device.set_dps(self._open_cmd, self._dps_id) - def close_cover(self, **kwargs): + async def async_close_cover(self, **kwargs): """Close cover.""" _LOGGER.debug("Launching command %s to cover ", self._close_cmd) - self._device.set_dps(self._close_cmd, self._dps_id) + await self._device.set_dps(self._close_cmd, self._dps_id) - def stop_cover(self, **kwargs): + async def async_stop_cover(self, **kwargs): """Stop the cover.""" _LOGGER.debug("Launching command %s to cover ", COVER_STOP_CMD) - self._device.set_dps(COVER_STOP_CMD, self._dps_id) + await self._device.set_dps(COVER_STOP_CMD, self._dps_id) def status_updated(self): """Device status was updated.""" diff --git a/custom_components/localtuya/fan.py b/custom_components/localtuya/fan.py index 7fd8819..8d62a49 100644 --- a/custom_components/localtuya/fan.py +++ b/custom_components/localtuya/fan.py @@ -1,113 +1,113 @@ -"""Platform to locally control Tuya-based fan devices.""" -import logging -from functools import partial - -from homeassistant.components.fan import ( - FanEntity, - DOMAIN, - SPEED_OFF, - SPEED_LOW, - SPEED_MEDIUM, - SPEED_HIGH, - SUPPORT_SET_SPEED, - SUPPORT_OSCILLATE, -) - -from .common import LocalTuyaEntity, async_setup_entry - -_LOGGER = logging.getLogger(__name__) - - -def flow_schema(dps): - """Return schema used in config flow.""" - return {} - - -class LocaltuyaFan(LocalTuyaEntity, FanEntity): - """Representation of a Tuya fan.""" - - def __init__( - self, - device, - config_entry, - fanid, - **kwargs, - ): - """Initialize the entity.""" - super().__init__(device, config_entry, fanid, **kwargs) - self._is_on = False - self._speed = SPEED_OFF - self._oscillating = False - - @property - def oscillating(self): - """Return current oscillating status.""" - return self._oscillating - - @property - def is_on(self): - """Check if Tuya fan is on.""" - return self._is_on - - @property - def speed(self) -> str: - """Return the current speed.""" - return self._speed - - @property - def speed_list(self) -> list: - """Get the list of available speeds.""" - return [SPEED_OFF, SPEED_LOW, SPEED_MEDIUM, SPEED_HIGH] - - def turn_on(self, speed: str = None, **kwargs) -> None: - """Turn on the entity.""" - self._device.set_dps(True, "1") - if speed is not None: - self.set_speed(speed) - else: - self.schedule_update_ha_state() - - def turn_off(self, **kwargs) -> None: - """Turn off the entity.""" - self._device.set_dps(False, "1") - self.schedule_update_ha_state() - - def set_speed(self, speed: str) -> None: - """Set the speed of the fan.""" - self._speed = speed - if speed == SPEED_OFF: - self._device.set_dps(False, "1") - elif speed == SPEED_LOW: - self._device.set_dps("1", "2") - elif speed == SPEED_MEDIUM: - self._device.set_dps("2", "2") - elif speed == SPEED_HIGH: - self._device.set_dps("3", "2") - self.schedule_update_ha_state() - - def oscillate(self, oscillating: bool) -> None: - """Set oscillation.""" - self._oscillating = oscillating - self._device.set_value("8", oscillating) - self.schedule_update_ha_state() - - @property - def supported_features(self) -> int: - """Flag supported features.""" - return SUPPORT_SET_SPEED | SUPPORT_OSCILLATE - - def status_updated(self): - """Get state of Tuya fan.""" - self._is_on = self._status["dps"]["1"] - if not self._status["dps"]["1"]: - self._speed = SPEED_OFF - elif self._status["dps"]["2"] == "1": - self._speed = SPEED_LOW - elif self._status["dps"]["2"] == "2": - self._speed = SPEED_MEDIUM - elif self._status["dps"]["2"] == "3": - self._speed = SPEED_HIGH - self._oscillating = self._status["dps"]["8"] - - -async_setup_entry = partial(async_setup_entry, DOMAIN, LocaltuyaFan, flow_schema) +"""Platform to locally control Tuya-based fan devices.""" +import logging +from functools import partial + +from homeassistant.components.fan import ( + FanEntity, + DOMAIN, + SPEED_OFF, + SPEED_LOW, + SPEED_MEDIUM, + SPEED_HIGH, + SUPPORT_SET_SPEED, + SUPPORT_OSCILLATE, +) + +from .common import LocalTuyaEntity, async_setup_entry + +_LOGGER = logging.getLogger(__name__) + + +def flow_schema(dps): + """Return schema used in config flow.""" + return {} + + +class LocaltuyaFan(LocalTuyaEntity, FanEntity): + """Representation of a Tuya fan.""" + + def __init__( + self, + device, + config_entry, + fanid, + **kwargs, + ): + """Initialize the entity.""" + super().__init__(device, config_entry, fanid, **kwargs) + self._is_on = False + self._speed = SPEED_OFF + self._oscillating = False + + @property + def oscillating(self): + """Return current oscillating status.""" + return self._oscillating + + @property + def is_on(self): + """Check if Tuya fan is on.""" + return self._is_on + + @property + def speed(self) -> str: + """Return the current speed.""" + return self._speed + + @property + def speed_list(self) -> list: + """Get the list of available speeds.""" + return [SPEED_OFF, SPEED_LOW, SPEED_MEDIUM, SPEED_HIGH] + + def turn_on(self, speed: str = None, **kwargs) -> None: + """Turn on the entity.""" + self._device.set_dps(True, "1") + if speed is not None: + self.set_speed(speed) + else: + self.schedule_update_ha_state() + + def turn_off(self, **kwargs) -> None: + """Turn off the entity.""" + self._device.set_dps(False, "1") + self.schedule_update_ha_state() + + def set_speed(self, speed: str) -> None: + """Set the speed of the fan.""" + self._speed = speed + if speed == SPEED_OFF: + self._device.set_dps(False, "1") + elif speed == SPEED_LOW: + self._device.set_dps("1", "2") + elif speed == SPEED_MEDIUM: + self._device.set_dps("2", "2") + elif speed == SPEED_HIGH: + self._device.set_dps("3", "2") + self.schedule_update_ha_state() + + def oscillate(self, oscillating: bool) -> None: + """Set oscillation.""" + self._oscillating = oscillating + self._device.set_value("8", oscillating) + self.schedule_update_ha_state() + + @property + def supported_features(self) -> int: + """Flag supported features.""" + return SUPPORT_SET_SPEED | SUPPORT_OSCILLATE + + def status_updated(self): + """Get state of Tuya fan.""" + self._is_on = self._status["dps"]["1"] + if not self._status["dps"]["1"]: + self._speed = SPEED_OFF + elif self._status["dps"]["2"] == "1": + self._speed = SPEED_LOW + elif self._status["dps"]["2"] == "2": + self._speed = SPEED_MEDIUM + elif self._status["dps"]["2"] == "3": + self._speed = SPEED_HIGH + self._oscillating = self._status["dps"]["8"] + + +async_setup_entry = partial(async_setup_entry, DOMAIN, LocaltuyaFan, flow_schema) diff --git a/custom_components/localtuya/light.py b/custom_components/localtuya/light.py index ef84098..92a645f 100644 --- a/custom_components/localtuya/light.py +++ b/custom_components/localtuya/light.py @@ -86,12 +86,12 @@ class LocaltuyaLight(LocalTuyaEntity, LightEntity): supports = supports | SUPPORT_COLOR_TEMP return supports - def turn_on(self, **kwargs): + async def async_turn_on(self, **kwargs): """Turn on or control the light.""" - self._device.set_dps(True, self._dps_id) + await self._device.set_dps(True, self._dps_id) if ATTR_BRIGHTNESS in kwargs: - self._device.set_dps( + await self._device.set_dps( max(int(kwargs[ATTR_BRIGHTNESS]), 25), self._config.get(CONF_BRIGHTNESS) ) @@ -104,11 +104,11 @@ class LocaltuyaLight(LocalTuyaEntity, LightEntity): - (255 / (MAX_MIRED - MIN_MIRED)) * (int(kwargs[ATTR_COLOR_TEMP]) - MIN_MIRED) ) - self._device.set_dps(color_temp, self._config.get(CONF_COLOR_TEMP)) + await self._device.set_dps(color_temp, self._config.get(CONF_COLOR_TEMP)) - def turn_off(self, **kwargs): + async def async_turn_off(self, **kwargs): """Turn Tuya light off.""" - self._device.set_dps(False, self._dps_id) + await self._device.set_dps(False, self._dps_id) def status_updated(self): """Device status was updated.""" diff --git a/custom_components/localtuya/pytuya/__init__.py b/custom_components/localtuya/pytuya/__init__.py index 98bcc3d..10b071f 100644 --- a/custom_components/localtuya/pytuya/__init__.py +++ b/custom_components/localtuya/pytuya/__init__.py @@ -36,16 +36,17 @@ Credits Updated pytuya to support devices with Device IDs of 22 characters """ +import asyncio import base64 from hashlib import md5 import json import logging -import socket import time import binascii import struct +import weakref from collections import namedtuple -from contextlib import contextmanager +from abc import ABC, abstractmethod from cryptography.hazmat.backends import default_backend from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes @@ -60,6 +61,7 @@ TuyaMessage = namedtuple("TuyaMessage", "seqno cmd retcode payload crc") SET = "set" STATUS = "status" +HEARTBEAT = "heartbeat" PROTOCOL_VERSION_BYTES_31 = b"3.1" PROTOCOL_VERSION_BYTES_33 = b"3.3" @@ -73,6 +75,7 @@ MESSAGE_END_FMT = ">2I" # 2*uint32: crc, suffix PREFIX_VALUE = 0x000055AA SUFFIX_VALUE = 0x0000AA55 +HEARTBEAT_INTERVAL = 20 # This is intended to match requests.json payload at # https://github.com/codetheweb/tuyapi : @@ -85,35 +88,18 @@ SUFFIX_VALUE = 0x0000AA55 # length, zero padding implies could be more than one byte) PAYLOAD_DICT = { "type_0a": { - "status": {"hexByte": 0x0A, "command": {"gwId": "", "devId": ""}}, - "set": {"hexByte": 0x07, "command": {"devId": "", "uid": "", "t": ""}}, + STATUS: {"hexByte": 0x0A, "command": {"gwId": "", "devId": ""}}, + SET: {"hexByte": 0x07, "command": {"devId": "", "uid": "", "t": ""}}, + HEARTBEAT: {"hexByte": 0x09, "command": {}}, }, "type_0d": { - "status": {"hexByte": 0x0D, "command": {"devId": "", "uid": "", "t": ""}}, - "set": {"hexByte": 0x07, "command": {"devId": "", "uid": "", "t": ""}}, + STATUS: {"hexByte": 0x0D, "command": {"devId": "", "uid": "", "t": ""}}, + SET: {"hexByte": 0x07, "command": {"devId": "", "uid": "", "t": ""}}, + HEARTBEAT: {"hexByte": 0x09, "command": {}}, }, } -@contextmanager -def socketcontext(address, port, timeout): - """Context manager which sets up and tears down socket properly.""" - s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) - s.settimeout(timeout) - s.connect((address, port)) - try: - yield s - except Exception: - # This should probably be a warning or error, but since this happens - # every now and the and we do retries on a higher level, use debug level - # to not spam log with errors. - _LOGGER.debug("Failed to connect to %s. Raising Exception.", address) - raise - finally: - s.close() - - def pack_message(msg): """Pack a TuyaMessage into bytes.""" # Create full message excluding CRC and suffix @@ -178,12 +164,118 @@ class AESCipher: return s[: -ord(s[len(s) - 1 :])] -class TuyaInterface: - """Represent a Tuya device.""" +class MessageDispatcher: + """Buffer and dispatcher for Tuya messages.""" - def __init__( - self, dev_id, address, local_key, protocol_version, connection_timeout=5 - ): + def __init__(self, listener): + """Initialize a new MessageBuffer.""" + self.buffer = b"" + self.listeners = {} + self.last_seqno = -1 + self.listener = listener + + def abort(self): + """Abort all waiting clients.""" + for key in self.listeners: + sem = self.listeners[key] + self.listeners[key] = None + sem.release() + + async def wait_for(self, seqno, timeout=5): + """Wait for response to a sequence number to be received and return it.""" + if self.last_seqno >= seqno: + return None + + if seqno in self.listeners: + raise Exception(f"listener exists for {seqno}") + + _LOGGER.debug("Waiting for sequence number %d", seqno) + self.listeners[seqno] = asyncio.Semaphore(0) + try: + await asyncio.wait_for(self.listeners[seqno].acquire(), timeout=timeout) + except asyncio.TimeoutError: + del self.listeners[seqno] + raise + + return self.listeners.pop(seqno) + + def add_data(self, data): + """Add new data to the buffer and try to parse messages.""" + self.buffer += data + header_len = struct.calcsize(MESSAGE_RECV_HEADER_FMT) + + while self.buffer: + # Check if enough data for measage header + if len(self.buffer) < header_len: + break + + # Parse header and check if enough data according to length in header + _, seqno, cmd, length, retcode = struct.unpack_from( + MESSAGE_RECV_HEADER_FMT, self.buffer + ) + if len(self.buffer[header_len - 4 :]) < length: + break + + # length includes payload length, retcode, crc and suffix + payload_length = length - 4 - struct.calcsize(MESSAGE_END_FMT) + payload = self.buffer[header_len : header_len + payload_length] + + crc, _ = struct.unpack_from( + MESSAGE_END_FMT, + self.buffer[header_len + payload_length : header_len + length], + ) + + self.buffer = self.buffer[header_len + length - 4 :] + self._dispatch(TuyaMessage(seqno, cmd, retcode, payload, crc)) + + def _dispatch(self, msg): + """Dispatch a message to someone that is listening.""" + self.last_seqno = max(self.last_seqno, msg.seqno) + if msg.seqno in self.listeners: + _LOGGER.debug("Dispatching sequence number %d", msg.seqno) + sem = self.listeners[msg.seqno] + self.listeners[msg.seqno] = msg + sem.release() + elif msg.cmd == 0x09: + _LOGGER.debug("Got heartbeat response") + elif msg.cmd == 0x08: + _LOGGER.debug("Got status update") + self.listener(msg) + else: + _LOGGER.debug( + "Got message type %d for unknown listener %d: %s", + msg.command, + msg.seqno, + msg, + ) + + +class TuyaListener(ABC): + """Listener interface for Tuya device changes.""" + + @abstractmethod + def status_updated(self, status): + """Device updated status.""" + + @abstractmethod + def disconnected(self, exc): + """Device disconnected.""" + + +class EmptyListener(TuyaListener): + """Listener doing nothing.""" + + def status_updated(self, status): + """Device updated status.""" + + def disconnected(self, exc): + """Device disconnected.""" + + +class TuyaProtocol(asyncio.Protocol): + """Implementation of the Tuya protocol.""" + + def __init__(self, dev_id, local_key, protocol_version, on_connected, listener): """ Initialize a new TuyaInterface. @@ -195,38 +287,75 @@ class TuyaInterface: Attributes: port (int): The port to connect to. """ + self.loop = asyncio.get_running_loop() self.id = dev_id - self.address = address self.local_key = local_key.encode("latin1") - self.connection_timeout = connection_timeout self.version = protocol_version self.dev_type = "type_0a" self.dps_to_request = {} self.cipher = AESCipher(self.local_key) self.seqno = 0 + self.transport = None + self.listener = weakref.ref(listener) + self.dispatcher = self._setup_dispatcher() + self.on_connected = on_connected + self.heartbeater = None - self.port = 6668 # default - do not expect caller to pass in + def _setup_dispatcher(self): + def _status_update(msg): + listener = self.listener() + if listener is not None: + listener.status_updated(self._decode_payload(msg.payload)) - def exchange(self, command, dps=None): + return MessageDispatcher(_status_update) + + def connection_made(self, transport): + """Did connect to the device.""" + + async def heartbeat_loop(): + """Continuously send heart beat updates.""" + while True: + self.heartbeat() + await asyncio.sleep(HEARTBEAT_INTERVAL) + + self.transport = transport + self.on_connected.set_result(True) + self.heartbeater = self.loop.create_task(heartbeat_loop()) + + def data_received(self, data): + """Received data from device.""" + self.dispatcher.add_data(data) + + def connection_lost(self, exc): + """Disconnected from device.""" + self.close() + listener = self.listener() + if listener is not None: + listener.disconnected(exc) + + def close(self): + """Close connection and abort all outstanding listeners.""" + if self.transport is not None: + self.dispatcher.abort() + self.heartbeater.cancel() + transport = self.transport + self.transport = None + transport.close() + + async def exchange(self, command, dps=None): """Send and receive a message, returning response from device.""" _LOGGER.debug("Sending command %s (device type: %s)", command, self.dev_type) payload = self._generate_payload(command, dps) dev_type = self.dev_type - with socketcontext(self.address, self.port, self.connection_timeout) as s: - s.send(payload) - data = s.recv(1024) + self.transport.write(payload) + msg = await self.dispatcher.wait_for(self.seqno - 1) + if msg is None: + _LOGGER.debug("Wait was aborted for %d", self.seqno - 1) + return None - # sometimes the first packet does not contain data (typically 28 bytes): - # need to read again - if len(data) < 40: - time.sleep(0.1) - data = s.recv(1024) - - msg = unpack_message(data) - # TODO: Verify stuff, e.g. CRC sequence number - - payload = self._decode_payload(msg.payload) + # TODO: Verify stuff, e.g. CRC sequence number? + payload = self._decode_payload(msg.payload) # Perform a new exchange (once) if we switched device type if dev_type != self.dev_type: @@ -239,11 +368,16 @@ class TuyaInterface: return self.exchange(command, dps) return payload - def status(self): + async def status(self): """Return device status.""" - return self.exchange(STATUS) + return await self.exchange(STATUS) - def set_dps(self, value, dps_index): + def heartbeat(self): + """Send a heartbeat message.""" + # We don't expect a response to this, just send blindly + self.transport.write(self._generate_payload(HEARTBEAT)) + + async def set_dps(self, value, dps_index): """ Set value (may be any type: bool, int or string) of any dps index. @@ -251,9 +385,9 @@ class TuyaInterface: dps_index(int): dps index to set value: new value for the dps index """ - return self.exchange(SET, {str(dps_index): value}) + return await self.exchange(SET, {str(dps_index): value}) - def detect_available_dps(self): + async def detect_available_dps(self): """Return which datapoints are supported by the device.""" # type_0d devices need a sort of bruteforce querying in order to detect the # list of available dps experience shows that the dps available are usually @@ -268,7 +402,7 @@ class TuyaInterface: self.dps_to_request = {"1": None} self.add_dps_to_request(range(*dps_range)) try: - data = self.status() + data = await self.status() except Exception as e: _LOGGER.warning("Failed to get status: %s", e) raise @@ -289,7 +423,9 @@ class TuyaInterface: def _decode_payload(self, payload): _LOGGER.debug("decode payload=%r", payload) - if payload.startswith(PROTOCOL_VERSION_BYTES_31): + if not payload: + payload = "{}" + elif payload.startswith(PROTOCOL_VERSION_BYTES_31): payload = payload[len(PROTOCOL_VERSION_BYTES_31) :] # remove version header # remove (what I'm guessing, but not confirmed is) 16-bytes of MD5 # hexdigest of payload @@ -377,4 +513,32 @@ class TuyaInterface: def __repr__(self): """Return internal string representation of object.""" - return "%r" % ((self.id, self.address),) # FIXME can do better than this + return self.id + + +async def connect( + address, + device_id, + local_key, + protocol_version, + listener=None, + port=6668, + timeout=5, +): + """Connect to a device.""" + loop = asyncio.get_running_loop() + on_connected = loop.create_future() + _, protocol = await loop.create_connection( + lambda: TuyaProtocol( + device_id, + local_key, + protocol_version, + on_connected, + listener or EmptyListener(), + ), + address, + port, + ) + + await asyncio.wait_for(on_connected, timeout=timeout) + return protocol diff --git a/custom_components/localtuya/switch.py b/custom_components/localtuya/switch.py index 6a03dca..729618d 100644 --- a/custom_components/localtuya/switch.py +++ b/custom_components/localtuya/switch.py @@ -65,13 +65,13 @@ class LocaltuyaSwitch(LocalTuyaEntity, SwitchEntity): attrs[ATTR_VOLTAGE] = self.dps(self._config[CONF_VOLTAGE]) / 10 return attrs - def turn_on(self, **kwargs): + async def async_turn_on(self, **kwargs): """Turn Tuya switch on.""" - self._device.set_dps(True, self._dps_id) + await self._device.set_dps(True, self._dps_id) - def turn_off(self, **kwargs): + async def async_turn_off(self, **kwargs): """Turn Tuya switch off.""" - self._device.set_dps(False, self._dps_id) + await self._device.set_dps(False, self._dps_id) def status_updated(self): """Device status was updated."""