Adds ChatDefinition utility (#1279)

The ChatDefinition allows users to turn any flow into a Chat flow by
defining what has to run and what are the inputs and outputs.

The ChatDefinition requires a function, inputs(optional) and
output_key(optional).
The function receives a dictionary as input and can output a string or a
dict. If the output is a dict, then an output_key must be provided.

Anything can run inside the function. You can also pass methods of
pre-built classes like a Chain.

Here's an example of how to use it in a CustomComponent:

```python
from langflow import CustomComponent
from langflow.utils.chat import ChatDefinition

class Component(CustomComponent):
    documentation: str = "http://docs.langflow.org/components/custom"

    def build(self) -> Data:
        
        def func(inputs, callbacks):
            return {"text":'This is a simple example.'}
        
        return ChatDefinition(func=func, inputs=[], output_key="text")
```
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-01-06 16:17:14 -03:00 committed by GitHub
commit a5aef06544
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
7 changed files with 179 additions and 20 deletions

41
poetry.lock generated
View file

@ -3728,6 +3728,43 @@ dev = ["black (>=23.3.0)", "httpx (>=0.24.1)", "mkdocs (>=1.4.3)", "mkdocs-mater
server = ["fastapi (>=0.100.0)", "pydantic-settings (>=2.0.1)", "sse-starlette (>=1.6.1)", "starlette-context (>=0.3.6,<0.4)", "uvicorn (>=0.22.0)"]
test = ["httpx (>=0.24.1)", "pytest (>=7.4.0)", "scipy (>=1.10)"]
[[package]]
name = "llama-index"
version = "0.9.24"
description = "Interface between LLMs and your data"
optional = false
python-versions = ">=3.8.1,<4.0"
files = [
{file = "llama_index-0.9.24-py3-none-any.whl", hash = "sha256:aeef8a4fb478d45474261289046f37c2805e3bf3453c156c84088c0414465e5e"},
{file = "llama_index-0.9.24.tar.gz", hash = "sha256:48175a35c30427f361068693d6f384baf76865831569ca4e04a1a8b6f10ba269"},
]
[package.dependencies]
aiohttp = ">=3.8.6,<4.0.0"
beautifulsoup4 = ">=4.12.2,<5.0.0"
dataclasses-json = "*"
deprecated = ">=1.2.9.3"
fsspec = ">=2023.5.0"
httpx = "*"
nest-asyncio = ">=1.5.8,<2.0.0"
nltk = ">=3.8.1,<4.0.0"
numpy = "*"
openai = ">=1.1.0"
pandas = "*"
requests = ">=2.31.0"
SQLAlchemy = {version = ">=1.4.49", extras = ["asyncio"]}
tenacity = ">=8.2.0,<9.0.0"
tiktoken = ">=0.3.3"
typing-extensions = ">=4.5.0"
typing-inspect = ">=0.8.0"
[package.extras]
gradientai = ["gradientai (>=1.4.0)"]
langchain = ["langchain (>=0.0.303)"]
local-models = ["optimum[onnxruntime] (>=1.13.2,<2.0.0)", "sentencepiece (>=0.1.99,<0.2.0)", "transformers[torch] (>=4.34.0,<5.0.0)"]
postgres = ["asyncpg (>=0.28.0,<0.29.0)", "pgvector (>=0.1.0,<0.2.0)", "psycopg-binary (>=3.1.12,<4.0.0)"]
query-tools = ["guidance (>=0.0.64,<0.0.65)", "jsonpath-ng (>=1.6.0,<2.0.0)", "lm-format-enforcer (>=0.4.3,<0.5.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "scikit-learn", "spacy (>=3.7.1,<4.0.0)"]
[[package]]
name = "locust"
version = "2.20.0"
@ -7542,7 +7579,7 @@ files = [
]
[package.dependencies]
greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""}
greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"asyncio\""}
typing-extensions = ">=4.2.0"
[package.extras]
@ -9184,4 +9221,4 @@ local = ["ctransformers", "llama-cpp-python", "sentence-transformers"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.9,<3.11"
content-hash = "298c4c74d1e9ed37405bfbeaefb8ecd5d766d970a4e16fb37cbbccbab79573ce"
content-hash = "80378e2f6511caf4f46ad7ff47a5e0823aca87e92e6b0f3ced80040bbb891ffb"

View file

@ -105,6 +105,7 @@ pgvector = "^0.2.3"
pyautogen = "^0.2.0"
langchain-google-genai = "^0.0.2"
pytube = "^15.0.0"
llama-index = "^0.9.24"
[tool.poetry.group.dev.dependencies]
pytest-asyncio = "^0.23.1"

View file

@ -2,7 +2,6 @@ from typing import Optional
from langchain.llms.base import BaseLLM
from langchain.llms.bedrock import Bedrock
from langflow import CustomComponent
@ -44,7 +43,7 @@ class AmazonBedrockComponent(CustomComponent):
model_kwargs: Optional[dict] = None,
endpoint_url: Optional[str] = None,
streaming: bool = False,
cache: bool | None = None,
cache: Optional[bool] = None,
) -> BaseLLM:
try:
output = Bedrock(

View file

@ -5,14 +5,15 @@ from typing import Any, Dict, List
import orjson
from fastapi import WebSocket, status
from loguru import logger
from starlette.websockets import WebSocketState
from langflow.api.v1.schemas import ChatMessage, ChatResponse, FileResponse
from langflow.interface.utils import pil_to_base64
from langflow.services import ServiceType, service_manager
from langflow.services.base import Service
from langflow.services.chat.cache import Subject
from langflow.services.chat.utils import process_graph
from loguru import logger
from starlette.websockets import WebSocketState
from .cache import cache_service
@ -117,7 +118,7 @@ class ChatService(Service):
if "after sending" in str(exc):
logger.error(f"Error closing connection: {exc}")
async def process_message(self, client_id: str, payload: Dict, langchain_object: Any):
async def process_message(self, client_id: str, payload: Dict, build_result: Any):
# Process the graph data and chat message
chat_inputs = payload.pop("inputs", {})
chatkey = payload.pop("chatKey", None)
@ -134,12 +135,12 @@ class ChatService(Service):
logger.debug("Generating result and thought")
result, intermediate_steps, raw_output = await process_graph(
langchain_object=langchain_object,
build_result=build_result,
chat_inputs=chat_inputs,
client_id=client_id,
session_id=self.connection_ids[client_id],
)
self.set_cache(client_id, langchain_object)
self.set_cache(client_id, build_result)
except Exception as e:
# Log stack trace
logger.exception(e)
@ -205,8 +206,8 @@ class ChatService(Service):
continue
with self.chat_cache.set_client_id(client_id):
if langchain_object := self.cache_service.get(client_id).get("result"):
await self.process_message(client_id, payload, langchain_object)
if build_result := self.cache_service.get(client_id).get("result"):
await self.process_message(client_id, payload, build_result)
else:
raise RuntimeError(f"Could not find a build result for client_id {client_id}")

View file

@ -1,20 +1,28 @@
from typing import Any
from langchain.agents import AgentExecutor
from langchain.chains.base import Chain
from langchain_core.runnables import Runnable
from loguru import logger
from langflow.api.v1.schemas import ChatMessage
from langflow.interface.utils import try_setting_streaming_options
from langflow.processing.base import get_result_and_steps
from langflow.utils.chat import ChatDefinition
LANGCHAIN_RUNNABLES = (Chain, Runnable, AgentExecutor)
async def process_graph(
langchain_object,
build_result,
chat_inputs: ChatMessage,
client_id: str,
session_id: str,
):
langchain_object = try_setting_streaming_options(langchain_object)
build_result = try_setting_streaming_options(build_result)
logger.debug("Loaded langchain object")
if langchain_object is None:
if build_result is None:
# Raise user facing error
raise ValueError("There was an error loading the langchain_object. Please, check all the nodes and try again.")
@ -25,15 +33,36 @@ async def process_graph(
chat_inputs.message = {}
logger.debug("Generating result and thought")
result, intermediate_steps, raw_output = await get_result_and_steps(
langchain_object,
chat_inputs.message,
client_id=client_id,
session_id=session_id,
)
if isinstance(build_result, LANGCHAIN_RUNNABLES):
result, intermediate_steps, raw_output = await get_result_and_steps(
build_result,
chat_inputs.message,
client_id=client_id,
session_id=session_id,
)
elif isinstance(build_result, ChatDefinition):
raw_output = await run_build_result(
build_result,
chat_inputs,
client_id=client_id,
session_id=session_id,
)
if isinstance(raw_output, dict):
if not build_result.output_key:
raise ValueError("No output key provided to ChatDefinition when returning a dict.")
result = raw_output[build_result.output_key]
else:
result = raw_output
intermediate_steps = []
else:
raise TypeError(f"Unknown type {type(build_result)}")
logger.debug("Generated result and intermediate_steps")
return result, intermediate_steps, raw_output
except Exception as e:
# Log stack trace
logger.exception(e)
raise e
async def run_build_result(build_result: Any, chat_inputs: ChatMessage, client_id: str, session_id: str):
return build_result(inputs=chat_inputs.message)

View file

@ -0,0 +1,34 @@
from typing import Any, Callable, Optional, Union
from langchain_core.prompts import PromptTemplate as LCPromptTemplate
from langflow.utils.prompt import GenericPromptTemplate
from llama_index.prompts import PromptTemplate as LIPromptTemplate
PromptTemplate = Union[LCPromptTemplate, LIPromptTemplate]
class ChatDefinition:
def __init__(
self,
func: Callable,
inputs: list[str],
output_key: Optional[str] = None,
prompt_template: Optional[PromptTemplate] = None,
):
self.func = func
self.input_keys = inputs
self.output_key = output_key
self.prompt_template = prompt_template
@classmethod
def from_prompt_template(cls, prompt_template: PromptTemplate, func: Callable, output_key: Optional[str] = None):
prompt = GenericPromptTemplate(prompt_template)
return cls(
func=func,
inputs=prompt.input_keys,
output_key=output_key,
prompt_template=prompt_template,
)
def __call__(self, inputs: dict, callbacks: Optional[Any] = None) -> dict:
return self.func(inputs, callbacks)

View file

@ -0,0 +1,58 @@
from typing import Any, Union
from langchain_core.prompts import PromptTemplate as LCPromptTemplate
from llama_index.prompts import PromptTemplate as LIPromptTemplate
PromptTemplateTypes = Union[LCPromptTemplate, LIPromptTemplate]
class GenericPromptTemplate:
def __init__(self, prompt_template: PromptTemplateTypes):
object.__setattr__(self, "prompt_template", prompt_template)
@property
def input_keys(self):
prompt_template = object.__getattribute__(self, "prompt_template")
if isinstance(prompt_template, LCPromptTemplate):
return prompt_template.input_variables
elif isinstance(prompt_template, LIPromptTemplate):
return prompt_template.template_vars
else:
raise TypeError(f"Unknown prompt template type {type(prompt_template)}")
def to_lc_prompt(self):
prompt_template = object.__getattribute__(self, "prompt_template")
if isinstance(prompt_template, LCPromptTemplate):
return prompt_template
elif isinstance(prompt_template, LIPromptTemplate):
return LCPromptTemplate.from_template(prompt_template.get_template())
else:
raise TypeError(f"Unknown prompt template type {type(prompt_template)}")
def to_li_prompt(self):
prompt_template = object.__getattribute__(self, "prompt_template")
if isinstance(prompt_template, LIPromptTemplate):
return prompt_template
elif isinstance(prompt_template, LCPromptTemplate):
return LIPromptTemplate(template=prompt_template.template)
else:
raise TypeError(f"Unknown prompt template type {type(prompt_template)}")
def __or__(self, other):
prompt_template = object.__getattribute__(self, "prompt_template")
if isinstance(prompt_template, LIPromptTemplate):
return self.to_lc_prompt() | other
else:
raise TypeError(f"Unknown prompt template type {type(other)}")
def __getattribute__(self, name: str) -> Any:
if name in {
"input_keys",
"to_lc_prompt",
"to_li_prompt",
"__or__",
"prompt_template",
}:
return object.__getattribute__(self, name)
prompt_template = object.__getattribute__(self, "prompt_template")
return getattr(prompt_template, name)