From 7052508fb02a985eb4b620ada119e42b6439b587 Mon Sep 17 00:00:00 2001 From: Ibis Prevedello Date: Sat, 1 Apr 2023 16:17:31 -0300 Subject: [PATCH] feat: add csv agent --- poetry.lock | 65 ++++++++++++++++++- pyproject.toml | 1 + src/backend/langflow/config.yaml | 1 + src/backend/langflow/custom/customs.py | 2 +- src/backend/langflow/graph/base.py | 4 +- src/backend/langflow/graph/utils.py | 10 ++- .../langflow/interface/agents/custom.py | 54 +++++++++++++++ src/backend/langflow/template/base.py | 1 + src/backend/langflow/template/nodes.py | 29 +++++++++ 9 files changed, 160 insertions(+), 7 deletions(-) diff --git a/poetry.lock b/poetry.lock index 9462c63cf..9043bccec 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1524,6 +1524,55 @@ files = [ {file = "packaging-23.0.tar.gz", hash = "sha256:b6ad297f8907de0fa2fe1ccbd26fdaf387f5f47c7275fedf8cce89f99446cf97"}, ] +[[package]] +name = "pandas" +version = "1.5.3" +description = "Powerful data structures for data analysis, time series, and statistics" +category = "main" +optional = false +python-versions = ">=3.8" +files = [ + {file = "pandas-1.5.3-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:3749077d86e3a2f0ed51367f30bf5b82e131cc0f14260c4d3e499186fccc4406"}, + {file = "pandas-1.5.3-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:972d8a45395f2a2d26733eb8d0f629b2f90bebe8e8eddbb8829b180c09639572"}, + {file = "pandas-1.5.3-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:50869a35cbb0f2e0cd5ec04b191e7b12ed688874bd05dd777c19b28cbea90996"}, + {file = "pandas-1.5.3-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c3ac844a0fe00bfaeb2c9b51ab1424e5c8744f89860b138434a363b1f620f354"}, + {file = "pandas-1.5.3-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7a0a56cef15fd1586726dace5616db75ebcfec9179a3a55e78f72c5639fa2a23"}, + {file = "pandas-1.5.3-cp310-cp310-win_amd64.whl", hash = "sha256:478ff646ca42b20376e4ed3fa2e8d7341e8a63105586efe54fa2508ee087f328"}, + {file = "pandas-1.5.3-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:6973549c01ca91ec96199e940495219c887ea815b2083722821f1d7abfa2b4dc"}, + {file = "pandas-1.5.3-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c39a8da13cede5adcd3be1182883aea1c925476f4e84b2807a46e2775306305d"}, + {file = "pandas-1.5.3-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:f76d097d12c82a535fda9dfe5e8dd4127952b45fea9b0276cb30cca5ea313fbc"}, + {file = "pandas-1.5.3-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:e474390e60ed609cec869b0da796ad94f420bb057d86784191eefc62b65819ae"}, + {file = "pandas-1.5.3-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5f2b952406a1588ad4cad5b3f55f520e82e902388a6d5a4a91baa8d38d23c7f6"}, + {file = "pandas-1.5.3-cp311-cp311-win_amd64.whl", hash = "sha256:bc4c368f42b551bf72fac35c5128963a171b40dce866fb066540eeaf46faa003"}, + {file = "pandas-1.5.3-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:14e45300521902689a81f3f41386dc86f19b8ba8dd5ac5a3c7010ef8d2932813"}, + {file = "pandas-1.5.3-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:9842b6f4b8479e41968eced654487258ed81df7d1c9b7b870ceea24ed9459b31"}, + {file = "pandas-1.5.3-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:26d9c71772c7afb9d5046e6e9cf42d83dd147b5cf5bcb9d97252077118543792"}, + {file = "pandas-1.5.3-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5fbcb19d6fceb9e946b3e23258757c7b225ba450990d9ed63ccceeb8cae609f7"}, + {file = "pandas-1.5.3-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:565fa34a5434d38e9d250af3c12ff931abaf88050551d9fbcdfafca50d62babf"}, + {file = "pandas-1.5.3-cp38-cp38-win32.whl", hash = "sha256:87bd9c03da1ac870a6d2c8902a0e1fd4267ca00f13bc494c9e5a9020920e1d51"}, + {file = "pandas-1.5.3-cp38-cp38-win_amd64.whl", hash = "sha256:41179ce559943d83a9b4bbacb736b04c928b095b5f25dd2b7389eda08f46f373"}, + {file = "pandas-1.5.3-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:c74a62747864ed568f5a82a49a23a8d7fe171d0c69038b38cedf0976831296fa"}, + {file = "pandas-1.5.3-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:c4c00e0b0597c8e4f59e8d461f797e5d70b4d025880516a8261b2817c47759ee"}, + {file = "pandas-1.5.3-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:a50d9a4336a9621cab7b8eb3fb11adb82de58f9b91d84c2cd526576b881a0c5a"}, + {file = "pandas-1.5.3-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:dd05f7783b3274aa206a1af06f0ceed3f9b412cf665b7247eacd83be41cf7bf0"}, + {file = "pandas-1.5.3-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9f69c4029613de47816b1bb30ff5ac778686688751a5e9c99ad8c7031f6508e5"}, + {file = "pandas-1.5.3-cp39-cp39-win32.whl", hash = "sha256:7cec0bee9f294e5de5bbfc14d0573f65526071029d036b753ee6507d2a21480a"}, + {file = "pandas-1.5.3-cp39-cp39-win_amd64.whl", hash = "sha256:dfd681c5dc216037e0b0a2c821f5ed99ba9f03ebcf119c7dac0e9a7b960b9ec9"}, + {file = "pandas-1.5.3.tar.gz", hash = "sha256:74a3fd7e5a7ec052f183273dc7b0acd3a863edf7520f5d3a1765c04ffdb3b0b1"}, +] + +[package.dependencies] +numpy = [ + {version = ">=1.20.3", markers = "python_version < \"3.10\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, + {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, +] +python-dateutil = ">=2.8.1" +pytz = ">=2020.1" + +[package.extras] +test = ["hypothesis (>=5.5.3)", "pytest (>=6.0)", "pytest-xdist (>=1.31)"] + [[package]] name = "parso" version = "0.8.3" @@ -1853,7 +1902,7 @@ testing = ["argcomplete", "hypothesis (>=3.56)", "mock", "nose", "pygments (>=2. name = "python-dateutil" version = "2.8.2" description = "Extensions to the standard Python datetime module" -category = "dev" +category = "main" optional = false python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" files = [ @@ -1864,6 +1913,18 @@ files = [ [package.dependencies] six = ">=1.5" +[[package]] +name = "pytz" +version = "2023.3" +description = "World timezone definitions, modern and historical" +category = "main" +optional = false +python-versions = "*" +files = [ + {file = "pytz-2023.3-py2.py3-none-any.whl", hash = "sha256:a151b3abb88eda1d4e34a9814df37de2a80e301e68ba0fd856fb9b46bfbbbffb"}, + {file = "pytz-2023.3.tar.gz", hash = "sha256:1d8ce29db189191fb55338ee6d0387d82ab59f3d00eac103412d64e0ebd0c588"}, +] + [[package]] name = "pywin32" version = "306" @@ -2636,4 +2697,4 @@ testing = ["big-O", "flake8 (<5)", "jaraco.functools", "jaraco.itertools", "more [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "afeeaa3c4d0aee2a52be1ccfeff2c47cca9f6af446b5cf4e422fcbb214eec762" +content-hash = "99d1b3923d427a2bdce635e88aca5f9dd2af850a31b94d290cbba52c8cf533f8" diff --git a/pyproject.toml b/pyproject.toml index 3f202b843..4e415e71f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ gunicorn = "^20.1.0" langchain = "^0.0.127" openai = "^0.27.2" types-pyyaml = "^6.0.12.8" +pandas = "^1.5.3" [tool.poetry.group.dev.dependencies] black = "^23.1.0" diff --git a/src/backend/langflow/config.yaml b/src/backend/langflow/config.yaml index b8c2be6da..dede4cee6 100644 --- a/src/backend/langflow/config.yaml +++ b/src/backend/langflow/config.yaml @@ -7,6 +7,7 @@ chains: agents: - ZeroShotAgent - JsonAgent + - CSVAgent prompts: - PromptTemplate diff --git a/src/backend/langflow/custom/customs.py b/src/backend/langflow/custom/customs.py index fa14fb2e5..6a70732a0 100644 --- a/src/backend/langflow/custom/customs.py +++ b/src/backend/langflow/custom/customs.py @@ -3,7 +3,7 @@ from langflow.template import nodes CUSTOM_NODES = { "prompts": {**nodes.ZeroShotPromptNode().to_dict()}, "tools": {**nodes.PythonFunctionNode().to_dict(), **nodes.ToolNode().to_dict()}, - "agents": {**nodes.JsonAgentNode().to_dict()}, + "agents": {**nodes.JsonAgentNode().to_dict(), **nodes.CSVAgentNode().to_dict()}, } diff --git a/src/backend/langflow/graph/base.py b/src/backend/langflow/graph/base.py index 0c4cf8705..f4b41bbfc 100644 --- a/src/backend/langflow/graph/base.py +++ b/src/backend/langflow/graph/base.py @@ -8,7 +8,7 @@ from copy import deepcopy from typing import Any, Dict, List from langflow.graph.constants import DIRECT_TYPES -from langflow.graph.utils import load_dict +from langflow.graph.utils import load_file from langflow.interface import loading from langflow.interface.listing import ALL_TYPES_DICT @@ -90,7 +90,7 @@ class Node: type_to_load = value.get("suffixes") file_name = value.get("value") content = value.get("content") - loaded_dict = load_dict(file_name, content, type_to_load) + loaded_dict = load_file(file_name, content, type_to_load) params[key] = loaded_dict # We should check if the type is in something not diff --git a/src/backend/langflow/graph/utils.py b/src/backend/langflow/graph/utils.py index 70f3a3145..ca728390d 100644 --- a/src/backend/langflow/graph/utils.py +++ b/src/backend/langflow/graph/utils.py @@ -1,11 +1,13 @@ import base64 import json -from typing import Dict +from typing import Any import re import yaml +import csv +import io -def load_dict(file_name, file_content, accepted_types) -> Dict: +def load_file(file_name, file_content, accepted_types) -> Any: """Load a file from a string.""" # Check if the file is accepted if not any(file_name.endswith(suffix) for suffix in accepted_types): @@ -24,6 +26,10 @@ def load_dict(file_name, file_content, accepted_types) -> Dict: elif suffix in ["yaml", "yml"]: # Return the yaml content return yaml.safe_load(decoded_string) + elif suffix == "csv": + # Load the csv content + csv_reader = csv.DictReader(io.StringIO(decoded_string)) + return list(csv_reader) else: raise ValueError(f"File {file_name} is not accepted") diff --git a/src/backend/langflow/interface/agents/custom.py b/src/backend/langflow/interface/agents/custom.py index cc998fc12..c056b7c72 100644 --- a/src/backend/langflow/interface/agents/custom.py +++ b/src/backend/langflow/interface/agents/custom.py @@ -7,6 +7,14 @@ from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS from langchain.schema import BaseLanguageModel from pydantic import BaseModel +from langchain.llms.base import BaseLLM +from typing import Any, Optional +from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent +from pathlib import Path + +from langchain.agents.agent_toolkits.pandas.prompt import PREFIX as PANDAS_PREFIX +from langchain.agents.agent_toolkits.pandas.prompt import SUFFIX as PANDAS_SUFFIX +from langchain.tools.python.tool import PythonAstREPLTool class JsonAgent(AgentExecutor): @@ -41,6 +49,52 @@ class JsonAgent(AgentExecutor): return super().run(*args, **kwargs) +class CSVAgent(AgentExecutor): + """CSV agent""" + + @classmethod + def initialize(cls, *args, **kwargs): + return cls.from_toolkit_and_llm(*args, **kwargs) + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + @classmethod + def from_toolkit_and_llm( + cls, + path: dict, + llm: BaseLanguageModel, + pandas_kwargs: Optional[dict] = None, + **kwargs: Any + ): + import pandas as pd + + _kwargs = pandas_kwargs or {} + df = pd.DataFrame.from_dict(path, **_kwargs) + + tools = [PythonAstREPLTool(locals={"df": df})] + prompt = ZeroShotAgent.create_prompt( + tools, + prefix=PANDAS_PREFIX, + suffix=PANDAS_SUFFIX, + input_variables=["df", "input", "agent_scratchpad"], + ) + partial_prompt = prompt.partial(df=str(df.head())) + llm_chain = LLMChain( + llm=llm, + prompt=partial_prompt, + callback_manager=None, + ) + tool_names = [tool.name for tool in tools] + agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs) + + return cls.from_agent_and_tools(agent=agent, tools=tools, verbose=True) + + def run(self, *args, **kwargs): + return super().run(*args, **kwargs) + + CUSTOM_AGENTS = { "JsonAgent": JsonAgent, + "CSVAgent": CSVAgent, } diff --git a/src/backend/langflow/template/base.py b/src/backend/langflow/template/base.py index bcd6ed162..7329df04d 100644 --- a/src/backend/langflow/template/base.py +++ b/src/backend/langflow/template/base.py @@ -13,6 +13,7 @@ class TemplateFieldCreator(BaseModel, ABC): multiline: bool = False value: Any = None suffixes: list[str] = [] + fileTypes: list[str] = [] file_types: list[str] = [] content: Union[str, None] = None password: bool = False diff --git a/src/backend/langflow/template/nodes.py b/src/backend/langflow/template/nodes.py index fae298f2d..96219e000 100644 --- a/src/backend/langflow/template/nodes.py +++ b/src/backend/langflow/template/nodes.py @@ -139,3 +139,32 @@ class JsonAgentNode(FrontendNode): def to_dict(self): return super().to_dict() + + +class CSVAgentNode(FrontendNode): + name: str = "CSVAgent" + template: Template = Template( + type_name="csv_agent", + fields=[ + TemplateField( + field_type="file", + required=True, + show=True, + name="path", + value="", + suffixes=[".csv"], + fileTypes=["csv"], + ), + TemplateField( + field_type="BaseLanguageModel", + required=True, + show=True, + name="llm", + ), + ], + ) + description: str = """Construct a json agent from a CSV and tools.""" + base_classes: list[str] = ["AgentExecutor"] + + def to_dict(self): + return super().to_dict()