Convert pytuya to asyncio
This commit is contained in:
committed by
rospogrigio
parent
084b3a741a
commit
cad31f1ffe
@@ -36,16 +36,17 @@ Credits
|
||||
Updated pytuya to support devices with Device IDs of 22 characters
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from hashlib import md5
|
||||
import json
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
import binascii
|
||||
import struct
|
||||
import weakref
|
||||
from collections import namedtuple
|
||||
from contextlib import contextmanager
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
@@ -60,6 +61,7 @@ TuyaMessage = namedtuple("TuyaMessage", "seqno cmd retcode payload crc")
|
||||
|
||||
SET = "set"
|
||||
STATUS = "status"
|
||||
HEARTBEAT = "heartbeat"
|
||||
|
||||
PROTOCOL_VERSION_BYTES_31 = b"3.1"
|
||||
PROTOCOL_VERSION_BYTES_33 = b"3.3"
|
||||
@@ -73,6 +75,7 @@ MESSAGE_END_FMT = ">2I" # 2*uint32: crc, suffix
|
||||
PREFIX_VALUE = 0x000055AA
|
||||
SUFFIX_VALUE = 0x0000AA55
|
||||
|
||||
HEARTBEAT_INTERVAL = 20
|
||||
|
||||
# This is intended to match requests.json payload at
|
||||
# https://github.com/codetheweb/tuyapi :
|
||||
@@ -85,35 +88,18 @@ SUFFIX_VALUE = 0x0000AA55
|
||||
# length, zero padding implies could be more than one byte)
|
||||
PAYLOAD_DICT = {
|
||||
"type_0a": {
|
||||
"status": {"hexByte": 0x0A, "command": {"gwId": "", "devId": ""}},
|
||||
"set": {"hexByte": 0x07, "command": {"devId": "", "uid": "", "t": ""}},
|
||||
STATUS: {"hexByte": 0x0A, "command": {"gwId": "", "devId": ""}},
|
||||
SET: {"hexByte": 0x07, "command": {"devId": "", "uid": "", "t": ""}},
|
||||
HEARTBEAT: {"hexByte": 0x09, "command": {}},
|
||||
},
|
||||
"type_0d": {
|
||||
"status": {"hexByte": 0x0D, "command": {"devId": "", "uid": "", "t": ""}},
|
||||
"set": {"hexByte": 0x07, "command": {"devId": "", "uid": "", "t": ""}},
|
||||
STATUS: {"hexByte": 0x0D, "command": {"devId": "", "uid": "", "t": ""}},
|
||||
SET: {"hexByte": 0x07, "command": {"devId": "", "uid": "", "t": ""}},
|
||||
HEARTBEAT: {"hexByte": 0x09, "command": {}},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def socketcontext(address, port, timeout):
|
||||
"""Context manager which sets up and tears down socket properly."""
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
s.settimeout(timeout)
|
||||
s.connect((address, port))
|
||||
try:
|
||||
yield s
|
||||
except Exception:
|
||||
# This should probably be a warning or error, but since this happens
|
||||
# every now and the and we do retries on a higher level, use debug level
|
||||
# to not spam log with errors.
|
||||
_LOGGER.debug("Failed to connect to %s. Raising Exception.", address)
|
||||
raise
|
||||
finally:
|
||||
s.close()
|
||||
|
||||
|
||||
def pack_message(msg):
|
||||
"""Pack a TuyaMessage into bytes."""
|
||||
# Create full message excluding CRC and suffix
|
||||
@@ -178,12 +164,118 @@ class AESCipher:
|
||||
return s[: -ord(s[len(s) - 1 :])]
|
||||
|
||||
|
||||
class TuyaInterface:
|
||||
"""Represent a Tuya device."""
|
||||
class MessageDispatcher:
|
||||
"""Buffer and dispatcher for Tuya messages."""
|
||||
|
||||
def __init__(
|
||||
self, dev_id, address, local_key, protocol_version, connection_timeout=5
|
||||
):
|
||||
def __init__(self, listener):
|
||||
"""Initialize a new MessageBuffer."""
|
||||
self.buffer = b""
|
||||
self.listeners = {}
|
||||
self.last_seqno = -1
|
||||
self.listener = listener
|
||||
|
||||
def abort(self):
|
||||
"""Abort all waiting clients."""
|
||||
for key in self.listeners:
|
||||
sem = self.listeners[key]
|
||||
self.listeners[key] = None
|
||||
sem.release()
|
||||
|
||||
async def wait_for(self, seqno, timeout=5):
|
||||
"""Wait for response to a sequence number to be received and return it."""
|
||||
if self.last_seqno >= seqno:
|
||||
return None
|
||||
|
||||
if seqno in self.listeners:
|
||||
raise Exception(f"listener exists for {seqno}")
|
||||
|
||||
_LOGGER.debug("Waiting for sequence number %d", seqno)
|
||||
self.listeners[seqno] = asyncio.Semaphore(0)
|
||||
try:
|
||||
await asyncio.wait_for(self.listeners[seqno].acquire(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
del self.listeners[seqno]
|
||||
raise
|
||||
|
||||
return self.listeners.pop(seqno)
|
||||
|
||||
def add_data(self, data):
|
||||
"""Add new data to the buffer and try to parse messages."""
|
||||
self.buffer += data
|
||||
header_len = struct.calcsize(MESSAGE_RECV_HEADER_FMT)
|
||||
|
||||
while self.buffer:
|
||||
# Check if enough data for measage header
|
||||
if len(self.buffer) < header_len:
|
||||
break
|
||||
|
||||
# Parse header and check if enough data according to length in header
|
||||
_, seqno, cmd, length, retcode = struct.unpack_from(
|
||||
MESSAGE_RECV_HEADER_FMT, self.buffer
|
||||
)
|
||||
if len(self.buffer[header_len - 4 :]) < length:
|
||||
break
|
||||
|
||||
# length includes payload length, retcode, crc and suffix
|
||||
payload_length = length - 4 - struct.calcsize(MESSAGE_END_FMT)
|
||||
payload = self.buffer[header_len : header_len + payload_length]
|
||||
|
||||
crc, _ = struct.unpack_from(
|
||||
MESSAGE_END_FMT,
|
||||
self.buffer[header_len + payload_length : header_len + length],
|
||||
)
|
||||
|
||||
self.buffer = self.buffer[header_len + length - 4 :]
|
||||
self._dispatch(TuyaMessage(seqno, cmd, retcode, payload, crc))
|
||||
|
||||
def _dispatch(self, msg):
|
||||
"""Dispatch a message to someone that is listening."""
|
||||
self.last_seqno = max(self.last_seqno, msg.seqno)
|
||||
if msg.seqno in self.listeners:
|
||||
_LOGGER.debug("Dispatching sequence number %d", msg.seqno)
|
||||
sem = self.listeners[msg.seqno]
|
||||
self.listeners[msg.seqno] = msg
|
||||
sem.release()
|
||||
elif msg.cmd == 0x09:
|
||||
_LOGGER.debug("Got heartbeat response")
|
||||
elif msg.cmd == 0x08:
|
||||
_LOGGER.debug("Got status update")
|
||||
self.listener(msg)
|
||||
else:
|
||||
_LOGGER.debug(
|
||||
"Got message type %d for unknown listener %d: %s",
|
||||
msg.command,
|
||||
msg.seqno,
|
||||
msg,
|
||||
)
|
||||
|
||||
|
||||
class TuyaListener(ABC):
|
||||
"""Listener interface for Tuya device changes."""
|
||||
|
||||
@abstractmethod
|
||||
def status_updated(self, status):
|
||||
"""Device updated status."""
|
||||
|
||||
@abstractmethod
|
||||
def disconnected(self, exc):
|
||||
"""Device disconnected."""
|
||||
|
||||
|
||||
class EmptyListener(TuyaListener):
|
||||
"""Listener doing nothing."""
|
||||
|
||||
def status_updated(self, status):
|
||||
"""Device updated status."""
|
||||
|
||||
def disconnected(self, exc):
|
||||
"""Device disconnected."""
|
||||
|
||||
|
||||
class TuyaProtocol(asyncio.Protocol):
|
||||
"""Implementation of the Tuya protocol."""
|
||||
|
||||
def __init__(self, dev_id, local_key, protocol_version, on_connected, listener):
|
||||
"""
|
||||
Initialize a new TuyaInterface.
|
||||
|
||||
@@ -195,38 +287,75 @@ class TuyaInterface:
|
||||
Attributes:
|
||||
port (int): The port to connect to.
|
||||
"""
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self.id = dev_id
|
||||
self.address = address
|
||||
self.local_key = local_key.encode("latin1")
|
||||
self.connection_timeout = connection_timeout
|
||||
self.version = protocol_version
|
||||
self.dev_type = "type_0a"
|
||||
self.dps_to_request = {}
|
||||
self.cipher = AESCipher(self.local_key)
|
||||
self.seqno = 0
|
||||
self.transport = None
|
||||
self.listener = weakref.ref(listener)
|
||||
self.dispatcher = self._setup_dispatcher()
|
||||
self.on_connected = on_connected
|
||||
self.heartbeater = None
|
||||
|
||||
self.port = 6668 # default - do not expect caller to pass in
|
||||
def _setup_dispatcher(self):
|
||||
def _status_update(msg):
|
||||
listener = self.listener()
|
||||
if listener is not None:
|
||||
listener.status_updated(self._decode_payload(msg.payload))
|
||||
|
||||
def exchange(self, command, dps=None):
|
||||
return MessageDispatcher(_status_update)
|
||||
|
||||
def connection_made(self, transport):
|
||||
"""Did connect to the device."""
|
||||
|
||||
async def heartbeat_loop():
|
||||
"""Continuously send heart beat updates."""
|
||||
while True:
|
||||
self.heartbeat()
|
||||
await asyncio.sleep(HEARTBEAT_INTERVAL)
|
||||
|
||||
self.transport = transport
|
||||
self.on_connected.set_result(True)
|
||||
self.heartbeater = self.loop.create_task(heartbeat_loop())
|
||||
|
||||
def data_received(self, data):
|
||||
"""Received data from device."""
|
||||
self.dispatcher.add_data(data)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
"""Disconnected from device."""
|
||||
self.close()
|
||||
listener = self.listener()
|
||||
if listener is not None:
|
||||
listener.disconnected(exc)
|
||||
|
||||
def close(self):
|
||||
"""Close connection and abort all outstanding listeners."""
|
||||
if self.transport is not None:
|
||||
self.dispatcher.abort()
|
||||
self.heartbeater.cancel()
|
||||
transport = self.transport
|
||||
self.transport = None
|
||||
transport.close()
|
||||
|
||||
async def exchange(self, command, dps=None):
|
||||
"""Send and receive a message, returning response from device."""
|
||||
_LOGGER.debug("Sending command %s (device type: %s)", command, self.dev_type)
|
||||
payload = self._generate_payload(command, dps)
|
||||
dev_type = self.dev_type
|
||||
|
||||
with socketcontext(self.address, self.port, self.connection_timeout) as s:
|
||||
s.send(payload)
|
||||
data = s.recv(1024)
|
||||
self.transport.write(payload)
|
||||
msg = await self.dispatcher.wait_for(self.seqno - 1)
|
||||
if msg is None:
|
||||
_LOGGER.debug("Wait was aborted for %d", self.seqno - 1)
|
||||
return None
|
||||
|
||||
# sometimes the first packet does not contain data (typically 28 bytes):
|
||||
# need to read again
|
||||
if len(data) < 40:
|
||||
time.sleep(0.1)
|
||||
data = s.recv(1024)
|
||||
|
||||
msg = unpack_message(data)
|
||||
# TODO: Verify stuff, e.g. CRC sequence number
|
||||
|
||||
payload = self._decode_payload(msg.payload)
|
||||
# TODO: Verify stuff, e.g. CRC sequence number?
|
||||
payload = self._decode_payload(msg.payload)
|
||||
|
||||
# Perform a new exchange (once) if we switched device type
|
||||
if dev_type != self.dev_type:
|
||||
@@ -239,11 +368,16 @@ class TuyaInterface:
|
||||
return self.exchange(command, dps)
|
||||
return payload
|
||||
|
||||
def status(self):
|
||||
async def status(self):
|
||||
"""Return device status."""
|
||||
return self.exchange(STATUS)
|
||||
return await self.exchange(STATUS)
|
||||
|
||||
def set_dps(self, value, dps_index):
|
||||
def heartbeat(self):
|
||||
"""Send a heartbeat message."""
|
||||
# We don't expect a response to this, just send blindly
|
||||
self.transport.write(self._generate_payload(HEARTBEAT))
|
||||
|
||||
async def set_dps(self, value, dps_index):
|
||||
"""
|
||||
Set value (may be any type: bool, int or string) of any dps index.
|
||||
|
||||
@@ -251,9 +385,9 @@ class TuyaInterface:
|
||||
dps_index(int): dps index to set
|
||||
value: new value for the dps index
|
||||
"""
|
||||
return self.exchange(SET, {str(dps_index): value})
|
||||
return await self.exchange(SET, {str(dps_index): value})
|
||||
|
||||
def detect_available_dps(self):
|
||||
async def detect_available_dps(self):
|
||||
"""Return which datapoints are supported by the device."""
|
||||
# type_0d devices need a sort of bruteforce querying in order to detect the
|
||||
# list of available dps experience shows that the dps available are usually
|
||||
@@ -268,7 +402,7 @@ class TuyaInterface:
|
||||
self.dps_to_request = {"1": None}
|
||||
self.add_dps_to_request(range(*dps_range))
|
||||
try:
|
||||
data = self.status()
|
||||
data = await self.status()
|
||||
except Exception as e:
|
||||
_LOGGER.warning("Failed to get status: %s", e)
|
||||
raise
|
||||
@@ -289,7 +423,9 @@ class TuyaInterface:
|
||||
def _decode_payload(self, payload):
|
||||
_LOGGER.debug("decode payload=%r", payload)
|
||||
|
||||
if payload.startswith(PROTOCOL_VERSION_BYTES_31):
|
||||
if not payload:
|
||||
payload = "{}"
|
||||
elif payload.startswith(PROTOCOL_VERSION_BYTES_31):
|
||||
payload = payload[len(PROTOCOL_VERSION_BYTES_31) :] # remove version header
|
||||
# remove (what I'm guessing, but not confirmed is) 16-bytes of MD5
|
||||
# hexdigest of payload
|
||||
@@ -377,4 +513,32 @@ class TuyaInterface:
|
||||
|
||||
def __repr__(self):
|
||||
"""Return internal string representation of object."""
|
||||
return "%r" % ((self.id, self.address),) # FIXME can do better than this
|
||||
return self.id
|
||||
|
||||
|
||||
async def connect(
|
||||
address,
|
||||
device_id,
|
||||
local_key,
|
||||
protocol_version,
|
||||
listener=None,
|
||||
port=6668,
|
||||
timeout=5,
|
||||
):
|
||||
"""Connect to a device."""
|
||||
loop = asyncio.get_running_loop()
|
||||
on_connected = loop.create_future()
|
||||
_, protocol = await loop.create_connection(
|
||||
lambda: TuyaProtocol(
|
||||
device_id,
|
||||
local_key,
|
||||
protocol_version,
|
||||
on_connected,
|
||||
listener or EmptyListener(),
|
||||
),
|
||||
address,
|
||||
port,
|
||||
)
|
||||
|
||||
await asyncio.wait_for(on_connected, timeout=timeout)
|
||||
return protocol
|
||||
|
Reference in New Issue
Block a user