Support arbitrary flow schemas for platforms

This commit is contained in:
Pierre Ståhl
2020-09-09 12:52:46 +02:00
parent 7018625f04
commit 7efb3024fb
3 changed files with 38 additions and 27 deletions

View File

@@ -44,22 +44,30 @@ PICK_ENTITY_SCHEMA = vol.Schema(
) )
def platform_schema(dps, additional_fields): def dps_string_list(dps_data):
"""Return list of friendly DPS values."""
return [f"{id} (value: {value})" for id, value in dps_data.items()]
def platform_schema(dps_strings, schema):
"""Generate input validation schema for a platform.""" """Generate input validation schema for a platform."""
dps_list = vol.In([f"{id} (value: {value})" for id, value in dps.items()])
return vol.Schema( return vol.Schema(
{ {
vol.Required(CONF_ID): dps_list, vol.Required(CONF_ID): vol.In(dps_strings),
vol.Required(CONF_FRIENDLY_NAME): str, vol.Required(CONF_FRIENDLY_NAME): str,
} }
).extend({conf: dps_list for conf in additional_fields}) ).extend(schema)
def strip_dps_values(user_input, fields): def strip_dps_values(user_input, dps_strings):
"""Remove values and keep only index for DPS config items.""" """Remove values and keep only index for DPS config items."""
for field in [CONF_ID] + fields: stripped = {}
user_input[field] = user_input[field].split(" ")[0] for field, value in user_input.items():
return user_input if value in dps_strings:
stripped[field] = user_input[field].split(" ")[0]
else:
stripped[field] = user_input[field]
return stripped
async def validate_input(hass: core.HomeAssistant, data): async def validate_input(hass: core.HomeAssistant, data):
@@ -75,7 +83,7 @@ async def validate_input(hass: core.HomeAssistant, data):
raise CannotConnect raise CannotConnect
except ValueError: except ValueError:
raise InvalidAuth raise InvalidAuth
return data["dps"] return dps_string_list(data["dps"])
class LocaltuyaConfigFlow(config_entries.ConfigFlow, domain=DOMAIN): class LocaltuyaConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
@@ -87,9 +95,9 @@ class LocaltuyaConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
def __init__(self): def __init__(self):
"""Initialize a new LocaltuyaConfigFlow.""" """Initialize a new LocaltuyaConfigFlow."""
self.basic_info = None self.basic_info = None
self.dps_data = None self.dps_strings = []
self.platform = None self.platform = None
self.platform_dps_fields = None self.platform_schema = None
self.entities = [] self.entities = []
async def async_step_user(self, user_input=None): async def async_step_user(self, user_input=None):
@@ -101,7 +109,7 @@ class LocaltuyaConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
try: try:
self.basic_info = user_input self.basic_info = user_input
self.dps_data = await validate_input(self.hass, user_input) self.dps_strings = await validate_input(self.hass, user_input)
return await self.async_step_pick_entity_type() return await self.async_step_pick_entity_type()
except CannotConnect: except CannotConnect:
errors["base"] = "cannot_connect" errors["base"] = "cannot_connect"
@@ -144,16 +152,14 @@ class LocaltuyaConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
) )
if not already_configured: if not already_configured:
user_input[CONF_PLATFORM] = self.platform user_input[CONF_PLATFORM] = self.platform
self.entities.append( self.entities.append(strip_dps_values(user_input, self.dps_strings))
strip_dps_values(user_input, self.platform_dps_fields)
)
return await self.async_step_pick_entity_type() return await self.async_step_pick_entity_type()
errors["base"] = "entity_already_configured" errors["base"] = "entity_already_configured"
return self.async_show_form( return self.async_show_form(
step_id="add_entity", step_id="add_entity",
data_schema=platform_schema(self.dps_data, self.platform_dps_fields), data_schema=platform_schema(self.dps_strings, self.platform_schema),
errors=errors, errors=errors,
description_placeholders={"platform": self.platform}, description_placeholders={"platform": self.platform},
) )
@@ -167,12 +173,12 @@ class LocaltuyaConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
CONF_FRIENDLY_NAME: conf[CONF_FRIENDLY_NAME], CONF_FRIENDLY_NAME: conf[CONF_FRIENDLY_NAME],
CONF_PLATFORM: self.platform, CONF_PLATFORM: self.platform,
} }
for field in self.platform_dps_fields: for field in self.platform_schema.keys():
converted[str(field)] = conf[field] converted[str(field)] = conf[field]
return converted return converted
await self.async_set_unique_id(user_input[CONF_DEVICE_ID]) await self.async_set_unique_id(user_input[CONF_DEVICE_ID])
self._set_platform(user_input[CONF_PLATFORM]) self._set_platform(user_input[CONF_PLATFORM], [])
if len(user_input.get(CONF_SWITCHES, [])) > 0: if len(user_input.get(CONF_SWITCHES, [])) > 0:
for switch_conf in user_input[CONF_SWITCHES].values(): for switch_conf in user_input[CONF_SWITCHES].values():
@@ -194,9 +200,9 @@ class LocaltuyaConfigFlow(config_entries.ConfigFlow, domain=DOMAIN):
def _set_platform(self, platform): def _set_platform(self, platform):
integration_module = ".".join(__name__.split(".")[:-1]) integration_module = ".".join(__name__.split(".")[:-1])
self.platform = platform self.platform = platform
self.platform_dps_fields = import_module( self.platform_schema = import_module(
"." + platform, integration_module "." + platform, integration_module
).DPS_FIELDS ).flow_schema(self.dps_strings)
class CannotConnect(exceptions.HomeAssistantError): class CannotConnect(exceptions.HomeAssistantError):

View File

@@ -43,7 +43,10 @@ UPDATE_RETRY_LIMIT = 3
PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(BASE_PLATFORM_SCHEMA) PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(BASE_PLATFORM_SCHEMA)
DPS_FIELDS = []
def flow_schema(dps):
"""Return schema used in config flow."""
return {}
async def async_setup_entry(hass, config_entry, async_add_entities): async def async_setup_entry(hass, config_entry, async_add_entities):

View File

@@ -53,7 +53,7 @@ _LOGGER = logging.getLogger(__name__)
DEFAULT_ID = "1" DEFAULT_ID = "1"
# TODO: This will eventully merge with DPS_FIELDS # TODO: This will eventully merge with flow_schema
SWITCH_SCHEMA = vol.Schema( SWITCH_SCHEMA = vol.Schema(
{ {
vol.Optional(CONF_ID, default=DEFAULT_ID): cv.string, vol.Optional(CONF_ID, default=DEFAULT_ID): cv.string,
@@ -75,11 +75,13 @@ PLATFORM_SCHEMA = PLATFORM_SCHEMA.extend(BASE_PLATFORM_SCHEMA).extend(
) )
DPS_FIELDS = [ def flow_schema(dps):
vol.Optional(CONF_CURRENT), """Return schema used in config flow."""
vol.Optional(CONF_CURRENT_CONSUMPTION), return {
vol.Optional(CONF_VOLTAGE), vol.Optional(CONF_CURRENT): vol.In(dps),
] vol.Optional(CONF_CURRENT_CONSUMPTION): vol.In(dps),
vol.Optional(CONF_VOLTAGE): vol.In(dps),
}
async def async_setup_entry(hass, config_entry, async_add_entities): async def async_setup_entry(hass, config_entry, async_add_entities):