diff --git a/src/backend/langflow/interface/initialize/loading.py b/src/backend/langflow/interface/initialize/loading.py index 4acc21383..a928ea586 100644 --- a/src/backend/langflow/interface/initialize/loading.py +++ b/src/backend/langflow/interface/initialize/loading.py @@ -16,11 +16,10 @@ from langflow.interface.toolkits.base import toolkits_creator from langflow.interface.chains.base import chain_creator from langflow.interface.utils import load_file_into_dict from langflow.utils import validate -from langchain.text_splitter import TextSplitter +from langchain.text_splitter import TextSplitter, RecursiveCharacterTextSplitter from langchain.chains.base import Chain from langchain.vectorstores.base import VectorStore from langchain.document_loaders.base import BaseLoader -from langchain.embeddings.base import Embeddings from langchain.prompts.base import BasePromptTemplate @@ -94,7 +93,7 @@ def instantiate_chains(node_type, class_object: Type[Chain], params: Dict): return class_object(**params) -def instantiate_agent(class_object: Type[Chain], params: Dict): +def instantiate_agent(class_object: Type[agent_module.Agent], params: Dict): return load_agent_executor(class_object, params) @@ -134,7 +133,7 @@ def instantiate_toolkit(node_type, class_object: Type[BaseToolkit], params: Dict return loaded_toolkit -def instantiate_embedding(class_object: Type[Embeddings], params: Dict): +def instantiate_embedding(class_object, params: Dict): params.pop("model", None) params.pop("headers", None) try: @@ -193,20 +192,25 @@ def instantiate_documentloader(class_object: Type[BaseLoader], params: Dict): return docs -def instantiate_textsplitter(class_object: Type[TextSplitter], params: Dict): +def instantiate_textsplitter( + class_object: Type[TextSplitter], + params: Dict, +): try: documents = params.pop("documents") - except KeyError as e: + except KeyError as exc: raise ValueError( "The source you provided did not load correctly or was empty." "Try changing the chunk_size of the Text Splitter." - ) from e - if "separator_type" in params and params["separator_type"] == "Text": - text_splitter = class_object(**params) - else: - params["language"] = params.pop("separator_type", None) - params.pop("separators", None) - text_splitter = class_object.from_language(**params) + ) from exc + + if type(class_object) == RecursiveCharacterTextSplitter: + if "separator_type" in params and params["separator_type"] == "Text": + text_splitter = class_object(**params) + else: + params["language"] = params.pop("separator_type", None) + params.pop("separators", None) + text_splitter = class_object.from_language(**params) return text_splitter.split_documents(documents)