Skip to content

Commit

Permalink
Merge pull request #159 from acon96/release/v0.3.1
Browse files Browse the repository at this point in the history
Release v0.3.1
  • Loading branch information
acon96 authored Jun 8, 2024
2 parents cbc7ced + 50bcd2e commit f407e53
Show file tree
Hide file tree
Showing 12 changed files with 224 additions and 97 deletions.
12 changes: 0 additions & 12 deletions .github/workflows/create-release.yml
Original file line number Diff line number Diff line change
Expand Up @@ -20,12 +20,8 @@ jobs:
matrix:
include:
# ARM variants
- home_assistant_version: "2023.12.4"
arch: "aarch64"
- home_assistant_version: "2024.2.1"
arch: "aarch64"
- home_assistant_version: "2023.12.4"
arch: "armhf"
- home_assistant_version: "2024.2.1"
arch: "armhf"

Expand All @@ -34,18 +30,10 @@ jobs:
suffix: "-noavx"
arch: "amd64"
extra_defines: "-DLLAMA_AVX=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF -DLLAMA_F16C=OFF"
- home_assistant_version: "2023.12.4"
arch: "amd64"
suffix: "-noavx"
extra_defines: "-DLLAMA_AVX=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF -DLLAMA_F16C=OFF"
- home_assistant_version: "2024.2.1"
arch: "i386"
suffix: "-noavx"
extra_defines: "-DLLAMA_AVX=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF -DLLAMA_F16C=OFF"
- home_assistant_version: "2023.12.4"
arch: "i386"
suffix: "-noavx"
extra_defines: "-DLLAMA_AVX=OFF -DLLAMA_AVX2=OFF -DLLAMA_FMA=OFF -DLLAMA_F16C=OFF"

# AVX2 and AVX512
- home_assistant_version: "2024.2.1"
Expand Down
1 change: 1 addition & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ In order to facilitate running the project entirely on the system where Home Ass
## Version History
| Version | Description |
|---------|----------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
| v0.3.1 | Adds basic area support in prompting, Fix for broken requirements, fix for issue with formatted tools, fix custom API not registering on startup properly |
| v0.3 | Adds support for Home Assistant LLM APIs, improved model prompting and tool formatting options, and automatic detection of GGUF quantization levels on HuggingFace |
| v0.2.17 | Disable native llama.cpp wheel optimizations, add Command R prompt format |
| v0.2.16 | Fix for missing huggingface_hub package preventing startup |
Expand Down
27 changes: 23 additions & 4 deletions custom_components/llama_conversation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
from __future__ import annotations

import logging
from typing import Final

import homeassistant.components.conversation as ha_conversation
from homeassistant.config_entries import ConfigEntry
Expand Down Expand Up @@ -31,7 +32,7 @@
BACKEND_TYPE_GENERIC_OPENAI,
BACKEND_TYPE_LLAMA_CPP_PYTHON_SERVER,
BACKEND_TYPE_OLLAMA,
ALLOWED_LEGACY_SERVICE_CALL_ARGUMENTS,
ALLOWED_SERVICE_CALL_ARGUMENTS,
DOMAIN,
HOME_LLM_API_ID,
SERVICE_TOOL_NAME,
Expand All @@ -54,6 +55,10 @@ async def update_listener(hass: HomeAssistant, entry: ConfigEntry):
async def async_setup_entry(hass: HomeAssistant, entry: ConfigEntry) -> bool:
"""Set up Local LLM Conversation from a config entry."""

# make sure the API is registered
if not any([x.id == HOME_LLM_API_ID for x in llm.async_get_apis(hass)]):
llm.async_register_api(hass, HomeLLMAPI(hass))

def create_agent(backend_type):
agent_cls = None

Expand Down Expand Up @@ -107,8 +112,8 @@ async def async_migrate_entry(hass, config_entry: ConfigEntry):
class HassServiceTool(llm.Tool):
"""Tool to get the current time."""

name = SERVICE_TOOL_NAME
description = "Executes a Home Assistant service"
name: Final[str] = SERVICE_TOOL_NAME
description: Final[str] = "Executes a Home Assistant service"

# Optional. A voluptuous schema of the input parameters.
parameters = vol.Schema({
Expand All @@ -125,15 +130,29 @@ class HassServiceTool(llm.Tool):
vol.Optional('item'): str,
})

ALLOWED_SERVICES: Final[list[str]] = [
"turn_on", "turn_off", "toggle", "press", "increase_speed", "decrease_speed", "open_cover", "close_cover", "stop_cover",
"lock", "unlock", "start", "stop", "return_to_base", "pause", "cancel", "add_item"
]
ALLOWED_DOMAINS: Final[list[str]] = [
"light", "switch", "button", "fan", "cover", "lock", "media_player", "climate", "vacuum", "todo", "timer", "script",
]

async def async_call(
self, hass: HomeAssistant, tool_input: llm.ToolInput, llm_context: llm.LLMContext
) -> JsonObjectType:
"""Call the tool."""
domain, service = tuple(tool_input.tool_args["service"].split("."))
target_device = tool_input.tool_args["target_device"]

if domain not in self.ALLOWED_DOMAINS or service not in self.ALLOWED_SERVICES:
return { "result": "unknown service" }

if domain == "script" and service not in ["reload", "turn_on", "turn_off", "toggle"]:
return { "result": "unknown service" }

service_data = {ATTR_ENTITY_ID: target_device}
for attr in ALLOWED_LEGACY_SERVICE_CALL_ARGUMENTS:
for attr in ALLOWED_SERVICE_CALL_ARGUMENTS:
if attr in tool_input.tool_args.keys():
service_data[attr] = tool_input.tool_args[attr]
try:
Expand Down
88 changes: 68 additions & 20 deletions custom_components/llama_conversation/agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,14 +23,14 @@
from homeassistant.const import ATTR_ENTITY_ID, CONF_HOST, CONF_PORT, CONF_SSL, MATCH_ALL, CONF_LLM_HASS_API
from homeassistant.core import HomeAssistant, callback
from homeassistant.exceptions import ConfigEntryNotReady, ConfigEntryError, TemplateError, HomeAssistantError
from homeassistant.helpers import config_validation as cv, intent, template, entity_registry as er, llm
from homeassistant.helpers import config_validation as cv, intent, template, entity_registry as er, llm, area_registry as ar
from homeassistant.helpers.event import async_track_state_change, async_call_later
from homeassistant.util import ulid, color

import voluptuous_serialize

from .utils import closest_color, flatten_vol_schema, custom_custom_serializer, install_llama_cpp_python, \
validate_llama_cpp_python_installation
validate_llama_cpp_python_installation, format_url
from .const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
Expand Down Expand Up @@ -106,14 +106,14 @@
TEXT_GEN_WEBUI_CHAT_MODE_CHAT,
TEXT_GEN_WEBUI_CHAT_MODE_INSTRUCT,
TEXT_GEN_WEBUI_CHAT_MODE_CHAT_INSTRUCT,
ALLOWED_LEGACY_SERVICE_CALL_ARGUMENTS,
DOMAIN,
HOME_LLM_API_ID,
SERVICE_TOOL_NAME,
PROMPT_TEMPLATE_DESCRIPTIONS,
TOOL_FORMAT_FULL,
TOOL_FORMAT_REDUCED,
TOOL_FORMAT_MINIMAL,
ALLOWED_SERVICE_CALL_ARGUMENTS,
)

# make type checking work for llama-cpp-python without importing it directly at runtime
Expand Down Expand Up @@ -254,11 +254,12 @@ async def async_process(
)
except HomeAssistantError as err:
_LOGGER.error("Error getting LLM API: %s", err)
intent_response = intent.IntentResponse(language=user_input.language)
intent_response.async_set_error(
intent.IntentResponseErrorCode.UNKNOWN,
f"Error preparing LLM API: {err}",
)
return conversation.ConversationResult(
return ConversationResult(
response=intent_response, conversation_id=user_input.conversation_id
)

Expand Down Expand Up @@ -445,6 +446,7 @@ def _async_get_exposed_entities(self) -> tuple[dict[str, str], list[str]]:
entity_states = {}
domains = set()
entity_registry = er.async_get(self.hass)
area_registry = ar.async_get(self.hass)

for state in self.hass.states.async_all():
if not async_should_expose(self.hass, CONVERSATION_DOMAIN, state.entity_id):
Expand All @@ -456,11 +458,15 @@ def _async_get_exposed_entities(self) -> tuple[dict[str, str], list[str]]:
attributes["state"] = state.state
if entity and entity.aliases:
attributes["aliases"] = entity.aliases

if entity and entity.area_id:
area = area_registry.async_get_area(entity.area_id)
attributes["area_id"] = area.id
attributes["area_name"] = area.name

entity_states[state.entity_id] = attributes
domains.add(state.domain)

# _LOGGER.debug(f"Exposed entities: {entity_states}")

return entity_states, list(domains)

def _format_prompt(
Expand Down Expand Up @@ -556,6 +562,9 @@ def _generate_icl_examples(self, num_examples, entity_names):
entity_names = entity_names[:]
entity_domains = set([x.split(".")[0] for x in entity_names])

area_registry = ar.async_get(self.hass)
all_areas = list(area_registry.async_list_areas())

in_context_examples = [
x for x in self.in_context_examples
if x["type"] in entity_domains
Expand All @@ -575,7 +584,7 @@ def _generate_icl_examples(self, num_examples, entity_names):
response = chosen_example["response"]

random_device = [ x for x in entity_names if x.split(".")[0] == chosen_example["type"] ][0]
random_area = "bedroom" # todo, pick a random area
random_area = random.choice(all_areas).name
random_brightness = round(random.random(), 2)
random_color = random.choice(list(color.COLORS.keys()))

Expand Down Expand Up @@ -619,8 +628,8 @@ def _generate_system_prompt(self, prompt_template: str, llm_api: llm.APIInstance
extra_attributes_to_expose = self.entry.options \
.get(CONF_EXTRA_ATTRIBUTES_TO_EXPOSE, DEFAULT_EXTRA_ATTRIBUTES_TO_EXPOSE)

def expose_attributes(attributes):
result = attributes["state"]
def expose_attributes(attributes) -> list[str]:
result = []
for attribute_name in extra_attributes_to_expose:
if attribute_name not in attributes:
continue
Expand All @@ -644,19 +653,38 @@ def expose_attributes(attributes):
elif attribute_name == "humidity":
value = f"{value}%"

result = result + ";" + str(value)
result.append(str(value))
return result

device_states = []
devices = []
formatted_devices = ""

# expose devices and their alias as well
for name, attributes in entities_to_expose.items():
device_states.append(f"{name} '{attributes.get('friendly_name')}' = {expose_attributes(attributes)}")
state = attributes["state"]
exposed_attributes = expose_attributes(attributes)
str_attributes = ";".join([state] + exposed_attributes)

formatted_devices = formatted_devices + f"{name} '{attributes.get('friendly_name')}' = {str_attributes}\n"
devices.append({
"entity_id": name,
"name": attributes.get('friendly_name'),
"state": state,
"attributes": exposed_attributes,
"area_name": attributes.get("area_name"),
"area_id": attributes.get("area_id")
})
if "aliases" in attributes:
for alias in attributes["aliases"]:
device_states.append(f"{name} '{alias}' = {expose_attributes(attributes)}")

formatted_states = "\n".join(device_states) + "\n"
formatted_devices = formatted_devices + f"{name} '{alias}' = {str_attributes}\n"
devices.append({
"entity_id": name,
"name": alias,
"state": state,
"attributes": exposed_attributes,
"area_name": attributes.get("area_name"),
"area_id": attributes.get("area_id")
})

if llm_api:
if llm_api.api.id == HOME_LLM_API_ID:
Expand All @@ -670,7 +698,7 @@ def expose_attributes(attributes):

for name, service in service_dict.get(domain, {}).items():
args = flatten_vol_schema(service.schema)
args_to_expose = set(args).intersection(ALLOWED_LEGACY_SERVICE_CALL_ARGUMENTS)
args_to_expose = set(args).intersection(ALLOWED_SERVICE_CALL_ARGUMENTS)
service_schema = vol.Schema({
vol.Optional(arg): str for arg in args_to_expose
})
Expand All @@ -681,17 +709,26 @@ def expose_attributes(attributes):
self._format_tool(*tool)
for tool in all_services
]

else:
tools = [
self._format_tool(tool.name, tool.parameters, tool.description)
for tool in llm_api.tools
]

if self.entry.options.get(CONF_TOOL_FORMAT, DEFAULT_TOOL_FORMAT) == TOOL_FORMAT_MINIMAL:
formatted_tools = ", ".join(tools)
else:
formatted_tools = json.dumps(tools)
else:
tools = "No tools were provided. If the user requests you interact with a device, tell them you are unable to do so."
tools = ["No tools were provided. If the user requests you interact with a device, tell them you are unable to do so."]
formatted_tools = tools[0]

render_variables = {
"devices": formatted_states,
"devices": devices,
"formatted_devices": formatted_devices,
"tools": tools,
"formatted_tools": formatted_tools,
"response_examples": []
}

Expand Down Expand Up @@ -1042,7 +1079,13 @@ class GenericOpenAIAPIAgent(LocalLLMAgent):
model_name: str

def _load_model(self, entry: ConfigEntry) -> None:
self.api_host = f"{'https' if entry.data[CONF_SSL] else 'http'}://{entry.data[CONF_HOST]}:{entry.data[CONF_PORT]}"
self.api_host = format_url(
hostname=entry.data[CONF_HOST],
port=entry.data[CONF_PORT],
ssl=entry.data[CONF_SSL],
path=""
)

self.api_key = entry.data.get(CONF_OPENAI_API_KEY)
self.model_name = entry.data.get(CONF_CHAT_MODEL)

Expand Down Expand Up @@ -1249,7 +1292,12 @@ class OllamaAPIAgent(LocalLLMAgent):
model_name: str

def _load_model(self, entry: ConfigEntry) -> None:
self.api_host = f"{'https' if entry.data[CONF_SSL] else 'http'}://{entry.data[CONF_HOST]}:{entry.data[CONF_PORT]}"
self.api_host = format_url(
hostname=entry.data[CONF_HOST],
port=entry.data[CONF_PORT],
ssl=entry.data[CONF_SSL],
path=""
)
self.api_key = entry.data.get(CONF_OPENAI_API_KEY)
self.model_name = entry.data.get(CONF_CHAT_MODEL)

Expand Down
18 changes: 15 additions & 3 deletions custom_components/llama_conversation/config_flow.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
from homeassistant.util.package import is_installed
from importlib.metadata import version

from .utils import download_model_from_hf, install_llama_cpp_python, MissingQuantizationException
from .utils import download_model_from_hf, install_llama_cpp_python, format_url, MissingQuantizationException
from .const import (
CONF_CHAT_MODEL,
CONF_MAX_TOKENS,
Expand Down Expand Up @@ -489,6 +489,8 @@ async def async_step_download(
self.download_task = None
return self.async_show_progress_done(next_step_id=next_step)

# TODO: add validate for generic openAI API and hit the `/v1/models endpoint to check

def _validate_text_generation_webui(self, user_input: dict) -> tuple:
"""
Validates a connection to text-generation-webui and that the model exists on the remote server
Expand All @@ -503,7 +505,12 @@ def _validate_text_generation_webui(self, user_input: dict) -> tuple:
headers["Authorization"] = f"Bearer {api_key}"

models_result = requests.get(
f"{'https' if self.model_config[CONF_SSL] else 'http'}://{self.model_config[CONF_HOST]}:{self.model_config[CONF_PORT]}/v1/internal/model/list",
format_url(
hostname=self.model_config[CONF_HOST],
port=self.model_config[CONF_PORT],
ssl=self.model_config[CONF_SSL],
path="/v1/internal/model/list"
),
timeout=5, # quick timeout
headers=headers
)
Expand Down Expand Up @@ -535,7 +542,12 @@ def _validate_ollama(self, user_input: dict) -> tuple:
headers["Authorization"] = f"Bearer {api_key}"

models_result = requests.get(
f"{'https' if self.model_config[CONF_SSL] else 'http'}://{self.model_config[CONF_HOST]}:{self.model_config[CONF_PORT]}/api/tags",
format_url(
hostname=self.model_config[CONF_HOST],
port=self.model_config[CONF_PORT],
ssl=self.model_config[CONF_SSL],
path="/api/tags"
),
timeout=5, # quick timeout
headers=headers
)
Expand Down
Loading

0 comments on commit f407e53

Please sign in to comment.