feat: add csv agent
This commit is contained in:
parent
9ef1078786
commit
7052508fb0
9 changed files with 160 additions and 7 deletions
|
|
@ -7,6 +7,7 @@ chains:
|
|||
agents:
|
||||
- ZeroShotAgent
|
||||
- JsonAgent
|
||||
- CSVAgent
|
||||
|
||||
prompts:
|
||||
- PromptTemplate
|
||||
|
|
|
|||
|
|
@ -3,7 +3,7 @@ from langflow.template import nodes
|
|||
CUSTOM_NODES = {
|
||||
"prompts": {**nodes.ZeroShotPromptNode().to_dict()},
|
||||
"tools": {**nodes.PythonFunctionNode().to_dict(), **nodes.ToolNode().to_dict()},
|
||||
"agents": {**nodes.JsonAgentNode().to_dict()},
|
||||
"agents": {**nodes.JsonAgentNode().to_dict(), **nodes.CSVAgentNode().to_dict()},
|
||||
}
|
||||
|
||||
|
||||
|
|
|
|||
|
|
@ -8,7 +8,7 @@ from copy import deepcopy
|
|||
from typing import Any, Dict, List
|
||||
from langflow.graph.constants import DIRECT_TYPES
|
||||
|
||||
from langflow.graph.utils import load_dict
|
||||
from langflow.graph.utils import load_file
|
||||
from langflow.interface import loading
|
||||
from langflow.interface.listing import ALL_TYPES_DICT
|
||||
|
||||
|
|
@ -90,7 +90,7 @@ class Node:
|
|||
type_to_load = value.get("suffixes")
|
||||
file_name = value.get("value")
|
||||
content = value.get("content")
|
||||
loaded_dict = load_dict(file_name, content, type_to_load)
|
||||
loaded_dict = load_file(file_name, content, type_to_load)
|
||||
params[key] = loaded_dict
|
||||
|
||||
# We should check if the type is in something not
|
||||
|
|
|
|||
|
|
@ -1,11 +1,13 @@
|
|||
import base64
|
||||
import json
|
||||
from typing import Dict
|
||||
from typing import Any
|
||||
import re
|
||||
import yaml
|
||||
import csv
|
||||
import io
|
||||
|
||||
|
||||
def load_dict(file_name, file_content, accepted_types) -> Dict:
|
||||
def load_file(file_name, file_content, accepted_types) -> Any:
|
||||
"""Load a file from a string."""
|
||||
# Check if the file is accepted
|
||||
if not any(file_name.endswith(suffix) for suffix in accepted_types):
|
||||
|
|
@ -24,6 +26,10 @@ def load_dict(file_name, file_content, accepted_types) -> Dict:
|
|||
elif suffix in ["yaml", "yml"]:
|
||||
# Return the yaml content
|
||||
return yaml.safe_load(decoded_string)
|
||||
elif suffix == "csv":
|
||||
# Load the csv content
|
||||
csv_reader = csv.DictReader(io.StringIO(decoded_string))
|
||||
return list(csv_reader)
|
||||
else:
|
||||
raise ValueError(f"File {file_name} is not accepted")
|
||||
|
||||
|
|
|
|||
|
|
@ -7,6 +7,14 @@ from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit
|
|||
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
|
||||
from langchain.schema import BaseLanguageModel
|
||||
from pydantic import BaseModel
|
||||
from langchain.llms.base import BaseLLM
|
||||
from typing import Any, Optional
|
||||
from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent
|
||||
from pathlib import Path
|
||||
|
||||
from langchain.agents.agent_toolkits.pandas.prompt import PREFIX as PANDAS_PREFIX
|
||||
from langchain.agents.agent_toolkits.pandas.prompt import SUFFIX as PANDAS_SUFFIX
|
||||
from langchain.tools.python.tool import PythonAstREPLTool
|
||||
|
||||
|
||||
class JsonAgent(AgentExecutor):
|
||||
|
|
@ -41,6 +49,52 @@ class JsonAgent(AgentExecutor):
|
|||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
class CSVAgent(AgentExecutor):
|
||||
"""CSV agent"""
|
||||
|
||||
@classmethod
|
||||
def initialize(cls, *args, **kwargs):
|
||||
return cls.from_toolkit_and_llm(*args, **kwargs)
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
@classmethod
|
||||
def from_toolkit_and_llm(
|
||||
cls,
|
||||
path: dict,
|
||||
llm: BaseLanguageModel,
|
||||
pandas_kwargs: Optional[dict] = None,
|
||||
**kwargs: Any
|
||||
):
|
||||
import pandas as pd
|
||||
|
||||
_kwargs = pandas_kwargs or {}
|
||||
df = pd.DataFrame.from_dict(path, **_kwargs)
|
||||
|
||||
tools = [PythonAstREPLTool(locals={"df": df})]
|
||||
prompt = ZeroShotAgent.create_prompt(
|
||||
tools,
|
||||
prefix=PANDAS_PREFIX,
|
||||
suffix=PANDAS_SUFFIX,
|
||||
input_variables=["df", "input", "agent_scratchpad"],
|
||||
)
|
||||
partial_prompt = prompt.partial(df=str(df.head()))
|
||||
llm_chain = LLMChain(
|
||||
llm=llm,
|
||||
prompt=partial_prompt,
|
||||
callback_manager=None,
|
||||
)
|
||||
tool_names = [tool.name for tool in tools]
|
||||
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)
|
||||
|
||||
return cls.from_agent_and_tools(agent=agent, tools=tools, verbose=True)
|
||||
|
||||
def run(self, *args, **kwargs):
|
||||
return super().run(*args, **kwargs)
|
||||
|
||||
|
||||
CUSTOM_AGENTS = {
|
||||
"JsonAgent": JsonAgent,
|
||||
"CSVAgent": CSVAgent,
|
||||
}
|
||||
|
|
|
|||
|
|
@ -13,6 +13,7 @@ class TemplateFieldCreator(BaseModel, ABC):
|
|||
multiline: bool = False
|
||||
value: Any = None
|
||||
suffixes: list[str] = []
|
||||
fileTypes: list[str] = []
|
||||
file_types: list[str] = []
|
||||
content: Union[str, None] = None
|
||||
password: bool = False
|
||||
|
|
|
|||
|
|
@ -139,3 +139,32 @@ class JsonAgentNode(FrontendNode):
|
|||
|
||||
def to_dict(self):
|
||||
return super().to_dict()
|
||||
|
||||
|
||||
class CSVAgentNode(FrontendNode):
|
||||
name: str = "CSVAgent"
|
||||
template: Template = Template(
|
||||
type_name="csv_agent",
|
||||
fields=[
|
||||
TemplateField(
|
||||
field_type="file",
|
||||
required=True,
|
||||
show=True,
|
||||
name="path",
|
||||
value="",
|
||||
suffixes=[".csv"],
|
||||
fileTypes=["csv"],
|
||||
),
|
||||
TemplateField(
|
||||
field_type="BaseLanguageModel",
|
||||
required=True,
|
||||
show=True,
|
||||
name="llm",
|
||||
),
|
||||
],
|
||||
)
|
||||
description: str = """Construct a json agent from a CSV and tools."""
|
||||
base_classes: list[str] = ["AgentExecutor"]
|
||||
|
||||
def to_dict(self):
|
||||
return super().to_dict()
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue