Vector stores, embeddings, document loaders, and text splitters (#145)

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-04-10 19:16:45 -03:00 committed by GitHub
commit e1ca085a7a
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
35 changed files with 2118 additions and 106 deletions

View file

@ -3,7 +3,7 @@ FROM python:3.10-slim
WORKDIR /app
# Install Poetry
RUN apt-get update && apt-get install gcc g++ curl -y
RUN apt-get update && apt-get install gcc g++ curl build-essential -y
RUN curl -sSL https://install.python-poetry.org | python3 -
# # Add Poetry to PATH
ENV PATH="${PATH}:/root/.local/bin"

1647
poetry.lock generated

File diff suppressed because it is too large Load diff

View file

@ -34,6 +34,7 @@ openai = "^0.27.2"
types-pyyaml = "^6.0.12.8"
dill = "^0.3.6"
pandas = "^1.5.3"
chromadb = "^0.3.21"
huggingface-hub = "^0.13.3"
rich = "^13.3.3"
llama-cpp-python = "0.1.23"

View file

@ -3,7 +3,7 @@ from typing import Any, Dict
from fastapi import APIRouter, HTTPException
from langflow.interface.run import process_graph
from langflow.interface.run import process_graph_cached
from langflow.interface.types import build_langchain_types_dict
# build router
@ -19,7 +19,7 @@ def get_all():
@router.post("/predict")
def get_load(data: Dict[str, Any]):
try:
return process_graph(data)
return process_graph_cached(data)
except Exception as e:
# Log stack trace
logger.exception(e)

View file

@ -1,12 +1,41 @@
import contextlib
import functools
import hashlib
import json
import os
import tempfile
from collections import OrderedDict
from pathlib import Path
import dill # type: ignore
def memoize_dict(maxsize=128):
cache = OrderedDict()
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
hashed = compute_dict_hash(args[0])
key = (func.__name__, hashed, frozenset(kwargs.items()))
if key not in cache:
result = func(*args, **kwargs)
cache[key] = result
if len(cache) > maxsize:
cache.popitem(last=False)
else:
result = cache[key]
return result
def clear_cache():
cache.clear()
wrapper.clear_cache = clear_cache
return wrapper
return decorator
PREFIX = "langflow_cache"
@ -24,6 +53,13 @@ def clear_old_cache_files(max_cache_size: int = 3):
os.remove(cache_file)
def compute_dict_hash(graph_data):
graph_data = filter_json(graph_data)
cleaned_graph_json = json.dumps(graph_data, sort_keys=True)
return hashlib.sha256(cleaned_graph_json.encode("utf-8")).hexdigest()
def filter_json(json_data):
filtered_data = json_data.copy()
@ -48,13 +84,6 @@ def filter_json(json_data):
return filtered_data
def compute_hash(graph_data):
graph_data = filter_json(graph_data)
cleaned_graph_json = json.dumps(graph_data, sort_keys=True)
return hashlib.sha256(cleaned_graph_json.encode("utf-8")).hexdigest()
def save_cache(hash_val: str, chat_data, clean_old_cache_files: bool):
cache_path = Path(tempfile.gettempdir()) / f"{PREFIX}_{hash_val}.dill"
with cache_path.open("wb") as cache_file:

View file

@ -12,6 +12,8 @@ agents:
- JsonAgent
- CSVAgent
- initialize_agent
- VectorStoreAgent
- VectorStoreRouterAgent
prompts:
- PromptTemplate
@ -44,16 +46,25 @@ wrappers:
toolkits:
- OpenAPIToolkit
- JsonToolkit
- VectorStoreInfo
- VectorStoreRouterToolkit
memories:
- ConversationBufferMemory
- ConversationSummaryMemory
- ConversationKGMemory
embeddings: []
embeddings:
- OpenAIEmbeddings
vectorstores: []
vectorstores:
- Chroma
documentloaders: []
documentloaders:
- TextLoader
- WebBaseLoader
textsplitters:
- CharacterTextSplitter
dev: false

View file

@ -8,6 +8,8 @@ CUSTOM_NODES = {
"JsonAgent": nodes.JsonAgentNode(),
"CSVAgent": nodes.CSVAgentNode(),
"initialize_agent": nodes.InitializeAgentNode(),
"VectorStoreAgent": nodes.VectorStoreAgentNode(),
"VectorStoreRouterAgent": nodes.VectorStoreRouterAgentNode(),
},
}

View file

@ -160,6 +160,7 @@ class Node:
result = result.run # type: ignore
elif hasattr(result, "get_function"):
result = result.get_function() # type: ignore
self.params[key] = result
elif isinstance(value, list) and all(
isinstance(node, Node) for node in value
@ -189,6 +190,15 @@ class Node:
def build(self, force: bool = False) -> Any:
if not self._built or force:
self._build()
#! Deepcopy is breaking for vectorstores
if self.base_type in [
"vectorstores",
"VectorStoreRouterAgent",
"VectorStoreAgent",
"VectorStoreInfo",
] or self.node_type in ["VectorStoreInfo", "VectorStoreRouterToolkit"]:
return self._built_object
return deepcopy(self._built_object)
def add_edge(self, edge: "Edge") -> None:

View file

@ -4,22 +4,30 @@ from langflow.graph.base import Edge, Node
from langflow.graph.nodes import (
AgentNode,
ChainNode,
DocumentLoaderNode,
EmbeddingNode,
FileToolNode,
LLMNode,
MemoryNode,
PromptNode,
TextSplitterNode,
ToolkitNode,
ToolNode,
VectorStoreNode,
WrapperNode,
)
from langflow.interface.agents.base import agent_creator
from langflow.interface.chains.base import chain_creator
from langflow.interface.documentLoaders.base import documentloader_creator
from langflow.interface.embeddings.base import embedding_creator
from langflow.interface.llms.base import llm_creator
from langflow.interface.memories.base import memory_creator
from langflow.interface.prompts.base import prompt_creator
from langflow.interface.textSplitters.base import textsplitter_creator
from langflow.interface.toolkits.base import toolkits_creator
from langflow.interface.tools.base import tool_creator
from langflow.interface.tools.constants import FILE_TOOLS
from langflow.interface.vectorStore.base import vectorstore_creator
from langflow.interface.wrappers.base import wrapper_creator
from langflow.utils import payload
@ -117,6 +125,10 @@ class Graph:
**{t: WrapperNode for t in wrapper_creator.to_list()},
**{t: LLMNode for t in llm_creator.to_list()},
**{t: MemoryNode for t in memory_creator.to_list()},
**{t: EmbeddingNode for t in embedding_creator.to_list()},
**{t: VectorStoreNode for t in vectorstore_creator.to_list()},
**{t: DocumentLoaderNode for t in documentloader_creator.to_list()},
**{t: TextSplitterNode for t in textsplitter_creator.to_list()},
}
if node_type in FILE_TOOLS:

View file

@ -32,6 +32,10 @@ class AgentNode(Node):
chain_node.build(tools=self.tools)
self._build()
#! Cannot deepcopy VectorStore
if self.node_type in ["VectorStoreAgent", "VectorStoreRouterAgent"]:
return self._built_object
return deepcopy(self._built_object)
@ -39,11 +43,6 @@ class ToolNode(Node):
def __init__(self, data: Dict):
super().__init__(data, base_type="tools")
def build(self, force: bool = False) -> Any:
if not self._built or force:
self._build()
return deepcopy(self._built_object)
class PromptNode(Node):
def __init__(self, data: Dict):
@ -109,32 +108,16 @@ class LLMNode(Node):
def __init__(self, data: Dict):
super().__init__(data, base_type="llms")
def build(self, force: bool = False) -> Any:
if not self._built or force:
self._build()
return deepcopy(self._built_object)
class ToolkitNode(Node):
def __init__(self, data: Dict):
super().__init__(data, base_type="toolkits")
def build(self, force: bool = False) -> Any:
if not self._built or force:
self._build()
return deepcopy(self._built_object)
class FileToolNode(ToolNode):
def __init__(self, data: Dict):
super().__init__(data)
def build(self, force: bool = False) -> Any:
if not self._built or force:
self._build()
return deepcopy(self._built_object)
class WrapperNode(Node):
def __init__(self, data: Dict):
@ -148,11 +131,26 @@ class WrapperNode(Node):
return deepcopy(self._built_object)
class DocumentLoaderNode(Node):
def __init__(self, data: Dict):
super().__init__(data, base_type="documentloaders")
class EmbeddingNode(Node):
def __init__(self, data: Dict):
super().__init__(data, base_type="embeddings")
class VectorStoreNode(Node):
def __init__(self, data: Dict):
super().__init__(data, base_type="vectorstores")
class MemoryNode(Node):
def __init__(self, data: Dict):
super().__init__(data, base_type="memory")
def build(self, force: bool = False) -> Any:
if not self._built or force:
self._build()
return deepcopy(self._built_object)
class TextSplitterNode(Node):
def __init__(self, data: Dict):
super().__init__(data, base_type="textsplitters")

View file

@ -38,6 +38,9 @@ def load_file(file_name, file_content, accepted_types) -> Any:
# Load the csv content
csv_reader = csv.DictReader(io.StringIO(decoded_string))
return list(csv_reader)
elif suffix == "txt":
# Return the text content
return decoded_string
else:
raise ValueError(f"File {file_name} is not accepted")

View file

@ -2,10 +2,21 @@ from typing import Any, List, Optional
from langchain import LLMChain
from langchain.agents import AgentExecutor, Tool, ZeroShotAgent, initialize_agent
from langchain.agents.agent_toolkits import (
VectorStoreInfo,
VectorStoreRouterToolkit,
VectorStoreToolkit,
)
from langchain.agents.agent_toolkits.json.prompt import JSON_PREFIX, JSON_SUFFIX
from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit
from langchain.agents.agent_toolkits.pandas.prompt import PREFIX as PANDAS_PREFIX
from langchain.agents.agent_toolkits.pandas.prompt import SUFFIX as PANDAS_SUFFIX
from langchain.agents.agent_toolkits.vectorstore.prompt import (
PREFIX as VECTORSTORE_PREFIX,
)
from langchain.agents.agent_toolkits.vectorstore.prompt import (
ROUTER_PREFIX as VECTORSTORE_ROUTER_PREFIX,
)
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.llms.base import BaseLLM
from langchain.memory.chat_memory import BaseChatMemory
@ -97,6 +108,83 @@ class CSVAgent(AgentExecutor):
return super().run(*args, **kwargs)
class VectorStoreAgent(AgentExecutor):
"""Vector Store agent"""
@staticmethod
def function_name():
return "VectorStoreAgent"
@classmethod
def initialize(cls, *args, **kwargs):
return cls.from_toolkit_and_llm(*args, **kwargs)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@classmethod
def from_toolkit_and_llm(
cls, llm: BaseLLM, vectorstoreinfo: VectorStoreInfo, **kwargs: Any
):
"""Construct a vectorstore agent from an LLM and tools."""
toolkit = VectorStoreToolkit(vectorstore_info=vectorstoreinfo, llm=llm)
tools = toolkit.get_tools()
prompt = ZeroShotAgent.create_prompt(tools, prefix=VECTORSTORE_PREFIX)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
return AgentExecutor.from_agent_and_tools(
agent=agent, tools=tools, verbose=True
)
def run(self, *args, **kwargs):
return super().run(*args, **kwargs)
class VectorStoreRouterAgent(AgentExecutor):
"""Vector Store Router Agent"""
@staticmethod
def function_name():
return "VectorStoreRouterAgent"
@classmethod
def initialize(cls, *args, **kwargs):
return cls.from_toolkit_and_llm(*args, **kwargs)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@classmethod
def from_toolkit_and_llm(
cls,
llm: BaseLanguageModel,
vectorstoreroutertoolkit: VectorStoreRouterToolkit,
**kwargs: Any
):
"""Construct a vector store router agent from an LLM and tools."""
tools = vectorstoreroutertoolkit.get_tools()
prompt = ZeroShotAgent.create_prompt(tools, prefix=VECTORSTORE_ROUTER_PREFIX)
llm_chain = LLMChain(
llm=llm,
prompt=prompt,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
return AgentExecutor.from_agent_and_tools(
agent=agent, tools=tools, verbose=True
)
def run(self, *args, **kwargs):
return super().run(*args, **kwargs)
class InitializeAgent(AgentExecutor):
"""Implementation of initialize_agent function"""
@ -128,4 +216,6 @@ CUSTOM_AGENTS = {
"JsonAgent": JsonAgent,
"CSVAgent": CSVAgent,
"initialize_agent": InitializeAgent,
"VectorStoreAgent": VectorStoreAgent,
"VectorStoreRouterAgent": VectorStoreRouterAgent,
}

View file

@ -1,6 +1,6 @@
import inspect
from typing import Any
## LLM
from langchain import (
chains,
document_loaders,
@ -8,6 +8,7 @@ from langchain import (
llms,
memory,
requests,
text_splitter,
vectorstores,
)
from langchain.agents import agent_toolkits
@ -15,16 +16,17 @@ from langchain.chat_models import ChatOpenAI
from langflow.interface.importing.utils import import_class
## LLM
## LLMs
llm_type_to_cls_dict = llms.type_to_cls_dict
llm_type_to_cls_dict["openai-chat"] = ChatOpenAI # type: ignore
## Chain
## Chains
chain_type_to_cls_dict: dict[str, Any] = {
chain_name: import_class(f"langchain.chains.{chain_name}")
for chain_name in chains.__all__
}
## Toolkits
toolkit_type_to_loader_dict: dict[str, Any] = {
toolkit_name: import_class(f"langchain.agents.agent_toolkits.{toolkit_name}")
# if toolkit_name is lower case it is a loader
@ -69,3 +71,8 @@ documentloaders_type_to_cls_dict: dict[str, Any] = {
)
for documentloader_name in document_loaders.__all__
}
## Text Splitters
textsplitter_type_to_cls_dict: dict[str, Any] = dict(
inspect.getmembers(text_splitter, inspect.isclass)
)

View file

@ -2,21 +2,54 @@ from typing import Dict, List, Optional
from langflow.interface.base import LangChainTypeCreator
from langflow.interface.custom_lists import documentloaders_type_to_cls_dict
from langflow.interface.documentLoaders.custom import CUSTOM_DOCUMENTLOADERS
from langflow.settings import settings
from langflow.utils.util import build_template_from_class
class DocumentLoaderCreator(LangChainTypeCreator):
type_name: str = "documentloader"
type_name: str = "documentloaders"
@property
def type_to_loader_dict(self) -> Dict:
return documentloaders_type_to_cls_dict
types = documentloaders_type_to_cls_dict
# Drop some types that are reimplemented with the same name
types.pop("TextLoader")
for name, documentloader in CUSTOM_DOCUMENTLOADERS.items():
types[name] = documentloader
return types
def get_signature(self, name: str) -> Optional[Dict]:
"""Get the signature of a document loader."""
try:
return build_template_from_class(name, documentloaders_type_to_cls_dict)
signature = build_template_from_class(
name, documentloaders_type_to_cls_dict
)
if name == "TextLoader":
signature["template"]["file"] = {
"type": "file",
"required": True,
"show": True,
"name": "path",
"value": "",
"suffixes": [".txt"],
"fileTypes": ["txt"],
}
elif name == "WebBaseLoader":
signature["template"]["web_path"] = {
"type": "str",
"required": True,
"show": True,
"name": "web_path",
"value": "",
"display_name": "Web Path",
}
return signature
except ValueError as exc:
raise ValueError(f"Documment Loader {name} not found") from exc

View file

@ -0,0 +1,22 @@
"""Load text files."""
from typing import List
from langchain.docstore.document import Document
from langchain.document_loaders.base import BaseLoader
class TextLoader(BaseLoader):
"""Load Text files."""
def __init__(self, file: str):
"""Initialize with file path."""
self.file = file
def load(self) -> List[Document]:
"""Load from file path."""
return [Document(page_content=self.file, metadata={"source": "loaded"})]
CUSTOM_DOCUMENTLOADERS = {
"TextLoader": TextLoader,
}

View file

@ -10,6 +10,7 @@ from langchain.chat_models.base import BaseChatModel
from langchain.llms.base import BaseLLM
from langchain.tools import BaseTool
from langflow.interface.documentLoaders.custom import CUSTOM_DOCUMENTLOADERS
from langflow.interface.tools.util import get_tool_by_name
@ -40,6 +41,10 @@ def import_by_type(_type: str, name: str) -> Any:
"toolkits": import_toolkit,
"wrappers": import_wrapper,
"memory": import_memory,
"embeddings": import_embedding,
"vectorstores": import_vectorstore,
"documentloaders": import_documentloader,
"textsplitters": import_textsplitter,
}
if _type == "llms":
key = "chat" if "chat" in name.lower() else "llm"
@ -113,3 +118,26 @@ def import_chain(chain: str) -> Type[Chain]:
if chain in CUSTOM_CHAINS:
return CUSTOM_CHAINS[chain]
return import_class(f"langchain.chains.{chain}")
def import_embedding(embedding: str) -> Any:
"""Import embedding from embedding name"""
return import_class(f"langchain.embeddings.{embedding}")
def import_vectorstore(vectorstore: str) -> Any:
"""Import vectorstore from vectorstore name"""
return import_class(f"langchain.vectorstores.{vectorstore}")
def import_documentloader(documentloader: str) -> Any:
"""Import documentloader from documentloader name"""
if documentloader in CUSTOM_DOCUMENTLOADERS:
return CUSTOM_DOCUMENTLOADERS[documentloader]
return import_class(f"langchain.document_loaders.{documentloader}")
def import_textsplitter(textsplitter: str) -> Any:
"""Import textsplitter from textsplitter name"""
return import_class(f"langchain.text_splitter.{textsplitter}")

View file

@ -1,10 +1,14 @@
from langflow.interface.agents.base import agent_creator
from langflow.interface.chains.base import chain_creator
from langflow.interface.documentLoaders.base import documentloader_creator
from langflow.interface.embeddings.base import embedding_creator
from langflow.interface.llms.base import llm_creator
from langflow.interface.memories.base import memory_creator
from langflow.interface.prompts.base import prompt_creator
from langflow.interface.textSplitters.base import textsplitter_creator
from langflow.interface.toolkits.base import toolkits_creator
from langflow.interface.tools.base import tool_creator
from langflow.interface.vectorStore.base import vectorstore_creator
from langflow.interface.wrappers.base import wrapper_creator
@ -18,6 +22,10 @@ def get_type_dict():
"memory": memory_creator.to_list(),
"toolkits": toolkits_creator.to_list(),
"wrappers": wrapper_creator.to_list(),
"documentLoaders": documentloader_creator.to_list(),
"vectorStore": vectorstore_creator.to_list(),
"embeddings": embedding_creator.to_list(),
"textSplitters": textsplitter_creator.to_list(),
}

View file

@ -57,6 +57,17 @@ def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any:
if toolkits_creator.has_create_function(node_type):
return load_toolkits_executor(node_type, loaded_toolkit, params)
return loaded_toolkit
elif base_type == "embeddings":
params.pop("model")
return class_object(**params)
elif base_type == "vectorstores":
return class_object.from_documents(**params)
elif base_type == "documentloaders":
return class_object(**params).load()
elif base_type == "textsplitters":
documents = params.pop("documents")
text_splitter = class_object(**params)
return text_splitter.split_documents(documents)
else:
return class_object(**params)

View file

@ -2,7 +2,7 @@ import contextlib
import io
from typing import Any, Dict
from langflow.cache.utils import compute_hash, load_cache, save_cache
from langflow.cache.utils import compute_dict_hash, load_cache, memoize_dict
from langflow.graph.graph import Graph
from langflow.interface import loading
from langflow.utils.logger import logger
@ -12,7 +12,7 @@ def load_langchain_object(data_graph, is_first_message=False):
"""
Load langchain object from cache if it exists, otherwise build it.
"""
computed_hash = compute_hash(data_graph)
computed_hash = compute_dict_hash(data_graph)
if is_first_message:
langchain_object = build_langchain_object(data_graph)
else:
@ -22,6 +22,32 @@ def load_langchain_object(data_graph, is_first_message=False):
return computed_hash, langchain_object
def load_or_build_langchain_object(data_graph, is_first_message=False):
"""
Load langchain object from cache if it exists, otherwise build it.
"""
if is_first_message:
build_langchain_object_with_caching.clear_cache()
return build_langchain_object_with_caching(data_graph)
@memoize_dict(maxsize=1)
def build_langchain_object_with_caching(data_graph):
"""
Build langchain object from data_graph.
"""
logger.debug("Building langchain object")
nodes = data_graph["nodes"]
# Add input variables
# nodes = payload.extract_input_variables(nodes)
# Nodes, edges and root node
edges = data_graph["edges"]
graph = Graph(nodes, edges)
return graph.build()
def build_langchain_object(data_graph):
"""
Build langchain object from data_graph.
@ -67,11 +93,35 @@ def process_graph(data_graph: Dict[str, Any]):
# We have to save it here because if the
# memory is updated we need to keep the new values
logger.debug("Saving langchain object to cache")
save_cache(computed_hash, langchain_object, is_first_message)
# save_cache(computed_hash, langchain_object, is_first_message)
logger.debug("Saved langchain object to cache")
return {"result": str(result), "thought": thought.strip()}
def process_graph_cached(data_graph: Dict[str, Any]):
"""
Process graph by extracting input variables and replacing ZeroShotPrompt
with PromptTemplate,then run the graph and return the result and thought.
"""
# Load langchain object
message = data_graph.pop("message", "")
is_first_message = len(data_graph.get("chatHistory", [])) == 0
langchain_object = load_or_build_langchain_object(data_graph, is_first_message)
logger.debug("Loaded langchain object")
if langchain_object is None:
# Raise user facing error
raise ValueError(
"There was an error loading the langchain_object. Please, check all the nodes and try again."
)
# Generate result and thought
logger.debug("Generating result and thought")
result, thought = get_result_and_thought_using_graph(langchain_object, message)
logger.debug("Generated result and thought")
return {"result": str(result), "thought": thought.strip()}
def get_memory_key(langchain_object):
"""
Given a LangChain object, this function retrieves the current memory key from the object's memory attribute.

View file

@ -0,0 +1,3 @@
from langflow.interface.textSplitters.base import TextSplitterCreator
__all__ = ["TextSplitterCreator"]

View file

@ -0,0 +1,40 @@
from typing import Dict, List, Optional
from langflow.interface.base import LangChainTypeCreator
from langflow.interface.custom_lists import textsplitter_type_to_cls_dict
from langflow.settings import settings
from langflow.utils.util import build_template_from_class
class TextSplitterCreator(LangChainTypeCreator):
type_name: str = "textsplitters"
@property
def type_to_loader_dict(self) -> Dict:
return textsplitter_type_to_cls_dict
def get_signature(self, name: str) -> Optional[Dict]:
"""Get the signature of a text splitter."""
try:
signature = build_template_from_class(name, textsplitter_type_to_cls_dict)
signature["template"]["documents"] = {
"type": "BaseLoader",
"required": True,
"show": True,
"name": "documents",
}
return signature
except ValueError as exc:
raise ValueError(f"Text Splitter {name} not found") from exc
def to_list(self) -> List[str]:
return [
textsplitter.__name__
for textsplitter in self.type_to_loader_dict.values()
if textsplitter.__name__ in settings.textsplitters or settings.dev
]
textsplitter_creator = TextSplitterCreator()

View file

@ -7,7 +7,7 @@ from langchain.agents.load_tools import (
)
from langchain.tools.json.tool import JsonSpec
from langflow.interface.custom.types import PythonFunction
from langflow.interface.tools.custom import PythonFunction
FILE_TOOLS = {"JsonSpec": JsonSpec}
CUSTOM_TOOLS = {"Tool": Tool, "PythonFunction": PythonFunction}

View file

@ -5,6 +5,7 @@ from langflow.interface.embeddings.base import embedding_creator
from langflow.interface.llms.base import llm_creator
from langflow.interface.memories.base import memory_creator
from langflow.interface.prompts.base import prompt_creator
from langflow.interface.textSplitters.base import textsplitter_creator
from langflow.interface.toolkits.base import toolkits_creator
from langflow.interface.tools.base import tool_creator
from langflow.interface.vectorStore.base import vectorstore_creator
@ -40,6 +41,7 @@ def build_langchain_types_dict(): # sourcery skip: dict-assign-update-to-union
embedding_creator,
vectorstore_creator,
documentloader_creator,
textsplitter_creator,
]
all_types = {}

View file

@ -7,7 +7,7 @@ from langflow.utils.util import build_template_from_class
class VectorstoreCreator(LangChainTypeCreator):
type_name: str = "vectorstore"
type_name: str = "vectorstores"
@property
def type_to_loader_dict(self) -> Dict:
@ -16,7 +16,27 @@ class VectorstoreCreator(LangChainTypeCreator):
def get_signature(self, name: str) -> Optional[Dict]:
"""Get the signature of an embedding."""
try:
return build_template_from_class(name, vectorstores_type_to_cls_dict)
signature = build_template_from_class(name, vectorstores_type_to_cls_dict)
# TODO: Use FrontendendNode class to build the signature
signature["template"] = {
"documents": {
"type": "TextSplitter",
"required": True,
"show": True,
"name": "documents",
"display_name": "Text Splitter",
},
"embedding": {
"type": "Embeddings",
"required": True,
"show": True,
"name": "embedding",
"display_name": "Embedding",
},
}
return signature
except ValueError as exc:
raise ValueError(f"Vector Store {name} not found") from exc

View file

@ -17,6 +17,7 @@ class Settings(BaseSettings):
documentloaders: List[str] = []
wrappers: List[str] = []
toolkits: List[str] = []
textsplitters: List[str] = []
dev: bool = False
class Config:
@ -40,6 +41,7 @@ class Settings(BaseSettings):
self.memories = new_settings.memories or []
self.wrappers = new_settings.wrappers or []
self.toolkits = new_settings.toolkits or []
self.textsplitters = new_settings.textsplitters or []
self.dev = new_settings.dev or False

View file

@ -248,6 +248,62 @@ class CSVAgentNode(FrontendNode):
return super().to_dict()
class VectorStoreAgentNode(FrontendNode):
name: str = "VectorStoreAgent"
template: Template = Template(
type_name="vectorstore_agent",
fields=[
TemplateField(
field_type="VectorStoreInfo",
required=True,
show=True,
name="vectorstoreinfo",
display_name="Vector Store Info",
),
TemplateField(
field_type="BaseLanguageModel",
required=True,
show=True,
name="llm",
display_name="LLM",
),
],
)
description: str = """Construct an agent from a Vector Store."""
base_classes: list[str] = ["AgentExecutor"]
def to_dict(self):
return super().to_dict()
class VectorStoreRouterAgentNode(FrontendNode):
name: str = "VectorStoreRouterAgent"
template: Template = Template(
type_name="vectorstorerouter_agent",
fields=[
TemplateField(
field_type="VectorStoreRouterToolkit",
required=True,
show=True,
name="vectorstoreroutertoolkit",
display_name="Vector Store Router Toolkit",
),
TemplateField(
field_type="BaseLanguageModel",
required=True,
show=True,
name="llm",
display_name="LLM",
),
],
)
description: str = """Construct an agent from a Vector Store Router."""
base_classes: list[str] = ["AgentExecutor"]
def to_dict(self):
return super().to_dict()
class PromptFrontendNode(FrontendNode):
@staticmethod
def format_field(field: TemplateField, name: Optional[str] = None) -> None:

View file

@ -9,7 +9,11 @@ import {
ComputerDesktopIcon,
Bars3CenterLeftIcon,
GiftIcon,
PaperClipIcon,
QuestionMarkCircleIcon,
FingerPrintIcon,
ScissorsIcon,
CircleStackIcon
} from "@heroicons/react/24/outline";
import { Connection, Edge, Node, ReactFlowInstance } from "reactflow";
import { FlowType } from "./types/flow";
@ -71,11 +75,14 @@ export const nodeColors: {[char: string]: string} = {
chains: "#FE7500",
agents: "#903BBE",
tools: "#FF3434",
memories: "#FF9135",
memories: "#F5B85A",
advanced: "#000000",
chat: "#454173",
thought:"#272541",
docloaders:"#FF9135",
embeddings:"#42BAA7",
documentloaders:"#7AAE42",
vectorstores: "#AA8742",
textsplitters: "#B47CB5",
toolkits:"#DB2C2C",
wrappers:"#E6277A",
unknown:"#9CA3AF"
@ -90,9 +97,12 @@ export const nodeNames:{[char: string]: string} = {
memories: "Memories",
advanced: "Advanced",
chat: "Chat",
docloaders:"Document Loader",
embeddings: "Embeddings",
documentloaders: "Document Loaders",
vectorstores: "Vector Stores",
toolkits:"Toolkits",
wrappers:"Wrappers",
textsplitters: "Text Splitters",
unknown:"Unknown"
};
@ -105,8 +115,11 @@ export const nodeIcons:{[char: string]: React.ForwardRefExoticComponent<React.SV
tools: WrenchIcon,
advanced: ComputerDesktopIcon,
chat: Bars3CenterLeftIcon,
docloaders:Bars3CenterLeftIcon,
embeddings:FingerPrintIcon,
documentloaders:PaperClipIcon,
vectorstores: CircleStackIcon,
toolkits:WrenchScrewdriverIcon,
textsplitters:ScissorsIcon,
wrappers:GiftIcon,
unknown:QuestionMarkCircleIcon
};

View file

@ -1,6 +1,6 @@
# Test this:
import pytest
from langflow.interface.custom.types import PythonFunction
from langflow.interface.tools.custom import PythonFunction
from langflow.utils import constants

View file

@ -3,6 +3,7 @@ from typing import Type, Union
import pytest
from langchain.agents import AgentExecutor
from langchain.llms.fake import FakeListLLM
from langflow.graph import Edge, Graph, Node
from langflow.graph.nodes import (
AgentNode,
@ -14,9 +15,8 @@ from langflow.graph.nodes import (
ToolNode,
WrapperNode,
)
from langflow.utils.payload import build_json, get_root_node
from langflow.interface.run import get_result_and_thought_using_graph
from langchain.llms.fake import FakeListLLM
from langflow.utils.payload import build_json, get_root_node
# Test cases for the graph module

View file

@ -0,0 +1,12 @@
from fastapi.testclient import TestClient
from langflow.settings import settings
# check that all agents are in settings.agents
# are in json_response["agents"]
def test_vectorstores_settings(client: TestClient):
response = client.get("/all")
assert response.status_code == 200
json_response = response.json()
vectorstores = json_response["vectorstores"]
assert set(vectorstores.keys()) == set(settings.vectorstores)