diff --git a/custom_components/localtuya/common.py b/custom_components/localtuya/common.py index 13c99f0..9b8198f 100644 --- a/custom_components/localtuya/common.py +++ b/custom_components/localtuya/common.py @@ -110,18 +110,10 @@ class TuyaDevice(pytuya.TuyaListener): def connect(self): """Connet to device if not already connected.""" - if self._connect_task is None: + if self._connect_task is None or self._interface: self._connect_task = asyncio.ensure_future(self._make_connection()) async def _make_connection(self): - # Do nothing if already connected - if self._interface: - return - - # The sleep gives another task the possibility to sweep in and - # connect, so we block that here - self._interface = True - backoff = min( randrange(2 ** self._connection_attempts), BACKOFF_TIME_UPPER_LIMIT ) @@ -150,8 +142,9 @@ class TuyaDevice(pytuya.TuyaListener): except Exception: _LOGGER.exception("connect failed") self._connection_attempts += 1 - self._interface.close() - self._interface = None + if self._interface is not None: + self._interface.close() + self._interface = None self._hass.loop.call_soon(self.connect) self._connect_task = None @@ -161,7 +154,7 @@ class TuyaDevice(pytuya.TuyaListener): try: await self._interface.set_dps(state, dps_index) except Exception: - _LOGGER.exception("Failed to set DP {dps_index} to state") + _LOGGER.exception("Failed to set DP %d to %d", dps_index, state) else: _LOGGER.error( "Not connected to device %s", self._config_entry[CONF_FRIENDLY_NAME] diff --git a/custom_components/localtuya/pytuya/__init__.py b/custom_components/localtuya/pytuya/__init__.py index c68d362..6e01106 100644 --- a/custom_components/localtuya/pytuya/__init__.py +++ b/custom_components/localtuya/pytuya/__init__.py @@ -167,11 +167,14 @@ class AESCipher: class MessageDispatcher: """Buffer and dispatcher for Tuya messages.""" + # Heartbeats always respond with sequence number 0, so they can't be waited for like + # other messages. This is a hack to allow waiting for heartbeats. + HEARTBEAT_SEQNO = -100 + def __init__(self, listener): """Initialize a new MessageBuffer.""" self.buffer = b"" self.listeners = {} - self.last_seqno = -1 self.listener = listener def abort(self): @@ -183,9 +186,6 @@ class MessageDispatcher: async def wait_for(self, seqno, timeout=5): """Wait for response to a sequence number to be received and return it.""" - if self.last_seqno >= seqno: - return None - if seqno in self.listeners: raise Exception(f"listener exists for {seqno}") @@ -230,7 +230,6 @@ class MessageDispatcher: def _dispatch(self, msg): """Dispatch a message to someone that is listening.""" - self.last_seqno = max(self.last_seqno, msg.seqno) if msg.seqno in self.listeners: _LOGGER.debug("Dispatching sequence number %d", msg.seqno) sem = self.listeners[msg.seqno] @@ -238,6 +237,10 @@ class MessageDispatcher: sem.release() elif msg.cmd == 0x09: _LOGGER.debug("Got heartbeat response") + if self.HEARTBEAT_SEQNO in self.listeners: + sem = self.listeners[self.HEARTBEAT_SEQNO] + self.listeners[self.HEARTBEAT_SEQNO] = msg + sem.release() elif msg.cmd == 0x08: _LOGGER.debug("Got status update") self.listener(msg) @@ -315,8 +318,13 @@ class TuyaProtocol(asyncio.Protocol): async def heartbeat_loop(): """Continuously send heart beat updates.""" while True: - self.heartbeat() + try: + await self.heartbeat() + except Exception as ex: + _LOGGER.error("Heartbeat failed (%s), disconnecting", ex) + break await asyncio.sleep(HEARTBEAT_INTERVAL) + self.close() self.transport = transport self.on_connected.set_result(True) @@ -350,10 +358,13 @@ class TuyaProtocol(asyncio.Protocol): payload = self._generate_payload(command, dps) dev_type = self.dev_type + # Wait for special sequence number if heartbeat + seqno = MessageDispatcher.HEARTBEAT_SEQNO if command == HEARTBEAT else (self.seqno - 1) + self.transport.write(payload) - msg = await self.dispatcher.wait_for(self.seqno - 1) + msg = await self.dispatcher.wait_for(seqno) if msg is None: - _LOGGER.debug("Wait was aborted for %d", self.seqno - 1) + _LOGGER.debug("Wait was aborted for %d", seqno) return None # TODO: Verify stuff, e.g. CRC sequence number? @@ -374,10 +385,11 @@ class TuyaProtocol(asyncio.Protocol): """Return device status.""" return await self.exchange(STATUS) - def heartbeat(self): + async def heartbeat(self): """Send a heartbeat message.""" # We don't expect a response to this, just send blindly - self.transport.write(self._generate_payload(HEARTBEAT)) + #self.transport.write(self._generate_payload(HEARTBEAT)) + return await self.exchange(HEARTBEAT) async def set_dps(self, value, dps_index): """