Refactor code for extracting input variables from prompt

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-01 11:11:18 -03:00
commit d6fe701c7f
2 changed files with 130 additions and 61 deletions

View file

@ -68,8 +68,6 @@ INVALID_CHARACTERS = {
")",
"[",
"]",
"{",
"}",
}
INVALID_NAMES = {
@ -88,73 +86,110 @@ def validate_prompt(template: str):
# Check if there are invalid characters in the input_variables
input_variables = check_input_variables(input_variables)
if any(var in INVALID_NAMES for var in input_variables):
raise ValueError(f"Invalid input variables. None of the variables can be named {', '.join(input_variables)}. ")
raise ValueError(
f"Invalid input variables. None of the variables can be named {', '.join(input_variables)}. "
)
try:
PromptTemplate(template=template, input_variables=input_variables)
except Exception as exc:
raise ValueError(str(exc)) from exc
raise ValueError(f"Invalid prompt: {exc}") from exc
return input_variables
def check_input_variables(input_variables: list):
def is_json_like(var):
if var.startswith("{{") and var.endswith("}}"):
# If it is a double brance variable
# we don't want to validate any of its content
return True
# the above doesn't work on all cases because the json string can be multiline
# or indented which can add \n or spaces at the start or end of the string
# test_case_3 new_var == '\n{{\n "test": "hello",\n "text": "world"\n}}\n'
# what we can do is to remove the \n and spaces from the start and end of the string
# and then check if the string starts with {{ and ends with }}
var = var.strip()
var = var.replace("\n", "")
var = var.replace(" ", "")
# Now it should be a valid json string
return var.startswith("{{") and var.endswith("}}")
def fix_variable(var, invalid_chars, wrong_variables):
if not var:
return var, invalid_chars, wrong_variables
new_var = var
# Handle variables starting with a number
if var[0].isdigit():
invalid_chars.append(var[0])
new_var, invalid_chars, wrong_variables = fix_variable(
var[1:], invalid_chars, wrong_variables
)
# Temporarily replace {{ and }} to avoid treating them as invalid
new_var = new_var.replace("{{", "ᴛᴇᴍᴘᴏᴘᴇɴ").replace("}}", "ᴛᴇᴍᴘʟsᴇ")
# Remove invalid characters
for char in new_var:
if char in INVALID_CHARACTERS:
invalid_chars.append(char)
new_var = new_var.replace(char, "")
if var not in wrong_variables: # Avoid duplicating entries
wrong_variables.append(var)
# Restore {{ and }}
new_var = new_var.replace("ᴛᴇᴍᴘᴏᴘᴇɴ", "{{").replace("ᴛᴇᴍᴘʟsᴇ", "}}")
return new_var, invalid_chars, wrong_variables
def check_variable(var, invalid_chars, wrong_variables, empty_variables):
if any(char in invalid_chars for char in var):
wrong_variables.append(var)
elif var == "":
empty_variables.append(var)
return wrong_variables, empty_variables
def check_for_errors(
input_variables, fixed_variables, wrong_variables, empty_variables
):
if any(var for var in input_variables if var not in fixed_variables):
error_message = (
f"Error: Input variables contain invalid characters or formats. \n"
f"Invalid variables: {', '.join(wrong_variables)}.\n"
f"Empty variables: {', '.join(empty_variables)}. \n"
f"Fixed variables: {', '.join(fixed_variables)}."
)
raise ValueError(error_message)
def check_input_variables(input_variables):
invalid_chars = []
fixed_variables = []
wrong_variables = []
empty_variables = []
for variable in input_variables:
new_var = variable
variables_to_check = []
# if variable is empty, then we should add that to the wrong variables
if not variable:
empty_variables.append(variable)
for var in input_variables:
# First, let's check if the variable is a JSON string
# because if it is, it won't be considered a variable
# and we don't need to validate it
if is_json_like(var):
continue
# if variable starts with a number we should add that to the invalid chars
# and wrong variables
if variable[0].isdigit():
invalid_chars.append(variable[0])
new_var = new_var.replace(variable[0], "")
wrong_variables.append(variable)
else:
for char in INVALID_CHARACTERS:
if char in variable:
invalid_chars.append(char)
new_var = new_var.replace(char, "")
wrong_variables.append(variable)
fixed_variables.append(new_var)
# If any of the input_variables is not in the fixed_variables, then it means that
# there are invalid characters in the input_variables
if any(var not in fixed_variables for var in input_variables):
error_message = build_error_message(
input_variables,
invalid_chars,
wrong_variables,
fixed_variables,
empty_variables,
new_var, wrong_variables, empty_variables = fix_variable(
var, invalid_chars, wrong_variables
)
raise ValueError(error_message)
return input_variables
wrong_variables, empty_variables = check_variable(
var, INVALID_CHARACTERS, wrong_variables, empty_variables
)
fixed_variables.append(new_var)
variables_to_check.append(var)
check_for_errors(
variables_to_check, fixed_variables, wrong_variables, empty_variables
)
def build_error_message(input_variables, invalid_chars, wrong_variables, fixed_variables, empty_variables):
input_variables_str = ", ".join([f"'{var}'" for var in input_variables])
error_string = f"Invalid input variables: {input_variables_str}. "
if wrong_variables and invalid_chars:
# fix the wrong variables replacing invalid chars and find them in the fixed variables
error_string_vars = "You can fix them by replacing the invalid characters: "
wvars = wrong_variables.copy()
for i, wrong_var in enumerate(wvars):
for char in invalid_chars:
wrong_var = wrong_var.replace(char, "")
if wrong_var in fixed_variables:
error_string_vars += f"'{wrong_variables[i]}' -> '{wrong_var}'"
error_string += error_string_vars
elif empty_variables:
error_string += f" There are {len(empty_variables)} empty variable{'s' if len(empty_variables) > 1 else ''}."
elif len(set(fixed_variables)) != len(fixed_variables):
error_string += "There are duplicate variables."
return error_string
return fixed_variables

View file

@ -1,14 +1,14 @@
import base64
import json
import os
from io import BytesIO
import re
from io import BytesIO
import yaml
from langchain.base_language import BaseLanguageModel
from PIL.Image import Image
from loguru import logger
from PIL.Image import Image
from langflow.services.chat.config import ChatConfig
from langflow.services.deps import get_settings_service
@ -43,7 +43,9 @@ def try_setting_streaming_options(langchain_object):
llm = None
if hasattr(langchain_object, "llm"):
llm = langchain_object.llm
elif hasattr(langchain_object, "llm_chain") and hasattr(langchain_object.llm_chain, "llm"):
elif hasattr(langchain_object, "llm_chain") and hasattr(
langchain_object.llm_chain, "llm"
):
llm = langchain_object.llm_chain.llm
if isinstance(llm, BaseLanguageModel):
@ -56,8 +58,37 @@ def try_setting_streaming_options(langchain_object):
def extract_input_variables_from_prompt(prompt: str) -> list[str]:
"""Extract input variables from prompt."""
return re.findall(r"{(.*?)}", prompt)
variables = []
remaining_text = prompt
# Pattern to match single {var} and double {{var}} braces.
pattern = r"\{\{(.*?)\}\}|\{([^{}]+)\}"
while True:
match = re.search(pattern, remaining_text)
if not match:
break
# Extract the variable name from either the single or double brace match
if match.group(1): # Match found in double braces
variable_name = (
"{{" + match.group(1) + "}}"
) # Re-add single braces for JSON strings
else: # Match found in single braces
variable_name = match.group(2)
if variable_name is not None:
# This means there is a match
# but there is nothing inside the braces
variables.append(variable_name)
# Remove the matched text from the remaining_text
start, end = match.span()
remaining_text = remaining_text[:start] + remaining_text[end:]
# Proceed to the next match until no more matches are found
# No need to compare remaining "{}" instances because we are re-adding braces for JSON compatibility
return variables
def setup_llm_caching():
@ -73,11 +104,14 @@ def setup_llm_caching():
def set_langchain_cache(settings):
from langchain.globals import set_llm_cache
from langflow.interface.importing.utils import import_class
if cache_type := os.getenv("LANGFLOW_LANGCHAIN_CACHE"):
try:
cache_class = import_class(f"langchain.cache.{cache_type or settings.LANGCHAIN_CACHE}")
cache_class = import_class(
f"langchain.cache.{cache_type or settings.LANGCHAIN_CACHE}"
)
logger.debug(f"Setting up LLM caching with {cache_class.__name__}")
set_llm_cache(cache_class())