Convert pytuya to asyncio

This commit is contained in:
Pierre Ståhl
2020-10-01 09:40:12 +02:00
committed by rospogrigio
parent 084b3a741a
commit cad31f1ffe
8 changed files with 457 additions and 308 deletions

View File

@@ -52,12 +52,10 @@ localtuya:
""" """
import asyncio import asyncio
import logging import logging
from datetime import timedelta, datetime
from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry from homeassistant.config_entries import SOURCE_IMPORT, ConfigEntry
from homeassistant.core import HomeAssistant, callback from homeassistant.core import HomeAssistant, callback
from homeassistant.const import ( from homeassistant.const import (
CONF_DEVICE_ID,
CONF_PLATFORM, CONF_PLATFORM,
CONF_ENTITIES, CONF_ENTITIES,
SERVICE_RELOAD, SERVICE_RELOAD,
@@ -73,9 +71,6 @@ from .common import TuyaDevice
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
UNSUB_LISTENER = "unsub_listener" UNSUB_LISTENER = "unsub_listener"
UNSUB_TRACK = "unsub_track"
POLL_INTERVAL = 30
CONFIG_SCHEMA = config_schema() CONFIG_SCHEMA = config_schema()
@@ -136,24 +131,8 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry):
device = TuyaDevice(hass, entry.data) device = TuyaDevice(hass, entry.data)
async def update_state(now):
"""Read device status and update platforms."""
status = None
try:
status = await hass.async_add_executor_job(device.status)
except Exception:
_LOGGER.debug("update failed")
signal = f"localtuya_{entry.data[CONF_DEVICE_ID]}"
async_dispatcher_send(hass, signal, status)
unsub_track = async_track_time_interval(
hass, update_state, timedelta(seconds=POLL_INTERVAL)
)
hass.data[DOMAIN][entry.entry_id] = { hass.data[DOMAIN][entry.entry_id] = {
UNSUB_LISTENER: unsub_listener, UNSUB_LISTENER: unsub_listener,
UNSUB_TRACK: unsub_track,
TUYA_DEVICE: device, TUYA_DEVICE: device,
} }
@@ -166,8 +145,7 @@ async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry):
for entity in entry.data[CONF_ENTITIES] for entity in entry.data[CONF_ENTITIES]
] ]
) )
device.connect()
await update_state(datetime.now())
hass.async_create_task(setup_entities()) hass.async_create_task(setup_entities())
@@ -188,7 +166,7 @@ async def async_unload_entry(hass: HomeAssistant, entry: ConfigEntry):
) )
hass.data[DOMAIN][entry.entry_id][UNSUB_LISTENER]() hass.data[DOMAIN][entry.entry_id][UNSUB_LISTENER]()
hass.data[DOMAIN][entry.entry_id][UNSUB_TRACK]() hass.data[DOMAIN][entry.entry_id][TUYA_DEVICE].close()
if unload_ok: if unload_ok:
hass.data[DOMAIN].pop(entry.entry_id) hass.data[DOMAIN].pop(entry.entry_id)

View File

@@ -1,8 +1,9 @@
"""Code shared between all platforms.""" """Code shared between all platforms."""
import asyncio
import logging import logging
from time import time, sleep from random import randrange
from threading import Lock
from homeassistant.core import callback
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from homeassistant.helpers.dispatcher import ( from homeassistant.helpers.dispatcher import (
async_dispatcher_connect, async_dispatcher_connect,
@@ -23,6 +24,8 @@ from .const import CONF_LOCAL_KEY, CONF_PROTOCOL_VERSION, DOMAIN, TUYA_DEVICE
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
BACKOFF_TIME_UPPER_LIMIT = 300 # Five minutes
def prepare_setup_entities(hass, config_entry, platform): def prepare_setup_entities(hass, config_entry, platform):
"""Prepare ro setup entities for a platform.""" """Prepare ro setup entities for a platform."""
@@ -60,7 +63,7 @@ async def async_setup_entry(
# Add DPS used by this platform to the request list # Add DPS used by this platform to the request list
for dp_conf in dps_config_fields: for dp_conf in dps_config_fields:
if dp_conf in device_config: if dp_conf in device_config:
tuyainterface._interface.add_dps_to_request(device_config[dp_conf]) tuyainterface._dps_to_request[device_config[dp_conf]] = None
entities.append( entities.append(
entity_class( entity_class(
@@ -88,92 +91,98 @@ def get_entity_config(config_entry, dps_id):
raise Exception(f"missing entity config for id {dps_id}") raise Exception(f"missing entity config for id {dps_id}")
class TuyaDevice: class TuyaDevice(pytuya.TuyaListener):
"""Cache wrapper for pytuya.TuyaInterface.""" """Cache wrapper for pytuya.TuyaInterface."""
def __init__(self, hass, config_entry): def __init__(self, hass, config_entry):
"""Initialize the cache.""" """Initialize the cache."""
self._cached_status = ""
self._cached_status_time = 0
self._interface = pytuya.TuyaInterface(
config_entry[CONF_DEVICE_ID],
config_entry[CONF_HOST],
config_entry[CONF_LOCAL_KEY],
float(config_entry[CONF_PROTOCOL_VERSION]),
)
for entity in config_entry[CONF_ENTITIES]:
# this has to be done in case the device type is type_0d
self._interface.add_dps_to_request(entity[CONF_ID])
self._friendly_name = config_entry[CONF_FRIENDLY_NAME]
self._hass = hass self._hass = hass
self._lock = Lock() self._config_entry = config_entry
self._interface = None
self._status = {}
self._dps_to_request = {}
self._connect_task = None
self._connection_attempts = 0
@property # This has to be done in case the device type is type_0d
def unique_id(self): for entity in config_entry[CONF_ENTITIES]:
"""Return unique device identifier.""" self._dps_to_request[entity[CONF_ID]] = None
return self._interface.id
def __get_status(self): def connect(self, delay=None):
_LOGGER.debug("running def __get_status from TuyaDevice") """Connet to device if not already connected."""
for i in range(5): if self._connect_task is None:
try: self._connect_task = asyncio.ensure_future(self._make_connection())
status = self._interface.status()
return status
except Exception as e:
print(
"Failed to update status of device [{}]: [{}]".format(
self._interface.address, e
)
)
sleep(1.0)
if i + 1 == 3:
_LOGGER.error(
"Failed to update status of device %s", self._interface.address
)
# return None
raise ConnectionError("Failed to update status .")
def set_dps(self, state, dps_index): async def _make_connection(self):
"""Change value of a DP of the Tuya device and update the cached status.""" # Do nothing if already connected
# _LOGGER.info("running def set_dps from TuyaDevice") if self._interface:
# No need to clear the cache here: let's just update the status of the return
# changed dps as returned by the interface (see 5 lines below)
# self._cached_status = ""
# self._cached_status_time = 0
for i in range(5):
try:
result = self._interface.set_dps(state, dps_index)
self._cached_status["dps"].update(result["dps"])
signal = f"localtuya_{self._interface.id}"
async_dispatcher_send(self._hass, signal, self._cached_status)
return
except Exception as e:
print(
"Failed to set status of device [{}]: [{}]".format(
self._interface.address, e
)
)
if i + 1 == 3:
_LOGGER.error(
"Failed to set status of device %s", self._interface.address
)
return
# raise ConnectionError("Failed to set status.") # The sleep gives another task the possibility to sweep in and
# connect, so we block that here
self._interface = True
backoff = min(
randrange(2 ** self._connection_attempts), BACKOFF_TIME_UPPER_LIMIT
)
_LOGGER.debug("Waiting %d seconds before connecting", backoff)
await asyncio.sleep(backoff)
def status(self):
"""Get the state of the Tuya device and cache the results."""
_LOGGER.debug("running def status(self) from TuyaDevice")
self._lock.acquire()
try: try:
now = time() _LOGGER.debug("Connecting to %s", self._config_entry[CONF_HOST])
if not self._cached_status or now - self._cached_status_time > 10: self._interface = await pytuya.connect(
sleep(0.5) self._config_entry[CONF_HOST],
self._cached_status = self.__get_status() self._config_entry[CONF_DEVICE_ID],
self._cached_status_time = time() self._config_entry[CONF_LOCAL_KEY],
return self._cached_status float(self._config_entry[CONF_PROTOCOL_VERSION]),
finally: self,
self._lock.release() )
self._interface.add_dps_to_request(self._dps_to_request)
_LOGGER.debug("Retrieving initial state")
status = await self._interface.status()
if status is None:
raise Exception("failed to retrieve status")
self.status_updated(status)
self._connection_attempts = 0
except Exception:
_LOGGER.exception("connect failed")
self._connection_attempts += 1
self._interface.close()
self._interface = None
self._hass.loop.call_soon(self.connect)
self._connect_task = None
async def set_dps(self, state, dps_index):
"""Change value of a DP of the Tuya device and update the cached status."""
if self._interface is not None:
try:
await self._interface.set_dps(state, dps_index)
except Exception:
_LOGGER.exception("Failed to set DP {dps_index} to state")
else:
_LOGGER.error(
"Not connected to device %s", self._config_entry[CONF_FRIENDLY_NAME]
)
@callback
def status_updated(self, status):
"""Device updated status."""
self._status.update(status["dps"])
signal = f"localtuya_{self._config_entry[CONF_DEVICE_ID]}"
async_dispatcher_send(self._hass, signal, self._status)
@callback
def disconnected(self, exc):
"""Device disconnected."""
signal = f"localtuya_{self._config_entry[CONF_DEVICE_ID]}"
async_dispatcher_send(self._hass, signal, None)
self._interface = None
self.connect()
class LocalTuyaEntity(Entity): class LocalTuyaEntity(Entity):
@@ -212,7 +221,7 @@ class LocalTuyaEntity(Entity):
return { return {
"identifiers": { "identifiers": {
# Serial numbers are unique identifiers within a specific domain # Serial numbers are unique identifiers within a specific domain
(DOMAIN, f"local_{self._device.unique_id}") (DOMAIN, f"local_{self._config_entry.data[CONF_DEVICE_ID]}")
}, },
"name": self._config_entry.data[CONF_FRIENDLY_NAME], "name": self._config_entry.data[CONF_FRIENDLY_NAME],
"manufacturer": "Unknown", "manufacturer": "Unknown",
@@ -233,7 +242,7 @@ class LocalTuyaEntity(Entity):
@property @property
def unique_id(self): def unique_id(self):
"""Return unique device identifier.""" """Return unique device identifier."""
return f"local_{self._device.unique_id}_{self._dps_id}" return f"local_{self._config_entry.data[CONF_DEVICE_ID]}_{self._dps_id}"
def has_config(self, attr): def has_config(self, attr):
"""Return if a config parameter has a valid value.""" """Return if a config parameter has a valid value."""
@@ -243,14 +252,11 @@ class LocalTuyaEntity(Entity):
@property @property
def available(self): def available(self):
"""Return if device is available or not.""" """Return if device is available or not."""
return bool(self._status) return str(self._dps_id) in self._status
def dps(self, dps_index): def dps(self, dps_index):
"""Return cached value for DPS index.""" """Return cached value for DPS index."""
if "dps" not in self._status: value = self._status.get(str(dps_index))
return None
value = self._status["dps"].get(str(dps_index))
if value is None: if value is None:
_LOGGER.warning( _LOGGER.warning(
"Entity %s is requesting unknown DPS index %s", "Entity %s is requesting unknown DPS index %s",

View File

@@ -159,22 +159,23 @@ def config_schema():
async def validate_input(hass: core.HomeAssistant, data): async def validate_input(hass: core.HomeAssistant, data):
"""Validate the user input allows us to connect.""" """Validate the user input allows us to connect."""
tuyainterface = pytuya.TuyaInterface(
data[CONF_DEVICE_ID],
data[CONF_HOST],
data[CONF_LOCAL_KEY],
float(data[CONF_PROTOCOL_VERSION]),
)
detected_dps = {} detected_dps = {}
try: try:
detected_dps = await hass.async_add_executor_job( interface = await pytuya.connect(
tuyainterface.detect_available_dps data[CONF_HOST],
data[CONF_DEVICE_ID],
data[CONF_LOCAL_KEY],
float(data[CONF_PROTOCOL_VERSION]),
) )
detected_dps = await interface.detect_available_dps()
except (ConnectionRefusedError, ConnectionResetError): except (ConnectionRefusedError, ConnectionResetError):
raise CannotConnect raise CannotConnect
except ValueError: except ValueError:
raise InvalidAuth raise InvalidAuth
finally:
interface.close()
return dps_string_list(detected_dps) return dps_string_list(detected_dps)

View File

@@ -1,7 +1,7 @@
"""Platform to locally control Tuya-based cover devices.""" """Platform to locally control Tuya-based cover devices."""
import asyncio
import logging import logging
from functools import partial from functools import partial
from time import sleep
import voluptuous as vol import voluptuous as vol
@@ -112,7 +112,7 @@ class LocaltuyaCover(LocalTuyaEntity, CoverEntity):
return None return None
return self._current_cover_position == 0 return self._current_cover_position == 0
def set_cover_position(self, **kwargs): async def async_set_cover_position(self, **kwargs):
"""Move the cover to a specific position.""" """Move the cover to a specific position."""
_LOGGER.debug("Setting cover position: %r", kwargs[ATTR_POSITION]) _LOGGER.debug("Setting cover position: %r", kwargs[ATTR_POSITION])
if self._config[CONF_POSITIONING_MODE] == COVER_MODE_FAKE: if self._config[CONF_POSITIONING_MODE] == COVER_MODE_FAKE:
@@ -123,36 +123,36 @@ class LocaltuyaCover(LocalTuyaEntity, CoverEntity):
mydelay = posdiff / 50.0 * self._config[CONF_SPAN_TIME] mydelay = posdiff / 50.0 * self._config[CONF_SPAN_TIME]
if newpos > currpos: if newpos > currpos:
_LOGGER.debug("Opening to %f: delay %f", newpos, mydelay) _LOGGER.debug("Opening to %f: delay %f", newpos, mydelay)
self.open_cover() await self.async_open_cover()
else: else:
_LOGGER.debug("Closing to %f: delay %f", newpos, mydelay) _LOGGER.debug("Closing to %f: delay %f", newpos, mydelay)
self.close_cover() await self.async_close_cover()
sleep(mydelay) await asyncio.sleep(mydelay)
self.stop_cover() await self.async_stop_cover()
self._current_cover_position = 50 self._current_cover_position = 50
_LOGGER.debug("Done") _LOGGER.debug("Done")
elif self._config[CONF_POSITIONING_MODE] == COVER_MODE_POSITION: elif self._config[CONF_POSITIONING_MODE] == COVER_MODE_POSITION:
converted_position = int(kwargs[ATTR_POSITION]) converted_position = int(kwargs[ATTR_POSITION])
if 0 <= converted_position <= 100 and self.has_config(CONF_SET_POSITION_DP): if 0 <= converted_position <= 100 and self.has_config(CONF_SET_POSITION_DP):
self._device.set_dps( await self._device.set_dps(
converted_position, self._config[CONF_SET_POSITION_DP] converted_position, self._config[CONF_SET_POSITION_DP]
) )
def open_cover(self, **kwargs): async def async_open_cover(self, **kwargs):
"""Open the cover.""" """Open the cover."""
_LOGGER.debug("Launching command %s to cover ", self._open_cmd) _LOGGER.debug("Launching command %s to cover ", self._open_cmd)
self._device.set_dps(self._open_cmd, self._dps_id) await self._device.set_dps(self._open_cmd, self._dps_id)
def close_cover(self, **kwargs): async def async_close_cover(self, **kwargs):
"""Close cover.""" """Close cover."""
_LOGGER.debug("Launching command %s to cover ", self._close_cmd) _LOGGER.debug("Launching command %s to cover ", self._close_cmd)
self._device.set_dps(self._close_cmd, self._dps_id) await self._device.set_dps(self._close_cmd, self._dps_id)
def stop_cover(self, **kwargs): async def async_stop_cover(self, **kwargs):
"""Stop the cover.""" """Stop the cover."""
_LOGGER.debug("Launching command %s to cover ", COVER_STOP_CMD) _LOGGER.debug("Launching command %s to cover ", COVER_STOP_CMD)
self._device.set_dps(COVER_STOP_CMD, self._dps_id) await self._device.set_dps(COVER_STOP_CMD, self._dps_id)
def status_updated(self): def status_updated(self):
"""Device status was updated.""" """Device status was updated."""

View File

@@ -1,113 +1,113 @@
"""Platform to locally control Tuya-based fan devices.""" """Platform to locally control Tuya-based fan devices."""
import logging import logging
from functools import partial from functools import partial
from homeassistant.components.fan import ( from homeassistant.components.fan import (
FanEntity, FanEntity,
DOMAIN, DOMAIN,
SPEED_OFF, SPEED_OFF,
SPEED_LOW, SPEED_LOW,
SPEED_MEDIUM, SPEED_MEDIUM,
SPEED_HIGH, SPEED_HIGH,
SUPPORT_SET_SPEED, SUPPORT_SET_SPEED,
SUPPORT_OSCILLATE, SUPPORT_OSCILLATE,
) )
from .common import LocalTuyaEntity, async_setup_entry from .common import LocalTuyaEntity, async_setup_entry
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
def flow_schema(dps): def flow_schema(dps):
"""Return schema used in config flow.""" """Return schema used in config flow."""
return {} return {}
class LocaltuyaFan(LocalTuyaEntity, FanEntity): class LocaltuyaFan(LocalTuyaEntity, FanEntity):
"""Representation of a Tuya fan.""" """Representation of a Tuya fan."""
def __init__( def __init__(
self, self,
device, device,
config_entry, config_entry,
fanid, fanid,
**kwargs, **kwargs,
): ):
"""Initialize the entity.""" """Initialize the entity."""
super().__init__(device, config_entry, fanid, **kwargs) super().__init__(device, config_entry, fanid, **kwargs)
self._is_on = False self._is_on = False
self._speed = SPEED_OFF self._speed = SPEED_OFF
self._oscillating = False self._oscillating = False
@property @property
def oscillating(self): def oscillating(self):
"""Return current oscillating status.""" """Return current oscillating status."""
return self._oscillating return self._oscillating
@property @property
def is_on(self): def is_on(self):
"""Check if Tuya fan is on.""" """Check if Tuya fan is on."""
return self._is_on return self._is_on
@property @property
def speed(self) -> str: def speed(self) -> str:
"""Return the current speed.""" """Return the current speed."""
return self._speed return self._speed
@property @property
def speed_list(self) -> list: def speed_list(self) -> list:
"""Get the list of available speeds.""" """Get the list of available speeds."""
return [SPEED_OFF, SPEED_LOW, SPEED_MEDIUM, SPEED_HIGH] return [SPEED_OFF, SPEED_LOW, SPEED_MEDIUM, SPEED_HIGH]
def turn_on(self, speed: str = None, **kwargs) -> None: def turn_on(self, speed: str = None, **kwargs) -> None:
"""Turn on the entity.""" """Turn on the entity."""
self._device.set_dps(True, "1") self._device.set_dps(True, "1")
if speed is not None: if speed is not None:
self.set_speed(speed) self.set_speed(speed)
else: else:
self.schedule_update_ha_state() self.schedule_update_ha_state()
def turn_off(self, **kwargs) -> None: def turn_off(self, **kwargs) -> None:
"""Turn off the entity.""" """Turn off the entity."""
self._device.set_dps(False, "1") self._device.set_dps(False, "1")
self.schedule_update_ha_state() self.schedule_update_ha_state()
def set_speed(self, speed: str) -> None: def set_speed(self, speed: str) -> None:
"""Set the speed of the fan.""" """Set the speed of the fan."""
self._speed = speed self._speed = speed
if speed == SPEED_OFF: if speed == SPEED_OFF:
self._device.set_dps(False, "1") self._device.set_dps(False, "1")
elif speed == SPEED_LOW: elif speed == SPEED_LOW:
self._device.set_dps("1", "2") self._device.set_dps("1", "2")
elif speed == SPEED_MEDIUM: elif speed == SPEED_MEDIUM:
self._device.set_dps("2", "2") self._device.set_dps("2", "2")
elif speed == SPEED_HIGH: elif speed == SPEED_HIGH:
self._device.set_dps("3", "2") self._device.set_dps("3", "2")
self.schedule_update_ha_state() self.schedule_update_ha_state()
def oscillate(self, oscillating: bool) -> None: def oscillate(self, oscillating: bool) -> None:
"""Set oscillation.""" """Set oscillation."""
self._oscillating = oscillating self._oscillating = oscillating
self._device.set_value("8", oscillating) self._device.set_value("8", oscillating)
self.schedule_update_ha_state() self.schedule_update_ha_state()
@property @property
def supported_features(self) -> int: def supported_features(self) -> int:
"""Flag supported features.""" """Flag supported features."""
return SUPPORT_SET_SPEED | SUPPORT_OSCILLATE return SUPPORT_SET_SPEED | SUPPORT_OSCILLATE
def status_updated(self): def status_updated(self):
"""Get state of Tuya fan.""" """Get state of Tuya fan."""
self._is_on = self._status["dps"]["1"] self._is_on = self._status["dps"]["1"]
if not self._status["dps"]["1"]: if not self._status["dps"]["1"]:
self._speed = SPEED_OFF self._speed = SPEED_OFF
elif self._status["dps"]["2"] == "1": elif self._status["dps"]["2"] == "1":
self._speed = SPEED_LOW self._speed = SPEED_LOW
elif self._status["dps"]["2"] == "2": elif self._status["dps"]["2"] == "2":
self._speed = SPEED_MEDIUM self._speed = SPEED_MEDIUM
elif self._status["dps"]["2"] == "3": elif self._status["dps"]["2"] == "3":
self._speed = SPEED_HIGH self._speed = SPEED_HIGH
self._oscillating = self._status["dps"]["8"] self._oscillating = self._status["dps"]["8"]
async_setup_entry = partial(async_setup_entry, DOMAIN, LocaltuyaFan, flow_schema) async_setup_entry = partial(async_setup_entry, DOMAIN, LocaltuyaFan, flow_schema)

View File

@@ -86,12 +86,12 @@ class LocaltuyaLight(LocalTuyaEntity, LightEntity):
supports = supports | SUPPORT_COLOR_TEMP supports = supports | SUPPORT_COLOR_TEMP
return supports return supports
def turn_on(self, **kwargs): async def async_turn_on(self, **kwargs):
"""Turn on or control the light.""" """Turn on or control the light."""
self._device.set_dps(True, self._dps_id) await self._device.set_dps(True, self._dps_id)
if ATTR_BRIGHTNESS in kwargs: if ATTR_BRIGHTNESS in kwargs:
self._device.set_dps( await self._device.set_dps(
max(int(kwargs[ATTR_BRIGHTNESS]), 25), self._config.get(CONF_BRIGHTNESS) max(int(kwargs[ATTR_BRIGHTNESS]), 25), self._config.get(CONF_BRIGHTNESS)
) )
@@ -104,11 +104,11 @@ class LocaltuyaLight(LocalTuyaEntity, LightEntity):
- (255 / (MAX_MIRED - MIN_MIRED)) - (255 / (MAX_MIRED - MIN_MIRED))
* (int(kwargs[ATTR_COLOR_TEMP]) - MIN_MIRED) * (int(kwargs[ATTR_COLOR_TEMP]) - MIN_MIRED)
) )
self._device.set_dps(color_temp, self._config.get(CONF_COLOR_TEMP)) await self._device.set_dps(color_temp, self._config.get(CONF_COLOR_TEMP))
def turn_off(self, **kwargs): async def async_turn_off(self, **kwargs):
"""Turn Tuya light off.""" """Turn Tuya light off."""
self._device.set_dps(False, self._dps_id) await self._device.set_dps(False, self._dps_id)
def status_updated(self): def status_updated(self):
"""Device status was updated.""" """Device status was updated."""

View File

@@ -36,16 +36,17 @@ Credits
Updated pytuya to support devices with Device IDs of 22 characters Updated pytuya to support devices with Device IDs of 22 characters
""" """
import asyncio
import base64 import base64
from hashlib import md5 from hashlib import md5
import json import json
import logging import logging
import socket
import time import time
import binascii import binascii
import struct import struct
import weakref
from collections import namedtuple from collections import namedtuple
from contextlib import contextmanager from abc import ABC, abstractmethod
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
@@ -60,6 +61,7 @@ TuyaMessage = namedtuple("TuyaMessage", "seqno cmd retcode payload crc")
SET = "set" SET = "set"
STATUS = "status" STATUS = "status"
HEARTBEAT = "heartbeat"
PROTOCOL_VERSION_BYTES_31 = b"3.1" PROTOCOL_VERSION_BYTES_31 = b"3.1"
PROTOCOL_VERSION_BYTES_33 = b"3.3" PROTOCOL_VERSION_BYTES_33 = b"3.3"
@@ -73,6 +75,7 @@ MESSAGE_END_FMT = ">2I" # 2*uint32: crc, suffix
PREFIX_VALUE = 0x000055AA PREFIX_VALUE = 0x000055AA
SUFFIX_VALUE = 0x0000AA55 SUFFIX_VALUE = 0x0000AA55
HEARTBEAT_INTERVAL = 20
# This is intended to match requests.json payload at # This is intended to match requests.json payload at
# https://github.com/codetheweb/tuyapi : # https://github.com/codetheweb/tuyapi :
@@ -85,35 +88,18 @@ SUFFIX_VALUE = 0x0000AA55
# length, zero padding implies could be more than one byte) # length, zero padding implies could be more than one byte)
PAYLOAD_DICT = { PAYLOAD_DICT = {
"type_0a": { "type_0a": {
"status": {"hexByte": 0x0A, "command": {"gwId": "", "devId": ""}}, STATUS: {"hexByte": 0x0A, "command": {"gwId": "", "devId": ""}},
"set": {"hexByte": 0x07, "command": {"devId": "", "uid": "", "t": ""}}, SET: {"hexByte": 0x07, "command": {"devId": "", "uid": "", "t": ""}},
HEARTBEAT: {"hexByte": 0x09, "command": {}},
}, },
"type_0d": { "type_0d": {
"status": {"hexByte": 0x0D, "command": {"devId": "", "uid": "", "t": ""}}, STATUS: {"hexByte": 0x0D, "command": {"devId": "", "uid": "", "t": ""}},
"set": {"hexByte": 0x07, "command": {"devId": "", "uid": "", "t": ""}}, SET: {"hexByte": 0x07, "command": {"devId": "", "uid": "", "t": ""}},
HEARTBEAT: {"hexByte": 0x09, "command": {}},
}, },
} }
@contextmanager
def socketcontext(address, port, timeout):
"""Context manager which sets up and tears down socket properly."""
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
s.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
s.settimeout(timeout)
s.connect((address, port))
try:
yield s
except Exception:
# This should probably be a warning or error, but since this happens
# every now and the and we do retries on a higher level, use debug level
# to not spam log with errors.
_LOGGER.debug("Failed to connect to %s. Raising Exception.", address)
raise
finally:
s.close()
def pack_message(msg): def pack_message(msg):
"""Pack a TuyaMessage into bytes.""" """Pack a TuyaMessage into bytes."""
# Create full message excluding CRC and suffix # Create full message excluding CRC and suffix
@@ -178,12 +164,118 @@ class AESCipher:
return s[: -ord(s[len(s) - 1 :])] return s[: -ord(s[len(s) - 1 :])]
class TuyaInterface: class MessageDispatcher:
"""Represent a Tuya device.""" """Buffer and dispatcher for Tuya messages."""
def __init__( def __init__(self, listener):
self, dev_id, address, local_key, protocol_version, connection_timeout=5 """Initialize a new MessageBuffer."""
): self.buffer = b""
self.listeners = {}
self.last_seqno = -1
self.listener = listener
def abort(self):
"""Abort all waiting clients."""
for key in self.listeners:
sem = self.listeners[key]
self.listeners[key] = None
sem.release()
async def wait_for(self, seqno, timeout=5):
"""Wait for response to a sequence number to be received and return it."""
if self.last_seqno >= seqno:
return None
if seqno in self.listeners:
raise Exception(f"listener exists for {seqno}")
_LOGGER.debug("Waiting for sequence number %d", seqno)
self.listeners[seqno] = asyncio.Semaphore(0)
try:
await asyncio.wait_for(self.listeners[seqno].acquire(), timeout=timeout)
except asyncio.TimeoutError:
del self.listeners[seqno]
raise
return self.listeners.pop(seqno)
def add_data(self, data):
"""Add new data to the buffer and try to parse messages."""
self.buffer += data
header_len = struct.calcsize(MESSAGE_RECV_HEADER_FMT)
while self.buffer:
# Check if enough data for measage header
if len(self.buffer) < header_len:
break
# Parse header and check if enough data according to length in header
_, seqno, cmd, length, retcode = struct.unpack_from(
MESSAGE_RECV_HEADER_FMT, self.buffer
)
if len(self.buffer[header_len - 4 :]) < length:
break
# length includes payload length, retcode, crc and suffix
payload_length = length - 4 - struct.calcsize(MESSAGE_END_FMT)
payload = self.buffer[header_len : header_len + payload_length]
crc, _ = struct.unpack_from(
MESSAGE_END_FMT,
self.buffer[header_len + payload_length : header_len + length],
)
self.buffer = self.buffer[header_len + length - 4 :]
self._dispatch(TuyaMessage(seqno, cmd, retcode, payload, crc))
def _dispatch(self, msg):
"""Dispatch a message to someone that is listening."""
self.last_seqno = max(self.last_seqno, msg.seqno)
if msg.seqno in self.listeners:
_LOGGER.debug("Dispatching sequence number %d", msg.seqno)
sem = self.listeners[msg.seqno]
self.listeners[msg.seqno] = msg
sem.release()
elif msg.cmd == 0x09:
_LOGGER.debug("Got heartbeat response")
elif msg.cmd == 0x08:
_LOGGER.debug("Got status update")
self.listener(msg)
else:
_LOGGER.debug(
"Got message type %d for unknown listener %d: %s",
msg.command,
msg.seqno,
msg,
)
class TuyaListener(ABC):
"""Listener interface for Tuya device changes."""
@abstractmethod
def status_updated(self, status):
"""Device updated status."""
@abstractmethod
def disconnected(self, exc):
"""Device disconnected."""
class EmptyListener(TuyaListener):
"""Listener doing nothing."""
def status_updated(self, status):
"""Device updated status."""
def disconnected(self, exc):
"""Device disconnected."""
class TuyaProtocol(asyncio.Protocol):
"""Implementation of the Tuya protocol."""
def __init__(self, dev_id, local_key, protocol_version, on_connected, listener):
""" """
Initialize a new TuyaInterface. Initialize a new TuyaInterface.
@@ -195,38 +287,75 @@ class TuyaInterface:
Attributes: Attributes:
port (int): The port to connect to. port (int): The port to connect to.
""" """
self.loop = asyncio.get_running_loop()
self.id = dev_id self.id = dev_id
self.address = address
self.local_key = local_key.encode("latin1") self.local_key = local_key.encode("latin1")
self.connection_timeout = connection_timeout
self.version = protocol_version self.version = protocol_version
self.dev_type = "type_0a" self.dev_type = "type_0a"
self.dps_to_request = {} self.dps_to_request = {}
self.cipher = AESCipher(self.local_key) self.cipher = AESCipher(self.local_key)
self.seqno = 0 self.seqno = 0
self.transport = None
self.listener = weakref.ref(listener)
self.dispatcher = self._setup_dispatcher()
self.on_connected = on_connected
self.heartbeater = None
self.port = 6668 # default - do not expect caller to pass in def _setup_dispatcher(self):
def _status_update(msg):
listener = self.listener()
if listener is not None:
listener.status_updated(self._decode_payload(msg.payload))
def exchange(self, command, dps=None): return MessageDispatcher(_status_update)
def connection_made(self, transport):
"""Did connect to the device."""
async def heartbeat_loop():
"""Continuously send heart beat updates."""
while True:
self.heartbeat()
await asyncio.sleep(HEARTBEAT_INTERVAL)
self.transport = transport
self.on_connected.set_result(True)
self.heartbeater = self.loop.create_task(heartbeat_loop())
def data_received(self, data):
"""Received data from device."""
self.dispatcher.add_data(data)
def connection_lost(self, exc):
"""Disconnected from device."""
self.close()
listener = self.listener()
if listener is not None:
listener.disconnected(exc)
def close(self):
"""Close connection and abort all outstanding listeners."""
if self.transport is not None:
self.dispatcher.abort()
self.heartbeater.cancel()
transport = self.transport
self.transport = None
transport.close()
async def exchange(self, command, dps=None):
"""Send and receive a message, returning response from device.""" """Send and receive a message, returning response from device."""
_LOGGER.debug("Sending command %s (device type: %s)", command, self.dev_type) _LOGGER.debug("Sending command %s (device type: %s)", command, self.dev_type)
payload = self._generate_payload(command, dps) payload = self._generate_payload(command, dps)
dev_type = self.dev_type dev_type = self.dev_type
with socketcontext(self.address, self.port, self.connection_timeout) as s: self.transport.write(payload)
s.send(payload) msg = await self.dispatcher.wait_for(self.seqno - 1)
data = s.recv(1024) if msg is None:
_LOGGER.debug("Wait was aborted for %d", self.seqno - 1)
return None
# sometimes the first packet does not contain data (typically 28 bytes): # TODO: Verify stuff, e.g. CRC sequence number?
# need to read again payload = self._decode_payload(msg.payload)
if len(data) < 40:
time.sleep(0.1)
data = s.recv(1024)
msg = unpack_message(data)
# TODO: Verify stuff, e.g. CRC sequence number
payload = self._decode_payload(msg.payload)
# Perform a new exchange (once) if we switched device type # Perform a new exchange (once) if we switched device type
if dev_type != self.dev_type: if dev_type != self.dev_type:
@@ -239,11 +368,16 @@ class TuyaInterface:
return self.exchange(command, dps) return self.exchange(command, dps)
return payload return payload
def status(self): async def status(self):
"""Return device status.""" """Return device status."""
return self.exchange(STATUS) return await self.exchange(STATUS)
def set_dps(self, value, dps_index): def heartbeat(self):
"""Send a heartbeat message."""
# We don't expect a response to this, just send blindly
self.transport.write(self._generate_payload(HEARTBEAT))
async def set_dps(self, value, dps_index):
""" """
Set value (may be any type: bool, int or string) of any dps index. Set value (may be any type: bool, int or string) of any dps index.
@@ -251,9 +385,9 @@ class TuyaInterface:
dps_index(int): dps index to set dps_index(int): dps index to set
value: new value for the dps index value: new value for the dps index
""" """
return self.exchange(SET, {str(dps_index): value}) return await self.exchange(SET, {str(dps_index): value})
def detect_available_dps(self): async def detect_available_dps(self):
"""Return which datapoints are supported by the device.""" """Return which datapoints are supported by the device."""
# type_0d devices need a sort of bruteforce querying in order to detect the # type_0d devices need a sort of bruteforce querying in order to detect the
# list of available dps experience shows that the dps available are usually # list of available dps experience shows that the dps available are usually
@@ -268,7 +402,7 @@ class TuyaInterface:
self.dps_to_request = {"1": None} self.dps_to_request = {"1": None}
self.add_dps_to_request(range(*dps_range)) self.add_dps_to_request(range(*dps_range))
try: try:
data = self.status() data = await self.status()
except Exception as e: except Exception as e:
_LOGGER.warning("Failed to get status: %s", e) _LOGGER.warning("Failed to get status: %s", e)
raise raise
@@ -289,7 +423,9 @@ class TuyaInterface:
def _decode_payload(self, payload): def _decode_payload(self, payload):
_LOGGER.debug("decode payload=%r", payload) _LOGGER.debug("decode payload=%r", payload)
if payload.startswith(PROTOCOL_VERSION_BYTES_31): if not payload:
payload = "{}"
elif payload.startswith(PROTOCOL_VERSION_BYTES_31):
payload = payload[len(PROTOCOL_VERSION_BYTES_31) :] # remove version header payload = payload[len(PROTOCOL_VERSION_BYTES_31) :] # remove version header
# remove (what I'm guessing, but not confirmed is) 16-bytes of MD5 # remove (what I'm guessing, but not confirmed is) 16-bytes of MD5
# hexdigest of payload # hexdigest of payload
@@ -377,4 +513,32 @@ class TuyaInterface:
def __repr__(self): def __repr__(self):
"""Return internal string representation of object.""" """Return internal string representation of object."""
return "%r" % ((self.id, self.address),) # FIXME can do better than this return self.id
async def connect(
address,
device_id,
local_key,
protocol_version,
listener=None,
port=6668,
timeout=5,
):
"""Connect to a device."""
loop = asyncio.get_running_loop()
on_connected = loop.create_future()
_, protocol = await loop.create_connection(
lambda: TuyaProtocol(
device_id,
local_key,
protocol_version,
on_connected,
listener or EmptyListener(),
),
address,
port,
)
await asyncio.wait_for(on_connected, timeout=timeout)
return protocol

View File

@@ -65,13 +65,13 @@ class LocaltuyaSwitch(LocalTuyaEntity, SwitchEntity):
attrs[ATTR_VOLTAGE] = self.dps(self._config[CONF_VOLTAGE]) / 10 attrs[ATTR_VOLTAGE] = self.dps(self._config[CONF_VOLTAGE]) / 10
return attrs return attrs
def turn_on(self, **kwargs): async def async_turn_on(self, **kwargs):
"""Turn Tuya switch on.""" """Turn Tuya switch on."""
self._device.set_dps(True, self._dps_id) await self._device.set_dps(True, self._dps_id)
def turn_off(self, **kwargs): async def async_turn_off(self, **kwargs):
"""Turn Tuya switch off.""" """Turn Tuya switch off."""
self._device.set_dps(False, self._dps_id) await self._device.set_dps(False, self._dps_id)
def status_updated(self): def status_updated(self):
"""Device status was updated.""" """Device status was updated."""