fix: password setting improvements

This commit is contained in:
Gabriel Almeida 2023-04-06 14:24:00 -03:00
commit 26524e3b6d
3 changed files with 31 additions and 30 deletions

View file

@ -19,7 +19,7 @@ class BaseCustomChain(ConversationChain):
template: Optional[str]
ai_prefix_key: Optional[str]
ai_prefix_value: Optional[str]
"""Field to use as the ai_prefix. It needs to be set and has to be in the template"""
@root_validator(pre=False)
@ -27,13 +27,13 @@ class BaseCustomChain(ConversationChain):
format_dict = {}
input_variables = extract_input_variables_from_prompt(values["template"])
if values.get("ai_prefix_key", None) is None:
values["ai_prefix_key"] = values["memory"].ai_prefix
if values.get("ai_prefix_value", None) is None:
values["ai_prefix_value"] = values["memory"].ai_prefix
for key in input_variables:
new_value = values.get(key, f"{{{key}}}")
format_dict[key] = new_value
if key == values.get("ai_prefix_key", None):
if key == values.get("ai_prefix_value", None):
values["memory"].ai_prefix = new_value
values["template"] = values["template"].format(**format_dict)
@ -62,7 +62,7 @@ Current conversation:
Human: {input}
{character}:"""
memory: BaseMemory = Field(default_factory=ConversationBufferMemory)
ai_prefix_key: Optional[str] = "character"
ai_prefix_value: Optional[str] = "character"
"""Default memory store."""

View file

@ -179,12 +179,13 @@ class FrontendNode(BaseModel):
(field.required and key not in ["input_variables"])
or key in FORCE_SHOW_FIELDS
or "api" in key
or "key" in key
or ("key" in key and "input" not in key and "output" not in key)
)
# Add password field
field.password = any(
text in key.lower() for text in {"password", "token", "api", "key"}
field.password = (
any(text in key.lower() for text in {"password", "token", "api", "key"})
and field.show
)
# Add multline

View file

@ -49,7 +49,7 @@ def test_conversation_chain(client: TestClient):
"show": False,
"multiline": False,
"value": "input",
"password": True,
"password": False,
"name": "input_key",
"type": "str",
"list": False,
@ -60,7 +60,7 @@ def test_conversation_chain(client: TestClient):
"show": False,
"multiline": False,
"value": "response",
"password": True,
"password": False,
"name": "output_key",
"type": "str",
"list": False,
@ -121,7 +121,7 @@ def test_llm_chain(client: TestClient):
"show": False,
"multiline": False,
"value": "text",
"password": True,
"password": False,
"name": "output_key",
"type": "str",
"list": False,
@ -175,7 +175,7 @@ def test_llm_checker_chain(client: TestClient):
"show": False,
"multiline": False,
"value": "query",
"password": True,
"password": False,
"name": "input_key",
"type": "str",
"list": False,
@ -186,7 +186,7 @@ def test_llm_checker_chain(client: TestClient):
"show": False,
"multiline": False,
"value": "result",
"password": True,
"password": False,
"name": "output_key",
"type": "str",
"list": False,
@ -247,7 +247,7 @@ def test_llm_math_chain(client: TestClient):
"show": False,
"multiline": False,
"value": "question",
"password": True,
"password": False,
"name": "input_key",
"type": "str",
"list": False,
@ -258,7 +258,7 @@ def test_llm_math_chain(client: TestClient):
"show": False,
"multiline": False,
"value": "answer",
"password": True,
"password": False,
"name": "output_key",
"type": "str",
"list": False,
@ -334,7 +334,7 @@ def test_series_character_chain(client: TestClient):
"show": False,
"multiline": False,
"value": "input",
"password": True,
"password": False,
"name": "input_key",
"type": "str",
"list": False,
@ -345,7 +345,7 @@ def test_series_character_chain(client: TestClient):
"show": False,
"multiline": False,
"value": "response",
"password": True,
"password": False,
"name": "output_key",
"type": "str",
"list": False,
@ -361,14 +361,14 @@ def test_series_character_chain(client: TestClient):
"type": "str",
"list": False,
}
assert template["ai_prefix_key"] == {
assert template["ai_prefix_value"] == {
"required": False,
"placeholder": "",
"show": False,
"multiline": False,
"value": "character",
"password": True,
"name": "ai_prefix_key",
"password": False,
"name": "ai_prefix_value",
"type": "str",
"list": False,
}
@ -485,7 +485,7 @@ def test_mid_journey_prompt_chain(client: TestClient):
"show": False,
"multiline": False,
"value": "response",
"password": True,
"password": False,
"name": "output_key",
"type": "str",
"list": False,
@ -496,7 +496,7 @@ def test_mid_journey_prompt_chain(client: TestClient):
"show": False,
"multiline": False,
"value": "input",
"password": True,
"password": False,
"name": "input_key",
"type": "str",
"list": False,
@ -512,13 +512,13 @@ def test_mid_journey_prompt_chain(client: TestClient):
"type": "str",
"list": False,
}
assert template["ai_prefix_key"] == {
assert template["ai_prefix_value"] == {
"required": False,
"placeholder": "",
"show": False,
"multiline": False,
"password": True,
"name": "ai_prefix_key",
"password": False,
"name": "ai_prefix_value",
"type": "str",
"list": False,
}
@ -613,7 +613,7 @@ def test_time_travel_guide_chain(client: TestClient):
"show": False,
"multiline": False,
"value": "response",
"password": True,
"password": False,
"name": "output_key",
"type": "str",
"list": False,
@ -625,7 +625,7 @@ def test_time_travel_guide_chain(client: TestClient):
"show": False,
"multiline": False,
"value": "input",
"password": True,
"password": False,
"name": "input_key",
"type": "str",
"list": False,
@ -642,13 +642,13 @@ def test_time_travel_guide_chain(client: TestClient):
"type": "str",
"list": False,
}
assert template["ai_prefix_key"] == {
assert template["ai_prefix_value"] == {
"required": False,
"placeholder": "",
"show": False,
"multiline": False,
"password": True,
"name": "ai_prefix_key",
"password": False,
"name": "ai_prefix_value",
"type": "str",
"list": False,
}