From 895bc202a910dee37074aa1fb628fb7cb2cba7ce Mon Sep 17 00:00:00 2001 From: gustavoschaedler Date: Fri, 30 Jun 2023 19:32:09 +0100 Subject: [PATCH] Refactor custom component code structure in `endpoints.py` and `customs.py` - Refactored the import statements for custom components in `endpoints.py` to use the newly created `CustomComponent` class instead of the previous implementation. - Removed unnecessary import statements for custom components in `endpoints.py`. - Added support for a new `CustomComponentVertex` type in the `VERTEX_TYPE_MAP` dictionary in `constants.py`. - Modified the `PromptVertex` class in `types.py` to handle input variables from prompt text more efficiently. - Added a new `CustomComponentVertex` class in `types.py` for custom component vertices. - Renamed the `CustomComponent` class in `custom.py` to `CustomComponent_old`. - Created a new `CustomComponent` class in `custom.py` to replace the previous implementation. --- src/backend/langflow/api/v1/endpoints.py | 14 ++++++++------ src/backend/langflow/custom/customs.py | 2 +- src/backend/langflow/graph/graph/constants.py | 1 + src/backend/langflow/graph/vertex/types.py | 8 +++++++- src/backend/langflow/interface/tools/custom.py | 5 +++-- 5 files changed, 20 insertions(+), 10 deletions(-) diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index 7104a2002..a9868d15e 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -9,10 +9,12 @@ from langflow.utils.logger import logger from fastapi import APIRouter, Depends, HTTPException, UploadFile -from langflow.api.extract_info_from_class import ( - ClassCodeExtractor, - is_valid_class_template -) +# from langflow.api.extract_info_from_class import ( +# ClassCodeExtractor, +# is_valid_class_template +# ) + +from langflow.interface.tools.custom import CustomComponent from langflow.api.v1.schemas import ( ProcessResponse, @@ -101,9 +103,9 @@ def get_version(): async def custom_component( raw_code: CustomComponentCode, ): - extractor = ClassCodeExtractor(raw_code.code) + extractor = CustomComponent(code=raw_code.code) data = extractor.extract_class_info() - valid = is_valid_class_template(data) + valid = extractor.is_valid_class_template(data) function_args, function_return_type = extractor.get_entrypoint_function_args_and_return_type() diff --git a/src/backend/langflow/custom/customs.py b/src/backend/langflow/custom/customs.py index a6ecb75f7..9f6e76be5 100644 --- a/src/backend/langflow/custom/customs.py +++ b/src/backend/langflow/custom/customs.py @@ -27,7 +27,7 @@ CUSTOM_NODES = { "TimeTravelGuideChain": frontend_node.chains.TimeTravelGuideChainNode(), "MidJourneyPromptChain": frontend_node.chains.MidJourneyPromptChainNode(), "load_qa_chain": frontend_node.chains.CombineDocsChainNode(), - }, + } } diff --git a/src/backend/langflow/graph/graph/constants.py b/src/backend/langflow/graph/graph/constants.py index 7775c50ff..d20ea4053 100644 --- a/src/backend/langflow/graph/graph/constants.py +++ b/src/backend/langflow/graph/graph/constants.py @@ -32,4 +32,5 @@ VERTEX_TYPE_MAP: Dict[str, Type[Vertex]] = { **{t: types.DocumentLoaderVertex for t in documentloader_creator.to_list()}, **{t: types.TextSplitterVertex for t in textsplitter_creator.to_list()}, **{t: types.OutputParserVertex for t in output_parser_creator.to_list()}, + **{t: types.CustomComponentVertex for t in tool_creator.to_list()}, } diff --git a/src/backend/langflow/graph/vertex/types.py b/src/backend/langflow/graph/vertex/types.py index af2081217..ace606163 100644 --- a/src/backend/langflow/graph/vertex/types.py +++ b/src/backend/langflow/graph/vertex/types.py @@ -184,7 +184,8 @@ class PromptVertex(Vertex): if "prompt" not in self.params and "messages" not in self.params: for param in prompt_params: prompt_text = self.params[param] - variables = extract_input_variables_from_prompt(prompt_text) + variables = extract_input_variables_from_prompt( + prompt_text) self.params["input_variables"].extend(variables) self.params["input_variables"] = list( set(self.params["input_variables"]) @@ -199,3 +200,8 @@ class PromptVertex(Vertex): class OutputParserVertex(Vertex): def __init__(self, data: Dict): super().__init__(data, base_type="output_parsers") + + +class CustomComponentVertex(Vertex): + def __init__(self, data: Dict): + super().__init__(data, base_type="tools") diff --git a/src/backend/langflow/interface/tools/custom.py b/src/backend/langflow/interface/tools/custom.py index 14de8a205..4b0f6f1ad 100644 --- a/src/backend/langflow/interface/tools/custom.py +++ b/src/backend/langflow/interface/tools/custom.py @@ -52,7 +52,7 @@ class PythonFunction(Function): code: str -class CustomComponent(BaseModel): +class CustomComponent_old(BaseModel): code: str function: Optional[Callable] = None imports: Optional[str] = None @@ -78,8 +78,9 @@ class CustomComponent(BaseModel): return validate.create_function(self.code, function_name) -class CustomComponent1(BaseModel): +class CustomComponent(BaseModel): code: str + function: Optional[Callable] = None function_entrypoint_name = "build" return_type_valid_list = [ "ConversationChain",