From 0c398fb6c59fc5a03456511b7a7be1ba7caa4807 Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Sun, 25 Jun 2023 19:34:25 -0300 Subject: [PATCH] =?UTF-8?q?=F0=9F=94=A8=20refactor(loading.py):=20add=20ty?= =?UTF-8?q?pe=20hinting=20to=20instantiate=5Fagent=20function=20?= =?UTF-8?q?=F0=9F=90=9B=20fix(loading.py):=20fix=20type=20hinting=20in=20i?= =?UTF-8?q?nstantiate=5Fembedding=20function=20=F0=9F=94=A8=20refactor(loa?= =?UTF-8?q?ding.py):=20add=20type=20hinting=20to=20instantiate=5Ftextsplit?= =?UTF-8?q?ter=20function=20The=20changes=20in=20this=20commit=20add=20typ?= =?UTF-8?q?e=20hinting=20to=20the=20`instantiate=5Fagent`,=20`instantiate?= =?UTF-8?q?=5Fembedding`,=20and=20`instantiate=5Ftextsplitter`=20functions?= =?UTF-8?q?=20to=20improve=20code=20readability=20and=20maintainability.?= =?UTF-8?q?=20The=20`instantiate=5Fembedding`=20function=20had=20a=20bug?= =?UTF-8?q?=20in=20its=20type=20hinting=20which=20has=20been=20fixed.?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .../langflow/interface/initialize/loading.py | 30 +++++++++++-------- 1 file changed, 17 insertions(+), 13 deletions(-) diff --git a/src/backend/langflow/interface/initialize/loading.py b/src/backend/langflow/interface/initialize/loading.py index 4acc21383..a928ea586 100644 --- a/src/backend/langflow/interface/initialize/loading.py +++ b/src/backend/langflow/interface/initialize/loading.py @@ -16,11 +16,10 @@ from langflow.interface.toolkits.base import toolkits_creator from langflow.interface.chains.base import chain_creator from langflow.interface.utils import load_file_into_dict from langflow.utils import validate -from langchain.text_splitter import TextSplitter +from langchain.text_splitter import TextSplitter, RecursiveCharacterTextSplitter from langchain.chains.base import Chain from langchain.vectorstores.base import VectorStore from langchain.document_loaders.base import BaseLoader -from langchain.embeddings.base import Embeddings from langchain.prompts.base import BasePromptTemplate @@ -94,7 +93,7 @@ def instantiate_chains(node_type, class_object: Type[Chain], params: Dict): return class_object(**params) -def instantiate_agent(class_object: Type[Chain], params: Dict): +def instantiate_agent(class_object: Type[agent_module.Agent], params: Dict): return load_agent_executor(class_object, params) @@ -134,7 +133,7 @@ def instantiate_toolkit(node_type, class_object: Type[BaseToolkit], params: Dict return loaded_toolkit -def instantiate_embedding(class_object: Type[Embeddings], params: Dict): +def instantiate_embedding(class_object, params: Dict): params.pop("model", None) params.pop("headers", None) try: @@ -193,20 +192,25 @@ def instantiate_documentloader(class_object: Type[BaseLoader], params: Dict): return docs -def instantiate_textsplitter(class_object: Type[TextSplitter], params: Dict): +def instantiate_textsplitter( + class_object: Type[TextSplitter], + params: Dict, +): try: documents = params.pop("documents") - except KeyError as e: + except KeyError as exc: raise ValueError( "The source you provided did not load correctly or was empty." "Try changing the chunk_size of the Text Splitter." - ) from e - if "separator_type" in params and params["separator_type"] == "Text": - text_splitter = class_object(**params) - else: - params["language"] = params.pop("separator_type", None) - params.pop("separators", None) - text_splitter = class_object.from_language(**params) + ) from exc + + if type(class_object) == RecursiveCharacterTextSplitter: + if "separator_type" in params and params["separator_type"] == "Text": + text_splitter = class_object(**params) + else: + params["language"] = params.pop("separator_type", None) + params.pop("separators", None) + text_splitter = class_object.from_language(**params) return text_splitter.split_documents(documents)