refac: fix linting

This commit is contained in:
Ibis Prevedello 2023-04-01 18:03:14 -03:00
commit 17ad4954ef
11 changed files with 131 additions and 140 deletions

View file

@ -5,7 +5,7 @@ import os
import tempfile
from pathlib import Path
import dill
import dill # type: ignore
PREFIX = "langflow_cache"

View file

@ -1,10 +1,8 @@
import json
from copy import deepcopy
from typing import Any, Dict, List, Optional, Union
from langflow.graph.base import Node
from langflow.graph.utils import extract_input_variables_from_prompt
from langflow.interface.toolkits.base import toolkits_creator
class AgentNode(Node):

View file

@ -1,18 +1,14 @@
from pathlib import Path
from typing import Any, Optional
from langchain import LLMChain
from langchain.agents import AgentExecutor, ZeroShotAgent
from langchain.agents.agent_toolkits.json.prompt import JSON_PREFIX, JSON_SUFFIX
from langchain.agents.agent_toolkits.json.toolkit import JsonToolkit
from langchain.agents.agent_toolkits.pandas.base import create_pandas_dataframe_agent
from langchain.agents.agent_toolkits.pandas.prompt import PREFIX as PANDAS_PREFIX
from langchain.agents.agent_toolkits.pandas.prompt import SUFFIX as PANDAS_SUFFIX
from langchain.agents.mrkl.prompt import FORMAT_INSTRUCTIONS
from langchain.llms.base import BaseLLM
from langchain.schema import BaseLanguageModel
from langchain.tools.python.tool import PythonAstREPLTool
from pydantic import BaseModel
class JsonAgent(AgentExecutor):
@ -65,12 +61,12 @@ class CSVAgent(AgentExecutor):
pandas_kwargs: Optional[dict] = None,
**kwargs: Any
):
import pandas as pd
import pandas as pd # type: ignore
_kwargs = pandas_kwargs or {}
df = pd.DataFrame.from_dict(path, **_kwargs)
tools = [PythonAstREPLTool(locals={"df": df})]
tools = [PythonAstREPLTool(locals={"df": df})] # type: ignore
prompt = ZeroShotAgent.create_prompt(
tools,
prefix=PANDAS_PREFIX,
@ -81,7 +77,6 @@ class CSVAgent(AgentExecutor):
llm_chain = LLMChain(
llm=llm,
prompt=partial_prompt,
callback_manager=None,
)
tool_names = [tool.name for tool in tools]
agent = ZeroShotAgent(llm_chain=llm_chain, allowed_tools=tool_names, **kwargs)

View file

@ -1,18 +1,11 @@
## LLM
from typing import Any
## LLM
from langchain import llms, requests
from langchain.agents import agent_toolkits
from langchain.chat_models import ChatOpenAI
from langflow.interface.importing.utils import import_class
llm_type_to_cls_dict = llms.type_to_cls_dict
llm_type_to_cls_dict["openai-chat"] = ChatOpenAI
## Memory
# from langchain.memory.buffer_window import ConversationBufferWindowMemory
# from langchain.memory.chat_memory import ChatMessageHistory
# from langchain.memory.combined import CombinedMemory
@ -22,108 +15,7 @@ llm_type_to_cls_dict["openai-chat"] = ChatOpenAI
# from langchain.memory.simple import SimpleMemory
# from langchain.memory.summary import ConversationSummaryMemory
# from langchain.memory.summary_buffer import ConversationSummaryBufferMemory
memory_type_to_cls_dict: dict[str, Any] = {
# "CombinedMemory": CombinedMemory,
# "ConversationBufferWindowMemory": ConversationBufferWindowMemory,
# "ConversationBufferMemory": ConversationBufferMemory,
# "SimpleMemory": SimpleMemory,
# "ConversationSummaryBufferMemory": ConversationSummaryBufferMemory,
# "ConversationKGMemory": ConversationKGMemory,
# "ConversationEntityMemory": ConversationEntityMemory,
# "ConversationSummaryMemory": ConversationSummaryMemory,
# "ChatMessageHistory": ChatMessageHistory,
# "ConversationStringBufferMemory": ConversationStringBufferMemory,
# "ReadOnlySharedMemory": ReadOnlySharedMemory,
}
## Chain
# from langchain.chains.loading import type_to_loader_dict
# from langchain.chains.conversation.base import ConversationChain
# chain_type_to_cls_dict = type_to_loader_dict
# chain_type_to_cls_dict["conversation_chain"] = ConversationChain
toolkit_type_to_loader_dict: dict[str, Any] = {
toolkit_name: import_class(f"langchain.agents.agent_toolkits.{toolkit_name}")
# if toolkit_name is lower case it is a loader
for toolkit_name in agent_toolkits.__all__
if toolkit_name.islower()
}
toolkit_type_to_cls_dict: dict[str, Any] = {
toolkit_name: import_class(f"langchain.agents.agent_toolkits.{toolkit_name}")
# if toolkit_name is not lower case it is a class
for toolkit_name in agent_toolkits.__all__
if not toolkit_name.islower()
}
wrapper_type_to_cls_dict: dict[str, Any] = {
wrapper.__name__: wrapper for wrapper in [requests.RequestsWrapper]
}
## Embeddings
from langchain.embeddings import (
CohereEmbeddings,
FakeEmbeddings,
HuggingFaceEmbeddings,
HuggingFaceHubEmbeddings,
HuggingFaceInstructEmbeddings,
OpenAIEmbeddings,
SelfHostedEmbeddings,
SelfHostedHuggingFaceEmbeddings,
SelfHostedHuggingFaceInstructEmbeddings,
# SagemakerEndpointEmbeddings,
TensorflowHubEmbeddings,
)
embedding_type_to_cls_dict = {
"OpenAIEmbeddings": OpenAIEmbeddings,
"HuggingFaceEmbeddings": HuggingFaceEmbeddings,
"CohereEmbeddings": CohereEmbeddings,
"HuggingFaceHubEmbeddings": HuggingFaceHubEmbeddings,
"TensorflowHubEmbeddings": TensorflowHubEmbeddings,
# "SagemakerEndpointEmbeddings": SagemakerEndpointEmbeddings,
"HuggingFaceInstructEmbeddings": HuggingFaceInstructEmbeddings,
"SelfHostedEmbeddings": SelfHostedEmbeddings,
"SelfHostedHuggingFaceEmbeddings": SelfHostedHuggingFaceEmbeddings,
"SelfHostedHuggingFaceInstructEmbeddings": SelfHostedHuggingFaceInstructEmbeddings,
"FakeEmbeddings": FakeEmbeddings,
}
## Vector Stores
from langchain.vectorstores import (
FAISS,
AtlasDB,
Chroma,
DeepLake,
ElasticVectorSearch,
Milvus,
OpenSearchVectorSearch,
Pinecone,
Qdrant,
VectorStore,
Weaviate,
)
vectorstores_type_to_cls_dict = {
"ElasticVectorSearch": ElasticVectorSearch,
"FAISS": FAISS,
"VectorStore": VectorStore,
"Pinecone": Pinecone,
"Weaviate": Weaviate,
"Qdrant": Qdrant,
"Milvus": Milvus,
"Chroma": Chroma,
"OpenSearchVectorSearch": OpenSearchVectorSearch,
"AtlasDB": AtlasDB,
"DeepLake": DeepLake,
}
## Document Loaders
from langchain.document_loaders import (
AirbyteJSONLoader,
AZLyricsLoader,
@ -173,6 +65,122 @@ from langchain.document_loaders import (
YoutubeLoader,
)
## Embeddings
from langchain.embeddings import (
CohereEmbeddings,
FakeEmbeddings,
HuggingFaceEmbeddings,
HuggingFaceHubEmbeddings,
HuggingFaceInstructEmbeddings,
OpenAIEmbeddings,
SelfHostedEmbeddings,
SelfHostedHuggingFaceEmbeddings,
SelfHostedHuggingFaceInstructEmbeddings,
# SagemakerEndpointEmbeddings,
TensorflowHubEmbeddings,
)
## Vector Stores
from langchain.vectorstores import (
FAISS,
AtlasDB,
Chroma,
DeepLake,
ElasticVectorSearch,
Milvus,
OpenSearchVectorSearch,
Pinecone,
Qdrant,
VectorStore,
Weaviate,
)
## Toolkits
from langflow.interface.importing.utils import import_class
## LLM
llm_type_to_cls_dict = llms.type_to_cls_dict
llm_type_to_cls_dict["openai-chat"] = ChatOpenAI # type: ignore
## Memory
memory_type_to_cls_dict: dict[str, Any] = {
# "CombinedMemory": CombinedMemory,
# "ConversationBufferWindowMemory": ConversationBufferWindowMemory,
# "ConversationBufferMemory": ConversationBufferMemory,
# "SimpleMemory": SimpleMemory,
# "ConversationSummaryBufferMemory": ConversationSummaryBufferMemory,
# "ConversationKGMemory": ConversationKGMemory,
# "ConversationEntityMemory": ConversationEntityMemory,
# "ConversationSummaryMemory": ConversationSummaryMemory,
# "ChatMessageHistory": ChatMessageHistory,
# "ConversationStringBufferMemory": ConversationStringBufferMemory,
# "ReadOnlySharedMemory": ReadOnlySharedMemory,
}
## Chain
# from langchain.chains.loading import type_to_loader_dict
# from langchain.chains.conversation.base import ConversationChain
# chain_type_to_cls_dict = type_to_loader_dict
# chain_type_to_cls_dict["conversation_chain"] = ConversationChain
toolkit_type_to_loader_dict: dict[str, Any] = {
toolkit_name: import_class(f"langchain.agents.agent_toolkits.{toolkit_name}")
# if toolkit_name is lower case it is a loader
for toolkit_name in agent_toolkits.__all__
if toolkit_name.islower()
}
toolkit_type_to_cls_dict: dict[str, Any] = {
toolkit_name: import_class(f"langchain.agents.agent_toolkits.{toolkit_name}")
# if toolkit_name is not lower case it is a class
for toolkit_name in agent_toolkits.__all__
if not toolkit_name.islower()
}
wrapper_type_to_cls_dict: dict[str, Any] = {
wrapper.__name__: wrapper for wrapper in [requests.RequestsWrapper]
}
## Embeddings
embedding_type_to_cls_dict = {
"OpenAIEmbeddings": OpenAIEmbeddings,
"HuggingFaceEmbeddings": HuggingFaceEmbeddings,
"CohereEmbeddings": CohereEmbeddings,
"HuggingFaceHubEmbeddings": HuggingFaceHubEmbeddings,
"TensorflowHubEmbeddings": TensorflowHubEmbeddings,
# "SagemakerEndpointEmbeddings": SagemakerEndpointEmbeddings,
"HuggingFaceInstructEmbeddings": HuggingFaceInstructEmbeddings,
"SelfHostedEmbeddings": SelfHostedEmbeddings,
"SelfHostedHuggingFaceEmbeddings": SelfHostedHuggingFaceEmbeddings,
"SelfHostedHuggingFaceInstructEmbeddings": SelfHostedHuggingFaceInstructEmbeddings,
"FakeEmbeddings": FakeEmbeddings,
}
## Vector Stores
vectorstores_type_to_cls_dict = {
"ElasticVectorSearch": ElasticVectorSearch,
"FAISS": FAISS,
"VectorStore": VectorStore,
"Pinecone": Pinecone,
"Weaviate": Weaviate,
"Qdrant": Qdrant,
"Milvus": Milvus,
"Chroma": Chroma,
"OpenSearchVectorSearch": OpenSearchVectorSearch,
"AtlasDB": AtlasDB,
"DeepLake": DeepLake,
}
## Document Loaders
documentloaders_type_to_cls_dict = {
"UnstructuredFileLoader": UnstructuredFileLoader,
"UnstructuredFileIOLoader": UnstructuredFileIOLoader,

View file

@ -9,7 +9,6 @@ from langchain.chains.base import Chain
from langchain.llms.base import BaseLLM
from langchain.tools import BaseTool
from langflow.interface.agents.custom import CUSTOM_AGENTS
from langflow.interface.tools.util import get_tool_by_name

View file

@ -29,7 +29,7 @@ def instantiate_class(node_type: str, base_type: str, params: Dict) -> Any:
"""Instantiate class from module type and key, and params"""
if node_type in CUSTOM_AGENTS:
if custom_agent := CUSTOM_AGENTS.get(node_type):
return custom_agent.initialize(**params)
return custom_agent.initialize(**params) # type: ignore
class_object = import_by_type(_type=base_type, name=node_type)

View file

@ -2,9 +2,7 @@ from typing import Dict, List
from langchain import requests
from langflow.custom.customs import get_custom_nodes
from langflow.interface.base import LangChainTypeCreator
from langflow.settings import settings
from langflow.utils.util import build_template_from_class

View file

@ -1,12 +1,10 @@
import hashlib
import json
import tempfile
from pathlib import Path
import dill
import pytest
from langflow.cache.utils import PREFIX, compute_hash, load_cache, save_cache
from langflow.interface.run import load_langchain_object, process_graph
from langflow.cache.utils import PREFIX, compute_hash
from langflow.interface.run import load_langchain_object
def get_graph(_type="basic"):

View file

@ -10,7 +10,7 @@ def sample_lang_chain_type_creator() -> LangChainTypeCreator:
class SampleLangChainTypeCreator(LangChainTypeCreator):
type_name: str = "test_type"
def type_to_loader_dict(self) -> Dict:
def type_to_loader_dict(self) -> Dict: # type: ignore
return {"test_type": "TestClass"}
def to_list(self) -> List[str]:

View file

@ -1,8 +1,4 @@
from typing import Dict, List
import pytest
from langflow.interface.agents.base import AgentCreator
from langflow.interface.base import LangChainTypeCreator
from langflow.template.base import FrontendNode, Template, TemplateField
@ -28,16 +24,16 @@ def sample_frontend_node(sample_template: Template) -> FrontendNode:
def test_template_field_defaults(sample_template_field: TemplateField):
assert sample_template_field.field_type == "str"
assert sample_template_field.required == False
assert sample_template_field.required is False
assert sample_template_field.placeholder == ""
assert sample_template_field.is_list == False
assert sample_template_field.show == True
assert sample_template_field.multiline == False
assert sample_template_field.value == None
assert sample_template_field.is_list is False
assert sample_template_field.show is True
assert sample_template_field.multiline is False
assert sample_template_field.value is None
assert sample_template_field.suffixes == []
assert sample_template_field.file_types == []
assert sample_template_field.content == None
assert sample_template_field.password == False
assert sample_template_field.content is None
assert sample_template_field.password is False
assert sample_template_field.name == "test_field"

View file

@ -1,5 +1,4 @@
import importlib
import re
from typing import Dict, List, Optional
import pytest