feat: add csv agent

This commit is contained in:
Ibis Prevedello 2023-04-01 16:17:31 -03:00
commit 7052508fb0
9 changed files with 160 additions and 7 deletions

View file

@ -7,6 +7,7 @@ chains:
agents:
- ZeroShotAgent
- JsonAgent
- CSVAgent
prompts:
- PromptTemplate

View file

@ -3,7 +3,7 @@ from langflow.template import nodes
CUSTOM_NODES = {
"prompts": {**nodes.ZeroShotPromptNode().to_dict()},
"tools": {**nodes.PythonFunctionNode().to_dict(), **nodes.ToolNode().to_dict()},
"agents": {**nodes.JsonAgentNode().to_dict()},
"agents": {**nodes.JsonAgentNode().to_dict(), **nodes.CSVAgentNode().to_dict()},
}

View file

@ -8,7 +8,7 @@ from copy import deepcopy
from typing import Any, Dict, List
from langflow.graph.constants import DIRECT_TYPES
from langflow.graph.utils import load_dict
from langflow.graph.utils import load_file
from langflow.interface import loading
from langflow.interface.listing import ALL_TYPES_DICT
@ -90,7 +90,7 @@ class Node:
type_to_load = value.get("suffixes")
file_name = value.get("value")
content = value.get("content")
loaded_dict = load_dict(file_name, content, type_to_load)
loaded_dict = load_file(file_name, content, type_to_load)
params[key] = loaded_dict
# We should check if the type is in something not

View file

@ -1,11 +1,13 @@
import base64
import json
from typing import Dict
from typing import Any
import re
import yaml
import csv
import io
def load_dict(file_name, file_content, accepted_types) -> Dict:
def load_file(file_name, file_content, accepted_types) -> Any:
"""Load a file from a string."""
# Check if the file is accepted
if not any(file_name.endswith(suffix) for suffix in accepted_types):
@ -24,6 +26,10 @@ def load_dict(file_name, file_content, accepted_types) -> Dict:
elif suffix in ["yaml", "yml"]:
# Return the yaml content
return yaml.safe_load(decoded_string)
elif suffix == "csv":
# Load the csv content
csv_reader = csv.DictReader(io.StringIO(decoded_string))
return list(csv_reader)
else:
raise ValueError(f"File {file_name} is not accepted")

View file

@ -7,6 +7,14 @@ from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.schema import BaseLanguageModel
from pydantic import BaseModel
from langchain.llms.base import BaseLLM
from typing import Any, Optional
from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent
from pathlib import Path
from langchain.agents.agent_toolkits.pandas.prompt import PREFIX as PANDAS_PREFIX
from langchain.agents.agent_toolkits.pandas.prompt import SUFFIX as PANDAS_SUFFIX
from langchain.tools.python.tool import PythonAstREPLTool
class JsonAgent(AgentExecutor):
@ -41,6 +49,52 @@ class JsonAgent(AgentExecutor):
return super().run(*args, **kwargs)
class CSVAgent(AgentExecutor):
"""CSV agent"""
@classmethod
def initialize(cls, *args, **kwargs):
return cls.from_toolkit_and_llm(*args, **kwargs)
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
@classmethod
def from_toolkit_and_llm(
cls,
path: dict,
llm: BaseLanguageModel,
pandas_kwargs: Optional[dict] = None,
**kwargs: Any
):
import pandas as pd
_kwargs = pandas_kwargs or {}
df = pd.DataFrame.from_dict(path, **_kwargs)
tools = [PythonAstREPLTool(locals={"df": df})]
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=PANDAS_PREFIX,
suffix=PANDAS_SUFFIX,
input_variables=["df", "input", "agent_scratchpad"],
)
partial_prompt = prompt.partial(df=str(df.head()))
llm_chain = LLMChain(
llm=llm,
prompt=partial_prompt,
callback_manager=None,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
return cls.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
def run(self, *args, **kwargs):
return super().run(*args, **kwargs)
CUSTOM_AGENTS = {
"JsonAgent": JsonAgent,
"CSVAgent": CSVAgent,
}

View file

@ -13,6 +13,7 @@ class TemplateFieldCreator(BaseModel, ABC):
multiline: bool = False
value: Any = None
suffixes: list[str] = []
fileTypes: list[str] = []
file_types: list[str] = []
content: Union[str, None] = None
password: bool = False

View file

@ -139,3 +139,32 @@ class JsonAgentNode(FrontendNode):
def to_dict(self):
return super().to_dict()
class CSVAgentNode(FrontendNode):
name: str = "CSVAgent"
template: Template = Template(
type_name="csv_agent",
fields=[
TemplateField(
field_type="file",
required=True,
show=True,
name="path",
value="",
suffixes=[".csv"],
fileTypes=["csv"],
),
TemplateField(
field_type="BaseLanguageModel",
required=True,
show=True,
name="llm",
),
],
)
description: str = """Construct a json agent from a CSV and tools."""
base_classes: list[str] = ["AgentExecutor"]
def to_dict(self):
return super().to_dict()