Convert pytuya to asyncio
This commit is contained in:
committed by
rospogrigio
parent
084b3a741a
commit
cad31f1ffe
@@ -52,12 +52,10 @@ localtuya:
|
||||
"""
|
||||
import asyncio
|
||||
import logging
|
||||
from datetime import timedelta, datetime
|
||||
|
||||
from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry
|
||||
from homeassistant.core import HomeAssistant, callback
|
||||
from homeassistant.const import (
|
||||
CONF_DEVICE_ID,
|
||||
CONF_PLATFORM,
|
||||
CONF_ENTITIES,
|
||||
SERVICE_RELOAD,
|
||||
@@ -73,9 +71,6 @@ from .common import TuyaDevice
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
UNSUB_LISTENER = "unsub_listener"
|
||||
UNSUB_TRACK = "unsub_track"
|
||||
|
||||
POLL_INTERVAL = 30
|
||||
|
||||
CONFIG_SCHEMA = config_schema()
|
||||
|
||||
@@ -136,24 +131,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry):
|
||||
|
||||
device = TuyaDevice(hass, entry.data)
|
||||
|
||||
async def update_state(now):
|
||||
"""Read device status and update platforms."""
|
||||
status = None
|
||||
try:
|
||||
status = await hass.async_add_executor_job(device.status)
|
||||
except Exception:
|
||||
_LOGGER.debug("update failed")
|
||||
|
||||
signal = f"localtuya_{entry.data[CONF_DEVICE_ID]}"
|
||||
async_dispatcher_send(hass, signal, status)
|
||||
|
||||
unsub_track = async_track_time_interval(
|
||||
hass, update_state, timedelta(seconds=POLL_INTERVAL)
|
||||
)
|
||||
|
||||
hass.data[DOMAIN][entry.entry_id] = {
|
||||
UNSUB_LISTENER: unsub_listener,
|
||||
UNSUB_TRACK: unsub_track,
|
||||
TUYA_DEVICE: device,
|
||||
}
|
||||
|
||||
@@ -166,8 +145,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry):
|
||||
for entity in entry.data[CONF_ENTITIES]
|
||||
]
|
||||
)
|
||||
|
||||
await update_state(datetime.now())
|
||||
device.connect()
|
||||
|
||||
hass.async_create_task(setup_entities())
|
||||
|
||||
@@ -188,7 +166,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry):
|
||||
)
|
||||
|
||||
hass.data[DOMAIN][entry.entry_id][UNSUB_LISTENER]()
|
||||
hass.data[DOMAIN][entry.entry_id][UNSUB_TRACK]()
|
||||
hass.data[DOMAIN][entry.entry_id][TUYA_DEVICE].close()
|
||||
if unload_ok:
|
||||
hass.data[DOMAIN].pop(entry.entry_id)
|
||||
|
||||
|
@@ -1,8 +1,9 @@
|
||||
"""Code shared between all platforms."""
|
||||
import asyncio
|
||||
import logging
|
||||
from time import time, sleep
|
||||
from threading import Lock
|
||||
from random import randrange
|
||||
|
||||
from homeassistant.core import callback
|
||||
from homeassistant.helpers.entity import Entity
|
||||
from homeassistant.helpers.dispatcher import (
|
||||
async_dispatcher_connect,
|
||||
@@ -23,6 +24,8 @@ from .const import CONF_LOCAL_KEY, CONF_PROTOCOL_VERSION, DOMAIN, TUYA_DEVICE
|
||||
|
||||
_LOGGER = logging.getLogger(__name__)
|
||||
|
||||
BACKOFF_TIME_UPPER_LIMIT = 300 # Five minutes
|
||||
|
||||
|
||||
def prepare_setup_entities(hass, config_entry, platform):
|
||||
"""Prepare ro setup entities for a platform."""
|
||||
@@ -60,7 +63,7 @@ async def async_setup_entry(
|
||||
# Add DPS used by this platform to the request list
|
||||
for dp_conf in dps_config_fields:
|
||||
if dp_conf in device_config:
|
||||
tuyainterface._interface.add_dps_to_request(device_config[dp_conf])
|
||||
tuyainterface._dps_to_request[device_config[dp_conf]] = None
|
||||
|
||||
entities.append(
|
||||
entity_class(
|
||||
@@ -88,92 +91,98 @@ def get_entity_config(config_entry, dps_id):
|
||||
raise Exception(f"missing entity config for id {dps_id}")
|
||||
|
||||
|
||||
class TuyaDevice:
|
||||
class TuyaDevice(pytuya.TuyaListener):
|
||||
"""Cache wrapper for pytuya.TuyaInterface."""
|
||||
|
||||
def __init__(self, hass, config_entry):
|
||||
"""Initialize the cache."""
|
||||
self._cached_status = ""
|
||||
self._cached_status_time = 0
|
||||
self._interface = pytuya.TuyaInterface(
|
||||
config_entry[CONF_DEVICE_ID],
|
||||
config_entry[CONF_HOST],
|
||||
config_entry[CONF_LOCAL_KEY],
|
||||
float(config_entry[CONF_PROTOCOL_VERSION]),
|
||||
)
|
||||
for entity in config_entry[CONF_ENTITIES]:
|
||||
# this has to be done in case the device type is type_0d
|
||||
self._interface.add_dps_to_request(entity[CONF_ID])
|
||||
self._friendly_name = config_entry[CONF_FRIENDLY_NAME]
|
||||
self._hass = hass
|
||||
self._lock = Lock()
|
||||
self._config_entry = config_entry
|
||||
self._interface = None
|
||||
self._status = {}
|
||||
self._dps_to_request = {}
|
||||
self._connect_task = None
|
||||
self._connection_attempts = 0
|
||||
|
||||
@property
|
||||
def unique_id(self):
|
||||
"""Return unique device identifier."""
|
||||
return self._interface.id
|
||||
# This has to be done in case the device type is type_0d
|
||||
for entity in config_entry[CONF_ENTITIES]:
|
||||
self._dps_to_request[entity[CONF_ID]] = None
|
||||
|
||||
def connect(self, delay=None):
|
||||
"""Connet to device if not already connected."""
|
||||
if self._connect_task is None:
|
||||
self._connect_task = asyncio.ensure_future(self._make_connection())
|
||||
|
||||
async def _make_connection(self):
|
||||
# Do nothing if already connected
|
||||
if self._interface:
|
||||
return
|
||||
|
||||
# The sleep gives another task the possibility to sweep in and
|
||||
# connect, so we block that here
|
||||
self._interface = True
|
||||
|
||||
backoff = min(
|
||||
randrange(2 ** self._connection_attempts), BACKOFF_TIME_UPPER_LIMIT
|
||||
)
|
||||
|
||||
_LOGGER.debug("Waiting %d seconds before connecting", backoff)
|
||||
await asyncio.sleep(backoff)
|
||||
|
||||
def __get_status(self):
|
||||
_LOGGER.debug("running def __get_status from TuyaDevice")
|
||||
for i in range(5):
|
||||
try:
|
||||
status = self._interface.status()
|
||||
return status
|
||||
except Exception as e:
|
||||
print(
|
||||
"Failed to update status of device [{}]: [{}]".format(
|
||||
self._interface.address, e
|
||||
_LOGGER.debug("Connecting to %s", self._config_entry[CONF_HOST])
|
||||
self._interface = await pytuya.connect(
|
||||
self._config_entry[CONF_HOST],
|
||||
self._config_entry[CONF_DEVICE_ID],
|
||||
self._config_entry[CONF_LOCAL_KEY],
|
||||
float(self._config_entry[CONF_PROTOCOL_VERSION]),
|
||||
self,
|
||||
)
|
||||
)
|
||||
sleep(1.0)
|
||||
if i + 1 == 3:
|
||||
_LOGGER.error(
|
||||
"Failed to update status of device %s", self._interface.address
|
||||
)
|
||||
# return None
|
||||
raise ConnectionError("Failed to update status .")
|
||||
self._interface.add_dps_to_request(self._dps_to_request)
|
||||
|
||||
def set_dps(self, state, dps_index):
|
||||
_LOGGER.debug("Retrieving initial state")
|
||||
status = await self._interface.status()
|
||||
if status is None:
|
||||
raise Exception("failed to retrieve status")
|
||||
|
||||
self.status_updated(status)
|
||||
self._connection_attempts = 0
|
||||
except Exception:
|
||||
_LOGGER.exception("connect failed")
|
||||
self._connection_attempts += 1
|
||||
self._interface.close()
|
||||
self._interface = None
|
||||
self._hass.loop.call_soon(self.connect)
|
||||
self._connect_task = None
|
||||
|
||||
async def set_dps(self, state, dps_index):
|
||||
"""Change value of a DP of the Tuya device and update the cached status."""
|
||||
# _LOGGER.info("running def set_dps from TuyaDevice")
|
||||
# No need to clear the cache here: let's just update the status of the
|
||||
# changed dps as returned by the interface (see 5 lines below)
|
||||
# self._cached_status = ""
|
||||
# self._cached_status_time = 0
|
||||
for i in range(5):
|
||||
if self._interface is not None:
|
||||
try:
|
||||
result = self._interface.set_dps(state, dps_index)
|
||||
self._cached_status["dps"].update(result["dps"])
|
||||
signal = f"localtuya_{self._interface.id}"
|
||||
async_dispatcher_send(self._hass, signal, self._cached_status)
|
||||
return
|
||||
except Exception as e:
|
||||
print(
|
||||
"Failed to set status of device [{}]: [{}]".format(
|
||||
self._interface.address, e
|
||||
)
|
||||
)
|
||||
if i + 1 == 3:
|
||||
await self._interface.set_dps(state, dps_index)
|
||||
except Exception:
|
||||
_LOGGER.exception("Failed to set DP {dps_index} to state")
|
||||
else:
|
||||
_LOGGER.error(
|
||||
"Failed to set status of device %s", self._interface.address
|
||||
"Not connected to device %s", self._config_entry[CONF_FRIENDLY_NAME]
|
||||
)
|
||||
return
|
||||
|
||||
# raise ConnectionError("Failed to set status.")
|
||||
@callback
|
||||
def status_updated(self, status):
|
||||
"""Device updated status."""
|
||||
self._status.update(status["dps"])
|
||||
|
||||
def status(self):
|
||||
"""Get the state of the Tuya device and cache the results."""
|
||||
_LOGGER.debug("running def status(self) from TuyaDevice")
|
||||
self._lock.acquire()
|
||||
try:
|
||||
now = time()
|
||||
if not self._cached_status or now - self._cached_status_time > 10:
|
||||
sleep(0.5)
|
||||
self._cached_status = self.__get_status()
|
||||
self._cached_status_time = time()
|
||||
return self._cached_status
|
||||
finally:
|
||||
self._lock.release()
|
||||
signal = f"localtuya_{self._config_entry[CONF_DEVICE_ID]}"
|
||||
async_dispatcher_send(self._hass, signal, self._status)
|
||||
|
||||
@callback
|
||||
def disconnected(self, exc):
|
||||
"""Device disconnected."""
|
||||
signal = f"localtuya_{self._config_entry[CONF_DEVICE_ID]}"
|
||||
async_dispatcher_send(self._hass, signal, None)
|
||||
|
||||
self._interface = None
|
||||
self.connect()
|
||||
|
||||
|
||||
class LocalTuyaEntity(Entity):
|
||||
@@ -212,7 +221,7 @@ class LocalTuyaEntity(Entity):
|
||||
return {
|
||||
"identifiers": {
|
||||
# Serial numbers are unique identifiers within a specific domain
|
||||
(DOMAIN, f"local_{self._device.unique_id}")
|
||||
(DOMAIN, f"local_{self._config_entry.data[CONF_DEVICE_ID]}")
|
||||
},
|
||||
"name": self._config_entry.data[CONF_FRIENDLY_NAME],
|
||||
"manufacturer": "Unknown",
|
||||
@@ -233,7 +242,7 @@ class LocalTuyaEntity(Entity):
|
||||
@property
|
||||
def unique_id(self):
|
||||
"""Return unique device identifier."""
|
||||
return f"local_{self._device.unique_id}_{self._dps_id}"
|
||||
return f"local_{self._config_entry.data[CONF_DEVICE_ID]}_{self._dps_id}"
|
||||
|
||||
def has_config(self, attr):
|
||||
"""Return if a config parameter has a valid value."""
|
||||
@@ -243,14 +252,11 @@ class LocalTuyaEntity(Entity):
|
||||
@property
|
||||
def available(self):
|
||||
"""Return if device is available or not."""
|
||||
return bool(self._status)
|
||||
return str(self._dps_id) in self._status
|
||||
|
||||
def dps(self, dps_index):
|
||||
"""Return cached value for DPS index."""
|
||||
if "dps" not in self._status:
|
||||
return None
|
||||
|
||||
value = self._status["dps"].get(str(dps_index))
|
||||
value = self._status.get(str(dps_index))
|
||||
if value is None:
|
||||
_LOGGER.warning(
|
||||
"Entity %s is requesting unknown DPS index %s",
|
||||
|
@@ -159,22 +159,23 @@ def config_schema():
|
||||
|
||||
async def validate_input(hass: core.HomeAssistant, data):
|
||||
"""Validate the user input allows us to connect."""
|
||||
tuyainterface = pytuya.TuyaInterface(
|
||||
data[CONF_DEVICE_ID],
|
||||
data[CONF_HOST],
|
||||
data[CONF_LOCAL_KEY],
|
||||
float(data[CONF_PROTOCOL_VERSION]),
|
||||
)
|
||||
detected_dps = {}
|
||||
|
||||
try:
|
||||
detected_dps = await hass.async_add_executor_job(
|
||||
tuyainterface.detect_available_dps
|
||||
interface = await pytuya.connect(
|
||||
data[CONF_HOST],
|
||||
data[CONF_DEVICE_ID],
|
||||
data[CONF_LOCAL_KEY],
|
||||
float(data[CONF_PROTOCOL_VERSION]),
|
||||
)
|
||||
|
||||
detected_dps = await interface.detect_available_dps()
|
||||
except (ConnectionRefusedError, ConnectionResetError):
|
||||
raise CannotConnect
|
||||
except ValueError:
|
||||
raise InvalidAuth
|
||||
finally:
|
||||
interface.close()
|
||||
|
||||
return dps_string_list(detected_dps)
|
||||
|
||||
|
@@ -1,7 +1,7 @@
|
||||
"""Platform to locally control Tuya-based cover devices."""
|
||||
import asyncio
|
||||
import logging
|
||||
from functools import partial
|
||||
from time import sleep
|
||||
|
||||
import voluptuous as vol
|
||||
|
||||
@@ -112,7 +112,7 @@ class LocaltuyaCover(LocalTuyaEntity, CoverEntity):
|
||||
return None
|
||||
return self._current_cover_position == 0
|
||||
|
||||
def set_cover_position(self, **kwargs):
|
||||
async def async_set_cover_position(self, **kwargs):
|
||||
"""Move the cover to a specific position."""
|
||||
_LOGGER.debug("Setting cover position: %r", kwargs[ATTR_POSITION])
|
||||
if self._config[CONF_POSITIONING_MODE] == COVER_MODE_FAKE:
|
||||
@@ -123,36 +123,36 @@ class LocaltuyaCover(LocalTuyaEntity, CoverEntity):
|
||||
mydelay = posdiff / 50.0 * self._config[CONF_SPAN_TIME]
|
||||
if newpos > currpos:
|
||||
_LOGGER.debug("Opening to %f: delay %f", newpos, mydelay)
|
||||
self.open_cover()
|
||||
await self.async_open_cover()
|
||||
else:
|
||||
_LOGGER.debug("Closing to %f: delay %f", newpos, mydelay)
|
||||
self.close_cover()
|
||||
sleep(mydelay)
|
||||
self.stop_cover()
|
||||
await self.async_close_cover()
|
||||
await asyncio.sleep(mydelay)
|
||||
await self.async_stop_cover()
|
||||
self._current_cover_position = 50
|
||||
_LOGGER.debug("Done")
|
||||
|
||||
elif self._config[CONF_POSITIONING_MODE] == COVER_MODE_POSITION:
|
||||
converted_position = int(kwargs[ATTR_POSITION])
|
||||
if 0 <= converted_position <= 100 and self.has_config(CONF_SET_POSITION_DP):
|
||||
self._device.set_dps(
|
||||
await self._device.set_dps(
|
||||
converted_position, self._config[CONF_SET_POSITION_DP]
|
||||
)
|
||||
|
||||
def open_cover(self, **kwargs):
|
||||
async def async_open_cover(self, **kwargs):
|
||||
"""Open the cover."""
|
||||
_LOGGER.debug("Launching command %s to cover ", self._open_cmd)
|
||||
self._device.set_dps(self._open_cmd, self._dps_id)
|
||||
await self._device.set_dps(self._open_cmd, self._dps_id)
|
||||
|
||||
def close_cover(self, **kwargs):
|
||||
async def async_close_cover(self, **kwargs):
|
||||
"""Close cover."""
|
||||
_LOGGER.debug("Launching command %s to cover ", self._close_cmd)
|
||||
self._device.set_dps(self._close_cmd, self._dps_id)
|
||||
await self._device.set_dps(self._close_cmd, self._dps_id)
|
||||
|
||||
def stop_cover(self, **kwargs):
|
||||
async def async_stop_cover(self, **kwargs):
|
||||
"""Stop the cover."""
|
||||
_LOGGER.debug("Launching command %s to cover ", COVER_STOP_CMD)
|
||||
self._device.set_dps(COVER_STOP_CMD, self._dps_id)
|
||||
await self._device.set_dps(COVER_STOP_CMD, self._dps_id)
|
||||
|
||||
def status_updated(self):
|
||||
"""Device status was updated."""
|
||||
|
@@ -86,12 +86,12 @@ class LocaltuyaLight(LocalTuyaEntity, LightEntity):
|
||||
supports = supports | SUPPORT_COLOR_TEMP
|
||||
return supports
|
||||
|
||||
def turn_on(self, **kwargs):
|
||||
async def async_turn_on(self, **kwargs):
|
||||
"""Turn on or control the light."""
|
||||
self._device.set_dps(True, self._dps_id)
|
||||
await self._device.set_dps(True, self._dps_id)
|
||||
|
||||
if ATTR_BRIGHTNESS in kwargs:
|
||||
self._device.set_dps(
|
||||
await self._device.set_dps(
|
||||
max(int(kwargs[ATTR_BRIGHTNESS]), 25), self._config.get(CONF_BRIGHTNESS)
|
||||
)
|
||||
|
||||
@@ -104,11 +104,11 @@ class LocaltuyaLight(LocalTuyaEntity, LightEntity):
|
||||
- (255 / (MAX_MIRED - MIN_MIRED))
|
||||
* (int(kwargs[ATTR_COLOR_TEMP]) - MIN_MIRED)
|
||||
)
|
||||
self._device.set_dps(color_temp, self._config.get(CONF_COLOR_TEMP))
|
||||
await self._device.set_dps(color_temp, self._config.get(CONF_COLOR_TEMP))
|
||||
|
||||
def turn_off(self, **kwargs):
|
||||
async def async_turn_off(self, **kwargs):
|
||||
"""Turn Tuya light off."""
|
||||
self._device.set_dps(False, self._dps_id)
|
||||
await self._device.set_dps(False, self._dps_id)
|
||||
|
||||
def status_updated(self):
|
||||
"""Device status was updated."""
|
||||
|
@@ -36,16 +36,17 @@ Credits
|
||||
Updated pytuya to support devices with Device IDs of 22 characters
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
from hashlib import md5
|
||||
import json
|
||||
import logging
|
||||
import socket
|
||||
import time
|
||||
import binascii
|
||||
import struct
|
||||
import weakref
|
||||
from collections import namedtuple
|
||||
from contextlib import contextmanager
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
from cryptography.hazmat.backends import default_backend
|
||||
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
|
||||
@@ -60,6 +61,7 @@ TuyaMessage = namedtuple("TuyaMessage", "seqno cmd retcode payload crc")
|
||||
|
||||
SET = "set"
|
||||
STATUS = "status"
|
||||
HEARTBEAT = "heartbeat"
|
||||
|
||||
PROTOCOL_VERSION_BYTES_31 = b"3.1"
|
||||
PROTOCOL_VERSION_BYTES_33 = b"3.3"
|
||||
@@ -73,6 +75,7 @@ MESSAGE_END_FMT = ">2I" # 2*uint32: crc, suffix
|
||||
PREFIX_VALUE = 0x000055AA
|
||||
SUFFIX_VALUE = 0x0000AA55
|
||||
|
||||
HEARTBEAT_INTERVAL = 20
|
||||
|
||||
# This is intended to match requests.json payload at
|
||||
# https://github.com/codetheweb/tuyapi :
|
||||
@@ -85,35 +88,18 @@ SUFFIX_VALUE = 0x0000AA55
|
||||
# length, zero padding implies could be more than one byte)
|
||||
PAYLOAD_DICT = {
|
||||
"type_0a": {
|
||||
"status": {"hexByte": 0x0A, "command": {"gwId": "", "devId": ""}},
|
||||
"set": {"hexByte": 0x07, "command": {"devId": "", "uid": "", "t": ""}},
|
||||
STATUS: {"hexByte": 0x0A, "command": {"gwId": "", "devId": ""}},
|
||||
SET: {"hexByte": 0x07, "command": {"devId": "", "uid": "", "t": ""}},
|
||||
HEARTBEAT: {"hexByte": 0x09, "command": {}},
|
||||
},
|
||||
"type_0d": {
|
||||
"status": {"hexByte": 0x0D, "command": {"devId": "", "uid": "", "t": ""}},
|
||||
"set": {"hexByte": 0x07, "command": {"devId": "", "uid": "", "t": ""}},
|
||||
STATUS: {"hexByte": 0x0D, "command": {"devId": "", "uid": "", "t": ""}},
|
||||
SET: {"hexByte": 0x07, "command": {"devId": "", "uid": "", "t": ""}},
|
||||
HEARTBEAT: {"hexByte": 0x09, "command": {}},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@contextmanager
|
||||
def socketcontext(address, port, timeout):
|
||||
"""Context manager which sets up and tears down socket properly."""
|
||||
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
|
||||
s.settimeout(timeout)
|
||||
s.connect((address, port))
|
||||
try:
|
||||
yield s
|
||||
except Exception:
|
||||
# This should probably be a warning or error, but since this happens
|
||||
# every now and the and we do retries on a higher level, use debug level
|
||||
# to not spam log with errors.
|
||||
_LOGGER.debug("Failed to connect to %s. Raising Exception.", address)
|
||||
raise
|
||||
finally:
|
||||
s.close()
|
||||
|
||||
|
||||
def pack_message(msg):
|
||||
"""Pack a TuyaMessage into bytes."""
|
||||
# Create full message excluding CRC and suffix
|
||||
@@ -178,12 +164,118 @@ class AESCipher:
|
||||
return s[: -ord(s[len(s) - 1 :])]
|
||||
|
||||
|
||||
class TuyaInterface:
|
||||
"""Represent a Tuya device."""
|
||||
class MessageDispatcher:
|
||||
"""Buffer and dispatcher for Tuya messages."""
|
||||
|
||||
def __init__(
|
||||
self, dev_id, address, local_key, protocol_version, connection_timeout=5
|
||||
):
|
||||
def __init__(self, listener):
|
||||
"""Initialize a new MessageBuffer."""
|
||||
self.buffer = b""
|
||||
self.listeners = {}
|
||||
self.last_seqno = -1
|
||||
self.listener = listener
|
||||
|
||||
def abort(self):
|
||||
"""Abort all waiting clients."""
|
||||
for key in self.listeners:
|
||||
sem = self.listeners[key]
|
||||
self.listeners[key] = None
|
||||
sem.release()
|
||||
|
||||
async def wait_for(self, seqno, timeout=5):
|
||||
"""Wait for response to a sequence number to be received and return it."""
|
||||
if self.last_seqno >= seqno:
|
||||
return None
|
||||
|
||||
if seqno in self.listeners:
|
||||
raise Exception(f"listener exists for {seqno}")
|
||||
|
||||
_LOGGER.debug("Waiting for sequence number %d", seqno)
|
||||
self.listeners[seqno] = asyncio.Semaphore(0)
|
||||
try:
|
||||
await asyncio.wait_for(self.listeners[seqno].acquire(), timeout=timeout)
|
||||
except asyncio.TimeoutError:
|
||||
del self.listeners[seqno]
|
||||
raise
|
||||
|
||||
return self.listeners.pop(seqno)
|
||||
|
||||
def add_data(self, data):
|
||||
"""Add new data to the buffer and try to parse messages."""
|
||||
self.buffer += data
|
||||
header_len = struct.calcsize(MESSAGE_RECV_HEADER_FMT)
|
||||
|
||||
while self.buffer:
|
||||
# Check if enough data for measage header
|
||||
if len(self.buffer) < header_len:
|
||||
break
|
||||
|
||||
# Parse header and check if enough data according to length in header
|
||||
_, seqno, cmd, length, retcode = struct.unpack_from(
|
||||
MESSAGE_RECV_HEADER_FMT, self.buffer
|
||||
)
|
||||
if len(self.buffer[header_len - 4 :]) < length:
|
||||
break
|
||||
|
||||
# length includes payload length, retcode, crc and suffix
|
||||
payload_length = length - 4 - struct.calcsize(MESSAGE_END_FMT)
|
||||
payload = self.buffer[header_len : header_len + payload_length]
|
||||
|
||||
crc, _ = struct.unpack_from(
|
||||
MESSAGE_END_FMT,
|
||||
self.buffer[header_len + payload_length : header_len + length],
|
||||
)
|
||||
|
||||
self.buffer = self.buffer[header_len + length - 4 :]
|
||||
self._dispatch(TuyaMessage(seqno, cmd, retcode, payload, crc))
|
||||
|
||||
def _dispatch(self, msg):
|
||||
"""Dispatch a message to someone that is listening."""
|
||||
self.last_seqno = max(self.last_seqno, msg.seqno)
|
||||
if msg.seqno in self.listeners:
|
||||
_LOGGER.debug("Dispatching sequence number %d", msg.seqno)
|
||||
sem = self.listeners[msg.seqno]
|
||||
self.listeners[msg.seqno] = msg
|
||||
sem.release()
|
||||
elif msg.cmd == 0x09:
|
||||
_LOGGER.debug("Got heartbeat response")
|
||||
elif msg.cmd == 0x08:
|
||||
_LOGGER.debug("Got status update")
|
||||
self.listener(msg)
|
||||
else:
|
||||
_LOGGER.debug(
|
||||
"Got message type %d for unknown listener %d: %s",
|
||||
msg.command,
|
||||
msg.seqno,
|
||||
msg,
|
||||
)
|
||||
|
||||
|
||||
class TuyaListener(ABC):
|
||||
"""Listener interface for Tuya device changes."""
|
||||
|
||||
@abstractmethod
|
||||
def status_updated(self, status):
|
||||
"""Device updated status."""
|
||||
|
||||
@abstractmethod
|
||||
def disconnected(self, exc):
|
||||
"""Device disconnected."""
|
||||
|
||||
|
||||
class EmptyListener(TuyaListener):
|
||||
"""Listener doing nothing."""
|
||||
|
||||
def status_updated(self, status):
|
||||
"""Device updated status."""
|
||||
|
||||
def disconnected(self, exc):
|
||||
"""Device disconnected."""
|
||||
|
||||
|
||||
class TuyaProtocol(asyncio.Protocol):
|
||||
"""Implementation of the Tuya protocol."""
|
||||
|
||||
def __init__(self, dev_id, local_key, protocol_version, on_connected, listener):
|
||||
"""
|
||||
Initialize a new TuyaInterface.
|
||||
|
||||
@@ -195,37 +287,74 @@ class TuyaInterface:
|
||||
Attributes:
|
||||
port (int): The port to connect to.
|
||||
"""
|
||||
self.loop = asyncio.get_running_loop()
|
||||
self.id = dev_id
|
||||
self.address = address
|
||||
self.local_key = local_key.encode("latin1")
|
||||
self.connection_timeout = connection_timeout
|
||||
self.version = protocol_version
|
||||
self.dev_type = "type_0a"
|
||||
self.dps_to_request = {}
|
||||
self.cipher = AESCipher(self.local_key)
|
||||
self.seqno = 0
|
||||
self.transport = None
|
||||
self.listener = weakref.ref(listener)
|
||||
self.dispatcher = self._setup_dispatcher()
|
||||
self.on_connected = on_connected
|
||||
self.heartbeater = None
|
||||
|
||||
self.port = 6668 # default - do not expect caller to pass in
|
||||
def _setup_dispatcher(self):
|
||||
def _status_update(msg):
|
||||
listener = self.listener()
|
||||
if listener is not None:
|
||||
listener.status_updated(self._decode_payload(msg.payload))
|
||||
|
||||
def exchange(self, command, dps=None):
|
||||
return MessageDispatcher(_status_update)
|
||||
|
||||
def connection_made(self, transport):
|
||||
"""Did connect to the device."""
|
||||
|
||||
async def heartbeat_loop():
|
||||
"""Continuously send heart beat updates."""
|
||||
while True:
|
||||
self.heartbeat()
|
||||
await asyncio.sleep(HEARTBEAT_INTERVAL)
|
||||
|
||||
self.transport = transport
|
||||
self.on_connected.set_result(True)
|
||||
self.heartbeater = self.loop.create_task(heartbeat_loop())
|
||||
|
||||
def data_received(self, data):
|
||||
"""Received data from device."""
|
||||
self.dispatcher.add_data(data)
|
||||
|
||||
def connection_lost(self, exc):
|
||||
"""Disconnected from device."""
|
||||
self.close()
|
||||
listener = self.listener()
|
||||
if listener is not None:
|
||||
listener.disconnected(exc)
|
||||
|
||||
def close(self):
|
||||
"""Close connection and abort all outstanding listeners."""
|
||||
if self.transport is not None:
|
||||
self.dispatcher.abort()
|
||||
self.heartbeater.cancel()
|
||||
transport = self.transport
|
||||
self.transport = None
|
||||
transport.close()
|
||||
|
||||
async def exchange(self, command, dps=None):
|
||||
"""Send and receive a message, returning response from device."""
|
||||
_LOGGER.debug("Sending command %s (device type: %s)", command, self.dev_type)
|
||||
payload = self._generate_payload(command, dps)
|
||||
dev_type = self.dev_type
|
||||
|
||||
with socketcontext(self.address, self.port, self.connection_timeout) as s:
|
||||
s.send(payload)
|
||||
data = s.recv(1024)
|
||||
|
||||
# sometimes the first packet does not contain data (typically 28 bytes):
|
||||
# need to read again
|
||||
if len(data) < 40:
|
||||
time.sleep(0.1)
|
||||
data = s.recv(1024)
|
||||
|
||||
msg = unpack_message(data)
|
||||
# TODO: Verify stuff, e.g. CRC sequence number
|
||||
self.transport.write(payload)
|
||||
msg = await self.dispatcher.wait_for(self.seqno - 1)
|
||||
if msg is None:
|
||||
_LOGGER.debug("Wait was aborted for %d", self.seqno - 1)
|
||||
return None
|
||||
|
||||
# TODO: Verify stuff, e.g. CRC sequence number?
|
||||
payload = self._decode_payload(msg.payload)
|
||||
|
||||
# Perform a new exchange (once) if we switched device type
|
||||
@@ -239,11 +368,16 @@ class TuyaInterface:
|
||||
return self.exchange(command, dps)
|
||||
return payload
|
||||
|
||||
def status(self):
|
||||
async def status(self):
|
||||
"""Return device status."""
|
||||
return self.exchange(STATUS)
|
||||
return await self.exchange(STATUS)
|
||||
|
||||
def set_dps(self, value, dps_index):
|
||||
def heartbeat(self):
|
||||
"""Send a heartbeat message."""
|
||||
# We don't expect a response to this, just send blindly
|
||||
self.transport.write(self._generate_payload(HEARTBEAT))
|
||||
|
||||
async def set_dps(self, value, dps_index):
|
||||
"""
|
||||
Set value (may be any type: bool, int or string) of any dps index.
|
||||
|
||||
@@ -251,9 +385,9 @@ class TuyaInterface:
|
||||
dps_index(int): dps index to set
|
||||
value: new value for the dps index
|
||||
"""
|
||||
return self.exchange(SET, {str(dps_index): value})
|
||||
return await self.exchange(SET, {str(dps_index): value})
|
||||
|
||||
def detect_available_dps(self):
|
||||
async def detect_available_dps(self):
|
||||
"""Return which datapoints are supported by the device."""
|
||||
# type_0d devices need a sort of bruteforce querying in order to detect the
|
||||
# list of available dps experience shows that the dps available are usually
|
||||
@@ -268,7 +402,7 @@ class TuyaInterface:
|
||||
self.dps_to_request = {"1": None}
|
||||
self.add_dps_to_request(range(*dps_range))
|
||||
try:
|
||||
data = self.status()
|
||||
data = await self.status()
|
||||
except Exception as e:
|
||||
_LOGGER.warning("Failed to get status: %s", e)
|
||||
raise
|
||||
@@ -289,7 +423,9 @@ class TuyaInterface:
|
||||
def _decode_payload(self, payload):
|
||||
_LOGGER.debug("decode payload=%r", payload)
|
||||
|
||||
if payload.startswith(PROTOCOL_VERSION_BYTES_31):
|
||||
if not payload:
|
||||
payload = "{}"
|
||||
elif payload.startswith(PROTOCOL_VERSION_BYTES_31):
|
||||
payload = payload[len(PROTOCOL_VERSION_BYTES_31) :] # remove version header
|
||||
# remove (what I'm guessing, but not confirmed is) 16-bytes of MD5
|
||||
# hexdigest of payload
|
||||
@@ -377,4 +513,32 @@ class TuyaInterface:
|
||||
|
||||
def __repr__(self):
|
||||
"""Return internal string representation of object."""
|
||||
return "%r" % ((self.id, self.address),) # FIXME can do better than this
|
||||
return self.id
|
||||
|
||||
|
||||
async def connect(
|
||||
address,
|
||||
device_id,
|
||||
local_key,
|
||||
protocol_version,
|
||||
listener=None,
|
||||
port=6668,
|
||||
timeout=5,
|
||||
):
|
||||
"""Connect to a device."""
|
||||
loop = asyncio.get_running_loop()
|
||||
on_connected = loop.create_future()
|
||||
_, protocol = await loop.create_connection(
|
||||
lambda: TuyaProtocol(
|
||||
device_id,
|
||||
local_key,
|
||||
protocol_version,
|
||||
on_connected,
|
||||
listener or EmptyListener(),
|
||||
),
|
||||
address,
|
||||
port,
|
||||
)
|
||||
|
||||
await asyncio.wait_for(on_connected, timeout=timeout)
|
||||
return protocol
|
||||
|
@@ -65,13 +65,13 @@ class LocaltuyaSwitch(LocalTuyaEntity, SwitchEntity):
|
||||
attrs[ATTR_VOLTAGE] = self.dps(self._config[CONF_VOLTAGE]) / 10
|
||||
return attrs
|
||||
|
||||
def turn_on(self, **kwargs):
|
||||
async def async_turn_on(self, **kwargs):
|
||||
"""Turn Tuya switch on."""
|
||||
self._device.set_dps(True, self._dps_id)
|
||||
await self._device.set_dps(True, self._dps_id)
|
||||
|
||||
def turn_off(self, **kwargs):
|
||||
async def async_turn_off(self, **kwargs):
|
||||
"""Turn Tuya switch off."""
|
||||
self._device.set_dps(False, self._dps_id)
|
||||
await self._device.set_dps(False, self._dps_id)
|
||||
|
||||
def status_updated(self):
|
||||
"""Device status was updated."""
|
||||
|
Reference in New Issue
Block a user