From 7d483eb7c954f7c228525f7e12fb99c8539778bf Mon Sep 17 00:00:00 2001 From: gustavoschaedler Date: Wed, 28 Jun 2023 16:32:48 +0100 Subject: [PATCH] Add ClassCodeExtractor and related functions to extract and handle information from a class. Also add the utility functions is_valid_class_template, get_entrypoint_function_args_and_return_type, find_class_type, and build_langchain_template_custom_component. --- .../langflow/api/extract_info_from_class.py | 101 ++++++++++++++++++ src/backend/langflow/api/v1/endpoints.py | 27 +++++ src/backend/langflow/api/v1/schemas.py | 7 +- src/backend/langflow/interface/types.py | 43 ++++++++ 4 files changed, 177 insertions(+), 1 deletion(-) create mode 100644 src/backend/langflow/api/extract_info_from_class.py diff --git a/src/backend/langflow/api/extract_info_from_class.py b/src/backend/langflow/api/extract_info_from_class.py new file mode 100644 index 000000000..b1923068b --- /dev/null +++ b/src/backend/langflow/api/extract_info_from_class.py @@ -0,0 +1,101 @@ +import ast + + +class ClassCodeExtractor: + def __init__(self, code): + self.code = code + self.function_entrypoint_name = "build" + self.data = { + "imports": [], + "class": { + "inherited_classes": "", + "name": "", + "init": "" + }, + "functions": [] + } + + def _handle_import(self, node): + for alias in node.names: + module_name = getattr(node, 'module', None) + self.data['imports'].append( + f"{module_name}.{alias.name}" if module_name else alias.name) + + def _handle_class(self, node): + self.data['class'].update({ + 'name': node.name, + 'inherited_classes': [ast.unparse(base) for base in node.bases] + }) + + for inner_node in node.body: + if isinstance(inner_node, ast.FunctionDef): + self._handle_function(inner_node) + + def _handle_function(self, node): + function_name = node.name + function_args_str = ast.unparse(node.args) + function_args = function_args_str.split( + ", ") if function_args_str else [] + + return_type = ast.unparse(node.returns) if node.returns else "None" + + function_data = { + "name": function_name, + "arguments": function_args, + "return_type": return_type + } + + if function_name == "__init__": + self.data['class']['init'] = function_args_str.split( + ", ") if function_args_str else [] + else: + self.data["functions"].append(function_data) + + def extract_class_info(self): + module = ast.parse(self.code) + + for node in module.body: + if isinstance(node, (ast.Import, ast.ImportFrom)): + self._handle_import(node) + elif isinstance(node, ast.ClassDef): + self._handle_class(node) + + return self.data + + def get_entrypoint_function_args_and_return_type(self): + data = self.extract_class_info() + functions = data.get("functions", []) + + build_function = next( + (f for f in functions if f["name"] == + self.function_entrypoint_name), None + ) + + funtion_args = build_function.get("arguments", None) + return_type = build_function.get("return_type", None) + + return funtion_args, return_type + + +def is_valid_class_template(code: dict): + function_entrypoint_name = "build" + return_type_valid_list = ["ConversationChain", "Tool"] + + class_name = code.get("class", {}).get("name", None) + if not class_name: # this will also check for None, empty string, etc. + return False + + functions = code.get("functions", []) + # use a generator and next to find if a function matching the criteria exists + build_function = next( + (f for f in functions if f["name"] == function_entrypoint_name), None + ) + + if not build_function: + return False + + # Check if the return type of the build function is valid + if build_function.get("return_type") not in return_type_valid_list: + return False + + return True diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index 1114412c5..48a7a6261 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -6,14 +6,22 @@ from langflow.utils.logger import logger from fastapi import APIRouter, Depends, HTTPException, UploadFile +from langflow.api.extract_info_from_class import ( + ClassCodeExtractor, + is_valid_class_template +) + from langflow.api.v1.schemas import ( ProcessResponse, UploadFileResponse, + CustomComponentCode, ) from langflow.interface.types import ( build_langchain_types_dict, + build_langchain_template_custom_component ) + from langflow.database.base import get_session from sqlmodel import Session @@ -83,3 +91,22 @@ def get_version(): from langflow import __version__ return {"version": __version__} + + +# @router.post("/custom_component", response_model=CustomComponentResponse, status_code=200) +@router.post("/custom_component", status_code=200) +def custom_component( + raw_code: CustomComponentCode, + session: Session = Depends(get_session), +): + extractor = ClassCodeExtractor(raw_code.code) + data = extractor.extract_class_info() + valid = is_valid_class_template(data) + + function_args, function_return_type = extractor.get_entrypoint_function_args_and_return_type() + + return build_langchain_template_custom_component( + raw_code.code, + function_args, + function_return_type + ) diff --git a/src/backend/langflow/api/v1/schemas.py b/src/backend/langflow/api/v1/schemas.py index ed5bf8b3b..6448f07bb 100644 --- a/src/backend/langflow/api/v1/schemas.py +++ b/src/backend/langflow/api/v1/schemas.py @@ -58,7 +58,8 @@ class ChatResponse(ChatMessage): @validator("type") def validate_message_type(cls, v): if v not in ["start", "stream", "end", "error", "info", "file"]: - raise ValueError("type must be start, stream, end, error, info, or file") + raise ValueError( + "type must be start, stream, end, error, info, or file") return v @@ -106,3 +107,7 @@ class StreamData(BaseModel): def __str__(self) -> str: return f"event: {self.event}\ndata: {json.dumps(self.data)}\n\n" + + +class CustomComponentCode(BaseModel): + code: str diff --git a/src/backend/langflow/interface/types.py b/src/backend/langflow/interface/types.py index 085537756..17c1562e4 100644 --- a/src/backend/langflow/interface/types.py +++ b/src/backend/langflow/interface/types.py @@ -52,3 +52,46 @@ def build_langchain_types_dict(): # sourcery skip: dict-assign-update-to-union if created_types[creator.type_name].values(): all_types.update(created_types) return all_types + + +def find_class_type(class_name, classes_dict): + return next( + ( + {"type": class_type, "class": class_name} + for class_type, class_list in classes_dict.items() + if class_name in class_list + ), + {"error": "class not found"}, + ) + + +def build_langchain_template_custom_component(raw_code, function_args, function_return_type): + type_list = get_type_list() + type_and_class = find_class_type("Tool", type_list) + + # Field with the Python code to allow update + code_field = { + "code": { + "required": True, + "placeholder": "", + "show": True, + "multiline": True, + "value": raw_code, + "password": False, + "name": "code", + "advanced": False, + "type": "code", + "list": False + } + } + + # TODO: Add extra fields + + # TODO: Build template result + template = chain_creator.to_dict()['chains']['ConversationChain'] + + template.get('template')['code'] = code_field.get('code') + + return template + # return globals()['tool_creator'].to_dict()[type_and_class['type']][type_and_class['class']] + # return chain_creator.to_dict()['chains']['ConversationChain']