Refactor ClassCodeExtractor to extract the entrypoint function arguments and return type\nAdd validation of correct format for custom_component code\nAdd function to build a template for custom_component with its code as a field value.

This commit is contained in:
gustavoschaedler 2023-06-27 19:25:44 +01:00
commit 5d430f9364
4 changed files with 122 additions and 30 deletions

View file

@ -4,6 +4,7 @@ import ast
class ClassCodeExtractor:
def __init__(self, code):
self.code = code
self.function_entrypoint_name = "build"
self.data = {
"imports": [],
"class": {
@ -61,10 +62,40 @@ class ClassCodeExtractor:
return self.data
def get_entrypoint_function_args_and_return_type(self):
data = self.extract_class_info()
functions = data.get("functions", [])
def is_valid_class_template(code: dict) -> bool:
class_name_ok = code["class"]["name"] == "PythonFunction"
function_run_exists = len(
[f for f in code["functions"] if f["name"] == "run"]) == 1
build_function = next(
(f for f in functions if f["name"] ==
self.function_entrypoint_name), None
)
return (class_name_ok and function_run_exists)
funtion_args = build_function.get("arguments", None)
return_type = build_function.get("return_type", None)
return funtion_args, return_type
def is_valid_class_template(code: dict):
function_entrypoint_name = "build"
return_type_valid_list = ["ChainCreator", "ToolCreator"]
class_name = code.get("class", {}).get("name", None)
if not class_name: # this will also check for None, empty string, etc.
return False
functions = code.get("functions", [])
# use a generator and next to find if a function matching the criteria exists
build_function = next(
(f for f in functions if f["name"] == function_entrypoint_name), None
)
if not build_function:
return False
# Check if the return type of the build function is valid
if build_function.get("return_type") not in return_type_valid_list:
return False
return True

View file

@ -1,18 +1,22 @@
from langflow.database.models.flow import Flow
from langflow.processing.process import process_graph_cached, process_tweaks
from langflow.utils.logger import logger
from langflow.api.extract_info_from_class import (
ClassCodeExtractor,
is_valid_class_template
)
from fastapi import APIRouter, Depends, HTTPException
from langflow.api.v1.schemas import (
PredictRequest,
PredictResponse,
CustomComponentResponse,
CustomComponentCode,
CustomComponentResponse
)
from langflow.interface.types import (
build_langchain_types_dict,
build_langchain_types_dict_by_creator
build_langchain_template_custom_component
)
from langflow.database.base import get_session
from sqlmodel import Session
@ -70,5 +74,40 @@ def get_version():
# @router.post("/custom_component", response_model=CustomComponentResponse, status_code=200)
@router.post("/custom_component", status_code=200)
def custom_component(code: dict):
return build_langchain_types_dict_by_creator("a")
def custom_component(
code: CustomComponentCode,
session: Session = Depends(get_session),
):
code_test = """
from langflow.interface.chains.base import ChainCreator
from langflow.interface.tools.base import ToolCreator
class MyPythonClass():
def __init__(self, title: str, author: str, year_published: int):
self.title = title
self.author = author
self.year_published = year_published
def get_details(self):
return f"Title: {self.title}, Author: {self.author}, Year Published: {self.year_published}"
def update_year_published(self, new_year: int):
self.year_published = new_year
print(f"The year of publication has been updated to {new_year}.")
def build(self, name: str, id: int, other: str) -> ChainCreator:
return ChainCreator()
"""
extractor = ClassCodeExtractor(code_test)
data = extractor.extract_class_info()
valid = is_valid_class_template(data)
function_args, function_return_type = extractor.get_entrypoint_function_args_and_return_type()
return build_langchain_template_custom_component(
code_test,
function_args,
function_return_type
)

View file

@ -116,3 +116,7 @@ class StreamData(BaseModel):
class CustomComponentResponse(BaseModel):
model: str = ""
step: str = ""
class CustomComponentCode(BaseModel):
code: str

View file

@ -54,26 +54,44 @@ def build_langchain_types_dict(): # sourcery skip: dict-assign-update-to-union
return all_types
# sourcery skip: dict-assign-update-to-union
def build_langchain_types_dict_by_creator(creator: str):
"""Build a dictionary of all langchain types"""
def find_class_type(class_name, classes_dict):
return next(
(
{"type": class_type, "class": class_name}
for class_type, class_list in classes_dict.items()
if class_name in class_list
),
{"error": "class not found"},
)
all_types = {}
creators = [
chain_creator,
agent_creator,
prompt_creator,
llm_creator,
memory_creator,
tool_creator,
toolkits_creator,
wrapper_creator,
embedding_creator,
vectorstore_creator,
documentloader_creator,
textsplitter_creator,
utility_creator,
]
def build_langchain_template_custom_component(raw_code, function_args, function_return_type):
type_list = get_type_list()
type_and_class = find_class_type("Tool", type_list)
return chain_creator.to_dict()['chains']['ConversationChain']
# Field with the Python code to allow update
code_field = {
"code": {
"required": True,
"placeholder": "",
"show": True,
"multiline": True,
"value": raw_code,
"password": False,
"name": "code",
"advanced": False,
"type": "code",
"list": False
}
}
# TODO: Add extra fields
# TODO: Build template result
template = chain_creator.to_dict()['chains']['ConversationChain']
template.get('template')['code'] = code_field.get('code')
return template
# return globals()['tool_creator'].to_dict()[type_and_class['type']][type_and_class['class']]
# return chain_creator.to_dict()['chains']['ConversationChain']