Support arbitrary flow schemas for platforms

This commit is contained in:
Pierre Ståhl
2020-09-09 12:52:46 +02:00
parent 7018625f04
commit 7efb3024fb
3 changed files with 38 additions and 27 deletions

View File

@@ -44,22 +44,30 @@ PICK_ENTITY_SCHEMA = vol.Schema(
)
def platform_schema(dps, additional_fields):
def dps_string_list(dps_data):
"""Return list of friendly DPS values."""
return [f"{id} (value: {value})" for id, value in dps_data.items()]
def platform_schema(dps_strings, schema):
"""Generate input validation schema for a platform."""
dps_list = vol.In([f"{id} (value: {value})" for id, value in dps.items()])
return vol.Schema(
{
vol.Required(CONF_ID): dps_list,
vol.Required(CONF_ID): vol.In(dps_strings),
vol.Required(CONF_FRIENDLY_NAME): str,
}
).extend({conf: dps_list for conf in additional_fields})
).extend(schema)
def strip_dps_values(user_input, fields):
def strip_dps_values(user_input, dps_strings):
"""Remove values and keep only index for DPS config items."""
for field in [CONF_ID] + fields:
user_input[field] = user_input[field].split(" ")[0]
return user_input
stripped = {}
for field, value in user_input.items():
if value in dps_strings:
stripped[field] = user_input[field].split(" ")[0]
else:
stripped[field] = user_input[field]
return stripped
async def validate_input(hass: core.HomeAssistant, data):
@@ -75,7 +83,7 @@ async def validate_input(hass: core.HomeAssistant, data):
raise CannotConnect
except ValueError:
raise InvalidAuth
return data["dps"]
return dps_string_list(data["dps"])
class LocaltuyaConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
@@ -87,9 +95,9 @@ class LocaltuyaConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
def __init__(self):
"""Initialize a new LocaltuyaConfigFlow."""
self.basic_info = None
self.dps_data = None
self.dps_strings = []
self.platform = None
self.platform_dps_fields = None
self.platform_schema = None
self.entities = []
async def async_step_user(self, user_input=None):
@@ -101,7 +109,7 @@ class LocaltuyaConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
try:
self.basic_info = user_input
self.dps_data = await validate_input(self.hass, user_input)
self.dps_strings = await validate_input(self.hass, user_input)
return await self.async_step_pick_entity_type()
except CannotConnect:
errors["base"] = "cannot_connect"
@@ -144,16 +152,14 @@ class LocaltuyaConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
)
if not already_configured:
user_input[CONF_PLATFORM] = self.platform
self.entities.append(
strip_dps_values(user_input, self.platform_dps_fields)
)
self.entities.append(strip_dps_values(user_input, self.dps_strings))
return await self.async_step_pick_entity_type()
errors["base"] = "entity_already_configured"
return self.async_show_form(
step_id="add_entity",
data_schema=platform_schema(self.dps_data, self.platform_dps_fields),
data_schema=platform_schema(self.dps_strings, self.platform_schema),
errors=errors,
description_placeholders={"platform": self.platform},
)
@@ -167,12 +173,12 @@ class LocaltuyaConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
CONF_FRIENDLY_NAME: conf[CONF_FRIENDLY_NAME],
CONF_PLATFORM: self.platform,
}
for field in self.platform_dps_fields:
for field in self.platform_schema.keys():
converted[str(field)] = conf[field]
return converted
await self.async_set_unique_id(user_input[CONF_DEVICE_ID])
self._set_platform(user_input[CONF_PLATFORM])
self._set_platform(user_input[CONF_PLATFORM], [])
if len(user_input.get(CONF_SWITCHES, [])) > 0:
for switch_conf in user_input[CONF_SWITCHES].values():
@@ -194,9 +200,9 @@ class LocaltuyaConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
def _set_platform(self, platform):
integration_module = ".".join(__name__.split(".")[:-1])
self.platform = platform
self.platform_dps_fields = import_module(
self.platform_schema = import_module(
"." + platform, integration_module
).DPS_FIELDS
).flow_schema(self.dps_strings)
class CannotConnect(exceptions.HomeAssistantError):

View File

@@ -43,7 +43,10 @@ UPDATE_RETRY_LIMIT = 3
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(BASE_PLATFORM_SCHEMA)
DPS_FIELDS = []
def flow_schema(dps):
"""Return schema used in config flow."""
return {}
async def async_setup_entry(hass, config_entry, async_add_entities):

View File

@@ -53,7 +53,7 @@ _LOGGER = logging.getLogger(__name__)
DEFAULT_ID = "1"
# TODO: This will eventully merge with DPS_FIELDS
# TODO: This will eventully merge with flow_schema
SWITCH_SCHEMA = vol.Schema(
{
vol.Optional(CONF_ID, default=DEFAULT_ID): cv.string,
@@ -75,11 +75,13 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(BASE_PLATFORM_SCHEMA).extend(
)
DPS_FIELDS = [
vol.Optional(CONF_CURRENT),
vol.Optional(CONF_CURRENT_CONSUMPTION),
vol.Optional(CONF_VOLTAGE),
]
def flow_schema(dps):
"""Return schema used in config flow."""
return {
vol.Optional(CONF_CURRENT): vol.In(dps),
vol.Optional(CONF_CURRENT_CONSUMPTION): vol.In(dps),
vol.Optional(CONF_VOLTAGE): vol.In(dps),
}
async def async_setup_entry(hass, config_entry, async_add_entities):