Add unescape_string function and use it to unescape values in Vertex class

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-08 10:20:24 -03:00
commit b8639977ae
3 changed files with 56 additions and 17 deletions

View file

@ -30,7 +30,7 @@ from langflow.interface.listing import lazy_load_dict
from langflow.services.deps import get_storage_service
from langflow.utils.constants import DIRECT_TYPES
from langflow.utils.schemas import ChatOutputResponse
from langflow.utils.util import sync_to_async
from langflow.utils.util import sync_to_async, unescape_string
if TYPE_CHECKING:
from langflow.graph.edge.base import ContractEdge
@ -377,9 +377,9 @@ class Vertex:
# val may contain escaped \n, \t, etc.
# so we need to unescape it
if isinstance(val, list):
params[key] = [v.encode().decode("unicode_escape") for v in val]
params[key] = [unescape_string(v) for v in val]
elif isinstance(val, str):
params[key] = val.encode().decode("unicode_escape")
params[key] = unescape_string(val)
elif val is not None and val != "":
params[key] = val

View file

@ -32,6 +32,7 @@ from langflow.interface.utils import load_file_into_dict
from langflow.interface.wrappers.base import wrapper_creator
from langflow.schema.schema import Record
from langflow.utils import validate
from langflow.utils.util import unescape_string
if TYPE_CHECKING:
from langflow import CustomComponent
@ -144,9 +145,13 @@ async def instantiate_based_on_type(
return class_object(**params)
async def instantiate_custom_component(node_type, class_object, params, user_id, vertex):
async def instantiate_custom_component(
node_type, class_object, params, user_id, vertex
):
params_copy = params.copy()
class_object: Type["CustomComponent"] = eval_custom_component_code(params_copy.pop("code"))
class_object: Type["CustomComponent"] = eval_custom_component_code(
params_copy.pop("code")
)
custom_component: "CustomComponent" = class_object(
user_id=user_id,
parameters=params_copy,
@ -222,7 +227,9 @@ def instantiate_memory(node_type, class_object, params):
# I want to catch a specific attribute error that happens
# when the object does not have a cursor attribute
except Exception as exc:
if "object has no attribute 'cursor'" in str(exc) or 'object has no field "conn"' in str(exc):
if "object has no attribute 'cursor'" in str(
exc
) or 'object has no field "conn"' in str(exc):
raise AttributeError(
(
"Failed to build connection to database."
@ -265,7 +272,9 @@ def instantiate_agent(node_type, class_object: Type[agent_module.Agent], params:
if class_method := getattr(class_object, method, None):
agent = class_method(**params)
tools = params.get("tools", [])
return AgentExecutor.from_agent_and_tools(agent=agent, tools=tools, handle_parsing_errors=True)
return AgentExecutor.from_agent_and_tools(
agent=agent, tools=tools, handle_parsing_errors=True
)
return load_agent_executor(class_object, params)
@ -321,7 +330,11 @@ def instantiate_embedding(node_type, class_object, params: Dict):
try:
return class_object(**params)
except ValidationError:
params = {key: value for key, value in params.items() if key in class_object.model_fields}
params = {
key: value
for key, value in params.items()
if key in class_object.model_fields
}
return class_object(**params)
@ -333,7 +346,9 @@ def instantiate_vectorstore(class_object: Type[VectorStore], params: Dict):
if "texts" in params:
params["documents"] = params.pop("texts")
if "documents" in params:
params["documents"] = [doc for doc in params["documents"] if isinstance(doc, Document)]
params["documents"] = [
doc for doc in params["documents"] if isinstance(doc, Document)
]
if initializer := vecstore_initializer.get(class_object.__name__):
vecstore = initializer(class_object, params)
else:
@ -348,7 +363,9 @@ def instantiate_vectorstore(class_object: Type[VectorStore], params: Dict):
return vecstore
def instantiate_documentloader(node_type: str, class_object: Type[BaseLoader], params: Dict):
def instantiate_documentloader(
node_type: str, class_object: Type[BaseLoader], params: Dict
):
if "file_filter" in params:
# file_filter will be a string but we need a function
# that will be used to filter the files using file_filter
@ -357,13 +374,17 @@ def instantiate_documentloader(node_type: str, class_object: Type[BaseLoader], p
# in x and if it is, we will return True
file_filter = params.pop("file_filter")
extensions = file_filter.split(",")
params["file_filter"] = lambda x: any(extension.strip() in x for extension in extensions)
params["file_filter"] = lambda x: any(
extension.strip() in x for extension in extensions
)
metadata = params.pop("metadata", None)
if metadata and isinstance(metadata, str):
try:
metadata = orjson.loads(metadata)
except json.JSONDecodeError as exc:
raise ValueError("The metadata you provided is not a valid JSON string.") from exc
raise ValueError(
"The metadata you provided is not a valid JSON string."
) from exc
if node_type == "WebBaseLoader":
if web_path := params.pop("web_path", None):
@ -396,12 +417,19 @@ def instantiate_textsplitter(
"Try changing the chunk_size of the Text Splitter."
) from exc
if ("separator_type" in params and params["separator_type"] == "Text") or "separator_type" not in params:
if (
"separator_type" in params and params["separator_type"] == "Text"
) or "separator_type" not in params:
params.pop("separator_type", None)
# separators might come in as an escaped string like \\n
# so we need to convert it to a string
if "separators" in params:
params["separators"] = params["separators"].encode().decode("unicode-escape")
if isinstance(params["separators"], str):
params["separators"] = unescape_string(params["separators"])
elif isinstance(params["separators"], list):
params["separators"] = [
unescape_string(separator) for separator in params["separators"]
]
text_splitter = class_object(**params)
else:
from langchain.text_splitter import Language
@ -428,7 +456,8 @@ def replace_zero_shot_prompt_with_prompt_template(nodes):
tools = [
tool
for tool in nodes
if tool["type"] != "chatOutputNode" and "Tool" in tool["data"]["node"]["base_classes"]
if tool["type"] != "chatOutputNode"
and "Tool" in tool["data"]["node"]["base_classes"]
]
node["data"] = build_prompt_template(prompt=node["data"], tools=tools)
break
@ -442,7 +471,9 @@ def load_agent_executor(agent_class: type[agent_module.Agent], params, **kwargs)
# agent has hidden args for memory. might need to be support
# memory = params["memory"]
# if allowed_tools is not a list or set, make it a list
if not isinstance(allowed_tools, (list, set)) and isinstance(allowed_tools, BaseTool):
if not isinstance(allowed_tools, (list, set)) and isinstance(
allowed_tools, BaseTool
):
allowed_tools = [allowed_tools]
tool_names = [tool.name for tool in allowed_tools]
# Agent class requires an output_parser but Agent classes
@ -470,7 +501,10 @@ def build_prompt_template(prompt, tools):
format_instructions = prompt["node"]["template"]["format_instructions"]["value"]
tool_strings = "\n".join(
[f"{tool['data']['node']['name']}: {tool['data']['node']['description']}" for tool in tools]
[
f"{tool['data']['node']['name']}: {tool['data']['node']['description']}"
for tool in tools
]
)
tool_names = ", ".join([tool["data"]["node"]["name"] for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)

View file

@ -11,6 +11,11 @@ from langflow.template.frontend_node.constants import FORCE_SHOW_FIELDS
from langflow.utils import constants
def unescape_string(s):
# Replace escaped new line characters with actual new line characters
return s.replace("\\n", "\n")
def remove_ansi_escape_codes(text):
return re.sub(r"\x1b\[[0-9;]*[a-zA-Z]", "", text)