Refactor code for extracting input variables from prompt
This commit is contained in:
parent
e0989c0f53
commit
d6fe701c7f
2 changed files with 130 additions and 61 deletions
|
|
@ -68,8 +68,6 @@ INVALID_CHARACTERS = {
|
|||
")",
|
||||
"[",
|
||||
"]",
|
||||
"{",
|
||||
"}",
|
||||
}
|
||||
|
||||
INVALID_NAMES = {
|
||||
|
|
@ -88,73 +86,110 @@ def validate_prompt(template: str):
|
|||
# Check if there are invalid characters in the input_variables
|
||||
input_variables = check_input_variables(input_variables)
|
||||
if any(var in INVALID_NAMES for var in input_variables):
|
||||
raise ValueError(f"Invalid input variables. None of the variables can be named {', '.join(input_variables)}. ")
|
||||
raise ValueError(
|
||||
f"Invalid input variables. None of the variables can be named {', '.join(input_variables)}. "
|
||||
)
|
||||
|
||||
try:
|
||||
PromptTemplate(template=template, input_variables=input_variables)
|
||||
except Exception as exc:
|
||||
raise ValueError(str(exc)) from exc
|
||||
raise ValueError(f"Invalid prompt: {exc}") from exc
|
||||
|
||||
return input_variables
|
||||
|
||||
|
||||
def check_input_variables(input_variables: list):
|
||||
def is_json_like(var):
|
||||
if var.startswith("{{") and var.endswith("}}"):
|
||||
# If it is a double brance variable
|
||||
# we don't want to validate any of its content
|
||||
return True
|
||||
# the above doesn't work on all cases because the json string can be multiline
|
||||
# or indented which can add \n or spaces at the start or end of the string
|
||||
# test_case_3 new_var == '\n{{\n "test": "hello",\n "text": "world"\n}}\n'
|
||||
# what we can do is to remove the \n and spaces from the start and end of the string
|
||||
# and then check if the string starts with {{ and ends with }}
|
||||
var = var.strip()
|
||||
var = var.replace("\n", "")
|
||||
var = var.replace(" ", "")
|
||||
# Now it should be a valid json string
|
||||
return var.startswith("{{") and var.endswith("}}")
|
||||
|
||||
|
||||
def fix_variable(var, invalid_chars, wrong_variables):
|
||||
if not var:
|
||||
return var, invalid_chars, wrong_variables
|
||||
new_var = var
|
||||
|
||||
# Handle variables starting with a number
|
||||
if var[0].isdigit():
|
||||
invalid_chars.append(var[0])
|
||||
new_var, invalid_chars, wrong_variables = fix_variable(
|
||||
var[1:], invalid_chars, wrong_variables
|
||||
)
|
||||
|
||||
# Temporarily replace {{ and }} to avoid treating them as invalid
|
||||
new_var = new_var.replace("{{", "ᴛᴇᴍᴘᴏᴘᴇɴ").replace("}}", "ᴛᴇᴍᴘᴄʟᴏsᴇ")
|
||||
|
||||
# Remove invalid characters
|
||||
for char in new_var:
|
||||
if char in INVALID_CHARACTERS:
|
||||
invalid_chars.append(char)
|
||||
new_var = new_var.replace(char, "")
|
||||
if var not in wrong_variables: # Avoid duplicating entries
|
||||
wrong_variables.append(var)
|
||||
|
||||
# Restore {{ and }}
|
||||
new_var = new_var.replace("ᴛᴇᴍᴘᴏᴘᴇɴ", "{{").replace("ᴛᴇᴍᴘᴄʟᴏsᴇ", "}}")
|
||||
|
||||
return new_var, invalid_chars, wrong_variables
|
||||
|
||||
|
||||
def check_variable(var, invalid_chars, wrong_variables, empty_variables):
|
||||
if any(char in invalid_chars for char in var):
|
||||
wrong_variables.append(var)
|
||||
elif var == "":
|
||||
empty_variables.append(var)
|
||||
return wrong_variables, empty_variables
|
||||
|
||||
|
||||
def check_for_errors(
|
||||
input_variables, fixed_variables, wrong_variables, empty_variables
|
||||
):
|
||||
if any(var for var in input_variables if var not in fixed_variables):
|
||||
error_message = (
|
||||
f"Error: Input variables contain invalid characters or formats. \n"
|
||||
f"Invalid variables: {', '.join(wrong_variables)}.\n"
|
||||
f"Empty variables: {', '.join(empty_variables)}. \n"
|
||||
f"Fixed variables: {', '.join(fixed_variables)}."
|
||||
)
|
||||
raise ValueError(error_message)
|
||||
|
||||
|
||||
def check_input_variables(input_variables):
|
||||
invalid_chars = []
|
||||
fixed_variables = []
|
||||
wrong_variables = []
|
||||
empty_variables = []
|
||||
for variable in input_variables:
|
||||
new_var = variable
|
||||
variables_to_check = []
|
||||
|
||||
# if variable is empty, then we should add that to the wrong variables
|
||||
if not variable:
|
||||
empty_variables.append(variable)
|
||||
for var in input_variables:
|
||||
# First, let's check if the variable is a JSON string
|
||||
# because if it is, it won't be considered a variable
|
||||
# and we don't need to validate it
|
||||
if is_json_like(var):
|
||||
continue
|
||||
|
||||
# if variable starts with a number we should add that to the invalid chars
|
||||
# and wrong variables
|
||||
if variable[0].isdigit():
|
||||
invalid_chars.append(variable[0])
|
||||
new_var = new_var.replace(variable[0], "")
|
||||
wrong_variables.append(variable)
|
||||
else:
|
||||
for char in INVALID_CHARACTERS:
|
||||
if char in variable:
|
||||
invalid_chars.append(char)
|
||||
new_var = new_var.replace(char, "")
|
||||
wrong_variables.append(variable)
|
||||
fixed_variables.append(new_var)
|
||||
# If any of the input_variables is not in the fixed_variables, then it means that
|
||||
# there are invalid characters in the input_variables
|
||||
|
||||
if any(var not in fixed_variables for var in input_variables):
|
||||
error_message = build_error_message(
|
||||
input_variables,
|
||||
invalid_chars,
|
||||
wrong_variables,
|
||||
fixed_variables,
|
||||
empty_variables,
|
||||
new_var, wrong_variables, empty_variables = fix_variable(
|
||||
var, invalid_chars, wrong_variables
|
||||
)
|
||||
raise ValueError(error_message)
|
||||
return input_variables
|
||||
wrong_variables, empty_variables = check_variable(
|
||||
var, INVALID_CHARACTERS, wrong_variables, empty_variables
|
||||
)
|
||||
fixed_variables.append(new_var)
|
||||
variables_to_check.append(var)
|
||||
|
||||
check_for_errors(
|
||||
variables_to_check, fixed_variables, wrong_variables, empty_variables
|
||||
)
|
||||
|
||||
def build_error_message(input_variables, invalid_chars, wrong_variables, fixed_variables, empty_variables):
|
||||
input_variables_str = ", ".join([f"'{var}'" for var in input_variables])
|
||||
error_string = f"Invalid input variables: {input_variables_str}. "
|
||||
|
||||
if wrong_variables and invalid_chars:
|
||||
# fix the wrong variables replacing invalid chars and find them in the fixed variables
|
||||
error_string_vars = "You can fix them by replacing the invalid characters: "
|
||||
wvars = wrong_variables.copy()
|
||||
for i, wrong_var in enumerate(wvars):
|
||||
for char in invalid_chars:
|
||||
wrong_var = wrong_var.replace(char, "")
|
||||
if wrong_var in fixed_variables:
|
||||
error_string_vars += f"'{wrong_variables[i]}' -> '{wrong_var}'"
|
||||
error_string += error_string_vars
|
||||
elif empty_variables:
|
||||
error_string += f" There are {len(empty_variables)} empty variable{'s' if len(empty_variables) > 1 else ''}."
|
||||
elif len(set(fixed_variables)) != len(fixed_variables):
|
||||
error_string += "There are duplicate variables."
|
||||
return error_string
|
||||
return fixed_variables
|
||||
|
|
|
|||
|
|
@ -1,14 +1,14 @@
|
|||
import base64
|
||||
import json
|
||||
import os
|
||||
from io import BytesIO
|
||||
import re
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
import yaml
|
||||
from langchain.base_language import BaseLanguageModel
|
||||
from PIL.Image import Image
|
||||
from loguru import logger
|
||||
from PIL.Image import Image
|
||||
|
||||
from langflow.services.chat.config import ChatConfig
|
||||
from langflow.services.deps import get_settings_service
|
||||
|
||||
|
|
@ -43,7 +43,9 @@ def try_setting_streaming_options(langchain_object):
|
|||
llm = None
|
||||
if hasattr(langchain_object, "llm"):
|
||||
llm = langchain_object.llm
|
||||
elif hasattr(langchain_object, "llm_chain") and hasattr(langchain_object.llm_chain, "llm"):
|
||||
elif hasattr(langchain_object, "llm_chain") and hasattr(
|
||||
langchain_object.llm_chain, "llm"
|
||||
):
|
||||
llm = langchain_object.llm_chain.llm
|
||||
|
||||
if isinstance(llm, BaseLanguageModel):
|
||||
|
|
@ -56,8 +58,37 @@ def try_setting_streaming_options(langchain_object):
|
|||
|
||||
|
||||
def extract_input_variables_from_prompt(prompt: str) -> list[str]:
|
||||
"""Extract input variables from prompt."""
|
||||
return re.findall(r"{(.*?)}", prompt)
|
||||
variables = []
|
||||
remaining_text = prompt
|
||||
|
||||
# Pattern to match single {var} and double {{var}} braces.
|
||||
pattern = r"\{\{(.*?)\}\}|\{([^{}]+)\}"
|
||||
|
||||
while True:
|
||||
match = re.search(pattern, remaining_text)
|
||||
if not match:
|
||||
break
|
||||
|
||||
# Extract the variable name from either the single or double brace match
|
||||
if match.group(1): # Match found in double braces
|
||||
variable_name = (
|
||||
"{{" + match.group(1) + "}}"
|
||||
) # Re-add single braces for JSON strings
|
||||
else: # Match found in single braces
|
||||
variable_name = match.group(2)
|
||||
if variable_name is not None:
|
||||
# This means there is a match
|
||||
# but there is nothing inside the braces
|
||||
variables.append(variable_name)
|
||||
|
||||
# Remove the matched text from the remaining_text
|
||||
start, end = match.span()
|
||||
remaining_text = remaining_text[:start] + remaining_text[end:]
|
||||
|
||||
# Proceed to the next match until no more matches are found
|
||||
# No need to compare remaining "{}" instances because we are re-adding braces for JSON compatibility
|
||||
|
||||
return variables
|
||||
|
||||
|
||||
def setup_llm_caching():
|
||||
|
|
@ -73,11 +104,14 @@ def setup_llm_caching():
|
|||
|
||||
def set_langchain_cache(settings):
|
||||
from langchain.globals import set_llm_cache
|
||||
|
||||
from langflow.interface.importing.utils import import_class
|
||||
|
||||
if cache_type := os.getenv("LANGFLOW_LANGCHAIN_CACHE"):
|
||||
try:
|
||||
cache_class = import_class(f"langchain.cache.{cache_type or settings.LANGCHAIN_CACHE}")
|
||||
cache_class = import_class(
|
||||
f"langchain.cache.{cache_type or settings.LANGCHAIN_CACHE}"
|
||||
)
|
||||
|
||||
logger.debug(f"Setting up LLM caching with {cache_class.__name__}")
|
||||
set_llm_cache(cache_class())
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue