Refactor custom component code structure in endpoints.py and customs.py

- Refactored the import statements for custom components in `endpoints.py` to use the newly created `CustomComponent` class instead of the previous implementation.
- Removed unnecessary import statements for custom components in `endpoints.py`.
- Added support for a new `CustomComponentVertex` type in the `VERTEX_TYPE_MAP` dictionary in `constants.py`.
- Modified the `PromptVertex` class in `types.py` to handle input variables from prompt text more efficiently.
- Added a new `CustomComponentVertex` class in `types.py` for custom component vertices.
- Renamed the `CustomComponent` class in `custom.py` to `CustomComponent_old`.
- Created a new `CustomComponent` class in `custom.py` to replace the previous implementation.
This commit is contained in:
gustavoschaedler 2023-06-30 19:32:09 +01:00
commit 895bc202a9
5 changed files with 20 additions and 10 deletions

View file

@ -9,10 +9,12 @@ from langflow.utils.logger import logger
from fastapi import APIRouter, Depends, HTTPException, UploadFile
from langflow.api.extract_info_from_class import (
ClassCodeExtractor,
is_valid_class_template
)
# from langflow.api.extract_info_from_class import (
# ClassCodeExtractor,
# is_valid_class_template
# )
from langflow.interface.tools.custom import CustomComponent
from langflow.api.v1.schemas import (
ProcessResponse,
@ -101,9 +103,9 @@ def get_version():
async def custom_component(
raw_code: CustomComponentCode,
):
extractor = ClassCodeExtractor(raw_code.code)
extractor = CustomComponent(code=raw_code.code)
data = extractor.extract_class_info()
valid = is_valid_class_template(data)
valid = extractor.is_valid_class_template(data)
function_args, function_return_type = extractor.get_entrypoint_function_args_and_return_type()

View file

@ -27,7 +27,7 @@ CUSTOM_NODES = {
"TimeTravelGuideChain": frontend_node.chains.TimeTravelGuideChainNode(),
"MidJourneyPromptChain": frontend_node.chains.MidJourneyPromptChainNode(),
"load_qa_chain": frontend_node.chains.CombineDocsChainNode(),
},
}
}

View file

@ -32,4 +32,5 @@ VERTEX_TYPE_MAP: Dict[str, Type[Vertex]] = {
**{t: types.DocumentLoaderVertex for t in documentloader_creator.to_list()},
**{t: types.TextSplitterVertex for t in textsplitter_creator.to_list()},
**{t: types.OutputParserVertex for t in output_parser_creator.to_list()},
**{t: types.CustomComponentVertex for t in tool_creator.to_list()},
}

View file

@ -184,7 +184,8 @@ class PromptVertex(Vertex):
if "prompt" not in self.params and "messages" not in self.params:
for param in prompt_params:
prompt_text = self.params[param]
variables = extract_input_variables_from_prompt(prompt_text)
variables = extract_input_variables_from_prompt(
prompt_text)
self.params["input_variables"].extend(variables)
self.params["input_variables"] = list(
set(self.params["input_variables"])
@ -199,3 +200,8 @@ class PromptVertex(Vertex):
class OutputParserVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="output_parsers")
class CustomComponentVertex(Vertex):
def __init__(self, data: Dict):
super().__init__(data, base_type="tools")

View file

@ -52,7 +52,7 @@ class PythonFunction(Function):
code: str
class CustomComponent(BaseModel):
class CustomComponent_old(BaseModel):
code: str
function: Optional[Callable] = None
imports: Optional[str] = None
@ -78,8 +78,9 @@ class CustomComponent(BaseModel):
return validate.create_function(self.code, function_name)
class CustomComponent1(BaseModel):
class CustomComponent(BaseModel):
code: str
function: Optional[Callable] = None
function_entrypoint_name = "build"
return_type_valid_list = [
"ConversationChain",