diff --git a/src/backend/langflow/api/v1/base.py b/src/backend/langflow/api/v1/base.py index d595210bb..1a4936a2f 100644 --- a/src/backend/langflow/api/v1/base.py +++ b/src/backend/langflow/api/v1/base.py @@ -1,3 +1,4 @@ +from langflow.template.frontend_node.base import FrontendNode from pydantic import BaseModel, validator from langflow.interface.utils import extract_input_variables_from_prompt @@ -12,8 +13,9 @@ class Code(BaseModel): code: str -class Prompt(BaseModel): +class ValidatePromptRequest(BaseModel): template: str + frontend_node: FrontendNode # Build ValidationResponse class for {"imports": {"errors": []}, "function": {"errors": []}} @@ -32,6 +34,7 @@ class CodeValidationResponse(BaseModel): class PromptValidationResponse(BaseModel): input_variables: list + frontend_node: FrontendNode INVALID_CHARACTERS = { @@ -66,7 +69,7 @@ def validate_prompt(template: str): # if len(input_variables) > 1: # # If there's more than one input variable - return PromptValidationResponse(input_variables=input_variables) + return input_variables def check_input_variables(input_variables: list): diff --git a/src/backend/langflow/api/v1/endpoints.py b/src/backend/langflow/api/v1/endpoints.py index f18e3056d..2d93aebb1 100644 --- a/src/backend/langflow/api/v1/endpoints.py +++ b/src/backend/langflow/api/v1/endpoints.py @@ -11,7 +11,10 @@ from langflow.api.v1.schemas import ( UploadFileResponse, ) -from langflow.interface.types import build_langchain_types_dict +from langflow.interface.types import ( + build_langchain_types_dict, + build_langchain_template_custom_component, +) from langflow.database.base import get_session from sqlmodel import Session @@ -81,3 +84,45 @@ def get_version(): from langflow import __version__ return {"version": __version__} + + +# @router.post("/custom_component", response_model=CustomComponentResponse, status_code=200) +@router.post("/custom_component", status_code=200) +def custom_component( + code: CustomComponentCode, + session: Session = Depends(get_session), +): + code_test = """ +from langflow.interface.chains.base import ChainCreator +from langflow.interface.tools.base import ToolCreator + + +class MyPythonClass(): + def __init__(self, title: str, author: str, year_published: int): + self.title = title + self.author = author + self.year_published = year_published + + def get_details(self): + return f"Title: {self.title}, Author: {self.author}, Year Published: {self.year_published}" + + def update_year_published(self, new_year: int): + self.year_published = new_year + print(f"The year of publication has been updated to {new_year}.") + + def build(self, name: str, id: int, other: str) -> ChainCreator: + return ChainCreator() +""" + + extractor = ClassCodeExtractor(code_test) + data = extractor.extract_class_info() + is_valid_class_template(data) + + ( + function_args, + function_return_type, + ) = extractor.get_entrypoint_function_args_and_return_type() + + return build_langchain_template_custom_component( + code_test, function_args, function_return_type + ) diff --git a/src/backend/langflow/api/v1/validate.py b/src/backend/langflow/api/v1/validate.py index 959273a00..3046ad9ba 100644 --- a/src/backend/langflow/api/v1/validate.py +++ b/src/backend/langflow/api/v1/validate.py @@ -3,10 +3,11 @@ from fastapi import APIRouter, HTTPException from langflow.api.v1.base import ( Code, CodeValidationResponse, - Prompt, + ValidatePromptRequest, PromptValidationResponse, validate_prompt, ) +from langflow.template.field.base import TemplateField from langflow.utils.logger import logger from langflow.utils.validate import validate_code @@ -27,9 +28,24 @@ def post_validate_code(code: Code): @router.post("/prompt", status_code=200, response_model=PromptValidationResponse) -def post_validate_prompt(prompt: Prompt): +def post_validate_prompt(prompt: ValidatePromptRequest): try: - return validate_prompt(prompt.template) + input_variables = validate_prompt(prompt.template) + for variable in input_variables: + try: + template_field = TemplateField( + name=variable, field_type="str", show=True, advanced=False + ) + prompt.frontend_node.template.fields.append(template_field) + prompt.frontend_node.custom_fields.append(variable) + except Exception as exc: + logger.exception(exc) + raise HTTPException(status_code=500, detail=str(exc)) from exc + + return PromptValidationResponse( + input_variables=input_variables, + frontend_node=prompt.frontend_node, + ) except Exception as e: logger.exception(e) raise HTTPException(status_code=500, detail=str(e)) from e diff --git a/src/backend/langflow/main.py b/src/backend/langflow/main.py index 2a1293f2e..12d70bfb4 100644 --- a/src/backend/langflow/main.py +++ b/src/backend/langflow/main.py @@ -5,6 +5,239 @@ from langflow.api import router from langflow.database.base import create_db_and_tables from langflow.interface.utils import setup_llm_caching +template_node = { + "template": { + "code": { + "required": True, + "placeholder": "", + "show": True, + "multiline": True, + "value": "\ndef my_user_python_function(text: str) -> str:\n \"\"\"This is a default python function that returns the input text\"\"\"\n return text.upper()\n", + "password": False, + "name": "code", + "advanced": False, + "type": "code", + "list": False + }, + "lc_kwargs": { + "required": False, + "placeholder": "", + "show": False, + "multiline": False, + "password": False, + "name": "lc_kwargs", + "advanced": True, + "type": "code", + "list": False + }, + "verbose": { + "required": False, + "placeholder": "", + "show": False, + "multiline": False, + "value": False, + "password": False, + "name": "verbose", + "advanced": False, + "type": "bool", + "list": False + }, + "callbacks": { + "required": False, + "placeholder": "", + "show": False, + "multiline": False, + "password": False, + "name": "callbacks", + "advanced": False, + "type": "langchain.callbacks.base.BaseCallbackHandler", + "list": True + }, + "tags": { + "required": False, + "placeholder": "", + "show": False, + "multiline": False, + "password": False, + "name": "tags", + "advanced": False, + "type": "str", + "list": True + }, + "client": { + "required": False, + "placeholder": "", + "show": False, + "multiline": False, + "password": False, + "name": "client", + "advanced": False, + "type": "Any", + "list": False + }, + "model_name": { + "required": False, + "placeholder": "", + "show": True, + "multiline": False, + "value": "gpt-3.5-turbo", + "password": False, + "options": [ + "gpt-3.5-turbo-0613", + "gpt-3.5-turbo", + "gpt-3.5-turbo-16k-0613", + "gpt-3.5-turbo-16k", + "gpt-4-0613", + "gpt-4-32k-0613", + "gpt-4", + "gpt-4-32k" + ], + "name": "model_name", + "advanced": False, + "type": "str", + "list": True + }, + "temperature": { + "required": False, + "placeholder": "", + "show": True, + "multiline": False, + "value": 0.7, + "password": False, + "name": "temperature", + "advanced": False, + "type": "float", + "list": False + }, + "model_kwargs": { + "required": False, + "placeholder": "", + "show": True, + "multiline": False, + "password": False, + "name": "model_kwargs", + "advanced": True, + "type": "code", + "list": False + }, + "openai_api_key": { + "required": False, + "placeholder": "", + "show": True, + "multiline": False, + "value": "", + "password": True, + "name": "openai_api_key", + "display_name": "OpenAI API Key", + "advanced": False, + "type": "str", + "list": False + }, + "openai_api_base": { + "required": False, + "placeholder": "", + "show": True, + "multiline": False, + "password": False, + "name": "openai_api_base", + "display_name": "OpenAI API Base", + "advanced": False, + "type": "str", + "list": False + }, + "openai_organization": { + "required": False, + "placeholder": "", + "show": False, + "multiline": False, + "password": False, + "name": "openai_organization", + "display_name": "OpenAI Organization", + "advanced": False, + "type": "str", + "list": False + }, + "openai_proxy": { + "required": False, + "placeholder": "", + "show": False, + "multiline": False, + "password": False, + "name": "openai_proxy", + "display_name": "OpenAI Proxy", + "advanced": False, + "type": "str", + "list": False + }, + "request_timeout": { + "required": False, + "placeholder": "", + "show": False, + "multiline": False, + "password": False, + "name": "request_timeout", + "advanced": False, + "type": "float", + "list": False + }, + "max_retries": { + "required": False, + "placeholder": "", + "show": False, + "multiline": False, + "value": 6, + "password": False, + "name": "max_retries", + "advanced": False, + "type": "int", + "list": False + }, + "streaming": { + "required": False, + "placeholder": "", + "show": False, + "multiline": False, + "value": False, + "password": False, + "name": "streaming", + "advanced": False, + "type": "bool", + "list": False + }, + "n": { + "required": False, + "placeholder": "", + "show": False, + "multiline": False, + "value": 1, + "password": False, + "name": "n", + "advanced": False, + "type": "int", + "list": False + }, + "max_tokens": { + "required": False, + "placeholder": "", + "show": True, + "multiline": False, + "password": True, + "name": "max_tokens", + "advanced": False, + "type": "int", + "list": False + }, + "_type": "ChatOpenAI" + }, + "base_classes": [ + "BaseChatModel", + "Serializable", + "BaseLanguageModel", + "ChatOpenAI" + ], + "description": "Wrapper around OpenAI Chat large language models." +} + def create_app(): """Create the FastAPI app and include the router.""" @@ -19,6 +252,10 @@ def create_app(): def get_health(): return {"status": "OK"} + @app.get("/dynamic_node") + def get_dynamic_nome(): + return template_node + app.add_middleware( CORSMiddleware, allow_origins=origins, diff --git a/src/backend/langflow/template/frontend_node/base.py b/src/backend/langflow/template/frontend_node/base.py index 4801da086..2663164e0 100644 --- a/src/backend/langflow/template/frontend_node/base.py +++ b/src/backend/langflow/template/frontend_node/base.py @@ -15,6 +15,7 @@ class FrontendNode(BaseModel): base_classes: List[str] name: str = "" display_name: str = "" + custom_fields: List[str] = [] def to_dict(self) -> dict: return { @@ -23,6 +24,7 @@ class FrontendNode(BaseModel): "description": self.description, "base_classes": self.base_classes, "display_name": self.display_name or self.name, + "custom_fields": self.custom_fields, }, } diff --git a/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx b/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx index 4347b88ca..9b7590213 100644 --- a/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx +++ b/src/frontend/src/CustomNodes/GenericNode/components/parameterComponent/index.tsx @@ -229,6 +229,10 @@ export default function ParameterComponent({ ) : left === true && type === "code" ? ( { + data.node = nodeClass; + }} + nodeClass={data.node} disabled={disabled} value={data.node.template[name].value ?? ""} onChange={handleOnNewValue} @@ -256,6 +260,7 @@ export default function ParameterComponent({ ) : left === true && type === "prompt" ? ( {}, [closePopUp, data.node.template]); - + console.log({data}) return ( <> diff --git a/src/frontend/src/components/codeAreaComponent/index.tsx b/src/frontend/src/components/codeAreaComponent/index.tsx index 1730904e2..6b9cc7725 100644 --- a/src/frontend/src/components/codeAreaComponent/index.tsx +++ b/src/frontend/src/components/codeAreaComponent/index.tsx @@ -2,7 +2,7 @@ import { useContext, useEffect, useState } from "react"; import { PopUpContext } from "../../contexts/popUpContext"; import CodeAreaModal from "../../modals/codeAreaModal"; import TextAreaModal from "../../modals/textAreaModal"; -import { TextAreaComponentType } from "../../types/components"; +import { CodeAreaComponentType, TextAreaComponentType } from "../../types/components"; import { INPUT_STYLE } from "../../constants"; import { ExternalLink } from "lucide-react"; @@ -11,7 +11,9 @@ export default function CodeAreaComponent({ onChange, disabled, editNode = false, -}: TextAreaComponentType) { + nodeClass, + setNodeClass, +}: CodeAreaComponentType) { const [myValue, setMyValue] = useState(value); const { openPopUp } = useContext(PopUpContext); useEffect(() => { @@ -37,6 +39,8 @@ export default function CodeAreaComponent({ openPopUp( { setMyValue(t); onChange(t); @@ -59,7 +63,9 @@ export default function CodeAreaComponent({ onClick={() => { openPopUp( { setMyValue(t); onChange(t); diff --git a/src/frontend/src/components/promptComponent/index.tsx b/src/frontend/src/components/promptComponent/index.tsx index 9752c0c10..abc16c8db 100644 --- a/src/frontend/src/components/promptComponent/index.tsx +++ b/src/frontend/src/components/promptComponent/index.tsx @@ -7,6 +7,7 @@ import { INPUT_STYLE } from "../../constants"; import { ExternalLink } from "lucide-react"; export default function PromptAreaComponent({ + nodeClass, value, onChange, disabled, @@ -44,6 +45,7 @@ export default function PromptAreaComponent({ setMyValue(t); onChange(t); }} + nodeClass={nodeClass} /> ); }} diff --git a/src/frontend/src/controllers/API/index.ts b/src/frontend/src/controllers/API/index.ts index 2651a0058..5c5c43a4d 100644 --- a/src/frontend/src/controllers/API/index.ts +++ b/src/frontend/src/controllers/API/index.ts @@ -4,6 +4,7 @@ import { errorsTypeAPI, InitTypeAPI, UploadFileTypeAPI, + APIClassType, } from "./../../types/api/index"; import { APIObjectType, sendAllProps } from "../../types/api/index"; import axios, { AxiosResponse } from "axios"; @@ -56,9 +57,13 @@ export async function postValidateCode( * @returns {Promise>} A promise that resolves to an AxiosResponse containing the validation results. */ export async function checkPrompt( - template: string + template: string, + frontend_node: APIClassType ): Promise> { - return await axios.post("/api/v1/validate/prompt", { template }); + return await axios.post("/api/v1/validate/prompt", { + template: template, + frontend_node: frontend_node, + }); } /** @@ -331,3 +336,10 @@ export async function uploadFile( formData.append("file", file); return await axios.post(`/api/v1/upload/${id}`, formData); } + +export async function postCustomComponent( + code: string, + apiClass: APIClassType +): Promise> { + return await axios.post(`/api/v1/custom_component`, { code }); +} diff --git a/src/frontend/src/modals/codeAreaModal/index.tsx b/src/frontend/src/modals/codeAreaModal/index.tsx index f6e4b8e1c..3c2da7eea 100644 --- a/src/frontend/src/modals/codeAreaModal/index.tsx +++ b/src/frontend/src/modals/codeAreaModal/index.tsx @@ -1,4 +1,4 @@ -import { Fragment, useContext, useRef, useState } from "react"; +import { useContext, useRef, useState } from "react"; import { PopUpContext } from "../../contexts/popUpContext"; import AceEditor from "react-ace"; import "ace-builds/src-noconflict/mode-python"; @@ -7,9 +7,8 @@ import "ace-builds/src-noconflict/theme-twilight"; import "ace-builds/src-noconflict/ext-language_tools"; // import "ace-builds/webpack-resolver"; import { darkContext } from "../../contexts/darkContext"; -import { postValidateCode } from "../../controllers/API"; +import { postCustomComponent, postValidateCode } from "../../controllers/API"; import { alertContext } from "../../contexts/alertContext"; -import { TabsContext } from "../../contexts/tabsContext"; import { Dialog, DialogContent, @@ -22,16 +21,22 @@ import { import { Button } from "../../components/ui/button"; import { CODE_PROMPT_DIALOG_SUBTITLE } from "../../constants"; import { TerminalSquare } from "lucide-react"; +import { APIClassType } from "../../types/api"; export default function CodeAreaModal({ value, setValue, + nodeClass, + setNodeClass, }: { setValue: (value: string) => void; value: string; + nodeClass: APIClassType; + setNodeClass: (Class: APIClassType) => void; }) { const [open, setOpen] = useState(true); const [code, setCode] = useState(value); + const [loading, setLoading] = useState(false); const { dark } = useContext(darkContext); const { setErrorData, setSuccessData } = useContext(alertContext); const { closePopUp } = useContext(PopUpContext); @@ -44,6 +49,55 @@ export default function CodeAreaModal({ }, 300); } } + + function handleClick() { + setLoading(true); + postValidateCode(code) + .then((apiReturn) => { + setLoading(false); + if (apiReturn.data) { + let importsErrors = apiReturn.data.imports.errors; + let funcErrors = apiReturn.data.function.errors; + if (funcErrors.length === 0 && importsErrors.length === 0) { + setSuccessData({ + title: "Code is ready to run", + }); + // setValue(code); + } else { + if (funcErrors.length !== 0) { + setErrorData({ + title: "There is an error in your function", + list: funcErrors, + }); + } + if (importsErrors.length !== 0) { + setErrorData({ + title: "There is an error in your imports", + list: importsErrors, + }); + } + } + } else { + setErrorData({ + title: "Something went wrong, please try again", + }); + } + }) + .catch((_) => { + setLoading(false); + setErrorData({ + title: "There is something wrong with this code, please review it", + }); + }); + postCustomComponent(code, nodeClass).then((apiReturn) => { + const data = apiReturn.data; + if (data) { + setNodeClass(data); + setModalOpen(false); + } + }); + } + return ( @@ -78,49 +132,8 @@ export default function CodeAreaModal({ - diff --git a/src/frontend/src/modals/genericModal/index.tsx b/src/frontend/src/modals/genericModal/index.tsx index 6f89e24a6..8f7dcefc8 100644 --- a/src/frontend/src/modals/genericModal/index.tsx +++ b/src/frontend/src/modals/genericModal/index.tsx @@ -1,9 +1,8 @@ -import { Fragment, useContext, useRef, useState } from "react"; +import { useContext, useRef, useState } from "react"; import { PopUpContext } from "../../contexts/popUpContext"; import { darkContext } from "../../contexts/darkContext"; import { checkPrompt } from "../../controllers/API"; import { alertContext } from "../../contexts/alertContext"; -import { TypeModal } from "../../utils"; import { Dialog, DialogContent, @@ -17,6 +16,7 @@ import { Button } from "../../components/ui/button"; import { Textarea } from "../../components/ui/textarea"; import { PROMPT_DIALOG_SUBTITLE, TEXT_DIALOG_SUBTITLE } from "../../constants"; import { FileText } from "lucide-react"; +import { APIClassType } from "../../types/api"; export default function GenericModal({ value, @@ -24,12 +24,14 @@ export default function GenericModal({ buttonText, modalTitle, type, + nodeClass, }: { setValue: (value: string) => void; value: string; buttonText: string; modalTitle: string; type: number; + nodeClass: APIClassType; }) { const [myButtonText] = useState(buttonText); const [myModalTitle] = useState(modalTitle); @@ -97,9 +99,13 @@ export default function GenericModal({ setModalOpen(false); break; case 2: - checkPrompt(myValue) + checkPrompt(myValue, nodeClass) .then((apiReturn) => { if (apiReturn.data) { + if (apiReturn.data) { + setNodeClass(data); + setModalOpen(false); + } let inputVariables = apiReturn.data.input_variables; if (inputVariables.length === 0) { setErrorData({ diff --git a/src/frontend/src/types/components/index.ts b/src/frontend/src/types/components/index.ts index 630526193..c1cc3756c 100644 --- a/src/frontend/src/types/components/index.ts +++ b/src/frontend/src/types/components/index.ts @@ -1,12 +1,7 @@ -import { - ComponentType, - ForwardRefExoticComponent, - ReactElement, - ReactNode, - SVGProps, -} from "react"; +import { ReactElement, ReactNode } from "react"; import { NodeDataType } from "../flow/index"; import { typesContextType } from "../typesContext"; +import { APIClassType } from "../api"; export type InputComponentType = { value: string; disabled?: boolean; @@ -50,12 +45,22 @@ export type InputListComponentType = { }; export type TextAreaComponentType = { + nodeClass?: APIClassType; disabled: boolean; onChange: (value: string[] | string) => void; value: string; editNode?: boolean; }; +export type CodeAreaComponentType = { + disabled: boolean; + onChange: (value: string[] | string) => void; + value: string; + editNode?: boolean; + nodeClass: APIClassType; + setNodeClass: (value: APIClassType) => void; +}; + export type FileComponentType = { disabled: boolean; onChange: (value: string[] | string) => void;