Add ClassCodeExtractor and related functions to extract and handle information from a class. Also add the utility functions is_valid_class_template, get_entrypoint_function_args_and_return_type, find_class_type, and build_langchain_template_custom_component.

This commit is contained in:
gustavoschaedler 2023-06-28 16:32:48 +01:00
commit 7d483eb7c9
4 changed files with 177 additions and 1 deletions

View file

@ -0,0 +1,101 @@
import ast
class ClassCodeExtractor:
def __init__(self, code):
self.code = code
self.function_entrypoint_name = "build"
self.data = {
"imports": [],
"class": {
"inherited_classes": "",
"name": "",
"init": ""
},
"functions": []
}
def _handle_import(self, node):
for alias in node.names:
module_name = getattr(node, 'module', None)
self.data['imports'].append(
f"{module_name}.{alias.name}" if module_name else alias.name)
def _handle_class(self, node):
self.data['class'].update({
'name': node.name,
'inherited_classes': [ast.unparse(base) for base in node.bases]
})
for inner_node in node.body:
if isinstance(inner_node, ast.FunctionDef):
self._handle_function(inner_node)
def _handle_function(self, node):
function_name = node.name
function_args_str = ast.unparse(node.args)
function_args = function_args_str.split(
", ") if function_args_str else []
return_type = ast.unparse(node.returns) if node.returns else "None"
function_data = {
"name": function_name,
"arguments": function_args,
"return_type": return_type
}
if function_name == "__init__":
self.data['class']['init'] = function_args_str.split(
", ") if function_args_str else []
else:
self.data["functions"].append(function_data)
def extract_class_info(self):
module = ast.parse(self.code)
for node in module.body:
if isinstance(node, (ast.Import, ast.ImportFrom)):
self._handle_import(node)
elif isinstance(node, ast.ClassDef):
self._handle_class(node)
return self.data
def get_entrypoint_function_args_and_return_type(self):
data = self.extract_class_info()
functions = data.get("functions", [])
build_function = next(
(f for f in functions if f["name"] ==
self.function_entrypoint_name), None
)
funtion_args = build_function.get("arguments", None)
return_type = build_function.get("return_type", None)
return funtion_args, return_type
def is_valid_class_template(code: dict):
function_entrypoint_name = "build"
return_type_valid_list = ["ConversationChain", "Tool"]
class_name = code.get("class", {}).get("name", None)
if not class_name: # this will also check for None, empty string, etc.
return False
functions = code.get("functions", [])
# use a generator and next to find if a function matching the criteria exists
build_function = next(
(f for f in functions if f["name"] == function_entrypoint_name), None
)
if not build_function:
return False
# Check if the return type of the build function is valid
if build_function.get("return_type") not in return_type_valid_list:
return False
return True

View file

@ -6,14 +6,22 @@ from langflow.utils.logger import logger
from fastapi import APIRouter, Depends, HTTPException, UploadFile
from langflow.api.extract_info_from_class import (
ClassCodeExtractor,
is_valid_class_template
)
from langflow.api.v1.schemas import (
ProcessResponse,
UploadFileResponse,
CustomComponentCode,
)
from langflow.interface.types import (
build_langchain_types_dict,
build_langchain_template_custom_component
)
from langflow.database.base import get_session
from sqlmodel import Session
@ -83,3 +91,22 @@ def get_version():
from langflow import __version__
return {"version": __version__}
# @router.post("/custom_component", response_model=CustomComponentResponse, status_code=200)
@router.post("/custom_component", status_code=200)
def custom_component(
raw_code: CustomComponentCode,
session: Session = Depends(get_session),
):
extractor = ClassCodeExtractor(raw_code.code)
data = extractor.extract_class_info()
valid = is_valid_class_template(data)
function_args, function_return_type = extractor.get_entrypoint_function_args_and_return_type()
return build_langchain_template_custom_component(
raw_code.code,
function_args,
function_return_type
)

View file

@ -58,7 +58,8 @@ class ChatResponse(ChatMessage):
@validator("type")
def validate_message_type(cls, v):
if v not in ["start", "stream", "end", "error", "info", "file"]:
raise ValueError("type must be start, stream, end, error, info, or file")
raise ValueError(
"type must be start, stream, end, error, info, or file")
return v
@ -106,3 +107,7 @@ class StreamData(BaseModel):
def __str__(self) -> str:
return f"event: {self.event}\ndata: {json.dumps(self.data)}\n\n"
class CustomComponentCode(BaseModel):
code: str

View file

@ -52,3 +52,46 @@ def build_langchain_types_dict(): # sourcery skip: dict-assign-update-to-union
if created_types[creator.type_name].values():
all_types.update(created_types)
return all_types
def find_class_type(class_name, classes_dict):
return next(
(
{"type": class_type, "class": class_name}
for class_type, class_list in classes_dict.items()
if class_name in class_list
),
{"error": "class not found"},
)
def build_langchain_template_custom_component(raw_code, function_args, function_return_type):
type_list = get_type_list()
type_and_class = find_class_type("Tool", type_list)
# Field with the Python code to allow update
code_field = {
"code": {
"required": True,
"placeholder": "",
"show": True,
"multiline": True,
"value": raw_code,
"password": False,
"name": "code",
"advanced": False,
"type": "code",
"list": False
}
}
# TODO: Add extra fields
# TODO: Build template result
template = chain_creator.to_dict()['chains']['ConversationChain']
template.get('template')['code'] = code_field.get('code')
return template
# return globals()['tool_creator'].to_dict()[type_and_class['type']][type_and_class['class']]
# return chain_creator.to_dict()['chains']['ConversationChain']