Disconnect on missing heartbeats

This commit is contained in:
Pierre Ståhl
2020-10-05 12:03:05 +02:00
committed by rospogrigio
parent f971646333
commit b2c4e93a47
2 changed files with 27 additions and 22 deletions

View File

@@ -110,18 +110,10 @@ class TuyaDevice(pytuya.TuyaListener):
def connect(self): def connect(self):
"""Connet to device if not already connected.""" """Connet to device if not already connected."""
if self._connect_task is None: if self._connect_task is None or self._interface:
self._connect_task = asyncio.ensure_future(self._make_connection()) self._connect_task = asyncio.ensure_future(self._make_connection())
async def _make_connection(self): async def _make_connection(self):
# Do nothing if already connected
if self._interface:
return
# The sleep gives another task the possibility to sweep in and
# connect, so we block that here
self._interface = True
backoff = min( backoff = min(
randrange(2 ** self._connection_attempts), BACKOFF_TIME_UPPER_LIMIT randrange(2 ** self._connection_attempts), BACKOFF_TIME_UPPER_LIMIT
) )
@@ -150,8 +142,9 @@ class TuyaDevice(pytuya.TuyaListener):
except Exception: except Exception:
_LOGGER.exception("connect failed") _LOGGER.exception("connect failed")
self._connection_attempts += 1 self._connection_attempts += 1
self._interface.close() if self._interface is not None:
self._interface = None self._interface.close()
self._interface = None
self._hass.loop.call_soon(self.connect) self._hass.loop.call_soon(self.connect)
self._connect_task = None self._connect_task = None
@@ -161,7 +154,7 @@ class TuyaDevice(pytuya.TuyaListener):
try: try:
await self._interface.set_dps(state, dps_index) await self._interface.set_dps(state, dps_index)
except Exception: except Exception:
_LOGGER.exception("Failed to set DP {dps_index} to state") _LOGGER.exception("Failed to set DP %d to %d", dps_index, state)
else: else:
_LOGGER.error( _LOGGER.error(
"Not connected to device %s", self._config_entry[CONF_FRIENDLY_NAME] "Not connected to device %s", self._config_entry[CONF_FRIENDLY_NAME]

View File

@@ -167,11 +167,14 @@ class AESCipher:
class MessageDispatcher: class MessageDispatcher:
"""Buffer and dispatcher for Tuya messages.""" """Buffer and dispatcher for Tuya messages."""
# Heartbeats always respond with sequence number 0, so they can't be waited for like
# other messages. This is a hack to allow waiting for heartbeats.
HEARTBEAT_SEQNO = -100
def __init__(self, listener): def __init__(self, listener):
"""Initialize a new MessageBuffer.""" """Initialize a new MessageBuffer."""
self.buffer = b"" self.buffer = b""
self.listeners = {} self.listeners = {}
self.last_seqno = -1
self.listener = listener self.listener = listener
def abort(self): def abort(self):
@@ -183,9 +186,6 @@ class MessageDispatcher:
async def wait_for(self, seqno, timeout=5): async def wait_for(self, seqno, timeout=5):
"""Wait for response to a sequence number to be received and return it.""" """Wait for response to a sequence number to be received and return it."""
if self.last_seqno >= seqno:
return None
if seqno in self.listeners: if seqno in self.listeners:
raise Exception(f"listener exists for {seqno}") raise Exception(f"listener exists for {seqno}")
@@ -230,7 +230,6 @@ class MessageDispatcher:
def _dispatch(self, msg): def _dispatch(self, msg):
"""Dispatch a message to someone that is listening.""" """Dispatch a message to someone that is listening."""
self.last_seqno = max(self.last_seqno, msg.seqno)
if msg.seqno in self.listeners: if msg.seqno in self.listeners:
_LOGGER.debug("Dispatching sequence number %d", msg.seqno) _LOGGER.debug("Dispatching sequence number %d", msg.seqno)
sem = self.listeners[msg.seqno] sem = self.listeners[msg.seqno]
@@ -238,6 +237,10 @@ class MessageDispatcher:
sem.release() sem.release()
elif msg.cmd == 0x09: elif msg.cmd == 0x09:
_LOGGER.debug("Got heartbeat response") _LOGGER.debug("Got heartbeat response")
if self.HEARTBEAT_SEQNO in self.listeners:
sem = self.listeners[self.HEARTBEAT_SEQNO]
self.listeners[self.HEARTBEAT_SEQNO] = msg
sem.release()
elif msg.cmd == 0x08: elif msg.cmd == 0x08:
_LOGGER.debug("Got status update") _LOGGER.debug("Got status update")
self.listener(msg) self.listener(msg)
@@ -315,8 +318,13 @@ class TuyaProtocol(asyncio.Protocol):
async def heartbeat_loop(): async def heartbeat_loop():
"""Continuously send heart beat updates.""" """Continuously send heart beat updates."""
while True: while True:
self.heartbeat() try:
await self.heartbeat()
except Exception as ex:
_LOGGER.error("Heartbeat failed (%s), disconnecting", ex)
break
await asyncio.sleep(HEARTBEAT_INTERVAL) await asyncio.sleep(HEARTBEAT_INTERVAL)
self.close()
self.transport = transport self.transport = transport
self.on_connected.set_result(True) self.on_connected.set_result(True)
@@ -350,10 +358,13 @@ class TuyaProtocol(asyncio.Protocol):
payload = self._generate_payload(command, dps) payload = self._generate_payload(command, dps)
dev_type = self.dev_type dev_type = self.dev_type
# Wait for special sequence number if heartbeat
seqno = MessageDispatcher.HEARTBEAT_SEQNO if command == HEARTBEAT else (self.seqno - 1)
self.transport.write(payload) self.transport.write(payload)
msg = await self.dispatcher.wait_for(self.seqno - 1) msg = await self.dispatcher.wait_for(seqno)
if msg is None: if msg is None:
_LOGGER.debug("Wait was aborted for %d", self.seqno - 1) _LOGGER.debug("Wait was aborted for %d", seqno)
return None return None
# TODO: Verify stuff, e.g. CRC sequence number? # TODO: Verify stuff, e.g. CRC sequence number?
@@ -374,10 +385,11 @@ class TuyaProtocol(asyncio.Protocol):
"""Return device status.""" """Return device status."""
return await self.exchange(STATUS) return await self.exchange(STATUS)
def heartbeat(self): async def heartbeat(self):
"""Send a heartbeat message.""" """Send a heartbeat message."""
# We don't expect a response to this, just send blindly # We don't expect a response to this, just send blindly
self.transport.write(self._generate_payload(HEARTBEAT)) #self.transport.write(self._generate_payload(HEARTBEAT))
return await self.exchange(HEARTBEAT)
async def set_dps(self, value, dps_index): async def set_dps(self, value, dps_index):
""" """