Fixed tox issues

This commit is contained in:
rospogrigio
2023-01-09 23:45:06 +01:00
committed by rospogrigio
parent c9d6bc521e
commit 3bf69d69f0
4 changed files with 214 additions and 126 deletions

View File

@@ -88,7 +88,9 @@ CONFIGURE_DEVICE_SCHEMA = vol.Schema(
vol.Required(CONF_LOCAL_KEY): str, vol.Required(CONF_LOCAL_KEY): str,
vol.Required(CONF_HOST): str, vol.Required(CONF_HOST): str,
vol.Required(CONF_DEVICE_ID): str, vol.Required(CONF_DEVICE_ID): str,
vol.Required(CONF_PROTOCOL_VERSION, default="3.3"): vol.In(["3.1", "3.2", "3.3", "3.4"]), vol.Required(CONF_PROTOCOL_VERSION, default="3.3"): vol.In(
["3.1", "3.2", "3.3", "3.4"]
),
vol.Optional(CONF_SCAN_INTERVAL): int, vol.Optional(CONF_SCAN_INTERVAL): int,
vol.Optional(CONF_MANUAL_DPS): str, vol.Optional(CONF_MANUAL_DPS): str,
vol.Optional(CONF_RESET_DPIDS): str, vol.Optional(CONF_RESET_DPIDS): str,
@@ -101,7 +103,9 @@ DEVICE_SCHEMA = vol.Schema(
vol.Required(CONF_DEVICE_ID): cv.string, vol.Required(CONF_DEVICE_ID): cv.string,
vol.Required(CONF_LOCAL_KEY): cv.string, vol.Required(CONF_LOCAL_KEY): cv.string,
vol.Required(CONF_FRIENDLY_NAME): cv.string, vol.Required(CONF_FRIENDLY_NAME): cv.string,
vol.Required(CONF_PROTOCOL_VERSION, default="3.3"): vol.In(["3.1", "3.2", "3.3", "3.4"]), vol.Required(CONF_PROTOCOL_VERSION, default="3.3"): vol.In(
["3.1", "3.2", "3.3", "3.4"]
),
vol.Optional(CONF_SCAN_INTERVAL): int, vol.Optional(CONF_SCAN_INTERVAL): int,
vol.Optional(CONF_MANUAL_DPS): cv.string, vol.Optional(CONF_MANUAL_DPS): cv.string,
vol.Optional(CONF_RESET_DPIDS): str, vol.Optional(CONF_RESET_DPIDS): str,
@@ -144,7 +148,9 @@ def options_schema(entities):
vol.Required(CONF_FRIENDLY_NAME): str, vol.Required(CONF_FRIENDLY_NAME): str,
vol.Required(CONF_HOST): str, vol.Required(CONF_HOST): str,
vol.Required(CONF_LOCAL_KEY): str, vol.Required(CONF_LOCAL_KEY): str,
vol.Required(CONF_PROTOCOL_VERSION, default="3.3"): vol.In(["3.1", "3.2", "3.3", "3.4"]), vol.Required(CONF_PROTOCOL_VERSION, default="3.3"): vol.In(
["3.1", "3.2", "3.3", "3.4"]
),
vol.Optional(CONF_SCAN_INTERVAL): int, vol.Optional(CONF_SCAN_INTERVAL): int,
vol.Optional(CONF_MANUAL_DPS): str, vol.Optional(CONF_MANUAL_DPS): str,
vol.Optional(CONF_RESET_DPIDS): str, vol.Optional(CONF_RESET_DPIDS): str,

View File

@@ -46,7 +46,7 @@ import time
import weakref import weakref
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import namedtuple from collections import namedtuple
from hashlib import md5,sha256 from hashlib import md5, sha256
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes from cryptography.hazmat.primitives.ciphers import Cipher, algorithms, modes
@@ -58,11 +58,13 @@ __author__ = "rospogrigio"
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
# Tuya Packet Format # Tuya Packet Format
TuyaHeader = namedtuple('TuyaHeader', 'prefix seqno cmd length') TuyaHeader = namedtuple("TuyaHeader", "prefix seqno cmd length")
MessagePayload = namedtuple("MessagePayload", "cmd payload") MessagePayload = namedtuple("MessagePayload", "cmd payload")
try: try:
TuyaMessage = namedtuple("TuyaMessage", "seqno cmd retcode payload crc crc_good", defaults=(True,)) TuyaMessage = namedtuple(
except: "TuyaMessage", "seqno cmd retcode payload crc crc_good", defaults=(True,)
)
except Exception:
TuyaMessage = namedtuple("TuyaMessage", "seqno cmd retcode payload crc crc_good") TuyaMessage = namedtuple("TuyaMessage", "seqno cmd retcode payload crc crc_good")
# TinyTuya Error Response Codes # TinyTuya Error Response Codes
@@ -99,30 +101,38 @@ error_codes = {
None: "Unknown Error", None: "Unknown Error",
} }
class DecodeError(Exception):
"""Specific Exception caused by decoding error."""
pass
# Tuya Command Types # Tuya Command Types
# Reference: https://github.com/tuya/tuya-iotos-embeded-sdk-wifi-ble-bk7231n/blob/master/sdk/include/lan_protocol.h # Reference:
AP_CONFIG = 0x01 # FRM_TP_CFG_WF # only used for ap 3.0 network config # https://github.com/tuya/tuya-iotos-embeded-sdk-wifi-ble-bk7231n/blob/master/sdk/include/lan_protocol.h
ACTIVE = 0x02 # FRM_TP_ACTV (discard) # WORK_MODE_CMD AP_CONFIG = 0x01 # FRM_TP_CFG_WF # only used for ap 3.0 network config
SESS_KEY_NEG_START = 0x03 # FRM_SECURITY_TYPE3 # negotiate session key ACTIVE = 0x02 # FRM_TP_ACTV (discard) # WORK_MODE_CMD
SESS_KEY_NEG_RESP = 0x04 # FRM_SECURITY_TYPE4 # negotiate session key response SESS_KEY_NEG_START = 0x03 # FRM_SECURITY_TYPE3 # negotiate session key
SESS_KEY_NEG_RESP = 0x04 # FRM_SECURITY_TYPE4 # negotiate session key response
SESS_KEY_NEG_FINISH = 0x05 # FRM_SECURITY_TYPE5 # finalize session key negotiation SESS_KEY_NEG_FINISH = 0x05 # FRM_SECURITY_TYPE5 # finalize session key negotiation
UNBIND = 0x06 # FRM_TP_UNBIND_DEV # DATA_QUERT_CMD - issue command UNBIND = 0x06 # FRM_TP_UNBIND_DEV # DATA_QUERT_CMD - issue command
CONTROL = 0x07 # FRM_TP_CMD # STATE_UPLOAD_CMD CONTROL = 0x07 # FRM_TP_CMD # STATE_UPLOAD_CMD
STATUS = 0x08 # FRM_TP_STAT_REPORT # STATE_QUERY_CMD STATUS = 0x08 # FRM_TP_STAT_REPORT # STATE_QUERY_CMD
HEART_BEAT = 0x09 # FRM_TP_HB HEART_BEAT = 0x09 # FRM_TP_HB
DP_QUERY = 0x0a # 10 # FRM_QUERY_STAT # UPDATE_START_CMD - get data points DP_QUERY = 0x0A # 10 # FRM_QUERY_STAT # UPDATE_START_CMD - get data points
QUERY_WIFI = 0x0b # 11 # FRM_SSID_QUERY (discard) # UPDATE_TRANS_CMD QUERY_WIFI = 0x0B # 11 # FRM_SSID_QUERY (discard) # UPDATE_TRANS_CMD
TOKEN_BIND = 0x0c # 12 # FRM_USER_BIND_REQ # GET_ONLINE_TIME_CMD - system time (GMT) TOKEN_BIND = 0x0C # 12 # FRM_USER_BIND_REQ # GET_ONLINE_TIME_CMD - system time (GMT)
CONTROL_NEW = 0x0d # 13 # FRM_TP_NEW_CMD # FACTORY_MODE_CMD CONTROL_NEW = 0x0D # 13 # FRM_TP_NEW_CMD # FACTORY_MODE_CMD
ENABLE_WIFI = 0x0e # 14 # FRM_ADD_SUB_DEV_CMD # WIFI_TEST_CMD ENABLE_WIFI = 0x0E # 14 # FRM_ADD_SUB_DEV_CMD # WIFI_TEST_CMD
WIFI_INFO = 0x0f # 15 # FRM_CFG_WIFI_INFO WIFI_INFO = 0x0F # 15 # FRM_CFG_WIFI_INFO
DP_QUERY_NEW = 0x10 # 16 # FRM_QUERY_STAT_NEW DP_QUERY_NEW = 0x10 # 16 # FRM_QUERY_STAT_NEW
SCENE_EXECUTE = 0x11 # 17 # FRM_SCENE_EXEC SCENE_EXECUTE = 0x11 # 17 # FRM_SCENE_EXEC
UPDATEDPS = 0x12 # 18 # FRM_LAN_QUERY_DP # Request refresh of DPS UPDATEDPS = 0x12 # 18 # FRM_LAN_QUERY_DP # Request refresh of DPS
UDP_NEW = 0x13 # 19 # FR_TYPE_ENCRYPTION UDP_NEW = 0x13 # 19 # FR_TYPE_ENCRYPTION
AP_CONFIG_NEW = 0x14 # 20 # FRM_AP_CFG_WF_V40 AP_CONFIG_NEW = 0x14 # 20 # FRM_AP_CFG_WF_V40
BOARDCAST_LPV34 = 0x23 # 35 # FR_TYPE_BOARDCAST_LPV34 BOARDCAST_LPV34 = 0x23 # 35 # FR_TYPE_BOARDCAST_LPV34
LAN_EXT_STREAM = 0x40 # 64 # FRM_LAN_EXT_STREAM LAN_EXT_STREAM = 0x40 # 64 # FRM_LAN_EXT_STREAM
PROTOCOL_VERSION_BYTES_31 = b"3.1" PROTOCOL_VERSION_BYTES_31 = b"3.1"
@@ -141,7 +151,15 @@ PREFIX_VALUE = 0x000055AA
PREFIX_BIN = b"\x00\x00U\xaa" PREFIX_BIN = b"\x00\x00U\xaa"
SUFFIX_VALUE = 0x0000AA55 SUFFIX_VALUE = 0x0000AA55
SUFFIX_BIN = b"\x00\x00\xaaU" SUFFIX_BIN = b"\x00\x00\xaaU"
NO_PROTOCOL_HEADER_CMDS = [DP_QUERY, DP_QUERY_NEW, UPDATEDPS, HEART_BEAT, SESS_KEY_NEG_START, SESS_KEY_NEG_RESP, SESS_KEY_NEG_FINISH ] NO_PROTOCOL_HEADER_CMDS = [
DP_QUERY,
DP_QUERY_NEW,
UPDATEDPS,
HEART_BEAT,
SESS_KEY_NEG_START,
SESS_KEY_NEG_RESP,
SESS_KEY_NEG_FINISH,
]
HEARTBEAT_INTERVAL = 10 HEARTBEAT_INTERVAL = 10
@@ -193,15 +211,13 @@ payload_dict = {
"v3.4": { "v3.4": {
CONTROL: { CONTROL: {
"command_override": CONTROL_NEW, # Uses CONTROL_NEW command "command_override": CONTROL_NEW, # Uses CONTROL_NEW command
"command": {"protocol":5, "t": "int", "data": ""} "command": {"protocol": 5, "t": "int", "data": ""},
}, },
DP_QUERY: { "command_override": DP_QUERY_NEW }, DP_QUERY: {"command_override": DP_QUERY_NEW},
} },
} }
class TuyaLoggingAdapter(logging.LoggerAdapter): class TuyaLoggingAdapter(logging.LoggerAdapter):
"""Adapter that adds device id to all log points.""" """Adapter that adds device id to all log points."""
@@ -243,7 +259,7 @@ class ContextualLogger:
return self._logger.exception(msg, *args) return self._logger.exception(msg, *args)
def pack_message(msg,hmac_key=None): def pack_message(msg, hmac_key=None):
"""Pack a TuyaMessage into bytes.""" """Pack a TuyaMessage into bytes."""
end_fmt = MESSAGE_END_FMT_HMAC if hmac_key else MESSAGE_END_FMT end_fmt = MESSAGE_END_FMT_HMAC if hmac_key else MESSAGE_END_FMT
# Create full message excluding CRC and suffix # Create full message excluding CRC and suffix
@@ -262,9 +278,7 @@ def pack_message(msg,hmac_key=None):
else: else:
crc = binascii.crc32(buffer) & 0xFFFFFFFF crc = binascii.crc32(buffer) & 0xFFFFFFFF
# Calculate CRC, add it together with suffix # Calculate CRC, add it together with suffix
buffer += struct.pack( buffer += struct.pack(end_fmt, crc, SUFFIX_VALUE)
end_fmt, crc, SUFFIX_VALUE
)
return buffer return buffer
@@ -277,55 +291,82 @@ def unpack_message(data, hmac_key=None, header=None, no_retcode=False, logger=No
end_len = struct.calcsize(end_fmt) end_len = struct.calcsize(end_fmt)
headret_len = header_len + retcode_len headret_len = header_len + retcode_len
if len(data) < headret_len+end_len: if len(data) < headret_len + end_len:
logger.debug('unpack_message(): not enough data to unpack header! need %d but only have %d', headret_len+end_len, len(data)) logger.debug(
raise DecodeError('Not enough data to unpack header') "unpack_message(): not enough data to unpack header! need %d but only have %d",
headret_len + end_len,
len(data),
)
raise DecodeError("Not enough data to unpack header")
if header is None: if header is None:
header = parse_header(data) header = parse_header(data)
if len(data) < header_len+header.length: if len(data) < header_len + header.length:
logger.debug('unpack_message(): not enough data to unpack payload! need %d but only have %d', header_len+header.length, len(data)) logger.debug(
raise DecodeError('Not enough data to unpack payload') "unpack_message(): not enough data to unpack payload! need %d but only have %d",
header_len + header.length,
len(data),
)
raise DecodeError("Not enough data to unpack payload")
retcode = 0 if no_retcode else struct.unpack(MESSAGE_RETCODE_FMT, data[header_len:headret_len])[0] retcode = (
0
if no_retcode
else struct.unpack(MESSAGE_RETCODE_FMT, data[header_len:headret_len])[0]
)
# the retcode is technically part of the payload, but strip it as we do not want it here # the retcode is technically part of the payload, but strip it as we do not want it here
payload = data[header_len+retcode_len:header_len+header.length] payload = data[header_len + retcode_len : header_len + header.length]
crc, suffix = struct.unpack(end_fmt, payload[-end_len:]) crc, suffix = struct.unpack(end_fmt, payload[-end_len:])
if hmac_key: if hmac_key:
have_crc = hmac.new(hmac_key, data[:(header_len+header.length)-end_len], sha256).digest() have_crc = hmac.new(
hmac_key, data[: (header_len + header.length) - end_len], sha256
).digest()
else: else:
have_crc = binascii.crc32(data[:(header_len+header.length)-end_len]) & 0xFFFFFFFF have_crc = (
binascii.crc32(data[: (header_len + header.length) - end_len]) & 0xFFFFFFFF
)
if suffix != SUFFIX_VALUE: if suffix != SUFFIX_VALUE:
logger.debug('Suffix prefix wrong! %08X != %08X', suffix, SUFFIX_VALUE) logger.debug("Suffix prefix wrong! %08X != %08X", suffix, SUFFIX_VALUE)
if crc != have_crc: if crc != have_crc:
if hmac_key: if hmac_key:
logger.debug('HMAC checksum wrong! %r != %r', binascii.hexlify(have_crc), binascii.hexlify(crc)) logger.debug(
"HMAC checksum wrong! %r != %r",
binascii.hexlify(have_crc),
binascii.hexlify(crc),
)
else: else:
logger.debug('CRC wrong! %08X != %08X', have_crc, crc) logger.debug("CRC wrong! %08X != %08X", have_crc, crc)
return TuyaMessage(
header.seqno, header.cmd, retcode, payload[:-end_len], crc, crc == have_crc
)
return TuyaMessage(header.seqno, header.cmd, retcode, payload[:-end_len], crc, crc == have_crc)
def parse_header(data): def parse_header(data):
"""Unpack bytes into a TuyaHeader."""
header_len = struct.calcsize(MESSAGE_HEADER_FMT) header_len = struct.calcsize(MESSAGE_HEADER_FMT)
if len(data) < header_len: if len(data) < header_len:
raise DecodeError('Not enough data to unpack header') raise DecodeError("Not enough data to unpack header")
prefix, seqno, cmd, payload_len = struct.unpack( prefix, seqno, cmd, payload_len = struct.unpack(
MESSAGE_HEADER_FMT, data[:header_len] MESSAGE_HEADER_FMT, data[:header_len]
) )
if prefix != PREFIX_VALUE: if prefix != PREFIX_VALUE:
#self.debug('Header prefix wrong! %08X != %08X', prefix, PREFIX_VALUE) # self.debug('Header prefix wrong! %08X != %08X', prefix, PREFIX_VALUE)
raise DecodeError('Header prefix wrong! %08X != %08X' % (prefix, PREFIX_VALUE)) raise DecodeError("Header prefix wrong! %08X != %08X" % (prefix, PREFIX_VALUE))
# sanity check. currently the max payload length is somewhere around 300 bytes # sanity check. currently the max payload length is somewhere around 300 bytes
if payload_len > 1000: if payload_len > 1000:
raise DecodeError('Header claims the packet size is over 1000 bytes! It is most likely corrupt. Claimed size: %d bytes' % payload_len) raise DecodeError(
"Header claims the packet size is over 1000 bytes! It is most likely corrupt. Claimed size: %d bytes"
% payload_len
)
return TuyaHeader(prefix, seqno, cmd, payload_len) return TuyaHeader(prefix, seqno, cmd, payload_len)
@@ -341,7 +382,8 @@ class AESCipher:
def encrypt(self, raw, use_base64=True, pad=True): def encrypt(self, raw, use_base64=True, pad=True):
"""Encrypt data to be sent to device.""" """Encrypt data to be sent to device."""
encryptor = self.cipher.encryptor() encryptor = self.cipher.encryptor()
if pad: raw = self._pad(raw) if pad:
raw = self._pad(raw)
crypted_text = encryptor.update(raw) + encryptor.finalize() crypted_text = encryptor.update(raw) + encryptor.finalize()
return base64.b64encode(crypted_text) if use_base64 else crypted_text return base64.b64encode(crypted_text) if use_base64 else crypted_text
@@ -373,13 +415,13 @@ class MessageDispatcher(ContextualLogger):
RESET_SEQNO = -101 RESET_SEQNO = -101
SESS_KEY_SEQNO = -102 SESS_KEY_SEQNO = -102
def __init__(self, dev_id, listener, version, local_key): def __init__(self, dev_id, listener, protocol_version, local_key):
"""Initialize a new MessageBuffer.""" """Initialize a new MessageBuffer."""
super().__init__() super().__init__()
self.buffer = b"" self.buffer = b""
self.listeners = {} self.listeners = {}
self.listener = listener self.listener = listener
self.version = version self.version = protocol_version
self.local_key = local_key self.local_key = local_key
self.set_logger(_LOGGER, dev_id) self.set_logger(_LOGGER, dev_id)
@@ -420,7 +462,9 @@ class MessageDispatcher(ContextualLogger):
header = parse_header(self.buffer) header = parse_header(self.buffer)
hmac_key = self.local_key if self.version == 3.4 else None hmac_key = self.local_key if self.version == 3.4 else None
msg = unpack_message(self.buffer, header=header, hmac_key=hmac_key, logger=self); msg = unpack_message(
self.buffer, header=header, hmac_key=hmac_key, logger=self
)
self.buffer = self.buffer[header_len - 4 + header.length :] self.buffer = self.buffer[header_len - 4 + header.length :]
self._dispatch(msg) self._dispatch(msg)
@@ -514,7 +558,6 @@ class TuyaProtocol(asyncio.Protocol, ContextualLogger):
self.id = dev_id self.id = dev_id
self.local_key = local_key.encode("latin1") self.local_key = local_key.encode("latin1")
self.real_local_key = self.local_key self.real_local_key = self.local_key
self.version = protocol_version
self.dev_type = "type_0a" self.dev_type = "type_0a"
self.dps_to_request = {} self.dps_to_request = {}
self.cipher = AESCipher(self.local_key) self.cipher = AESCipher(self.local_key)
@@ -525,6 +568,8 @@ class TuyaProtocol(asyncio.Protocol, ContextualLogger):
self.on_connected = on_connected self.on_connected = on_connected
self.heartbeater = None self.heartbeater = None
self.dps_cache = {} self.dps_cache = {}
self.local_nonce = b"0123456789abcdef" # not-so-random random key
self.remote_nonce = b""
if protocol_version: if protocol_version:
self.set_version(float(protocol_version)) self.set_version(float(protocol_version))
@@ -533,26 +578,27 @@ class TuyaProtocol(asyncio.Protocol, ContextualLogger):
# them (such as BulbDevice) make connections when called # them (such as BulbDevice) make connections when called
TuyaProtocol.set_version(self, 3.1) TuyaProtocol.set_version(self, 3.1)
def set_version(self, version): def set_version(self, protocol_version):
self.version = version """Set the device version and eventually start available DPs detection."""
self.version_bytes = str(version).encode('latin1') self.version = protocol_version
self.version_bytes = str(protocol_version).encode("latin1")
self.version_header = self.version_bytes + PROTOCOL_3x_HEADER self.version_header = self.version_bytes + PROTOCOL_3x_HEADER
if version == 3.2: # 3.2 behaves like 3.3 with type_0d if protocol_version == 3.2: # 3.2 behaves like 3.3 with type_0d
#self.version = 3.3 # self.version = 3.3
self.dev_type="type_0d" self.dev_type = "type_0d"
if self.dps_to_request == {}: if self.dps_to_request == {}:
self.detect_available_dps() self.detect_available_dps()
elif version == 3.4: elif protocol_version == 3.4:
self.dev_type = "v3.4" self.dev_type = "v3.4"
elif self.dev_type == "v3.4": elif self.dev_type == "v3.4":
self.dev_type = "default" self.dev_type = "default"
def error_json(self, number=None, payload=None): def error_json(self, number=None, payload=None):
"""Return error details in JSON""" """Return error details in JSON."""
try: try:
spayload = json.dumps(payload) spayload = json.dumps(payload)
# spayload = payload.replace('\"','').replace('\'','') # spayload = payload.replace('\"','').replace('\'','')
except: except Exception:
spayload = '""' spayload = '""'
vals = (error_codes[number], str(number), spayload) vals = (error_codes[number], str(number), spayload)
@@ -640,43 +686,51 @@ class TuyaProtocol(asyncio.Protocol, ContextualLogger):
self.transport = None self.transport = None
transport.close() transport.close()
# similar to exchange() but never retries sending and does not decode the response
async def exchange_quick(self, payload, recv_retries): async def exchange_quick(self, payload, recv_retries):
"""Similar to exchange() but never retries sending and does not decode the response."""
if not self.transport: if not self.transport:
self.debug("[" + self.id + "] send quick failed, could not get socket: %s", payload) self.debug(
"[" + self.id + "] send quick failed, could not get socket: %s", payload
)
return None return None
enc_payload = self._encode_message(payload) if type(payload) == MessagePayload else payload enc_payload = (
self._encode_message(payload)
if isinstance(payload, MessagePayload)
else payload
)
# self.debug("Quick-dispatching message %s, seqno %s", binascii.hexlify(enc_payload), self.seqno) # self.debug("Quick-dispatching message %s, seqno %s", binascii.hexlify(enc_payload), self.seqno)
try: try:
self.transport.write(enc_payload) self.transport.write(enc_payload)
except: except Exception:
# self._check_socket_close(True) # self._check_socket_close(True)
self.close() self.close()
return None return None
while recv_retries: while recv_retries:
try: try:
#msg = await self._receive()
seqno = MessageDispatcher.SESS_KEY_SEQNO seqno = MessageDispatcher.SESS_KEY_SEQNO
# seqno = self.seqno - 1
msg = await self.dispatcher.wait_for(seqno, payload.cmd) msg = await self.dispatcher.wait_for(seqno, payload.cmd)
# for 3.4 devices, we get the starting seqno with the SESS_KEY_NEG_RESP message # for 3.4 devices, we get the starting seqno with the SESS_KEY_NEG_RESP message
self.seqno = msg.seqno self.seqno = msg.seqno
except: except Exception:
msg = None msg = None
if msg and len(msg.payload) != 0: if msg and len(msg.payload) != 0:
return msg return msg
recv_retries -= 1 recv_retries -= 1
if recv_retries == 0: if recv_retries == 0:
self.debug("received null payload (%r) but out of recv retries, giving up", msg) self.debug(
"received null payload (%r) but out of recv retries, giving up", msg
)
else: else:
self.debug("received null payload (%r), fetch new one - %s retries remaining", msg, recv_retries) self.debug(
"received null payload (%r), fetch new one - %s retries remaining",
msg,
recv_retries,
)
return None return None
async def exchange(self, command, dps=None): async def exchange(self, command, dps=None):
"""Send and receive a message, returning response from device.""" """Send and receive a message, returning response from device."""
if self.version == 3.4 and self.real_local_key == self.local_key: if self.version == 3.4 and self.real_local_key == self.local_key:
self.debug("3.4 device: negotiating a new session key") self.debug("3.4 device: negotiating a new session key")
await self._negotiate_session_key() await self._negotiate_session_key()
@@ -701,7 +755,7 @@ class TuyaProtocol(asyncio.Protocol, ContextualLogger):
enc_payload = self._encode_message(payload) enc_payload = self._encode_message(payload)
self.transport.write(enc_payload) self.transport.write(enc_payload)
msg = await self.dispatcher.wait_for(seqno, payload.cmd ) msg = await self.dispatcher.wait_for(seqno, payload.cmd)
if msg is None: if msg is None:
self.debug("Wait was aborted for seqno %d", seqno) self.debug("Wait was aborted for seqno %d", seqno)
return None return None
@@ -822,7 +876,7 @@ class TuyaProtocol(asyncio.Protocol, ContextualLogger):
try: try:
# self.debug("decrypting=%r", payload) # self.debug("decrypting=%r", payload)
payload = cipher.decrypt(payload, False, decode_text=False) payload = cipher.decrypt(payload, False, decode_text=False)
except: except Exception:
self.debug("incomplete payload=%r (len:%d)", payload, len(payload)) self.debug("incomplete payload=%r (len:%d)", payload, len(payload))
return self.error_json(ERR_PAYLOAD) return self.error_json(ERR_PAYLOAD)
@@ -835,9 +889,9 @@ class TuyaProtocol(asyncio.Protocol, ContextualLogger):
# Decrypt payload # Decrypt payload
# Remove 16-bytes of MD5 hexdigest of payload # Remove 16-bytes of MD5 hexdigest of payload
payload = cipher.decrypt(payload[16:]) payload = cipher.decrypt(payload[16:])
elif self.version >= 3.2: # 3.2 or 3.3 or 3.4 elif self.version >= 3.2: # 3.2 or 3.3 or 3.4
# Trim header for non-default device type # Trim header for non-default device type
if payload.startswith( self.version_bytes ): if payload.startswith(self.version_bytes):
payload = payload[len(self.version_header) :] payload = payload[len(self.version_header) :]
# self.debug("removing 3.x=%r", payload) # self.debug("removing 3.x=%r", payload)
elif self.dev_type == "type_0d" and (len(payload) & 0x0F) != 0: elif self.dev_type == "type_0d" and (len(payload) & 0x0F) != 0:
@@ -848,7 +902,7 @@ class TuyaProtocol(asyncio.Protocol, ContextualLogger):
try: try:
# self.debug("decrypting=%r", payload) # self.debug("decrypting=%r", payload)
payload = cipher.decrypt(payload, False) payload = cipher.decrypt(payload, False)
except: except Exception:
self.debug("incomplete payload=%r (len:%d)", payload, len(payload)) self.debug("incomplete payload=%r (len:%d)", payload, len(payload))
return self.error_json(ERR_PAYLOAD) return self.error_json(ERR_PAYLOAD)
@@ -858,7 +912,7 @@ class TuyaProtocol(asyncio.Protocol, ContextualLogger):
if not isinstance(payload, str): if not isinstance(payload, str):
try: try:
payload = payload.decode() payload = payload.decode()
except: except Exception:
self.debug("payload was not string type and decoding failed") self.debug("payload was not string type and decoding failed")
return self.error_json(ERR_JSON, payload) return self.error_json(ERR_JSON, payload)
if "data unvalid" in payload: if "data unvalid" in payload:
@@ -877,28 +931,34 @@ class TuyaProtocol(asyncio.Protocol, ContextualLogger):
self.debug("Deciphered data = %r", payload) self.debug("Deciphered data = %r", payload)
try: try:
json_payload = json.loads(payload) json_payload = json.loads(payload)
except: except Exception:
json_payload = self.error_json(ERR_JSON, payload) json_payload = self.error_json(ERR_JSON, payload)
# v3.4 stuffs it into {"data":{"dps":{"1":true}}, ...} # v3.4 stuffs it into {"data":{"dps":{"1":true}}, ...}
if "dps" not in json_payload and "data" in json_payload and "dps" in json_payload['data']: if (
json_payload['dps'] = json_payload['data']['dps'] "dps" not in json_payload
and "data" in json_payload
and "dps" in json_payload["data"]
):
json_payload["dps"] = json_payload["data"]["dps"]
return json_payload return json_payload
async def _negotiate_session_key(self): async def _negotiate_session_key(self):
self.local_nonce = b'0123456789abcdef' # not-so-random random key
self.remote_nonce = b''
self.local_key = self.real_local_key self.local_key = self.real_local_key
rkey = await self.exchange_quick( MessagePayload(SESS_KEY_NEG_START, self.local_nonce), 2 ) rkey = await self.exchange_quick(
if not rkey or type(rkey) != TuyaMessage or len(rkey.payload) < 48: MessagePayload(SESS_KEY_NEG_START, self.local_nonce), 2
)
if not rkey or not isinstance(rkey, TuyaMessage) or len(rkey.payload) < 48:
# error # error
self.debug("session key negotiation failed on step 1") self.debug("session key negotiation failed on step 1")
return False return False
if rkey.cmd != SESS_KEY_NEG_RESP: if rkey.cmd != SESS_KEY_NEG_RESP:
self.debug("session key negotiation step 2 returned wrong command: %d", rkey.cmd) self.debug(
"session key negotiation step 2 returned wrong command: %d", rkey.cmd
)
return False return False
payload = rkey.payload payload = rkey.payload
@@ -906,8 +966,12 @@ class TuyaProtocol(asyncio.Protocol, ContextualLogger):
# self.debug("decrypting %r using %r", payload, self.real_local_key) # self.debug("decrypting %r using %r", payload, self.real_local_key)
cipher = AESCipher(self.real_local_key) cipher = AESCipher(self.real_local_key)
payload = cipher.decrypt(payload, False, decode_text=False) payload = cipher.decrypt(payload, False, decode_text=False)
except: except Exception:
self.debug("session key step 2 decrypt failed, payload=%r (len:%d)", payload, len(payload)) self.debug(
"session key step 2 decrypt failed, payload=%r (len:%d)",
payload,
len(payload),
)
return False return False
self.debug("decrypted session key negotiation step 2: payload=%r", payload) self.debug("decrypted session key negotiation step 2: payload=%r", payload)
@@ -920,23 +984,31 @@ class TuyaProtocol(asyncio.Protocol, ContextualLogger):
hmac_check = hmac.new(self.local_key, self.local_nonce, sha256).digest() hmac_check = hmac.new(self.local_key, self.local_nonce, sha256).digest()
if hmac_check != payload[16:48]: if hmac_check != payload[16:48]:
self.debug("session key negotiation step 2 failed HMAC check! wanted=%r but got=%r", binascii.hexlify(hmac_check), binascii.hexlify(payload[16:48])) self.debug(
"session key negotiation step 2 failed HMAC check! wanted=%r but got=%r",
binascii.hexlify(hmac_check),
binascii.hexlify(payload[16:48]),
)
# self.debug("session local nonce: %r remote nonce: %r", self.local_nonce, self.remote_nonce) # self.debug("session local nonce: %r remote nonce: %r", self.local_nonce, self.remote_nonce)
rkey_hmac = hmac.new(self.local_key, self.remote_nonce, sha256).digest() rkey_hmac = hmac.new(self.local_key, self.remote_nonce, sha256).digest()
await self.exchange_quick( MessagePayload(SESS_KEY_NEG_FINISH, rkey_hmac), None ) await self.exchange_quick(MessagePayload(SESS_KEY_NEG_FINISH, rkey_hmac), None)
self.local_key = bytes( [ a^b for (a,b) in zip(self.local_nonce,self.remote_nonce) ] ) self.local_key = bytes(
[a ^ b for (a, b) in zip(self.local_nonce, self.remote_nonce)]
)
# self.debug("Session nonce XOR'd: %r" % self.local_key) # self.debug("Session nonce XOR'd: %r" % self.local_key)
cipher = AESCipher(self.real_local_key) cipher = AESCipher(self.real_local_key)
self.local_key = self.dispatcher.local_key = cipher.encrypt(self.local_key, False, pad=False) self.local_key = self.dispatcher.local_key = cipher.encrypt(
self.local_key, False, pad=False
)
self.debug("Session key negotiate success! session key: %r", self.local_key) self.debug("Session key negotiate success! session key: %r", self.local_key)
return True return True
# adds protocol header (if needed) and encrypts # adds protocol header (if needed) and encrypts
def _encode_message( self, msg ): def _encode_message(self, msg):
hmac_key = None hmac_key = None
payload = msg.payload payload = msg.payload
self.cipher = AESCipher(self.local_key) self.cipher = AESCipher(self.local_key)
@@ -945,7 +1017,7 @@ class TuyaProtocol(asyncio.Protocol, ContextualLogger):
if msg.cmd not in NO_PROTOCOL_HEADER_CMDS: if msg.cmd not in NO_PROTOCOL_HEADER_CMDS:
# add the 3.x header # add the 3.x header
payload = self.version_header + payload payload = self.version_header + payload
self.debug('final payload for cmd %r: %r', msg.cmd, payload) self.debug("final payload for cmd %r: %r", msg.cmd, payload)
payload = self.cipher.encrypt(payload, False) payload = self.cipher.encrypt(payload, False)
elif self.version >= 3.2: elif self.version >= 3.2:
# expect to connect and then disconnect to set new # expect to connect and then disconnect to set new
@@ -977,7 +1049,7 @@ class TuyaProtocol(asyncio.Protocol, ContextualLogger):
self.cipher = None self.cipher = None
msg = TuyaMessage(self.seqno, msg.cmd, 0, payload, 0, True) msg = TuyaMessage(self.seqno, msg.cmd, 0, payload, 0, True)
self.seqno += 1 # increase message sequence number self.seqno += 1 # increase message sequence number
buffer = pack_message(msg,hmac_key=hmac_key) buffer = pack_message(msg, hmac_key=hmac_key)
# self.debug("payload encrypted with key %r => %r", self.local_key, binascii.hexlify(buffer)) # self.debug("payload encrypted with key %r => %r", self.local_key, binascii.hexlify(buffer))
return buffer return buffer
@@ -997,16 +1069,26 @@ class TuyaProtocol(asyncio.Protocol, ContextualLogger):
json_data = command_override = None json_data = command_override = None
if command in payload_dict[self.dev_type]: if command in payload_dict[self.dev_type]:
if 'command' in payload_dict[self.dev_type][command]: if "command" in payload_dict[self.dev_type][command]:
json_data = payload_dict[self.dev_type][command]['command'] json_data = payload_dict[self.dev_type][command]["command"]
if 'command_override' in payload_dict[self.dev_type][command]: if "command_override" in payload_dict[self.dev_type][command]:
command_override = payload_dict[self.dev_type][command]['command_override'] command_override = payload_dict[self.dev_type][command][
"command_override"
]
if self.dev_type != 'type_0a': if self.dev_type != "type_0a":
if json_data is None and command in payload_dict['type_0a'] and 'command' in payload_dict['type_0a'][command]: if (
json_data = payload_dict['type_0a'][command]['command'] json_data is None
if command_override is None and command in payload_dict['type_0a'] and 'command_override' in payload_dict['type_0a'][command]: and command in payload_dict["type_0a"]
command_override = payload_dict['type_0a'][command]['command_override'] and "command" in payload_dict["type_0a"][command]
):
json_data = payload_dict["type_0a"][command]["command"]
if (
command_override is None
and command in payload_dict["type_0a"]
and "command_override" in payload_dict["type_0a"][command]
):
command_override = payload_dict["type_0a"][command]["command_override"]
if command_override is None: if command_override is None:
command_override = command command_override = command
@@ -1014,7 +1096,6 @@ class TuyaProtocol(asyncio.Protocol, ContextualLogger):
# I have yet to see a device complain about included but unneeded attribs, but they *will* # I have yet to see a device complain about included but unneeded attribs, but they *will*
# complain about missing attribs, so just include them all unless otherwise specified # complain about missing attribs, so just include them all unless otherwise specified
json_data = {"gwId": "", "devId": "", "uid": "", "t": ""} json_data = {"gwId": "", "devId": "", "uid": "", "t": ""}
cmd_data = ""
if "gwId" in json_data: if "gwId" in json_data:
if gwId is not None: if gwId is not None:
@@ -1032,7 +1113,7 @@ class TuyaProtocol(asyncio.Protocol, ContextualLogger):
else: else:
json_data["uid"] = self.id json_data["uid"] = self.id
if "t" in json_data: if "t" in json_data:
if json_data['t'] == "int": if json_data["t"] == "int":
json_data["t"] = int(time.time()) json_data["t"] = int(time.time())
else: else:
json_data["t"] = str(int(time.time())) json_data["t"] = str(int(time.time()))
@@ -1057,7 +1138,6 @@ class TuyaProtocol(asyncio.Protocol, ContextualLogger):
return MessagePayload(command_override, payload) return MessagePayload(command_override, payload)
def __repr__(self): def __repr__(self):
"""Return internal string representation of object.""" """Return internal string representation of object."""
return self.id return self.id

View File

@@ -171,10 +171,12 @@ disable=line-too-long,
deprecated-sys-function, deprecated-sys-function,
exception-escape, exception-escape,
comprehension-escape, comprehension-escape,
unused-variable, unused-variable,
invalid-name, invalid-name,
dangerous-default-value, dangerous-default-value,
unreachable unreachable,
unnecessary-pass,
broad-except
# Enable the message, report, category or checker with the given id(s). You can # Enable the message, report, category or checker with the given id(s). You can
# either give multiple identifier separated by comma (,) or put this option # either give multiple identifier separated by comma (,) or put this option

View File

@@ -1,10 +1,10 @@
[flake8] [flake8]
exclude = .git,.tox exclude = .git,.tox
max-line-length = 88 max-line-length = 120
ignore = E203, W503 ignore = E203, W503
[mypy] [mypy]
python_version = 3.7 python_version = 3.9
ignore_errors = true ignore_errors = true
follow_imports = silent follow_imports = silent
ignore_missing_imports = true ignore_missing_imports = true