diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index 7104a2002..a9868d15e 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -9,10 +9,12 @@ from langflow.utils.logger import logger from fastapi import APIRouter, Depends, HTTPException, UploadFile -from langflow.api.extract_info_from_class import ( - ClassCodeExtractor, - is_valid_class_template -) +# from langflow.api.extract_info_from_class import ( +# ClassCodeExtractor, +# is_valid_class_template +# ) + +from langflow.interface.tools.custom import CustomComponent from langflow.api.v1.schemas import ( ProcessResponse, @@ -101,9 +103,9 @@ def get_version(): async def custom_component( raw_code: CustomComponentCode, ): - extractor = ClassCodeExtractor(raw_code.code) + extractor = CustomComponent(code=raw_code.code) data = extractor.extract_class_info() - valid = is_valid_class_template(data) + valid = extractor.is_valid_class_template(data) function_args, function_return_type = extractor.get_entrypoint_function_args_and_return_type() diff --git a/src/backend/langflow/custom/customs.py b/src/backend/langflow/custom/customs.py index a6ecb75f7..9f6e76be5 100644 --- a/src/backend/langflow/custom/customs.py +++ b/src/backend/langflow/custom/customs.py @@ -27,7 +27,7 @@ CUSTOM_NODES = { "TimeTravelGuideChain": frontend_node.chains.TimeTravelGuideChainNode(), "MidJourneyPromptChain": frontend_node.chains.MidJourneyPromptChainNode(), "load_qa_chain": frontend_node.chains.CombineDocsChainNode(), - }, + } } diff --git a/src/backend/langflow/graph/graph/constants.py b/src/backend/langflow/graph/graph/constants.py index 7775c50ff..d20ea4053 100644 --- a/src/backend/langflow/graph/graph/constants.py +++ b/src/backend/langflow/graph/graph/constants.py @@ -32,4 +32,5 @@ VERTEX_TYPE_MAP: Dict[str, Type[Vertex]] = { **{t: types.DocumentLoaderVertex for t in documentloader_creator.to_list()}, **{t: types.TextSplitterVertex for t in textsplitter_creator.to_list()}, **{t: types.OutputParserVertex for t in output_parser_creator.to_list()}, + **{t: types.CustomComponentVertex for t in tool_creator.to_list()}, } diff --git a/src/backend/langflow/graph/vertex/types.py b/src/backend/langflow/graph/vertex/types.py index af2081217..ace606163 100644 --- a/src/backend/langflow/graph/vertex/types.py +++ b/src/backend/langflow/graph/vertex/types.py @@ -184,7 +184,8 @@ class PromptVertex(Vertex): if "prompt" not in self.params and "messages" not in self.params: for param in prompt_params: prompt_text = self.params[param] - variables = extract_input_variables_from_prompt(prompt_text) + variables = extract_input_variables_from_prompt( + prompt_text) self.params["input_variables"].extend(variables) self.params["input_variables"] = list( set(self.params["input_variables"]) @@ -199,3 +200,8 @@ class PromptVertex(Vertex): class OutputParserVertex(Vertex): def __init__(self, data: Dict): super().__init__(data, base_type="output_parsers") + + +class CustomComponentVertex(Vertex): + def __init__(self, data: Dict): + super().__init__(data, base_type="tools") diff --git a/src/backend/langflow/interface/tools/custom.py b/src/backend/langflow/interface/tools/custom.py index 14de8a205..4b0f6f1ad 100644 --- a/src/backend/langflow/interface/tools/custom.py +++ b/src/backend/langflow/interface/tools/custom.py @@ -52,7 +52,7 @@ class PythonFunction(Function): code: str -class CustomComponent(BaseModel): +class CustomComponent_old(BaseModel): code: str function: Optional[Callable] = None imports: Optional[str] = None @@ -78,8 +78,9 @@ class CustomComponent(BaseModel): return validate.create_function(self.code, function_name) -class CustomComponent1(BaseModel): +class CustomComponent(BaseModel): code: str + function: Optional[Callable] = None function_entrypoint_name = "build" return_type_valid_list = [ "ConversationChain",