Adds ChatDefinition utility (#1279)
The ChatDefinition allows users to turn any flow into a Chat flow by
defining what has to run and what are the inputs and outputs.
The ChatDefinition requires a function, inputs(optional) and
output_key(optional).
The function receives a dictionary as input and can output a string or a
dict. If the output is a dict, then an output_key must be provided.
Anything can run inside the function. You can also pass methods of
pre-built classes like a Chain.
Here's an example of how to use it in a CustomComponent:
```python
from langflow import CustomComponent
from langflow.utils.chat import ChatDefinition
class Component(CustomComponent):
documentation: str = "http://docs.langflow.org/components/custom"
def build(self) -> Data:
def func(inputs, callbacks):
return {"text":'This is a simple example.'}
return ChatDefinition(func=func, inputs=[], output_key="text")
```
This commit is contained in:
commit
a5aef06544
7 changed files with 179 additions and 20 deletions
41
poetry.lock
generated
41
poetry.lock
generated
|
|
@ -3728,6 +3728,43 @@ dev = ["black (>=23.3.0)", "httpx (>=0.24.1)", "mkdocs (>=1.4.3)", "mkdocs-mater
|
|||
server = ["fastapi (>=0.100.0)", "pydantic-settings (>=2.0.1)", "sse-starlette (>=1.6.1)", "starlette-context (>=0.3.6,<0.4)", "uvicorn (>=0.22.0)"]
|
||||
test = ["httpx (>=0.24.1)", "pytest (>=7.4.0)", "scipy (>=1.10)"]
|
||||
|
||||
[[package]]
|
||||
name = "llama-index"
|
||||
version = "0.9.24"
|
||||
description = "Interface between LLMs and your data"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
files = [
|
||||
{file = "llama_index-0.9.24-py3-none-any.whl", hash = "sha256:aeef8a4fb478d45474261289046f37c2805e3bf3453c156c84088c0414465e5e"},
|
||||
{file = "llama_index-0.9.24.tar.gz", hash = "sha256:48175a35c30427f361068693d6f384baf76865831569ca4e04a1a8b6f10ba269"},
|
||||
]
|
||||
|
||||
[package.dependencies]
|
||||
aiohttp = ">=3.8.6,<4.0.0"
|
||||
beautifulsoup4 = ">=4.12.2,<5.0.0"
|
||||
dataclasses-json = "*"
|
||||
deprecated = ">=1.2.9.3"
|
||||
fsspec = ">=2023.5.0"
|
||||
httpx = "*"
|
||||
nest-asyncio = ">=1.5.8,<2.0.0"
|
||||
nltk = ">=3.8.1,<4.0.0"
|
||||
numpy = "*"
|
||||
openai = ">=1.1.0"
|
||||
pandas = "*"
|
||||
requests = ">=2.31.0"
|
||||
SQLAlchemy = {version = ">=1.4.49", extras = ["asyncio"]}
|
||||
tenacity = ">=8.2.0,<9.0.0"
|
||||
tiktoken = ">=0.3.3"
|
||||
typing-extensions = ">=4.5.0"
|
||||
typing-inspect = ">=0.8.0"
|
||||
|
||||
[package.extras]
|
||||
gradientai = ["gradientai (>=1.4.0)"]
|
||||
langchain = ["langchain (>=0.0.303)"]
|
||||
local-models = ["optimum[onnxruntime] (>=1.13.2,<2.0.0)", "sentencepiece (>=0.1.99,<0.2.0)", "transformers[torch] (>=4.34.0,<5.0.0)"]
|
||||
postgres = ["asyncpg (>=0.28.0,<0.29.0)", "pgvector (>=0.1.0,<0.2.0)", "psycopg-binary (>=3.1.12,<4.0.0)"]
|
||||
query-tools = ["guidance (>=0.0.64,<0.0.65)", "jsonpath-ng (>=1.6.0,<2.0.0)", "lm-format-enforcer (>=0.4.3,<0.5.0)", "rank-bm25 (>=0.2.2,<0.3.0)", "scikit-learn", "spacy (>=3.7.1,<4.0.0)"]
|
||||
|
||||
[[package]]
|
||||
name = "locust"
|
||||
version = "2.20.0"
|
||||
|
|
@ -7542,7 +7579,7 @@ files = [
|
|||
]
|
||||
|
||||
[package.dependencies]
|
||||
greenlet = {version = "!=0.4.17", markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\""}
|
||||
greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"asyncio\""}
|
||||
typing-extensions = ">=4.2.0"
|
||||
|
||||
[package.extras]
|
||||
|
|
@ -9184,4 +9221,4 @@ local = ["ctransformers", "llama-cpp-python", "sentence-transformers"]
|
|||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.9,<3.11"
|
||||
content-hash = "298c4c74d1e9ed37405bfbeaefb8ecd5d766d970a4e16fb37cbbccbab79573ce"
|
||||
content-hash = "80378e2f6511caf4f46ad7ff47a5e0823aca87e92e6b0f3ced80040bbb891ffb"
|
||||
|
|
|
|||
|
|
@ -105,6 +105,7 @@ pgvector = "^0.2.3"
|
|||
pyautogen = "^0.2.0"
|
||||
langchain-google-genai = "^0.0.2"
|
||||
pytube = "^15.0.0"
|
||||
llama-index = "^0.9.24"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
pytest-asyncio = "^0.23.1"
|
||||
|
|
|
|||
|
|
@ -2,7 +2,6 @@ from typing import Optional
|
|||
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.llms.bedrock import Bedrock
|
||||
|
||||
from langflow import CustomComponent
|
||||
|
||||
|
||||
|
|
@ -44,7 +43,7 @@ class AmazonBedrockComponent(CustomComponent):
|
|||
model_kwargs: Optional[dict] = None,
|
||||
endpoint_url: Optional[str] = None,
|
||||
streaming: bool = False,
|
||||
cache: bool | None = None,
|
||||
cache: Optional[bool] = None,
|
||||
) -> BaseLLM:
|
||||
try:
|
||||
output = Bedrock(
|
||||
|
|
|
|||
|
|
@ -5,14 +5,15 @@ from typing import Any, Dict, List
|
|||
|
||||
import orjson
|
||||
from fastapi import WebSocket, status
|
||||
from loguru import logger
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
from langflow.api.v1.schemas import ChatMessage, ChatResponse, FileResponse
|
||||
from langflow.interface.utils import pil_to_base64
|
||||
from langflow.services import ServiceType, service_manager
|
||||
from langflow.services.base import Service
|
||||
from langflow.services.chat.cache import Subject
|
||||
from langflow.services.chat.utils import process_graph
|
||||
from loguru import logger
|
||||
from starlette.websockets import WebSocketState
|
||||
|
||||
from .cache import cache_service
|
||||
|
||||
|
|
@ -117,7 +118,7 @@ class ChatService(Service):
|
|||
if "after sending" in str(exc):
|
||||
logger.error(f"Error closing connection: {exc}")
|
||||
|
||||
async def process_message(self, client_id: str, payload: Dict, langchain_object: Any):
|
||||
async def process_message(self, client_id: str, payload: Dict, build_result: Any):
|
||||
# Process the graph data and chat message
|
||||
chat_inputs = payload.pop("inputs", {})
|
||||
chatkey = payload.pop("chatKey", None)
|
||||
|
|
@ -134,12 +135,12 @@ class ChatService(Service):
|
|||
logger.debug("Generating result and thought")
|
||||
|
||||
result, intermediate_steps, raw_output = await process_graph(
|
||||
langchain_object=langchain_object,
|
||||
build_result=build_result,
|
||||
chat_inputs=chat_inputs,
|
||||
client_id=client_id,
|
||||
session_id=self.connection_ids[client_id],
|
||||
)
|
||||
self.set_cache(client_id, langchain_object)
|
||||
self.set_cache(client_id, build_result)
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
logger.exception(e)
|
||||
|
|
@ -205,8 +206,8 @@ class ChatService(Service):
|
|||
continue
|
||||
|
||||
with self.chat_cache.set_client_id(client_id):
|
||||
if langchain_object := self.cache_service.get(client_id).get("result"):
|
||||
await self.process_message(client_id, payload, langchain_object)
|
||||
if build_result := self.cache_service.get(client_id).get("result"):
|
||||
await self.process_message(client_id, payload, build_result)
|
||||
|
||||
else:
|
||||
raise RuntimeError(f"Could not find a build result for client_id {client_id}")
|
||||
|
|
|
|||
|
|
@ -1,20 +1,28 @@
|
|||
from typing import Any
|
||||
|
||||
from langchain.agents import AgentExecutor
|
||||
from langchain.chains.base import Chain
|
||||
from langchain_core.runnables import Runnable
|
||||
from loguru import logger
|
||||
|
||||
from langflow.api.v1.schemas import ChatMessage
|
||||
from langflow.interface.utils import try_setting_streaming_options
|
||||
from langflow.processing.base import get_result_and_steps
|
||||
from langflow.utils.chat import ChatDefinition
|
||||
|
||||
LANGCHAIN_RUNNABLES = (Chain, Runnable, AgentExecutor)
|
||||
|
||||
|
||||
async def process_graph(
|
||||
langchain_object,
|
||||
build_result,
|
||||
chat_inputs: ChatMessage,
|
||||
client_id: str,
|
||||
session_id: str,
|
||||
):
|
||||
langchain_object = try_setting_streaming_options(langchain_object)
|
||||
build_result = try_setting_streaming_options(build_result)
|
||||
logger.debug("Loaded langchain object")
|
||||
|
||||
if langchain_object is None:
|
||||
if build_result is None:
|
||||
# Raise user facing error
|
||||
raise ValueError("There was an error loading the langchain_object. Please, check all the nodes and try again.")
|
||||
|
||||
|
|
@ -25,15 +33,36 @@ async def process_graph(
|
|||
chat_inputs.message = {}
|
||||
|
||||
logger.debug("Generating result and thought")
|
||||
result, intermediate_steps, raw_output = await get_result_and_steps(
|
||||
langchain_object,
|
||||
chat_inputs.message,
|
||||
client_id=client_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
if isinstance(build_result, LANGCHAIN_RUNNABLES):
|
||||
result, intermediate_steps, raw_output = await get_result_and_steps(
|
||||
build_result,
|
||||
chat_inputs.message,
|
||||
client_id=client_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
elif isinstance(build_result, ChatDefinition):
|
||||
raw_output = await run_build_result(
|
||||
build_result,
|
||||
chat_inputs,
|
||||
client_id=client_id,
|
||||
session_id=session_id,
|
||||
)
|
||||
if isinstance(raw_output, dict):
|
||||
if not build_result.output_key:
|
||||
raise ValueError("No output key provided to ChatDefinition when returning a dict.")
|
||||
result = raw_output[build_result.output_key]
|
||||
else:
|
||||
result = raw_output
|
||||
intermediate_steps = []
|
||||
else:
|
||||
raise TypeError(f"Unknown type {type(build_result)}")
|
||||
logger.debug("Generated result and intermediate_steps")
|
||||
return result, intermediate_steps, raw_output
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
logger.exception(e)
|
||||
raise e
|
||||
|
||||
|
||||
async def run_build_result(build_result: Any, chat_inputs: ChatMessage, client_id: str, session_id: str):
|
||||
return build_result(inputs=chat_inputs.message)
|
||||
|
|
|
|||
34
src/backend/langflow/utils/chat.py
Normal file
34
src/backend/langflow/utils/chat.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from langchain_core.prompts import PromptTemplate as LCPromptTemplate
|
||||
from langflow.utils.prompt import GenericPromptTemplate
|
||||
from llama_index.prompts import PromptTemplate as LIPromptTemplate
|
||||
|
||||
PromptTemplate = Union[LCPromptTemplate, LIPromptTemplate]
|
||||
|
||||
|
||||
class ChatDefinition:
|
||||
def __init__(
|
||||
self,
|
||||
func: Callable,
|
||||
inputs: list[str],
|
||||
output_key: Optional[str] = None,
|
||||
prompt_template: Optional[PromptTemplate] = None,
|
||||
):
|
||||
self.func = func
|
||||
self.input_keys = inputs
|
||||
self.output_key = output_key
|
||||
self.prompt_template = prompt_template
|
||||
|
||||
@classmethod
|
||||
def from_prompt_template(cls, prompt_template: PromptTemplate, func: Callable, output_key: Optional[str] = None):
|
||||
prompt = GenericPromptTemplate(prompt_template)
|
||||
return cls(
|
||||
func=func,
|
||||
inputs=prompt.input_keys,
|
||||
output_key=output_key,
|
||||
prompt_template=prompt_template,
|
||||
)
|
||||
|
||||
def __call__(self, inputs: dict, callbacks: Optional[Any] = None) -> dict:
|
||||
return self.func(inputs, callbacks)
|
||||
58
src/backend/langflow/utils/prompt.py
Normal file
58
src/backend/langflow/utils/prompt.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
from typing import Any, Union
|
||||
|
||||
from langchain_core.prompts import PromptTemplate as LCPromptTemplate
|
||||
from llama_index.prompts import PromptTemplate as LIPromptTemplate
|
||||
|
||||
PromptTemplateTypes = Union[LCPromptTemplate, LIPromptTemplate]
|
||||
|
||||
|
||||
class GenericPromptTemplate:
|
||||
def __init__(self, prompt_template: PromptTemplateTypes):
|
||||
object.__setattr__(self, "prompt_template", prompt_template)
|
||||
|
||||
@property
|
||||
def input_keys(self):
|
||||
prompt_template = object.__getattribute__(self, "prompt_template")
|
||||
if isinstance(prompt_template, LCPromptTemplate):
|
||||
return prompt_template.input_variables
|
||||
elif isinstance(prompt_template, LIPromptTemplate):
|
||||
return prompt_template.template_vars
|
||||
else:
|
||||
raise TypeError(f"Unknown prompt template type {type(prompt_template)}")
|
||||
|
||||
def to_lc_prompt(self):
|
||||
prompt_template = object.__getattribute__(self, "prompt_template")
|
||||
if isinstance(prompt_template, LCPromptTemplate):
|
||||
return prompt_template
|
||||
elif isinstance(prompt_template, LIPromptTemplate):
|
||||
return LCPromptTemplate.from_template(prompt_template.get_template())
|
||||
else:
|
||||
raise TypeError(f"Unknown prompt template type {type(prompt_template)}")
|
||||
|
||||
def to_li_prompt(self):
|
||||
prompt_template = object.__getattribute__(self, "prompt_template")
|
||||
if isinstance(prompt_template, LIPromptTemplate):
|
||||
return prompt_template
|
||||
elif isinstance(prompt_template, LCPromptTemplate):
|
||||
return LIPromptTemplate(template=prompt_template.template)
|
||||
else:
|
||||
raise TypeError(f"Unknown prompt template type {type(prompt_template)}")
|
||||
|
||||
def __or__(self, other):
|
||||
prompt_template = object.__getattribute__(self, "prompt_template")
|
||||
if isinstance(prompt_template, LIPromptTemplate):
|
||||
return self.to_lc_prompt() | other
|
||||
else:
|
||||
raise TypeError(f"Unknown prompt template type {type(other)}")
|
||||
|
||||
def __getattribute__(self, name: str) -> Any:
|
||||
if name in {
|
||||
"input_keys",
|
||||
"to_lc_prompt",
|
||||
"to_li_prompt",
|
||||
"__or__",
|
||||
"prompt_template",
|
||||
}:
|
||||
return object.__getattribute__(self, name)
|
||||
prompt_template = object.__getattribute__(self, "prompt_template")
|
||||
return getattr(prompt_template, name)
|
||||
Loading…
Add table
Add a link
Reference in a new issue