Merge branch 'dev' into authentication
This commit is contained in:
commit
2fa7810ff8
440 changed files with 47130 additions and 10958 deletions
|
|
@ -1,6 +1,7 @@
|
|||
from importlib import metadata
|
||||
from langflow.cache import cache_manager
|
||||
from langflow.processing.process import load_flow_from_json
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
|
||||
try:
|
||||
__version__ = metadata.version(__package__)
|
||||
|
|
@ -9,4 +10,4 @@ except metadata.PackageNotFoundError:
|
|||
__version__ = ""
|
||||
del metadata # optional, avoids polluting the results of dir(__package__)
|
||||
|
||||
__all__ = ["load_flow_from_json", "cache_manager"]
|
||||
__all__ = ["load_flow_from_json", "cache_manager", "CustomComponent"]
|
||||
|
|
|
|||
|
|
@ -1,8 +1,9 @@
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
from fastapi import FastAPI
|
||||
import httpx
|
||||
from multiprocess import Process, cpu_count # type: ignore
|
||||
from langflow.utils.util import get_number_of_workers
|
||||
from multiprocess import Process # type: ignore
|
||||
import platform
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
|
@ -11,9 +12,7 @@ from rich.panel import Panel
|
|||
from rich import box
|
||||
from rich import print as rprint
|
||||
import typer
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
from fastapi.responses import FileResponse
|
||||
from langflow.main import create_app
|
||||
from langflow.main import setup_app
|
||||
from langflow.settings import settings
|
||||
from langflow.utils.logger import configure, logger
|
||||
import webbrowser
|
||||
|
|
@ -22,25 +21,53 @@ from dotenv import load_dotenv
|
|||
app = typer.Typer()
|
||||
|
||||
|
||||
def get_number_of_workers(workers=None):
|
||||
if workers == -1:
|
||||
workers = (cpu_count() * 2) + 1
|
||||
return workers
|
||||
|
||||
|
||||
def update_settings(
|
||||
config: str,
|
||||
cache: str,
|
||||
dev: bool = False,
|
||||
database_url: Optional[str] = None,
|
||||
remove_api_keys: bool = False,
|
||||
components_path: Optional[Path] = None,
|
||||
):
|
||||
"""Update the settings from a config file."""
|
||||
|
||||
# Check for database_url in the environment variables
|
||||
database_url = database_url or os.getenv("langflow_database_url")
|
||||
|
||||
if config:
|
||||
logger.debug(f"Loading settings from {config}")
|
||||
settings.update_from_yaml(config, dev=dev)
|
||||
if database_url:
|
||||
settings.update_settings(database_url=database_url)
|
||||
if remove_api_keys:
|
||||
logger.debug(f"Setting remove_api_keys to {remove_api_keys}")
|
||||
settings.update_settings(remove_api_keys=remove_api_keys)
|
||||
if cache:
|
||||
logger.debug(f"Setting cache to {cache}")
|
||||
settings.update_settings(cache=cache)
|
||||
if components_path:
|
||||
logger.debug(f"Adding component path {components_path}")
|
||||
settings.update_settings(components_path=components_path)
|
||||
|
||||
|
||||
def load_params():
|
||||
"""
|
||||
Load the parameters from the environment variables.
|
||||
"""
|
||||
global_vars = globals()
|
||||
|
||||
for key, value in global_vars.items():
|
||||
env_key = f"LANGFLOW_{key.upper()}"
|
||||
if env_key in os.environ:
|
||||
if isinstance(value, bool):
|
||||
# Handle booleans
|
||||
global_vars[key] = os.getenv(env_key, str(value)).lower() == "true"
|
||||
elif isinstance(value, int):
|
||||
# Handle integers
|
||||
global_vars[key] = int(os.getenv(env_key, str(value)))
|
||||
elif isinstance(value, str) or value is None:
|
||||
# Handle strings and None values
|
||||
global_vars[key] = os.getenv(env_key, str(value))
|
||||
|
||||
|
||||
def serve_on_jcloud():
|
||||
|
|
@ -91,56 +118,81 @@ def serve_on_jcloud():
|
|||
|
||||
@app.command()
|
||||
def serve(
|
||||
host: str = typer.Option("127.0.0.1", help="Host to bind the server to."),
|
||||
workers: int = typer.Option(1, help="Number of worker processes."),
|
||||
timeout: int = typer.Option(60, help="Worker timeout in seconds."),
|
||||
port: int = typer.Option(7860, help="Port to listen on."),
|
||||
host: str = typer.Option(
|
||||
"127.0.0.1", help="Host to bind the server to.", envvar="LANGFLOW_HOST"
|
||||
),
|
||||
workers: int = typer.Option(
|
||||
2, help="Number of worker processes.", envvar="LANGFLOW_WORKERS"
|
||||
),
|
||||
timeout: int = typer.Option(300, help="Worker timeout in seconds."),
|
||||
port: int = typer.Option(7860, help="Port to listen on.", envvar="LANGFLOW_PORT"),
|
||||
components_path: Optional[Path] = typer.Option(
|
||||
Path(__file__).parent / "components",
|
||||
help="Path to the directory containing custom components.",
|
||||
envvar="LANGFLOW_COMPONENTS_PATH",
|
||||
),
|
||||
config: str = typer.Option("config.yaml", help="Path to the configuration file."),
|
||||
# .env file param
|
||||
env_file: Path = typer.Option(
|
||||
".env", help="Path to the .env file containing environment variables."
|
||||
),
|
||||
log_level: str = typer.Option("critical", help="Logging level."),
|
||||
log_file: Path = typer.Option("logs/langflow.log", help="Path to the log file."),
|
||||
log_level: str = typer.Option(
|
||||
"critical", help="Logging level.", envvar="LANGFLOW_LOG_LEVEL"
|
||||
),
|
||||
log_file: Path = typer.Option(
|
||||
"logs/langflow.log", help="Path to the log file.", envvar="LANGFLOW_LOG_FILE"
|
||||
),
|
||||
cache: str = typer.Option(
|
||||
envvar="LANGFLOW_LANGCHAIN_CACHE",
|
||||
help="Type of cache to use. (InMemoryCache, SQLiteCache)",
|
||||
default="SQLiteCache",
|
||||
),
|
||||
jcloud: bool = typer.Option(False, help="Deploy on Jina AI Cloud"),
|
||||
dev: bool = typer.Option(False, help="Run in development mode (may contain bugs)"),
|
||||
database_url: str = typer.Option(
|
||||
None,
|
||||
help="Database URL to connect to. If not provided, a local SQLite database will be used.",
|
||||
envvar="LANGFLOW_DATABASE_URL",
|
||||
),
|
||||
path: str = typer.Option(
|
||||
None,
|
||||
help="Path to the frontend directory containing build files. This is for development purposes only.",
|
||||
envvar="LANGFLOW_FRONTEND_PATH",
|
||||
),
|
||||
open_browser: bool = typer.Option(
|
||||
True, help="Open the browser after starting the server."
|
||||
True,
|
||||
help="Open the browser after starting the server.",
|
||||
envvar="LANGFLOW_OPEN_BROWSER",
|
||||
),
|
||||
remove_api_keys: bool = typer.Option(
|
||||
False, help="Remove API keys from the projects saved in the database."
|
||||
False,
|
||||
help="Remove API keys from the projects saved in the database.",
|
||||
envvar="LANGFLOW_REMOVE_API_KEYS",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Run the Langflow server.
|
||||
"""
|
||||
# override env variables with .env file
|
||||
if env_file:
|
||||
load_dotenv(env_file, override=True)
|
||||
load_params()
|
||||
|
||||
if jcloud:
|
||||
return serve_on_jcloud()
|
||||
|
||||
load_dotenv(env_file)
|
||||
|
||||
configure(log_level=log_level, log_file=log_file)
|
||||
update_settings(
|
||||
config, dev=dev, database_url=database_url, remove_api_keys=remove_api_keys
|
||||
config,
|
||||
dev=dev,
|
||||
database_url=database_url,
|
||||
remove_api_keys=remove_api_keys,
|
||||
cache=cache,
|
||||
components_path=components_path,
|
||||
)
|
||||
# get the directory of the current file
|
||||
if not path:
|
||||
frontend_path = Path(__file__).parent
|
||||
static_files_dir = frontend_path / "frontend"
|
||||
else:
|
||||
static_files_dir = Path(path)
|
||||
|
||||
app = create_app()
|
||||
setup_static_files(app, static_files_dir)
|
||||
# create path object if path is provided
|
||||
static_files_dir: Optional[Path] = Path(path) if path else None
|
||||
app = setup_app(static_files_dir=static_files_dir)
|
||||
# check if port is being used
|
||||
if is_port_in_use(port, host):
|
||||
port = get_free_port(port)
|
||||
|
|
@ -152,6 +204,17 @@ def serve(
|
|||
"timeout": timeout,
|
||||
}
|
||||
|
||||
if platform.system() in ["Windows"]:
|
||||
# Run using uvicorn on MacOS and Windows
|
||||
# Windows doesn't support gunicorn
|
||||
# MacOS requires an env variable to be set to use gunicorn
|
||||
run_on_windows(host, port, log_level, options, app)
|
||||
else:
|
||||
# Run using gunicorn on Linux
|
||||
run_on_mac_or_linux(host, port, log_level, options, app, open_browser)
|
||||
|
||||
|
||||
def run_on_mac_or_linux(host, port, log_level, options, app, open_browser=True):
|
||||
webapp_process = Process(
|
||||
target=run_langflow, args=(host, port, log_level, options, app)
|
||||
)
|
||||
|
|
@ -169,27 +232,12 @@ def serve(
|
|||
webbrowser.open(f"http://{host}:{port}")
|
||||
|
||||
|
||||
def setup_static_files(app: FastAPI, static_files_dir: Path):
|
||||
def run_on_windows(host, port, log_level, options, app):
|
||||
"""
|
||||
Setup the static files directory.
|
||||
|
||||
Args:
|
||||
app (FastAPI): FastAPI app.
|
||||
path (str): Path to the static files directory.
|
||||
Run the Langflow server on Windows.
|
||||
"""
|
||||
app.mount(
|
||||
"/",
|
||||
StaticFiles(directory=static_files_dir, html=True),
|
||||
name="static",
|
||||
)
|
||||
|
||||
@app.exception_handler(404)
|
||||
async def custom_404_handler(request, __):
|
||||
path = static_files_dir / "index.html"
|
||||
|
||||
if not path.exists():
|
||||
raise RuntimeError(f"File at path {path} does not exist.")
|
||||
return FileResponse(path)
|
||||
print_banner(host, port)
|
||||
run_langflow(host, port, log_level, options, app)
|
||||
|
||||
|
||||
def is_port_in_use(port, host="localhost"):
|
||||
|
|
@ -225,7 +273,7 @@ def get_free_port(port):
|
|||
def print_banner(host, port):
|
||||
# console = Console()
|
||||
|
||||
word = "LangFlow"
|
||||
word = "Langflow"
|
||||
colors = ["#3300cc"]
|
||||
|
||||
styled_word = ""
|
||||
|
|
@ -258,7 +306,7 @@ def run_langflow(host, port, log_level, options, app):
|
|||
Run Langflow server on localhost
|
||||
"""
|
||||
try:
|
||||
if platform.system() in ["Darwin", "Windows"]:
|
||||
if platform.system() in ["Windows"]:
|
||||
# Run using uvicorn on MacOS and Windows
|
||||
# Windows doesn't support gunicorn
|
||||
# MacOS requires an env variable to be set to use gunicorn
|
||||
|
|
@ -272,7 +320,7 @@ def run_langflow(host, port, log_level, options, app):
|
|||
except KeyboardInterrupt:
|
||||
pass
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
logger.exception(e)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from langflow.api.v1 import (
|
|||
validate_router,
|
||||
flows_router,
|
||||
flow_styles_router,
|
||||
component_router,
|
||||
)
|
||||
|
||||
router = APIRouter(
|
||||
|
|
@ -14,5 +15,6 @@ router = APIRouter(
|
|||
router.include_router(chat_router)
|
||||
router.include_router(endpoints_router)
|
||||
router.include_router(validate_router)
|
||||
router.include_router(component_router)
|
||||
router.include_router(flows_router)
|
||||
router.include_router(flow_styles_router)
|
||||
|
|
|
|||
|
|
@ -22,3 +22,47 @@ def remove_api_keys(flow: dict):
|
|||
value["value"] = None
|
||||
|
||||
return flow
|
||||
|
||||
|
||||
def build_input_keys_response(langchain_object, artifacts):
|
||||
"""Build the input keys response."""
|
||||
|
||||
input_keys_response = {
|
||||
"input_keys": {key: "" for key in langchain_object.input_keys},
|
||||
"memory_keys": [],
|
||||
"handle_keys": artifacts.get("handle_keys", []),
|
||||
}
|
||||
|
||||
# Set the input keys values from artifacts
|
||||
for key, value in artifacts.items():
|
||||
if key in input_keys_response["input_keys"]:
|
||||
input_keys_response["input_keys"][key] = value
|
||||
# If the object has memory, that memory will have a memory_variables attribute
|
||||
# memory variables should be removed from the input keys
|
||||
if hasattr(langchain_object, "memory") and hasattr(
|
||||
langchain_object.memory, "memory_variables"
|
||||
):
|
||||
# Remove memory variables from input keys
|
||||
input_keys_response["input_keys"] = {
|
||||
key: value
|
||||
for key, value in input_keys_response["input_keys"].items()
|
||||
if key not in langchain_object.memory.memory_variables
|
||||
}
|
||||
# Add memory variables to memory_keys
|
||||
input_keys_response["memory_keys"] = langchain_object.memory.memory_variables
|
||||
|
||||
if hasattr(langchain_object, "prompt") and hasattr(
|
||||
langchain_object.prompt, "template"
|
||||
):
|
||||
input_keys_response["template"] = langchain_object.prompt.template
|
||||
|
||||
return input_keys_response
|
||||
|
||||
|
||||
def merge_nested_dicts(dict1, dict2):
|
||||
for key, value in dict2.items():
|
||||
if isinstance(value, dict) and isinstance(dict1.get(key), dict):
|
||||
dict1[key] = merge_nested_dicts(dict1[key], value)
|
||||
else:
|
||||
dict1[key] = value
|
||||
return dict1
|
||||
|
|
|
|||
|
|
@ -3,10 +3,12 @@ from langflow.api.v1.validate import router as validate_router
|
|||
from langflow.api.v1.chat import router as chat_router
|
||||
from langflow.api.v1.flows import router as flows_router
|
||||
from langflow.api.v1.flow_styles import router as flow_styles_router
|
||||
from langflow.api.v1.components import router as component_router
|
||||
|
||||
__all__ = [
|
||||
"chat_router",
|
||||
"endpoints_router",
|
||||
"component_router",
|
||||
"validate_router",
|
||||
"flows_router",
|
||||
"flow_styles_router",
|
||||
|
|
|
|||
|
|
@ -1,6 +1,8 @@
|
|||
from langflow.template.frontend_node.base import FrontendNode
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from langflow.interface.utils import extract_input_variables_from_prompt
|
||||
from langchain.prompts import PromptTemplate
|
||||
|
||||
|
||||
class CacheResponse(BaseModel):
|
||||
|
|
@ -11,8 +13,14 @@ class Code(BaseModel):
|
|||
code: str
|
||||
|
||||
|
||||
class Prompt(BaseModel):
|
||||
class FrontendNodeRequest(FrontendNode):
|
||||
template: dict # type: ignore
|
||||
|
||||
|
||||
class ValidatePromptRequest(BaseModel):
|
||||
name: str
|
||||
template: str
|
||||
frontend_node: FrontendNodeRequest
|
||||
|
||||
|
||||
# Build ValidationResponse class for {"imports": {"errors": []}, "function": {"errors": []}}
|
||||
|
|
@ -31,6 +39,7 @@ class CodeValidationResponse(BaseModel):
|
|||
|
||||
class PromptValidationResponse(BaseModel):
|
||||
input_variables: list
|
||||
frontend_node: FrontendNodeRequest
|
||||
|
||||
|
||||
INVALID_CHARACTERS = {
|
||||
|
|
@ -51,34 +60,93 @@ INVALID_CHARACTERS = {
|
|||
"}",
|
||||
}
|
||||
|
||||
INVALID_NAMES = {
|
||||
"input_variables",
|
||||
"output_parser",
|
||||
"partial_variables",
|
||||
"template",
|
||||
"template_format",
|
||||
"validate_template",
|
||||
}
|
||||
|
||||
|
||||
def validate_prompt(template: str):
|
||||
input_variables = extract_input_variables_from_prompt(template)
|
||||
|
||||
# Check if there are invalid characters in the input_variables
|
||||
input_variables = check_input_variables(input_variables)
|
||||
if any(var in INVALID_NAMES for var in input_variables):
|
||||
raise ValueError(
|
||||
f"Invalid input variables. None of the variables can be named {', '.join(input_variables)}. "
|
||||
)
|
||||
|
||||
return PromptValidationResponse(input_variables=input_variables)
|
||||
try:
|
||||
PromptTemplate(template=template, input_variables=input_variables)
|
||||
except Exception as exc:
|
||||
raise ValueError(str(exc)) from exc
|
||||
|
||||
return input_variables
|
||||
|
||||
|
||||
def check_input_variables(input_variables: list):
|
||||
invalid_chars = []
|
||||
fixed_variables = []
|
||||
wrong_variables = []
|
||||
empty_variables = []
|
||||
for variable in input_variables:
|
||||
new_var = variable
|
||||
for char in INVALID_CHARACTERS:
|
||||
if char in variable:
|
||||
invalid_chars.append(char)
|
||||
new_var = new_var.replace(char, "")
|
||||
|
||||
# if variable is empty, then we should add that to the wrong variables
|
||||
if not variable:
|
||||
empty_variables.append(variable)
|
||||
continue
|
||||
|
||||
# if variable starts with a number we should add that to the invalid chars
|
||||
# and wrong variables
|
||||
if variable[0].isdigit():
|
||||
invalid_chars.append(variable[0])
|
||||
new_var = new_var.replace(variable[0], "")
|
||||
wrong_variables.append(variable)
|
||||
else:
|
||||
for char in INVALID_CHARACTERS:
|
||||
if char in variable:
|
||||
invalid_chars.append(char)
|
||||
new_var = new_var.replace(char, "")
|
||||
wrong_variables.append(variable)
|
||||
fixed_variables.append(new_var)
|
||||
if new_var != variable:
|
||||
input_variables.remove(variable)
|
||||
input_variables.append(new_var)
|
||||
# If any of the input_variables is not in the fixed_variables, then it means that
|
||||
# there are invalid characters in the input_variables
|
||||
if any(var not in fixed_variables for var in input_variables):
|
||||
raise ValueError(
|
||||
f"Invalid input variables: {input_variables}. Please, use something like {fixed_variables} instead."
|
||||
)
|
||||
|
||||
if any(var not in fixed_variables for var in input_variables):
|
||||
error_message = build_error_message(
|
||||
input_variables,
|
||||
invalid_chars,
|
||||
wrong_variables,
|
||||
fixed_variables,
|
||||
empty_variables,
|
||||
)
|
||||
raise ValueError(error_message)
|
||||
return input_variables
|
||||
|
||||
|
||||
def build_error_message(
|
||||
input_variables, invalid_chars, wrong_variables, fixed_variables, empty_variables
|
||||
):
|
||||
input_variables_str = ", ".join([f"'{var}'" for var in input_variables])
|
||||
error_string = f"Invalid input variables: {input_variables_str}. "
|
||||
|
||||
if wrong_variables and invalid_chars:
|
||||
# fix the wrong variables replacing invalid chars and find them in the fixed variables
|
||||
error_string_vars = "You can fix them by replacing the invalid characters: "
|
||||
wvars = wrong_variables.copy()
|
||||
for i, wrong_var in enumerate(wvars):
|
||||
for char in invalid_chars:
|
||||
wrong_var = wrong_var.replace(char, "")
|
||||
if wrong_var in fixed_variables:
|
||||
error_string_vars += f"'{wrong_variables[i]}' -> '{wrong_var}'"
|
||||
error_string += error_string_vars
|
||||
elif empty_variables:
|
||||
error_string += f" There are {len(empty_variables)} empty variable{'s' if len(empty_variables) > 1 else ''}."
|
||||
elif len(set(fixed_variables)) != len(fixed_variables):
|
||||
error_string += "There are duplicate variables."
|
||||
return error_string
|
||||
|
|
|
|||
|
|
@ -1,22 +1,132 @@
|
|||
import asyncio
|
||||
from typing import Any
|
||||
|
||||
from langchain.callbacks.base import AsyncCallbackHandler, BaseCallbackHandler
|
||||
|
||||
from langflow.api.v1.schemas import ChatResponse
|
||||
|
||||
|
||||
from typing import Any, Dict, List, Union
|
||||
from fastapi import WebSocket
|
||||
|
||||
|
||||
from langchain.schema import AgentAction, LLMResult, AgentFinish
|
||||
from langflow.utils.logger import logger
|
||||
|
||||
|
||||
# https://github.com/hwchase17/chat-langchain/blob/master/callback.py
|
||||
class AsyncStreamingLLMCallbackHandler(AsyncCallbackHandler):
|
||||
"""Callback handler for streaming LLM responses."""
|
||||
|
||||
def __init__(self, websocket):
|
||||
def __init__(self, websocket: WebSocket):
|
||||
self.websocket = websocket
|
||||
|
||||
async def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
|
||||
resp = ChatResponse(message=token, type="stream", intermediate_steps="")
|
||||
await self.websocket.send_json(resp.dict())
|
||||
|
||||
async def on_llm_start(
|
||||
self, serialized: Dict[str, Any], prompts: List[str], **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run when LLM starts running."""
|
||||
|
||||
async def on_llm_end(self, response: LLMResult, **kwargs: Any) -> Any:
|
||||
"""Run when LLM ends running."""
|
||||
|
||||
async def on_llm_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run when LLM errors."""
|
||||
|
||||
async def on_chain_start(
|
||||
self, serialized: Dict[str, Any], inputs: Dict[str, Any], **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run when chain starts running."""
|
||||
|
||||
async def on_chain_end(self, outputs: Dict[str, Any], **kwargs: Any) -> Any:
|
||||
"""Run when chain ends running."""
|
||||
|
||||
async def on_chain_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run when chain errors."""
|
||||
|
||||
async def on_tool_start(
|
||||
self, serialized: Dict[str, Any], input_str: str, **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run when tool starts running."""
|
||||
resp = ChatResponse(
|
||||
message="",
|
||||
type="stream",
|
||||
intermediate_steps=f"Tool input: {input_str}",
|
||||
)
|
||||
await self.websocket.send_json(resp.dict())
|
||||
|
||||
async def on_tool_end(self, output: str, **kwargs: Any) -> Any:
|
||||
"""Run when tool ends running."""
|
||||
observation_prefix = kwargs.get("observation_prefix", "Tool output: ")
|
||||
split_output = output.split()
|
||||
first_word = split_output[0]
|
||||
rest_of_output = split_output[1:]
|
||||
# Create a formatted message.
|
||||
intermediate_steps = f"{observation_prefix}{first_word}"
|
||||
|
||||
# Create a ChatResponse instance.
|
||||
resp = ChatResponse(
|
||||
message="",
|
||||
type="stream",
|
||||
intermediate_steps=intermediate_steps,
|
||||
)
|
||||
rest_of_resps = [
|
||||
ChatResponse(
|
||||
message="",
|
||||
type="stream",
|
||||
intermediate_steps=f"{word}",
|
||||
)
|
||||
for word in rest_of_output
|
||||
]
|
||||
resps = [resp] + rest_of_resps
|
||||
# Try to send the response, handle potential errors.
|
||||
|
||||
try:
|
||||
# This is to emulate the stream of tokens
|
||||
for resp in resps:
|
||||
await self.websocket.send_json(resp.dict())
|
||||
except Exception as exc:
|
||||
logger.error(f"Error sending response: {exc}")
|
||||
|
||||
async def on_tool_error(
|
||||
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
|
||||
) -> Any:
|
||||
"""Run when tool errors."""
|
||||
|
||||
async def on_text(self, text: str, **kwargs: Any) -> Any:
|
||||
"""Run on arbitrary text."""
|
||||
# This runs when first sending the prompt
|
||||
# to the LLM, adding it will send the final prompt
|
||||
# to the frontend
|
||||
|
||||
async def on_agent_action(self, action: AgentAction, **kwargs: Any):
|
||||
log = f"Thought: {action.log}"
|
||||
# if there are line breaks, split them and send them
|
||||
# as separate messages
|
||||
if "\n" in log:
|
||||
logs = log.split("\n")
|
||||
for log in logs:
|
||||
resp = ChatResponse(message="", type="stream", intermediate_steps=log)
|
||||
await self.websocket.send_json(resp.dict())
|
||||
else:
|
||||
resp = ChatResponse(message="", type="stream", intermediate_steps=log)
|
||||
await self.websocket.send_json(resp.dict())
|
||||
|
||||
async def on_agent_finish(self, finish: AgentFinish, **kwargs: Any) -> Any:
|
||||
"""Run on agent end."""
|
||||
resp = ChatResponse(
|
||||
message="",
|
||||
type="stream",
|
||||
intermediate_steps=finish.log,
|
||||
)
|
||||
await self.websocket.send_json(resp.dict())
|
||||
|
||||
|
||||
class StreamingLLMCallbackHandler(BaseCallbackHandler):
|
||||
"""Callback handler for streaming LLM responses."""
|
||||
|
|
|
|||
|
|
@ -1,13 +1,7 @@
|
|||
import json
|
||||
from fastapi import (
|
||||
APIRouter,
|
||||
HTTPException,
|
||||
WebSocket,
|
||||
WebSocketException,
|
||||
status,
|
||||
)
|
||||
from fastapi import APIRouter, HTTPException, WebSocket, WebSocketException, status
|
||||
from fastapi.responses import StreamingResponse
|
||||
from langflow.api.v1.schemas import BuiltResponse, InitResponse
|
||||
from langflow.api.utils import build_input_keys_response
|
||||
from langflow.api.v1.schemas import BuildStatus, BuiltResponse, InitResponse, StreamData
|
||||
|
||||
from langflow.chat.manager import ChatManager
|
||||
from langflow.graph.graph.base import Graph
|
||||
|
|
@ -26,26 +20,43 @@ async def chat(client_id: str, websocket: WebSocket):
|
|||
if client_id in chat_manager.in_memory_cache:
|
||||
await chat_manager.handle_websocket(client_id, websocket)
|
||||
else:
|
||||
# We accept the connection but close it immediately
|
||||
# if the flow is not built yet
|
||||
await websocket.accept()
|
||||
message = "Please, build the flow before sending messages"
|
||||
await websocket.close(code=status.WS_1008_POLICY_VIOLATION, reason=message)
|
||||
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=message)
|
||||
except WebSocketException as exc:
|
||||
logger.error(exc)
|
||||
logger.error(f"Websocket error: {exc}")
|
||||
await websocket.close(code=status.WS_1011_INTERNAL_ERROR, reason=str(exc))
|
||||
|
||||
|
||||
@router.post("/build/init", response_model=InitResponse, status_code=201)
|
||||
async def init_build(graph_data: dict):
|
||||
@router.post("/build/init/{flow_id}", response_model=InitResponse, status_code=201)
|
||||
async def init_build(graph_data: dict, flow_id: str):
|
||||
"""Initialize the build by storing graph data and returning a unique session ID."""
|
||||
|
||||
try:
|
||||
flow_id = graph_data.get("id")
|
||||
if flow_id is None:
|
||||
raise ValueError("No ID provided")
|
||||
flow_data_store[flow_id] = graph_data
|
||||
# Check if already building
|
||||
if (
|
||||
flow_id in flow_data_store
|
||||
and flow_data_store[flow_id]["status"] == BuildStatus.IN_PROGRESS
|
||||
):
|
||||
return InitResponse(flowId=flow_id)
|
||||
|
||||
# Delete from cache if already exists
|
||||
if flow_id in chat_manager.in_memory_cache:
|
||||
with chat_manager.in_memory_cache._lock:
|
||||
chat_manager.in_memory_cache.delete(flow_id)
|
||||
logger.debug(f"Deleted flow {flow_id} from cache")
|
||||
flow_data_store[flow_id] = {
|
||||
"graph_data": graph_data,
|
||||
"status": BuildStatus.STARTED,
|
||||
}
|
||||
|
||||
return InitResponse(flowId=flow_id)
|
||||
except Exception as exc:
|
||||
logger.error(exc)
|
||||
logger.error(f"Error initializing build: {exc}")
|
||||
return HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
|
|
@ -53,8 +64,9 @@ async def init_build(graph_data: dict):
|
|||
async def build_status(flow_id: str):
|
||||
"""Check the flow_id is in the flow_data_store."""
|
||||
try:
|
||||
built = flow_id in flow_data_store and not isinstance(
|
||||
flow_data_store[flow_id], dict
|
||||
built = (
|
||||
flow_id in flow_data_store
|
||||
and flow_data_store[flow_id]["status"] == BuildStatus.SUCCESS
|
||||
)
|
||||
|
||||
return BuiltResponse(
|
||||
|
|
@ -62,7 +74,7 @@ async def build_status(flow_id: str):
|
|||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(exc)
|
||||
logger.error(f"Error checking build status: {exc}")
|
||||
return HTTPException(status_code=500, detail=str(exc))
|
||||
|
||||
|
||||
|
|
@ -71,52 +83,99 @@ async def stream_build(flow_id: str):
|
|||
"""Stream the build process based on stored flow data."""
|
||||
|
||||
async def event_stream(flow_id):
|
||||
final_response = json.dumps({"end_of_stream": True})
|
||||
final_response = {"end_of_stream": True}
|
||||
artifacts = {}
|
||||
try:
|
||||
if flow_id not in flow_data_store:
|
||||
error_message = "Invalid session ID"
|
||||
yield f"data: {json.dumps({'error': error_message})}\n\n"
|
||||
yield str(StreamData(event="error", data={"error": error_message}))
|
||||
return
|
||||
|
||||
graph_data = flow_data_store[flow_id].get("data")
|
||||
if flow_data_store[flow_id].get("status") == BuildStatus.IN_PROGRESS:
|
||||
error_message = "Already building"
|
||||
yield str(StreamData(event="error", data={"error": error_message}))
|
||||
return
|
||||
|
||||
graph_data = flow_data_store[flow_id].get("graph_data")
|
||||
|
||||
if not graph_data:
|
||||
error_message = "No data provided"
|
||||
yield f"data: {json.dumps({'error': error_message})}\n\n"
|
||||
yield str(StreamData(event="error", data={"error": error_message}))
|
||||
return
|
||||
|
||||
logger.debug("Building langchain object")
|
||||
graph = Graph.from_payload(graph_data)
|
||||
for node in graph.generator_build():
|
||||
try:
|
||||
# Some error could happen when building the graph
|
||||
graph = Graph.from_payload(graph_data)
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
error_message = str(exc)
|
||||
yield str(StreamData(event="error", data={"error": error_message}))
|
||||
return
|
||||
|
||||
number_of_nodes = len(graph.nodes)
|
||||
flow_data_store[flow_id]["status"] = BuildStatus.IN_PROGRESS
|
||||
|
||||
for i, vertex in enumerate(graph.generator_build(), 1):
|
||||
try:
|
||||
node.build()
|
||||
params = node._built_object_repr()
|
||||
log_dict = {
|
||||
"log": f"Building node {vertex.vertex_type}",
|
||||
}
|
||||
yield str(StreamData(event="log", data=log_dict))
|
||||
vertex.build()
|
||||
params = vertex._built_object_repr()
|
||||
valid = True
|
||||
logger.debug(
|
||||
f"Building node {params[:50]}{'...' if len(params) > 50 else ''}"
|
||||
f"Building node {str(params)[:50]}{'...' if len(str(params)) > 50 else ''}"
|
||||
)
|
||||
if vertex.artifacts:
|
||||
# The artifacts will be prompt variables
|
||||
# passed to build_input_keys_response
|
||||
# to set the input_keys values
|
||||
artifacts.update(vertex.artifacts)
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
params = str(exc)
|
||||
valid = False
|
||||
flow_data_store[flow_id]["status"] = BuildStatus.FAILURE
|
||||
|
||||
response = json.dumps(
|
||||
{
|
||||
"valid": valid,
|
||||
"params": params,
|
||||
"id": node.id,
|
||||
}
|
||||
response = {
|
||||
"valid": valid,
|
||||
"params": params,
|
||||
"id": vertex.id,
|
||||
"progress": round(i / number_of_nodes, 2),
|
||||
}
|
||||
|
||||
yield str(StreamData(event="message", data=response))
|
||||
|
||||
langchain_object = graph.build()
|
||||
# Now we need to check the input_keys to send them to the client
|
||||
if hasattr(langchain_object, "input_keys"):
|
||||
input_keys_response = build_input_keys_response(
|
||||
langchain_object, artifacts
|
||||
)
|
||||
yield f"data: {response}\n\n"
|
||||
else:
|
||||
input_keys_response = {
|
||||
"input_keys": {},
|
||||
"memory_keys": [],
|
||||
"handle_keys": [],
|
||||
}
|
||||
yield str(StreamData(event="message", data=input_keys_response))
|
||||
|
||||
chat_manager.set_cache(flow_id, graph.build())
|
||||
chat_manager.set_cache(flow_id, langchain_object)
|
||||
# We need to reset the chat history
|
||||
chat_manager.chat_history.empty_history(flow_id)
|
||||
flow_data_store[flow_id]["status"] = BuildStatus.SUCCESS
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
logger.error("Error while building the flow: %s", exc)
|
||||
yield f"error: {json.dumps({'error': str(exc)})}\n\n"
|
||||
flow_data_store[flow_id]["status"] = BuildStatus.FAILURE
|
||||
yield str(StreamData(event="error", data={"error": str(exc)}))
|
||||
finally:
|
||||
yield f"data: {final_response}\n\n"
|
||||
yield str(StreamData(event="message", data=final_response))
|
||||
|
||||
try:
|
||||
return StreamingResponse(event_stream(flow_id), media_type="text/event-stream")
|
||||
except Exception as exc:
|
||||
logger.error(exc)
|
||||
logger.error(f"Error streaming build: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc))
|
||||
|
|
|
|||
77
src/backend/langflow/api/v1/components.py
Normal file
77
src/backend/langflow/api/v1/components.py
Normal file
|
|
@ -0,0 +1,77 @@
|
|||
from datetime import timezone
|
||||
from typing import List
|
||||
from uuid import UUID
|
||||
from langflow.database.models.component import Component, ComponentModel
|
||||
from langflow.database.base import get_session
|
||||
from sqlmodel import Session, select
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
COMPONENT_NOT_FOUND = "Component not found"
|
||||
COMPONENT_ALREADY_EXISTS = "A component with the same id already exists."
|
||||
COMPONENT_DELETED = "Component deleted"
|
||||
|
||||
|
||||
router = APIRouter(prefix="/components", tags=["Components"])
|
||||
|
||||
|
||||
@router.post("/", response_model=Component)
|
||||
def create_component(component: ComponentModel, db: Session = Depends(get_session)):
|
||||
db_component = Component(**component.dict())
|
||||
try:
|
||||
db.add(db_component)
|
||||
db.commit()
|
||||
db.refresh(db_component)
|
||||
except IntegrityError as e:
|
||||
db.rollback()
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail=COMPONENT_ALREADY_EXISTS,
|
||||
) from e
|
||||
return db_component
|
||||
|
||||
|
||||
@router.get("/{component_id}", response_model=Component)
|
||||
def read_component(component_id: UUID, db: Session = Depends(get_session)):
|
||||
if component := db.get(Component, component_id):
|
||||
return component
|
||||
else:
|
||||
raise HTTPException(status_code=404, detail=COMPONENT_NOT_FOUND)
|
||||
|
||||
|
||||
@router.get("/", response_model=List[Component])
|
||||
def read_components(skip: int = 0, limit: int = 50, db: Session = Depends(get_session)):
|
||||
query = select(Component)
|
||||
query = query.offset(skip).limit(limit)
|
||||
|
||||
return db.execute(query).fetchall()
|
||||
|
||||
|
||||
@router.patch("/{component_id}", response_model=Component)
|
||||
def update_component(
|
||||
component_id: UUID, component: ComponentModel, db: Session = Depends(get_session)
|
||||
):
|
||||
db_component = db.get(Component, component_id)
|
||||
if not db_component:
|
||||
raise HTTPException(status_code=404, detail=COMPONENT_NOT_FOUND)
|
||||
component_data = component.dict(exclude_unset=True)
|
||||
|
||||
for key, value in component_data.items():
|
||||
setattr(db_component, key, value)
|
||||
|
||||
db_component.update_at = datetime.now(timezone.utc)
|
||||
db.commit()
|
||||
db.refresh(db_component)
|
||||
return db_component
|
||||
|
||||
|
||||
@router.delete("/{component_id}")
|
||||
def delete_component(component_id: UUID, db: Session = Depends(get_session)):
|
||||
component = db.get(Component, component_id)
|
||||
if not component:
|
||||
raise HTTPException(status_code=404, detail=COMPONENT_NOT_FOUND)
|
||||
db.delete(component)
|
||||
db.commit()
|
||||
return {"detail": COMPONENT_DELETED}
|
||||
|
|
@ -1,15 +1,34 @@
|
|||
from http import HTTPStatus
|
||||
from typing import Annotated, Optional
|
||||
|
||||
from langflow.cache.utils import save_uploaded_file
|
||||
from langflow.database.models.flow import Flow
|
||||
from langflow.processing.process import process_graph_cached, process_tweaks
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.settings import settings
|
||||
|
||||
from fastapi import APIRouter, Depends, HTTPException
|
||||
from fastapi import APIRouter, Depends, HTTPException, UploadFile, Body
|
||||
|
||||
from langflow.api.v1.schemas import (
|
||||
PredictRequest,
|
||||
PredictResponse,
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
|
||||
from langflow.interface.custom.directory_reader import (
|
||||
CustomComponentPathValueError,
|
||||
)
|
||||
|
||||
from langflow.api.v1.schemas import (
|
||||
ProcessResponse,
|
||||
UploadFileResponse,
|
||||
CustomComponentCode,
|
||||
)
|
||||
|
||||
from langflow.api.utils import merge_nested_dicts
|
||||
|
||||
from langflow.interface.types import (
|
||||
build_langchain_types_dict,
|
||||
build_langchain_template_custom_component,
|
||||
build_langchain_custom_component_list_from_path,
|
||||
)
|
||||
|
||||
from langflow.interface.types import build_langchain_types_dict
|
||||
from langflow.database.base import get_session
|
||||
from sqlmodel import Session
|
||||
|
||||
|
|
@ -19,17 +38,61 @@ router = APIRouter(tags=["Base"])
|
|||
|
||||
@router.get("/all")
|
||||
def get_all():
|
||||
return build_langchain_types_dict()
|
||||
native_components = build_langchain_types_dict()
|
||||
|
||||
# custom_components is a list of dicts
|
||||
# need to merge all the keys into one dict
|
||||
custom_components_from_file = {}
|
||||
if settings.components_path:
|
||||
custom_component_dicts = [
|
||||
build_langchain_custom_component_list_from_path(str(path))
|
||||
for path in settings.components_path
|
||||
]
|
||||
for custom_component_dict in custom_component_dicts:
|
||||
custom_components_from_file = merge_nested_dicts(
|
||||
custom_components_from_file, custom_component_dict
|
||||
)
|
||||
return merge_nested_dicts(native_components, custom_components_from_file)
|
||||
|
||||
|
||||
@router.post("/predict/{flow_id}", response_model=PredictResponse)
|
||||
async def predict_flow(
|
||||
predict_request: PredictRequest,
|
||||
@router.get("/load_custom_component_from_path")
|
||||
def get_load_custom_component_from_path(path: str):
|
||||
try:
|
||||
data = build_langchain_custom_component_list_from_path(path)
|
||||
except CustomComponentPathValueError as err:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": type(err).__name__, "traceback": str(err)},
|
||||
) from err
|
||||
|
||||
return data
|
||||
|
||||
|
||||
@router.get("/load_custom_component_from_path_TEST")
|
||||
def get_load_custom_component_from_path_test(path: str):
|
||||
from langflow.interface.custom.directory_reader import (
|
||||
DirectoryReader,
|
||||
)
|
||||
|
||||
reader = DirectoryReader(path, False)
|
||||
file_list = reader.get_files()
|
||||
data = reader.build_component_menu_list(file_list)
|
||||
|
||||
return reader.filter_loaded_components(data, True)
|
||||
|
||||
|
||||
# For backwards compatibility we will keep the old endpoint
|
||||
@router.post("/predict/{flow_id}", response_model=ProcessResponse)
|
||||
@router.post("/process/{flow_id}", response_model=ProcessResponse)
|
||||
async def process_flow(
|
||||
flow_id: str,
|
||||
inputs: Optional[dict] = None,
|
||||
tweaks: Optional[dict] = None,
|
||||
clear_cache: Annotated[bool, Body(embed=True)] = False, # noqa: F821
|
||||
session: Session = Depends(get_session),
|
||||
):
|
||||
"""
|
||||
Endpoint to process a message using the flow passed in the bearer token.
|
||||
Endpoint to process an input with a given flow_id.
|
||||
"""
|
||||
|
||||
try:
|
||||
|
|
@ -40,15 +103,14 @@ async def predict_flow(
|
|||
if flow.data is None:
|
||||
raise ValueError(f"Flow {flow_id} has no data")
|
||||
graph_data = flow.data
|
||||
if predict_request.tweaks:
|
||||
if tweaks:
|
||||
try:
|
||||
graph_data = process_tweaks(graph_data, predict_request.tweaks)
|
||||
graph_data = process_tweaks(graph_data, tweaks)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error processing tweaks: {exc}")
|
||||
response = process_graph_cached(graph_data, predict_request.message)
|
||||
return PredictResponse(
|
||||
result=response.get("result", ""),
|
||||
intermediate_steps=response.get("thought", ""),
|
||||
response = process_graph_cached(graph_data, inputs, clear_cache)
|
||||
return ProcessResponse(
|
||||
result=response,
|
||||
)
|
||||
except Exception as e:
|
||||
# Log stack trace
|
||||
|
|
@ -56,9 +118,38 @@ async def predict_flow(
|
|||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
@router.post(
|
||||
"/upload/{flow_id}",
|
||||
response_model=UploadFileResponse,
|
||||
status_code=HTTPStatus.CREATED,
|
||||
)
|
||||
async def create_upload_file(file: UploadFile, flow_id: str):
|
||||
# Cache file
|
||||
try:
|
||||
file_path = save_uploaded_file(file.file, folder_name=flow_id)
|
||||
|
||||
return UploadFileResponse(
|
||||
flowId=flow_id,
|
||||
file_path=file_path,
|
||||
)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error saving file: {exc}")
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
|
||||
# get endpoint to return version of langflow
|
||||
@router.get("/version")
|
||||
def get_version():
|
||||
from langflow import __version__
|
||||
|
||||
return {"version": __version__}
|
||||
|
||||
|
||||
@router.post("/custom_component", status_code=HTTPStatus.OK)
|
||||
async def custom_component(
|
||||
raw_code: CustomComponentCode,
|
||||
):
|
||||
extractor = CustomComponent(code=raw_code.code)
|
||||
extractor.is_check_valid()
|
||||
|
||||
return build_langchain_template_custom_component(extractor)
|
||||
|
|
|
|||
|
|
@ -61,7 +61,7 @@ def update_flow(
|
|||
if not db_flow:
|
||||
raise HTTPException(status_code=404, detail="Flow not found")
|
||||
flow_data = flow.dict(exclude_unset=True)
|
||||
if not settings.remove_api_keys:
|
||||
if settings.remove_api_keys:
|
||||
flow_data = remove_api_keys(flow_data)
|
||||
for key, value in flow_data.items():
|
||||
setattr(db_flow, key, value)
|
||||
|
|
|
|||
|
|
@ -1,6 +1,18 @@
|
|||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from langflow.database.models.flow import FlowCreate, FlowRead
|
||||
from pydantic import BaseModel, Field, validator
|
||||
import json
|
||||
|
||||
|
||||
class BuildStatus(Enum):
|
||||
"""Status of the build."""
|
||||
|
||||
SUCCESS = "success"
|
||||
FAILURE = "failure"
|
||||
STARTED = "started"
|
||||
IN_PROGRESS = "in_progress"
|
||||
|
||||
|
||||
class GraphData(BaseModel):
|
||||
|
|
@ -11,7 +23,7 @@ class GraphData(BaseModel):
|
|||
|
||||
|
||||
class ExportedFlow(BaseModel):
|
||||
"""Exported flow from LangFlow."""
|
||||
"""Exported flow from Langflow."""
|
||||
|
||||
description: str
|
||||
name: str
|
||||
|
|
@ -19,41 +31,29 @@ class ExportedFlow(BaseModel):
|
|||
data: GraphData
|
||||
|
||||
|
||||
class PredictRequest(BaseModel):
|
||||
"""Predict request schema."""
|
||||
class InputRequest(BaseModel):
|
||||
input: dict
|
||||
|
||||
message: str
|
||||
|
||||
class TweaksRequest(BaseModel):
|
||||
tweaks: Optional[Dict[str, Dict[str, str]]] = Field(default_factory=dict)
|
||||
|
||||
class Config:
|
||||
schema_extra = {
|
||||
"example": {
|
||||
"message": "Hello, how are you?",
|
||||
"tweaks": {
|
||||
"dndnode_986363f0-4677-4035-9f38-74b94af5dd78": {
|
||||
"name": "A tool name",
|
||||
"description": "A tool description",
|
||||
},
|
||||
"dndnode_986363f0-4677-4035-9f38-74b94af57378": {
|
||||
"template": "A {template}",
|
||||
},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
class UpdateTemplateRequest(BaseModel):
|
||||
template: dict
|
||||
|
||||
|
||||
class PredictResponse(BaseModel):
|
||||
"""Predict response schema."""
|
||||
class ProcessResponse(BaseModel):
|
||||
"""Process response schema."""
|
||||
|
||||
result: str
|
||||
intermediate_steps: str = ""
|
||||
result: dict
|
||||
|
||||
|
||||
class ChatMessage(BaseModel):
|
||||
"""Chat message schema."""
|
||||
|
||||
is_bot: bool = False
|
||||
message: Union[str, None] = None
|
||||
message: Union[str, None, dict] = None
|
||||
type: str = "human"
|
||||
|
||||
|
||||
|
|
@ -101,3 +101,35 @@ class InitResponse(BaseModel):
|
|||
|
||||
class BuiltResponse(BaseModel):
|
||||
built: bool
|
||||
|
||||
|
||||
class UploadFileResponse(BaseModel):
|
||||
"""Upload file response schema."""
|
||||
|
||||
flowId: str
|
||||
file_path: Path
|
||||
|
||||
|
||||
class StreamData(BaseModel):
|
||||
event: str
|
||||
data: dict
|
||||
|
||||
def __str__(self) -> str:
|
||||
return f"event: {self.event}\ndata: {json.dumps(self.data)}\n\n"
|
||||
|
||||
|
||||
class CustomComponentCode(BaseModel):
|
||||
code: str
|
||||
|
||||
|
||||
class CustomComponentResponseError(BaseModel):
|
||||
detail: str
|
||||
traceback: str
|
||||
|
||||
|
||||
class ComponentListCreate(BaseModel):
|
||||
flows: List[FlowCreate]
|
||||
|
||||
|
||||
class ComponentListRead(BaseModel):
|
||||
flows: List[FlowRead]
|
||||
|
|
|
|||
|
|
@ -1,16 +1,13 @@
|
|||
import json
|
||||
|
||||
from fastapi import APIRouter, HTTPException
|
||||
|
||||
from langflow.api.v1.base import (
|
||||
Code,
|
||||
CodeValidationResponse,
|
||||
Prompt,
|
||||
ValidatePromptRequest,
|
||||
PromptValidationResponse,
|
||||
validate_prompt,
|
||||
)
|
||||
from langflow.graph.vertex.types import VectorStoreVertex
|
||||
from langflow.graph import Graph
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.utils.validate import validate_code
|
||||
|
||||
|
|
@ -31,27 +28,100 @@ def post_validate_code(code: Code):
|
|||
|
||||
|
||||
@router.post("/prompt", status_code=200, response_model=PromptValidationResponse)
|
||||
def post_validate_prompt(prompt: Prompt):
|
||||
def post_validate_prompt(prompt_request: ValidatePromptRequest):
|
||||
try:
|
||||
return validate_prompt(prompt.template)
|
||||
input_variables = validate_prompt(prompt_request.template)
|
||||
|
||||
old_custom_fields = get_old_custom_fields(prompt_request)
|
||||
|
||||
add_new_variables_to_template(input_variables, prompt_request)
|
||||
|
||||
remove_old_variables_from_template(
|
||||
old_custom_fields, input_variables, prompt_request
|
||||
)
|
||||
|
||||
update_input_variables_field(input_variables, prompt_request)
|
||||
|
||||
return PromptValidationResponse(
|
||||
input_variables=input_variables,
|
||||
frontend_node=prompt_request.frontend_node,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
raise HTTPException(status_code=500, detail=str(e)) from e
|
||||
|
||||
|
||||
# validate node
|
||||
@router.post("/node/{node_id}", status_code=200)
|
||||
def post_validate_node(node_id: str, data: dict):
|
||||
def get_old_custom_fields(prompt_request):
|
||||
try:
|
||||
# build graph
|
||||
graph = Graph.from_payload(data)
|
||||
# validate node
|
||||
node = graph.get_node(node_id)
|
||||
if node is None:
|
||||
raise ValueError(f"Node {node_id} not found")
|
||||
if not isinstance(node, VectorStoreVertex):
|
||||
node.build()
|
||||
return json.dumps({"valid": True, "params": str(node._built_object_repr())})
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
return json.dumps({"valid": False, "params": str(e)})
|
||||
old_custom_fields = prompt_request.frontend_node.custom_fields[
|
||||
prompt_request.name
|
||||
].copy()
|
||||
except KeyError:
|
||||
old_custom_fields = []
|
||||
prompt_request.frontend_node.custom_fields[prompt_request.name] = []
|
||||
return old_custom_fields
|
||||
|
||||
|
||||
def add_new_variables_to_template(input_variables, prompt_request):
|
||||
for variable in input_variables:
|
||||
try:
|
||||
template_field = TemplateField(
|
||||
name=variable,
|
||||
display_name=variable,
|
||||
field_type="str",
|
||||
show=True,
|
||||
advanced=False,
|
||||
multiline=True,
|
||||
input_types=["Document", "BaseOutputParser"],
|
||||
value="", # Set the value to empty string
|
||||
)
|
||||
if variable in prompt_request.frontend_node.template:
|
||||
# Set the new field with the old value
|
||||
template_field.value = prompt_request.frontend_node.template[variable][
|
||||
"value"
|
||||
]
|
||||
|
||||
prompt_request.frontend_node.template[variable] = template_field.to_dict()
|
||||
|
||||
# Check if variable is not already in the list before appending
|
||||
if (
|
||||
variable
|
||||
not in prompt_request.frontend_node.custom_fields[prompt_request.name]
|
||||
):
|
||||
prompt_request.frontend_node.custom_fields[prompt_request.name].append(
|
||||
variable
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
|
||||
def remove_old_variables_from_template(
|
||||
old_custom_fields, input_variables, prompt_request
|
||||
):
|
||||
for variable in old_custom_fields:
|
||||
if variable not in input_variables:
|
||||
try:
|
||||
# Remove the variable from custom_fields associated with the given name
|
||||
if (
|
||||
variable
|
||||
in prompt_request.frontend_node.custom_fields[prompt_request.name]
|
||||
):
|
||||
prompt_request.frontend_node.custom_fields[
|
||||
prompt_request.name
|
||||
].remove(variable)
|
||||
|
||||
# Remove the variable from the template
|
||||
prompt_request.frontend_node.template.pop(variable, None)
|
||||
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
raise HTTPException(status_code=500, detail=str(exc)) from exc
|
||||
|
||||
|
||||
def update_input_variables_field(input_variables, prompt_request):
|
||||
if "input_variables" in prompt_request.frontend_node.template:
|
||||
prompt_request.frontend_node.template["input_variables"][
|
||||
"value"
|
||||
] = input_variables
|
||||
|
|
|
|||
3
src/backend/langflow/cache/base.py
vendored
3
src/backend/langflow/cache/base.py
vendored
|
|
@ -62,9 +62,6 @@ class BaseCache(abc.ABC):
|
|||
|
||||
Args:
|
||||
key: The key of the item to retrieve.
|
||||
|
||||
Returns:
|
||||
The value associated with the key, or None if the key is not found.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
|
|
|
|||
51
src/backend/langflow/cache/utils.py
vendored
51
src/backend/langflow/cache/utils.py
vendored
|
|
@ -8,15 +8,17 @@ import tempfile
|
|||
from collections import OrderedDict
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict
|
||||
|
||||
from appdirs import user_cache_dir
|
||||
|
||||
CACHE: Dict[str, Any] = {}
|
||||
|
||||
CACHE_DIR = user_cache_dir("langflow", "langflow")
|
||||
|
||||
|
||||
def create_cache_folder(func):
|
||||
def wrapper(*args, **kwargs):
|
||||
# Get the destination folder
|
||||
cache_path = Path(tempfile.gettempdir()) / PREFIX
|
||||
cache_path = Path(CACHE_DIR) / PREFIX
|
||||
|
||||
# Create the destination folder if it doesn't exist
|
||||
os.makedirs(cache_path, exist_ok=True)
|
||||
|
|
@ -118,7 +120,7 @@ def save_binary_file(content: str, file_name: str, accepted_types: list[str]) ->
|
|||
raise ValueError(f"File {file_name} is not accepted")
|
||||
|
||||
# Get the destination folder
|
||||
cache_path = Path(tempfile.gettempdir()) / PREFIX
|
||||
cache_path = Path(CACHE_DIR) / PREFIX
|
||||
if not content:
|
||||
raise ValueError("Please, reload the file in the loader.")
|
||||
data = content.split(",")[1]
|
||||
|
|
@ -132,3 +134,46 @@ def save_binary_file(content: str, file_name: str, accepted_types: list[str]) ->
|
|||
file.write(decoded_bytes)
|
||||
|
||||
return file_path
|
||||
|
||||
|
||||
@create_cache_folder
|
||||
def save_uploaded_file(file, folder_name):
|
||||
"""
|
||||
Save an uploaded file to the specified folder with a hash of its content as the file name.
|
||||
|
||||
Args:
|
||||
file: The uploaded file object.
|
||||
folder_name: The name of the folder to save the file in.
|
||||
|
||||
Returns:
|
||||
The path to the saved file.
|
||||
"""
|
||||
cache_path = Path(CACHE_DIR)
|
||||
folder_path = cache_path / folder_name
|
||||
|
||||
# Create the folder if it doesn't exist
|
||||
if not folder_path.exists():
|
||||
folder_path.mkdir()
|
||||
|
||||
# Create a hash of the file content
|
||||
sha256_hash = hashlib.sha256()
|
||||
# Reset the file cursor to the beginning of the file
|
||||
file.seek(0)
|
||||
# Iterate over the uploaded file in small chunks to conserve memory
|
||||
while chunk := file.read(8192): # Read 8KB at a time (adjust as needed)
|
||||
sha256_hash.update(chunk)
|
||||
|
||||
# Use the hex digest of the hash as the file name
|
||||
hex_dig = sha256_hash.hexdigest()
|
||||
file_name = hex_dig
|
||||
|
||||
# Reset the file cursor to the beginning of the file
|
||||
file.seek(0)
|
||||
|
||||
# Save the file with the hash as its name
|
||||
file_path = folder_path / file_name
|
||||
with open(file_path, "wb") as new_file:
|
||||
while chunk := file.read(8192):
|
||||
new_file.write(chunk)
|
||||
|
||||
return file_path
|
||||
|
|
|
|||
2
src/backend/langflow/chat/config.py
Normal file
2
src/backend/langflow/chat/config.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
class ChatConfig:
|
||||
streaming: bool = True
|
||||
|
|
@ -104,16 +104,22 @@ class ChatManager:
|
|||
|
||||
async def close_connection(self, client_id: str, code: int, reason: str):
|
||||
if websocket := self.active_connections[client_id]:
|
||||
await websocket.close(code=code, reason=reason)
|
||||
self.disconnect(client_id)
|
||||
try:
|
||||
await websocket.close(code=code, reason=reason)
|
||||
self.disconnect(client_id)
|
||||
except RuntimeError as exc:
|
||||
# This is to catch the following error:
|
||||
# Unexpected ASGI message 'websocket.close', after sending 'websocket.close'
|
||||
if "after sending" in str(exc):
|
||||
logger.error(f"Error closing connection: {exc}")
|
||||
|
||||
async def process_message(
|
||||
self, client_id: str, payload: Dict, langchain_object: Any
|
||||
):
|
||||
# Process the graph data and chat message
|
||||
chat_message = payload.pop("message", "")
|
||||
chat_message = ChatMessage(message=chat_message)
|
||||
self.chat_history.add_message(client_id, chat_message)
|
||||
chat_inputs = payload.pop("inputs", "")
|
||||
chat_inputs = ChatMessage(message=chat_inputs)
|
||||
self.chat_history.add_message(client_id, chat_inputs)
|
||||
|
||||
# graph_data = payload
|
||||
start_resp = ChatResponse(message=None, type="start", intermediate_steps="")
|
||||
|
|
@ -126,7 +132,7 @@ class ChatManager:
|
|||
|
||||
result, intermediate_steps = await process_graph(
|
||||
langchain_object=langchain_object,
|
||||
chat_message=chat_message,
|
||||
chat_inputs=chat_inputs,
|
||||
websocket=self.active_connections[client_id],
|
||||
)
|
||||
except Exception as e:
|
||||
|
|
@ -144,6 +150,8 @@ class ChatManager:
|
|||
if isinstance(msg, FileResponse):
|
||||
if msg.data_type == "image":
|
||||
# Base64 encode the image
|
||||
if isinstance(msg.data, str):
|
||||
continue
|
||||
msg.data = pil_to_base64(msg.data)
|
||||
file_responses.append(msg)
|
||||
if msg.type == "start":
|
||||
|
|
@ -189,13 +197,13 @@ class ChatManager:
|
|||
langchain_object = self.in_memory_cache.get(client_id)
|
||||
await self.process_message(client_id, payload, langchain_object)
|
||||
|
||||
except Exception as e:
|
||||
except Exception as exc:
|
||||
# Handle any exceptions that might occur
|
||||
logger.error(e)
|
||||
logger.error(f"Error handling websocket: {exc}")
|
||||
await self.close_connection(
|
||||
client_id=client_id,
|
||||
code=status.WS_1011_INTERNAL_ERROR,
|
||||
reason=str(e)[:120],
|
||||
reason=str(exc)[:120],
|
||||
)
|
||||
finally:
|
||||
try:
|
||||
|
|
@ -204,6 +212,6 @@ class ChatManager:
|
|||
code=status.WS_1000_NORMAL_CLOSURE,
|
||||
reason="Client disconnected",
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error closing connection: {exc}")
|
||||
self.disconnect(client_id)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,7 @@ from langflow.utils.logger import logger
|
|||
|
||||
async def process_graph(
|
||||
langchain_object,
|
||||
chat_message: ChatMessage,
|
||||
chat_inputs: ChatMessage,
|
||||
websocket: WebSocket,
|
||||
):
|
||||
langchain_object = try_setting_streaming_options(langchain_object, websocket)
|
||||
|
|
@ -21,9 +21,13 @@ async def process_graph(
|
|||
|
||||
# Generate result and thought
|
||||
try:
|
||||
if not chat_inputs.message:
|
||||
logger.debug("No message provided")
|
||||
raise ValueError("No message provided")
|
||||
|
||||
logger.debug("Generating result and thought")
|
||||
result, intermediate_steps = await get_result_and_steps(
|
||||
langchain_object, chat_message.message or "", websocket=websocket
|
||||
langchain_object, chat_inputs.message, websocket=websocket
|
||||
)
|
||||
logger.debug("Generated result and intermediate_steps")
|
||||
return result, intermediate_steps
|
||||
|
|
|
|||
|
|
@ -1,136 +1,297 @@
|
|||
---
|
||||
agents:
|
||||
- ZeroShotAgent
|
||||
- JsonAgent
|
||||
- CSVAgent
|
||||
- AgentInitializer
|
||||
- VectorStoreAgent
|
||||
- VectorStoreRouterAgent
|
||||
- SQLAgent
|
||||
ZeroShotAgent:
|
||||
documentation: "https://python.langchain.com/docs/modules/agents/how_to/custom_mrkl_agent"
|
||||
JsonAgent:
|
||||
documentation: "https://python.langchain.com/docs/modules/agents/toolkits/openapi"
|
||||
CSVAgent:
|
||||
documentation: "https://python.langchain.com/docs/modules/agents/toolkits/csv"
|
||||
AgentInitializer:
|
||||
documentation: "https://python.langchain.com/docs/modules/agents/agent_types/"
|
||||
VectorStoreAgent:
|
||||
documentation: ""
|
||||
VectorStoreRouterAgent:
|
||||
documentation: ""
|
||||
SQLAgent:
|
||||
documentation: ""
|
||||
chains:
|
||||
- LLMChain
|
||||
- LLMMathChain
|
||||
- LLMCheckerChain
|
||||
- ConversationChain
|
||||
- SeriesCharacterChain
|
||||
- MidJourneyPromptChain
|
||||
- TimeTravelGuideChain
|
||||
- SQLDatabaseChain
|
||||
- RetrievalQA
|
||||
- RetrievalQAWithSourcesChain
|
||||
- ConversationalRetrievalChain
|
||||
- CombineDocsChain
|
||||
LLMChain:
|
||||
documentation: "https://python.langchain.com/docs/modules/chains/foundational/llm_chain"
|
||||
LLMMathChain:
|
||||
documentation: "https://python.langchain.com/docs/modules/chains/additional/llm_math"
|
||||
LLMCheckerChain:
|
||||
documentation: "https://python.langchain.com/docs/modules/chains/additional/llm_checker"
|
||||
ConversationChain:
|
||||
documentation: ""
|
||||
SeriesCharacterChain:
|
||||
documentation: ""
|
||||
MidJourneyPromptChain:
|
||||
documentation: ""
|
||||
TimeTravelGuideChain:
|
||||
documentation: ""
|
||||
SQLDatabaseChain:
|
||||
documentation: ""
|
||||
RetrievalQA:
|
||||
documentation: "https://python.langchain.com/docs/modules/chains/popular/vector_db_qa"
|
||||
RetrievalQAWithSourcesChain:
|
||||
documentation: ""
|
||||
ConversationalRetrievalChain:
|
||||
documentation: "https://python.langchain.com/docs/modules/chains/popular/chat_vector_db"
|
||||
CombineDocsChain:
|
||||
documentation: ""
|
||||
documentloaders:
|
||||
- AirbyteJSONLoader
|
||||
- CoNLLULoader
|
||||
- CSVLoader
|
||||
- UnstructuredEmailLoader
|
||||
- EverNoteLoader
|
||||
- FacebookChatLoader
|
||||
- GutenbergLoader
|
||||
- BSHTMLLoader
|
||||
- UnstructuredHTMLLoader
|
||||
# - UnstructuredImageLoader # Issue with Python 3.11 (https://github.com/Unstructured-IO/unstructured-inference/issues/83)
|
||||
- UnstructuredMarkdownLoader
|
||||
- PyPDFLoader
|
||||
- UnstructuredPowerPointLoader
|
||||
- SRTLoader
|
||||
- TelegramChatLoader
|
||||
- TextLoader
|
||||
- UnstructuredWordDocumentLoader
|
||||
- WebBaseLoader
|
||||
- AZLyricsLoader
|
||||
- CollegeConfidentialLoader
|
||||
- HNLoader
|
||||
- IFixitLoader
|
||||
- IMSDbLoader
|
||||
- GitbookLoader
|
||||
- ReadTheDocsLoader
|
||||
- SlackDirectoryLoader
|
||||
- NotionDirectoryLoader
|
||||
AirbyteJSONLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/integrations/airbyte_json"
|
||||
CoNLLULoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/integrations/conll-u"
|
||||
CSVLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/integrations/csv"
|
||||
UnstructuredEmailLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/integrations/email"
|
||||
EverNoteLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/integrations/evernote"
|
||||
FacebookChatLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/integrations/facebook_chat"
|
||||
GutenbergLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/integrations/gutenberg"
|
||||
BSHTMLLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/how_to/html"
|
||||
UnstructuredHTMLLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/how_to/html"
|
||||
UnstructuredMarkdownLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/how_to/markdown"
|
||||
PyPDFDirectoryLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/how_to/pdf"
|
||||
PyPDFLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/how_to/pdf"
|
||||
UnstructuredPowerPointLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/integrations/microsoft_powerpoint"
|
||||
SRTLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/integrations/subtitle"
|
||||
TelegramChatLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/integrations/telegram"
|
||||
TextLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/"
|
||||
UnstructuredWordDocumentLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/integrations/microsoft_word"
|
||||
WebBaseLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/integrations/web_base"
|
||||
AZLyricsLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/integrations/azlyrics"
|
||||
CollegeConfidentialLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/integrations/college_confidential"
|
||||
HNLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/integrations/hacker_news"
|
||||
IFixitLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/integrations/ifixit"
|
||||
IMSDbLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/integrations/imsdb"
|
||||
GitbookLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/integrations/gitbook"
|
||||
ReadTheDocsLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/integrations/readthedocs_documentation"
|
||||
SlackDirectoryLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/integrations/slack"
|
||||
NotionDirectoryLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/integrations/notion"
|
||||
DirectoryLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/how_to/file_directory"
|
||||
GitLoader:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_loaders/integrations/git"
|
||||
embeddings:
|
||||
- OpenAIEmbeddings
|
||||
- HuggingFaceEmbeddings
|
||||
- CohereEmbeddings
|
||||
OpenAIEmbeddings:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/text_embedding/integrations/openai"
|
||||
HuggingFaceEmbeddings:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/text_embedding/integrations/sentence_transformers"
|
||||
CohereEmbeddings:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/text_embedding/integrations/cohere"
|
||||
llms:
|
||||
- OpenAI
|
||||
# - AzureOpenAI
|
||||
# - AzureChatOpenAI
|
||||
- ChatOpenAI
|
||||
- LlamaCpp
|
||||
- CTransformers
|
||||
- Cohere
|
||||
- Anthropic
|
||||
- ChatAnthropic
|
||||
- HuggingFaceHub
|
||||
OpenAI:
|
||||
documentation: "https://python.langchain.com/docs/modules/model_io/models/llms/integrations/openai"
|
||||
ChatOpenAI:
|
||||
documentation: "https://python.langchain.com/docs/modules/model_io/models/chat/integrations/openai"
|
||||
LlamaCpp:
|
||||
documentation: "https://python.langchain.com/docs/modules/model_io/models/llms/integrations/llamacpp"
|
||||
CTransformers:
|
||||
documentation: "https://python.langchain.com/docs/modules/model_io/models/llms/integrations/ctransformers"
|
||||
Cohere:
|
||||
documentation: "https://python.langchain.com/docs/modules/model_io/models/llms/integrations/cohere"
|
||||
Anthropic:
|
||||
documentation: ""
|
||||
ChatAnthropic:
|
||||
documentation: "https://python.langchain.com/docs/modules/model_io/models/chat/integrations/anthropic"
|
||||
HuggingFaceHub:
|
||||
documentation: "https://python.langchain.com/docs/modules/model_io/models/llms/integrations/huggingface_hub"
|
||||
VertexAI:
|
||||
documentation: "https://python.langchain.com/docs/modules/model_io/models/llms/integrations/google_vertex_ai_palm"
|
||||
###
|
||||
# There's a bug in this component deactivating until we get it sorted: _language_models.py", line 804, in send_message
|
||||
# is_blocked=safety_attributes.get("blocked", False),
|
||||
# AttributeError: 'list' object has no attribute 'get'
|
||||
# ChatVertexAI:
|
||||
# documentation: "https://python.langchain.com/docs/modules/model_io/models/chat/integrations/google_vertex_ai_palm"
|
||||
###
|
||||
memories:
|
||||
- ConversationBufferMemory
|
||||
- ConversationSummaryMemory
|
||||
- ConversationKGMemory
|
||||
# https://github.com/supabase-community/supabase-py/issues/482
|
||||
# ZepChatMessageHistory:
|
||||
# documentation: "https://python.langchain.com/docs/modules/memory/integrations/zep_memory"
|
||||
ConversationEntityMemory:
|
||||
documentation: "https://python.langchain.com/docs/modules/memory/integrations/entity_memory_with_sqlite"
|
||||
# https://github.com/hwchase17/langchain/issues/6091
|
||||
# SQLiteEntityStore:
|
||||
# documentation: "https://python.langchain.com/docs/modules/memory/integrations/entity_memory_with_sqlite"
|
||||
PostgresChatMessageHistory:
|
||||
documentation: "https://python.langchain.com/docs/modules/memory/integrations/postgres_chat_message_history"
|
||||
ConversationBufferMemory:
|
||||
documentation: "https://python.langchain.com/docs/modules/memory/how_to/buffer"
|
||||
ConversationSummaryMemory:
|
||||
documentation: "https://python.langchain.com/docs/modules/memory/how_to/summary"
|
||||
ConversationKGMemory:
|
||||
documentation: "https://python.langchain.com/docs/modules/memory/how_to/kg"
|
||||
ConversationBufferWindowMemory:
|
||||
documentation: "https://python.langchain.com/docs/modules/memory/how_to/buffer_window"
|
||||
VectorStoreRetrieverMemory:
|
||||
documentation: "https://python.langchain.com/docs/modules/memory/how_to/vectorstore_retriever_memory"
|
||||
MongoDBChatMessageHistory:
|
||||
documentation: "https://python.langchain.com/docs/modules/memory/integrations/mongodb_chat_message_history"
|
||||
MotorheadMemory:
|
||||
documentation: "https://python.langchain.com/docs/integrations/memory/motorhead_memory"
|
||||
prompts:
|
||||
- PromptTemplate
|
||||
- FewShotPromptTemplate
|
||||
- ZeroShotPrompt
|
||||
ChatMessagePromptTemplate:
|
||||
documentation: "https://python.langchain.com/docs/modules/model_io/prompts/prompt_templates/msg_prompt_templates"
|
||||
HumanMessagePromptTemplate:
|
||||
documentation: "https://python.langchain.com/docs/modules/model_io/models/chat/how_to/prompts"
|
||||
SystemMessagePromptTemplate:
|
||||
documentation: "https://python.langchain.com/docs/modules/model_io/models/chat/how_to/prompts"
|
||||
ChatPromptTemplate:
|
||||
documentation: "https://python.langchain.com/docs/modules/model_io/models/chat/how_to/prompts"
|
||||
PromptTemplate:
|
||||
documentation: "https://python.langchain.com/docs/modules/model_io/prompts/prompt_templates/"
|
||||
textsplitters:
|
||||
- CharacterTextSplitter
|
||||
- RecursiveCharacterTextSplitter
|
||||
# - LatexTextSplitter
|
||||
# - PythonCodeTextSplitter
|
||||
CharacterTextSplitter:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_transformers/text_splitters/character_text_splitter"
|
||||
RecursiveCharacterTextSplitter:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/document_transformers/text_splitters/recursive_text_splitter"
|
||||
toolkits:
|
||||
- OpenAPIToolkit
|
||||
- JsonToolkit
|
||||
- VectorStoreInfo
|
||||
- VectorStoreRouterToolkit
|
||||
- VectorStoreToolkit
|
||||
OpenAPIToolkit:
|
||||
documentation: ""
|
||||
JsonToolkit:
|
||||
documentation: ""
|
||||
VectorStoreInfo:
|
||||
documentation: ""
|
||||
VectorStoreRouterToolkit:
|
||||
documentation: ""
|
||||
VectorStoreToolkit:
|
||||
documentation: ""
|
||||
tools:
|
||||
- Search
|
||||
- PAL-MATH
|
||||
- Calculator
|
||||
- Serper Search
|
||||
- Tool
|
||||
- PythonFunctionTool
|
||||
- PythonFunction
|
||||
- JsonSpec
|
||||
- News API
|
||||
- TMDB API
|
||||
- Podcast API
|
||||
- QuerySQLDataBaseTool
|
||||
- InfoSQLDatabaseTool
|
||||
- ListSQLDatabaseTool
|
||||
# - QueryCheckerTool
|
||||
- BingSearchRun
|
||||
- GoogleSearchRun
|
||||
- GoogleSearchResults
|
||||
- GoogleSerperRun
|
||||
- JsonListKeysTool
|
||||
- JsonGetValueTool
|
||||
- PythonREPLTool
|
||||
- PythonAstREPLTool
|
||||
- RequestsGetTool
|
||||
- RequestsPostTool
|
||||
- RequestsPatchTool
|
||||
- RequestsPutTool
|
||||
- RequestsDeleteTool
|
||||
- WikipediaQueryRun
|
||||
- WolframAlphaQueryRun
|
||||
Search:
|
||||
documentation: ""
|
||||
PAL-MATH:
|
||||
documentation: ""
|
||||
Calculator:
|
||||
documentation: ""
|
||||
Serper Search:
|
||||
documentation: ""
|
||||
Tool:
|
||||
documentation: ""
|
||||
PythonFunctionTool:
|
||||
documentation: ""
|
||||
PythonFunction:
|
||||
documentation: ""
|
||||
JsonSpec:
|
||||
documentation: ""
|
||||
News API:
|
||||
documentation: ""
|
||||
TMDB API:
|
||||
documentation: ""
|
||||
Podcast API:
|
||||
documentation: ""
|
||||
QuerySQLDataBaseTool:
|
||||
documentation: ""
|
||||
InfoSQLDatabaseTool:
|
||||
documentation: ""
|
||||
ListSQLDatabaseTool:
|
||||
documentation: ""
|
||||
BingSearchRun:
|
||||
documentation: ""
|
||||
GoogleSearchRun:
|
||||
documentation: ""
|
||||
GoogleSearchResults:
|
||||
documentation: ""
|
||||
GoogleSerperRun:
|
||||
documentation: ""
|
||||
JsonListKeysTool:
|
||||
documentation: ""
|
||||
JsonGetValueTool:
|
||||
documentation: ""
|
||||
PythonREPLTool:
|
||||
documentation: ""
|
||||
PythonAstREPLTool:
|
||||
documentation: ""
|
||||
RequestsGetTool:
|
||||
documentation: ""
|
||||
RequestsPostTool:
|
||||
documentation: ""
|
||||
RequestsPatchTool:
|
||||
documentation: ""
|
||||
RequestsPutTool:
|
||||
documentation: ""
|
||||
RequestsDeleteTool:
|
||||
documentation: ""
|
||||
WikipediaQueryRun:
|
||||
documentation: ""
|
||||
WolframAlphaQueryRun:
|
||||
documentation: ""
|
||||
utilities:
|
||||
- BingSearchAPIWrapper
|
||||
- GoogleSearchAPIWrapper
|
||||
- GoogleSerperAPIWrapper
|
||||
- SearxResults
|
||||
- SearxSearchWrapper
|
||||
- SerpAPIWrapper
|
||||
- WikipediaAPIWrapper
|
||||
- WolframAlphaAPIWrapper
|
||||
# - ZapierNLAWrapper
|
||||
- SQLDatabase
|
||||
BingSearchAPIWrapper:
|
||||
documentation: ""
|
||||
GoogleSearchAPIWrapper:
|
||||
documentation: ""
|
||||
GoogleSerperAPIWrapper:
|
||||
documentation: ""
|
||||
SearxResults:
|
||||
documentation: ""
|
||||
SearxSearchWrapper:
|
||||
documentation: ""
|
||||
SerpAPIWrapper:
|
||||
documentation: ""
|
||||
WikipediaAPIWrapper:
|
||||
documentation: ""
|
||||
WolframAlphaAPIWrapper:
|
||||
documentation: ""
|
||||
retrievers:
|
||||
MultiQueryRetriever:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/retrievers/how_to/MultiQueryRetriever"
|
||||
# https://github.com/supabase-community/supabase-py/issues/482
|
||||
# ZepRetriever:
|
||||
# documentation: "https://python.langchain.com/docs/modules/data_connection/retrievers/integrations/zep_memorystore"
|
||||
vectorstores:
|
||||
- Chroma
|
||||
- Qdrant
|
||||
- Weaviate
|
||||
- FAISS
|
||||
Chroma:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/vectorstores/integrations/chroma"
|
||||
Qdrant:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/vectorstores/integrations/qdrant"
|
||||
Weaviate:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/vectorstores/integrations/weaviate"
|
||||
FAISS:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/vectorstores/integrations/faiss"
|
||||
Pinecone:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/vectorstores/integrations/pinecone"
|
||||
SupabaseVectorStore:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/vectorstores/integrations/supabase"
|
||||
MongoDBAtlasVectorSearch:
|
||||
documentation: "https://python.langchain.com/docs/modules/data_connection/vectorstores/integrations/mongodb_atlas"
|
||||
# Requires docarray >=0.32.0 but langchain-serve requires jina 3.15.2 which doesn't support docarray >=0.32.0
|
||||
# DocArrayInMemorySearch:
|
||||
# documentation: "https://python.langchain.com/docs/modules/data_connection/vectorstores/integrations/docarray_in_memory"
|
||||
wrappers:
|
||||
- RequestsWrapper
|
||||
# - ChatPromptTemplate
|
||||
# - SystemMessagePromptTemplate
|
||||
# - HumanMessagePromptTemplate
|
||||
RequestsWrapper:
|
||||
documentation: ""
|
||||
SQLDatabase:
|
||||
documentation: ""
|
||||
output_parsers:
|
||||
StructuredOutputParser:
|
||||
documentation: "https://python.langchain.com/docs/modules/model_io/output_parsers/structured"
|
||||
ResponseSchema:
|
||||
documentation: "https://python.langchain.com/docs/modules/model_io/output_parsers/structured"
|
||||
custom_components:
|
||||
CustomComponent:
|
||||
documentation: ""
|
||||
|
|
|
|||
|
|
@ -2,9 +2,9 @@ from langflow.template import frontend_node
|
|||
|
||||
# These should always be instantiated
|
||||
CUSTOM_NODES = {
|
||||
"prompts": {
|
||||
"ZeroShotPrompt": frontend_node.prompts.ZeroShotPromptNode(),
|
||||
},
|
||||
# "prompts": {
|
||||
# "ZeroShotPrompt": frontend_node.prompts.ZeroShotPromptNode(),
|
||||
# },
|
||||
"tools": {
|
||||
"PythonFunctionTool": frontend_node.tools.PythonFunctionToolNode(),
|
||||
"PythonFunction": frontend_node.tools.PythonFunctionNode(),
|
||||
|
|
@ -21,12 +21,19 @@ CUSTOM_NODES = {
|
|||
"utilities": {
|
||||
"SQLDatabase": frontend_node.agents.SQLDatabaseNode(),
|
||||
},
|
||||
"memories": {
|
||||
"PostgresChatMessageHistory": frontend_node.memories.PostgresChatMessageHistoryFrontendNode(),
|
||||
"MongoDBChatMessageHistory": frontend_node.memories.MongoDBChatMessageHistoryFrontendNode(),
|
||||
},
|
||||
"chains": {
|
||||
"SeriesCharacterChain": frontend_node.chains.SeriesCharacterChainNode(),
|
||||
"TimeTravelGuideChain": frontend_node.chains.TimeTravelGuideChainNode(),
|
||||
"MidJourneyPromptChain": frontend_node.chains.MidJourneyPromptChainNode(),
|
||||
"load_qa_chain": frontend_node.chains.CombineDocsChainNode(),
|
||||
},
|
||||
"custom_components": {
|
||||
"CustomComponent": frontend_node.custom_components.CustomComponentFrontendNode(),
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,18 +1,51 @@
|
|||
from contextlib import contextmanager
|
||||
from langflow.settings import settings
|
||||
from sqlmodel import SQLModel, Session, create_engine
|
||||
from langflow.utils.logger import logger
|
||||
|
||||
|
||||
if settings.database_url.startswith("sqlite"):
|
||||
if settings.database_url and settings.database_url.startswith("sqlite"):
|
||||
connect_args = {"check_same_thread": False}
|
||||
else:
|
||||
connect_args = {}
|
||||
if not settings.database_url:
|
||||
raise RuntimeError("No database_url provided")
|
||||
engine = create_engine(settings.database_url, connect_args=connect_args)
|
||||
|
||||
|
||||
def create_db_and_tables():
|
||||
SQLModel.metadata.create_all(engine)
|
||||
logger.debug("Creating database and tables")
|
||||
try:
|
||||
SQLModel.metadata.create_all(engine)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error creating database and tables: {exc}")
|
||||
raise RuntimeError("Error creating database and tables") from exc
|
||||
# Now check if the table Flow exists, if not, something went wrong
|
||||
# and we need to create the tables again.
|
||||
from sqlalchemy import inspect
|
||||
|
||||
inspector = inspect(engine)
|
||||
if "flow" not in inspector.get_table_names():
|
||||
logger.error("Something went wrong creating the database and tables.")
|
||||
logger.error("Please check your database settings.")
|
||||
|
||||
raise RuntimeError("Something went wrong creating the database and tables.")
|
||||
else:
|
||||
logger.debug("Database and tables created successfully")
|
||||
|
||||
|
||||
@contextmanager
|
||||
def session_getter():
|
||||
try:
|
||||
session = Session(engine)
|
||||
yield session
|
||||
except Exception as e:
|
||||
print("Session rollback because of exception:", e)
|
||||
session.rollback()
|
||||
raise
|
||||
finally:
|
||||
session.close()
|
||||
|
||||
|
||||
def get_session():
|
||||
with Session(engine) as session:
|
||||
with session_getter() as session:
|
||||
yield session
|
||||
|
|
|
|||
29
src/backend/langflow/database/models/component.py
Normal file
29
src/backend/langflow/database/models/component.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
from langflow.database.models.base import SQLModelSerializable, SQLModel
|
||||
from sqlmodel import Field
|
||||
from typing import Optional
|
||||
from datetime import datetime
|
||||
import uuid
|
||||
|
||||
|
||||
class Component(SQLModelSerializable, table=True):
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, primary_key=True)
|
||||
frontend_node_id: uuid.UUID = Field(index=True)
|
||||
name: str = Field(index=True)
|
||||
description: Optional[str] = Field(default=None)
|
||||
python_code: Optional[str] = Field(default=None)
|
||||
return_type: Optional[str] = Field(default=None)
|
||||
is_disabled: bool = Field(default=False)
|
||||
is_read_only: bool = Field(default=False)
|
||||
create_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
update_at: datetime = Field(default_factory=datetime.utcnow)
|
||||
|
||||
|
||||
class ComponentModel(SQLModel):
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4)
|
||||
frontend_node_id: uuid.UUID = Field(default=uuid.uuid4())
|
||||
name: str = Field(default="")
|
||||
description: Optional[str] = None
|
||||
python_code: Optional[str] = None
|
||||
return_type: Optional[str] = None
|
||||
is_disabled: bool = False
|
||||
is_read_only: bool = False
|
||||
|
|
@ -14,6 +14,7 @@ from langflow.graph.vertex.types import (
|
|||
ToolkitVertex,
|
||||
VectorStoreVertex,
|
||||
WrapperVertex,
|
||||
RetrieverVertex,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -32,4 +33,5 @@ __all__ = [
|
|||
"ToolkitVertex",
|
||||
"VectorStoreVertex",
|
||||
"WrapperVertex",
|
||||
"RetrieverVertex",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -6,9 +6,15 @@ if TYPE_CHECKING:
|
|||
|
||||
|
||||
class Edge:
|
||||
def __init__(self, source: "Vertex", target: "Vertex"):
|
||||
def __init__(self, source: "Vertex", target: "Vertex", edge: dict):
|
||||
self.source: "Vertex" = source
|
||||
self.target: "Vertex" = target
|
||||
self.source_handle = edge.get("sourceHandle", "")
|
||||
self.target_handle = edge.get("targetHandle", "")
|
||||
# 'BaseLoader;BaseOutputParser|documents|PromptTemplate-zmTlD'
|
||||
# target_param is documents
|
||||
self.target_param = self.target_handle.split("|")[1]
|
||||
|
||||
self.validate_edge()
|
||||
|
||||
def validate_edge(self) -> None:
|
||||
|
|
@ -27,12 +33,7 @@ class Edge:
|
|||
# Get what type of input the target node is expecting
|
||||
|
||||
self.matched_type = next(
|
||||
(
|
||||
output
|
||||
for output in self.source_types
|
||||
for target_req in self.target_reqs
|
||||
if output in target_req
|
||||
),
|
||||
(output for output in self.source_types if output in self.target_reqs),
|
||||
None,
|
||||
)
|
||||
no_matched_type = self.matched_type is None
|
||||
|
|
@ -47,6 +48,16 @@ class Edge:
|
|||
|
||||
def __repr__(self) -> str:
|
||||
return (
|
||||
f"Edge(source={self.source.id}, target={self.target.id}, valid={self.valid}"
|
||||
f"Edge(source={self.source.id}, target={self.target.id}, target_param={self.target_param}"
|
||||
f", matched_type={self.matched_type})"
|
||||
)
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return hash(self.__repr__())
|
||||
|
||||
def __eq__(self, __value: object) -> bool:
|
||||
return (
|
||||
self.__repr__() == __value.__repr__()
|
||||
if isinstance(__value, Edge)
|
||||
else False
|
||||
)
|
||||
|
|
|
|||
|
|
@ -11,6 +11,7 @@ from langflow.graph.vertex.types import (
|
|||
from langflow.interface.tools.constants import FILE_TOOLS
|
||||
from langflow.utils import payload
|
||||
from langflow.utils.logger import logger
|
||||
from langchain.chains.base import Chain
|
||||
|
||||
|
||||
class Graph:
|
||||
|
|
@ -59,7 +60,7 @@ class Graph:
|
|||
# the toolkit node
|
||||
self._build_node_params()
|
||||
# remove invalid nodes
|
||||
self._remove_invalid_nodes()
|
||||
self._validate_nodes()
|
||||
|
||||
def _build_node_params(self) -> None:
|
||||
"""Identifies and handles the LLM node within the graph."""
|
||||
|
|
@ -74,14 +75,15 @@ class Graph:
|
|||
if isinstance(node, ToolkitVertex):
|
||||
node.params["llm"] = llm_node
|
||||
|
||||
def _remove_invalid_nodes(self) -> None:
|
||||
"""Removes invalid nodes from the graph."""
|
||||
self.nodes = [
|
||||
node
|
||||
for node in self.nodes
|
||||
if self._validate_node(node)
|
||||
or (len(self.nodes) == 1 and len(self.edges) == 0)
|
||||
]
|
||||
def _validate_nodes(self) -> None:
|
||||
"""Check that all nodes have edges"""
|
||||
if len(self.nodes) == 1:
|
||||
return
|
||||
for node in self.nodes:
|
||||
if not self._validate_node(node):
|
||||
raise ValueError(
|
||||
f"{node.vertex_type} is not connected to any other components"
|
||||
)
|
||||
|
||||
def _validate_node(self, node: Vertex) -> bool:
|
||||
"""Validates a node."""
|
||||
|
|
@ -99,7 +101,7 @@ class Graph:
|
|||
]
|
||||
return connected_nodes
|
||||
|
||||
def build(self) -> List[Vertex]:
|
||||
def build(self) -> Chain:
|
||||
"""Builds the graph."""
|
||||
# Get root node
|
||||
root_node = payload.get_root_node(self)
|
||||
|
|
@ -145,7 +147,7 @@ class Graph:
|
|||
def generator_build(self) -> Generator:
|
||||
"""Builds each vertex in the graph and yields it."""
|
||||
sorted_vertices = self.topological_sort()
|
||||
logger.info("Sorted vertices: %s", sorted_vertices)
|
||||
logger.debug("Sorted vertices: %s", sorted_vertices)
|
||||
yield from sorted_vertices
|
||||
|
||||
def get_node_neighbors(self, node: Vertex) -> Dict[Vertex, int]:
|
||||
|
|
@ -178,7 +180,7 @@ class Graph:
|
|||
raise ValueError(f"Source node {edge['source']} not found")
|
||||
if target is None:
|
||||
raise ValueError(f"Target node {edge['target']} not found")
|
||||
edges.append(Edge(source, target))
|
||||
edges.append(Edge(source, target, edge))
|
||||
return edges
|
||||
|
||||
def _get_vertex_class(self, node_type: str, node_lc_type: str) -> Type[Vertex]:
|
||||
|
|
@ -213,3 +215,10 @@ class Graph:
|
|||
if node_type in node_types:
|
||||
children.append(node)
|
||||
return children
|
||||
|
||||
def __repr__(self):
|
||||
node_ids = [node.id for node in self.nodes]
|
||||
edges_repr = "\n".join(
|
||||
[f"{edge.source.id} --> {edge.target.id}" for edge in self.edges]
|
||||
)
|
||||
return f"Graph:\nNodes: {node_ids}\nConnections:\n{edges_repr}"
|
||||
|
|
|
|||
|
|
@ -1,18 +1,5 @@
|
|||
from langflow.graph.vertex.base import Vertex
|
||||
from langflow.graph.vertex.types import (
|
||||
AgentVertex,
|
||||
ChainVertex,
|
||||
DocumentLoaderVertex,
|
||||
EmbeddingVertex,
|
||||
LLMVertex,
|
||||
MemoryVertex,
|
||||
PromptVertex,
|
||||
TextSplitterVertex,
|
||||
ToolVertex,
|
||||
ToolkitVertex,
|
||||
VectorStoreVertex,
|
||||
WrapperVertex,
|
||||
)
|
||||
from langflow.graph.vertex import types
|
||||
from langflow.interface.agents.base import agent_creator
|
||||
from langflow.interface.chains.base import chain_creator
|
||||
from langflow.interface.document_loaders.base import documentloader_creator
|
||||
|
|
@ -25,25 +12,26 @@ from langflow.interface.toolkits.base import toolkits_creator
|
|||
from langflow.interface.tools.base import tool_creator
|
||||
from langflow.interface.vector_store.base import vectorstore_creator
|
||||
from langflow.interface.wrappers.base import wrapper_creator
|
||||
|
||||
|
||||
from langflow.interface.output_parsers.base import output_parser_creator
|
||||
from langflow.interface.retrievers.base import retriever_creator
|
||||
from langflow.interface.custom.base import custom_component_creator
|
||||
from typing import Dict, Type
|
||||
|
||||
|
||||
DIRECT_TYPES = ["str", "bool", "code", "int", "float", "Any", "prompt"]
|
||||
|
||||
|
||||
VERTEX_TYPE_MAP: Dict[str, Type[Vertex]] = {
|
||||
**{t: PromptVertex for t in prompt_creator.to_list()},
|
||||
**{t: AgentVertex for t in agent_creator.to_list()},
|
||||
**{t: ChainVertex for t in chain_creator.to_list()},
|
||||
**{t: ToolVertex for t in tool_creator.to_list()},
|
||||
**{t: ToolkitVertex for t in toolkits_creator.to_list()},
|
||||
**{t: WrapperVertex for t in wrapper_creator.to_list()},
|
||||
**{t: LLMVertex for t in llm_creator.to_list()},
|
||||
**{t: MemoryVertex for t in memory_creator.to_list()},
|
||||
**{t: EmbeddingVertex for t in embedding_creator.to_list()},
|
||||
**{t: VectorStoreVertex for t in vectorstore_creator.to_list()},
|
||||
**{t: DocumentLoaderVertex for t in documentloader_creator.to_list()},
|
||||
**{t: TextSplitterVertex for t in textsplitter_creator.to_list()},
|
||||
**{t: types.PromptVertex for t in prompt_creator.to_list()},
|
||||
**{t: types.AgentVertex for t in agent_creator.to_list()},
|
||||
**{t: types.ChainVertex for t in chain_creator.to_list()},
|
||||
**{t: types.ToolVertex for t in tool_creator.to_list()},
|
||||
**{t: types.ToolkitVertex for t in toolkits_creator.to_list()},
|
||||
**{t: types.WrapperVertex for t in wrapper_creator.to_list()},
|
||||
**{t: types.LLMVertex for t in llm_creator.to_list()},
|
||||
**{t: types.MemoryVertex for t in memory_creator.to_list()},
|
||||
**{t: types.EmbeddingVertex for t in embedding_creator.to_list()},
|
||||
**{t: types.VectorStoreVertex for t in vectorstore_creator.to_list()},
|
||||
**{t: types.DocumentLoaderVertex for t in documentloader_creator.to_list()},
|
||||
**{t: types.TextSplitterVertex for t in textsplitter_creator.to_list()},
|
||||
**{t: types.OutputParserVertex for t in output_parser_creator.to_list()},
|
||||
**{t: types.CustomComponentVertex for t in custom_component_creator.to_list()},
|
||||
**{t: types.RetrieverVertex for t in retriever_creator.to_list()},
|
||||
}
|
||||
|
|
|
|||
|
|
@ -1,15 +1,12 @@
|
|||
from langflow.cache import utils as cache_utils
|
||||
from langflow.graph.vertex.constants import DIRECT_TYPES
|
||||
from langflow.interface import loading
|
||||
from langflow.interface.initialize import loading
|
||||
from langflow.interface.listing import ALL_TYPES_DICT
|
||||
from langflow.utils.constants import DIRECT_TYPES
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.utils.util import sync_to_async
|
||||
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
import types
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
|
|
@ -26,6 +23,7 @@ class Vertex:
|
|||
self._parse_data()
|
||||
self._built_object = None
|
||||
self._built = False
|
||||
self.artifacts: Dict[str, Any] = {}
|
||||
|
||||
def _parse_data(self) -> None:
|
||||
self.data = self._data["data"]
|
||||
|
|
@ -46,6 +44,14 @@ class Vertex:
|
|||
for key, value in template_dicts.items()
|
||||
if not value["required"]
|
||||
]
|
||||
# Add the template_dicts[key]["input_types"] to the optional_inputs
|
||||
self.optional_inputs.extend(
|
||||
[
|
||||
input_type
|
||||
for value in template_dicts.values()
|
||||
for input_type in value.get("input_types", [])
|
||||
]
|
||||
)
|
||||
|
||||
template_dict = self.data["node"]["template"]
|
||||
self.vertex_type = (
|
||||
|
|
@ -61,6 +67,7 @@ class Vertex:
|
|||
break
|
||||
|
||||
def _build_params(self):
|
||||
# sourcery skip: merge-list-append, remove-redundant-if
|
||||
# Some params are required, some are optional
|
||||
# but most importantly, some params are python base classes
|
||||
# like str and others are LangChain objects like LLMChain, BasePromptTemplate
|
||||
|
|
@ -81,8 +88,19 @@ class Vertex:
|
|||
if isinstance(value, dict)
|
||||
}
|
||||
params = {}
|
||||
|
||||
for edge in self.edges:
|
||||
param_key = edge.target_param
|
||||
if param_key in template_dict:
|
||||
if template_dict[param_key]["list"]:
|
||||
if param_key not in params:
|
||||
params[param_key] = []
|
||||
params[param_key].append(edge.source)
|
||||
elif edge.target.id == self.id:
|
||||
params[param_key] = edge.source
|
||||
|
||||
for key, value in template_dict.items():
|
||||
if key == "_type":
|
||||
if key == "_type" or not value.get("show"):
|
||||
continue
|
||||
# If the type is not transformable to a python base class
|
||||
# then we need to get the edge that connects to this node
|
||||
|
|
@ -90,132 +108,136 @@ class Vertex:
|
|||
# Load the type in value.get('suffixes') using
|
||||
# what is inside value.get('content')
|
||||
# value.get('value') is the file name
|
||||
file_name = value.get("value")
|
||||
content = value.get("content")
|
||||
type_to_load = value.get("suffixes")
|
||||
file_path = cache_utils.save_binary_file(
|
||||
content=content, file_name=file_name, accepted_types=type_to_load
|
||||
)
|
||||
file_path = value.get("file_path")
|
||||
|
||||
params[key] = file_path
|
||||
elif value.get("type") in DIRECT_TYPES and params.get(key) is None:
|
||||
params[key] = value.get("value")
|
||||
|
||||
elif value.get("type") not in DIRECT_TYPES:
|
||||
# Get the edge that connects to this node
|
||||
edges = [
|
||||
edge
|
||||
for edge in self.edges
|
||||
if edge.target == self and edge.matched_type in value["type"]
|
||||
]
|
||||
|
||||
# Get the output of the node that the edge connects to
|
||||
# if the value['list'] is True, then there will be more
|
||||
# than one time setting to params[key]
|
||||
# so we need to append to a list if it exists
|
||||
# or create a new list if it doesn't
|
||||
|
||||
if value["required"] and not edges:
|
||||
# If a required parameter is not found, raise an error
|
||||
raise ValueError(
|
||||
f"Required input {key} for module {self.vertex_type} not found"
|
||||
)
|
||||
elif value["list"]:
|
||||
# If this is a list parameter, append all sources to a list
|
||||
params[key] = [edge.source for edge in edges]
|
||||
elif edges:
|
||||
# If a single parameter is found, use its source
|
||||
params[key] = edges[0].source
|
||||
|
||||
elif value["required"] or value.get("value"):
|
||||
# If value does not have value this still passes
|
||||
# but then gives a keyError
|
||||
# so we need to check if value has value
|
||||
new_value = value.get("value")
|
||||
if new_value is None:
|
||||
warnings.warn(f"Value for {key} in {self.vertex_type} is None. ")
|
||||
if value.get("type") == "int":
|
||||
with contextlib.suppress(TypeError, ValueError):
|
||||
new_value = int(new_value) # type: ignore
|
||||
params[key] = new_value
|
||||
|
||||
if not value.get("required") and params.get(key) is None:
|
||||
if value.get("default"):
|
||||
params[key] = value.get("default")
|
||||
else:
|
||||
params.pop(key, None)
|
||||
# Add _type to params
|
||||
self.params = params
|
||||
|
||||
def _build(self):
|
||||
# The params dict is used to build the module
|
||||
# it contains values and keys that point to nodes which
|
||||
# have their own params dict
|
||||
# When build is called, we iterate through the params dict
|
||||
# and if the value is a node, we call build on that node
|
||||
# and use the output of that build as the value for the param
|
||||
# if the value is not a node, then we use the value as the param
|
||||
# and continue
|
||||
# Another aspect is that the node_type is the class that we need to import
|
||||
# and instantiate with these built params
|
||||
"""
|
||||
Initiate the build process.
|
||||
"""
|
||||
logger.debug(f"Building {self.vertex_type}")
|
||||
# Build each node in the params dict
|
||||
self._build_each_node_in_params_dict()
|
||||
self._get_and_instantiate_class()
|
||||
self._validate_built_object()
|
||||
|
||||
self._built = True
|
||||
|
||||
def _build_each_node_in_params_dict(self):
|
||||
"""
|
||||
Iterates over each node in the params dictionary and builds it.
|
||||
"""
|
||||
for key, value in self.params.copy().items():
|
||||
# Check if Node or list of Nodes and not self
|
||||
# to avoid recursion
|
||||
if isinstance(value, Vertex):
|
||||
if self._is_node(value):
|
||||
if value == self:
|
||||
del self.params[key]
|
||||
continue
|
||||
result = value.build()
|
||||
# If the key is "func", then we need to use the run method
|
||||
if key == "func":
|
||||
if not isinstance(result, types.FunctionType):
|
||||
# func can be
|
||||
# PythonFunction(code='\ndef upper_case(text: str) -> str:\n return text.upper()\n')
|
||||
# so we need to check if there is an attribute called run
|
||||
if hasattr(result, "run"):
|
||||
result = result.run # type: ignore
|
||||
elif hasattr(result, "get_function"):
|
||||
result = result.get_function() # type: ignore
|
||||
elif inspect.iscoroutinefunction(result):
|
||||
self.params["coroutine"] = result
|
||||
else:
|
||||
# turn result which is a function into a coroutine
|
||||
# so that it can be awaited
|
||||
self.params["coroutine"] = sync_to_async(result)
|
||||
if isinstance(result, list):
|
||||
# If the result is a list, then we need to extend the list
|
||||
# with the result but first check if the key exists
|
||||
# if it doesn't, then we need to create a new list
|
||||
if isinstance(self.params[key], list):
|
||||
self.params[key].extend(result)
|
||||
self._build_node_and_update_params(key, value)
|
||||
elif isinstance(value, list) and self._is_list_of_nodes(value):
|
||||
self._build_list_of_nodes_and_update_params(key, value)
|
||||
|
||||
self.params[key] = result
|
||||
elif isinstance(value, list) and all(
|
||||
isinstance(node, Vertex) for node in value
|
||||
):
|
||||
self.params[key] = []
|
||||
for node in value:
|
||||
built = node.build()
|
||||
if isinstance(built, list):
|
||||
self.params[key].extend(built)
|
||||
else:
|
||||
self.params[key].append(built)
|
||||
def _is_node(self, value):
|
||||
"""
|
||||
Checks if the provided value is an instance of Vertex.
|
||||
"""
|
||||
return isinstance(value, Vertex)
|
||||
|
||||
# Get the class from LANGCHAIN_TYPES_DICT
|
||||
# and instantiate it with the params
|
||||
# and return the instance
|
||||
def _is_list_of_nodes(self, value):
|
||||
"""
|
||||
Checks if the provided value is a list of Vertex instances.
|
||||
"""
|
||||
return all(self._is_node(node) for node in value)
|
||||
|
||||
def _build_node_and_update_params(self, key, node):
|
||||
"""
|
||||
Builds a given node and updates the params dictionary accordingly.
|
||||
"""
|
||||
result = node.build()
|
||||
self._handle_func(key, result)
|
||||
if isinstance(result, list):
|
||||
self._extend_params_list_with_result(key, result)
|
||||
self.params[key] = result
|
||||
|
||||
def _build_list_of_nodes_and_update_params(self, key, nodes):
|
||||
"""
|
||||
Iterates over a list of nodes, builds each and updates the params dictionary.
|
||||
"""
|
||||
self.params[key] = []
|
||||
for node in nodes:
|
||||
built = node.build()
|
||||
if isinstance(built, list):
|
||||
if key not in self.params:
|
||||
self.params[key] = []
|
||||
self.params[key].extend(built)
|
||||
else:
|
||||
self.params[key].append(built)
|
||||
|
||||
def _handle_func(self, key, result):
|
||||
"""
|
||||
Handles 'func' key by checking if the result is a function and setting it as coroutine.
|
||||
"""
|
||||
if key == "func":
|
||||
if not isinstance(result, types.FunctionType):
|
||||
if hasattr(result, "run"):
|
||||
result = result.run # type: ignore
|
||||
elif hasattr(result, "get_function"):
|
||||
result = result.get_function() # type: ignore
|
||||
elif inspect.iscoroutinefunction(result):
|
||||
self.params["coroutine"] = result
|
||||
else:
|
||||
self.params["coroutine"] = sync_to_async(result)
|
||||
|
||||
def _extend_params_list_with_result(self, key, result):
|
||||
"""
|
||||
Extends a list in the params dictionary with the given result if it exists.
|
||||
"""
|
||||
if isinstance(self.params[key], list):
|
||||
self.params[key].extend(result)
|
||||
|
||||
def _get_and_instantiate_class(self):
|
||||
"""
|
||||
Gets the class from a dictionary and instantiates it with the params.
|
||||
"""
|
||||
if self.base_type is None:
|
||||
raise ValueError(f"Base type for node {self.vertex_type} not found")
|
||||
try:
|
||||
self._built_object = loading.instantiate_class(
|
||||
result = loading.instantiate_class(
|
||||
node_type=self.vertex_type,
|
||||
base_type=self.base_type,
|
||||
params=self.params,
|
||||
)
|
||||
self._update_built_object_and_artifacts(result)
|
||||
except Exception as exc:
|
||||
raise ValueError(
|
||||
f"Error building node {self.vertex_type}: {str(exc)}"
|
||||
) from exc
|
||||
|
||||
def _update_built_object_and_artifacts(self, result):
|
||||
"""
|
||||
Updates the built object and its artifacts.
|
||||
"""
|
||||
if isinstance(result, tuple):
|
||||
self._built_object, self.artifacts = result
|
||||
else:
|
||||
self._built_object = result
|
||||
|
||||
def _validate_built_object(self):
|
||||
"""
|
||||
Checks if the built object is None and raises a ValueError if so.
|
||||
"""
|
||||
if self._built_object is None:
|
||||
raise ValueError(f"Node type {self.vertex_type} not found")
|
||||
|
||||
self._built = True
|
||||
|
||||
def build(self, force: bool = False) -> Any:
|
||||
if not self._built or force:
|
||||
self._build()
|
||||
|
|
@ -223,10 +245,11 @@ class Vertex:
|
|||
return self._built_object
|
||||
|
||||
def add_edge(self, edge: "Edge") -> None:
|
||||
self.edges.append(edge)
|
||||
if edge not in self.edges:
|
||||
self.edges.append(edge)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"Node(id={self.id}, data={self.data})"
|
||||
return f"Vertex(id={self.id}, data={self.data})"
|
||||
|
||||
def __eq__(self, __o: object) -> bool:
|
||||
return self.id == __o.id if isinstance(__o, Vertex) else False
|
||||
|
|
@ -235,4 +258,5 @@ class Vertex:
|
|||
return id(self)
|
||||
|
||||
def _built_object_repr(self):
|
||||
return repr(self._built_object)
|
||||
# Add a message with an emoji, stars for sucess,
|
||||
return "Built sucessfully ✨" if self._built_object else "Failed to build 😵💫"
|
||||
|
|
|
|||
|
|
@ -1 +1 @@
|
|||
DIRECT_TYPES = ["str", "bool", "code", "int", "float", "Any", "prompt"]
|
||||
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
import ast
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from langflow.graph.vertex.base import Vertex
|
||||
|
|
@ -79,7 +80,7 @@ class WrapperVertex(Vertex):
|
|||
def build(self, force: bool = False) -> Any:
|
||||
if not self._built or force:
|
||||
if "headers" in self.params:
|
||||
self.params["headers"] = eval(self.params["headers"])
|
||||
self.params["headers"] = ast.literal_eval(self.params["headers"])
|
||||
self._build()
|
||||
return self._built_object
|
||||
|
||||
|
|
@ -91,8 +92,13 @@ class DocumentLoaderVertex(Vertex):
|
|||
def _built_object_repr(self):
|
||||
# This built_object is a list of documents. Maybe we should
|
||||
# show how many documents are in the list?
|
||||
|
||||
if self._built_object:
|
||||
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(
|
||||
self._built_object
|
||||
)
|
||||
return f"""{self.vertex_type}({len(self._built_object)} documents)
|
||||
\nAvg. Document Length (characters): {int(avg_length)}
|
||||
Documents: {self._built_object[:3]}..."""
|
||||
return f"{self.vertex_type}()"
|
||||
|
||||
|
|
@ -106,15 +112,17 @@ class VectorStoreVertex(Vertex):
|
|||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="vectorstores")
|
||||
|
||||
def _built_object_repr(self):
|
||||
return "Vector stores can take time to build. It will build on the first query."
|
||||
|
||||
|
||||
class MemoryVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="memory")
|
||||
|
||||
|
||||
class RetrieverVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="retrievers")
|
||||
|
||||
|
||||
class TextSplitterVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="textsplitters")
|
||||
|
|
@ -122,8 +130,13 @@ class TextSplitterVertex(Vertex):
|
|||
def _built_object_repr(self):
|
||||
# This built_object is a list of documents. Maybe we should
|
||||
# show how many documents are in the list?
|
||||
|
||||
if self._built_object:
|
||||
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(
|
||||
self._built_object
|
||||
)
|
||||
return f"""{self.vertex_type}({len(self._built_object)} documents)
|
||||
\nAvg. Document Length (characters): {int(avg_length)}
|
||||
\nDocuments: {self._built_object[:3]}..."""
|
||||
return f"{self.vertex_type}()"
|
||||
|
||||
|
|
@ -183,11 +196,55 @@ class PromptVertex(Vertex):
|
|||
]
|
||||
else:
|
||||
prompt_params = ["template"]
|
||||
for param in prompt_params:
|
||||
prompt_text = self.params[param]
|
||||
variables = extract_input_variables_from_prompt(prompt_text)
|
||||
self.params["input_variables"].extend(variables)
|
||||
self.params["input_variables"] = list(set(self.params["input_variables"]))
|
||||
|
||||
if "prompt" not in self.params and "messages" not in self.params:
|
||||
for param in prompt_params:
|
||||
prompt_text = self.params[param]
|
||||
variables = extract_input_variables_from_prompt(prompt_text)
|
||||
self.params["input_variables"].extend(variables)
|
||||
self.params["input_variables"] = list(
|
||||
set(self.params["input_variables"])
|
||||
)
|
||||
else:
|
||||
self.params.pop("input_variables", None)
|
||||
|
||||
self._build()
|
||||
return self._built_object
|
||||
|
||||
def _built_object_repr(self):
|
||||
if (
|
||||
not self.artifacts
|
||||
or self._built_object is None
|
||||
or not hasattr(self._built_object, "format")
|
||||
):
|
||||
return super()._built_object_repr()
|
||||
# We'll build the prompt with the artifacts
|
||||
# to show the user what the prompt looks like
|
||||
# with the variables filled in
|
||||
artifacts = self.artifacts.copy()
|
||||
# Remove the handle_keys from the artifacts
|
||||
# so the prompt format doesn't break
|
||||
artifacts.pop("handle_keys", None)
|
||||
try:
|
||||
template = self._built_object.format(**artifacts)
|
||||
return (
|
||||
template
|
||||
if isinstance(template, str)
|
||||
else f"{self.vertex_type}({template})"
|
||||
)
|
||||
except KeyError:
|
||||
return str(self._built_object)
|
||||
|
||||
|
||||
class OutputParserVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="output_parsers")
|
||||
|
||||
|
||||
class CustomComponentVertex(Vertex):
|
||||
def __init__(self, data: Dict):
|
||||
super().__init__(data, base_type="custom_components")
|
||||
|
||||
def _built_object_repr(self):
|
||||
if self.artifacts and "repr" in self.artifacts:
|
||||
return self.artifacts["repr"] or super()._built_object_repr()
|
||||
|
|
|
|||
|
|
@ -6,13 +6,20 @@ from langflow.custom.customs import get_custom_nodes
|
|||
from langflow.interface.agents.custom import CUSTOM_AGENTS
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.settings import settings
|
||||
from langflow.template.frontend_node.agents import AgentFrontendNode
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.utils.util import build_template_from_class
|
||||
from langflow.utils.util import build_template_from_class, build_template_from_method
|
||||
|
||||
|
||||
class AgentCreator(LangChainTypeCreator):
|
||||
type_name: str = "agents"
|
||||
|
||||
from_method_nodes = {"ZeroShotAgent": "from_llm_and_tools"}
|
||||
|
||||
@property
|
||||
def frontend_node_class(self) -> type[AgentFrontendNode]:
|
||||
return AgentFrontendNode
|
||||
|
||||
@property
|
||||
def type_to_loader_dict(self) -> Dict:
|
||||
if self.type_dict is None:
|
||||
|
|
@ -27,6 +34,13 @@ class AgentCreator(LangChainTypeCreator):
|
|||
try:
|
||||
if name in get_custom_nodes(self.type_name).keys():
|
||||
return get_custom_nodes(self.type_name)[name]
|
||||
elif name in self.from_method_nodes:
|
||||
return build_template_from_method(
|
||||
name,
|
||||
type_to_cls_dict=self.type_to_loader_dict,
|
||||
add_function=True,
|
||||
method_name=self.from_method_nodes[name],
|
||||
)
|
||||
return build_template_from_class(
|
||||
name, self.type_to_loader_dict, add_function=True
|
||||
)
|
||||
|
|
|
|||
|
|
@ -6,6 +6,7 @@ from langchain.agents import (
|
|||
Tool,
|
||||
ZeroShotAgent,
|
||||
initialize_agent,
|
||||
AgentType,
|
||||
)
|
||||
from langchain.agents.agent_toolkits import (
|
||||
SQLDatabaseToolkit,
|
||||
|
|
@ -156,7 +157,7 @@ class VectorStoreAgent(CustomAgentExecutor):
|
|||
llm_chain=llm_chain, allowed_tools=tool_names, **kwargs # type: ignore
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent, tools=tools, verbose=True
|
||||
agent=agent, tools=tools, verbose=True, handle_parsing_errors=True
|
||||
)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
|
|
@ -192,7 +193,7 @@ class SQLAgent(CustomAgentExecutor):
|
|||
from langchain.tools.sql_database.tool import (
|
||||
InfoSQLDatabaseTool,
|
||||
ListSQLDatabaseTool,
|
||||
QueryCheckerTool,
|
||||
QuerySQLCheckerTool,
|
||||
QuerySQLDataBaseTool,
|
||||
)
|
||||
|
||||
|
|
@ -207,7 +208,7 @@ class SQLAgent(CustomAgentExecutor):
|
|||
QuerySQLDataBaseTool(db=db), # type: ignore
|
||||
InfoSQLDatabaseTool(db=db), # type: ignore
|
||||
ListSQLDatabaseTool(db=db), # type: ignore
|
||||
QueryCheckerTool(db=db, llm_chain=llmchain, llm=llm), # type: ignore
|
||||
QuerySQLCheckerTool(db=db, llm_chain=llmchain, llm=llm), # type: ignore
|
||||
]
|
||||
|
||||
prefix = SQL_PREFIX.format(dialect=toolkit.dialect, top_k=10)
|
||||
|
|
@ -231,6 +232,7 @@ class SQLAgent(CustomAgentExecutor):
|
|||
verbose=True,
|
||||
max_iterations=15,
|
||||
early_stopping_method="force",
|
||||
handle_parsing_errors=True,
|
||||
)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
|
|
@ -275,7 +277,7 @@ class VectorStoreRouterAgent(CustomAgentExecutor):
|
|||
llm_chain=llm_chain, allowed_tools=tool_names, **kwargs # type: ignore
|
||||
)
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent, tools=tools, verbose=True
|
||||
agent=agent, tools=tools, verbose=True, handle_parsing_errors=True
|
||||
)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
|
|
@ -297,6 +299,9 @@ class InitializeAgent(CustomAgentExecutor):
|
|||
agent: str,
|
||||
memory: Optional[BaseChatMemory] = None,
|
||||
):
|
||||
# Find which value in the AgentType enum corresponds to the string
|
||||
# passed in as agent
|
||||
agent = AgentType(agent)
|
||||
return initialize_agent(
|
||||
tools=tools,
|
||||
llm=llm,
|
||||
|
|
@ -304,6 +309,7 @@ class InitializeAgent(CustomAgentExecutor):
|
|||
agent=agent, # type: ignore
|
||||
memory=memory,
|
||||
return_intermediate_steps=True,
|
||||
handle_parsing_errors=True,
|
||||
)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ from langflow.template.field.base import TemplateField
|
|||
from langflow.template.frontend_node.base import FrontendNode
|
||||
from langflow.template.template.base import Template
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.settings import settings
|
||||
|
||||
# Assuming necessary imports for Field, Template, and FrontendNode classes
|
||||
|
||||
|
|
@ -15,12 +16,29 @@ from langflow.utils.logger import logger
|
|||
class LangChainTypeCreator(BaseModel, ABC):
|
||||
type_name: str
|
||||
type_dict: Optional[Dict] = None
|
||||
name_docs_dict: Optional[Dict[str, str]] = None
|
||||
|
||||
@property
|
||||
def frontend_node_class(self) -> Type[FrontendNode]:
|
||||
"""The class type of the FrontendNode created in frontend_node."""
|
||||
return FrontendNode
|
||||
|
||||
@property
|
||||
def docs_map(self) -> Dict[str, str]:
|
||||
"""A dict with the name of the component as key and the documentation link as value."""
|
||||
if self.name_docs_dict is None:
|
||||
try:
|
||||
type_settings = getattr(settings, self.type_name)
|
||||
self.name_docs_dict = {
|
||||
name: value_dict["documentation"]
|
||||
for name, value_dict in type_settings.items()
|
||||
}
|
||||
except AttributeError as exc:
|
||||
logger.error(f"Error getting settings for {self.type_name}: {exc}")
|
||||
|
||||
self.name_docs_dict = {}
|
||||
return self.name_docs_dict
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def type_to_loader_dict(self) -> Dict:
|
||||
|
|
@ -68,7 +86,7 @@ class LangChainTypeCreator(BaseModel, ABC):
|
|||
value=value.get("value", None),
|
||||
suffixes=value.get("suffixes", []),
|
||||
file_types=value.get("fileTypes", []),
|
||||
content=value.get("content", None),
|
||||
file_path=value.get("file_path", None),
|
||||
)
|
||||
for key, value in signature["template"].items()
|
||||
if key != "_type"
|
||||
|
|
@ -83,7 +101,7 @@ class LangChainTypeCreator(BaseModel, ABC):
|
|||
|
||||
signature.add_extra_fields()
|
||||
signature.add_extra_base_classes()
|
||||
|
||||
signature.set_documentation(self.docs_map.get(name, ""))
|
||||
return signature
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -23,6 +23,7 @@ class ChainCreator(LangChainTypeCreator):
|
|||
from_method_nodes = {
|
||||
"ConversationalRetrievalChain": "from_llm",
|
||||
"LLMCheckerChain": "from_llm",
|
||||
"SQLDatabaseChain": "from_llm",
|
||||
}
|
||||
|
||||
@property
|
||||
|
|
|
|||
4
src/backend/langflow/interface/custom/__init__.py
Normal file
4
src/backend/langflow/interface/custom/__init__.py
Normal file
|
|
@ -0,0 +1,4 @@
|
|||
from langflow.interface.custom.base import CustomComponentCreator
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
|
||||
__all__ = ["CustomComponentCreator", "CustomComponent"]
|
||||
48
src/backend/langflow/interface/custom/base.py
Normal file
48
src/backend/langflow/interface/custom/base.py
Normal file
|
|
@ -0,0 +1,48 @@
|
|||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
|
||||
# from langflow.interface.custom.custom import CustomComponent
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
from langflow.template.frontend_node.custom_components import (
|
||||
CustomComponentFrontendNode,
|
||||
)
|
||||
from langflow.utils.logger import logger
|
||||
|
||||
# Assuming necessary imports for Field, Template, and FrontendNode classes
|
||||
|
||||
|
||||
class CustomComponentCreator(LangChainTypeCreator):
|
||||
type_name: str = "custom_components"
|
||||
|
||||
@property
|
||||
def frontend_node_class(self) -> Type[CustomComponentFrontendNode]:
|
||||
return CustomComponentFrontendNode
|
||||
|
||||
@property
|
||||
def type_to_loader_dict(self) -> Dict:
|
||||
if self.type_dict is None:
|
||||
self.type_dict: dict[str, Any] = {
|
||||
"CustomComponent": CustomComponent,
|
||||
}
|
||||
return self.type_dict
|
||||
|
||||
def get_signature(self, name: str) -> Optional[Dict]:
|
||||
from langflow.custom.customs import get_custom_nodes
|
||||
|
||||
try:
|
||||
if name in get_custom_nodes(self.type_name).keys():
|
||||
return get_custom_nodes(self.type_name)[name]
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"CustomComponent {name} not found: {exc}") from exc
|
||||
except AttributeError as exc:
|
||||
logger.error(f"CustomComponent {name} not loaded: {exc}")
|
||||
return None
|
||||
return None
|
||||
|
||||
def to_list(self) -> List[str]:
|
||||
return list(self.type_to_loader_dict.keys())
|
||||
|
||||
|
||||
custom_component_creator = CustomComponentCreator()
|
||||
272
src/backend/langflow/interface/custom/code_parser.py
Normal file
272
src/backend/langflow/interface/custom/code_parser.py
Normal file
|
|
@ -0,0 +1,272 @@
|
|||
import ast
|
||||
import inspect
|
||||
import traceback
|
||||
|
||||
from typing import Dict, Any, List, Type, Union
|
||||
from fastapi import HTTPException
|
||||
from langflow.interface.custom.schema import CallableCodeDetails, ClassCodeDetails
|
||||
|
||||
|
||||
class CodeSyntaxError(HTTPException):
|
||||
pass
|
||||
|
||||
|
||||
class CodeParser:
|
||||
"""
|
||||
A parser for Python source code, extracting code details.
|
||||
"""
|
||||
|
||||
def __init__(self, code: Union[str, Type]) -> None:
|
||||
"""
|
||||
Initializes the parser with the provided code.
|
||||
"""
|
||||
if isinstance(code, type):
|
||||
if not inspect.isclass(code):
|
||||
raise ValueError("The provided code must be a class.")
|
||||
# If the code is a class, get its source code
|
||||
code = inspect.getsource(code)
|
||||
self.code = code
|
||||
self.data: Dict[str, Any] = {
|
||||
"imports": [],
|
||||
"functions": [],
|
||||
"classes": [],
|
||||
"global_vars": [],
|
||||
}
|
||||
self.handlers = {
|
||||
ast.Import: self.parse_imports,
|
||||
ast.ImportFrom: self.parse_imports,
|
||||
ast.FunctionDef: self.parse_functions,
|
||||
ast.ClassDef: self.parse_classes,
|
||||
ast.Assign: self.parse_global_vars,
|
||||
}
|
||||
|
||||
def __get_tree(self):
|
||||
"""
|
||||
Parses the provided code to validate its syntax.
|
||||
It tries to parse the code into an abstract syntax tree (AST).
|
||||
"""
|
||||
try:
|
||||
tree = ast.parse(self.code)
|
||||
except SyntaxError as err:
|
||||
raise CodeSyntaxError(
|
||||
status_code=400,
|
||||
detail={"error": err.msg, "traceback": traceback.format_exc()},
|
||||
) from err
|
||||
|
||||
return tree
|
||||
|
||||
def parse_node(self, node: Union[ast.stmt, ast.AST]) -> None:
|
||||
"""
|
||||
Parses an AST node and updates the data
|
||||
dictionary with the relevant information.
|
||||
"""
|
||||
if handler := self.handlers.get(type(node)): # type: ignore
|
||||
handler(node) # type: ignore
|
||||
|
||||
def parse_imports(self, node: Union[ast.Import, ast.ImportFrom]) -> None:
|
||||
"""
|
||||
Extracts "imports" from the code.
|
||||
"""
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
self.data["imports"].append(alias.name)
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
for alias in node.names:
|
||||
self.data["imports"].append((node.module, alias.name))
|
||||
|
||||
def parse_functions(self, node: ast.FunctionDef) -> None:
|
||||
"""
|
||||
Extracts "functions" from the code.
|
||||
"""
|
||||
self.data["functions"].append(self.parse_callable_details(node))
|
||||
|
||||
def parse_arg(self, arg, default):
|
||||
"""
|
||||
Parses an argument and its default value.
|
||||
"""
|
||||
arg_dict = {"name": arg.arg, "default": default}
|
||||
if arg.annotation:
|
||||
arg_dict["type"] = ast.unparse(arg.annotation)
|
||||
return arg_dict
|
||||
|
||||
def parse_callable_details(self, node: ast.FunctionDef) -> Dict[str, Any]:
|
||||
"""
|
||||
Extracts details from a single function or method node.
|
||||
"""
|
||||
func = CallableCodeDetails(
|
||||
name=node.name,
|
||||
doc=ast.get_docstring(node),
|
||||
args=[],
|
||||
body=[],
|
||||
return_type=ast.unparse(node.returns) if node.returns else None,
|
||||
)
|
||||
|
||||
func.args = self.parse_function_args(node)
|
||||
func.body = self.parse_function_body(node)
|
||||
|
||||
return func.dict()
|
||||
|
||||
def parse_function_args(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Parses the arguments of a function or method node.
|
||||
"""
|
||||
args = []
|
||||
|
||||
args += self.parse_positional_args(node)
|
||||
args += self.parse_varargs(node)
|
||||
args += self.parse_keyword_args(node)
|
||||
args += self.parse_kwargs(node)
|
||||
|
||||
return args
|
||||
|
||||
def parse_positional_args(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Parses the positional arguments of a function or method node.
|
||||
"""
|
||||
num_args = len(node.args.args)
|
||||
num_defaults = len(node.args.defaults)
|
||||
num_missing_defaults = num_args - num_defaults
|
||||
missing_defaults = [None] * num_missing_defaults
|
||||
default_values = [
|
||||
ast.unparse(default).strip("'") if default else None
|
||||
for default in node.args.defaults
|
||||
]
|
||||
# Now check all default values to see if there
|
||||
# are any "None" values in the middle
|
||||
default_values = [
|
||||
None if value == "None" else value for value in default_values
|
||||
]
|
||||
|
||||
defaults = missing_defaults + default_values
|
||||
|
||||
args = [
|
||||
self.parse_arg(arg, default)
|
||||
for arg, default in zip(node.args.args, defaults)
|
||||
]
|
||||
return args
|
||||
|
||||
def parse_varargs(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Parses the *args argument of a function or method node.
|
||||
"""
|
||||
args = []
|
||||
|
||||
if node.args.vararg:
|
||||
args.append(self.parse_arg(node.args.vararg, None))
|
||||
|
||||
return args
|
||||
|
||||
def parse_keyword_args(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Parses the keyword-only arguments of a function or method node.
|
||||
"""
|
||||
kw_defaults = [None] * (
|
||||
len(node.args.kwonlyargs) - len(node.args.kw_defaults)
|
||||
) + [
|
||||
ast.unparse(default) if default else None
|
||||
for default in node.args.kw_defaults
|
||||
]
|
||||
|
||||
args = [
|
||||
self.parse_arg(arg, default)
|
||||
for arg, default in zip(node.args.kwonlyargs, kw_defaults)
|
||||
]
|
||||
return args
|
||||
|
||||
def parse_kwargs(self, node: ast.FunctionDef) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
Parses the **kwargs argument of a function or method node.
|
||||
"""
|
||||
args = []
|
||||
|
||||
if node.args.kwarg:
|
||||
args.append(self.parse_arg(node.args.kwarg, None))
|
||||
|
||||
return args
|
||||
|
||||
def parse_function_body(self, node: ast.FunctionDef) -> List[str]:
|
||||
"""
|
||||
Parses the body of a function or method node.
|
||||
"""
|
||||
return [ast.unparse(line) for line in node.body]
|
||||
|
||||
def parse_assign(self, stmt):
|
||||
"""
|
||||
Parses an Assign statement and returns a dictionary
|
||||
with the target's name and value.
|
||||
"""
|
||||
for target in stmt.targets:
|
||||
if isinstance(target, ast.Name):
|
||||
return {"name": target.id, "value": ast.unparse(stmt.value)}
|
||||
|
||||
def parse_ann_assign(self, stmt):
|
||||
"""
|
||||
Parses an AnnAssign statement and returns a dictionary
|
||||
with the target's name, value, and annotation.
|
||||
"""
|
||||
if isinstance(stmt.target, ast.Name):
|
||||
return {
|
||||
"name": stmt.target.id,
|
||||
"value": ast.unparse(stmt.value) if stmt.value else None,
|
||||
"annotation": ast.unparse(stmt.annotation),
|
||||
}
|
||||
|
||||
def parse_function_def(self, stmt):
|
||||
"""
|
||||
Parses a FunctionDef statement and returns the parsed
|
||||
method and a boolean indicating if it's an __init__ method.
|
||||
"""
|
||||
method = self.parse_callable_details(stmt)
|
||||
return (method, True) if stmt.name == "__init__" else (method, False)
|
||||
|
||||
def parse_classes(self, node: ast.ClassDef) -> None:
|
||||
"""
|
||||
Extracts "classes" from the code, including inheritance and init methods.
|
||||
"""
|
||||
|
||||
class_details = ClassCodeDetails(
|
||||
name=node.name,
|
||||
doc=ast.get_docstring(node),
|
||||
bases=[ast.unparse(base) for base in node.bases],
|
||||
attributes=[],
|
||||
methods=[],
|
||||
init=None,
|
||||
)
|
||||
|
||||
for stmt in node.body:
|
||||
if isinstance(stmt, ast.Assign):
|
||||
if attr := self.parse_assign(stmt):
|
||||
class_details.attributes.append(attr)
|
||||
elif isinstance(stmt, ast.AnnAssign):
|
||||
if attr := self.parse_ann_assign(stmt):
|
||||
class_details.attributes.append(attr)
|
||||
elif isinstance(stmt, ast.FunctionDef):
|
||||
method, is_init = self.parse_function_def(stmt)
|
||||
if is_init:
|
||||
class_details.init = method
|
||||
else:
|
||||
class_details.methods.append(method)
|
||||
|
||||
self.data["classes"].append(class_details.dict())
|
||||
|
||||
def parse_global_vars(self, node: ast.Assign) -> None:
|
||||
"""
|
||||
Extracts global variables from the code.
|
||||
"""
|
||||
global_var = {
|
||||
"targets": [
|
||||
t.id if hasattr(t, "id") else ast.dump(t) for t in node.targets
|
||||
],
|
||||
"value": ast.unparse(node.value),
|
||||
}
|
||||
self.data["global_vars"].append(global_var)
|
||||
|
||||
def parse_code(self) -> Dict[str, Any]:
|
||||
"""
|
||||
Runs all parsing operations and returns the resulting data.
|
||||
"""
|
||||
tree = self.__get_tree()
|
||||
|
||||
for node in ast.walk(tree):
|
||||
self.parse_node(node)
|
||||
return self.data
|
||||
72
src/backend/langflow/interface/custom/component.py
Normal file
72
src/backend/langflow/interface/custom/component.py
Normal file
|
|
@ -0,0 +1,72 @@
|
|||
import ast
|
||||
from typing import Optional
|
||||
from pydantic import BaseModel
|
||||
from fastapi import HTTPException
|
||||
|
||||
from langflow.utils import validate
|
||||
from langflow.interface.custom.code_parser import CodeParser
|
||||
|
||||
|
||||
class ComponentCodeNullError(HTTPException):
|
||||
pass
|
||||
|
||||
|
||||
class ComponentFunctionEntrypointNameNullError(HTTPException):
|
||||
pass
|
||||
|
||||
|
||||
class Component(BaseModel):
|
||||
ERROR_CODE_NULL = "Python code must be provided."
|
||||
ERROR_FUNCTION_ENTRYPOINT_NAME_NULL = (
|
||||
"The name of the entrypoint function must be provided."
|
||||
)
|
||||
|
||||
code: Optional[str]
|
||||
function_entrypoint_name = "build"
|
||||
field_config: dict = {}
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
def get_code_tree(self, code: str):
|
||||
parser = CodeParser(code)
|
||||
return parser.parse_code()
|
||||
|
||||
def get_function(self):
|
||||
if not self.code:
|
||||
raise ComponentCodeNullError(
|
||||
status_code=400,
|
||||
detail={"error": self.ERROR_CODE_NULL, "traceback": ""},
|
||||
)
|
||||
|
||||
if not self.function_entrypoint_name:
|
||||
raise ComponentFunctionEntrypointNameNullError(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": self.ERROR_FUNCTION_ENTRYPOINT_NAME_NULL,
|
||||
"traceback": "",
|
||||
},
|
||||
)
|
||||
|
||||
return validate.create_function(self.code, self.function_entrypoint_name)
|
||||
|
||||
def build_template_config(self, attributes) -> dict:
|
||||
template_config = {}
|
||||
|
||||
for item in attributes:
|
||||
item_name = item.get("name")
|
||||
|
||||
if item_value := item.get("value"):
|
||||
if "display_name" in item_name:
|
||||
template_config["display_name"] = ast.literal_eval(item_value)
|
||||
|
||||
elif "description" in item_name:
|
||||
template_config["description"] = ast.literal_eval(item_value)
|
||||
|
||||
elif "field_config" in item_name:
|
||||
template_config["field_config"] = ast.literal_eval(item_value)
|
||||
|
||||
return template_config
|
||||
|
||||
def build(self):
|
||||
raise NotImplementedError
|
||||
58
src/backend/langflow/interface/custom/constants.py
Normal file
58
src/backend/langflow/interface/custom/constants.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
from langchain import PromptTemplate
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.schema import BaseRetriever, Document
|
||||
from langchain.text_splitter import TextSplitter
|
||||
from langchain.tools import Tool
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
|
||||
|
||||
LANGCHAIN_BASE_TYPES = {
|
||||
"Chain": Chain,
|
||||
"Tool": Tool,
|
||||
"BaseLLM": BaseLLM,
|
||||
"PromptTemplate": PromptTemplate,
|
||||
"BaseLoader": BaseLoader,
|
||||
"Document": Document,
|
||||
"TextSplitter": TextSplitter,
|
||||
"VectorStore": VectorStore,
|
||||
"Embeddings": Embeddings,
|
||||
"BaseRetriever": BaseRetriever,
|
||||
}
|
||||
|
||||
# Langchain base types plus Python base types
|
||||
CUSTOM_COMPONENT_SUPPORTED_TYPES = {
|
||||
**LANGCHAIN_BASE_TYPES,
|
||||
"str": str,
|
||||
"int": int,
|
||||
"float": float,
|
||||
"bool": bool,
|
||||
"list": list,
|
||||
"dict": dict,
|
||||
}
|
||||
|
||||
|
||||
DEFAULT_CUSTOM_COMPONENT_CODE = """from langflow import CustomComponent
|
||||
|
||||
from langchain.llms.base import BaseLLM
|
||||
from langchain.chains import LLMChain
|
||||
from langchain import PromptTemplate
|
||||
from langchain.schema import Document
|
||||
|
||||
import requests
|
||||
|
||||
class YourComponent(CustomComponent):
|
||||
display_name: str = "Custom Component"
|
||||
description: str = "Create any custom component you want!"
|
||||
|
||||
def build_config(self):
|
||||
return { "url": { "multiline": True, "required": True } }
|
||||
|
||||
def build(self, url: str, llm: BaseLLM, prompt: PromptTemplate) -> Document:
|
||||
response = requests.get(url)
|
||||
chain = LLMChain(llm=llm, prompt=prompt)
|
||||
result = chain.run(response.text[:300])
|
||||
return Document(page_content=str(result))
|
||||
"""
|
||||
194
src/backend/langflow/interface/custom/custom_component.py
Normal file
194
src/backend/langflow/interface/custom/custom_component.py
Normal file
|
|
@ -0,0 +1,194 @@
|
|||
from typing import Any, Callable, List, Optional
|
||||
from fastapi import HTTPException
|
||||
from langflow.interface.custom.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES
|
||||
from langflow.interface.custom.component import Component
|
||||
from langflow.interface.custom.directory_reader import DirectoryReader
|
||||
|
||||
from langflow.utils import validate
|
||||
|
||||
from langflow.database.base import session_getter
|
||||
from langflow.database.models.flow import Flow
|
||||
from pydantic import Extra
|
||||
|
||||
|
||||
class CustomComponent(Component, extra=Extra.allow):
|
||||
code: Optional[str]
|
||||
field_config: dict = {}
|
||||
code_class_base_inheritance = "CustomComponent"
|
||||
function_entrypoint_name = "build"
|
||||
function: Optional[Callable] = None
|
||||
return_type_valid_list = list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())
|
||||
repr_value: Optional[str] = ""
|
||||
|
||||
def __init__(self, **data):
|
||||
super().__init__(**data)
|
||||
|
||||
def custom_repr(self):
|
||||
return str(self.repr_value)
|
||||
|
||||
def build_config(self):
|
||||
return self.field_config
|
||||
|
||||
def _class_template_validation(self, code: str):
|
||||
TYPE_HINT_LIST = ["Optional", "Prompt", "PromptTemplate", "LLMChain"]
|
||||
|
||||
if not code:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": self.ERROR_CODE_NULL,
|
||||
"traceback": "",
|
||||
},
|
||||
)
|
||||
|
||||
reader = DirectoryReader("", False)
|
||||
|
||||
for type_hint in TYPE_HINT_LIST:
|
||||
if reader.is_type_hint_used_but_not_imported(type_hint, code):
|
||||
error_detail = {
|
||||
"error": "Type hint Error",
|
||||
"traceback": f"Type hint '{type_hint}' is used but not imported in the code.",
|
||||
}
|
||||
raise HTTPException(status_code=400, detail=error_detail)
|
||||
|
||||
def is_check_valid(self) -> bool:
|
||||
return self._class_template_validation(self.code) if self.code else False
|
||||
|
||||
def get_code_tree(self, code: str):
|
||||
return super().get_code_tree(code)
|
||||
|
||||
@property
|
||||
def get_function_entrypoint_args(self) -> str:
|
||||
if not self.code:
|
||||
return ""
|
||||
tree = self.get_code_tree(self.code)
|
||||
|
||||
component_classes = [
|
||||
cls
|
||||
for cls in tree["classes"]
|
||||
if self.code_class_base_inheritance in cls["bases"]
|
||||
]
|
||||
if not component_classes:
|
||||
return ""
|
||||
|
||||
# Assume the first Component class is the one we're interested in
|
||||
component_class = component_classes[0]
|
||||
build_methods = [
|
||||
method
|
||||
for method in component_class["methods"]
|
||||
if method["name"] == self.function_entrypoint_name
|
||||
]
|
||||
|
||||
if not build_methods:
|
||||
return ""
|
||||
|
||||
build_method = build_methods[0]
|
||||
|
||||
return build_method["args"]
|
||||
|
||||
@property
|
||||
def get_function_entrypoint_return_type(self) -> str:
|
||||
if not self.code:
|
||||
return ""
|
||||
tree = self.get_code_tree(self.code)
|
||||
|
||||
component_classes = [
|
||||
cls
|
||||
for cls in tree["classes"]
|
||||
if self.code_class_base_inheritance in cls["bases"]
|
||||
]
|
||||
if not component_classes:
|
||||
return ""
|
||||
|
||||
# Assume the first Component class is the one we're interested in
|
||||
component_class = component_classes[0]
|
||||
build_methods = [
|
||||
method
|
||||
for method in component_class["methods"]
|
||||
if method["name"] == self.function_entrypoint_name
|
||||
]
|
||||
|
||||
if not build_methods:
|
||||
return ""
|
||||
|
||||
build_method = build_methods[0]
|
||||
|
||||
return build_method["return_type"]
|
||||
|
||||
@property
|
||||
def get_main_class_name(self):
|
||||
tree = self.get_code_tree(self.code)
|
||||
|
||||
base_name = self.code_class_base_inheritance
|
||||
method_name = self.function_entrypoint_name
|
||||
|
||||
classes = []
|
||||
for item in tree.get("classes"):
|
||||
if base_name in item["bases"]:
|
||||
method_names = [method["name"] for method in item["methods"]]
|
||||
if method_name in method_names:
|
||||
classes.append(item["name"])
|
||||
|
||||
# Get just the first item
|
||||
return next(iter(classes), "")
|
||||
|
||||
@property
|
||||
def build_template_config(self):
|
||||
tree = self.get_code_tree(self.code)
|
||||
|
||||
attributes = [
|
||||
main_class["attributes"]
|
||||
for main_class in tree.get("classes")
|
||||
if main_class["name"] == self.get_main_class_name
|
||||
]
|
||||
# Get just the first item
|
||||
attributes = next(iter(attributes), [])
|
||||
|
||||
return super().build_template_config(attributes)
|
||||
|
||||
@property
|
||||
def get_function(self):
|
||||
return validate.create_function(self.code, self.function_entrypoint_name)
|
||||
|
||||
def load_flow(self, flow_id: str, tweaks: Optional[dict] = None) -> Any:
|
||||
from langflow.processing.process import build_sorted_vertices_with_caching
|
||||
from langflow.processing.process import process_tweaks
|
||||
|
||||
with session_getter() as session:
|
||||
graph_data = flow.data if (flow := session.get(Flow, flow_id)) else None
|
||||
if not graph_data:
|
||||
raise ValueError(f"Flow {flow_id} not found")
|
||||
if tweaks:
|
||||
graph_data = process_tweaks(graph_data=graph_data, tweaks=tweaks)
|
||||
return build_sorted_vertices_with_caching(graph_data)
|
||||
|
||||
def list_flows(self, *, get_session: Optional[Callable] = None) -> List[Flow]:
|
||||
get_session = get_session or session_getter
|
||||
with get_session() as session:
|
||||
flows = session.query(Flow).all()
|
||||
return flows
|
||||
|
||||
def get_flow(
|
||||
self,
|
||||
*,
|
||||
flow_name: Optional[str] = None,
|
||||
flow_id: Optional[str] = None,
|
||||
tweaks: Optional[dict] = None,
|
||||
get_session: Optional[Callable] = None,
|
||||
) -> Flow:
|
||||
get_session = get_session or session_getter
|
||||
|
||||
with get_session() as session:
|
||||
if flow_id:
|
||||
flow = session.query(Flow).get(flow_id)
|
||||
elif flow_name:
|
||||
flow = session.query(Flow).filter(Flow.name == flow_name).first()
|
||||
else:
|
||||
raise ValueError("Either flow_name or flow_id must be provided")
|
||||
|
||||
if not flow:
|
||||
raise ValueError(f"Flow {flow_name or flow_id} not found")
|
||||
return self.load_flow(flow.id, tweaks)
|
||||
|
||||
def build(self):
|
||||
raise NotImplementedError
|
||||
239
src/backend/langflow/interface/custom/directory_reader.py
Normal file
239
src/backend/langflow/interface/custom/directory_reader.py
Normal file
|
|
@ -0,0 +1,239 @@
|
|||
import os
|
||||
import ast
|
||||
import zlib
|
||||
|
||||
|
||||
class CustomComponentPathValueError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class StringCompressor:
|
||||
def __init__(self, input_string):
|
||||
"""Initialize StringCompressor with a string to compress."""
|
||||
self.input_string = input_string
|
||||
|
||||
def compress_string(self):
|
||||
"""
|
||||
Compress the initial string and return the compressed data.
|
||||
"""
|
||||
# Convert string to bytes
|
||||
byte_data = self.input_string.encode("utf-8")
|
||||
# Compress the bytes
|
||||
self.compressed_data = zlib.compress(byte_data)
|
||||
|
||||
return self.compressed_data
|
||||
|
||||
def decompress_string(self):
|
||||
"""
|
||||
Decompress the compressed data and return the original string.
|
||||
"""
|
||||
# Decompress the bytes
|
||||
decompressed_data = zlib.decompress(self.compressed_data)
|
||||
# Convert bytes back to string
|
||||
return decompressed_data.decode("utf-8")
|
||||
|
||||
|
||||
class DirectoryReader:
|
||||
# Ensure the base path to read the files that contain
|
||||
# the custom components from this directory.
|
||||
base_path = ""
|
||||
|
||||
def __init__(self, directory_path, compress_code_field=False):
|
||||
"""
|
||||
Initialize DirectoryReader with a directory path
|
||||
and a flag indicating whether to compress the code.
|
||||
"""
|
||||
self.directory_path = directory_path
|
||||
self.compress_code_field = compress_code_field
|
||||
|
||||
def get_safe_path(self):
|
||||
"""Check if the path is valid and return it, or None if it's not."""
|
||||
return self.directory_path if self.is_valid_path() else None
|
||||
|
||||
def is_valid_path(self) -> bool:
|
||||
"""Check if the directory path is valid by comparing it to the base path."""
|
||||
fullpath = os.path.normpath(os.path.join(self.directory_path))
|
||||
return fullpath.startswith(self.base_path)
|
||||
|
||||
def is_empty_file(self, file_content):
|
||||
"""
|
||||
Check if the file content is empty.
|
||||
"""
|
||||
return len(file_content.strip()) == 0
|
||||
|
||||
def filter_loaded_components(self, data: dict, with_errors: bool) -> dict:
|
||||
items = [
|
||||
{
|
||||
"name": menu["name"],
|
||||
"path": menu["path"],
|
||||
"components": [
|
||||
component
|
||||
for component in menu["components"]
|
||||
if (component["error"] if with_errors else not component["error"])
|
||||
],
|
||||
}
|
||||
for menu in data["menu"]
|
||||
]
|
||||
filtred = [menu for menu in items if menu["components"]]
|
||||
return {"menu": filtred}
|
||||
|
||||
def validate_code(self, file_content):
|
||||
"""
|
||||
Validate the Python code by trying to parse it with ast.parse.
|
||||
"""
|
||||
try:
|
||||
ast.parse(file_content)
|
||||
return True
|
||||
except SyntaxError:
|
||||
return False
|
||||
|
||||
def validate_build(self, file_content):
|
||||
"""
|
||||
Check if the file content contains a function named 'build'.
|
||||
"""
|
||||
return "def build" in file_content
|
||||
|
||||
def read_file_content(self, file_path):
|
||||
"""
|
||||
Read and return the content of a file.
|
||||
"""
|
||||
if not os.path.isfile(file_path):
|
||||
return None
|
||||
with open(file_path, "r") as file:
|
||||
return file.read()
|
||||
|
||||
def get_files(self):
|
||||
"""
|
||||
Walk through the directory path and return a list of all .py files.
|
||||
"""
|
||||
if not (safe_path := self.get_safe_path()):
|
||||
raise CustomComponentPathValueError(
|
||||
f"The path needs to start with '{self.base_path}'."
|
||||
)
|
||||
|
||||
file_list = []
|
||||
for root, _, files in os.walk(safe_path):
|
||||
file_list.extend(
|
||||
os.path.join(root, filename)
|
||||
for filename in files
|
||||
if filename.endswith(".py")
|
||||
)
|
||||
return file_list
|
||||
|
||||
def find_menu(self, response, menu_name):
|
||||
"""
|
||||
Find and return a menu by its name in the response.
|
||||
"""
|
||||
return next(
|
||||
(menu for menu in response["menu"] if menu["name"] == menu_name),
|
||||
None,
|
||||
)
|
||||
|
||||
def _is_type_hint_imported(self, type_hint_name: str, code: str) -> bool:
|
||||
"""
|
||||
Check if a specific type hint is imported
|
||||
from the typing module in the given code.
|
||||
"""
|
||||
module = ast.parse(code)
|
||||
|
||||
return any(
|
||||
isinstance(node, ast.ImportFrom)
|
||||
and node.module == "typing"
|
||||
and any(alias.name == type_hint_name for alias in node.names)
|
||||
for node in ast.walk(module)
|
||||
)
|
||||
|
||||
def _is_type_hint_used_in_args(self, type_hint_name: str, code: str) -> bool:
|
||||
"""
|
||||
Check if a specific type hint is used in the
|
||||
function definitions within the given code.
|
||||
"""
|
||||
module = ast.parse(code)
|
||||
|
||||
for node in ast.walk(module):
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
for arg in node.args.args:
|
||||
if self._is_type_hint_in_arg_annotation(
|
||||
arg.annotation, type_hint_name
|
||||
):
|
||||
return True
|
||||
return False
|
||||
|
||||
def _is_type_hint_in_arg_annotation(self, annotation, type_hint_name: str) -> bool:
|
||||
"""
|
||||
Helper function to check if a type hint exists in an annotation.
|
||||
"""
|
||||
return (
|
||||
annotation is not None
|
||||
and isinstance(annotation, ast.Subscript)
|
||||
and isinstance(annotation.value, ast.Name)
|
||||
and annotation.value.id == type_hint_name
|
||||
)
|
||||
|
||||
def is_type_hint_used_but_not_imported(
|
||||
self, type_hint_name: str, code: str
|
||||
) -> bool:
|
||||
"""
|
||||
Check if a type hint is used but not imported in the given code.
|
||||
"""
|
||||
try:
|
||||
return self._is_type_hint_used_in_args(
|
||||
type_hint_name, code
|
||||
) and not self._is_type_hint_imported(type_hint_name, code)
|
||||
except SyntaxError:
|
||||
# Returns True if there's something wrong with the code
|
||||
# TODO : Find a better way to handle this
|
||||
return True
|
||||
|
||||
def process_file(self, file_path):
|
||||
"""
|
||||
Process a file by validating its content and
|
||||
returning the result and content/error message.
|
||||
"""
|
||||
file_content = self.read_file_content(file_path)
|
||||
|
||||
if file_content is None:
|
||||
return False, f"Could not read {file_path}"
|
||||
elif self.is_empty_file(file_content):
|
||||
return False, "Empty file"
|
||||
elif not self.validate_code(file_content):
|
||||
return False, "Syntax error"
|
||||
elif not self.validate_build(file_content):
|
||||
return False, "Missing build function"
|
||||
elif self.is_type_hint_used_but_not_imported("Optional", file_content):
|
||||
return False, "Type hint 'Optional' is used but not imported in the code."
|
||||
else:
|
||||
if self.compress_code_field:
|
||||
file_content = str(StringCompressor(file_content).compress_string())
|
||||
return True, file_content
|
||||
|
||||
def build_component_menu_list(self, file_paths):
|
||||
"""
|
||||
Build a list of menus with their components
|
||||
from the .py files in the directory.
|
||||
"""
|
||||
response = {"menu": []}
|
||||
|
||||
for file_path in file_paths:
|
||||
menu_name = os.path.basename(os.path.dirname(file_path))
|
||||
filename = os.path.basename(file_path)
|
||||
validation_result, result_content = self.process_file(file_path)
|
||||
|
||||
menu_result = self.find_menu(response, menu_name) or {
|
||||
"name": menu_name,
|
||||
"path": os.path.dirname(file_path),
|
||||
"components": [],
|
||||
}
|
||||
|
||||
component_info = {
|
||||
"name": filename.split(".")[0],
|
||||
"file": filename,
|
||||
"code": result_content if validation_result else "",
|
||||
"error": "" if validation_result else result_content,
|
||||
}
|
||||
menu_result["components"].append(component_info)
|
||||
|
||||
if menu_result not in response["menu"]:
|
||||
response["menu"].append(menu_result)
|
||||
|
||||
return response
|
||||
29
src/backend/langflow/interface/custom/schema.py
Normal file
29
src/backend/langflow/interface/custom/schema.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
class ClassCodeDetails(BaseModel):
|
||||
"""
|
||||
A dataclass for storing details about a class.
|
||||
"""
|
||||
|
||||
name: str
|
||||
doc: Optional[str]
|
||||
bases: list
|
||||
attributes: list
|
||||
methods: list
|
||||
init: Optional[dict] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class CallableCodeDetails(BaseModel):
|
||||
"""
|
||||
A dataclass for storing details about a callable.
|
||||
"""
|
||||
|
||||
name: str
|
||||
doc: Optional[str]
|
||||
args: list
|
||||
body: list
|
||||
return_type: Optional[str]
|
||||
|
|
@ -10,8 +10,12 @@ from langchain import (
|
|||
text_splitter,
|
||||
)
|
||||
from langchain.agents import agent_toolkits
|
||||
from langchain.chat_models import AzureChatOpenAI, ChatOpenAI
|
||||
from langchain.chat_models import ChatAnthropic
|
||||
from langchain.chat_models import (
|
||||
AzureChatOpenAI,
|
||||
ChatOpenAI,
|
||||
ChatVertexAI,
|
||||
ChatAnthropic,
|
||||
)
|
||||
|
||||
from langflow.interface.importing.utils import import_class
|
||||
from langflow.interface.agents.custom import CUSTOM_AGENTS
|
||||
|
|
@ -22,6 +26,7 @@ llm_type_to_cls_dict = llms.type_to_cls_dict
|
|||
llm_type_to_cls_dict["anthropic-chat"] = ChatAnthropic # type: ignore
|
||||
llm_type_to_cls_dict["azure-chat"] = AzureChatOpenAI # type: ignore
|
||||
llm_type_to_cls_dict["openai-chat"] = ChatOpenAI # type: ignore
|
||||
llm_type_to_cls_dict["vertexai-chat"] = ChatVertexAI # type: ignore
|
||||
|
||||
|
||||
# Toolkits
|
||||
|
|
|
|||
|
|
@ -9,7 +9,9 @@ from langchain.base_language import BaseLanguageModel
|
|||
from langchain.chains.base import Chain
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.tools import BaseTool
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
from langflow.utils import validate
|
||||
from langflow.interface.wrappers.base import wrapper_creator
|
||||
|
||||
|
||||
def import_module(module_path: str) -> Any:
|
||||
|
|
@ -44,6 +46,9 @@ def import_by_type(_type: str, name: str) -> Any:
|
|||
"documentloaders": import_documentloader,
|
||||
"textsplitters": import_textsplitter,
|
||||
"utilities": import_utility,
|
||||
"output_parsers": import_output_parser,
|
||||
"retrievers": import_retriever,
|
||||
"custom_components": import_custom_component,
|
||||
}
|
||||
if _type == "llms":
|
||||
key = "chat" if "chat" in name.lower() else "llm"
|
||||
|
|
@ -54,11 +59,28 @@ def import_by_type(_type: str, name: str) -> Any:
|
|||
return loaded_func(name)
|
||||
|
||||
|
||||
def import_custom_component(custom_component: str) -> CustomComponent:
|
||||
"""Import custom component from custom component name"""
|
||||
return import_class(
|
||||
f"langflow.interface.custom.custom_component.{custom_component}"
|
||||
)
|
||||
|
||||
|
||||
def import_output_parser(output_parser: str) -> Any:
|
||||
"""Import output parser from output parser name"""
|
||||
return import_module(f"from langchain.output_parsers import {output_parser}")
|
||||
|
||||
|
||||
def import_chat_llm(llm: str) -> BaseChatModel:
|
||||
"""Import chat llm from llm name"""
|
||||
return import_class(f"langchain.chat_models.{llm}")
|
||||
|
||||
|
||||
def import_retriever(retriever: str) -> Any:
|
||||
"""Import retriever from retriever name"""
|
||||
return import_module(f"from langchain.retrievers import {retriever}")
|
||||
|
||||
|
||||
def import_memory(memory: str) -> Any:
|
||||
"""Import memory from memory name"""
|
||||
return import_module(f"from langchain.memory import {memory}")
|
||||
|
|
@ -84,7 +106,11 @@ def import_prompt(prompt: str) -> Type[PromptTemplate]:
|
|||
|
||||
def import_wrapper(wrapper: str) -> Any:
|
||||
"""Import wrapper from wrapper name"""
|
||||
return import_module(f"from langchain.requests import {wrapper}")
|
||||
if (
|
||||
isinstance(wrapper_creator.type_dict, dict)
|
||||
and wrapper in wrapper_creator.type_dict
|
||||
):
|
||||
return wrapper_creator.type_dict.get(wrapper)
|
||||
|
||||
|
||||
def import_toolkit(toolkit: str) -> Any:
|
||||
|
|
@ -155,3 +181,8 @@ def get_function(code):
|
|||
function_name = validate.extract_function_name(code)
|
||||
|
||||
return validate.create_function(code, function_name)
|
||||
|
||||
|
||||
def get_function_custom(code):
|
||||
class_name = validate.extract_class_name(code)
|
||||
return validate.create_class(code, class_name)
|
||||
|
|
|
|||
0
src/backend/langflow/interface/initialize/__init__.py
Normal file
0
src/backend/langflow/interface/initialize/__init__.py
Normal file
9
src/backend/langflow/interface/initialize/llm.py
Normal file
9
src/backend/langflow/interface/initialize/llm.py
Normal file
|
|
@ -0,0 +1,9 @@
|
|||
def initialize_vertexai(class_object, params):
|
||||
if credentials_path := params.get("credentials"):
|
||||
from google.oauth2 import service_account # type: ignore
|
||||
|
||||
credentials_object = service_account.Credentials.from_service_account_file(
|
||||
filename=credentials_path
|
||||
)
|
||||
params["credentials"] = credentials_object
|
||||
return class_object(**params)
|
||||
474
src/backend/langflow/interface/initialize/loading.py
Normal file
474
src/backend/langflow/interface/initialize/loading.py
Normal file
|
|
@ -0,0 +1,474 @@
|
|||
import json
|
||||
from typing import Any, Callable, Dict, Sequence, Type
|
||||
|
||||
from langchain.agents import agent as agent_module
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.agent_toolkits.base import BaseToolkit
|
||||
from langchain.agents.tools import BaseTool
|
||||
from langflow.interface.initialize.llm import initialize_vertexai
|
||||
from langflow.interface.initialize.utils import handle_format_kwargs, handle_node_type
|
||||
|
||||
from langflow.interface.initialize.vector_store import vecstore_initializer
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from langflow.interface.importing.utils import (
|
||||
get_function,
|
||||
get_function_custom,
|
||||
import_by_type,
|
||||
)
|
||||
from langflow.interface.custom_lists import CUSTOM_NODES
|
||||
from langflow.interface.agents.base import agent_creator
|
||||
from langflow.interface.toolkits.base import toolkits_creator
|
||||
from langflow.interface.chains.base import chain_creator
|
||||
from langflow.interface.output_parsers.base import output_parser_creator
|
||||
from langflow.interface.retrievers.base import retriever_creator
|
||||
from langflow.interface.wrappers.base import wrapper_creator
|
||||
from langflow.interface.utils import load_file_into_dict
|
||||
from langflow.utils import validate
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from langchain.document_loaders.base import BaseLoader
|
||||
|
||||
|
||||
def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any:
|
||||
"""Instantiate class from module type and key, and params"""
|
||||
params = convert_params_to_sets(params)
|
||||
params = convert_kwargs(params)
|
||||
if node_type in CUSTOM_NODES:
|
||||
if custom_node := CUSTOM_NODES.get(node_type):
|
||||
if hasattr(custom_node, "initialize"):
|
||||
return custom_node.initialize(**params)
|
||||
return custom_node(**params)
|
||||
|
||||
class_object = import_by_type(_type=base_type, name=node_type)
|
||||
return instantiate_based_on_type(class_object, base_type, node_type, params)
|
||||
|
||||
|
||||
def convert_params_to_sets(params):
|
||||
"""Convert certain params to sets"""
|
||||
if "allowed_special" in params:
|
||||
params["allowed_special"] = set(params["allowed_special"])
|
||||
if "disallowed_special" in params:
|
||||
params["disallowed_special"] = set(params["disallowed_special"])
|
||||
return params
|
||||
|
||||
|
||||
def convert_kwargs(params):
|
||||
# if *kwargs are passed as a string, convert to dict
|
||||
# first find any key that has kwargs or config in it
|
||||
kwargs_keys = [key for key in params.keys() if "kwargs" in key or "config" in key]
|
||||
for key in kwargs_keys:
|
||||
if isinstance(params[key], str):
|
||||
try:
|
||||
params[key] = json.loads(params[key])
|
||||
except json.JSONDecodeError:
|
||||
# if the string is not a valid json string, we will
|
||||
# remove the key from the params
|
||||
params.pop(key, None)
|
||||
return params
|
||||
|
||||
|
||||
def instantiate_based_on_type(class_object, base_type, node_type, params):
|
||||
if base_type == "agents":
|
||||
return instantiate_agent(node_type, class_object, params)
|
||||
elif base_type == "prompts":
|
||||
return instantiate_prompt(node_type, class_object, params)
|
||||
elif base_type == "tools":
|
||||
tool = instantiate_tool(node_type, class_object, params)
|
||||
if hasattr(tool, "name") and isinstance(tool, BaseTool):
|
||||
# tool name shouldn't contain spaces
|
||||
tool.name = tool.name.replace(" ", "_")
|
||||
return tool
|
||||
elif base_type == "toolkits":
|
||||
return instantiate_toolkit(node_type, class_object, params)
|
||||
elif base_type == "embeddings":
|
||||
return instantiate_embedding(class_object, params)
|
||||
elif base_type == "vectorstores":
|
||||
return instantiate_vectorstore(class_object, params)
|
||||
elif base_type == "documentloaders":
|
||||
return instantiate_documentloader(class_object, params)
|
||||
elif base_type == "textsplitters":
|
||||
return instantiate_textsplitter(class_object, params)
|
||||
elif base_type == "utilities":
|
||||
return instantiate_utility(node_type, class_object, params)
|
||||
elif base_type == "chains":
|
||||
return instantiate_chains(node_type, class_object, params)
|
||||
elif base_type == "output_parsers":
|
||||
return instantiate_output_parser(node_type, class_object, params)
|
||||
elif base_type == "llms":
|
||||
return instantiate_llm(node_type, class_object, params)
|
||||
elif base_type == "retrievers":
|
||||
return instantiate_retriever(node_type, class_object, params)
|
||||
elif base_type == "memory":
|
||||
return instantiate_memory(node_type, class_object, params)
|
||||
elif base_type == "custom_components":
|
||||
return instantiate_custom_component(node_type, class_object, params)
|
||||
elif base_type == "wrappers":
|
||||
return instantiate_wrapper(node_type, class_object, params)
|
||||
else:
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def instantiate_custom_component(node_type, class_object, params):
|
||||
class_object = get_function_custom(params.pop("code"))
|
||||
custom_component = class_object()
|
||||
built_object = custom_component.build(**params)
|
||||
return built_object, {"repr": custom_component.custom_repr()}
|
||||
|
||||
|
||||
def instantiate_wrapper(node_type, class_object, params):
|
||||
if node_type in wrapper_creator.from_method_nodes:
|
||||
method = wrapper_creator.from_method_nodes[node_type]
|
||||
if class_method := getattr(class_object, method, None):
|
||||
return class_method(**params)
|
||||
raise ValueError(f"Method {method} not found in {class_object}")
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def instantiate_output_parser(node_type, class_object, params):
|
||||
if node_type in output_parser_creator.from_method_nodes:
|
||||
method = output_parser_creator.from_method_nodes[node_type]
|
||||
if class_method := getattr(class_object, method, None):
|
||||
return class_method(**params)
|
||||
raise ValueError(f"Method {method} not found in {class_object}")
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def instantiate_llm(node_type, class_object, params: Dict):
|
||||
# This is a workaround so JinaChat works until streaming is implemented
|
||||
# if "openai_api_base" in params and "jina" in params["openai_api_base"]:
|
||||
# False if condition is True
|
||||
if node_type == "VertexAI":
|
||||
return initialize_vertexai(class_object=class_object, params=params)
|
||||
# max_tokens sometimes is a string and should be an int
|
||||
if "max_tokens" in params:
|
||||
if isinstance(params["max_tokens"], str) and params["max_tokens"].isdigit():
|
||||
params["max_tokens"] = int(params["max_tokens"])
|
||||
elif not isinstance(params.get("max_tokens"), int):
|
||||
params.pop("max_tokens", None)
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def instantiate_memory(node_type, class_object, params):
|
||||
# process input_key and output_key to remove them if
|
||||
# they are empty strings
|
||||
if node_type == "ConversationEntityMemory":
|
||||
params.pop("memory_key", None)
|
||||
|
||||
for key in ["input_key", "output_key"]:
|
||||
if key in params and (params[key] == "" or not params[key]):
|
||||
params.pop(key)
|
||||
|
||||
try:
|
||||
if "retriever" in params and hasattr(params["retriever"], "as_retriever"):
|
||||
params["retriever"] = params["retriever"].as_retriever()
|
||||
return class_object(**params)
|
||||
# I want to catch a specific attribute error that happens
|
||||
# when the object does not have a cursor attribute
|
||||
except Exception as exc:
|
||||
if "object has no attribute 'cursor'" in str(
|
||||
exc
|
||||
) or 'object has no field "conn"' in str(exc):
|
||||
raise AttributeError(
|
||||
(
|
||||
"Failed to build connection to database."
|
||||
f" Please check your connection string and try again. Error: {exc}"
|
||||
)
|
||||
) from exc
|
||||
raise exc
|
||||
|
||||
|
||||
def instantiate_retriever(node_type, class_object, params):
|
||||
if "retriever" in params and hasattr(params["retriever"], "as_retriever"):
|
||||
params["retriever"] = params["retriever"].as_retriever()
|
||||
if node_type in retriever_creator.from_method_nodes:
|
||||
method = retriever_creator.from_method_nodes[node_type]
|
||||
if class_method := getattr(class_object, method, None):
|
||||
return class_method(**params)
|
||||
raise ValueError(f"Method {method} not found in {class_object}")
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def instantiate_chains(node_type, class_object: Type[Chain], params: Dict):
|
||||
if "retriever" in params and hasattr(params["retriever"], "as_retriever"):
|
||||
params["retriever"] = params["retriever"].as_retriever()
|
||||
if node_type in chain_creator.from_method_nodes:
|
||||
method = chain_creator.from_method_nodes[node_type]
|
||||
if class_method := getattr(class_object, method, None):
|
||||
return class_method(**params)
|
||||
raise ValueError(f"Method {method} not found in {class_object}")
|
||||
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def instantiate_agent(node_type, class_object: Type[agent_module.Agent], params: Dict):
|
||||
if node_type in agent_creator.from_method_nodes:
|
||||
method = agent_creator.from_method_nodes[node_type]
|
||||
if class_method := getattr(class_object, method, None):
|
||||
agent = class_method(**params)
|
||||
tools = params.get("tools", [])
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent, tools=tools, handle_parsing_errors=True
|
||||
)
|
||||
return load_agent_executor(class_object, params)
|
||||
|
||||
|
||||
def instantiate_prompt(node_type, class_object, params: Dict):
|
||||
params, prompt = handle_node_type(node_type, class_object, params)
|
||||
format_kwargs = handle_format_kwargs(prompt, params)
|
||||
return prompt, format_kwargs
|
||||
|
||||
|
||||
def instantiate_tool(node_type, class_object: Type[BaseTool], params: Dict):
|
||||
if node_type == "JsonSpec":
|
||||
if file_dict := load_file_into_dict(params.pop("path")):
|
||||
params["dict_"] = file_dict
|
||||
else:
|
||||
raise ValueError("Invalid file")
|
||||
return class_object(**params)
|
||||
elif node_type == "PythonFunctionTool":
|
||||
params["func"] = get_function(params.get("code"))
|
||||
return class_object(**params)
|
||||
elif node_type == "PythonFunction":
|
||||
function_string = params["code"]
|
||||
if isinstance(function_string, str):
|
||||
return validate.eval_function(function_string)
|
||||
raise ValueError("Function should be a string")
|
||||
elif node_type.lower() == "tool":
|
||||
return class_object(**params)
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def instantiate_toolkit(node_type, class_object: Type[BaseToolkit], params: Dict):
|
||||
loaded_toolkit = class_object(**params)
|
||||
# Commenting this out for now to use toolkits as normal tools
|
||||
# if toolkits_creator.has_create_function(node_type):
|
||||
# return load_toolkits_executor(node_type, loaded_toolkit, params)
|
||||
if isinstance(loaded_toolkit, BaseToolkit):
|
||||
return loaded_toolkit.get_tools()
|
||||
return loaded_toolkit
|
||||
|
||||
|
||||
def instantiate_embedding(class_object, params: Dict):
|
||||
params.pop("model", None)
|
||||
params.pop("headers", None)
|
||||
try:
|
||||
return class_object(**params)
|
||||
except ValidationError:
|
||||
params = {
|
||||
key: value
|
||||
for key, value in params.items()
|
||||
if key in class_object.__fields__
|
||||
}
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def instantiate_vectorstore(class_object: Type[VectorStore], params: Dict):
|
||||
search_kwargs = params.pop("search_kwargs", {})
|
||||
if initializer := vecstore_initializer.get(class_object.__name__):
|
||||
vecstore = initializer(class_object, params)
|
||||
else:
|
||||
if "texts" in params:
|
||||
params["documents"] = params.pop("texts")
|
||||
vecstore = class_object.from_documents(**params)
|
||||
|
||||
# ! This might not work. Need to test
|
||||
if search_kwargs and hasattr(vecstore, "as_retriever"):
|
||||
vecstore = vecstore.as_retriever(search_kwargs=search_kwargs)
|
||||
|
||||
return vecstore
|
||||
|
||||
|
||||
def instantiate_documentloader(class_object: Type[BaseLoader], params: Dict):
|
||||
if "file_filter" in params:
|
||||
# file_filter will be a string but we need a function
|
||||
# that will be used to filter the files using file_filter
|
||||
# like lambda x: x.endswith(".txt") but as we don't know
|
||||
# anything besides the string, we will simply check if the string is
|
||||
# in x and if it is, we will return True
|
||||
file_filter = params.pop("file_filter")
|
||||
extensions = file_filter.split(",")
|
||||
params["file_filter"] = lambda x: any(
|
||||
extension.strip() in x for extension in extensions
|
||||
)
|
||||
metadata = params.pop("metadata", None)
|
||||
if metadata and isinstance(metadata, str):
|
||||
try:
|
||||
metadata = json.loads(metadata)
|
||||
except json.JSONDecodeError as exc:
|
||||
raise ValueError(
|
||||
"The metadata you provided is not a valid JSON string."
|
||||
) from exc
|
||||
docs = class_object(**params).load()
|
||||
# Now if metadata is an empty dict, we will not add it to the documents
|
||||
if metadata:
|
||||
for doc in docs:
|
||||
# If the document already has metadata, we will not overwrite it
|
||||
if not doc.metadata:
|
||||
doc.metadata = metadata
|
||||
else:
|
||||
doc.metadata.update(metadata)
|
||||
|
||||
return docs
|
||||
|
||||
|
||||
def instantiate_textsplitter(
|
||||
class_object,
|
||||
params: Dict,
|
||||
):
|
||||
try:
|
||||
documents = params.pop("documents")
|
||||
if not isinstance(documents, list):
|
||||
documents = [documents]
|
||||
except KeyError as exc:
|
||||
raise ValueError(
|
||||
"The source you provided did not load correctly or was empty."
|
||||
"Try changing the chunk_size of the Text Splitter."
|
||||
) from exc
|
||||
|
||||
if (
|
||||
"separator_type" in params and params["separator_type"] == "Text"
|
||||
) or "separator_type" not in params:
|
||||
params.pop("separator_type", None)
|
||||
# separators might come in as an escaped string like \\n
|
||||
# so we need to convert it to a string
|
||||
if "separators" in params:
|
||||
params["separators"] = (
|
||||
params["separators"].encode().decode("unicode-escape")
|
||||
)
|
||||
text_splitter = class_object(**params)
|
||||
else:
|
||||
from langchain.text_splitter import Language
|
||||
|
||||
language = params.pop("separator_type", None)
|
||||
params["language"] = Language(language)
|
||||
params.pop("separators", None)
|
||||
|
||||
text_splitter = class_object.from_language(**params)
|
||||
return text_splitter.split_documents(documents)
|
||||
|
||||
|
||||
def instantiate_utility(node_type, class_object, params: Dict):
|
||||
if node_type == "SQLDatabase":
|
||||
return class_object.from_uri(params.pop("uri"))
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def replace_zero_shot_prompt_with_prompt_template(nodes):
|
||||
"""Replace ZeroShotPrompt with PromptTemplate"""
|
||||
for node in nodes:
|
||||
if node["data"]["type"] == "ZeroShotPrompt":
|
||||
# Build Prompt Template
|
||||
tools = [
|
||||
tool
|
||||
for tool in nodes
|
||||
if tool["type"] != "chatOutputNode"
|
||||
and "Tool" in tool["data"]["node"]["base_classes"]
|
||||
]
|
||||
node["data"] = build_prompt_template(prompt=node["data"], tools=tools)
|
||||
break
|
||||
return nodes
|
||||
|
||||
|
||||
def load_agent_executor(agent_class: type[agent_module.Agent], params, **kwargs):
|
||||
"""Load agent executor from agent class, tools and chain"""
|
||||
allowed_tools: Sequence[BaseTool] = params.get("allowed_tools", [])
|
||||
llm_chain = params["llm_chain"]
|
||||
# agent has hidden args for memory. might need to be support
|
||||
# memory = params["memory"]
|
||||
# if allowed_tools is not a list or set, make it a list
|
||||
if not isinstance(allowed_tools, (list, set)) and isinstance(
|
||||
allowed_tools, BaseTool
|
||||
):
|
||||
allowed_tools = [allowed_tools]
|
||||
tool_names = [tool.name for tool in allowed_tools]
|
||||
# Agent class requires an output_parser but Agent classes
|
||||
# have a default output_parser.
|
||||
agent = agent_class(allowed_tools=tool_names, llm_chain=llm_chain) # type: ignore
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=allowed_tools,
|
||||
handle_parsing_errors=True,
|
||||
# memory=memory,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def load_toolkits_executor(node_type: str, toolkit: BaseToolkit, params: dict):
|
||||
create_function: Callable = toolkits_creator.get_create_function(node_type)
|
||||
if llm := params.get("llm"):
|
||||
return create_function(llm=llm, toolkit=toolkit)
|
||||
|
||||
|
||||
def build_prompt_template(prompt, tools):
|
||||
"""Build PromptTemplate from ZeroShotPrompt"""
|
||||
prefix = prompt["node"]["template"]["prefix"]["value"]
|
||||
suffix = prompt["node"]["template"]["suffix"]["value"]
|
||||
format_instructions = prompt["node"]["template"]["format_instructions"]["value"]
|
||||
|
||||
tool_strings = "\n".join(
|
||||
[
|
||||
f"{tool['data']['node']['name']}: {tool['data']['node']['description']}"
|
||||
for tool in tools
|
||||
]
|
||||
)
|
||||
tool_names = ", ".join([tool["data"]["node"]["name"] for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
value = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
|
||||
|
||||
prompt["type"] = "PromptTemplate"
|
||||
|
||||
prompt["node"] = {
|
||||
"template": {
|
||||
"_type": "prompt",
|
||||
"input_variables": {
|
||||
"type": "str",
|
||||
"required": True,
|
||||
"placeholder": "",
|
||||
"list": True,
|
||||
"show": False,
|
||||
"multiline": False,
|
||||
},
|
||||
"output_parser": {
|
||||
"type": "BaseOutputParser",
|
||||
"required": False,
|
||||
"placeholder": "",
|
||||
"list": False,
|
||||
"show": False,
|
||||
"multline": False,
|
||||
"value": None,
|
||||
},
|
||||
"template": {
|
||||
"type": "str",
|
||||
"required": True,
|
||||
"placeholder": "",
|
||||
"list": False,
|
||||
"show": True,
|
||||
"multiline": True,
|
||||
"value": value,
|
||||
},
|
||||
"template_format": {
|
||||
"type": "str",
|
||||
"required": False,
|
||||
"placeholder": "",
|
||||
"list": False,
|
||||
"show": False,
|
||||
"multline": False,
|
||||
"value": "f-string",
|
||||
},
|
||||
"validate_template": {
|
||||
"type": "bool",
|
||||
"required": False,
|
||||
"placeholder": "",
|
||||
"list": False,
|
||||
"show": False,
|
||||
"multline": False,
|
||||
"value": True,
|
||||
},
|
||||
},
|
||||
"description": "Schema to represent a prompt for an LLM.",
|
||||
"base_classes": ["BasePromptTemplate"],
|
||||
}
|
||||
|
||||
return prompt
|
||||
103
src/backend/langflow/interface/initialize/utils.py
Normal file
103
src/backend/langflow/interface/initialize/utils.py
Normal file
|
|
@ -0,0 +1,103 @@
|
|||
import contextlib
|
||||
import json
|
||||
from typing import Any, Dict, List
|
||||
|
||||
from langchain.agents import ZeroShotAgent
|
||||
|
||||
|
||||
from langchain.schema import Document, BaseOutputParser
|
||||
|
||||
|
||||
def handle_node_type(node_type, class_object, params: Dict):
|
||||
if node_type == "ZeroShotPrompt":
|
||||
params = check_tools_in_params(params)
|
||||
prompt = ZeroShotAgent.create_prompt(**params)
|
||||
elif "MessagePromptTemplate" in node_type:
|
||||
prompt = instantiate_from_template(class_object, params)
|
||||
elif node_type == "ChatPromptTemplate":
|
||||
prompt = class_object.from_messages(**params)
|
||||
else:
|
||||
prompt = class_object(**params)
|
||||
return params, prompt
|
||||
|
||||
|
||||
def check_tools_in_params(params: Dict):
|
||||
if "tools" not in params:
|
||||
params["tools"] = []
|
||||
return params
|
||||
|
||||
|
||||
def instantiate_from_template(class_object, params: Dict):
|
||||
from_template_params = {
|
||||
"template": params.pop("prompt", params.pop("template", ""))
|
||||
}
|
||||
if not from_template_params.get("template"):
|
||||
raise ValueError("Prompt template is required")
|
||||
return class_object.from_template(**from_template_params)
|
||||
|
||||
|
||||
def handle_format_kwargs(prompt, params: Dict):
|
||||
format_kwargs: Dict[str, Any] = {}
|
||||
for input_variable in prompt.input_variables:
|
||||
if input_variable in params:
|
||||
format_kwargs = handle_variable(params, input_variable, format_kwargs)
|
||||
return format_kwargs
|
||||
|
||||
|
||||
def handle_variable(params: Dict, input_variable: str, format_kwargs: Dict):
|
||||
variable = params[input_variable]
|
||||
if isinstance(variable, str):
|
||||
format_kwargs[input_variable] = variable
|
||||
elif isinstance(variable, BaseOutputParser) and hasattr(
|
||||
variable, "get_format_instructions"
|
||||
):
|
||||
format_kwargs[input_variable] = variable.get_format_instructions()
|
||||
elif is_instance_of_list_or_document(variable):
|
||||
format_kwargs = format_document(variable, input_variable, format_kwargs)
|
||||
if needs_handle_keys(variable):
|
||||
format_kwargs = add_handle_keys(input_variable, format_kwargs)
|
||||
return format_kwargs
|
||||
|
||||
|
||||
def is_instance_of_list_or_document(variable):
|
||||
return (
|
||||
isinstance(variable, List)
|
||||
and all(isinstance(item, Document) for item in variable)
|
||||
or isinstance(variable, Document)
|
||||
)
|
||||
|
||||
|
||||
def format_document(variable, input_variable: str, format_kwargs: Dict):
|
||||
variable = variable if isinstance(variable, List) else [variable]
|
||||
content = format_content(variable)
|
||||
format_kwargs[input_variable] = content
|
||||
return format_kwargs
|
||||
|
||||
|
||||
def format_content(variable):
|
||||
if len(variable) > 1:
|
||||
return "\n".join([item.page_content for item in variable if item.page_content])
|
||||
content = variable[0].page_content
|
||||
return try_to_load_json(content)
|
||||
|
||||
|
||||
def try_to_load_json(content):
|
||||
with contextlib.suppress(json.JSONDecodeError):
|
||||
content = json.loads(content)
|
||||
if isinstance(content, list):
|
||||
content = ",".join([str(item) for item in content])
|
||||
return content
|
||||
|
||||
|
||||
def needs_handle_keys(variable):
|
||||
return is_instance_of_list_or_document(variable) or (
|
||||
isinstance(variable, BaseOutputParser)
|
||||
and hasattr(variable, "get_format_instructions")
|
||||
)
|
||||
|
||||
|
||||
def add_handle_keys(input_variable: str, format_kwargs: Dict):
|
||||
if "handle_keys" not in format_kwargs:
|
||||
format_kwargs["handle_keys"] = []
|
||||
format_kwargs["handle_keys"].append(input_variable)
|
||||
return format_kwargs
|
||||
223
src/backend/langflow/interface/initialize/vector_store.py
Normal file
223
src/backend/langflow/interface/initialize/vector_store.py
Normal file
|
|
@ -0,0 +1,223 @@
|
|||
import json
|
||||
from typing import Any, Callable, Dict, Type
|
||||
from langchain.vectorstores import (
|
||||
Pinecone,
|
||||
Qdrant,
|
||||
Chroma,
|
||||
FAISS,
|
||||
Weaviate,
|
||||
SupabaseVectorStore,
|
||||
MongoDBAtlasVectorSearch,
|
||||
)
|
||||
|
||||
import os
|
||||
|
||||
|
||||
def docs_in_params(params: dict) -> bool:
|
||||
"""Check if params has documents OR texts and one of them is not an empty list,
|
||||
If any of them is not an empty list, return True, else return False"""
|
||||
return ("documents" in params and params["documents"]) or (
|
||||
"texts" in params and params["texts"]
|
||||
)
|
||||
|
||||
|
||||
def initialize_mongodb(class_object: Type[MongoDBAtlasVectorSearch], params: dict):
|
||||
"""Initialize mongodb and return the class object"""
|
||||
|
||||
MONGODB_ATLAS_CLUSTER_URI = params.pop("mongodb_atlas_cluster_uri")
|
||||
if not MONGODB_ATLAS_CLUSTER_URI:
|
||||
raise ValueError("Mongodb atlas cluster uri must be provided in the params")
|
||||
from pymongo import MongoClient
|
||||
import certifi
|
||||
|
||||
client: MongoClient = MongoClient(
|
||||
MONGODB_ATLAS_CLUSTER_URI, tlsCAFile=certifi.where()
|
||||
)
|
||||
db_name = params.pop("db_name", None)
|
||||
collection_name = params.pop("collection_name", None)
|
||||
if not db_name or not collection_name:
|
||||
raise ValueError("db_name and collection_name must be provided in the params")
|
||||
|
||||
index_name = params.pop("index_name", None)
|
||||
if not index_name:
|
||||
raise ValueError("index_name must be provided in the params")
|
||||
|
||||
collection = client[db_name][collection_name]
|
||||
if not docs_in_params(params):
|
||||
# __init__ requires collection, embedding and index_name
|
||||
init_args = {
|
||||
"collection": collection,
|
||||
"index_name": index_name,
|
||||
"embedding": params.get("embedding"),
|
||||
}
|
||||
|
||||
return class_object(**init_args)
|
||||
|
||||
if "texts" in params:
|
||||
params["documents"] = params.pop("texts")
|
||||
|
||||
params["collection"] = collection
|
||||
params["index_name"] = index_name
|
||||
|
||||
return class_object.from_documents(**params)
|
||||
|
||||
|
||||
def initialize_supabase(class_object: Type[SupabaseVectorStore], params: dict):
|
||||
"""Initialize supabase and return the class object"""
|
||||
from supabase.client import Client, create_client
|
||||
|
||||
if "supabase_url" not in params or "supabase_service_key" not in params:
|
||||
raise ValueError("Supabase url and service key must be provided in the params")
|
||||
if "texts" in params:
|
||||
params["documents"] = params.pop("texts")
|
||||
|
||||
client_kwargs = {
|
||||
"supabase_url": params.pop("supabase_url"),
|
||||
"supabase_key": params.pop("supabase_service_key"),
|
||||
}
|
||||
|
||||
supabase: Client = create_client(**client_kwargs)
|
||||
if not docs_in_params(params):
|
||||
params.pop("documents", None)
|
||||
params.pop("texts", None)
|
||||
return class_object(client=supabase, **params)
|
||||
# If there are docs in the params, create a new index
|
||||
|
||||
return class_object.from_documents(client=supabase, **params)
|
||||
|
||||
|
||||
def initialize_weaviate(class_object: Type[Weaviate], params: dict):
|
||||
"""Initialize weaviate and return the class object"""
|
||||
if not docs_in_params(params):
|
||||
import weaviate # type: ignore
|
||||
|
||||
client_kwargs_json = params.get("client_kwargs", "{}")
|
||||
client_kwargs = json.loads(client_kwargs_json)
|
||||
client_params = {
|
||||
"url": params.get("weaviate_url"),
|
||||
}
|
||||
client_params.update(client_kwargs)
|
||||
weaviate_client = weaviate.Client(**client_params)
|
||||
|
||||
new_params = {
|
||||
"client": weaviate_client,
|
||||
"index_name": params.get("index_name"),
|
||||
"text_key": params.get("text_key"),
|
||||
}
|
||||
return class_object(**new_params)
|
||||
# If there are docs in the params, create a new index
|
||||
if "texts" in params:
|
||||
params["documents"] = params.pop("texts")
|
||||
|
||||
return class_object.from_documents(**params)
|
||||
|
||||
|
||||
def initialize_faiss(class_object: Type[FAISS], params: dict):
|
||||
"""Initialize faiss and return the class object"""
|
||||
|
||||
if not docs_in_params(params):
|
||||
return class_object.load_local
|
||||
|
||||
save_local = params.get("save_local")
|
||||
faiss_index = class_object(**params)
|
||||
if save_local:
|
||||
faiss_index.save_local(folder_path=save_local)
|
||||
return faiss_index
|
||||
|
||||
|
||||
def initialize_pinecone(class_object: Type[Pinecone], params: dict):
|
||||
"""Initialize pinecone and return the class object"""
|
||||
|
||||
import pinecone # type: ignore
|
||||
|
||||
pinecone_api_key = params.get("pinecone_api_key")
|
||||
pinecone_env = params.get("pinecone_env")
|
||||
|
||||
if pinecone_api_key is None or pinecone_env is None:
|
||||
if os.getenv("PINECONE_API_KEY") is not None:
|
||||
pinecone_api_key = os.getenv("PINECONE_API_KEY")
|
||||
if os.getenv("PINECONE_ENV") is not None:
|
||||
pinecone_env = os.getenv("PINECONE_ENV")
|
||||
|
||||
if pinecone_api_key is None or pinecone_env is None:
|
||||
raise ValueError(
|
||||
"Pinecone API key and environment must be provided in the params"
|
||||
)
|
||||
|
||||
# initialize pinecone
|
||||
pinecone.init(
|
||||
api_key=pinecone_api_key, # find at app.pinecone.io
|
||||
environment=pinecone_env, # next to api key in console
|
||||
)
|
||||
|
||||
# If there are no docs in the params, return an existing index
|
||||
# but first remove any texts or docs keys from the params
|
||||
if not docs_in_params(params):
|
||||
existing_index_params = {
|
||||
"embedding": params.pop("embedding"),
|
||||
}
|
||||
if "index_name" in params:
|
||||
existing_index_params["index_name"] = params.pop("index_name")
|
||||
if "namespace" in params:
|
||||
existing_index_params["namespace"] = params.pop("namespace")
|
||||
|
||||
return class_object.from_existing_index(**existing_index_params)
|
||||
# If there are docs in the params, create a new index
|
||||
if "texts" in params:
|
||||
params["documents"] = params.pop("texts")
|
||||
return class_object.from_documents(**params)
|
||||
|
||||
|
||||
def initialize_chroma(class_object: Type[Chroma], params: dict):
|
||||
"""Initialize a ChromaDB object from the params"""
|
||||
persist = params.pop("persist", False)
|
||||
if not docs_in_params(params):
|
||||
params.pop("documents", None)
|
||||
params.pop("texts", None)
|
||||
params["embedding_function"] = params.pop("embedding")
|
||||
chromadb = class_object(**params)
|
||||
else:
|
||||
if "texts" in params:
|
||||
params["documents"] = params.pop("texts")
|
||||
for doc in params["documents"]:
|
||||
if doc.metadata is None:
|
||||
doc.metadata = {}
|
||||
for key, value in doc.metadata.items():
|
||||
if value is None:
|
||||
doc.metadata[key] = ""
|
||||
chromadb = class_object.from_documents(**params)
|
||||
if persist:
|
||||
chromadb.persist()
|
||||
return chromadb
|
||||
|
||||
|
||||
def initialize_qdrant(class_object: Type[Qdrant], params: dict):
|
||||
if not docs_in_params(params):
|
||||
if "location" not in params and "api_key" not in params:
|
||||
raise ValueError("Location and API key must be provided in the params")
|
||||
from qdrant_client import QdrantClient
|
||||
|
||||
client_params = {
|
||||
"location": params.pop("location"),
|
||||
"api_key": params.pop("api_key"),
|
||||
}
|
||||
lc_params = {
|
||||
"collection_name": params.pop("collection_name"),
|
||||
"embeddings": params.pop("embedding"),
|
||||
}
|
||||
client = QdrantClient(**client_params)
|
||||
|
||||
return class_object(client=client, **lc_params)
|
||||
|
||||
return class_object.from_documents(**params)
|
||||
|
||||
|
||||
vecstore_initializer: Dict[str, Callable[[Type[Any], dict], Any]] = {
|
||||
"Pinecone": initialize_pinecone,
|
||||
"Chroma": initialize_chroma,
|
||||
"Qdrant": initialize_qdrant,
|
||||
"Weaviate": initialize_weaviate,
|
||||
"FAISS": initialize_faiss,
|
||||
"SupabaseVectorStore": initialize_supabase,
|
||||
"MongoDBAtlasVectorSearch": initialize_mongodb,
|
||||
}
|
||||
|
|
@ -11,6 +11,9 @@ from langflow.interface.tools.base import tool_creator
|
|||
from langflow.interface.utilities.base import utility_creator
|
||||
from langflow.interface.vector_store.base import vectorstore_creator
|
||||
from langflow.interface.wrappers.base import wrapper_creator
|
||||
from langflow.interface.output_parsers.base import output_parser_creator
|
||||
from langflow.interface.retrievers.base import retriever_creator
|
||||
from langflow.interface.custom.base import custom_component_creator
|
||||
|
||||
|
||||
def get_type_dict():
|
||||
|
|
@ -28,6 +31,9 @@ def get_type_dict():
|
|||
"embeddings": embedding_creator.to_list(),
|
||||
"textSplitters": textsplitter_creator.to_list(),
|
||||
"utilities": utility_creator.to_list(),
|
||||
"outputParsers": output_parser_creator.to_list(),
|
||||
"retrievers": retriever_creator.to_list(),
|
||||
"custom_components": custom_component_creator.to_list(),
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -1,380 +0,0 @@
|
|||
import json
|
||||
from typing import Any, Callable, Dict, Optional
|
||||
|
||||
from langchain.agents import ZeroShotAgent
|
||||
from langchain.agents import agent as agent_module
|
||||
from langchain.agents.agent import AgentExecutor
|
||||
from langchain.agents.agent_toolkits.base import BaseToolkit
|
||||
from langchain.agents.load_tools import (
|
||||
_BASE_TOOLS,
|
||||
_EXTRA_LLM_TOOLS,
|
||||
_EXTRA_OPTIONAL_TOOLS,
|
||||
_LLM_TOOLS,
|
||||
)
|
||||
from langchain.agents.loading import load_agent_from_config
|
||||
from langchain.agents.tools import Tool
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from langchain.callbacks.base import BaseCallbackManager
|
||||
from langchain.chains.loading import load_chain_from_config
|
||||
from langchain.llms.loading import load_llm_from_config
|
||||
from pydantic import ValidationError
|
||||
|
||||
from langflow.interface.custom_lists import CUSTOM_NODES
|
||||
from langflow.interface.importing.utils import get_function, import_by_type
|
||||
from langflow.interface.toolkits.base import toolkits_creator
|
||||
from langflow.interface.chains.base import chain_creator
|
||||
from langflow.interface.types import get_type_list
|
||||
from langflow.interface.utils import load_file_into_dict
|
||||
from langflow.utils import util, validate
|
||||
|
||||
|
||||
def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any:
|
||||
"""Instantiate class from module type and key, and params"""
|
||||
params = convert_params_to_sets(params)
|
||||
params = convert_kwargs(params)
|
||||
if node_type in CUSTOM_NODES:
|
||||
if custom_node := CUSTOM_NODES.get(node_type):
|
||||
if hasattr(custom_node, "initialize"):
|
||||
return custom_node.initialize(**params)
|
||||
return custom_node(**params)
|
||||
|
||||
class_object = import_by_type(_type=base_type, name=node_type)
|
||||
return instantiate_based_on_type(class_object, base_type, node_type, params)
|
||||
|
||||
|
||||
def convert_params_to_sets(params):
|
||||
"""Convert certain params to sets"""
|
||||
if "allowed_special" in params:
|
||||
params["allowed_special"] = set(params["allowed_special"])
|
||||
if "disallowed_special" in params:
|
||||
params["disallowed_special"] = set(params["disallowed_special"])
|
||||
return params
|
||||
|
||||
|
||||
def convert_kwargs(params):
|
||||
# if *kwargs are passed as a string, convert to dict
|
||||
# first find any key that has kwargs in it
|
||||
kwargs_keys = [key for key in params.keys() if "kwargs" in key]
|
||||
for key in kwargs_keys:
|
||||
if isinstance(params[key], str):
|
||||
params[key] = json.loads(params[key])
|
||||
return params
|
||||
|
||||
|
||||
def instantiate_based_on_type(class_object, base_type, node_type, params):
|
||||
if base_type == "agents":
|
||||
return instantiate_agent(class_object, params)
|
||||
elif base_type == "prompts":
|
||||
return instantiate_prompt(node_type, class_object, params)
|
||||
elif base_type == "tools":
|
||||
return instantiate_tool(node_type, class_object, params)
|
||||
elif base_type == "toolkits":
|
||||
return instantiate_toolkit(node_type, class_object, params)
|
||||
elif base_type == "embeddings":
|
||||
return instantiate_embedding(class_object, params)
|
||||
elif base_type == "vectorstores":
|
||||
return instantiate_vectorstore(class_object, params)
|
||||
elif base_type == "documentloaders":
|
||||
return instantiate_documentloader(class_object, params)
|
||||
elif base_type == "textsplitters":
|
||||
return instantiate_textsplitter(class_object, params)
|
||||
elif base_type == "utilities":
|
||||
return instantiate_utility(node_type, class_object, params)
|
||||
elif base_type == "chains":
|
||||
return instantiate_chains(node_type, class_object, params)
|
||||
else:
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def instantiate_chains(node_type, class_object, params):
|
||||
if "retriever" in params and hasattr(params["retriever"], "as_retriever"):
|
||||
params["retriever"] = params["retriever"].as_retriever()
|
||||
if node_type in chain_creator.from_method_nodes:
|
||||
method = chain_creator.from_method_nodes[node_type]
|
||||
if class_method := getattr(class_object, method, None):
|
||||
return class_method(**params)
|
||||
raise ValueError(f"Method {method} not found in {class_object}")
|
||||
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def instantiate_agent(class_object, params):
|
||||
return load_agent_executor(class_object, params)
|
||||
|
||||
|
||||
def instantiate_prompt(node_type, class_object, params):
|
||||
if node_type == "ZeroShotPrompt":
|
||||
if "tools" not in params:
|
||||
params["tools"] = []
|
||||
return ZeroShotAgent.create_prompt(**params)
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def instantiate_tool(node_type, class_object, params):
|
||||
if node_type == "JsonSpec":
|
||||
params["dict_"] = load_file_into_dict(params.pop("path"))
|
||||
return class_object(**params)
|
||||
elif node_type == "PythonFunctionTool":
|
||||
params["func"] = get_function(params.get("code"))
|
||||
return class_object(**params)
|
||||
# For backward compatibility
|
||||
elif node_type == "PythonFunction":
|
||||
function_string = params["code"]
|
||||
if isinstance(function_string, str):
|
||||
return validate.eval_function(function_string)
|
||||
raise ValueError("Function should be a string")
|
||||
elif node_type.lower() == "tool":
|
||||
return class_object(**params)
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def instantiate_toolkit(node_type, class_object, params):
|
||||
loaded_toolkit = class_object(**params)
|
||||
# Commenting this out for now to use toolkits as normal tools
|
||||
# if toolkits_creator.has_create_function(node_type):
|
||||
# return load_toolkits_executor(node_type, loaded_toolkit, params)
|
||||
if isinstance(loaded_toolkit, BaseToolkit):
|
||||
return loaded_toolkit.get_tools()
|
||||
return loaded_toolkit
|
||||
|
||||
|
||||
def instantiate_embedding(class_object, params):
|
||||
params.pop("model", None)
|
||||
params.pop("headers", None)
|
||||
try:
|
||||
return class_object(**params)
|
||||
except ValidationError:
|
||||
params = {
|
||||
key: value
|
||||
for key, value in params.items()
|
||||
if key in class_object.__fields__
|
||||
}
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def instantiate_vectorstore(class_object, params):
|
||||
if len(params.get("documents", [])) == 0:
|
||||
raise ValueError(
|
||||
"The source you provided did not load correctly or was empty."
|
||||
"This may cause an error in the vectorstore."
|
||||
)
|
||||
# Chroma requires all metadata values to not be None
|
||||
if class_object.__name__ == "Chroma":
|
||||
for doc in params["documents"]:
|
||||
if doc.metadata is None:
|
||||
doc.metadata = {}
|
||||
for key, value in doc.metadata.items():
|
||||
if value is None:
|
||||
doc.metadata[key] = ""
|
||||
return class_object.from_documents(**params)
|
||||
|
||||
|
||||
def instantiate_documentloader(class_object, params):
|
||||
return class_object(**params).load()
|
||||
|
||||
|
||||
def instantiate_textsplitter(class_object, params):
|
||||
try:
|
||||
documents = params.pop("documents")
|
||||
except KeyError as e:
|
||||
raise ValueError(
|
||||
"The source you provided did not load correctly or was empty."
|
||||
"Try changing the chunk_size of the Text Splitter."
|
||||
) from e
|
||||
text_splitter = class_object(**params)
|
||||
return text_splitter.split_documents(documents)
|
||||
|
||||
|
||||
def instantiate_utility(node_type, class_object, params):
|
||||
if node_type == "SQLDatabase":
|
||||
return class_object.from_uri(params.pop("uri"))
|
||||
return class_object(**params)
|
||||
|
||||
|
||||
def replace_zero_shot_prompt_with_prompt_template(nodes):
|
||||
"""Replace ZeroShotPrompt with PromptTemplate"""
|
||||
for node in nodes:
|
||||
if node["data"]["type"] == "ZeroShotPrompt":
|
||||
# Build Prompt Template
|
||||
tools = [
|
||||
tool
|
||||
for tool in nodes
|
||||
if tool["type"] != "chatOutputNode"
|
||||
and "Tool" in tool["data"]["node"]["base_classes"]
|
||||
]
|
||||
node["data"] = build_prompt_template(prompt=node["data"], tools=tools)
|
||||
break
|
||||
return nodes
|
||||
|
||||
|
||||
def load_langchain_type_from_config(config: Dict[str, Any]):
|
||||
"""Load langchain type from config"""
|
||||
# Get type list
|
||||
type_list = get_type_list()
|
||||
if config["_type"] in type_list["agents"]:
|
||||
config = util.update_verbose(config, new_value=False)
|
||||
return load_agent_executor_from_config(config, verbose=True)
|
||||
elif config["_type"] in type_list["chains"]:
|
||||
config = util.update_verbose(config, new_value=False)
|
||||
return load_chain_from_config(config, verbose=True)
|
||||
elif config["_type"] in type_list["llms"]:
|
||||
config = util.update_verbose(config, new_value=True)
|
||||
return load_llm_from_config(config)
|
||||
else:
|
||||
raise ValueError("Type should be either agent, chain or llm")
|
||||
|
||||
|
||||
def load_agent_executor_from_config(
|
||||
config: dict,
|
||||
llm: Optional[BaseLanguageModel] = None,
|
||||
tools: Optional[list[Tool]] = None,
|
||||
callback_manager: Optional[BaseCallbackManager] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
tools = load_tools_from_config(config["allowed_tools"])
|
||||
config["allowed_tools"] = [tool.name for tool in tools] if tools else []
|
||||
agent_obj = load_agent_from_config(config, llm, tools, **kwargs)
|
||||
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent_obj,
|
||||
tools=tools,
|
||||
callback_manager=callback_manager,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def load_agent_executor(agent_class: type[agent_module.Agent], params, **kwargs):
|
||||
"""Load agent executor from agent class, tools and chain"""
|
||||
allowed_tools = params.get("allowed_tools", [])
|
||||
llm_chain = params["llm_chain"]
|
||||
# if allowed_tools is not a list or set, make it a list
|
||||
if not isinstance(allowed_tools, (list, set)):
|
||||
allowed_tools = [allowed_tools]
|
||||
tool_names = [tool.name for tool in allowed_tools]
|
||||
# Agent class requires an output_parser but Agent classes
|
||||
# have a default output_parser.
|
||||
agent = agent_class(allowed_tools=tool_names, llm_chain=llm_chain) # type: ignore
|
||||
return AgentExecutor.from_agent_and_tools(
|
||||
agent=agent,
|
||||
tools=allowed_tools,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def load_toolkits_executor(node_type: str, toolkit: BaseToolkit, params: dict):
|
||||
create_function: Callable = toolkits_creator.get_create_function(node_type)
|
||||
if llm := params.get("llm"):
|
||||
return create_function(llm=llm, toolkit=toolkit)
|
||||
|
||||
|
||||
def load_tools_from_config(tool_list: list[dict]) -> list:
|
||||
"""Load tools based on a config list.
|
||||
|
||||
Args:
|
||||
config: config list.
|
||||
|
||||
Returns:
|
||||
List of tools.
|
||||
"""
|
||||
tools = []
|
||||
for tool in tool_list:
|
||||
tool_type = tool.pop("_type")
|
||||
llm_config = tool.pop("llm", None)
|
||||
llm = load_llm_from_config(llm_config) if llm_config else None
|
||||
kwargs = tool
|
||||
if tool_type in _BASE_TOOLS:
|
||||
tools.append(_BASE_TOOLS[tool_type]())
|
||||
elif tool_type in _LLM_TOOLS:
|
||||
if llm is None:
|
||||
raise ValueError(f"Tool {tool_type} requires an LLM to be provided")
|
||||
tools.append(_LLM_TOOLS[tool_type](llm))
|
||||
elif tool_type in _EXTRA_LLM_TOOLS:
|
||||
if llm is None:
|
||||
raise ValueError(f"Tool {tool_type} requires an LLM to be provided")
|
||||
_get_llm_tool_func, extra_keys = _EXTRA_LLM_TOOLS[tool_type]
|
||||
if missing_keys := set(extra_keys).difference(kwargs):
|
||||
raise ValueError(
|
||||
f"Tool {tool_type} requires some parameters that were not "
|
||||
f"provided: {missing_keys}"
|
||||
)
|
||||
tools.append(_get_llm_tool_func(llm=llm, **kwargs))
|
||||
elif tool_type in _EXTRA_OPTIONAL_TOOLS:
|
||||
_get_tool_func, extra_keys = _EXTRA_OPTIONAL_TOOLS[tool_type]
|
||||
kwargs = {k: value for k, value in kwargs.items() if value}
|
||||
tools.append(_get_tool_func(**kwargs))
|
||||
else:
|
||||
raise ValueError(f"Got unknown tool {tool_type}")
|
||||
return tools
|
||||
|
||||
|
||||
def build_prompt_template(prompt, tools):
|
||||
"""Build PromptTemplate from ZeroShotPrompt"""
|
||||
prefix = prompt["node"]["template"]["prefix"]["value"]
|
||||
suffix = prompt["node"]["template"]["suffix"]["value"]
|
||||
format_instructions = prompt["node"]["template"]["format_instructions"]["value"]
|
||||
|
||||
tool_strings = "\n".join(
|
||||
[
|
||||
f"{tool['data']['node']['name']}: {tool['data']['node']['description']}"
|
||||
for tool in tools
|
||||
]
|
||||
)
|
||||
tool_names = ", ".join([tool["data"]["node"]["name"] for tool in tools])
|
||||
format_instructions = format_instructions.format(tool_names=tool_names)
|
||||
value = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
|
||||
|
||||
prompt["type"] = "PromptTemplate"
|
||||
|
||||
prompt["node"] = {
|
||||
"template": {
|
||||
"_type": "prompt",
|
||||
"input_variables": {
|
||||
"type": "str",
|
||||
"required": True,
|
||||
"placeholder": "",
|
||||
"list": True,
|
||||
"show": False,
|
||||
"multiline": False,
|
||||
},
|
||||
"output_parser": {
|
||||
"type": "BaseOutputParser",
|
||||
"required": False,
|
||||
"placeholder": "",
|
||||
"list": False,
|
||||
"show": False,
|
||||
"multline": False,
|
||||
"value": None,
|
||||
},
|
||||
"template": {
|
||||
"type": "str",
|
||||
"required": True,
|
||||
"placeholder": "",
|
||||
"list": False,
|
||||
"show": True,
|
||||
"multiline": True,
|
||||
"value": value,
|
||||
},
|
||||
"template_format": {
|
||||
"type": "str",
|
||||
"required": False,
|
||||
"placeholder": "",
|
||||
"list": False,
|
||||
"show": False,
|
||||
"multline": False,
|
||||
"value": "f-string",
|
||||
},
|
||||
"validate_template": {
|
||||
"type": "bool",
|
||||
"required": False,
|
||||
"placeholder": "",
|
||||
"list": False,
|
||||
"show": False,
|
||||
"multline": False,
|
||||
"value": True,
|
||||
},
|
||||
},
|
||||
"description": "Schema to represent a prompt for an LLM.",
|
||||
"base_classes": ["BasePromptTemplate"],
|
||||
}
|
||||
|
||||
return prompt
|
||||
|
|
@ -6,12 +6,18 @@ from langflow.settings import settings
|
|||
from langflow.template.frontend_node.base import FrontendNode
|
||||
from langflow.template.frontend_node.memories import MemoryFrontendNode
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.utils.util import build_template_from_class
|
||||
from langflow.utils.util import build_template_from_class, build_template_from_method
|
||||
from langflow.custom.customs import get_custom_nodes
|
||||
|
||||
|
||||
class MemoryCreator(LangChainTypeCreator):
|
||||
type_name: str = "memories"
|
||||
|
||||
from_method_nodes = {
|
||||
"ZepChatMessageHistory": "__init__",
|
||||
"SQLiteEntityStore": "__init__",
|
||||
}
|
||||
|
||||
@property
|
||||
def frontend_node_class(self) -> Type[FrontendNode]:
|
||||
"""The class type of the FrontendNode created in frontend_node."""
|
||||
|
|
@ -26,6 +32,14 @@ class MemoryCreator(LangChainTypeCreator):
|
|||
def get_signature(self, name: str) -> Optional[Dict]:
|
||||
"""Get the signature of a memory."""
|
||||
try:
|
||||
if name in get_custom_nodes(self.type_name).keys():
|
||||
return get_custom_nodes(self.type_name)[name]
|
||||
elif name in self.from_method_nodes:
|
||||
return build_template_from_method(
|
||||
name,
|
||||
type_to_cls_dict=memory_type_to_cls_dict,
|
||||
method_name=self.from_method_nodes[name],
|
||||
)
|
||||
return build_template_from_class(name, memory_type_to_cls_dict)
|
||||
except ValueError as exc:
|
||||
raise ValueError("Memory not found") from exc
|
||||
|
|
|
|||
64
src/backend/langflow/interface/output_parsers/base.py
Normal file
64
src/backend/langflow/interface/output_parsers/base.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
from typing import Dict, List, Optional, Type
|
||||
|
||||
from langchain import output_parsers
|
||||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.importing.utils import import_class
|
||||
from langflow.settings import settings
|
||||
from langflow.template.frontend_node.output_parsers import OutputParserFrontendNode
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.utils.util import build_template_from_class, build_template_from_method
|
||||
|
||||
|
||||
class OutputParserCreator(LangChainTypeCreator):
|
||||
type_name: str = "output_parsers"
|
||||
from_method_nodes = {
|
||||
"StructuredOutputParser": "from_response_schemas",
|
||||
}
|
||||
|
||||
@property
|
||||
def frontend_node_class(self) -> Type[OutputParserFrontendNode]:
|
||||
return OutputParserFrontendNode
|
||||
|
||||
@property
|
||||
def type_to_loader_dict(self) -> Dict:
|
||||
if self.type_dict is None:
|
||||
self.type_dict = {
|
||||
output_parser_name: import_class(
|
||||
f"langchain.output_parsers.{output_parser_name}"
|
||||
)
|
||||
# if output_parser_name is not lower case it is a class
|
||||
for output_parser_name in output_parsers.__all__
|
||||
}
|
||||
self.type_dict = {
|
||||
name: output_parser
|
||||
for name, output_parser in self.type_dict.items()
|
||||
if name in settings.output_parsers or settings.dev
|
||||
}
|
||||
return self.type_dict
|
||||
|
||||
def get_signature(self, name: str) -> Optional[Dict]:
|
||||
try:
|
||||
if name in self.from_method_nodes:
|
||||
return build_template_from_method(
|
||||
name,
|
||||
type_to_cls_dict=self.type_to_loader_dict,
|
||||
method_name=self.from_method_nodes[name],
|
||||
)
|
||||
else:
|
||||
return build_template_from_class(
|
||||
name,
|
||||
type_to_cls_dict=self.type_to_loader_dict,
|
||||
)
|
||||
except ValueError as exc:
|
||||
# raise ValueError("OutputParser not found") from exc
|
||||
logger.error(f"OutputParser {name} not found: {exc}")
|
||||
except AttributeError as exc:
|
||||
logger.error(f"OutputParser {name} not loaded: {exc}")
|
||||
return None
|
||||
|
||||
def to_list(self) -> List[str]:
|
||||
return list(self.type_to_loader_dict.keys())
|
||||
|
||||
|
||||
output_parser_creator = OutputParserCreator()
|
||||
0
src/backend/langflow/interface/retrievers/__init__.py
Normal file
0
src/backend/langflow/interface/retrievers/__init__.py
Normal file
58
src/backend/langflow/interface/retrievers/base.py
Normal file
58
src/backend/langflow/interface/retrievers/base.py
Normal file
|
|
@ -0,0 +1,58 @@
|
|||
from typing import Any, Dict, List, Optional, Type
|
||||
|
||||
from langchain import retrievers
|
||||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.interface.importing.utils import import_class
|
||||
from langflow.settings import settings
|
||||
from langflow.template.frontend_node.retrievers import RetrieverFrontendNode
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.utils.util import build_template_from_method, build_template_from_class
|
||||
|
||||
|
||||
class RetrieverCreator(LangChainTypeCreator):
|
||||
type_name: str = "retrievers"
|
||||
|
||||
from_method_nodes = {"MultiQueryRetriever": "from_llm", "ZepRetriever": "__init__"}
|
||||
|
||||
@property
|
||||
def frontend_node_class(self) -> Type[RetrieverFrontendNode]:
|
||||
return RetrieverFrontendNode
|
||||
|
||||
@property
|
||||
def type_to_loader_dict(self) -> Dict:
|
||||
if self.type_dict is None:
|
||||
self.type_dict: dict[str, Any] = {
|
||||
retriever_name: import_class(f"langchain.retrievers.{retriever_name}")
|
||||
for retriever_name in retrievers.__all__
|
||||
}
|
||||
return self.type_dict
|
||||
|
||||
def get_signature(self, name: str) -> Optional[Dict]:
|
||||
"""Get the signature of an embedding."""
|
||||
try:
|
||||
if name in self.from_method_nodes:
|
||||
return build_template_from_method(
|
||||
name,
|
||||
type_to_cls_dict=self.type_to_loader_dict,
|
||||
method_name=self.from_method_nodes[name],
|
||||
)
|
||||
else:
|
||||
return build_template_from_class(
|
||||
name, type_to_cls_dict=self.type_to_loader_dict
|
||||
)
|
||||
except ValueError as exc:
|
||||
raise ValueError(f"Retriever {name} not found") from exc
|
||||
except AttributeError as exc:
|
||||
logger.error(f"Retriever {name} not loaded: {exc}")
|
||||
return None
|
||||
|
||||
def to_list(self) -> List[str]:
|
||||
return [
|
||||
retriever
|
||||
for retriever in self.type_to_loader_dict.keys()
|
||||
if retriever in settings.retrievers or settings.dev
|
||||
]
|
||||
|
||||
|
||||
retriever_creator = RetrieverCreator()
|
||||
|
|
@ -14,6 +14,23 @@ def build_langchain_object_with_caching(data_graph):
|
|||
return graph.build()
|
||||
|
||||
|
||||
@memoize_dict(maxsize=10)
|
||||
def build_sorted_vertices_with_caching(data_graph):
|
||||
"""
|
||||
Build langchain object from data_graph.
|
||||
"""
|
||||
|
||||
logger.debug("Building langchain object")
|
||||
graph = Graph.from_payload(data_graph)
|
||||
sorted_vertices = graph.topological_sort()
|
||||
artifacts = {}
|
||||
for vertex in sorted_vertices:
|
||||
vertex.build()
|
||||
if vertex.artifacts:
|
||||
artifacts.update(vertex.artifacts)
|
||||
return graph.build(), artifacts
|
||||
|
||||
|
||||
def build_langchain_object(data_graph):
|
||||
"""
|
||||
Build langchain object from data_graph.
|
||||
|
|
@ -62,6 +79,10 @@ def update_memory_keys(langchain_object, possible_new_mem_key):
|
|||
if key not in [langchain_object.memory.memory_key, possible_new_mem_key]
|
||||
][0]
|
||||
|
||||
langchain_object.memory.input_key = input_key
|
||||
langchain_object.memory.output_key = output_key
|
||||
langchain_object.memory.memory_key = possible_new_mem_key
|
||||
keys = [input_key, output_key, possible_new_mem_key]
|
||||
attrs = ["input_key", "output_key", "memory_key"]
|
||||
for key, attr in zip(keys, attrs):
|
||||
try:
|
||||
setattr(langchain_object.memory, attr, key)
|
||||
except ValueError as exc:
|
||||
logger.debug(f"{langchain_object.memory} has no attribute {attr} ({exc})")
|
||||
|
|
|
|||
|
|
@ -55,7 +55,7 @@ TOOL_INPUTS = {
|
|||
show=True,
|
||||
value="",
|
||||
suffixes=[".json", ".yaml", ".yml"],
|
||||
fileTypes=["json", "yaml", "yml"],
|
||||
file_types=["json", "yaml", "yml"],
|
||||
),
|
||||
}
|
||||
|
||||
|
|
@ -90,7 +90,7 @@ class ToolCreator(LangChainTypeCreator):
|
|||
def get_signature(self, name: str) -> Optional[Dict]:
|
||||
"""Get the signature of a tool."""
|
||||
|
||||
base_classes = ["Tool"]
|
||||
base_classes = ["Tool", "BaseTool"]
|
||||
fields = []
|
||||
params = []
|
||||
tool_params = {}
|
||||
|
|
|
|||
|
|
@ -9,7 +9,10 @@ from langchain.agents.load_tools import (
|
|||
from langchain.tools.json.tool import JsonSpec
|
||||
|
||||
from langflow.interface.importing.utils import import_class
|
||||
from langflow.interface.tools.custom import PythonFunctionTool, PythonFunction
|
||||
from langflow.interface.tools.custom import (
|
||||
PythonFunctionTool,
|
||||
PythonFunction,
|
||||
)
|
||||
|
||||
FILE_TOOLS = {"JsonSpec": JsonSpec}
|
||||
CUSTOM_TOOLS = {
|
||||
|
|
|
|||
|
|
@ -34,8 +34,6 @@ class Function(BaseModel):
|
|||
|
||||
|
||||
class PythonFunctionTool(Function, Tool):
|
||||
"""Python function"""
|
||||
|
||||
name: str = "Custom Tool"
|
||||
description: str
|
||||
code: str
|
||||
|
|
@ -49,6 +47,4 @@ class PythonFunctionTool(Function, Tool):
|
|||
|
||||
|
||||
class PythonFunction(Function):
|
||||
"""Python function"""
|
||||
|
||||
code: str
|
||||
|
|
|
|||
|
|
@ -1,7 +1,10 @@
|
|||
from typing import Any
|
||||
from langflow.interface.agents.base import agent_creator
|
||||
from langflow.interface.chains.base import chain_creator
|
||||
from langflow.interface.custom.constants import CUSTOM_COMPONENT_SUPPORTED_TYPES
|
||||
from langflow.interface.document_loaders.base import documentloader_creator
|
||||
from langflow.interface.embeddings.base import embedding_creator
|
||||
from langflow.interface.importing.utils import get_function_custom
|
||||
from langflow.interface.llms.base import llm_creator
|
||||
from langflow.interface.memories.base import memory_creator
|
||||
from langflow.interface.prompts.base import prompt_creator
|
||||
|
|
@ -11,8 +14,29 @@ from langflow.interface.tools.base import tool_creator
|
|||
from langflow.interface.utilities.base import utility_creator
|
||||
from langflow.interface.vector_store.base import vectorstore_creator
|
||||
from langflow.interface.wrappers.base import wrapper_creator
|
||||
from langflow.interface.output_parsers.base import output_parser_creator
|
||||
from langflow.interface.custom.base import custom_component_creator
|
||||
from langflow.interface.custom.custom_component import CustomComponent
|
||||
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.frontend_node.constants import CLASSES_TO_REMOVE
|
||||
from langflow.template.frontend_node.custom_components import (
|
||||
CustomComponentFrontendNode,
|
||||
)
|
||||
from langflow.interface.retrievers.base import retriever_creator
|
||||
|
||||
from langflow.interface.custom.directory_reader import DirectoryReader
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.utils.util import get_base_classes
|
||||
from langflow.api.utils import merge_nested_dicts
|
||||
|
||||
import re
|
||||
import warnings
|
||||
import traceback
|
||||
from fastapi import HTTPException
|
||||
|
||||
|
||||
# Used to get the base_classes list
|
||||
def get_type_list():
|
||||
"""Get a list of all langchain types"""
|
||||
all_types = build_langchain_types_dict()
|
||||
|
|
@ -27,7 +51,6 @@ def get_type_list():
|
|||
|
||||
def build_langchain_types_dict(): # sourcery skip: dict-assign-update-to-union
|
||||
"""Build a dictionary of all langchain types"""
|
||||
|
||||
all_types = {}
|
||||
|
||||
creators = [
|
||||
|
|
@ -44,6 +67,9 @@ def build_langchain_types_dict(): # sourcery skip: dict-assign-update-to-union
|
|||
documentloader_creator,
|
||||
textsplitter_creator,
|
||||
utility_creator,
|
||||
output_parser_creator,
|
||||
retriever_creator,
|
||||
custom_component_creator,
|
||||
]
|
||||
|
||||
all_types = {}
|
||||
|
|
@ -51,4 +77,315 @@ def build_langchain_types_dict(): # sourcery skip: dict-assign-update-to-union
|
|||
created_types = creator.to_dict()
|
||||
if created_types[creator.type_name].values():
|
||||
all_types.update(created_types)
|
||||
|
||||
return all_types
|
||||
|
||||
|
||||
def process_type(field_type: str):
|
||||
return "prompt" if field_type == "Prompt" else field_type
|
||||
|
||||
|
||||
# TODO: Move to correct place
|
||||
def add_new_custom_field(
|
||||
template,
|
||||
field_name: str,
|
||||
field_type: str,
|
||||
field_value: Any,
|
||||
field_required: bool,
|
||||
field_config: dict,
|
||||
):
|
||||
# Check field_config if any of the keys are in it
|
||||
# if it is, update the value
|
||||
display_name = field_config.pop("display_name", field_name)
|
||||
field_type = field_config.pop("field_type", field_type)
|
||||
field_type = process_type(field_type)
|
||||
field_value = field_config.pop("value", field_value)
|
||||
field_advanced = field_config.pop("advanced", False)
|
||||
|
||||
if field_type == "bool" and field_value is None:
|
||||
field_value = False
|
||||
|
||||
# If options is a list, then it's a dropdown
|
||||
# If options is None, then it's a list of strings
|
||||
is_list = isinstance(field_config.get("options"), list)
|
||||
field_config["is_list"] = is_list or field_config.get("is_list", False)
|
||||
|
||||
if "name" in field_config:
|
||||
warnings.warn(
|
||||
"The 'name' key in field_config is used to build the object and can't be changed."
|
||||
)
|
||||
field_config.pop("name", None)
|
||||
|
||||
required = field_config.pop("required", field_required)
|
||||
placeholder = field_config.pop("placeholder", "")
|
||||
|
||||
new_field = TemplateField(
|
||||
name=field_name,
|
||||
field_type=field_type,
|
||||
value=field_value,
|
||||
show=True,
|
||||
required=required,
|
||||
advanced=field_advanced,
|
||||
placeholder=placeholder,
|
||||
display_name=display_name,
|
||||
**field_config,
|
||||
)
|
||||
template.get("template")[field_name] = new_field.to_dict()
|
||||
template.get("custom_fields")[field_name] = None
|
||||
|
||||
return template
|
||||
|
||||
|
||||
# TODO: Move to correct place
|
||||
def add_code_field(template, raw_code, field_config):
|
||||
# Field with the Python code to allow update
|
||||
|
||||
code_field = {
|
||||
"code": {
|
||||
"dynamic": True,
|
||||
"required": True,
|
||||
"placeholder": "",
|
||||
"show": True,
|
||||
"multiline": True,
|
||||
"value": raw_code,
|
||||
"password": False,
|
||||
"name": "code",
|
||||
"advanced": field_config.pop("advanced", False),
|
||||
"type": "code",
|
||||
"list": False,
|
||||
}
|
||||
}
|
||||
template.get("template")["code"] = code_field.get("code")
|
||||
|
||||
return template
|
||||
|
||||
|
||||
def extract_type_from_optional(field_type):
|
||||
"""
|
||||
Extract the type from a string formatted as "Optional[<type>]".
|
||||
|
||||
Parameters:
|
||||
field_type (str): The string from which to extract the type.
|
||||
|
||||
Returns:
|
||||
str: The extracted type, or an empty string if no type was found.
|
||||
"""
|
||||
match = re.search(r"\[(.*?)\]", field_type)
|
||||
return match[1] if match else None
|
||||
|
||||
|
||||
def build_frontend_node(custom_component: CustomComponent):
|
||||
"""Build a frontend node for a custom component"""
|
||||
try:
|
||||
return (
|
||||
CustomComponentFrontendNode().to_dict().get(type(custom_component).__name__)
|
||||
)
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Error while building base frontend node: {exc}")
|
||||
return None
|
||||
|
||||
|
||||
def update_display_name_and_description(frontend_node, template_config):
|
||||
"""Update the display name and description of a frontend node"""
|
||||
if "display_name" in template_config:
|
||||
frontend_node["display_name"] = template_config["display_name"]
|
||||
|
||||
if "description" in template_config:
|
||||
frontend_node["description"] = template_config["description"]
|
||||
|
||||
|
||||
def build_field_config(custom_component: CustomComponent):
|
||||
"""Build the field configuration for a custom component"""
|
||||
|
||||
try:
|
||||
custom_class = get_function_custom(custom_component.code)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error while getting custom function: {str(exc)}")
|
||||
return {}
|
||||
|
||||
try:
|
||||
return custom_class().build_config()
|
||||
except Exception as exc:
|
||||
logger.error(f"Error while building field config: {str(exc)}")
|
||||
return {}
|
||||
|
||||
|
||||
def add_extra_fields(frontend_node, field_config, function_args):
|
||||
"""Add extra fields to the frontend node"""
|
||||
if function_args is None or function_args == "":
|
||||
return
|
||||
|
||||
# sort function_args which is a list of dicts
|
||||
function_args.sort(key=lambda x: x["name"])
|
||||
|
||||
for extra_field in function_args:
|
||||
if "name" not in extra_field or extra_field["name"] == "self":
|
||||
continue
|
||||
|
||||
field_name, field_type, field_value, field_required = get_field_properties(
|
||||
extra_field
|
||||
)
|
||||
config = field_config.get(field_name, {})
|
||||
frontend_node = add_new_custom_field(
|
||||
frontend_node,
|
||||
field_name,
|
||||
field_type,
|
||||
field_value,
|
||||
field_required,
|
||||
config,
|
||||
)
|
||||
|
||||
|
||||
def get_field_properties(extra_field):
|
||||
"""Get the properties of an extra field"""
|
||||
field_name = extra_field["name"]
|
||||
field_type = extra_field.get("type", "str")
|
||||
field_value = extra_field.get("default", "")
|
||||
field_required = "optional" not in field_type.lower()
|
||||
|
||||
if not field_required:
|
||||
field_type = extract_type_from_optional(field_type)
|
||||
|
||||
return field_name, field_type, field_value, field_required
|
||||
|
||||
|
||||
def add_base_classes(frontend_node, return_type):
|
||||
"""Add base classes to the frontend node"""
|
||||
if return_type not in CUSTOM_COMPONENT_SUPPORTED_TYPES or return_type is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": (
|
||||
"Invalid return type should be one of: "
|
||||
f"{list(CUSTOM_COMPONENT_SUPPORTED_TYPES.keys())}"
|
||||
),
|
||||
"traceback": traceback.format_exc(),
|
||||
},
|
||||
)
|
||||
|
||||
return_type_instance = CUSTOM_COMPONENT_SUPPORTED_TYPES.get(return_type)
|
||||
base_classes = get_base_classes(return_type_instance)
|
||||
|
||||
for base_class in base_classes:
|
||||
if base_class not in CLASSES_TO_REMOVE:
|
||||
frontend_node.get("base_classes").append(base_class)
|
||||
|
||||
|
||||
def build_langchain_template_custom_component(custom_component: CustomComponent):
|
||||
"""Build a custom component template for the langchain"""
|
||||
frontend_node = build_frontend_node(custom_component)
|
||||
|
||||
if frontend_node is None:
|
||||
return None
|
||||
|
||||
template_config = custom_component.build_template_config
|
||||
|
||||
update_display_name_and_description(frontend_node, template_config)
|
||||
|
||||
field_config = build_field_config(custom_component)
|
||||
add_extra_fields(
|
||||
frontend_node, field_config, custom_component.get_function_entrypoint_args
|
||||
)
|
||||
|
||||
frontend_node = add_code_field(
|
||||
frontend_node, custom_component.code, field_config.get("code", {})
|
||||
)
|
||||
|
||||
add_base_classes(
|
||||
frontend_node, custom_component.get_function_entrypoint_return_type
|
||||
)
|
||||
|
||||
return frontend_node
|
||||
|
||||
|
||||
def load_files_from_path(path: str):
|
||||
"""Load all files from a given path"""
|
||||
reader = DirectoryReader(path, False)
|
||||
|
||||
return reader.get_files()
|
||||
|
||||
|
||||
def build_and_validate_all_files(reader, file_list):
|
||||
"""Build and validate all files"""
|
||||
data = reader.build_component_menu_list(file_list)
|
||||
|
||||
valid_components = reader.filter_loaded_components(data=data, with_errors=False)
|
||||
invalid_components = reader.filter_loaded_components(data=data, with_errors=True)
|
||||
|
||||
return valid_components, invalid_components
|
||||
|
||||
|
||||
def build_valid_menu(valid_components):
|
||||
"""Build the valid menu"""
|
||||
valid_menu = {}
|
||||
for menu_item in valid_components["menu"]:
|
||||
menu_name = menu_item["name"]
|
||||
valid_menu[menu_name] = {}
|
||||
|
||||
for component in menu_item["components"]:
|
||||
try:
|
||||
component_name = component["name"]
|
||||
component_code = component["code"]
|
||||
|
||||
component_extractor = CustomComponent(code=component_code)
|
||||
component_extractor.is_check_valid()
|
||||
component_template = build_langchain_template_custom_component(
|
||||
component_extractor
|
||||
)
|
||||
|
||||
valid_menu[menu_name][component_name] = component_template
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(f"Error while building custom component: {exc}")
|
||||
|
||||
return valid_menu
|
||||
|
||||
|
||||
def build_invalid_menu(invalid_components):
|
||||
"""Build the invalid menu"""
|
||||
invalid_menu = {}
|
||||
for menu_item in invalid_components["menu"]:
|
||||
menu_name = menu_item["name"]
|
||||
invalid_menu[menu_name] = {}
|
||||
|
||||
for component in menu_item["components"]:
|
||||
try:
|
||||
component_name = component["name"]
|
||||
component_code = component["code"]
|
||||
|
||||
component_template = (
|
||||
CustomComponentFrontendNode(
|
||||
description="ERROR - Check your Python Code",
|
||||
display_name=f"ERROR - {component_name}",
|
||||
)
|
||||
.to_dict()
|
||||
.get(type(CustomComponent()).__name__)
|
||||
)
|
||||
|
||||
component_template["error"] = component.get("error", None)
|
||||
component_template.get("template").get("code")["value"] = component_code
|
||||
|
||||
invalid_menu[menu_name][component_name] = component_template
|
||||
|
||||
except Exception as exc:
|
||||
logger.error(
|
||||
f"Error while creating custom component [{component_name}]: {str(exc)}"
|
||||
)
|
||||
|
||||
return invalid_menu
|
||||
|
||||
|
||||
def build_langchain_custom_component_list_from_path(path: str):
|
||||
"""Build a list of custom components for the langchain from a given path"""
|
||||
file_list = load_files_from_path(path)
|
||||
reader = DirectoryReader(path, False)
|
||||
|
||||
valid_components, invalid_components = build_and_validate_all_files(
|
||||
reader, file_list
|
||||
)
|
||||
|
||||
valid_menu = build_valid_menu(valid_components)
|
||||
invalid_menu = build_invalid_menu(invalid_components)
|
||||
|
||||
return merge_nested_dicts(valid_menu, invalid_menu)
|
||||
|
|
|
|||
|
|
@ -4,26 +4,27 @@ import os
|
|||
from io import BytesIO
|
||||
import re
|
||||
|
||||
|
||||
import yaml
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from PIL.Image import Image
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.chat.config import ChatConfig
|
||||
|
||||
|
||||
def load_file_into_dict(file_path: str) -> dict:
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
file_extension = os.path.splitext(file_path)[1].lower()
|
||||
|
||||
if file_extension == ".json":
|
||||
with open(file_path, "r") as json_file:
|
||||
data = json.load(json_file)
|
||||
elif file_extension in [".yaml", ".yml"]:
|
||||
with open(file_path, "r") as yaml_file:
|
||||
data = yaml.safe_load(yaml_file)
|
||||
else:
|
||||
raise ValueError("Unsupported file type. Please provide a JSON or YAML file.")
|
||||
|
||||
# Files names are UUID, so we can't find the extension
|
||||
with open(file_path, "r") as file:
|
||||
try:
|
||||
data = json.load(file)
|
||||
except json.JSONDecodeError:
|
||||
file.seek(0)
|
||||
data = yaml.safe_load(file)
|
||||
except ValueError as exc:
|
||||
raise ValueError("Invalid file type. Expected .json or .yaml.") from exc
|
||||
return data
|
||||
|
||||
|
||||
|
|
@ -48,9 +49,9 @@ def try_setting_streaming_options(langchain_object, websocket):
|
|||
|
||||
if isinstance(llm, BaseLanguageModel):
|
||||
if hasattr(llm, "streaming") and isinstance(llm.streaming, bool):
|
||||
llm.streaming = True
|
||||
llm.streaming = ChatConfig.streaming
|
||||
elif hasattr(llm, "stream") and isinstance(llm.stream, bool):
|
||||
llm.stream = True
|
||||
llm.stream = ChatConfig.streaming
|
||||
|
||||
return langchain_object
|
||||
|
||||
|
|
@ -58,3 +59,29 @@ def try_setting_streaming_options(langchain_object, websocket):
|
|||
def extract_input_variables_from_prompt(prompt: str) -> list[str]:
|
||||
"""Extract input variables from prompt."""
|
||||
return re.findall(r"{(.*?)}", prompt)
|
||||
|
||||
|
||||
def setup_llm_caching():
|
||||
"""Setup LLM caching."""
|
||||
|
||||
from langflow.settings import settings
|
||||
|
||||
try:
|
||||
set_langchain_cache(settings)
|
||||
except ImportError:
|
||||
logger.warning(f"Could not import {settings.cache}. ")
|
||||
except Exception as exc:
|
||||
logger.warning(f"Could not setup LLM caching. Error: {exc}")
|
||||
|
||||
|
||||
# TODO Rename this here and in `setup_llm_caching`
|
||||
def set_langchain_cache(settings):
|
||||
import langchain
|
||||
from langflow.interface.importing.utils import import_class
|
||||
|
||||
cache_type = os.getenv("LANGFLOW_LANGCHAIN_CACHE")
|
||||
cache_class = import_class(f"langchain.cache.{cache_type or settings.cache}")
|
||||
|
||||
logger.debug(f"Setting up LLM caching with {cache_class.__name__}")
|
||||
langchain.llm_cache = cache_class()
|
||||
logger.info(f"LLM caching setup with {cache_class.__name__}")
|
||||
|
|
|
|||
|
|
@ -1,25 +1,36 @@
|
|||
from typing import Dict, List, Optional
|
||||
|
||||
from langchain import requests
|
||||
from langchain import requests, sql_database
|
||||
|
||||
from langflow.interface.base import LangChainTypeCreator
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.utils.util import build_template_from_class
|
||||
from langflow.utils.util import build_template_from_class, build_template_from_method
|
||||
|
||||
|
||||
class WrapperCreator(LangChainTypeCreator):
|
||||
type_name: str = "wrappers"
|
||||
|
||||
from_method_nodes = {"SQLDatabase": "from_uri"}
|
||||
|
||||
@property
|
||||
def type_to_loader_dict(self) -> Dict:
|
||||
if self.type_dict is None:
|
||||
self.type_dict = {
|
||||
wrapper.__name__: wrapper for wrapper in [requests.TextRequestsWrapper]
|
||||
wrapper.__name__: wrapper
|
||||
for wrapper in [requests.TextRequestsWrapper, sql_database.SQLDatabase]
|
||||
}
|
||||
return self.type_dict
|
||||
|
||||
def get_signature(self, name: str) -> Optional[Dict]:
|
||||
try:
|
||||
if name in self.from_method_nodes:
|
||||
return build_template_from_method(
|
||||
name,
|
||||
type_to_cls_dict=self.type_to_loader_dict,
|
||||
add_function=True,
|
||||
method_name=self.from_method_nodes[name],
|
||||
)
|
||||
|
||||
return build_template_from_class(name, self.type_to_loader_dict)
|
||||
except ValueError as exc:
|
||||
raise ValueError("Wrapper not found") from exc
|
||||
|
|
|
|||
|
|
@ -1,16 +1,15 @@
|
|||
# This file is used by lc-serve to load the mounted app and serve it.
|
||||
|
||||
from pathlib import Path
|
||||
import os
|
||||
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
# Use the JCLOUD_WORKSPACE for db URL if it's provided by JCloud.
|
||||
if "JCLOUD_WORKSPACE" in os.environ:
|
||||
os.environ[
|
||||
"LANGFLOW_DATABASE_URL"
|
||||
] = f"sqlite:///{os.environ['JCLOUD_WORKSPACE']}/langflow.db"
|
||||
|
||||
from langflow.main import create_app
|
||||
from langflow.main import setup_app
|
||||
from langflow.utils.logger import configure
|
||||
|
||||
app = create_app()
|
||||
path = Path(__file__).parent
|
||||
static_files_dir = path / "frontend"
|
||||
app.mount(
|
||||
"/",
|
||||
StaticFiles(directory=static_files_dir, html=True),
|
||||
name="static",
|
||||
)
|
||||
configure(log_level="DEBUG")
|
||||
app = setup_app()
|
||||
|
|
|
|||
|
|
@ -1,13 +1,21 @@
|
|||
from pathlib import Path
|
||||
from typing import Optional
|
||||
from fastapi import FastAPI
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import FileResponse
|
||||
from fastapi.staticfiles import StaticFiles
|
||||
|
||||
from langflow.api import router
|
||||
from langflow.routers import login, users, items, health
|
||||
from langflow.database.base import create_db_and_tables
|
||||
from langflow.interface.utils import setup_llm_caching
|
||||
from langflow.utils.logger import configure
|
||||
|
||||
|
||||
def create_app():
|
||||
"""Create the FastAPI app and include the router."""
|
||||
configure()
|
||||
|
||||
app = FastAPI()
|
||||
|
||||
origins = ["*"]
|
||||
|
|
@ -27,12 +35,62 @@ def create_app():
|
|||
app.include_router(router)
|
||||
|
||||
app.on_event("startup")(create_db_and_tables)
|
||||
app.on_event("startup")(setup_llm_caching)
|
||||
return app
|
||||
|
||||
|
||||
app = create_app()
|
||||
def setup_static_files(app: FastAPI, static_files_dir: Path):
|
||||
"""
|
||||
Setup the static files directory.
|
||||
Args:
|
||||
app (FastAPI): FastAPI app.
|
||||
path (str): Path to the static files directory.
|
||||
"""
|
||||
app.mount(
|
||||
"/",
|
||||
StaticFiles(directory=static_files_dir, html=True),
|
||||
name="static",
|
||||
)
|
||||
|
||||
@app.exception_handler(404)
|
||||
async def custom_404_handler(request, __):
|
||||
path = static_files_dir / "index.html"
|
||||
|
||||
if not path.exists():
|
||||
raise RuntimeError(f"File at path {path} does not exist.")
|
||||
return FileResponse(path)
|
||||
|
||||
|
||||
def get_static_files_dir():
|
||||
"""Get the static files directory relative to Langflow's main.py file."""
|
||||
frontend_path = Path(__file__).parent
|
||||
return frontend_path / "frontend"
|
||||
|
||||
|
||||
def setup_app(static_files_dir: Optional[Path] = None) -> FastAPI:
|
||||
"""Setup the FastAPI app."""
|
||||
# get the directory of the current file
|
||||
if not static_files_dir:
|
||||
static_files_dir = get_static_files_dir()
|
||||
|
||||
if not static_files_dir or not static_files_dir.exists():
|
||||
raise RuntimeError(
|
||||
f"Static files directory {static_files_dir} does not exist.")
|
||||
app = create_app()
|
||||
setup_static_files(app, static_files_dir)
|
||||
return app
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
import uvicorn
|
||||
uvicorn.run(app, host="127.0.0.1", port=7860)
|
||||
from langflow.utils.util import get_number_of_workers
|
||||
|
||||
configure()
|
||||
uvicorn.run(
|
||||
create_app,
|
||||
host="127.0.0.1",
|
||||
port=7860,
|
||||
workers=get_number_of_workers(),
|
||||
log_level="debug",
|
||||
reload=True,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,3 +1,4 @@
|
|||
from typing import Union
|
||||
from langflow.api.v1.callback import (
|
||||
AsyncStreamingLLMCallbackHandler,
|
||||
StreamingLLMCallbackHandler,
|
||||
|
|
@ -6,39 +7,31 @@ from langflow.processing.process import fix_memory_inputs, format_actions
|
|||
from langflow.utils.logger import logger
|
||||
|
||||
|
||||
async def get_result_and_steps(langchain_object, message: str, **kwargs):
|
||||
async def get_result_and_steps(langchain_object, inputs: Union[dict, str], **kwargs):
|
||||
"""Get result and thought from extracted json"""
|
||||
|
||||
try:
|
||||
if hasattr(langchain_object, "verbose"):
|
||||
langchain_object.verbose = True
|
||||
chat_input = None
|
||||
memory_key = ""
|
||||
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
|
||||
memory_key = langchain_object.memory.memory_key
|
||||
|
||||
if hasattr(langchain_object, "input_keys"):
|
||||
for key in langchain_object.input_keys:
|
||||
if key not in [memory_key, "chat_history"]:
|
||||
chat_input = {key: message}
|
||||
else:
|
||||
chat_input = message # type: ignore
|
||||
|
||||
if hasattr(langchain_object, "return_intermediate_steps"):
|
||||
# https://github.com/hwchase17/langchain/issues/2068
|
||||
# Deactivating until we have a frontend solution
|
||||
# to display intermediate steps
|
||||
langchain_object.return_intermediate_steps = True
|
||||
try:
|
||||
fix_memory_inputs(langchain_object)
|
||||
except Exception as exc:
|
||||
logger.error(f"Error fixing memory inputs: {exc}")
|
||||
|
||||
fix_memory_inputs(langchain_object)
|
||||
try:
|
||||
async_callbacks = [AsyncStreamingLLMCallbackHandler(**kwargs)]
|
||||
output = await langchain_object.acall(chat_input, callbacks=async_callbacks)
|
||||
output = await langchain_object.acall(inputs, callbacks=async_callbacks)
|
||||
except Exception as exc:
|
||||
# make the error message more informative
|
||||
logger.debug(f"Error: {str(exc)}")
|
||||
sync_callbacks = [StreamingLLMCallbackHandler(**kwargs)]
|
||||
output = langchain_object(chat_input, callbacks=sync_callbacks)
|
||||
output = langchain_object(inputs, callbacks=sync_callbacks)
|
||||
|
||||
intermediate_steps = (
|
||||
output.get("intermediate_steps", []) if isinstance(output, dict) else []
|
||||
|
|
@ -49,7 +42,12 @@ async def get_result_and_steps(langchain_object, message: str, **kwargs):
|
|||
if isinstance(output, dict)
|
||||
else output
|
||||
)
|
||||
thought = format_actions(intermediate_steps) if intermediate_steps else ""
|
||||
try:
|
||||
thought = format_actions(intermediate_steps) if intermediate_steps else ""
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
thought = ""
|
||||
except Exception as exc:
|
||||
logger.exception(exc)
|
||||
raise ValueError(f"Error: {str(exc)}") from exc
|
||||
return result, thought
|
||||
|
|
|
|||
|
|
@ -1,17 +1,16 @@
|
|||
import contextlib
|
||||
import io
|
||||
from pathlib import Path
|
||||
from langchain.schema import AgentAction
|
||||
import json
|
||||
from langflow.interface.run import (
|
||||
build_langchain_object_with_caching,
|
||||
build_sorted_vertices_with_caching,
|
||||
get_memory_key,
|
||||
update_memory_keys,
|
||||
)
|
||||
from langflow.utils.logger import logger
|
||||
from langflow.graph import Graph
|
||||
|
||||
|
||||
from typing import Any, Dict, List, Tuple
|
||||
from langchain.chains.base import Chain
|
||||
from langchain.vectorstores.base import VectorStore
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
|
||||
def fix_memory_inputs(langchain_object):
|
||||
|
|
@ -20,22 +19,26 @@ def fix_memory_inputs(langchain_object):
|
|||
object's input variables. If so, it does nothing. Otherwise, it gets a possible new memory key using the
|
||||
get_memory_key function and updates the memory keys using the update_memory_keys function.
|
||||
"""
|
||||
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
|
||||
try:
|
||||
if langchain_object.memory.memory_key in langchain_object.input_variables:
|
||||
return
|
||||
except AttributeError:
|
||||
input_variables = (
|
||||
langchain_object.prompt.input_variables
|
||||
if hasattr(langchain_object, "prompt")
|
||||
else langchain_object.input_keys
|
||||
)
|
||||
if langchain_object.memory.memory_key in input_variables:
|
||||
return
|
||||
if not hasattr(langchain_object, "memory") or langchain_object.memory is None:
|
||||
return
|
||||
try:
|
||||
if (
|
||||
hasattr(langchain_object.memory, "memory_key")
|
||||
and langchain_object.memory.memory_key in langchain_object.input_variables
|
||||
):
|
||||
return
|
||||
except AttributeError:
|
||||
input_variables = (
|
||||
langchain_object.prompt.input_variables
|
||||
if hasattr(langchain_object, "prompt")
|
||||
else langchain_object.input_keys
|
||||
)
|
||||
if langchain_object.memory.memory_key in input_variables:
|
||||
return
|
||||
|
||||
possible_new_mem_key = get_memory_key(langchain_object)
|
||||
if possible_new_mem_key is not None:
|
||||
update_memory_keys(langchain_object, possible_new_mem_key)
|
||||
possible_new_mem_key = get_memory_key(langchain_object)
|
||||
if possible_new_mem_key is not None:
|
||||
update_memory_keys(langchain_object, possible_new_mem_key)
|
||||
|
||||
|
||||
def format_actions(actions: List[Tuple[AgentAction, str]]) -> str:
|
||||
|
|
@ -54,69 +57,59 @@ def format_actions(actions: List[Tuple[AgentAction, str]]) -> str:
|
|||
return "\n".join(output)
|
||||
|
||||
|
||||
def get_result_and_thought(langchain_object, message: str):
|
||||
def get_result_and_thought(langchain_object: Any, inputs: dict):
|
||||
"""Get result and thought from extracted json"""
|
||||
try:
|
||||
if hasattr(langchain_object, "verbose"):
|
||||
langchain_object.verbose = True
|
||||
chat_input = None
|
||||
memory_key = ""
|
||||
if hasattr(langchain_object, "memory") and langchain_object.memory is not None:
|
||||
memory_key = langchain_object.memory.memory_key
|
||||
|
||||
if hasattr(langchain_object, "input_keys"):
|
||||
for key in langchain_object.input_keys:
|
||||
if key not in [memory_key, "chat_history"]:
|
||||
chat_input = {key: message}
|
||||
else:
|
||||
chat_input = message # type: ignore
|
||||
|
||||
if hasattr(langchain_object, "return_intermediate_steps"):
|
||||
# https://github.com/hwchase17/langchain/issues/2068
|
||||
# Deactivating until we have a frontend solution
|
||||
# to display intermediate steps
|
||||
langchain_object.return_intermediate_steps = False
|
||||
langchain_object.return_intermediate_steps = True
|
||||
|
||||
fix_memory_inputs(langchain_object)
|
||||
|
||||
with io.StringIO() as output_buffer, contextlib.redirect_stdout(output_buffer):
|
||||
try:
|
||||
# if hasattr(langchain_object, "acall"):
|
||||
# output = await langchain_object.acall(chat_input)
|
||||
# else:
|
||||
output = langchain_object(chat_input)
|
||||
except ValueError as exc:
|
||||
# make the error message more informative
|
||||
logger.debug(f"Error: {str(exc)}")
|
||||
output = langchain_object.run(chat_input)
|
||||
|
||||
intermediate_steps = (
|
||||
output.get("intermediate_steps", []) if isinstance(output, dict) else []
|
||||
)
|
||||
|
||||
result = (
|
||||
output.get(langchain_object.output_keys[0])
|
||||
if isinstance(output, dict)
|
||||
else output
|
||||
)
|
||||
if intermediate_steps:
|
||||
thought = format_actions(intermediate_steps)
|
||||
else:
|
||||
thought = output_buffer.getvalue()
|
||||
try:
|
||||
output = langchain_object(inputs, return_only_outputs=True)
|
||||
except ValueError as exc:
|
||||
# make the error message more informative
|
||||
logger.debug(f"Error: {str(exc)}")
|
||||
output = langchain_object.run(inputs)
|
||||
|
||||
except Exception as exc:
|
||||
raise ValueError(f"Error: {str(exc)}") from exc
|
||||
return result, thought
|
||||
return output
|
||||
|
||||
|
||||
def process_graph_cached(data_graph: Dict[str, Any], message: str):
|
||||
def get_input_str_if_only_one_input(inputs: dict) -> Optional[str]:
|
||||
"""Get input string if only one input is provided"""
|
||||
return list(inputs.values())[0] if len(inputs) == 1 else None
|
||||
|
||||
|
||||
def process_graph_cached(
|
||||
data_graph: Dict[str, Any], inputs: Optional[dict] = None, clear_cache=False
|
||||
):
|
||||
"""
|
||||
Process graph by extracting input variables and replacing ZeroShotPrompt
|
||||
with PromptTemplate,then run the graph and return the result and thought.
|
||||
"""
|
||||
# Load langchain object
|
||||
langchain_object = build_langchain_object_with_caching(data_graph)
|
||||
logger.debug("Loaded langchain object")
|
||||
if clear_cache:
|
||||
build_sorted_vertices_with_caching.clear_cache()
|
||||
logger.debug("Cleared cache")
|
||||
langchain_object, artifacts = build_sorted_vertices_with_caching(data_graph)
|
||||
logger.debug("Loaded LangChain object")
|
||||
if inputs is None:
|
||||
inputs = {}
|
||||
|
||||
# Add artifacts to inputs
|
||||
# artifacts can be documents loaded when building
|
||||
# the flow
|
||||
for (
|
||||
key,
|
||||
value,
|
||||
) in artifacts.items():
|
||||
if key not in inputs or not inputs[key]:
|
||||
inputs[key] = value
|
||||
|
||||
if langchain_object is None:
|
||||
# Raise user facing error
|
||||
|
|
@ -125,63 +118,124 @@ def process_graph_cached(data_graph: Dict[str, Any], message: str):
|
|||
)
|
||||
|
||||
# Generate result and thought
|
||||
logger.debug("Generating result and thought")
|
||||
result, thought = get_result_and_thought(langchain_object, message)
|
||||
logger.debug("Generated result and thought")
|
||||
return {"result": str(result), "thought": thought.strip()}
|
||||
if isinstance(langchain_object, Chain):
|
||||
if inputs is None:
|
||||
raise ValueError("Inputs must be provided for a Chain")
|
||||
logger.debug("Generating result and thought")
|
||||
result = get_result_and_thought(langchain_object, inputs)
|
||||
logger.debug("Generated result and thought")
|
||||
elif isinstance(langchain_object, VectorStore):
|
||||
result = langchain_object.search(**inputs)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown langchain_object type: {type(langchain_object).__name__}"
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
def load_flow_from_json(path: str, build=True):
|
||||
"""Load flow from json file"""
|
||||
# This is done to avoid circular imports
|
||||
def load_flow_from_json(
|
||||
flow: Union[Path, str, dict], tweaks: Optional[dict] = None, build=True
|
||||
):
|
||||
"""
|
||||
Load flow from a JSON file or a JSON object.
|
||||
|
||||
with open(path, "r", encoding="utf-8") as f:
|
||||
flow_graph = json.load(f)
|
||||
data_graph = flow_graph["data"]
|
||||
nodes = data_graph["nodes"]
|
||||
# Substitute ZeroShotPrompt with PromptTemplate
|
||||
# nodes = replace_zero_shot_prompt_with_prompt_template(nodes)
|
||||
# Add input variables
|
||||
# nodes = payload.extract_input_variables(nodes)
|
||||
:param flow: JSON file path or JSON object
|
||||
:param tweaks: Optional tweaks to be processed
|
||||
:param build: If True, build the graph, otherwise return the graph object
|
||||
:return: Langchain object or Graph object depending on the build parameter
|
||||
"""
|
||||
# If input is a file path, load JSON from the file
|
||||
if isinstance(flow, (str, Path)):
|
||||
with open(flow, "r", encoding="utf-8") as f:
|
||||
flow_graph = json.load(f)
|
||||
# If input is a dictionary, assume it's a JSON object
|
||||
elif isinstance(flow, dict):
|
||||
flow_graph = flow
|
||||
else:
|
||||
raise TypeError(
|
||||
"Input must be either a file path (str) or a JSON object (dict)"
|
||||
)
|
||||
|
||||
# Nodes, edges and root node
|
||||
edges = data_graph["edges"]
|
||||
graph_data = flow_graph["data"]
|
||||
if tweaks is not None:
|
||||
graph_data = process_tweaks(graph_data, tweaks)
|
||||
nodes = graph_data["nodes"]
|
||||
edges = graph_data["edges"]
|
||||
graph = Graph(nodes, edges)
|
||||
|
||||
if build:
|
||||
langchain_object = graph.build()
|
||||
|
||||
if hasattr(langchain_object, "verbose"):
|
||||
langchain_object.verbose = True
|
||||
|
||||
if hasattr(langchain_object, "return_intermediate_steps"):
|
||||
# https://github.com/hwchase17/langchain/issues/2068
|
||||
# Deactivating until we have a frontend solution
|
||||
# to display intermediate steps
|
||||
langchain_object.return_intermediate_steps = False
|
||||
|
||||
fix_memory_inputs(langchain_object)
|
||||
return langchain_object
|
||||
|
||||
return graph
|
||||
|
||||
|
||||
def process_tweaks(graph_data: Dict, tweaks: Dict):
|
||||
"""This function is used to tweak the graph data using the node id and the tweaks dict"""
|
||||
# the tweaks dict is a dict of dicts
|
||||
# the key is the node id and the value is a dict of the tweaks
|
||||
# the dict of tweaks contains the name of a certain parameter and the value to be tweaked
|
||||
def validate_input(
|
||||
graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]
|
||||
) -> List[Dict[str, Any]]:
|
||||
if not isinstance(graph_data, dict) or not isinstance(tweaks, dict):
|
||||
raise ValueError("graph_data and tweaks should be dictionaries")
|
||||
|
||||
nodes = graph_data.get("data", {}).get("nodes") or graph_data.get("nodes")
|
||||
|
||||
if not isinstance(nodes, list):
|
||||
raise ValueError(
|
||||
"graph_data should contain a list of nodes under 'data' key or directly under 'nodes' key"
|
||||
)
|
||||
|
||||
return nodes
|
||||
|
||||
|
||||
def apply_tweaks(node: Dict[str, Any], node_tweaks: Dict[str, Any]) -> None:
|
||||
template_data = node.get("data", {}).get("node", {}).get("template")
|
||||
|
||||
if not isinstance(template_data, dict):
|
||||
logger.warning(
|
||||
f"Template data for node {node.get('id')} should be a dictionary"
|
||||
)
|
||||
return
|
||||
|
||||
for tweak_name, tweak_value in node_tweaks.items():
|
||||
if tweak_name and tweak_value and tweak_name in template_data:
|
||||
key = tweak_name if tweak_name == "file_path" else "value"
|
||||
template_data[tweak_name][key] = tweak_value
|
||||
|
||||
|
||||
def process_tweaks(
|
||||
graph_data: Dict[str, Any], tweaks: Dict[str, Dict[str, Any]]
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
This function is used to tweak the graph data using the node id and the tweaks dict.
|
||||
|
||||
:param graph_data: The dictionary containing the graph data. It must contain a 'data' key with
|
||||
'nodes' as its child or directly contain 'nodes' key. Each node should have an 'id' and 'data'.
|
||||
:param tweaks: A dictionary where the key is the node id and the value is a dictionary of the tweaks.
|
||||
The inner dictionary contains the name of a certain parameter as the key and the value to be tweaked.
|
||||
|
||||
:return: The modified graph_data dictionary.
|
||||
|
||||
:raises ValueError: If the input is not in the expected format.
|
||||
"""
|
||||
nodes = validate_input(graph_data, tweaks)
|
||||
|
||||
# We need to process the graph data to add the tweaks
|
||||
if "data" not in graph_data and "nodes" in graph_data:
|
||||
nodes = graph_data["nodes"]
|
||||
else:
|
||||
nodes = graph_data["data"]["nodes"]
|
||||
for node in nodes:
|
||||
node_id = node["id"]
|
||||
if node_id in tweaks:
|
||||
node_tweaks = tweaks[node_id]
|
||||
template_data = node["data"]["node"]["template"]
|
||||
for tweak_name, tweake_value in node_tweaks.items():
|
||||
if tweak_name in template_data:
|
||||
template_data[tweak_name]["value"] = tweake_value
|
||||
print(
|
||||
f"Something changed in node {node_id} with tweak {tweak_name} and value {tweake_value}"
|
||||
)
|
||||
if isinstance(node, dict) and isinstance(node.get("id"), str):
|
||||
node_id = node["id"]
|
||||
if node_tweaks := tweaks.get(node_id):
|
||||
apply_tweaks(node, node_tweaks)
|
||||
else:
|
||||
logger.warning(
|
||||
"Each node should be a dictionary with an 'id' key of type str"
|
||||
)
|
||||
|
||||
return graph_data
|
||||
|
|
|
|||
|
|
@ -1,32 +1,71 @@
|
|||
import os
|
||||
from typing import List
|
||||
from typing import Optional, List
|
||||
from pathlib import Path
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseSettings, root_validator
|
||||
from langflow.utils.logger import logger
|
||||
|
||||
BASE_COMPONENTS_PATH = Path(__file__).parent / "components"
|
||||
|
||||
|
||||
class Settings(BaseSettings):
|
||||
chains: List[str] = []
|
||||
agents: List[str] = []
|
||||
prompts: List[str] = []
|
||||
llms: List[str] = []
|
||||
tools: List[str] = []
|
||||
memories: List[str] = []
|
||||
embeddings: List[str] = []
|
||||
vectorstores: List[str] = []
|
||||
documentloaders: List[str] = []
|
||||
wrappers: List[str] = []
|
||||
toolkits: List[str] = []
|
||||
textsplitters: List[str] = []
|
||||
utilities: List[str] = []
|
||||
chains: dict = {}
|
||||
agents: dict = {}
|
||||
prompts: dict = {}
|
||||
llms: dict = {}
|
||||
tools: dict = {}
|
||||
memories: dict = {}
|
||||
embeddings: dict = {}
|
||||
vectorstores: dict = {}
|
||||
documentloaders: dict = {}
|
||||
wrappers: dict = {}
|
||||
retrievers: dict = {}
|
||||
toolkits: dict = {}
|
||||
textsplitters: dict = {}
|
||||
utilities: dict = {}
|
||||
output_parsers: dict = {}
|
||||
custom_components: dict = {}
|
||||
|
||||
dev: bool = False
|
||||
database_url: str = "sqlite:///./langflow.db"
|
||||
database_url: Optional[str] = None
|
||||
cache: str = "InMemoryCache"
|
||||
remove_api_keys: bool = False
|
||||
components_path: List[Path]
|
||||
|
||||
@root_validator(pre=True)
|
||||
def set_env_variables(cls, values):
|
||||
if "database_url" not in values:
|
||||
logger.debug(
|
||||
"No database_url provided, trying LANGFLOW_DATABASE_URL env variable"
|
||||
)
|
||||
if langflow_database_url := os.getenv("LANGFLOW_DATABASE_URL"):
|
||||
values["database_url"] = langflow_database_url
|
||||
else:
|
||||
logger.debug("No DATABASE_URL env variable, using sqlite database")
|
||||
values["database_url"] = "sqlite:///./langflow.db"
|
||||
|
||||
if not values.get("components_path"):
|
||||
values["components_path"] = [BASE_COMPONENTS_PATH]
|
||||
logger.debug("No components_path provided, using default components path")
|
||||
elif BASE_COMPONENTS_PATH not in values["components_path"]:
|
||||
values["components_path"].append(BASE_COMPONENTS_PATH)
|
||||
logger.debug("Adding default components path to components_path")
|
||||
|
||||
if os.getenv("LANGFLOW_COMPONENTS_PATH"):
|
||||
logger.debug("Adding LANGFLOW_COMPONENTS_PATH to components_path")
|
||||
langflow_component_path = Path(os.getenv("LANGFLOW_COMPONENTS_PATH"))
|
||||
if (
|
||||
langflow_component_path.exists()
|
||||
and langflow_component_path not in values["components_path"]
|
||||
):
|
||||
values["components_path"].append(langflow_component_path)
|
||||
logger.debug(f"Adding {langflow_component_path} to components_path")
|
||||
return values
|
||||
|
||||
class Config:
|
||||
validate_assignment = True
|
||||
extra = "ignore"
|
||||
env_prefix = "LANGFLOW_"
|
||||
|
||||
@root_validator(allow_reuse=True)
|
||||
def validate_lists(cls, values):
|
||||
|
|
@ -37,22 +76,35 @@ class Settings(BaseSettings):
|
|||
|
||||
def update_from_yaml(self, file_path: str, dev: bool = False):
|
||||
new_settings = load_settings_from_yaml(file_path)
|
||||
self.chains = new_settings.chains or []
|
||||
self.agents = new_settings.agents or []
|
||||
self.prompts = new_settings.prompts or []
|
||||
self.llms = new_settings.llms or []
|
||||
self.tools = new_settings.tools or []
|
||||
self.memories = new_settings.memories or []
|
||||
self.wrappers = new_settings.wrappers or []
|
||||
self.toolkits = new_settings.toolkits or []
|
||||
self.textsplitters = new_settings.textsplitters or []
|
||||
self.utilities = new_settings.utilities or []
|
||||
self.chains = new_settings.chains or {}
|
||||
self.agents = new_settings.agents or {}
|
||||
self.prompts = new_settings.prompts or {}
|
||||
self.llms = new_settings.llms or {}
|
||||
self.tools = new_settings.tools or {}
|
||||
self.memories = new_settings.memories or {}
|
||||
self.wrappers = new_settings.wrappers or {}
|
||||
self.toolkits = new_settings.toolkits or {}
|
||||
self.textsplitters = new_settings.textsplitters or {}
|
||||
self.utilities = new_settings.utilities or {}
|
||||
self.embeddings = new_settings.embeddings or {}
|
||||
self.vectorstores = new_settings.vectorstores or {}
|
||||
self.documentloaders = new_settings.documentloaders or {}
|
||||
self.retrievers = new_settings.retrievers or {}
|
||||
self.output_parsers = new_settings.output_parsers or {}
|
||||
self.custom_components = new_settings.custom_components or {}
|
||||
self.components_path = new_settings.components_path or []
|
||||
self.dev = dev
|
||||
|
||||
def update_settings(self, **kwargs):
|
||||
for key, value in kwargs.items():
|
||||
if hasattr(self, key):
|
||||
setattr(self, key, value)
|
||||
if isinstance(getattr(self, key), list):
|
||||
if isinstance(value, list):
|
||||
getattr(self, key).extend(value)
|
||||
else:
|
||||
getattr(self, key).append(value)
|
||||
else:
|
||||
setattr(self, key, value)
|
||||
|
||||
|
||||
def save_settings_to_yaml(settings: Settings, file_path: str):
|
||||
|
|
|
|||
|
|
@ -6,21 +6,58 @@ from pydantic import BaseModel
|
|||
|
||||
class TemplateFieldCreator(BaseModel, ABC):
|
||||
field_type: str = "str"
|
||||
"""The type of field this is. Default is a string."""
|
||||
|
||||
required: bool = False
|
||||
"""Specifies if the field is required. Defaults to False."""
|
||||
|
||||
placeholder: str = ""
|
||||
"""A placeholder string for the field. Default is an empty string."""
|
||||
|
||||
is_list: bool = False
|
||||
"""Defines if the field is a list. Default is False."""
|
||||
|
||||
show: bool = True
|
||||
"""Should the field be shown. Defaults to True."""
|
||||
|
||||
multiline: bool = False
|
||||
"""Defines if the field will allow the user to open a text editor. Default is False."""
|
||||
|
||||
value: Any = None
|
||||
"""The value of the field. Default is None."""
|
||||
|
||||
suffixes: list[str] = []
|
||||
fileTypes: list[str] = []
|
||||
"""List of suffixes for a file field. Default is an empty list."""
|
||||
|
||||
file_types: list[str] = []
|
||||
content: Union[str, None] = None
|
||||
"""List of file types associated with the field. Default is an empty list. (duplicate)"""
|
||||
|
||||
file_path: Union[str, None] = None
|
||||
"""The file path of the field if it is a file. Defaults to None."""
|
||||
|
||||
password: bool = False
|
||||
"""Specifies if the field is a password. Defaults to False."""
|
||||
|
||||
options: list[str] = []
|
||||
"""List of options for the field. Only used when is_list=True. Default is an empty list."""
|
||||
|
||||
name: str = ""
|
||||
"""Name of the field. Default is an empty string."""
|
||||
|
||||
display_name: Optional[str] = None
|
||||
"""Display name of the field. Defaults to None."""
|
||||
|
||||
advanced: bool = False
|
||||
"""Specifies if the field will an advanced parameter (hidden). Defaults to False."""
|
||||
|
||||
input_types: list[str] = []
|
||||
"""List of input types for the handle when the field has more than one type. Default is an empty list."""
|
||||
|
||||
dynamic: bool = False
|
||||
"""Specifies if the field is dynamic. Defaults to False."""
|
||||
|
||||
info: Optional[str] = ""
|
||||
"""Additional information about the field to be shown in the tooltip. Defaults to an empty string."""
|
||||
|
||||
def to_dict(self):
|
||||
result = self.dict()
|
||||
|
|
@ -35,7 +72,7 @@ class TemplateFieldCreator(BaseModel, ABC):
|
|||
result["fileTypes"] = result.pop("file_types")
|
||||
|
||||
if self.field_type == "file":
|
||||
result["content"] = self.content
|
||||
result["file_path"] = self.file_path
|
||||
return result
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -9,6 +9,7 @@ from langflow.template.frontend_node import (
|
|||
vectorstores,
|
||||
documentloaders,
|
||||
textsplitters,
|
||||
custom_components,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -22,4 +23,5 @@ __all__ = [
|
|||
"vectorstores",
|
||||
"documentloaders",
|
||||
"textsplitters",
|
||||
"custom_components",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -13,6 +13,16 @@ NON_CHAT_AGENTS = {
|
|||
}
|
||||
|
||||
|
||||
class AgentFrontendNode(FrontendNode):
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
if field.name in ["suffix", "prefix"]:
|
||||
field.show = True
|
||||
if field.name == "Tools" and name == "ZeroShotAgent":
|
||||
field.field_type = "BaseTool"
|
||||
field.is_list = True
|
||||
|
||||
|
||||
class SQLAgentNode(FrontendNode):
|
||||
name: str = "SQLAgent"
|
||||
template: Template = Template(
|
||||
|
|
@ -135,7 +145,7 @@ class CSVAgentNode(FrontendNode):
|
|||
name="path",
|
||||
value="",
|
||||
suffixes=[".csv"],
|
||||
fileTypes=["csv"],
|
||||
file_types=["csv"],
|
||||
),
|
||||
TemplateField(
|
||||
field_type="BaseLanguageModel",
|
||||
|
|
|
|||
|
|
@ -1,28 +1,92 @@
|
|||
from collections import defaultdict
|
||||
import re
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from langflow.template.frontend_node.constants import FORCE_SHOW_FIELDS
|
||||
from langflow.template.frontend_node.formatter import field_formatters
|
||||
from langflow.template.frontend_node.constants import (
|
||||
CLASSES_TO_REMOVE,
|
||||
FORCE_SHOW_FIELDS,
|
||||
)
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.template.base import Template
|
||||
from langflow.utils import constants
|
||||
|
||||
|
||||
class FieldFormatters(BaseModel):
|
||||
formatters = {
|
||||
"openai_api_key": field_formatters.OpenAIAPIKeyFormatter(),
|
||||
}
|
||||
base_formatters = {
|
||||
"kwargs": field_formatters.KwargsFormatter(),
|
||||
"optional": field_formatters.RemoveOptionalFormatter(),
|
||||
"list": field_formatters.ListTypeFormatter(),
|
||||
"dict": field_formatters.DictTypeFormatter(),
|
||||
"union": field_formatters.UnionTypeFormatter(),
|
||||
"multiline": field_formatters.MultilineFieldFormatter(),
|
||||
"show": field_formatters.ShowFieldFormatter(),
|
||||
"password": field_formatters.PasswordFieldFormatter(),
|
||||
"default": field_formatters.DefaultValueFormatter(),
|
||||
"headers": field_formatters.HeadersDefaultValueFormatter(),
|
||||
"dict_code_file": field_formatters.DictCodeFileFormatter(),
|
||||
"model_fields": field_formatters.ModelSpecificFieldFormatter(),
|
||||
}
|
||||
|
||||
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
|
||||
for key, formatter in self.base_formatters.items():
|
||||
formatter.format(field, name)
|
||||
|
||||
for key, formatter in self.formatters.items():
|
||||
if key == field.name:
|
||||
formatter.format(field, name)
|
||||
|
||||
|
||||
class FrontendNode(BaseModel):
|
||||
template: Template
|
||||
description: str
|
||||
base_classes: List[str]
|
||||
name: str = ""
|
||||
display_name: str = ""
|
||||
documentation: str = ""
|
||||
custom_fields: defaultdict = defaultdict(list)
|
||||
output_types: List[str] = []
|
||||
field_formatters: FieldFormatters = Field(default_factory=FieldFormatters)
|
||||
beta: bool = False
|
||||
error: Optional[str] = None
|
||||
|
||||
# field formatters is an instance attribute but it is not used in the class
|
||||
# so we need to create a method to get it
|
||||
@staticmethod
|
||||
def get_field_formatters() -> FieldFormatters:
|
||||
return FieldFormatters()
|
||||
|
||||
def set_documentation(self, documentation: str) -> None:
|
||||
"""Sets the documentation of the frontend node."""
|
||||
self.documentation = documentation
|
||||
|
||||
def process_base_classes(self) -> None:
|
||||
"""Removes unwanted base classes from the list of base classes."""
|
||||
self.base_classes = [
|
||||
base_class
|
||||
for base_class in self.base_classes
|
||||
if base_class not in CLASSES_TO_REMOVE
|
||||
]
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
"""Returns a dict representation of the frontend node."""
|
||||
self.process_base_classes()
|
||||
return {
|
||||
self.name: {
|
||||
"template": self.template.to_dict(self.format_field),
|
||||
"description": self.description,
|
||||
"base_classes": self.base_classes,
|
||||
"display_name": self.display_name or self.name,
|
||||
"custom_fields": self.custom_fields,
|
||||
"output_types": self.output_types,
|
||||
"documentation": self.documentation,
|
||||
"beta": self.beta,
|
||||
"error": self.error,
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -35,33 +99,8 @@ class FrontendNode(BaseModel):
|
|||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
"""Formats a given field based on its attributes and value."""
|
||||
SPECIAL_FIELD_HANDLERS = {
|
||||
"allowed_tools": lambda field: "Tool",
|
||||
"max_value_length": lambda field: "int",
|
||||
}
|
||||
|
||||
key = field.name
|
||||
value = field.to_dict()
|
||||
_type = value["type"]
|
||||
|
||||
_type = FrontendNode.remove_optional(_type)
|
||||
_type, is_list = FrontendNode.check_for_list_type(_type)
|
||||
field.is_list = is_list or field.is_list
|
||||
_type = FrontendNode.replace_mapping_with_dict(_type)
|
||||
_type = FrontendNode.handle_union_type(_type)
|
||||
|
||||
field.field_type = FrontendNode.handle_special_field(
|
||||
field, key, _type, SPECIAL_FIELD_HANDLERS
|
||||
)
|
||||
field.field_type = FrontendNode.handle_dict_type(field, _type)
|
||||
field.show = FrontendNode.should_show_field(key, field.required)
|
||||
field.password = FrontendNode.should_be_password(key, field.show)
|
||||
field.multiline = FrontendNode.should_be_multiline(key)
|
||||
|
||||
FrontendNode.replace_default_value(field, value)
|
||||
FrontendNode.handle_specific_field_values(field, key, name)
|
||||
FrontendNode.handle_kwargs_field(field)
|
||||
FrontendNode.handle_api_key_field(field, key)
|
||||
FrontendNode.get_field_formatters().format(field, name)
|
||||
|
||||
@staticmethod
|
||||
def remove_optional(_type: str) -> str:
|
||||
|
|
|
|||
|
|
@ -13,17 +13,46 @@ class ChainFrontendNode(FrontendNode):
|
|||
self.template.add_field(
|
||||
TemplateField(
|
||||
field_type="BaseChatMemory",
|
||||
required=False,
|
||||
required=True,
|
||||
show=True,
|
||||
name="memory",
|
||||
advanced=False,
|
||||
)
|
||||
)
|
||||
# add return_source_documents
|
||||
self.template.add_field(
|
||||
TemplateField(
|
||||
field_type="bool",
|
||||
required=False,
|
||||
show=True,
|
||||
name="return_source_documents",
|
||||
advanced=False,
|
||||
value=True,
|
||||
display_name="Return source documents",
|
||||
)
|
||||
)
|
||||
self.template.add_field(
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
required=True,
|
||||
is_list=True,
|
||||
show=True,
|
||||
multiline=False,
|
||||
options=QA_CHAIN_TYPES,
|
||||
value=QA_CHAIN_TYPES[0],
|
||||
name="chain_type",
|
||||
advanced=False,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
FrontendNode.format_field(field, name)
|
||||
|
||||
if "name" == "RetrievalQA" and field.name == "memory":
|
||||
field.show = False
|
||||
field.required = False
|
||||
|
||||
field.advanced = False
|
||||
if "key" in field.name:
|
||||
field.password = False
|
||||
|
|
@ -47,18 +76,24 @@ class ChainFrontendNode(FrontendNode):
|
|||
field.show = True
|
||||
field.advanced = False
|
||||
if field.name == "memory":
|
||||
field.required = False
|
||||
# field.required = False
|
||||
field.show = True
|
||||
field.advanced = False
|
||||
if field.name == "verbose":
|
||||
field.required = False
|
||||
field.show = True
|
||||
field.show = False
|
||||
field.advanced = True
|
||||
if field.name == "llm":
|
||||
field.required = True
|
||||
field.show = True
|
||||
field.advanced = False
|
||||
|
||||
if field.name == "return_source_documents":
|
||||
field.required = False
|
||||
field.show = True
|
||||
field.advanced = True
|
||||
field.value = True
|
||||
|
||||
|
||||
class SeriesCharacterChainNode(FrontendNode):
|
||||
name: str = "SeriesCharacterChain"
|
||||
|
|
|
|||
|
|
@ -8,6 +8,7 @@ FORCE_SHOW_FIELDS = [
|
|||
"headers",
|
||||
"max_value_length",
|
||||
"max_tokens",
|
||||
"google_cse_id",
|
||||
]
|
||||
|
||||
DEFAULT_PROMPT = """
|
||||
|
|
@ -32,3 +33,36 @@ You are a good listener and you can talk about anything.
|
|||
HUMAN_PROMPT = "{input}"
|
||||
|
||||
QA_CHAIN_TYPES = ["stuff", "map_reduce", "map_rerank", "refine"]
|
||||
|
||||
CTRANSFORMERS_DEFAULT_CONFIG = {
|
||||
"top_k": 40,
|
||||
"top_p": 0.95,
|
||||
"temperature": 0.8,
|
||||
"repetition_penalty": 1.1,
|
||||
"last_n_tokens": 64,
|
||||
"seed": -1,
|
||||
"max_new_tokens": 256,
|
||||
"stop": None,
|
||||
"stream": False,
|
||||
"reset": True,
|
||||
"batch_size": 8,
|
||||
"threads": -1,
|
||||
"context_length": -1,
|
||||
"gpu_layers": 0,
|
||||
}
|
||||
|
||||
# This variable is used to tell the user
|
||||
# that it can be changed to use other APIs
|
||||
# like Prem and LocalAI
|
||||
OPENAI_API_BASE_INFO = """
|
||||
The base URL of the OpenAI API. Defaults to https://api.openai.com/v1.
|
||||
|
||||
You can change this to use other APIs like JinaChat, LocalAI and Prem.
|
||||
"""
|
||||
|
||||
|
||||
INPUT_KEY_INFO = """The variable to be used as Chat Input when more than one variable is available."""
|
||||
OUTPUT_KEY_INFO = """The variable to be used as Chat Output (e.g. answer in a ConversationalRetrievalChain)"""
|
||||
|
||||
|
||||
CLASSES_TO_REMOVE = ["Serializable", "BaseModel", "object", "Runnable", "Generic"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,31 @@
|
|||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.frontend_node.base import FrontendNode
|
||||
from langflow.template.template.base import Template
|
||||
from langflow.interface.custom.constants import DEFAULT_CUSTOM_COMPONENT_CODE
|
||||
|
||||
|
||||
class CustomComponentFrontendNode(FrontendNode):
|
||||
name: str = "CustomComponent"
|
||||
display_name: str = "Custom Component"
|
||||
beta: bool = True
|
||||
template: Template = Template(
|
||||
type_name="CustomComponent",
|
||||
fields=[
|
||||
TemplateField(
|
||||
field_type="code",
|
||||
required=True,
|
||||
placeholder="",
|
||||
is_list=False,
|
||||
show=True,
|
||||
value=DEFAULT_CUSTOM_COMPONENT_CODE,
|
||||
name="code",
|
||||
advanced=False,
|
||||
dynamic=True,
|
||||
)
|
||||
],
|
||||
)
|
||||
description: str = "Create any custom component you want!"
|
||||
base_classes: list[str] = []
|
||||
|
||||
def to_dict(self):
|
||||
return super().to_dict()
|
||||
|
|
@ -1,8 +1,9 @@
|
|||
from typing import Optional
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.frontend_node.base import FrontendNode
|
||||
|
||||
|
||||
def build_template(
|
||||
def build_file_field(
|
||||
suffixes: list, fileTypes: list, name: str = "file_path"
|
||||
) -> TemplateField:
|
||||
"""Build a template field for a document loader."""
|
||||
|
|
@ -13,45 +14,99 @@ def build_template(
|
|||
name=name,
|
||||
value="",
|
||||
suffixes=suffixes,
|
||||
fileTypes=fileTypes,
|
||||
file_types=fileTypes,
|
||||
)
|
||||
|
||||
|
||||
class DocumentLoaderFrontNode(FrontendNode):
|
||||
def add_extra_base_classes(self) -> None:
|
||||
self.base_classes = ["Document"]
|
||||
self.output_types = ["Document"]
|
||||
|
||||
file_path_templates = {
|
||||
"AirbyteJSONLoader": build_template(suffixes=[".json"], fileTypes=["json"]),
|
||||
"CoNLLULoader": build_template(suffixes=[".csv"], fileTypes=["csv"]),
|
||||
"CSVLoader": build_template(suffixes=[".csv"], fileTypes=["csv"]),
|
||||
"UnstructuredEmailLoader": build_template(suffixes=[".eml"], fileTypes=["eml"]),
|
||||
"EverNoteLoader": build_template(suffixes=[".xml"], fileTypes=["xml"]),
|
||||
"FacebookChatLoader": build_template(suffixes=[".json"], fileTypes=["json"]),
|
||||
"GutenbergLoader": build_template(suffixes=[".txt"], fileTypes=["txt"]),
|
||||
"BSHTMLLoader": build_template(suffixes=[".html"], fileTypes=["html"]),
|
||||
"UnstructuredHTMLLoader": build_template(
|
||||
"AirbyteJSONLoader": build_file_field(suffixes=[".json"], fileTypes=["json"]),
|
||||
"CoNLLULoader": build_file_field(suffixes=[".csv"], fileTypes=["csv"]),
|
||||
"CSVLoader": build_file_field(suffixes=[".csv"], fileTypes=["csv"]),
|
||||
"UnstructuredEmailLoader": build_file_field(
|
||||
suffixes=[".eml"], fileTypes=["eml"]
|
||||
),
|
||||
"SlackDirectoryLoader": build_file_field(suffixes=[".zip"], fileTypes=["zip"]),
|
||||
"EverNoteLoader": build_file_field(suffixes=[".xml"], fileTypes=["xml"]),
|
||||
"FacebookChatLoader": build_file_field(suffixes=[".json"], fileTypes=["json"]),
|
||||
"BSHTMLLoader": build_file_field(suffixes=[".html"], fileTypes=["html"]),
|
||||
"UnstructuredHTMLLoader": build_file_field(
|
||||
suffixes=[".html"], fileTypes=["html"]
|
||||
),
|
||||
"UnstructuredImageLoader": build_template(
|
||||
"UnstructuredImageLoader": build_file_field(
|
||||
suffixes=[".jpg", ".jpeg", ".png", ".gif", ".bmp"],
|
||||
fileTypes=["jpg", "jpeg", "png", "gif", "bmp"],
|
||||
),
|
||||
"UnstructuredMarkdownLoader": build_template(
|
||||
"UnstructuredMarkdownLoader": build_file_field(
|
||||
suffixes=[".md"], fileTypes=["md"]
|
||||
),
|
||||
"PyPDFLoader": build_template(suffixes=[".pdf"], fileTypes=["pdf"]),
|
||||
"UnstructuredPowerPointLoader": build_template(
|
||||
"PyPDFLoader": build_file_field(suffixes=[".pdf"], fileTypes=["pdf"]),
|
||||
"UnstructuredPowerPointLoader": build_file_field(
|
||||
suffixes=[".pptx", ".ppt"], fileTypes=["pptx", "ppt"]
|
||||
),
|
||||
"SRTLoader": build_template(suffixes=[".srt"], fileTypes=["srt"]),
|
||||
"TelegramChatLoader": build_template(suffixes=[".json"], fileTypes=["json"]),
|
||||
"TextLoader": build_template(suffixes=[".txt"], fileTypes=["txt"]),
|
||||
"UnstructuredWordDocumentLoader": build_template(
|
||||
"SRTLoader": build_file_field(suffixes=[".srt"], fileTypes=["srt"]),
|
||||
"TelegramChatLoader": build_file_field(suffixes=[".json"], fileTypes=["json"]),
|
||||
"TextLoader": build_file_field(suffixes=[".txt"], fileTypes=["txt"]),
|
||||
"UnstructuredWordDocumentLoader": build_file_field(
|
||||
suffixes=[".docx", ".doc"], fileTypes=["docx", "doc"]
|
||||
),
|
||||
}
|
||||
|
||||
def add_extra_fields(self) -> None:
|
||||
name = None
|
||||
if self.template.type_name in self.file_path_templates:
|
||||
display_name = "Web Page"
|
||||
if self.template.type_name in {"GitLoader"}:
|
||||
# Add fields repo_path, clone_url, branch and file_filter
|
||||
self.template.add_field(
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
required=True,
|
||||
show=True,
|
||||
name="repo_path",
|
||||
value="",
|
||||
display_name="Path to repository",
|
||||
advanced=False,
|
||||
)
|
||||
)
|
||||
self.template.add_field(
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
required=False,
|
||||
show=True,
|
||||
name="clone_url",
|
||||
value="",
|
||||
display_name="Clone URL",
|
||||
advanced=False,
|
||||
)
|
||||
)
|
||||
self.template.add_field(
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
required=True,
|
||||
show=True,
|
||||
name="branch",
|
||||
value="",
|
||||
display_name="Branch",
|
||||
advanced=False,
|
||||
)
|
||||
)
|
||||
self.template.add_field(
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
required=False,
|
||||
show=True,
|
||||
name="file_filter",
|
||||
value="",
|
||||
display_name="File extensions (comma-separated)",
|
||||
advanced=False,
|
||||
)
|
||||
)
|
||||
|
||||
elif self.template.type_name in self.file_path_templates:
|
||||
self.template.add_field(self.file_path_templates[self.template.type_name])
|
||||
elif self.template.type_name in {
|
||||
"WebBaseLoader",
|
||||
|
|
@ -60,20 +115,151 @@ class DocumentLoaderFrontNode(FrontendNode):
|
|||
"HNLoader",
|
||||
"IFixitLoader",
|
||||
"IMSDbLoader",
|
||||
"GutenbergLoader",
|
||||
}:
|
||||
name = "web_path"
|
||||
elif self.template.type_name in {"GutenbergLoader"}:
|
||||
name = "file_path"
|
||||
elif self.template.type_name in {"GitbookLoader"}:
|
||||
name = "web_page"
|
||||
elif self.template.type_name in {"ReadTheDocsLoader"}:
|
||||
elif self.template.type_name in {
|
||||
"DirectoryLoader",
|
||||
"ReadTheDocsLoader",
|
||||
"NotionDirectoryLoader",
|
||||
"PyPDFDirectoryLoader",
|
||||
}:
|
||||
name = "path"
|
||||
display_name = "Local directory"
|
||||
if name:
|
||||
self.template.add_field(
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
required=True,
|
||||
show=True,
|
||||
name=name,
|
||||
value="",
|
||||
display_name="Web Page",
|
||||
if self.template.type_name in {"DirectoryLoader"}:
|
||||
for field in build_directory_loader_fields():
|
||||
self.template.add_field(field)
|
||||
else:
|
||||
self.template.add_field(
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
required=True,
|
||||
show=True,
|
||||
name=name,
|
||||
value="",
|
||||
display_name=display_name,
|
||||
)
|
||||
)
|
||||
# add a metadata field of type dict
|
||||
self.template.add_field(
|
||||
TemplateField(
|
||||
field_type="code",
|
||||
required=True,
|
||||
show=True,
|
||||
name="metadata",
|
||||
value="{}",
|
||||
display_name="Metadata",
|
||||
multiline=False,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
FrontendNode.format_field(field, name)
|
||||
if field.name == "metadata":
|
||||
field.show = True
|
||||
field.advanced = False
|
||||
field.show = True
|
||||
|
||||
|
||||
def build_directory_loader_fields():
|
||||
# if loader_kwargs is None:
|
||||
# loader_kwargs = {}
|
||||
# self.path = path
|
||||
# self.glob = glob
|
||||
# self.load_hidden = load_hidden
|
||||
# self.loader_cls = loader_cls
|
||||
# self.loader_kwargs = loader_kwargs
|
||||
# self.silent_errors = silent_errors
|
||||
# self.recursive = recursive
|
||||
# self.show_progress = show_progress
|
||||
# self.use_multithreading = use_multithreading
|
||||
# self.max_concurrency = max_concurrency
|
||||
# Based on the above fields, we can build the following fields:
|
||||
# path, glob, load_hidden, silent_errors, recursive, show_progress, use_multithreading, max_concurrency
|
||||
# path
|
||||
path = TemplateField(
|
||||
field_type="str",
|
||||
required=True,
|
||||
show=True,
|
||||
name="path",
|
||||
value="",
|
||||
display_name="Local directory",
|
||||
advanced=False,
|
||||
)
|
||||
# glob
|
||||
glob = TemplateField(
|
||||
field_type="str",
|
||||
required=True,
|
||||
show=True,
|
||||
name="glob",
|
||||
value="**/*.txt",
|
||||
display_name="glob",
|
||||
advanced=False,
|
||||
)
|
||||
# load_hidden
|
||||
load_hidden = TemplateField(
|
||||
field_type="bool",
|
||||
required=False,
|
||||
show=True,
|
||||
name="load_hidden",
|
||||
value="False",
|
||||
display_name="Load hidden files",
|
||||
advanced=True,
|
||||
)
|
||||
# silent_errors
|
||||
silent_errors = TemplateField(
|
||||
field_type="bool",
|
||||
required=False,
|
||||
show=True,
|
||||
name="silent_errors",
|
||||
value="False",
|
||||
display_name="Silent errors",
|
||||
advanced=True,
|
||||
)
|
||||
# recursive
|
||||
recursive = TemplateField(
|
||||
field_type="bool",
|
||||
required=False,
|
||||
show=True,
|
||||
name="recursive",
|
||||
value="True",
|
||||
display_name="Recursive",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
# use_multithreading
|
||||
use_multithreading = TemplateField(
|
||||
field_type="bool",
|
||||
required=False,
|
||||
show=True,
|
||||
name="use_multithreading",
|
||||
value="True",
|
||||
display_name="Use multithreading",
|
||||
advanced=True,
|
||||
)
|
||||
# max_concurrency
|
||||
max_concurrency = TemplateField(
|
||||
field_type="int",
|
||||
required=False,
|
||||
show=True,
|
||||
name="max_concurrency",
|
||||
value=10,
|
||||
display_name="Max concurrency",
|
||||
advanced=True,
|
||||
)
|
||||
|
||||
return (
|
||||
path,
|
||||
glob,
|
||||
load_hidden,
|
||||
silent_errors,
|
||||
recursive,
|
||||
use_multithreading,
|
||||
max_concurrency,
|
||||
)
|
||||
|
|
|
|||
|
|
@ -0,0 +1,11 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from typing import Optional
|
||||
|
||||
from langflow.template.field.base import TemplateField
|
||||
from pydantic import BaseModel
|
||||
|
||||
|
||||
class FieldFormatter(BaseModel, ABC):
|
||||
@abstractmethod
|
||||
def format(self, field: TemplateField, name: Optional[str]) -> None:
|
||||
pass
|
||||
|
|
@ -0,0 +1,162 @@
|
|||
from typing import Optional
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.frontend_node.constants import FORCE_SHOW_FIELDS
|
||||
from langflow.template.frontend_node.formatter.base import FieldFormatter
|
||||
import re
|
||||
|
||||
from langflow.utils.constants import (
|
||||
ANTHROPIC_MODELS,
|
||||
CHAT_OPENAI_MODELS,
|
||||
OPENAI_MODELS,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIAPIKeyFormatter(FieldFormatter):
|
||||
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
|
||||
if "api_key" in field.name and "OpenAI" in str(name):
|
||||
field.display_name = "OpenAI API Key"
|
||||
field.required = False
|
||||
if field.value is None:
|
||||
field.value = ""
|
||||
|
||||
|
||||
class ModelSpecificFieldFormatter(FieldFormatter):
|
||||
MODEL_DICT = {
|
||||
"OpenAI": OPENAI_MODELS,
|
||||
"ChatOpenAI": CHAT_OPENAI_MODELS,
|
||||
"Anthropic": ANTHROPIC_MODELS,
|
||||
"ChatAnthropic": ANTHROPIC_MODELS,
|
||||
}
|
||||
|
||||
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
|
||||
if name in self.MODEL_DICT and field.name == "model_name":
|
||||
field.options = self.MODEL_DICT[name]
|
||||
field.is_list = True
|
||||
|
||||
|
||||
class KwargsFormatter(FieldFormatter):
|
||||
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
|
||||
if "kwargs" in field.name.lower():
|
||||
field.advanced = True
|
||||
field.required = False
|
||||
field.show = False
|
||||
|
||||
|
||||
class APIKeyFormatter(FieldFormatter):
|
||||
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
|
||||
if "api" in field.name.lower() and "key" in field.name.lower():
|
||||
field.required = False
|
||||
field.advanced = False
|
||||
|
||||
field.display_name = field.name.replace("_", " ").title()
|
||||
field.display_name = field.display_name.replace("Api", "API")
|
||||
|
||||
|
||||
class RemoveOptionalFormatter(FieldFormatter):
|
||||
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
|
||||
_type = field.field_type
|
||||
field.field_type = re.sub(r"Optional\[(.*)\]", r"\1", _type)
|
||||
|
||||
|
||||
class ListTypeFormatter(FieldFormatter):
|
||||
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
|
||||
_type = field.field_type
|
||||
is_list = "List" in _type or "Sequence" in _type
|
||||
if is_list:
|
||||
_type = re.sub(r"(List|Sequence)\[(.*)\]", r"\2", _type)
|
||||
field.is_list = True
|
||||
field.field_type = _type
|
||||
|
||||
|
||||
class DictTypeFormatter(FieldFormatter):
|
||||
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
|
||||
_type = field.field_type
|
||||
_type = _type.replace("Mapping", "dict")
|
||||
field.field_type = _type
|
||||
|
||||
|
||||
class UnionTypeFormatter(FieldFormatter):
|
||||
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
|
||||
_type = field.field_type
|
||||
if "Union" in _type:
|
||||
_type = _type.replace("Union[", "")[:-1]
|
||||
_type = _type.split(",")[0]
|
||||
_type = _type.replace("]", "").replace("[", "")
|
||||
field.field_type = _type
|
||||
|
||||
|
||||
class SpecialFieldFormatter(FieldFormatter):
|
||||
SPECIAL_FIELD_HANDLERS = {
|
||||
"allowed_tools": lambda field: "Tool",
|
||||
"max_value_length": lambda field: "int",
|
||||
}
|
||||
|
||||
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
|
||||
handler = self.SPECIAL_FIELD_HANDLERS.get(field.name)
|
||||
field.field_type = handler(field) if handler else field.field_type
|
||||
|
||||
|
||||
class ShowFieldFormatter(FieldFormatter):
|
||||
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
|
||||
key = field.name
|
||||
required = field.required
|
||||
field.show = (
|
||||
(required and key not in ["input_variables"])
|
||||
or key in FORCE_SHOW_FIELDS
|
||||
or "api" in key
|
||||
or ("key" in key and "input" not in key and "output" not in key)
|
||||
)
|
||||
|
||||
|
||||
class PasswordFieldFormatter(FieldFormatter):
|
||||
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
|
||||
key = field.name
|
||||
show = field.show
|
||||
if (
|
||||
any(text in key.lower() for text in {"password", "token", "api", "key"})
|
||||
and show
|
||||
):
|
||||
field.password = True
|
||||
|
||||
|
||||
class MultilineFieldFormatter(FieldFormatter):
|
||||
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
|
||||
key = field.name
|
||||
if key in {
|
||||
"suffix",
|
||||
"prefix",
|
||||
"template",
|
||||
"examples",
|
||||
"code",
|
||||
"headers",
|
||||
"description",
|
||||
}:
|
||||
field.multiline = True
|
||||
|
||||
|
||||
class DefaultValueFormatter(FieldFormatter):
|
||||
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
|
||||
value = field.to_dict()
|
||||
if "default" in value:
|
||||
field.value = value["default"]
|
||||
|
||||
|
||||
class HeadersDefaultValueFormatter(FieldFormatter):
|
||||
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
|
||||
key = field.name
|
||||
if key == "headers":
|
||||
field.value = """{'Authorization': 'Bearer <token>'}"""
|
||||
|
||||
|
||||
class DictCodeFileFormatter(FieldFormatter):
|
||||
def format(self, field: TemplateField, name: Optional[str] = None) -> None:
|
||||
key = field.name
|
||||
value = field.to_dict()
|
||||
_type = value["type"]
|
||||
if "dict" in _type.lower():
|
||||
if key == "dict_":
|
||||
field.field_type = "file"
|
||||
field.suffixes = [".json", ".yaml", ".yml"]
|
||||
field.file_types = ["json", "yaml", "yml"]
|
||||
else:
|
||||
field.field_type = "code"
|
||||
|
|
@ -1,10 +1,56 @@
|
|||
import json
|
||||
from typing import Optional
|
||||
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.frontend_node.base import FrontendNode
|
||||
from langflow.template.frontend_node.constants import CTRANSFORMERS_DEFAULT_CONFIG
|
||||
from langflow.template.frontend_node.constants import OPENAI_API_BASE_INFO
|
||||
|
||||
|
||||
class LLMFrontendNode(FrontendNode):
|
||||
def add_extra_fields(self) -> None:
|
||||
if "VertexAI" in self.template.type_name:
|
||||
# Add credentials field which should of type file.
|
||||
self.template.add_field(
|
||||
TemplateField(
|
||||
field_type="file",
|
||||
required=False,
|
||||
show=True,
|
||||
name="credentials",
|
||||
value="",
|
||||
suffixes=[".json"],
|
||||
file_types=["json"],
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def format_vertex_field(field: TemplateField, name: str):
|
||||
if "VertexAI" in name:
|
||||
advanced_fields = [
|
||||
"tuned_model_name",
|
||||
"verbose",
|
||||
"top_p",
|
||||
"top_k",
|
||||
"max_output_tokens",
|
||||
]
|
||||
if field.name in advanced_fields:
|
||||
field.advanced = True
|
||||
show_fields = [
|
||||
"tuned_model_name",
|
||||
"verbose",
|
||||
"project",
|
||||
"location",
|
||||
"credentials",
|
||||
"max_output_tokens",
|
||||
"model_name",
|
||||
"temperature",
|
||||
"top_p",
|
||||
"top_k",
|
||||
]
|
||||
|
||||
if field.name in show_fields:
|
||||
field.show = True
|
||||
|
||||
@staticmethod
|
||||
def format_openai_field(field: TemplateField):
|
||||
if "openai" in field.name.lower():
|
||||
|
|
@ -15,6 +61,13 @@ class LLMFrontendNode(FrontendNode):
|
|||
if "key" not in field.name.lower() and "token" not in field.name.lower():
|
||||
field.password = False
|
||||
|
||||
if field.name == "openai_api_base":
|
||||
field.info = OPENAI_API_BASE_INFO
|
||||
|
||||
def add_extra_base_classes(self) -> None:
|
||||
if "BaseLLM" not in self.base_classes:
|
||||
self.base_classes.append("BaseLLM")
|
||||
|
||||
@staticmethod
|
||||
def format_azure_field(field: TemplateField):
|
||||
if field.name == "model_name":
|
||||
|
|
@ -31,6 +84,13 @@ class LLMFrontendNode(FrontendNode):
|
|||
field.show = True
|
||||
field.advanced = not field.required
|
||||
|
||||
@staticmethod
|
||||
def format_ctransformers_field(field: TemplateField):
|
||||
if field.name == "config":
|
||||
field.show = True
|
||||
field.advanced = True
|
||||
field.value = json.dumps(CTRANSFORMERS_DEFAULT_CONFIG, indent=2)
|
||||
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
display_names_dict = {
|
||||
|
|
@ -38,10 +98,13 @@ class LLMFrontendNode(FrontendNode):
|
|||
}
|
||||
FrontendNode.format_field(field, name)
|
||||
LLMFrontendNode.format_openai_field(field)
|
||||
LLMFrontendNode.format_ctransformers_field(field)
|
||||
if name and "azure" in name.lower():
|
||||
LLMFrontendNode.format_azure_field(field)
|
||||
if name and "llama" in name.lower():
|
||||
LLMFrontendNode.format_llama_field(field)
|
||||
if name and "vertex" in name.lower():
|
||||
LLMFrontendNode.format_vertex_field(field, name)
|
||||
SHOW_FIELDS = ["repo_id"]
|
||||
if field.name in SHOW_FIELDS:
|
||||
field.show = True
|
||||
|
|
@ -77,6 +140,17 @@ class LLMFrontendNode(FrontendNode):
|
|||
"model_file",
|
||||
"model_type",
|
||||
"deployment_name",
|
||||
"credentials",
|
||||
]:
|
||||
field.advanced = False
|
||||
field.show = True
|
||||
if field.name == "credentials":
|
||||
field.field_type = "file"
|
||||
if name == "VertexAI" and field.name not in [
|
||||
"callbacks",
|
||||
"client",
|
||||
"stop",
|
||||
"tags",
|
||||
"cache",
|
||||
]:
|
||||
field.show = True
|
||||
|
|
|
|||
|
|
@ -2,11 +2,24 @@ from typing import Optional
|
|||
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.frontend_node.base import FrontendNode
|
||||
from langflow.template.frontend_node.constants import INPUT_KEY_INFO, OUTPUT_KEY_INFO
|
||||
from langflow.template.template.base import Template
|
||||
from langchain.memory.chat_message_histories.postgres import DEFAULT_CONNECTION_STRING
|
||||
from langchain.memory.chat_message_histories.mongodb import (
|
||||
DEFAULT_COLLECTION_NAME,
|
||||
DEFAULT_DBNAME,
|
||||
)
|
||||
|
||||
|
||||
class MemoryFrontendNode(FrontendNode):
|
||||
#! Needs testing
|
||||
def add_extra_fields(self) -> None:
|
||||
# chat history should have another way to add common field?
|
||||
# prevent adding incorect field in ChatMessageHistory
|
||||
base_message_classes = ["BaseEntityStore", "BaseChatMessageHistory"]
|
||||
if any(base_class in self.base_classes for base_class in base_message_classes):
|
||||
return
|
||||
|
||||
# add return_messages field
|
||||
self.template.add_field(
|
||||
TemplateField(
|
||||
|
|
@ -18,6 +31,28 @@ class MemoryFrontendNode(FrontendNode):
|
|||
value=False,
|
||||
)
|
||||
)
|
||||
# add input_key and output_key str fields
|
||||
self.template.add_field(
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
required=False,
|
||||
show=True,
|
||||
name="input_key",
|
||||
advanced=True,
|
||||
value="",
|
||||
)
|
||||
)
|
||||
if self.template.type_name not in {"VectorStoreRetrieverMemory"}:
|
||||
self.template.add_field(
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
required=False,
|
||||
show=True,
|
||||
name="output_key",
|
||||
advanced=True,
|
||||
value="",
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
|
|
@ -36,3 +71,123 @@ class MemoryFrontendNode(FrontendNode):
|
|||
field.required = False
|
||||
field.show = True
|
||||
field.advanced = False
|
||||
if field.name in {"input_key", "output_key"}:
|
||||
field.required = False
|
||||
field.show = True
|
||||
field.advanced = False
|
||||
field.value = ""
|
||||
field.info = (
|
||||
INPUT_KEY_INFO if field.name == "input_key" else OUTPUT_KEY_INFO
|
||||
)
|
||||
|
||||
if field.name == "memory_key":
|
||||
field.value = "chat_history"
|
||||
if field.name == "chat_memory":
|
||||
field.show = True
|
||||
field.advanced = False
|
||||
field.required = False
|
||||
if field.name == "url":
|
||||
field.show = True
|
||||
if field.name == "entity_store":
|
||||
field.show = False
|
||||
if name == "ConversationEntityMemory" and field.name == "memory_key":
|
||||
field.show = False
|
||||
field.required = False
|
||||
|
||||
if name == "MotorheadMemory":
|
||||
if field.name == "chat_memory":
|
||||
field.show = False
|
||||
field.required = False
|
||||
elif field.name == "client_id":
|
||||
field.show = True
|
||||
field.advanced = False
|
||||
|
||||
|
||||
class PostgresChatMessageHistoryFrontendNode(MemoryFrontendNode):
|
||||
name: str = "PostgresChatMessageHistory"
|
||||
template: Template = Template(
|
||||
type_name="PostgresChatMessageHistory",
|
||||
fields=[
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
required=True,
|
||||
placeholder="",
|
||||
is_list=False,
|
||||
show=True,
|
||||
multiline=False,
|
||||
name="session_id",
|
||||
),
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
required=True,
|
||||
show=True,
|
||||
name="connection_string",
|
||||
value=DEFAULT_CONNECTION_STRING,
|
||||
),
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
required=True,
|
||||
placeholder="",
|
||||
is_list=False,
|
||||
show=True,
|
||||
multiline=False,
|
||||
value="message_store",
|
||||
name="table_name",
|
||||
),
|
||||
],
|
||||
)
|
||||
description: str = "Memory store with Postgres"
|
||||
base_classes: list[str] = ["PostgresChatMessageHistory", "BaseChatMessageHistory"]
|
||||
|
||||
|
||||
class MongoDBChatMessageHistoryFrontendNode(MemoryFrontendNode):
|
||||
name: str = "MongoDBChatMessageHistory"
|
||||
template: Template = Template(
|
||||
# langchain/memory/chat_message_histories/mongodb.py
|
||||
# connection_string: str,
|
||||
# session_id: str,
|
||||
# database_name: str = DEFAULT_DBNAME,
|
||||
# collection_name: str = DEFAULT_COLLECTION_NAME,
|
||||
type_name="MongoDBChatMessageHistory",
|
||||
fields=[
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
required=True,
|
||||
placeholder="",
|
||||
is_list=False,
|
||||
show=True,
|
||||
multiline=False,
|
||||
name="session_id",
|
||||
),
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
required=True,
|
||||
show=True,
|
||||
name="connection_string",
|
||||
value="",
|
||||
info="MongoDB connection string (e.g mongodb://mongo_user:password123@mongo:27017)",
|
||||
),
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
required=True,
|
||||
placeholder="",
|
||||
is_list=False,
|
||||
show=True,
|
||||
multiline=False,
|
||||
value=DEFAULT_DBNAME,
|
||||
name="database_name",
|
||||
),
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
required=True,
|
||||
placeholder="",
|
||||
is_list=False,
|
||||
show=True,
|
||||
multiline=False,
|
||||
value=DEFAULT_COLLECTION_NAME,
|
||||
name="collection_name",
|
||||
),
|
||||
],
|
||||
)
|
||||
description: str = "Memory store with MongoDB"
|
||||
base_classes: list[str] = ["MongoDBChatMessageHistory", "BaseChatMessageHistory"]
|
||||
|
|
|
|||
|
|
@ -0,0 +1,10 @@
|
|||
from typing import Optional
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.frontend_node.base import FrontendNode
|
||||
|
||||
|
||||
class OutputParserFrontendNode(FrontendNode):
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
FrontendNode.format_field(field, name)
|
||||
field.show = True
|
||||
15
src/backend/langflow/template/frontend_node/retrievers.py
Normal file
15
src/backend/langflow/template/frontend_node/retrievers.py
Normal file
|
|
@ -0,0 +1,15 @@
|
|||
from typing import Optional
|
||||
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.frontend_node.base import FrontendNode
|
||||
|
||||
|
||||
class RetrieverFrontendNode(FrontendNode):
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
FrontendNode.format_field(field, name)
|
||||
# Define common field attributes
|
||||
field.show = True
|
||||
if field.name == "parser_key":
|
||||
field.display_name = "Parser Key"
|
||||
field.password = False
|
||||
|
|
@ -1,15 +1,21 @@
|
|||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.frontend_node.base import FrontendNode
|
||||
from langchain.text_splitter import Language
|
||||
|
||||
|
||||
class TextSplittersFrontendNode(FrontendNode):
|
||||
def add_extra_base_classes(self) -> None:
|
||||
self.base_classes = ["Document"]
|
||||
self.output_types = ["Document"]
|
||||
|
||||
def add_extra_fields(self) -> None:
|
||||
self.template.add_field(
|
||||
TemplateField(
|
||||
field_type="BaseLoader",
|
||||
field_type="Document",
|
||||
required=True,
|
||||
show=True,
|
||||
name="documents",
|
||||
is_list=True,
|
||||
)
|
||||
)
|
||||
name = "separator"
|
||||
|
|
@ -17,12 +23,30 @@ class TextSplittersFrontendNode(FrontendNode):
|
|||
name = "separator"
|
||||
elif self.template.type_name == "RecursiveCharacterTextSplitter":
|
||||
name = "separators"
|
||||
# Add a field for type of separator
|
||||
# which will have Text or any value from the
|
||||
# Language enum
|
||||
options = [x.value for x in Language] + ["Text"]
|
||||
options.sort()
|
||||
self.template.add_field(
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
required=True,
|
||||
show=True,
|
||||
name="separator_type",
|
||||
advanced=False,
|
||||
is_list=True,
|
||||
options=options,
|
||||
value="Text",
|
||||
display_name="Separator Type",
|
||||
)
|
||||
)
|
||||
self.template.add_field(
|
||||
TemplateField(
|
||||
field_type="str",
|
||||
required=True,
|
||||
show=True,
|
||||
value=".",
|
||||
value="\\n",
|
||||
name=name,
|
||||
display_name="Separator",
|
||||
)
|
||||
|
|
|
|||
|
|
@ -1,7 +1,9 @@
|
|||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.frontend_node.base import FrontendNode
|
||||
from langflow.template.template.base import Template
|
||||
from langflow.utils.constants import DEFAULT_PYTHON_FUNCTION
|
||||
from langflow.utils.constants import (
|
||||
DEFAULT_PYTHON_FUNCTION,
|
||||
)
|
||||
|
||||
|
||||
class ToolNode(FrontendNode):
|
||||
|
|
@ -53,7 +55,7 @@ class ToolNode(FrontendNode):
|
|||
],
|
||||
)
|
||||
description: str = "Converts a chain, agent or function into a tool."
|
||||
base_classes: list[str] = ["Tool"]
|
||||
base_classes: list[str] = ["Tool", "BaseTool"]
|
||||
|
||||
def to_dict(self):
|
||||
return super().to_dict()
|
||||
|
|
@ -96,10 +98,20 @@ class PythonFunctionToolNode(FrontendNode):
|
|||
name="code",
|
||||
advanced=False,
|
||||
),
|
||||
TemplateField(
|
||||
field_type="bool",
|
||||
required=True,
|
||||
placeholder="",
|
||||
is_list=False,
|
||||
show=True,
|
||||
multiline=False,
|
||||
value=False,
|
||||
name="return_direct",
|
||||
),
|
||||
],
|
||||
)
|
||||
description: str = "Python function to be executed."
|
||||
base_classes: list[str] = ["Tool"]
|
||||
base_classes: list[str] = ["BaseTool", "Tool"]
|
||||
|
||||
def to_dict(self):
|
||||
return super().to_dict()
|
||||
|
|
|
|||
|
|
@ -1,4 +1,4 @@
|
|||
from typing import Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.template.frontend_node.base import FrontendNode
|
||||
|
|
@ -6,6 +6,19 @@ from langflow.template.frontend_node.base import FrontendNode
|
|||
|
||||
class VectorStoreFrontendNode(FrontendNode):
|
||||
def add_extra_fields(self) -> None:
|
||||
extra_fields: List[TemplateField] = []
|
||||
# Add search_kwargs field
|
||||
extra_field = TemplateField(
|
||||
name="search_kwargs",
|
||||
field_type="code",
|
||||
required=False,
|
||||
placeholder="",
|
||||
show=True,
|
||||
advanced=True,
|
||||
multiline=False,
|
||||
value="{}",
|
||||
)
|
||||
extra_fields.append(extra_field)
|
||||
if self.template.type_name == "Weaviate":
|
||||
extra_field = TemplateField(
|
||||
name="weaviate_url",
|
||||
|
|
@ -17,17 +30,203 @@ class VectorStoreFrontendNode(FrontendNode):
|
|||
multiline=False,
|
||||
value="http://localhost:8080",
|
||||
)
|
||||
# Add client_kwargs field
|
||||
extra_field2 = TemplateField(
|
||||
name="client_kwargs",
|
||||
field_type="code",
|
||||
required=False,
|
||||
placeholder="",
|
||||
show=True,
|
||||
advanced=True,
|
||||
multiline=False,
|
||||
value="{}",
|
||||
)
|
||||
extra_fields.extend((extra_field, extra_field2))
|
||||
|
||||
self.template.add_field(extra_field)
|
||||
elif self.template.type_name == "Chroma":
|
||||
# New bool field for persist parameter
|
||||
extra_field = TemplateField(
|
||||
name="persist",
|
||||
field_type="bool",
|
||||
required=False,
|
||||
show=True,
|
||||
advanced=False,
|
||||
value=False,
|
||||
display_name="Persist",
|
||||
)
|
||||
extra_fields.append(extra_field)
|
||||
elif self.template.type_name == "Pinecone":
|
||||
# add pinecone_api_key and pinecone_env
|
||||
extra_field = TemplateField(
|
||||
name="pinecone_api_key",
|
||||
field_type="str",
|
||||
required=False,
|
||||
placeholder="",
|
||||
show=True,
|
||||
advanced=True,
|
||||
multiline=False,
|
||||
password=True,
|
||||
value="",
|
||||
)
|
||||
extra_field2 = TemplateField(
|
||||
name="pinecone_env",
|
||||
field_type="str",
|
||||
required=False,
|
||||
placeholder="",
|
||||
show=True,
|
||||
advanced=True,
|
||||
multiline=False,
|
||||
value="",
|
||||
)
|
||||
extra_fields.extend((extra_field, extra_field2))
|
||||
elif self.template.type_name == "FAISS":
|
||||
extra_field = TemplateField(
|
||||
name="folder_path",
|
||||
field_type="str",
|
||||
required=False,
|
||||
placeholder="",
|
||||
show=True,
|
||||
advanced=True,
|
||||
multiline=False,
|
||||
display_name="Local Path",
|
||||
value="",
|
||||
)
|
||||
extra_field2 = TemplateField(
|
||||
name="index_name",
|
||||
field_type="str",
|
||||
required=False,
|
||||
show=True,
|
||||
advanced=False,
|
||||
value="",
|
||||
display_name="Index Name",
|
||||
)
|
||||
extra_fields.extend((extra_field, extra_field2))
|
||||
elif self.template.type_name == "SupabaseVectorStore":
|
||||
self.display_name = "Supabase"
|
||||
# Add table_name and query_name
|
||||
extra_field = TemplateField(
|
||||
name="table_name",
|
||||
field_type="str",
|
||||
required=False,
|
||||
placeholder="",
|
||||
show=True,
|
||||
advanced=True,
|
||||
multiline=False,
|
||||
value="",
|
||||
)
|
||||
extra_field2 = TemplateField(
|
||||
name="query_name",
|
||||
field_type="str",
|
||||
required=False,
|
||||
placeholder="",
|
||||
show=True,
|
||||
advanced=True,
|
||||
multiline=False,
|
||||
value="",
|
||||
)
|
||||
# Add supabase_url and supabase_service_key
|
||||
extra_field3 = TemplateField(
|
||||
name="supabase_url",
|
||||
field_type="str",
|
||||
required=False,
|
||||
placeholder="",
|
||||
show=True,
|
||||
advanced=True,
|
||||
multiline=False,
|
||||
value="",
|
||||
)
|
||||
extra_field4 = TemplateField(
|
||||
name="supabase_service_key",
|
||||
field_type="str",
|
||||
required=False,
|
||||
placeholder="",
|
||||
show=True,
|
||||
advanced=True,
|
||||
multiline=False,
|
||||
password=True,
|
||||
value="",
|
||||
)
|
||||
extra_fields.extend((extra_field, extra_field2, extra_field3, extra_field4))
|
||||
|
||||
elif self.template.type_name == "MongoDBAtlasVectorSearch":
|
||||
self.display_name = "MongoDB Atlas"
|
||||
|
||||
extra_field = TemplateField(
|
||||
name="mongodb_atlas_cluster_uri",
|
||||
field_type="str",
|
||||
required=False,
|
||||
placeholder="",
|
||||
show=True,
|
||||
advanced=True,
|
||||
multiline=False,
|
||||
display_name="MongoDB Atlas Cluster URI",
|
||||
value="",
|
||||
)
|
||||
extra_field2 = TemplateField(
|
||||
name="collection_name",
|
||||
field_type="str",
|
||||
required=False,
|
||||
placeholder="",
|
||||
show=True,
|
||||
advanced=True,
|
||||
multiline=False,
|
||||
display_name="Collection Name",
|
||||
value="",
|
||||
)
|
||||
extra_field3 = TemplateField(
|
||||
name="db_name",
|
||||
field_type="str",
|
||||
required=False,
|
||||
placeholder="",
|
||||
show=True,
|
||||
advanced=True,
|
||||
multiline=False,
|
||||
display_name="Database Name",
|
||||
value="",
|
||||
)
|
||||
extra_field4 = TemplateField(
|
||||
name="index_name",
|
||||
field_type="str",
|
||||
required=False,
|
||||
placeholder="",
|
||||
show=True,
|
||||
advanced=True,
|
||||
multiline=False,
|
||||
display_name="Index Name",
|
||||
value="",
|
||||
)
|
||||
extra_fields.extend((extra_field, extra_field2, extra_field3, extra_field4))
|
||||
|
||||
if extra_fields:
|
||||
for field in extra_fields:
|
||||
self.template.add_field(field)
|
||||
|
||||
def add_extra_base_classes(self) -> None:
|
||||
self.base_classes.append("BaseRetriever")
|
||||
self.base_classes.extend(("BaseRetriever", "VectorStoreRetriever"))
|
||||
|
||||
@staticmethod
|
||||
def format_field(field: TemplateField, name: Optional[str] = None) -> None:
|
||||
FrontendNode.format_field(field, name)
|
||||
# Define common field attributes
|
||||
basic_fields = ["work_dir", "collection_name", "api_key", "location"]
|
||||
basic_fields = [
|
||||
"work_dir",
|
||||
"collection_name",
|
||||
"api_key",
|
||||
"location",
|
||||
"persist_directory",
|
||||
"persist",
|
||||
"weaviate_url",
|
||||
"index_name",
|
||||
"namespace",
|
||||
"folder_path",
|
||||
"table_name",
|
||||
"query_name",
|
||||
"supabase_url",
|
||||
"supabase_service_key",
|
||||
"mongodb_atlas_cluster_uri",
|
||||
"collection_name",
|
||||
"db_name",
|
||||
]
|
||||
advanced_fields = [
|
||||
"n_dim",
|
||||
"key",
|
||||
|
|
@ -43,17 +242,24 @@ class VectorStoreFrontendNode(FrontendNode):
|
|||
"https",
|
||||
"prefer_grpc",
|
||||
"grpc_port",
|
||||
"pinecone_api_key",
|
||||
"pinecone_env",
|
||||
"client_kwargs",
|
||||
"search_kwargs",
|
||||
]
|
||||
|
||||
# Check and set field attributes
|
||||
if field.name == "texts":
|
||||
# if field.name is "texts" it has to be replaced
|
||||
# when instantiating the vectorstores
|
||||
field.name = "documents"
|
||||
field.field_type = "TextSplitter"
|
||||
field.display_name = "Text Splitter"
|
||||
field.required = True
|
||||
|
||||
field.field_type = "Document"
|
||||
field.display_name = "Documents"
|
||||
field.required = False
|
||||
field.show = True
|
||||
field.advanced = False
|
||||
|
||||
field.is_list = True
|
||||
elif "embedding" in field.name:
|
||||
# for backwards compatibility
|
||||
field.name = "embedding"
|
||||
|
|
@ -78,5 +284,6 @@ class VectorStoreFrontendNode(FrontendNode):
|
|||
field.advanced = True
|
||||
if "key" in field.name:
|
||||
field.password = False
|
||||
# TODO: Weaviate requires weaviate_url to be passed as it is not part of
|
||||
# the class or from_texts method. We need the add_extra_fields to fix this
|
||||
|
||||
elif field.name == "text_key":
|
||||
field.show = False
|
||||
|
|
|
|||
|
|
@ -3,6 +3,7 @@ from typing import Callable, Optional, Union
|
|||
from pydantic import BaseModel
|
||||
|
||||
from langflow.template.field.base import TemplateField
|
||||
from langflow.utils.constants import DIRECT_TYPES
|
||||
|
||||
|
||||
class Template(BaseModel):
|
||||
|
|
@ -18,8 +19,15 @@ class Template(BaseModel):
|
|||
for field in self.fields:
|
||||
format_field_func(field, name)
|
||||
|
||||
def sort_fields(self):
|
||||
# first sort alphabetically
|
||||
# then sort fields so that fields that have .field_type in DIRECT_TYPES are first
|
||||
self.fields.sort(key=lambda x: x.name)
|
||||
self.fields.sort(key=lambda x: x.field_type in DIRECT_TYPES, reverse=False)
|
||||
|
||||
def to_dict(self, format_field_func=None):
|
||||
self.process_fields(self.type_name, format_field_func)
|
||||
self.sort_fields()
|
||||
result = {field.name: field.to_dict() for field in self.fields}
|
||||
result["_type"] = self.type_name # type: ignore
|
||||
return result
|
||||
|
|
|
|||
|
|
@ -17,18 +17,29 @@ CHAT_OPENAI_MODELS = [
|
|||
]
|
||||
|
||||
ANTHROPIC_MODELS = [
|
||||
"claude-v1", # largest model, ideal for a wide range of more complex tasks.
|
||||
"claude-v1-100k", # An enhanced version of claude-v1 with a 100,000 token (roughly 75,000 word) context window.
|
||||
"claude-instant-v1", # A smaller model with far lower latency, sampling at roughly 40 words/sec!
|
||||
"claude-instant-v1-100k", # Like claude-instant-v1 with a 100,000 token context window but retains its performance.
|
||||
# largest model, ideal for a wide range of more complex tasks.
|
||||
"claude-v1",
|
||||
# An enhanced version of claude-v1 with a 100,000 token (roughly 75,000 word) context window.
|
||||
"claude-v1-100k",
|
||||
# A smaller model with far lower latency, sampling at roughly 40 words/sec!
|
||||
"claude-instant-v1",
|
||||
# Like claude-instant-v1 with a 100,000 token context window but retains its performance.
|
||||
"claude-instant-v1-100k",
|
||||
# Specific sub-versions of the above models:
|
||||
"claude-v1.3", # Vs claude-v1.2: better instruction-following, code, and non-English dialogue and writing.
|
||||
"claude-v1.3-100k", # An enhanced version of claude-v1.3 with a 100,000 token (roughly 75,000 word) context window.
|
||||
"claude-v1.2", # Vs claude-v1.1: small adv in general helpfulness, instruction following, coding, and other tasks.
|
||||
"claude-v1.0", # An earlier version of claude-v1.
|
||||
"claude-instant-v1.1", # Latest version of claude-instant-v1. Better than claude-instant-v1.0 at most tasks.
|
||||
"claude-instant-v1.1-100k", # Version of claude-instant-v1.1 with a 100K token context window.
|
||||
"claude-instant-v1.0", # An earlier version of claude-instant-v1.
|
||||
# Vs claude-v1.2: better instruction-following, code, and non-English dialogue and writing.
|
||||
"claude-v1.3",
|
||||
# An enhanced version of claude-v1.3 with a 100,000 token (roughly 75,000 word) context window.
|
||||
"claude-v1.3-100k",
|
||||
# Vs claude-v1.1: small adv in general helpfulness, instruction following, coding, and other tasks.
|
||||
"claude-v1.2",
|
||||
# An earlier version of claude-v1.
|
||||
"claude-v1.0",
|
||||
# Latest version of claude-instant-v1. Better than claude-instant-v1.0 at most tasks.
|
||||
"claude-instant-v1.1",
|
||||
# Version of claude-instant-v1.1 with a 100K token context window.
|
||||
"claude-instant-v1.1-100k",
|
||||
# An earlier version of claude-instant-v1.
|
||||
"claude-instant-v1.0",
|
||||
]
|
||||
|
||||
DEFAULT_PYTHON_FUNCTION = """
|
||||
|
|
@ -36,3 +47,5 @@ def python_function(text: str) -> str:
|
|||
\"\"\"This is a default python function that returns the input text\"\"\"
|
||||
return text
|
||||
"""
|
||||
|
||||
DIRECT_TYPES = ["str", "bool", "code", "int", "float", "Any", "prompt"]
|
||||
|
|
|
|||
|
|
@ -6,7 +6,7 @@ from rich.logging import RichHandler
|
|||
logger = logging.getLogger("langflow")
|
||||
|
||||
|
||||
def configure(log_level: str = "INFO", log_file: Path = None): # type: ignore
|
||||
def configure(log_level: str = "DEBUG", log_file: Path = None): # type: ignore
|
||||
log_format = "%(asctime)s - %(levelname)s - %(message)s"
|
||||
log_level_value = getattr(logging, log_level.upper(), logging.INFO)
|
||||
|
||||
|
|
|
|||
2
src/backend/langflow/utils/types.py
Normal file
2
src/backend/langflow/utils/types.py
Normal file
|
|
@ -0,0 +1,2 @@
|
|||
class Prompt:
|
||||
pass
|
||||
|
|
@ -1,13 +1,15 @@
|
|||
import importlib
|
||||
import inspect
|
||||
import re
|
||||
import inspect
|
||||
import importlib
|
||||
from functools import wraps
|
||||
from typing import Dict, Optional
|
||||
from typing import Optional, Dict, Any, Union
|
||||
|
||||
from docstring_parser import parse # type: ignore
|
||||
|
||||
from langflow.template.frontend_node.constants import FORCE_SHOW_FIELDS
|
||||
from langflow.utils import constants
|
||||
from langflow.utils.logger import logger
|
||||
from multiprocess import cpu_count # type: ignore
|
||||
|
||||
|
||||
def build_template_from_function(
|
||||
|
|
@ -165,6 +167,7 @@ def build_template_from_method(
|
|||
"required": param.default == param.empty,
|
||||
}
|
||||
for name, param in params.items()
|
||||
if name not in ["self", "kwargs", "args"]
|
||||
},
|
||||
}
|
||||
|
||||
|
|
@ -213,101 +216,6 @@ def get_default_factory(module: str, function: str):
|
|||
return None
|
||||
|
||||
|
||||
def format_dict(d, name: Optional[str] = None):
|
||||
"""
|
||||
Formats a dictionary by removing certain keys and modifying the
|
||||
values of other keys.
|
||||
|
||||
Args:
|
||||
d: the dictionary to format
|
||||
name: the name of the class to format
|
||||
|
||||
Returns:
|
||||
A new dictionary with the desired modifications applied.
|
||||
"""
|
||||
|
||||
# Process remaining keys
|
||||
for key, value in d.items():
|
||||
if key == "_type":
|
||||
continue
|
||||
|
||||
_type = value["type"]
|
||||
|
||||
# Remove 'Optional' wrapper
|
||||
if "Optional" in _type:
|
||||
_type = _type.replace("Optional[", "")[:-1]
|
||||
|
||||
# Check for list type
|
||||
if "List" in _type or "Sequence" in _type or "Set" in _type:
|
||||
_type = _type.replace("List[", "")[:-1]
|
||||
value["list"] = True
|
||||
else:
|
||||
value["list"] = False
|
||||
|
||||
# Replace 'Mapping' with 'dict'
|
||||
if "Mapping" in _type:
|
||||
_type = _type.replace("Mapping", "dict")
|
||||
|
||||
# Change type from str to Tool
|
||||
value["type"] = "Tool" if key in ["allowed_tools"] else _type
|
||||
|
||||
value["type"] = "int" if key in ["max_value_length"] else value["type"]
|
||||
|
||||
# Show or not field
|
||||
value["show"] = bool(
|
||||
(value["required"] and key not in ["input_variables"])
|
||||
or key in FORCE_SHOW_FIELDS
|
||||
or "api_key" in key
|
||||
)
|
||||
|
||||
# Add password field
|
||||
value["password"] = any(
|
||||
text in key.lower() for text in ["password", "token", "api", "key"]
|
||||
)
|
||||
|
||||
# Add multline
|
||||
value["multiline"] = key in [
|
||||
"suffix",
|
||||
"prefix",
|
||||
"template",
|
||||
"examples",
|
||||
"code",
|
||||
"headers",
|
||||
"format_instructions",
|
||||
]
|
||||
|
||||
# Replace dict type with str
|
||||
if "dict" in value["type"].lower():
|
||||
value["type"] = "code"
|
||||
|
||||
if key == "dict_":
|
||||
value["type"] = "file"
|
||||
value["suffixes"] = [".json", ".yaml", ".yml"]
|
||||
value["fileTypes"] = ["json", "yaml", "yml"]
|
||||
|
||||
# Replace default value with actual value
|
||||
if "default" in value:
|
||||
value["value"] = value["default"]
|
||||
value.pop("default")
|
||||
|
||||
if key == "headers":
|
||||
value[
|
||||
"value"
|
||||
] = """{'Authorization':
|
||||
'Bearer <token>'}"""
|
||||
# Add options to openai
|
||||
if name == "OpenAI" and key == "model_name":
|
||||
value["options"] = constants.OPENAI_MODELS
|
||||
value["list"] = True
|
||||
elif name == "ChatOpenAI" and key == "model_name":
|
||||
value["options"] = constants.CHAT_OPENAI_MODELS
|
||||
value["list"] = True
|
||||
elif (name == "Anthropic" or name == "ChatAnthropic") and key == "model_name":
|
||||
value["options"] = constants.ANTHROPIC_MODELS
|
||||
value["list"] = True
|
||||
return d
|
||||
|
||||
|
||||
def update_verbose(d: dict, new_value: bool) -> dict:
|
||||
"""
|
||||
Recursively updates the value of the 'verbose' key in a dictionary.
|
||||
|
|
@ -338,3 +246,219 @@ def sync_to_async(func):
|
|||
return func(*args, **kwargs)
|
||||
|
||||
return async_wrapper
|
||||
|
||||
|
||||
def format_dict(
|
||||
dictionary: Dict[str, Any], class_name: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Formats a dictionary by removing certain keys and modifying the
|
||||
values of other keys.
|
||||
|
||||
Returns:
|
||||
A new dictionary with the desired modifications applied.
|
||||
"""
|
||||
|
||||
for key, value in dictionary.items():
|
||||
if key == "_type":
|
||||
continue
|
||||
|
||||
_type: Union[str, type] = get_type(value)
|
||||
|
||||
_type = remove_optional_wrapper(_type)
|
||||
_type = check_list_type(_type, value)
|
||||
_type = replace_mapping_with_dict(_type)
|
||||
|
||||
value["type"] = get_formatted_type(key, _type)
|
||||
value["show"] = should_show_field(value, key)
|
||||
value["password"] = is_password_field(key)
|
||||
value["multiline"] = is_multiline_field(key)
|
||||
|
||||
replace_dict_type_with_code(value)
|
||||
|
||||
if key == "dict_":
|
||||
set_dict_file_attributes(value)
|
||||
|
||||
replace_default_value_with_actual(value)
|
||||
|
||||
if key == "headers":
|
||||
set_headers_value(value)
|
||||
|
||||
add_options_to_field(value, class_name, key)
|
||||
|
||||
return dictionary
|
||||
|
||||
|
||||
def get_type(value: Any) -> Union[str, type]:
|
||||
"""
|
||||
Retrieves the type value from the dictionary.
|
||||
|
||||
Returns:
|
||||
The type value.
|
||||
"""
|
||||
_type = value["type"]
|
||||
|
||||
return _type if isinstance(_type, str) else _type.__name__
|
||||
|
||||
|
||||
def remove_optional_wrapper(_type: Union[str, type]) -> str:
|
||||
"""
|
||||
Removes the 'Optional' wrapper from the type string.
|
||||
|
||||
Returns:
|
||||
The type string with the 'Optional' wrapper removed.
|
||||
"""
|
||||
if isinstance(_type, type):
|
||||
_type = str(_type)
|
||||
if "Optional" in _type:
|
||||
_type = _type.replace("Optional[", "")[:-1]
|
||||
|
||||
return _type
|
||||
|
||||
|
||||
def check_list_type(_type: str, value: Dict[str, Any]) -> str:
|
||||
"""
|
||||
Checks if the type is a list type and modifies the value accordingly.
|
||||
|
||||
Returns:
|
||||
The modified type string.
|
||||
"""
|
||||
if any(list_type in _type for list_type in ["List", "Sequence", "Set"]):
|
||||
_type = (
|
||||
_type.replace("List[", "").replace("Sequence[", "").replace("Set[", "")[:-1]
|
||||
)
|
||||
value["list"] = True
|
||||
else:
|
||||
value["list"] = False
|
||||
|
||||
return _type
|
||||
|
||||
|
||||
def replace_mapping_with_dict(_type: str) -> str:
|
||||
"""
|
||||
Replaces 'Mapping' with 'dict' in the type string.
|
||||
|
||||
Returns:
|
||||
The modified type string.
|
||||
"""
|
||||
if "Mapping" in _type:
|
||||
_type = _type.replace("Mapping", "dict")
|
||||
|
||||
return _type
|
||||
|
||||
|
||||
def get_formatted_type(key: str, _type: str) -> str:
|
||||
"""
|
||||
Formats the type value based on the given key.
|
||||
|
||||
Returns:
|
||||
The formatted type value.
|
||||
"""
|
||||
if key == "allowed_tools":
|
||||
return "Tool"
|
||||
|
||||
elif key == "max_value_length":
|
||||
return "int"
|
||||
|
||||
return _type
|
||||
|
||||
|
||||
def should_show_field(value: Dict[str, Any], key: str) -> bool:
|
||||
"""
|
||||
Determines if the field should be shown or not.
|
||||
|
||||
Returns:
|
||||
True if the field should be shown, False otherwise.
|
||||
"""
|
||||
return (
|
||||
(value["required"] and key != "input_variables")
|
||||
or key in FORCE_SHOW_FIELDS
|
||||
or any(text in key.lower() for text in ["password", "token", "api", "key"])
|
||||
)
|
||||
|
||||
|
||||
def is_password_field(key: str) -> bool:
|
||||
"""
|
||||
Determines if the field is a password field.
|
||||
|
||||
Returns:
|
||||
True if the field is a password field, False otherwise.
|
||||
"""
|
||||
return any(text in key.lower() for text in ["password", "token", "api", "key"])
|
||||
|
||||
|
||||
def is_multiline_field(key: str) -> bool:
|
||||
"""
|
||||
Determines if the field is a multiline field.
|
||||
|
||||
Returns:
|
||||
True if the field is a multiline field, False otherwise.
|
||||
"""
|
||||
return key in {
|
||||
"suffix",
|
||||
"prefix",
|
||||
"template",
|
||||
"examples",
|
||||
"code",
|
||||
"headers",
|
||||
"format_instructions",
|
||||
}
|
||||
|
||||
|
||||
def replace_dict_type_with_code(value: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Replaces the type value with 'code' if the type is a dict.
|
||||
"""
|
||||
if "dict" in value["type"].lower():
|
||||
value["type"] = "code"
|
||||
|
||||
|
||||
def set_dict_file_attributes(value: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Sets the file attributes for the 'dict_' key.
|
||||
"""
|
||||
value["type"] = "file"
|
||||
value["suffixes"] = [".json", ".yaml", ".yml"]
|
||||
value["fileTypes"] = ["json", "yaml", "yml"]
|
||||
|
||||
|
||||
def replace_default_value_with_actual(value: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Replaces the default value with the actual value.
|
||||
"""
|
||||
if "default" in value:
|
||||
value["value"] = value["default"]
|
||||
value.pop("default")
|
||||
|
||||
|
||||
def set_headers_value(value: Dict[str, Any]) -> None:
|
||||
"""
|
||||
Sets the value for the 'headers' key.
|
||||
"""
|
||||
value["value"] = """{'Authorization': 'Bearer <token>'}"""
|
||||
|
||||
|
||||
def add_options_to_field(
|
||||
value: Dict[str, Any], class_name: Optional[str], key: str
|
||||
) -> None:
|
||||
"""
|
||||
Adds options to the field based on the class name and key.
|
||||
"""
|
||||
options_map = {
|
||||
"OpenAI": constants.OPENAI_MODELS,
|
||||
"ChatOpenAI": constants.CHAT_OPENAI_MODELS,
|
||||
"Anthropic": constants.ANTHROPIC_MODELS,
|
||||
"ChatAnthropic": constants.ANTHROPIC_MODELS,
|
||||
}
|
||||
|
||||
if class_name in options_map and key == "model_name":
|
||||
value["options"] = options_map[class_name]
|
||||
value["list"] = True
|
||||
value["value"] = options_map[class_name][0]
|
||||
|
||||
|
||||
def get_number_of_workers(workers=None):
|
||||
if workers == -1 or workers is None:
|
||||
workers = (cpu_count() * 2) + 1
|
||||
logger.debug(f"Number of workers: {workers}")
|
||||
return workers
|
||||
|
|
|
|||
|
|
@ -163,9 +163,77 @@ def create_function(code, function_name):
|
|||
return wrapped_function
|
||||
|
||||
|
||||
def create_class(code, class_name):
|
||||
if not hasattr(ast, "TypeIgnore"):
|
||||
|
||||
class TypeIgnore(ast.AST):
|
||||
_fields = ()
|
||||
|
||||
ast.TypeIgnore = TypeIgnore
|
||||
|
||||
module = ast.parse(code)
|
||||
exec_globals = globals().copy()
|
||||
|
||||
for node in module.body:
|
||||
if isinstance(node, ast.Import):
|
||||
for alias in node.names:
|
||||
try:
|
||||
exec_globals[alias.asname or alias.name] = importlib.import_module(
|
||||
alias.name
|
||||
)
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(
|
||||
f"Module {alias.name} not found. Please install it and try again."
|
||||
) from e
|
||||
elif isinstance(node, ast.ImportFrom):
|
||||
try:
|
||||
imported_module = importlib.import_module(node.module)
|
||||
for alias in node.names:
|
||||
exec_globals[alias.name] = getattr(imported_module, alias.name)
|
||||
except ModuleNotFoundError as e:
|
||||
raise ModuleNotFoundError(
|
||||
f"Module {node.module} not found. Please install it and try again."
|
||||
) from e
|
||||
|
||||
class_code = next(
|
||||
node
|
||||
for node in module.body
|
||||
if isinstance(node, ast.ClassDef) and node.name == class_name
|
||||
)
|
||||
class_code.parent = None
|
||||
code_obj = compile(
|
||||
ast.Module(body=[class_code], type_ignores=[]), "<string>", "exec"
|
||||
)
|
||||
# This suppresses import errors
|
||||
# with contextlib.suppress(Exception):
|
||||
exec(code_obj, exec_globals, locals())
|
||||
exec_globals[class_name] = locals()[class_name]
|
||||
|
||||
# Return a function that imports necessary modules and creates an instance of the target class
|
||||
def build_my_class(*args, **kwargs):
|
||||
for module_name, module in exec_globals.items():
|
||||
if isinstance(module, type(importlib)):
|
||||
globals()[module_name] = module
|
||||
|
||||
instance = exec_globals[class_name](*args, **kwargs)
|
||||
return instance
|
||||
|
||||
build_my_class.__globals__.update(exec_globals)
|
||||
|
||||
return build_my_class
|
||||
|
||||
|
||||
def extract_function_name(code):
|
||||
module = ast.parse(code)
|
||||
for node in module.body:
|
||||
if isinstance(node, ast.FunctionDef):
|
||||
return node.name
|
||||
raise ValueError("No function definition found in the code string")
|
||||
|
||||
|
||||
def extract_class_name(code):
|
||||
module = ast.parse(code)
|
||||
for node in module.body:
|
||||
if isinstance(node, ast.ClassDef):
|
||||
return node.name
|
||||
raise ValueError("No class definition found in the code string")
|
||||
|
|
|
|||
|
|
@ -5,7 +5,8 @@
|
|||
<meta http-equiv="X-UA-Compatible" content="IE=edge">
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<link rel="icon" href="/favicon.ico" />
|
||||
<title>LangFlow</title>
|
||||
<script src="/node_modules/ace-builds/src-min-noconflict/ace.js" type="text/javascript"></script>
|
||||
<title>Langflow</title>
|
||||
</head>
|
||||
<body id='body' style="width: 100%; height:100%">
|
||||
<noscript>You need to enable JavaScript to run this app.</noscript>
|
||||
|
|
|
|||
2790
src/frontend/package-lock.json
generated
2790
src/frontend/package-lock.json
generated
File diff suppressed because it is too large
Load diff
|
|
@ -8,19 +8,25 @@
|
|||
"@headlessui/react": "^1.7.10",
|
||||
"@heroicons/react": "^2.0.15",
|
||||
"@mui/material": "^5.11.9",
|
||||
"@radix-ui/react-dropdown-menu": "^2.0.5",
|
||||
"@radix-ui/react-menubar": "^1.0.3",
|
||||
"@radix-ui/react-separator": "^1.0.3",
|
||||
"@radix-ui/react-slot": "^1.0.2",
|
||||
"@radix-ui/react-tabs": "^1.0.4",
|
||||
"@radix-ui/react-accordion": "^1.1.2",
|
||||
"@radix-ui/react-checkbox": "^1.0.4",
|
||||
"@radix-ui/react-dialog": "^1.0.4",
|
||||
"@radix-ui/react-dropdown-menu": "^2.0.5",
|
||||
"@radix-ui/react-icons": "^1.3.0",
|
||||
"@radix-ui/react-label": "^2.0.2",
|
||||
"@radix-ui/react-menubar": "^1.0.3",
|
||||
"@radix-ui/react-popover": "^1.0.6",
|
||||
"@radix-ui/react-progress": "^1.0.3",
|
||||
"@radix-ui/react-separator": "^1.0.3",
|
||||
"@radix-ui/react-slot": "^1.0.2",
|
||||
"@radix-ui/react-switch": "^1.0.3",
|
||||
"@radix-ui/react-tabs": "^1.0.4",
|
||||
"@radix-ui/react-tooltip": "^1.0.6",
|
||||
"@tabler/icons-react": "^2.18.0",
|
||||
"@tailwindcss/forms": "^0.5.3",
|
||||
"@tailwindcss/line-clamp": "^0.4.4",
|
||||
"@types/axios": "^0.14.0",
|
||||
"accordion": "^3.0.2",
|
||||
"ace-builds": "^1.16.0",
|
||||
"add": "^2.0.6",
|
||||
"ansi-to-html": "^0.7.2",
|
||||
|
|
@ -28,6 +34,7 @@
|
|||
"base64-js": "^1.5.1",
|
||||
"class-variance-authority": "^0.6.0",
|
||||
"clsx": "^1.2.1",
|
||||
"dompurify": "^3.0.4",
|
||||
"esbuild": "^0.17.18",
|
||||
"lodash": "^4.17.21",
|
||||
"lucide-react": "^0.233.0",
|
||||
|
|
@ -47,7 +54,7 @@
|
|||
"rehype-mathjax": "^4.0.2",
|
||||
"remark-gfm": "^3.0.1",
|
||||
"remark-math": "^5.1.1",
|
||||
"shadcn-ui": "^0.1.3",
|
||||
"shadcn-ui": "^0.2.2",
|
||||
"short-unique-id": "^4.4.4",
|
||||
"switch": "^0.0.0",
|
||||
"table": "^6.8.1",
|
||||
|
|
@ -98,7 +105,11 @@
|
|||
"@types/uuid": "^9.0.1",
|
||||
"@vitejs/plugin-react-swc": "^3.0.0",
|
||||
"autoprefixer": "^10.4.14",
|
||||
"daisyui": "^3.1.1",
|
||||
"postcss": "^8.4.23",
|
||||
"prettier": "^2.8.8",
|
||||
"prettier-plugin-organize-imports": "^3.2.2",
|
||||
"prettier-plugin-tailwindcss": "^0.3.0",
|
||||
"tailwindcss": "^3.3.2",
|
||||
"typescript": "^5.0.2",
|
||||
"vite": "^4.3.9"
|
||||
|
|
|
|||
3
src/frontend/prettier.config.js
Normal file
3
src/frontend/prettier.config.js
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
module.exports = {
|
||||
plugins: [require("prettier-plugin-tailwindcss")],
|
||||
};
|
||||
|
|
@ -1,20 +1,19 @@
|
|||
import "reactflow/dist/style.css";
|
||||
import { useState, useEffect, useContext } from "react";
|
||||
import "./App.css";
|
||||
import { useLocation } from "react-router-dom";
|
||||
import _ from "lodash";
|
||||
import { useContext, useEffect, useState } from "react";
|
||||
import { useLocation } from "react-router-dom";
|
||||
import "reactflow/dist/style.css";
|
||||
import "./App.css";
|
||||
|
||||
import { ErrorBoundary } from "react-error-boundary";
|
||||
import ErrorAlert from "./alerts/error";
|
||||
import NoticeAlert from "./alerts/notice";
|
||||
import SuccessAlert from "./alerts/success";
|
||||
import CrashErrorComponent from "./components/CrashErrorComponent";
|
||||
import Header from "./components/headerComponent";
|
||||
import { alertContext } from "./contexts/alertContext";
|
||||
import { locationContext } from "./contexts/locationContext";
|
||||
import { ErrorBoundary } from "react-error-boundary";
|
||||
import CrashErrorComponent from "./components/CrashErrorComponent";
|
||||
import { TabsContext } from "./contexts/tabsContext";
|
||||
import { getVersion } from "./controllers/API";
|
||||
import Router from "./routes";
|
||||
import Header from "./components/headerComponent";
|
||||
|
||||
export default function App() {
|
||||
let { setCurrent, setShowSideBar, setIsStackedOpen } =
|
||||
|
|
@ -51,6 +50,13 @@ export default function App() {
|
|||
useEffect(() => {
|
||||
// If there is an error alert open with data, add it to the alertsList
|
||||
if (errorOpen && errorData) {
|
||||
if (
|
||||
alertsList.length > 0 &&
|
||||
JSON.stringify(alertsList[alertsList.length - 1].data) ===
|
||||
JSON.stringify(errorData)
|
||||
) {
|
||||
return;
|
||||
}
|
||||
setErrorOpen(false);
|
||||
setAlertsList((old) => {
|
||||
let newAlertsList = [
|
||||
|
|
@ -62,6 +68,13 @@ export default function App() {
|
|||
}
|
||||
// If there is a notice alert open with data, add it to the alertsList
|
||||
else if (noticeOpen && noticeData) {
|
||||
if (
|
||||
alertsList.length > 0 &&
|
||||
JSON.stringify(alertsList[alertsList.length - 1].data) ===
|
||||
JSON.stringify(noticeData)
|
||||
) {
|
||||
return;
|
||||
}
|
||||
setNoticeOpen(false);
|
||||
setAlertsList((old) => {
|
||||
let newAlertsList = [
|
||||
|
|
@ -73,6 +86,13 @@ export default function App() {
|
|||
}
|
||||
// If there is a success alert open with data, add it to the alertsList
|
||||
else if (successOpen && successData) {
|
||||
if (
|
||||
alertsList.length > 0 &&
|
||||
JSON.stringify(alertsList[alertsList.length - 1].data) ===
|
||||
JSON.stringify(successData)
|
||||
) {
|
||||
return;
|
||||
}
|
||||
setSuccessOpen(false);
|
||||
setAlertsList((old) => {
|
||||
let newAlertsList = [
|
||||
|
|
@ -103,7 +123,7 @@ export default function App() {
|
|||
|
||||
return (
|
||||
//need parent component with width and height
|
||||
<div className="h-full flex flex-col">
|
||||
<div className="flex h-full flex-col">
|
||||
<ErrorBoundary
|
||||
onReset={() => {
|
||||
window.localStorage.removeItem("tabsData");
|
||||
|
|
@ -117,10 +137,7 @@ export default function App() {
|
|||
<Router />
|
||||
</ErrorBoundary>
|
||||
<div></div>
|
||||
<div
|
||||
className="flex flex-col-reverse fixed bottom-5 left-5"
|
||||
style={{ zIndex: 999 }}
|
||||
>
|
||||
<div className="app-div" style={{ zIndex: 999 }}>
|
||||
{alertsList.map((alert) => (
|
||||
<div key={alert.id}>
|
||||
{alert.type === "error" ? (
|
||||
|
|
|
|||
|
|
@ -1,47 +1,54 @@
|
|||
import { cloneDeep } from "lodash";
|
||||
import React, { useContext, useEffect, useRef, useState } from "react";
|
||||
import { Handle, Position, useUpdateNodeInternals } from "reactflow";
|
||||
import {
|
||||
classNames,
|
||||
groupByFamily,
|
||||
isValidConnection,
|
||||
} from "../../../../utils";
|
||||
import { useContext, useEffect, useRef, useState } from "react";
|
||||
import InputComponent from "../../../../components/inputComponent";
|
||||
import InputListComponent from "../../../../components/inputListComponent";
|
||||
import TextAreaComponent from "../../../../components/textAreaComponent";
|
||||
import { typesContext } from "../../../../contexts/typesContext";
|
||||
import { ParameterComponentType } from "../../../../types/components";
|
||||
import FloatComponent from "../../../../components/floatComponent";
|
||||
import Dropdown from "../../../../components/dropdownComponent";
|
||||
import ShadTooltip from "../../../../components/ShadTooltipComponent";
|
||||
import CodeAreaComponent from "../../../../components/codeAreaComponent";
|
||||
import Dropdown from "../../../../components/dropdownComponent";
|
||||
import FloatComponent from "../../../../components/floatComponent";
|
||||
import IconComponent from "../../../../components/genericIconComponent";
|
||||
import InputComponent from "../../../../components/inputComponent";
|
||||
import InputFileComponent from "../../../../components/inputFileComponent";
|
||||
import { TabsContext } from "../../../../contexts/tabsContext";
|
||||
import InputListComponent from "../../../../components/inputListComponent";
|
||||
import IntComponent from "../../../../components/intComponent";
|
||||
import PromptAreaComponent from "../../../../components/promptComponent";
|
||||
import { nodeNames, nodeIcons } from "../../../../utils";
|
||||
import React from "react";
|
||||
import { nodeColors } from "../../../../utils";
|
||||
import ShadTooltip from "../../../../components/ShadTooltipComponent";
|
||||
import { PopUpContext } from "../../../../contexts/popUpContext";
|
||||
import TextAreaComponent from "../../../../components/textAreaComponent";
|
||||
import ToggleShadComponent from "../../../../components/toggleShadComponent";
|
||||
import { TOOLTIP_EMPTY } from "../../../../constants/constants";
|
||||
import { TabsContext } from "../../../../contexts/tabsContext";
|
||||
import { typesContext } from "../../../../contexts/typesContext";
|
||||
import { ParameterComponentType } from "../../../../types/components";
|
||||
import { isValidConnection } from "../../../../utils/reactflowUtils";
|
||||
import {
|
||||
nodeColors,
|
||||
nodeIconsLucide,
|
||||
nodeNames,
|
||||
} from "../../../../utils/styleUtils";
|
||||
import { classNames, groupByFamily } from "../../../../utils/utils";
|
||||
|
||||
export default function ParameterComponent({
|
||||
left,
|
||||
id,
|
||||
data,
|
||||
setData,
|
||||
tooltipTitle,
|
||||
title,
|
||||
color,
|
||||
type,
|
||||
name = "",
|
||||
required = false,
|
||||
optionalHandle = null,
|
||||
info = "",
|
||||
}: ParameterComponentType) {
|
||||
const ref = useRef(null);
|
||||
const refHtml = useRef(null);
|
||||
const infoHtml = useRef(null);
|
||||
const updateNodeInternals = useUpdateNodeInternals();
|
||||
const [position, setPosition] = useState(0);
|
||||
const { closePopUp } = useContext(PopUpContext);
|
||||
const { setTabsState, tabId } = useContext(TabsContext);
|
||||
const { setTabsState, tabId, save, flows } = useContext(TabsContext);
|
||||
|
||||
const flow = flows.find((f) => f.id === tabId).data?.nodes ?? null;
|
||||
|
||||
// Update component position
|
||||
useEffect(() => {
|
||||
if (ref.current && ref.current.offsetTop && ref.current.clientHeight) {
|
||||
setPosition(ref.current.offsetTop + ref.current.clientHeight / 2);
|
||||
|
|
@ -53,80 +60,130 @@ export default function ParameterComponent({
|
|||
updateNodeInternals(data.id);
|
||||
}, [data.id, position, updateNodeInternals]);
|
||||
|
||||
const [enabled, setEnabled] = useState(
|
||||
data.node.template[name]?.value ?? false
|
||||
);
|
||||
|
||||
useEffect(() => {}, [closePopUp, data.node.template]);
|
||||
|
||||
const { reactFlowInstance } = useContext(typesContext);
|
||||
let disabled =
|
||||
reactFlowInstance?.getEdges().some((e) => e.targetHandle === id) ?? false;
|
||||
const [myData, setMyData] = useState(useContext(typesContext).data);
|
||||
|
||||
const { data: myData } = useContext(typesContext);
|
||||
|
||||
const handleOnNewValue = (newValue: any) => {
|
||||
data.node.template[name].value = newValue;
|
||||
let newData = cloneDeep(data);
|
||||
newData.node.template[name].value = newValue;
|
||||
setData(newData);
|
||||
// Set state to pending
|
||||
setTabsState((prev) => {
|
||||
return {
|
||||
...prev,
|
||||
[tabId]: {
|
||||
...prev[tabId],
|
||||
isPending: true,
|
||||
formKeysData: prev[tabId].formKeysData,
|
||||
},
|
||||
};
|
||||
});
|
||||
renderTooltips();
|
||||
};
|
||||
|
||||
useEffect(() => {
|
||||
const groupedObj = groupByFamily(myData, tooltipTitle);
|
||||
if (name === "openai_api_base") console.log(info);
|
||||
infoHtml.current = (
|
||||
<div className="h-full w-full break-words">
|
||||
{info.split("\n").map((line, i) => (
|
||||
<p key={i} className="block">
|
||||
{line}
|
||||
</p>
|
||||
))}
|
||||
</div>
|
||||
);
|
||||
}, [info]);
|
||||
|
||||
refHtml.current = groupedObj.map((item, i) => (
|
||||
<span
|
||||
key={i}
|
||||
className={classNames(
|
||||
i > 0 ? "items-center flex mt-3" : "items-center flex"
|
||||
)}
|
||||
>
|
||||
<div
|
||||
className="h-5 w-5"
|
||||
style={{
|
||||
color: nodeColors[item.family],
|
||||
}}
|
||||
>
|
||||
{React.createElement(nodeIcons[item.family])}
|
||||
</div>
|
||||
<span className="ps-2 text-gray-950">
|
||||
{nodeNames[item.family] ?? ""}{" "}
|
||||
<span className={classNames(left ? "hidden" : "")}>
|
||||
{" "}
|
||||
-
|
||||
{item.type.split(", ").length > 2
|
||||
? item.type.split(", ").map((el, i) => (
|
||||
<>
|
||||
<span key={i}>
|
||||
{i == item.type.split(", ").length - 1
|
||||
? el
|
||||
: (el += `, `)}
|
||||
</span>
|
||||
{i % 2 == 0 && i > 0 && <br></br>}
|
||||
</>
|
||||
))
|
||||
: item.type}
|
||||
function renderTooltips() {
|
||||
let groupedObj = groupByFamily(myData, tooltipTitle, left, flow);
|
||||
|
||||
if (groupedObj && groupedObj.length > 0) {
|
||||
refHtml.current = groupedObj.map((item, i) => {
|
||||
const Icon: any =
|
||||
nodeIconsLucide[item.family] ?? nodeIconsLucide["unknown"];
|
||||
|
||||
return (
|
||||
<span
|
||||
key={i}
|
||||
className={classNames(
|
||||
i > 0 ? "mt-2 flex items-center" : "flex items-center"
|
||||
)}
|
||||
>
|
||||
<div
|
||||
className="h-5 w-5"
|
||||
style={{
|
||||
color: nodeColors[item.family],
|
||||
}}
|
||||
>
|
||||
<Icon
|
||||
className="h-5 w-5"
|
||||
strokeWidth={1.5}
|
||||
style={{
|
||||
color: nodeColors[item.family] ?? nodeColors.unknown,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<span className="ps-2 text-xs text-foreground">
|
||||
{nodeNames[item.family] ?? "Other"}
|
||||
<span className="text-xs">
|
||||
{" "}
|
||||
{item.type === "" ? "" : " - "}
|
||||
{item.type.split(", ").length > 2
|
||||
? item.type.split(", ").map((el, i) => (
|
||||
<React.Fragment key={el + i}>
|
||||
<span>
|
||||
{i === item.type.split(", ").length - 1
|
||||
? el
|
||||
: (el += `, `)}
|
||||
</span>
|
||||
</React.Fragment>
|
||||
))
|
||||
: item.type}
|
||||
</span>
|
||||
</span>
|
||||
</span>
|
||||
</span>
|
||||
</span>
|
||||
));
|
||||
}, [tooltipTitle]);
|
||||
);
|
||||
});
|
||||
} else {
|
||||
refHtml.current = <span>{TOOLTIP_EMPTY}</span>;
|
||||
}
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
renderTooltips();
|
||||
}, [tooltipTitle, flow]);
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={ref}
|
||||
className="w-full flex flex-wrap justify-between items-center bg-muted dark:bg-gray-800 dark:text-white mt-1 px-5 py-2"
|
||||
className="mt-1 flex w-full flex-wrap items-center justify-between bg-muted px-5 py-2"
|
||||
>
|
||||
<>
|
||||
<div className={"text-sm truncate w-full " + (left ? "" : "text-end")}>
|
||||
<div
|
||||
className={
|
||||
"w-full truncate text-sm" +
|
||||
(left ? "" : " text-end") +
|
||||
(info !== "" ? " flex items-center" : "")
|
||||
}
|
||||
>
|
||||
{title}
|
||||
<span className="text-red-600">{required ? " *" : ""}</span>
|
||||
<span className="text-status-red">{required ? " *" : ""}</span>
|
||||
<div className="">
|
||||
{info !== "" && (
|
||||
<ShadTooltip content={infoHtml.current}>
|
||||
{/* put div to avoid bug that does not display tooltip */}
|
||||
<div>
|
||||
<IconComponent
|
||||
name="Info"
|
||||
className="relative bottom-0.5 ml-2 h-3 w-4"
|
||||
/>
|
||||
</div>
|
||||
</ShadTooltip>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
{left &&
|
||||
(type === "str" ||
|
||||
|
|
@ -135,14 +192,15 @@ export default function ParameterComponent({
|
|||
type === "code" ||
|
||||
type === "prompt" ||
|
||||
type === "file" ||
|
||||
type === "int") ? (
|
||||
type === "int") &&
|
||||
!optionalHandle ? (
|
||||
<></>
|
||||
) : (
|
||||
<ShadTooltip
|
||||
styleClasses={"tooltip-fixed-width custom-scroll nowheel"}
|
||||
delayDuration={0}
|
||||
content={refHtml.current}
|
||||
side={left ? "left" : "right"}
|
||||
open={refHtml?.current?.length > 0}
|
||||
>
|
||||
<Handle
|
||||
type={left ? "target" : "source"}
|
||||
|
|
@ -153,7 +211,7 @@ export default function ParameterComponent({
|
|||
}
|
||||
className={classNames(
|
||||
left ? "-ml-0.5 " : "-mr-0.5 ",
|
||||
"w-3 h-3 rounded-full border-2 bg-white dark:bg-gray-800"
|
||||
"h-3 w-3 rounded-full border-2 bg-background"
|
||||
)}
|
||||
style={{
|
||||
borderColor: color,
|
||||
|
|
@ -187,7 +245,6 @@ export default function ParameterComponent({
|
|||
) : (
|
||||
<InputComponent
|
||||
disabled={disabled}
|
||||
disableCopyPaste={true}
|
||||
password={data.node.template[name].password ?? false}
|
||||
value={data.node.template[name].value ?? ""}
|
||||
onChange={handleOnNewValue}
|
||||
|
|
@ -195,13 +252,12 @@ export default function ParameterComponent({
|
|||
)}
|
||||
</div>
|
||||
) : left === true && type === "bool" ? (
|
||||
<div className="mt-2">
|
||||
<div className="mt-2 w-full">
|
||||
<ToggleShadComponent
|
||||
disabled={disabled}
|
||||
enabled={enabled}
|
||||
enabled={data.node.template[name].value ?? false}
|
||||
setEnabled={(t) => {
|
||||
handleOnNewValue(t);
|
||||
setEnabled(t);
|
||||
}}
|
||||
size="large"
|
||||
/>
|
||||
|
|
@ -210,7 +266,6 @@ export default function ParameterComponent({
|
|||
<div className="mt-2 w-full">
|
||||
<FloatComponent
|
||||
disabled={disabled}
|
||||
disableCopyPaste={true}
|
||||
value={data.node.template[name].value ?? ""}
|
||||
onChange={handleOnNewValue}
|
||||
/>
|
||||
|
|
@ -218,7 +273,7 @@ export default function ParameterComponent({
|
|||
) : left === true &&
|
||||
type === "str" &&
|
||||
data.node.template[name].options ? (
|
||||
<div className="w-full">
|
||||
<div className="mt-2 w-full">
|
||||
<Dropdown
|
||||
options={data.node.template[name].options}
|
||||
onSelect={handleOnNewValue}
|
||||
|
|
@ -226,37 +281,53 @@ export default function ParameterComponent({
|
|||
></Dropdown>
|
||||
</div>
|
||||
) : left === true && type === "code" ? (
|
||||
<CodeAreaComponent
|
||||
disabled={disabled}
|
||||
value={data.node.template[name].value ?? ""}
|
||||
onChange={handleOnNewValue}
|
||||
/>
|
||||
<div className="mt-2 w-full">
|
||||
<CodeAreaComponent
|
||||
dynamic={data.node.template[name].dynamic ?? false}
|
||||
setNodeClass={(nodeClass) => {
|
||||
data.node = nodeClass;
|
||||
}}
|
||||
nodeClass={data.node}
|
||||
disabled={disabled}
|
||||
value={data.node.template[name].value ?? ""}
|
||||
onChange={handleOnNewValue}
|
||||
/>
|
||||
</div>
|
||||
) : left === true && type === "file" ? (
|
||||
<InputFileComponent
|
||||
disabled={disabled}
|
||||
value={data.node.template[name].value ?? ""}
|
||||
onChange={handleOnNewValue}
|
||||
fileTypes={data.node.template[name].fileTypes}
|
||||
suffixes={data.node.template[name].suffixes}
|
||||
onFileChange={(t: string) => {
|
||||
data.node.template[name].content = t;
|
||||
}}
|
||||
></InputFileComponent>
|
||||
<div className="mt-2 w-full">
|
||||
<InputFileComponent
|
||||
disabled={disabled}
|
||||
value={data.node.template[name].value ?? ""}
|
||||
onChange={handleOnNewValue}
|
||||
fileTypes={data.node.template[name].fileTypes}
|
||||
suffixes={data.node.template[name].suffixes}
|
||||
onFileChange={(t: string) => {
|
||||
data.node.template[name].file_path = t;
|
||||
save();
|
||||
}}
|
||||
></InputFileComponent>
|
||||
</div>
|
||||
) : left === true && type === "int" ? (
|
||||
<div className="mt-2 w-full">
|
||||
<IntComponent
|
||||
disabled={disabled}
|
||||
disableCopyPaste={true}
|
||||
value={data.node.template[name].value ?? ""}
|
||||
onChange={handleOnNewValue}
|
||||
/>
|
||||
</div>
|
||||
) : left === true && type === "prompt" ? (
|
||||
<PromptAreaComponent
|
||||
disabled={disabled}
|
||||
value={data.node.template[name].value ?? ""}
|
||||
onChange={handleOnNewValue}
|
||||
/>
|
||||
<div className="mt-2 w-full">
|
||||
<PromptAreaComponent
|
||||
field_name={name}
|
||||
setNodeClass={(nodeClass) => {
|
||||
data.node = nodeClass;
|
||||
}}
|
||||
nodeClass={data.node}
|
||||
disabled={disabled}
|
||||
value={data.node.template[name].value ?? ""}
|
||||
onChange={handleOnNewValue}
|
||||
/>
|
||||
</div>
|
||||
) : (
|
||||
<></>
|
||||
)}
|
||||
|
|
|
|||
|
|
@ -1,41 +1,54 @@
|
|||
import { classNames, nodeColors, nodeIcons, toTitleCase } from "../../utils";
|
||||
import ParameterComponent from "./components/parameterComponent";
|
||||
import { typesContext } from "../../contexts/typesContext";
|
||||
import { useContext, useState, useEffect, useRef } from "react";
|
||||
import { NodeDataType } from "../../types/flow";
|
||||
import { alertContext } from "../../contexts/alertContext";
|
||||
import { PopUpContext } from "../../contexts/popUpContext";
|
||||
import NodeModal from "../../modals/NodeModal";
|
||||
import Tooltip from "../../components/TooltipComponent";
|
||||
import { NodeToolbar } from "reactflow";
|
||||
import NodeToolbarComponent from "../../pages/FlowPage/components/nodeToolbarComponent";
|
||||
|
||||
import { cloneDeep } from "lodash";
|
||||
import { useContext, useEffect, useState } from "react";
|
||||
import { NodeToolbar, useUpdateNodeInternals } from "reactflow";
|
||||
import ShadTooltip from "../../components/ShadTooltipComponent";
|
||||
import Tooltip from "../../components/TooltipComponent";
|
||||
import IconComponent from "../../components/genericIconComponent";
|
||||
import { useSSE } from "../../contexts/SSEContext";
|
||||
import { TabsContext } from "../../contexts/tabsContext";
|
||||
import { typesContext } from "../../contexts/typesContext";
|
||||
import NodeToolbarComponent from "../../pages/FlowPage/components/nodeToolbarComponent";
|
||||
import { NodeDataType } from "../../types/flow";
|
||||
import { cleanEdges } from "../../utils/reactflowUtils";
|
||||
import { nodeColors, nodeIconsLucide } from "../../utils/styleUtils";
|
||||
import { classNames, toTitleCase } from "../../utils/utils";
|
||||
import ParameterComponent from "./components/parameterComponent";
|
||||
|
||||
export default function GenericNode({
|
||||
data,
|
||||
data: olddata,
|
||||
selected,
|
||||
}: {
|
||||
data: NodeDataType;
|
||||
selected: boolean;
|
||||
}) {
|
||||
const { setErrorData } = useContext(alertContext);
|
||||
const showError = useRef(true);
|
||||
const { types, deleteNode } = useContext(typesContext);
|
||||
|
||||
const { closePopUp, openPopUp } = useContext(PopUpContext);
|
||||
|
||||
const Icon = nodeIcons[data.type] || nodeIcons[types[data.type]];
|
||||
const [data, setData] = useState(olddata);
|
||||
const { updateFlow, flows, tabId } = useContext(TabsContext);
|
||||
const updateNodeInternals = useUpdateNodeInternals();
|
||||
const { types, deleteNode, reactFlowInstance } = useContext(typesContext);
|
||||
const name = nodeIconsLucide[data.type] ? data.type : types[data.type];
|
||||
const [validationStatus, setValidationStatus] = useState(null);
|
||||
// State for outline color
|
||||
const { sseData, isBuilding } = useSSE();
|
||||
|
||||
// useEffect(() => {
|
||||
// if (reactFlowInstance) {
|
||||
// setParams(Object.values(reactFlowInstance.toObject()));
|
||||
// }
|
||||
// }, [save]);
|
||||
useEffect(() => {
|
||||
olddata.node = data.node;
|
||||
let myFlow = flows.find((flow) => flow.id === tabId);
|
||||
if (reactFlowInstance && myFlow) {
|
||||
let flow = cloneDeep(myFlow);
|
||||
flow.data = reactFlowInstance.toObject();
|
||||
cleanEdges({
|
||||
flow: {
|
||||
edges: flow.data.edges,
|
||||
nodes: flow.data.nodes,
|
||||
},
|
||||
updateEdge: (edge) => {
|
||||
flow.data.edges = edge;
|
||||
reactFlowInstance.setEdges(edge);
|
||||
updateNodeInternals(data.id);
|
||||
},
|
||||
});
|
||||
updateFlow(flow);
|
||||
}
|
||||
}, [data]);
|
||||
|
||||
// New useEffect to watch for changes in sseData and update validation status
|
||||
useEffect(() => {
|
||||
|
|
@ -48,99 +61,91 @@ export default function GenericNode({
|
|||
}
|
||||
}, [sseData, data.id]);
|
||||
|
||||
if (!Icon) {
|
||||
if (showError.current) {
|
||||
setErrorData({
|
||||
title: data.type
|
||||
? `The ${data.type} node could not be rendered, please review your json file`
|
||||
: "There was a node that can't be rendered, please review your json file",
|
||||
});
|
||||
showError.current = false;
|
||||
}
|
||||
deleteNode(data.id);
|
||||
return;
|
||||
}
|
||||
|
||||
useEffect(() => {}, [closePopUp, data.node.template]);
|
||||
|
||||
return (
|
||||
<>
|
||||
<NodeToolbar>
|
||||
<NodeToolbarComponent
|
||||
data={data}
|
||||
openPopUp={openPopUp}
|
||||
setData={setData}
|
||||
deleteNode={deleteNode}
|
||||
></NodeToolbarComponent>
|
||||
</NodeToolbar>
|
||||
|
||||
<div
|
||||
className={classNames(
|
||||
selected ? "border border-ring" : "border dark:border-gray-700",
|
||||
"prompt-node relative flex w-96 flex-col justify-center rounded-lg bg-white dark:bg-gray-900"
|
||||
selected ? "border border-ring" : "border",
|
||||
"generic-node-div"
|
||||
)}
|
||||
>
|
||||
<div className="flex w-full items-center justify-between gap-8 rounded-t-lg border-b bg-muted p-4 dark:border-b-gray-700 dark:bg-gray-800 dark:text-white ">
|
||||
<div className="flex w-full items-center gap-2 truncate text-lg">
|
||||
<Icon
|
||||
className="h-10 w-10 rounded p-1"
|
||||
style={{
|
||||
color: nodeColors[types[data.type]] ?? nodeColors.unknown,
|
||||
}}
|
||||
{data.node.beta && (
|
||||
<div className="beta-badge-wrapper">
|
||||
<div className="beta-badge-content">BETA</div>
|
||||
</div>
|
||||
)}
|
||||
<div className="generic-node-div-title">
|
||||
<div className="generic-node-title-arrangement">
|
||||
<IconComponent
|
||||
name={name}
|
||||
className="generic-node-icon"
|
||||
iconColor={`${nodeColors[types[data.type]]}`}
|
||||
/>
|
||||
<div className="ml-2 truncate">
|
||||
<ShadTooltip delayDuration={1500} content={data.type}>
|
||||
<div className="ml-2 truncate text-gray-800">{data.type}</div>
|
||||
<div className="generic-node-tooltip-div">
|
||||
<ShadTooltip content={data.node.display_name}>
|
||||
<div className="generic-node-tooltip-div text-primary">
|
||||
{data.node.display_name}
|
||||
</div>
|
||||
</ShadTooltip>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex gap-3">
|
||||
<button
|
||||
className="relative"
|
||||
onClick={(event) => {
|
||||
event.preventDefault();
|
||||
openPopUp(<NodeModal data={data} />);
|
||||
}}
|
||||
></button>
|
||||
</div>
|
||||
<div className="flex gap-3">
|
||||
<div className="round-button-div">
|
||||
<div>
|
||||
<Tooltip
|
||||
title={
|
||||
!validationStatus ? (
|
||||
"Validating..."
|
||||
isBuilding ? (
|
||||
<span>Building...</span>
|
||||
) : !validationStatus ? (
|
||||
<span className="flex">
|
||||
Build{" "}
|
||||
<IconComponent
|
||||
name="Zap"
|
||||
className="mx-0.5 h-5 fill-build-trigger stroke-build-trigger stroke-1"
|
||||
/>{" "}
|
||||
flow to validate status.
|
||||
</span>
|
||||
) : (
|
||||
<div className="max-h-96 overflow-auto">
|
||||
{validationStatus.params ||
|
||||
""
|
||||
.split("\n")
|
||||
.map((line, index) => <div key={index}>{line}</div>)}
|
||||
{typeof validationStatus.params === "string"
|
||||
? validationStatus.params
|
||||
.split("\n")
|
||||
.map((line, index) => <div key={index}>{line}</div>)
|
||||
: ""}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
>
|
||||
<div className="w-5 h-5 relative top-[3px]">
|
||||
<div className="generic-node-status-position">
|
||||
<div
|
||||
className={classNames(
|
||||
validationStatus && validationStatus.valid
|
||||
? "w-4 h-4 rounded-full bg-green-500 opacity-100"
|
||||
: "w-4 h-4 rounded-full bg-gray-500 opacity-0 hidden animate-spin",
|
||||
"absolute w-4 hover:text-gray-500 hover:dark:text-gray-300 transition-all ease-in-out duration-200"
|
||||
? "green-status"
|
||||
: "status-build-animation",
|
||||
"status-div"
|
||||
)}
|
||||
></div>
|
||||
<div
|
||||
className={classNames(
|
||||
validationStatus && !validationStatus.valid
|
||||
? "w-4 h-4 rounded-full bg-red-500 opacity-100"
|
||||
: "w-4 h-4 rounded-full bg-gray-500 opacity-0 hidden animate-spin",
|
||||
"absolute w-4 hover:text-gray-500 hover:dark:text-gray-300 transition-all ease-in-out duration-200"
|
||||
? "red-status"
|
||||
: "status-build-animation",
|
||||
"status-div"
|
||||
)}
|
||||
></div>
|
||||
<div
|
||||
className={classNames(
|
||||
!validationStatus || isBuilding
|
||||
? "w-4 h-4 rounded-full bg-yellow-500 opacity-100"
|
||||
: "w-4 h-4 rounded-full bg-gray-500 opacity-0 hidden animate-spin",
|
||||
"absolute w-4 hover:text-gray-500 hover:dark:text-gray-300 transition-all ease-in-out duration-200"
|
||||
? "yellow-status"
|
||||
: "status-build-animation",
|
||||
"status-div"
|
||||
)}
|
||||
></div>
|
||||
</div>
|
||||
|
|
@ -149,41 +154,30 @@ export default function GenericNode({
|
|||
</div>
|
||||
</div>
|
||||
|
||||
<div className="h-full w-full py-5 text-gray-800">
|
||||
<div className="w-full px-5 pb-3 text-sm text-muted-foreground">
|
||||
{data.node.description}
|
||||
</div>
|
||||
<div className="generic-node-desc">
|
||||
<div className="generic-node-desc-text">{data.node.description}</div>
|
||||
|
||||
<>
|
||||
{Object.keys(data.node.template)
|
||||
.filter((t) => t.charAt(0) !== "_")
|
||||
.map((t: string, idx) => (
|
||||
<div key={idx}>
|
||||
{/* {idx === 0 ? (
|
||||
<div
|
||||
className={classNames(
|
||||
"px-5 py-2 mt-2 dark:text-white text-center",
|
||||
Object.keys(data.node.template).filter(
|
||||
(key) =>
|
||||
!key.startsWith("_") &&
|
||||
data.node.template[key].show &&
|
||||
!data.node.template[key].advanced
|
||||
).length === 0
|
||||
? "hidden"
|
||||
: ""
|
||||
)}
|
||||
>
|
||||
Inputs
|
||||
</div>
|
||||
) : (
|
||||
<></>
|
||||
)} */}
|
||||
{data.node.template[t].show &&
|
||||
!data.node.template[t].advanced ? (
|
||||
<ParameterComponent
|
||||
key={
|
||||
(data.node.template[t].input_types?.join(";") ??
|
||||
data.node.template[t].type) +
|
||||
"|" +
|
||||
t +
|
||||
"|" +
|
||||
data.id
|
||||
}
|
||||
data={data}
|
||||
setData={setData}
|
||||
color={
|
||||
nodeColors[types[data.node.template[t].type]] ??
|
||||
nodeColors[data.node.template[t].type] ??
|
||||
nodeColors.unknown
|
||||
}
|
||||
title={
|
||||
|
|
@ -193,12 +187,24 @@ export default function GenericNode({
|
|||
? toTitleCase(data.node.template[t].name)
|
||||
: toTitleCase(t)
|
||||
}
|
||||
info={data.node.template[t].info}
|
||||
name={t}
|
||||
tooltipTitle={data.node.template[t].type}
|
||||
tooltipTitle={
|
||||
data.node.template[t].input_types?.join("\n") ??
|
||||
data.node.template[t].type
|
||||
}
|
||||
required={data.node.template[t].required}
|
||||
id={data.node.template[t].type + "|" + t + "|" + data.id}
|
||||
id={
|
||||
(data.node.template[t].input_types?.join(";") ??
|
||||
data.node.template[t].type) +
|
||||
"|" +
|
||||
t +
|
||||
"|" +
|
||||
data.id
|
||||
}
|
||||
left={true}
|
||||
type={data.node.template[t].type}
|
||||
optionalHandle={data.node.template[t].input_types}
|
||||
/>
|
||||
) : (
|
||||
<></>
|
||||
|
|
@ -208,19 +214,22 @@ export default function GenericNode({
|
|||
<div
|
||||
className={classNames(
|
||||
Object.keys(data.node.template).length < 1 ? "hidden" : "",
|
||||
"flex w-full justify-center"
|
||||
"flex-max-width justify-center"
|
||||
)}
|
||||
>
|
||||
{" "}
|
||||
</div>
|
||||
{/* <div className="px-5 py-2 mt-2 dark:text-white text-center">
|
||||
Output
|
||||
</div> */}
|
||||
<ParameterComponent
|
||||
key={[data.type, data.id, ...data.node.base_classes].join("|")}
|
||||
data={data}
|
||||
setData={setData}
|
||||
color={nodeColors[types[data.type]] ?? nodeColors.unknown}
|
||||
title={data.type}
|
||||
tooltipTitle={`${data.node.base_classes.join("\n")}`}
|
||||
title={
|
||||
data.node.output_types && data.node.output_types.length > 0
|
||||
? data.node.output_types.join("|")
|
||||
: data.type
|
||||
}
|
||||
tooltipTitle={data.node.base_classes.join("\n")}
|
||||
id={[data.type, data.id, ...data.node.base_classes].join("|")}
|
||||
type={data.node.base_classes.join("|")}
|
||||
left={false}
|
||||
|
|
|
|||
|
|
@ -1,12 +1,7 @@
|
|||
import {
|
||||
XCircleIcon,
|
||||
XMarkIcon,
|
||||
InformationCircleIcon,
|
||||
CheckCircleIcon,
|
||||
} from "@heroicons/react/24/outline";
|
||||
import { Link } from "react-router-dom";
|
||||
import { Transition } from "@headlessui/react";
|
||||
import { useState } from "react";
|
||||
import { Link } from "react-router-dom";
|
||||
import IconComponent from "../../../../components/genericIconComponent";
|
||||
import { SingleAlertComponentType } from "../../../../types/alerts";
|
||||
|
||||
export default function SingleAlert({
|
||||
|
|
@ -30,21 +25,22 @@ export default function SingleAlert({
|
|||
>
|
||||
{type === "error" ? (
|
||||
<div
|
||||
className="flex bg-red-50 dark:bg-red-900 rounded-md p-3 mb-2 mx-2"
|
||||
className="mx-2 mb-2 flex rounded-md bg-error-background p-3"
|
||||
key={dropItem.id}
|
||||
>
|
||||
<div className="flex-shrink-0">
|
||||
<XCircleIcon
|
||||
className="h-5 w-5 text-red-400 dark:text-red-50"
|
||||
<IconComponent
|
||||
name="XCircle"
|
||||
className="h-5 w-5 text-status-red"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
</div>
|
||||
<div className="ml-3">
|
||||
<h3 className="text-sm break-words font-medium text-red-800 dark:text-white/80">
|
||||
<h3 className="break-words text-sm font-medium text-error-foreground">
|
||||
{dropItem.title}
|
||||
</h3>
|
||||
{dropItem.list ? (
|
||||
<div className="mt-2 text-sm text-red-700 dark:text-red-50">
|
||||
<div className="mt-2 text-sm text-error-foreground">
|
||||
<ul className="list-disc space-y-1 pl-5">
|
||||
{dropItem.list.map((item, idx) => (
|
||||
<li className="break-words" key={idx}>
|
||||
|
|
@ -67,34 +63,39 @@ export default function SingleAlert({
|
|||
removeAlert(dropItem.id);
|
||||
}, 500);
|
||||
}}
|
||||
className="inline-flex rounded-md bg-red-50 dark:bg-transparent p-1.5 text-red-500 dark:text-red-50"
|
||||
className="inline-flex rounded-md p-1.5 text-status-red"
|
||||
>
|
||||
<span className="sr-only">Dismiss</span>
|
||||
<XMarkIcon className="h-5 w-5" aria-hidden="true" />
|
||||
<IconComponent
|
||||
name="X"
|
||||
className="h-4 w-4 text-error-foreground"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : type === "notice" ? (
|
||||
<div
|
||||
className="flex rounded-md bg-blue-50 dark:bg-blue-900 p-3 mb-2 mx-2"
|
||||
className="mx-2 mb-2 flex rounded-md bg-info-background p-3"
|
||||
key={dropItem.id}
|
||||
>
|
||||
<div className="flex-shrink-0">
|
||||
<InformationCircleIcon
|
||||
className="h-5 w-5 text-blue-400 dark:text-blue-50"
|
||||
<IconComponent
|
||||
name="Info"
|
||||
className="h-5 w-5 text-status-blue "
|
||||
aria-hidden="true"
|
||||
/>
|
||||
</div>
|
||||
<div className="ml-3 flex-1 md:flex md:justify-between">
|
||||
<p className="text-sm text-blue-700 dark:text-white/80">
|
||||
<p className="text-sm font-medium text-info-foreground">
|
||||
{dropItem.title}
|
||||
</p>
|
||||
<p className="mt-3 text-sm md:mt-0 md:ml-6">
|
||||
<p className="mt-3 text-sm md:ml-6 md:mt-0">
|
||||
{dropItem.link ? (
|
||||
<Link
|
||||
to={dropItem.link}
|
||||
className="whitespace-nowrap font-medium text-blue-700 dark:text-blue-50 dark:hover:text-blue-100 hover:text-ring"
|
||||
className="whitespace-nowrap font-medium text-info-foreground hover:text-accent-foreground"
|
||||
>
|
||||
Details
|
||||
</Link>
|
||||
|
|
@ -113,27 +114,32 @@ export default function SingleAlert({
|
|||
removeAlert(dropItem.id);
|
||||
}, 500);
|
||||
}}
|
||||
className="inline-flex rounded-md bg-blue-50 dark:bg-transparent p-1.5 text-blue-500 dark:text-blue-50"
|
||||
className="inline-flex rounded-md p-1.5 text-info-foreground"
|
||||
>
|
||||
<span className="sr-only">Dismiss</span>
|
||||
<XMarkIcon className="h-5 w-5" aria-hidden="true" />
|
||||
<IconComponent
|
||||
name="X"
|
||||
className="h-4 w-4 text-info-foreground"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
) : (
|
||||
<div
|
||||
className="flex bg-green-50 dark:bg-green-900 p-3 mb-2 mx-2 rounded-md"
|
||||
className="mx-2 mb-2 flex rounded-md bg-success-background p-3"
|
||||
key={dropItem.id}
|
||||
>
|
||||
<div className="flex-shrink-0">
|
||||
<CheckCircleIcon
|
||||
className="h-5 w-5 text-green-400 dark:text-green-50"
|
||||
<IconComponent
|
||||
name="CheckCircle2"
|
||||
className="h-5 w-5 text-status-green"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
</div>
|
||||
<div className="ml-3">
|
||||
<p className="text-sm font-medium text-green-800 dark:bg-white/80">
|
||||
<p className="text-sm font-medium text-success-foreground">
|
||||
{dropItem.title}
|
||||
</p>
|
||||
</div>
|
||||
|
|
@ -147,10 +153,14 @@ export default function SingleAlert({
|
|||
removeAlert(dropItem.id);
|
||||
}, 500);
|
||||
}}
|
||||
className="inline-flex rounded-md bg-green-50 dark:bg-transparent p-1.5 text-green-500 dark:text-green-50"
|
||||
className="inline-flex rounded-md p-1.5 text-status-green"
|
||||
>
|
||||
<span className="sr-only">Dismiss</span>
|
||||
<XMarkIcon className="h-5 w-5" aria-hidden="true" />
|
||||
<IconComponent
|
||||
name="X"
|
||||
className="h-4 w-4 text-success-foreground"
|
||||
aria-hidden="true"
|
||||
/>
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
|
|
|
|||
|
|
@ -1,66 +1,72 @@
|
|||
import { useContext, useEffect, useRef } from "react";
|
||||
import { useContext, useState } from "react";
|
||||
import IconComponent from "../../components/genericIconComponent";
|
||||
import {
|
||||
Popover,
|
||||
PopoverContent,
|
||||
PopoverTrigger,
|
||||
} from "../../components/ui/popover";
|
||||
import { alertContext } from "../../contexts/alertContext";
|
||||
import { XMarkIcon } from "@heroicons/react/24/solid";
|
||||
import { TrashIcon } from "@heroicons/react/24/outline";
|
||||
import SingleAlert from "./components/singleAlertComponent";
|
||||
import { AlertDropdownType } from "../../types/alerts";
|
||||
import { PopUpContext } from "../../contexts/popUpContext";
|
||||
import { useOnClickOutside } from "../hooks/useOnClickOutside";
|
||||
export default function AlertDropdown({}: AlertDropdownType) {
|
||||
const { closePopUp } = useContext(PopUpContext);
|
||||
const componentRef = useRef<HTMLDivElement>(null);
|
||||
|
||||
// Use the custom hook
|
||||
useOnClickOutside(componentRef, () => {
|
||||
closePopUp();
|
||||
});
|
||||
import SingleAlert from "./components/singleAlertComponent";
|
||||
|
||||
export default function AlertDropdown({ children }: AlertDropdownType) {
|
||||
const {
|
||||
notificationList,
|
||||
clearNotificationList,
|
||||
removeFromNotificationList,
|
||||
setNotificationCenter,
|
||||
} = useContext(alertContext);
|
||||
|
||||
const [open, setOpen] = useState(false);
|
||||
|
||||
return (
|
||||
<div
|
||||
ref={componentRef}
|
||||
className="z-10 py-3 pb-4 px-2 rounded-md bg-white dark:bg-gray-800 ring-1 ring-black ring-opacity-5 shadow-lg focus:outline-none overflow-hidden w-[400px] h-[500px] flex flex-col"
|
||||
<Popover
|
||||
open={open}
|
||||
onOpenChange={(k) => {
|
||||
setOpen(k);
|
||||
if (k) setNotificationCenter(false);
|
||||
}}
|
||||
>
|
||||
<div className="flex pl-3 flex-row justify-between text-md font-medium text-gray-800 dark:text-gray-200">
|
||||
Notifications
|
||||
<div className="flex gap-3 pr-3 ">
|
||||
<button
|
||||
className="text-gray-800 hover:text-red-500 dark:text-gray-200 dark:hover:text-red-500"
|
||||
onClick={() => {
|
||||
closePopUp();
|
||||
setTimeout(clearNotificationList, 100);
|
||||
}}
|
||||
>
|
||||
<TrashIcon className="w-[1.1rem] h-[1.1rem]" />
|
||||
</button>
|
||||
<button
|
||||
className="text-gray-800 hover:text-red-500 dark:text-gray-200 dark:hover:text-red-500"
|
||||
onClick={closePopUp}
|
||||
>
|
||||
<XMarkIcon className="h-5 w-5" />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<div className="mt-3 flex flex-col overflow-y-scroll w-full h-full scrollbar-hide text-gray-900 dark:text-gray-300">
|
||||
{notificationList.length !== 0 ? (
|
||||
notificationList.map((alertItem, index) => (
|
||||
<SingleAlert
|
||||
key={alertItem.id}
|
||||
dropItem={alertItem}
|
||||
removeAlert={removeFromNotificationList}
|
||||
/>
|
||||
))
|
||||
) : (
|
||||
<div className="h-full w-full pb-16 text-gray-500 dark:text-gray-500 flex justify-center items-center">
|
||||
No new notifications
|
||||
<PopoverTrigger>{children}</PopoverTrigger>
|
||||
<PopoverContent className="flex h-[500px] w-[500px] flex-col">
|
||||
<div className="text-md flex flex-row justify-between pl-3 font-medium text-foreground">
|
||||
Notifications
|
||||
<div className="flex gap-3 pr-3 ">
|
||||
<button
|
||||
className="text-foreground hover:text-status-red"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
setTimeout(clearNotificationList, 100);
|
||||
}}
|
||||
>
|
||||
<IconComponent name="Trash2" className="h-[1.1rem] w-[1.1rem]" />
|
||||
</button>
|
||||
<button
|
||||
className="text-foreground hover:text-status-red"
|
||||
onClick={() => {
|
||||
setOpen(false);
|
||||
}}
|
||||
>
|
||||
<IconComponent name="X" className="h-5 w-5" />
|
||||
</button>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="text-high-foreground mt-3 flex h-full w-full flex-col overflow-y-scroll scrollbar-hide">
|
||||
{notificationList.length !== 0 ? (
|
||||
notificationList.map((alertItem, index) => (
|
||||
<SingleAlert
|
||||
key={alertItem.id}
|
||||
dropItem={alertItem}
|
||||
removeAlert={removeFromNotificationList}
|
||||
/>
|
||||
))
|
||||
) : (
|
||||
<div className="flex h-full w-full items-center justify-center pb-16 text-ring">
|
||||
No new notifications
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</PopoverContent>
|
||||
</Popover>
|
||||
);
|
||||
}
|
||||
|
|
|
|||
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue