diff --git a/custom_components/proxmox_pve/services.py b/custom_components/proxmox_pve/services.py index 355d3e7..992ced5 100644 --- a/custom_components/proxmox_pve/services.py +++ b/custom_components/proxmox_pve/services.py @@ -5,6 +5,7 @@ import voluptuous as vol from homeassistant.core import HomeAssistant, ServiceCall from homeassistant.helpers import device_registry as dr +from homeassistant.helpers import entity_registry as er from .const import DOMAIN @@ -16,6 +17,7 @@ SERVICE_STOP_HARD = "stop_hard" SERVICE_REBOOT = "reboot" ATTR_DEVICE_ID = "device_id" +ATTR_ENTITY_ID = "entity_id" ATTR_CONFIG_ENTRY_ID = "config_entry_id" ATTR_HOST = "host" ATTR_NODE = "node" @@ -24,10 +26,11 @@ ATTR_TYPE = "type" VALID_TYPES = ("qemu", "lxc") -# HA may pass device_id as str or list[str] (especially when using UI targets). +# Accept device_id passed as str or list[str] (depends on HA UI/script) SERVICE_SCHEMA = vol.Schema( { vol.Optional(ATTR_DEVICE_ID): vol.Any(str, [str]), + vol.Optional(ATTR_ENTITY_ID): vol.Any(str, [str]), vol.Optional(ATTR_CONFIG_ENTRY_ID): str, vol.Optional(ATTR_HOST): str, vol.Optional(ATTR_NODE): str, @@ -38,7 +41,6 @@ SERVICE_SCHEMA = vol.Schema( def _first_str(value: Any) -> str | None: - """Return first string from str or list[str], else None.""" if isinstance(value, str) and value.strip(): return value.strip() if isinstance(value, list) and value: @@ -48,19 +50,54 @@ def _first_str(value: Any) -> str | None: return None -def _get_device_id(call: ServiceCall) -> str | None: +def _first_str_from_target(target: Any, key: str) -> str | None: + if isinstance(target, dict): + return _first_str(target.get(key)) + return None + + +def _get_device_id(hass: HomeAssistant, call: ServiceCall) -> str | None: """ - IMPORTANT: On some HA versions ServiceCall has no .target attribute. - UI targets are still provided, usually as data.device_id (often a list). + Robust device_id extraction across HA versions. + + Priority: + 1) call.target.device_id (newer HA) + 2) call.data.device_id (some wrappers) + 3) call.target.entity_id -> map to device_id + 4) call.data.entity_id -> map to device_id """ - return _first_str(call.data.get(ATTR_DEVICE_ID)) + target = getattr(call, "target", None) + + # 1) target.device_id + dev_id = _first_str_from_target(target, "device_id") + if dev_id: + return dev_id + + # 2) data.device_id + dev_id = _first_str(call.data.get(ATTR_DEVICE_ID)) + if dev_id: + return dev_id + + # 3) target.entity_id -> device_id + ent_id = _first_str_from_target(target, "entity_id") + if ent_id: + ent_reg = er.async_get(hass) + ent = ent_reg.async_get(ent_id) + if ent and ent.device_id: + return ent.device_id + + # 4) data.entity_id -> device_id + ent_id = _first_str(call.data.get(ATTR_ENTITY_ID)) + if ent_id: + ent_reg = er.async_get(hass) + ent = ent_reg.async_get(ent_id) + if ent and ent.device_id: + return ent.device_id + + return None def _parse_guest_identifier(identifier: str) -> Tuple[str, str, int]: - """ - Guest device identifier format: "node:type:vmid" - Example: "pve1:qemu:100" - """ parts = identifier.split(":") if len(parts) != 3: raise ValueError(f"Invalid guest identifier: {identifier}") @@ -72,8 +109,7 @@ def _parse_guest_identifier(identifier: str) -> Tuple[str, str, int]: def _resolve_target(hass: HomeAssistant, call: ServiceCall) -> Tuple[str, str, int]: - """Resolve node/type/vmid from device_id OR node+vmid (+ optional type).""" - device_id = _get_device_id(call) + device_id = _get_device_id(hass, call) node = call.data.get(ATTR_NODE) vmid = call.data.get(ATTR_VMID) vmtype = call.data.get(ATTR_TYPE, "qemu") @@ -84,20 +120,17 @@ def _resolve_target(hass: HomeAssistant, call: ServiceCall) -> Tuple[str, str, i if not device: raise ValueError(f"Device not found: {device_id}") - # Find our guest identifier in device.identifiers for ident_domain, ident_value in device.identifiers: if ident_domain != DOMAIN: continue - # Node devices are "node:" — ignore those if ident_value.startswith("node:"): continue return _parse_guest_identifier(ident_value) raise ValueError(f"Selected device has no Easy Proxmox guest identifier: {device_id}") - # manual mode if not node or vmid is None: - raise ValueError("Provide device_id OR node + vmid (+ optional type/host/config_entry_id).") + raise ValueError("Provide a Device/Entity target OR node + vmid (+ optional type/host/config_entry_id).") if vmtype not in VALID_TYPES: raise ValueError(f"Invalid type: {vmtype} (allowed: {VALID_TYPES})") @@ -113,7 +146,6 @@ def _get_domain_entries(hass: HomeAssistant) -> dict[str, Any]: def _pick_entry_id_for_device(hass: HomeAssistant, device_id: str) -> str: - """Pick correct config_entry_id by using device.config_entries.""" dev_reg = dr.async_get(hass) device = dev_reg.async_get(device_id) if not device: @@ -146,7 +178,6 @@ def _pick_entry_id_by_host(hass: HomeAssistant, host: str) -> str: def _pick_entry_id_by_guest_lookup(hass: HomeAssistant, node: str, vmtype: str, vmid: int) -> str: - """Find correct entry by scanning each entry's resources list.""" domain_entries = _get_domain_entries(hass) matches = [] @@ -169,18 +200,17 @@ def _pick_entry_id_by_guest_lookup(hass: HomeAssistant, node: str, vmtype: str, if not matches: raise ValueError( f"Could not find guest {node}/{vmtype}/{vmid} in any configured Proxmox host. " - "Provide host or config_entry_id, or use device_id." + "Provide host or config_entry_id, or use a Device/Entity target." ) if len(matches) > 1: raise ValueError( f"Guest {node}/{vmtype}/{vmid} exists on multiple configured hosts (ambiguous). " - "Please provide host or config_entry_id, or use device_id." + "Please provide host or config_entry_id, or use a Device/Entity target." ) return matches[0] -def _resolve_entry_id(hass: HomeAssistant, call: ServiceCall, target: Tuple[str, str, int]) -> str: - """Resolve which config entry should execute this service call.""" +def _resolve_entry_id(hass: HomeAssistant, call: ServiceCall, node: str, vmtype: str, vmid: int) -> str: domain_entries = _get_domain_entries(hass) config_entry_id = call.data.get(ATTR_CONFIG_ENTRY_ID) @@ -189,7 +219,7 @@ def _resolve_entry_id(hass: HomeAssistant, call: ServiceCall, target: Tuple[str, raise ValueError(f"config_entry_id '{config_entry_id}' not found or not loaded.") return config_entry_id - device_id = _get_device_id(call) + device_id = _get_device_id(hass, call) if device_id: return _pick_entry_id_for_device(hass, device_id) @@ -197,34 +227,41 @@ def _resolve_entry_id(hass: HomeAssistant, call: ServiceCall, target: Tuple[str, if host: return _pick_entry_id_by_host(hass, host) - node, vmtype, vmid = target return _pick_entry_id_by_guest_lookup(hass, node, vmtype, vmid) async def async_register_services(hass: HomeAssistant) -> None: - """Register services once per HA instance.""" if hass.services.has_service(DOMAIN, SERVICE_START): return async def _call_action(call: ServiceCall, action: str) -> None: node, vmtype, vmid = _resolve_target(hass, call) - entry_id = _resolve_entry_id(hass, call, (node, vmtype, vmid)) + entry_id = _resolve_entry_id(hass, call, node, vmtype, vmid) - domain_entries = _get_domain_entries(hass) - entry_data = domain_entries.get(entry_id) + entry_data = _get_domain_entries(hass).get(entry_id) if not isinstance(entry_data, dict) or not entry_data.get("client"): raise ValueError(f"Selected config entry '{entry_id}' has no client (not loaded).") client = entry_data["client"] - _LOGGER.debug("Service action=%s entry=%s target=%s/%s/%s", action, entry_id, node, vmtype, vmid) + _LOGGER.debug("Service action=%s entry=%s target=%s/%s/%s data=%s", action, entry_id, node, vmtype, vmid, call.data) await client.guest_action(node=node, vmid=vmid, vmtype=vmtype, action=action) - hass.services.async_register(DOMAIN, SERVICE_START, lambda call: _call_action(call, "start"), schema=SERVICE_SCHEMA) - hass.services.async_register( - DOMAIN, SERVICE_SHUTDOWN, lambda call: _call_action(call, "shutdown"), schema=SERVICE_SCHEMA - ) - hass.services.async_register(DOMAIN, SERVICE_STOP_HARD, lambda call: _call_action(call, "stop"), schema=SERVICE_SCHEMA) - hass.services.async_register(DOMAIN, SERVICE_REBOOT, lambda call: _call_action(call, "reboot"), schema=SERVICE_SCHEMA) + async def handle_start(call: ServiceCall) -> None: + await _call_action(call, "start") + + async def handle_shutdown(call: ServiceCall) -> None: + await _call_action(call, "shutdown") + + async def handle_stop_hard(call: ServiceCall) -> None: + await _call_action(call, "stop") + + async def handle_reboot(call: ServiceCall) -> None: + await _call_action(call, "reboot") + + hass.services.async_register(DOMAIN, SERVICE_START, handle_start, schema=SERVICE_SCHEMA) + hass.services.async_register(DOMAIN, SERVICE_SHUTDOWN, handle_shutdown, schema=SERVICE_SCHEMA) + hass.services.async_register(DOMAIN, SERVICE_STOP_HARD, handle_stop_hard, schema=SERVICE_SCHEMA) + hass.services.async_register(DOMAIN, SERVICE_REBOOT, handle_reboot, schema=SERVICE_SCHEMA) async def async_unregister_services(hass: HomeAssistant) -> None: