Convert pytuya to asyncio

This commit is contained in:
Pierre Ståhl
2020-10-01 09:40:12 +02:00
committed by rospogrigio
parent 084b3a741a
commit cad31f1ffe
8 changed files with 457 additions and 308 deletions

View File

@@ -36,16 +36,17 @@ Credits
Updated pytuya to support devices with Device IDs of 22 characters
"""
import asyncio
import base64
from hashlib import md5
import json
import logging
import socket
import time
import binascii
import struct
import weakref
from collections import namedtuple
from contextlib import contextmanager
from abc import ABC, abstractmethod
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
@@ -60,6 +61,7 @@ TuyaMessage = namedtuple("TuyaMessage", "seqno cmd retcode payload crc")
SET = "set"
STATUS = "status"
HEARTBEAT = "heartbeat"
PROTOCOL_VERSION_BYTES_31 = b"3.1"
PROTOCOL_VERSION_BYTES_33 = b"3.3"
@@ -73,6 +75,7 @@ MESSAGE_END_FMT = ">2I" # 2*uint32: crc, suffix
PREFIX_VALUE = 0x000055AA
SUFFIX_VALUE = 0x0000AA55
HEARTBEAT_INTERVAL = 20
# This is intended to match requests.json payload at
# https://github.com/codetheweb/tuyapi :
@@ -85,35 +88,18 @@ SUFFIX_VALUE = 0x0000AA55
# length, zero padding implies could be more than one byte)
PAYLOAD_DICT = {
"type_0a": {
"status": {"hexByte": 0x0A, "command": {"gwId": "", "devId": ""}},
"set": {"hexByte": 0x07, "command": {"devId": "", "uid": "", "t": ""}},
STATUS: {"hexByte": 0x0A, "command": {"gwId": "", "devId": ""}},
SET: {"hexByte": 0x07, "command": {"devId": "", "uid": "", "t": ""}},
HEARTBEAT: {"hexByte": 0x09, "command": {}},
},
"type_0d": {
"status": {"hexByte": 0x0D, "command": {"devId": "", "uid": "", "t": ""}},
"set": {"hexByte": 0x07, "command": {"devId": "", "uid": "", "t": ""}},
STATUS: {"hexByte": 0x0D, "command": {"devId": "", "uid": "", "t": ""}},
SET: {"hexByte": 0x07, "command": {"devId": "", "uid": "", "t": ""}},
HEARTBEAT: {"hexByte": 0x09, "command": {}},
},
}
@contextmanager
def socketcontext(address, port, timeout):
"""Context manager which sets up and tears down socket properly."""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
s.settimeout(timeout)
s.connect((address, port))
try:
yield s
except Exception:
# This should probably be a warning or error, but since this happens
# every now and the and we do retries on a higher level, use debug level
# to not spam log with errors.
_LOGGER.debug("Failed to connect to %s. Raising Exception.", address)
raise
finally:
s.close()
def pack_message(msg):
"""Pack a TuyaMessage into bytes."""
# Create full message excluding CRC and suffix
@@ -178,12 +164,118 @@ class AESCipher:
return s[: -ord(s[len(s) - 1 :])]
class TuyaInterface:
"""Represent a Tuya device."""
class MessageDispatcher:
"""Buffer and dispatcher for Tuya messages."""
def __init__(
self, dev_id, address, local_key, protocol_version, connection_timeout=5
):
def __init__(self, listener):
"""Initialize a new MessageBuffer."""
self.buffer = b""
self.listeners = {}
self.last_seqno = -1
self.listener = listener
def abort(self):
"""Abort all waiting clients."""
for key in self.listeners:
sem = self.listeners[key]
self.listeners[key] = None
sem.release()
async def wait_for(self, seqno, timeout=5):
"""Wait for response to a sequence number to be received and return it."""
if self.last_seqno >= seqno:
return None
if seqno in self.listeners:
raise Exception(f"listener exists for {seqno}")
_LOGGER.debug("Waiting for sequence number %d", seqno)
self.listeners[seqno] = asyncio.Semaphore(0)
try:
await asyncio.wait_for(self.listeners[seqno].acquire(), timeout=timeout)
except asyncio.TimeoutError:
del self.listeners[seqno]
raise
return self.listeners.pop(seqno)
def add_data(self, data):
"""Add new data to the buffer and try to parse messages."""
self.buffer += data
header_len = struct.calcsize(MESSAGE_RECV_HEADER_FMT)
while self.buffer:
# Check if enough data for measage header
if len(self.buffer) < header_len:
break
# Parse header and check if enough data according to length in header
_, seqno, cmd, length, retcode = struct.unpack_from(
MESSAGE_RECV_HEADER_FMT, self.buffer
)
if len(self.buffer[header_len - 4 :]) < length:
break
# length includes payload length, retcode, crc and suffix
payload_length = length - 4 - struct.calcsize(MESSAGE_END_FMT)
payload = self.buffer[header_len : header_len + payload_length]
crc, _ = struct.unpack_from(
MESSAGE_END_FMT,
self.buffer[header_len + payload_length : header_len + length],
)
self.buffer = self.buffer[header_len + length - 4 :]
self._dispatch(TuyaMessage(seqno, cmd, retcode, payload, crc))
def _dispatch(self, msg):
"""Dispatch a message to someone that is listening."""
self.last_seqno = max(self.last_seqno, msg.seqno)
if msg.seqno in self.listeners:
_LOGGER.debug("Dispatching sequence number %d", msg.seqno)
sem = self.listeners[msg.seqno]
self.listeners[msg.seqno] = msg
sem.release()
elif msg.cmd == 0x09:
_LOGGER.debug("Got heartbeat response")
elif msg.cmd == 0x08:
_LOGGER.debug("Got status update")
self.listener(msg)
else:
_LOGGER.debug(
"Got message type %d for unknown listener %d: %s",
msg.command,
msg.seqno,
msg,
)
class TuyaListener(ABC):
"""Listener interface for Tuya device changes."""
@abstractmethod
def status_updated(self, status):
"""Device updated status."""
@abstractmethod
def disconnected(self, exc):
"""Device disconnected."""
class EmptyListener(TuyaListener):
"""Listener doing nothing."""
def status_updated(self, status):
"""Device updated status."""
def disconnected(self, exc):
"""Device disconnected."""
class TuyaProtocol(asyncio.Protocol):
"""Implementation of the Tuya protocol."""
def __init__(self, dev_id, local_key, protocol_version, on_connected, listener):
"""
Initialize a new TuyaInterface.
@@ -195,38 +287,75 @@ class TuyaInterface:
Attributes:
port (int): The port to connect to.
"""
self.loop = asyncio.get_running_loop()
self.id = dev_id
self.address = address
self.local_key = local_key.encode("latin1")
self.connection_timeout = connection_timeout
self.version = protocol_version
self.dev_type = "type_0a"
self.dps_to_request = {}
self.cipher = AESCipher(self.local_key)
self.seqno = 0
self.transport = None
self.listener = weakref.ref(listener)
self.dispatcher = self._setup_dispatcher()
self.on_connected = on_connected
self.heartbeater = None
self.port = 6668 # default - do not expect caller to pass in
def _setup_dispatcher(self):
def _status_update(msg):
listener = self.listener()
if listener is not None:
listener.status_updated(self._decode_payload(msg.payload))
def exchange(self, command, dps=None):
return MessageDispatcher(_status_update)
def connection_made(self, transport):
"""Did connect to the device."""
async def heartbeat_loop():
"""Continuously send heart beat updates."""
while True:
self.heartbeat()
await asyncio.sleep(HEARTBEAT_INTERVAL)
self.transport = transport
self.on_connected.set_result(True)
self.heartbeater = self.loop.create_task(heartbeat_loop())
def data_received(self, data):
"""Received data from device."""
self.dispatcher.add_data(data)
def connection_lost(self, exc):
"""Disconnected from device."""
self.close()
listener = self.listener()
if listener is not None:
listener.disconnected(exc)
def close(self):
"""Close connection and abort all outstanding listeners."""
if self.transport is not None:
self.dispatcher.abort()
self.heartbeater.cancel()
transport = self.transport
self.transport = None
transport.close()
async def exchange(self, command, dps=None):
"""Send and receive a message, returning response from device."""
_LOGGER.debug("Sending command %s (device type: %s)", command, self.dev_type)
payload = self._generate_payload(command, dps)
dev_type = self.dev_type
with socketcontext(self.address, self.port, self.connection_timeout) as s:
s.send(payload)
data = s.recv(1024)
self.transport.write(payload)
msg = await self.dispatcher.wait_for(self.seqno - 1)
if msg is None:
_LOGGER.debug("Wait was aborted for %d", self.seqno - 1)
return None
# sometimes the first packet does not contain data (typically 28 bytes):
# need to read again
if len(data) < 40:
time.sleep(0.1)
data = s.recv(1024)
msg = unpack_message(data)
# TODO: Verify stuff, e.g. CRC sequence number
payload = self._decode_payload(msg.payload)
# TODO: Verify stuff, e.g. CRC sequence number?
payload = self._decode_payload(msg.payload)
# Perform a new exchange (once) if we switched device type
if dev_type != self.dev_type:
@@ -239,11 +368,16 @@ class TuyaInterface:
return self.exchange(command, dps)
return payload
def status(self):
async def status(self):
"""Return device status."""
return self.exchange(STATUS)
return await self.exchange(STATUS)
def set_dps(self, value, dps_index):
def heartbeat(self):
"""Send a heartbeat message."""
# We don't expect a response to this, just send blindly
self.transport.write(self._generate_payload(HEARTBEAT))
async def set_dps(self, value, dps_index):
"""
Set value (may be any type: bool, int or string) of any dps index.
@@ -251,9 +385,9 @@ class TuyaInterface:
dps_index(int): dps index to set
value: new value for the dps index
"""
return self.exchange(SET, {str(dps_index): value})
return await self.exchange(SET, {str(dps_index): value})
def detect_available_dps(self):
async def detect_available_dps(self):
"""Return which datapoints are supported by the device."""
# type_0d devices need a sort of bruteforce querying in order to detect the
# list of available dps experience shows that the dps available are usually
@@ -268,7 +402,7 @@ class TuyaInterface:
self.dps_to_request = {"1": None}
self.add_dps_to_request(range(*dps_range))
try:
data = self.status()
data = await self.status()
except Exception as e:
_LOGGER.warning("Failed to get status: %s", e)
raise
@@ -289,7 +423,9 @@ class TuyaInterface:
def _decode_payload(self, payload):
_LOGGER.debug("decode payload=%r", payload)
if payload.startswith(PROTOCOL_VERSION_BYTES_31):
if not payload:
payload = "{}"
elif payload.startswith(PROTOCOL_VERSION_BYTES_31):
payload = payload[len(PROTOCOL_VERSION_BYTES_31) :] # remove version header
# remove (what I'm guessing, but not confirmed is) 16-bytes of MD5
# hexdigest of payload
@@ -377,4 +513,32 @@ class TuyaInterface:
def __repr__(self):
"""Return internal string representation of object."""
return "%r" % ((self.id, self.address),) # FIXME can do better than this
return self.id
async def connect(
address,
device_id,
local_key,
protocol_version,
listener=None,
port=6668,
timeout=5,
):
"""Connect to a device."""
loop = asyncio.get_running_loop()
on_connected = loop.create_future()
_, protocol = await loop.create_connection(
lambda: TuyaProtocol(
device_id,
local_key,
protocol_version,
on_connected,
listener or EmptyListener(),
),
address,
port,
)
await asyncio.wait_for(on_connected, timeout=timeout)
return protocol