From 7efb3024fbb0118fe4da410ef5778587ed2ba0fc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Pierre=20St=C3=A5hl?= Date: Wed, 9 Sep 2020 12:52:46 +0200 Subject: [PATCH] Support arbitrary flow schemas for platforms --- custom_components/localtuya/config_flow.py | 46 ++++++++++++---------- custom_components/localtuya/light.py | 5 ++- custom_components/localtuya/switch.py | 14 ++++--- 3 files changed, 38 insertions(+), 27 deletions(-) diff --git a/custom_components/localtuya/config_flow.py b/custom_components/localtuya/config_flow.py index d8ba181..173cffd 100644 --- a/custom_components/localtuya/config_flow.py +++ b/custom_components/localtuya/config_flow.py @@ -44,22 +44,30 @@ PICK_ENTITY_SCHEMA = vol.Schema( ) -def platform_schema(dps, additional_fields): +def dps_string_list(dps_data): + """Return list of friendly DPS values.""" + return [f"{id} (value: {value})" for id, value in dps_data.items()] + + +def platform_schema(dps_strings, schema): """Generate input validation schema for a platform.""" - dps_list = vol.In([f"{id} (value: {value})" for id, value in dps.items()]) return vol.Schema( { - vol.Required(CONF_ID): dps_list, + vol.Required(CONF_ID): vol.In(dps_strings), vol.Required(CONF_FRIENDLY_NAME): str, } - ).extend({conf: dps_list for conf in additional_fields}) + ).extend(schema) -def strip_dps_values(user_input, fields): +def strip_dps_values(user_input, dps_strings): """Remove values and keep only index for DPS config items.""" - for field in [CONF_ID] + fields: - user_input[field] = user_input[field].split(" ")[0] - return user_input + stripped = {} + for field, value in user_input.items(): + if value in dps_strings: + stripped[field] = user_input[field].split(" ")[0] + else: + stripped[field] = user_input[field] + return stripped async def validate_input(hass: core.HomeAssistant, data): @@ -75,7 +83,7 @@ async def validate_input(hass: core.HomeAssistant, data): raise CannotConnect except ValueError: raise InvalidAuth - return data["dps"] + return dps_string_list(data["dps"]) class LocaltuyaConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): @@ -87,9 +95,9 @@ class LocaltuyaConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): def __init__(self): """Initialize a new LocaltuyaConfigFlow.""" self.basic_info = None - self.dps_data = None + self.dps_strings = [] self.platform = None - self.platform_dps_fields = None + self.platform_schema = None self.entities = [] async def async_step_user(self, user_input=None): @@ -101,7 +109,7 @@ class LocaltuyaConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): try: self.basic_info = user_input - self.dps_data = await validate_input(self.hass, user_input) + self.dps_strings = await validate_input(self.hass, user_input) return await self.async_step_pick_entity_type() except CannotConnect: errors["base"] = "cannot_connect" @@ -144,16 +152,14 @@ class LocaltuyaConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): ) if not already_configured: user_input[CONF_PLATFORM] = self.platform - self.entities.append( - strip_dps_values(user_input, self.platform_dps_fields) - ) + self.entities.append(strip_dps_values(user_input, self.dps_strings)) return await self.async_step_pick_entity_type() errors["base"] = "entity_already_configured" return self.async_show_form( step_id="add_entity", - data_schema=platform_schema(self.dps_data, self.platform_dps_fields), + data_schema=platform_schema(self.dps_strings, self.platform_schema), errors=errors, description_placeholders={"platform": self.platform}, ) @@ -167,12 +173,12 @@ class LocaltuyaConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): CONF_FRIENDLY_NAME: conf[CONF_FRIENDLY_NAME], CONF_PLATFORM: self.platform, } - for field in self.platform_dps_fields: + for field in self.platform_schema.keys(): converted[str(field)] = conf[field] return converted await self.async_set_unique_id(user_input[CONF_DEVICE_ID]) - self._set_platform(user_input[CONF_PLATFORM]) + self._set_platform(user_input[CONF_PLATFORM], []) if len(user_input.get(CONF_SWITCHES, [])) > 0: for switch_conf in user_input[CONF_SWITCHES].values(): @@ -194,9 +200,9 @@ class LocaltuyaConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): def _set_platform(self, platform): integration_module = ".".join(__name__.split(".")[:-1]) self.platform = platform - self.platform_dps_fields = import_module( + self.platform_schema = import_module( "." + platform, integration_module - ).DPS_FIELDS + ).flow_schema(self.dps_strings) class CannotConnect(exceptions.HomeAssistantError): diff --git a/custom_components/localtuya/light.py b/custom_components/localtuya/light.py index d6ffe99..37c57b0 100644 --- a/custom_components/localtuya/light.py +++ b/custom_components/localtuya/light.py @@ -43,7 +43,10 @@ UPDATE_RETRY_LIMIT = 3 PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(BASE_PLATFORM_SCHEMA) -DPS_FIELDS = [] + +def flow_schema(dps): + """Return schema used in config flow.""" + return {} async def async_setup_entry(hass, config_entry, async_add_entities): diff --git a/custom_components/localtuya/switch.py b/custom_components/localtuya/switch.py index 9ae0473..8b12b6c 100644 --- a/custom_components/localtuya/switch.py +++ b/custom_components/localtuya/switch.py @@ -53,7 +53,7 @@ _LOGGER = logging.getLogger(__name__) DEFAULT_ID = "1" -# TODO: This will eventully merge with DPS_FIELDS +# TODO: This will eventully merge with flow_schema SWITCH_SCHEMA = vol.Schema( { vol.Optional(CONF_ID, default=DEFAULT_ID): cv.string, @@ -75,11 +75,13 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(BASE_PLATFORM_SCHEMA).extend( ) -DPS_FIELDS = [ - vol.Optional(CONF_CURRENT), - vol.Optional(CONF_CURRENT_CONSUMPTION), - vol.Optional(CONF_VOLTAGE), -] +def flow_schema(dps): + """Return schema used in config flow.""" + return { + vol.Optional(CONF_CURRENT): vol.In(dps), + vol.Optional(CONF_CURRENT_CONSUMPTION): vol.In(dps), + vol.Optional(CONF_VOLTAGE): vol.In(dps), + } async def async_setup_entry(hass, config_entry, async_add_entities):