Store product key in config entry (#170)

This commit is contained in:
Pierre Ståhl
2020-11-19 10:13:00 +01:00
committed by GitHub
parent 5eeb4c4af3
commit 27af622405
5 changed files with 42 additions and 18 deletions

View File

@@ -71,7 +71,7 @@ from homeassistant.helpers.reload import async_integration_yaml_config
from .common import TuyaDevice from .common import TuyaDevice
from .config_flow import config_schema from .config_flow import config_schema
from .const import DATA_DISCOVERY, DOMAIN, TUYA_DEVICE from .const import CONF_PRODUCT_KEY, DATA_DISCOVERY, DOMAIN, TUYA_DEVICE
from .discovery import TuyaDiscovery from .discovery import TuyaDiscovery
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -129,6 +129,7 @@ async def async_setup(hass: HomeAssistant, config: dict):
"""Update address of device if it has changed.""" """Update address of device if it has changed."""
device_ip = device["ip"] device_ip = device["ip"]
device_id = device["gwId"] device_id = device["gwId"]
product_key = device["productKey"]
# If device is not in cache, check if a config entry exists # If device is not in cache, check if a config entry exists
if device_id not in device_cache: if device_id not in device_cache:
@@ -138,16 +139,27 @@ async def async_setup(hass: HomeAssistant, config: dict):
# potential update below # potential update below
device_cache[device_id] = entry.data[CONF_HOST] device_cache[device_id] = entry.data[CONF_HOST]
# If device is in cache and address changed... if device_id not in device_cache:
if device_id in device_cache and device_cache[device_id] != device_ip: return
_LOGGER.debug("Device %s changed IP to %s", device_id, device_ip)
entry = _entry_by_device_id(device_id) entry = _entry_by_device_id(device_id)
if entry: if entry is None:
hass.config_entries.async_update_entry( return
entry, data={**entry.data, CONF_HOST: device_ip}
) updates = {}
device_cache[device_id] = device_ip
if device_cache[device_id] != device_ip:
updates[CONF_HOST] = device_ip
device_cache[device_id] = device_ip
if entry.data.get(CONF_PRODUCT_KEY) != product_key:
updates[CONF_PRODUCT_KEY] = product_key
if updates:
_LOGGER.debug("Update keys for device %s: %s", updates)
hass.config_entries.async_update_entry(
entry, data={**entry.data, **updates}
)
discovery = TuyaDiscovery(_device_discovered) discovery = TuyaDiscovery(_device_discovered)

View File

@@ -19,7 +19,13 @@ from homeassistant.helpers.dispatcher import (
from homeassistant.helpers.entity import Entity from homeassistant.helpers.entity import Entity
from . import pytuya from . import pytuya
from .const import CONF_LOCAL_KEY, CONF_PROTOCOL_VERSION, DOMAIN, TUYA_DEVICE from .const import (
CONF_LOCAL_KEY,
CONF_PRODUCT_KEY,
CONF_PROTOCOL_VERSION,
DOMAIN,
TUYA_DEVICE,
)
_LOGGER = logging.getLogger(__name__) _LOGGER = logging.getLogger(__name__)
@@ -257,7 +263,7 @@ class LocalTuyaEntity(Entity, pytuya.ContextualLogger):
}, },
"name": self._config_entry.data[CONF_FRIENDLY_NAME], "name": self._config_entry.data[CONF_FRIENDLY_NAME],
"manufacturer": "Unknown", "manufacturer": "Unknown",
"model": "Tuya generic", "model": self._config_entry.data.get(CONF_PRODUCT_KEY, "Tuya generic"),
"sw_version": self._config_entry.data[CONF_PROTOCOL_VERSION], "sw_version": self._config_entry.data[CONF_PROTOCOL_VERSION],
} }

View File

@@ -20,6 +20,7 @@ from . import pytuya
from .const import CONF_DPS_STRINGS # pylint: disable=unused-import from .const import CONF_DPS_STRINGS # pylint: disable=unused-import
from .const import ( from .const import (
CONF_LOCAL_KEY, CONF_LOCAL_KEY,
CONF_PRODUCT_KEY,
CONF_PROTOCOL_VERSION, CONF_PROTOCOL_VERSION,
DATA_DISCOVERY, DATA_DISCOVERY,
DOMAIN, DOMAIN,
@@ -212,7 +213,8 @@ class LocaltuyaConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
errors = {} errors = {}
if user_input is not None: if user_input is not None:
if user_input[DISCOVERED_DEVICE] != CUSTOM_DEVICE: if user_input[DISCOVERED_DEVICE] != CUSTOM_DEVICE:
self.selected_device = user_input[DISCOVERED_DEVICE].split(" ")[0] device = user_input[DISCOVERED_DEVICE].split(" ")[0]
self.selected_device = self.devices[device]
return await self.async_step_basic_info() return await self.async_step_basic_info()
# Use cache if available or fallback to manual discovery # Use cache if available or fallback to manual discovery
@@ -251,6 +253,10 @@ class LocaltuyaConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
try: try:
self.basic_info = user_input self.basic_info = user_input
if self.selected_device is not None:
self.basic_info[CONF_PRODUCT_KEY] = self.selected_device[
"productKey"
]
self.dps_strings = await validate_input(self.hass, user_input) self.dps_strings = await validate_input(self.hass, user_input)
return await self.async_step_pick_entity_type() return await self.async_step_pick_entity_type()
except CannotConnect: except CannotConnect:
@@ -266,10 +272,9 @@ class LocaltuyaConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
defaults = {} defaults = {}
defaults.update(user_input or {}) defaults.update(user_input or {})
if self.selected_device is not None: if self.selected_device is not None:
device = self.devices[self.selected_device] defaults[CONF_HOST] = self.selected_device.get("ip")
defaults[CONF_HOST] = device.get("ip") defaults[CONF_DEVICE_ID] = self.selected_device.get("gwId")
defaults[CONF_DEVICE_ID] = device.get("gwId") defaults[CONF_PROTOCOL_VERSION] = self.selected_device.get("version")
defaults[CONF_PROTOCOL_VERSION] = device.get("version")
return self.async_show_form( return self.async_show_form(
step_id="basic_info", step_id="basic_info",

View File

@@ -7,6 +7,7 @@ ATTR_VOLTAGE = "voltage"
CONF_LOCAL_KEY = "local_key" CONF_LOCAL_KEY = "local_key"
CONF_PROTOCOL_VERSION = "protocol_version" CONF_PROTOCOL_VERSION = "protocol_version"
CONF_DPS_STRINGS = "dps_strings" CONF_DPS_STRINGS = "dps_strings"
CONF_PRODUCT_KEY = "product_key"
# light # light
CONF_BRIGHTNESS_LOWER = "brightness_lower" CONF_BRIGHTNESS_LOWER = "brightness_lower"

View File

@@ -1,7 +1,7 @@
"""Platform to locally control Tuya-based cover devices.""" """Platform to locally control Tuya-based cover devices."""
import asyncio import asyncio
import time
import logging import logging
import time
from functools import partial from functools import partial
import voluptuous as vol import voluptuous as vol