diff --git a/src/backend/langflow/api/extract_info_from_class.py b/src/backend/langflow/api/extract_info_from_class.py index 4dfb24e03..0c930e09c 100644 --- a/src/backend/langflow/api/extract_info_from_class.py +++ b/src/backend/langflow/api/extract_info_from_class.py @@ -4,6 +4,7 @@ import ast class ClassCodeExtractor: def __init__(self, code): self.code = code + self.function_entrypoint_name = "build" self.data = { "imports": [], "class": { @@ -61,10 +62,40 @@ class ClassCodeExtractor: return self.data + def get_entrypoint_function_args_and_return_type(self): + data = self.extract_class_info() + functions = data.get("functions", []) -def is_valid_class_template(code: dict) -> bool: - class_name_ok = code["class"]["name"] == "PythonFunction" - function_run_exists = len( - [f for f in code["functions"] if f["name"] == "run"]) == 1 + build_function = next( + (f for f in functions if f["name"] == + self.function_entrypoint_name), None + ) - return (class_name_ok and function_run_exists) + funtion_args = build_function.get("arguments", None) + return_type = build_function.get("return_type", None) + + return funtion_args, return_type + + +def is_valid_class_template(code: dict): + function_entrypoint_name = "build" + return_type_valid_list = ["ChainCreator", "ToolCreator"] + + class_name = code.get("class", {}).get("name", None) + if not class_name: # this will also check for None, empty string, etc. + return False + + functions = code.get("functions", []) + # use a generator and next to find if a function matching the criteria exists + build_function = next( + (f for f in functions if f["name"] == function_entrypoint_name), None + ) + + if not build_function: + return False + + # Check if the return type of the build function is valid + if build_function.get("return_type") not in return_type_valid_list: + return False + + return True diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index 4d3ac9f5e..73ffc63a1 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -1,18 +1,22 @@ from langflow.database.models.flow import Flow from langflow.processing.process import process_graph_cached, process_tweaks from langflow.utils.logger import logger - +from langflow.api.extract_info_from_class import ( + ClassCodeExtractor, + is_valid_class_template +) from fastapi import APIRouter, Depends, HTTPException from langflow.api.v1.schemas import ( PredictRequest, PredictResponse, - CustomComponentResponse, + CustomComponentCode, + CustomComponentResponse ) from langflow.interface.types import ( build_langchain_types_dict, - build_langchain_types_dict_by_creator + build_langchain_template_custom_component ) from langflow.database.base import get_session from sqlmodel import Session @@ -70,5 +74,40 @@ def get_version(): # @router.post("/custom_component", response_model=CustomComponentResponse, status_code=200) @router.post("/custom_component", status_code=200) -def custom_component(code: dict): - return build_langchain_types_dict_by_creator("a") +def custom_component( + code: CustomComponentCode, + session: Session = Depends(get_session), +): + code_test = """ +from langflow.interface.chains.base import ChainCreator +from langflow.interface.tools.base import ToolCreator + + +class MyPythonClass(): + def __init__(self, title: str, author: str, year_published: int): + self.title = title + self.author = author + self.year_published = year_published + + def get_details(self): + return f"Title: {self.title}, Author: {self.author}, Year Published: {self.year_published}" + + def update_year_published(self, new_year: int): + self.year_published = new_year + print(f"The year of publication has been updated to {new_year}.") + + def build(self, name: str, id: int, other: str) -> ChainCreator: + return ChainCreator() +""" + + extractor = ClassCodeExtractor(code_test) + data = extractor.extract_class_info() + valid = is_valid_class_template(data) + + function_args, function_return_type = extractor.get_entrypoint_function_args_and_return_type() + + return build_langchain_template_custom_component( + code_test, + function_args, + function_return_type + ) diff --git a/src/backend/langflow/api/v1/schemas.py b/src/backend/langflow/api/v1/schemas.py index 145ac9365..c8b6e2856 100644 --- a/src/backend/langflow/api/v1/schemas.py +++ b/src/backend/langflow/api/v1/schemas.py @@ -116,3 +116,7 @@ class StreamData(BaseModel): class CustomComponentResponse(BaseModel): model: str = "" step: str = "" + + +class CustomComponentCode(BaseModel): + code: str diff --git a/src/backend/langflow/interface/types.py b/src/backend/langflow/interface/types.py index ac1b59e6f..17c1562e4 100644 --- a/src/backend/langflow/interface/types.py +++ b/src/backend/langflow/interface/types.py @@ -54,26 +54,44 @@ def build_langchain_types_dict(): # sourcery skip: dict-assign-update-to-union return all_types -# sourcery skip: dict-assign-update-to-union -def build_langchain_types_dict_by_creator(creator: str): - """Build a dictionary of all langchain types""" +def find_class_type(class_name, classes_dict): + return next( + ( + {"type": class_type, "class": class_name} + for class_type, class_list in classes_dict.items() + if class_name in class_list + ), + {"error": "class not found"}, + ) - all_types = {} - creators = [ - chain_creator, - agent_creator, - prompt_creator, - llm_creator, - memory_creator, - tool_creator, - toolkits_creator, - wrapper_creator, - embedding_creator, - vectorstore_creator, - documentloader_creator, - textsplitter_creator, - utility_creator, - ] +def build_langchain_template_custom_component(raw_code, function_args, function_return_type): + type_list = get_type_list() + type_and_class = find_class_type("Tool", type_list) - return chain_creator.to_dict()['chains']['ConversationChain'] + # Field with the Python code to allow update + code_field = { + "code": { + "required": True, + "placeholder": "", + "show": True, + "multiline": True, + "value": raw_code, + "password": False, + "name": "code", + "advanced": False, + "type": "code", + "list": False + } + } + + # TODO: Add extra fields + + # TODO: Build template result + template = chain_creator.to_dict()['chains']['ConversationChain'] + + template.get('template')['code'] = code_field.get('code') + + return template + # return globals()['tool_creator'].to_dict()[type_and_class['type']][type_and_class['class']] + # return chain_creator.to_dict()['chains']['ConversationChain']