Refactor code for better readability and maintainability

This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-03-06 18:46:05 -03:00
commit 34467d8f70
2 changed files with 38 additions and 18 deletions

View file

@ -123,9 +123,11 @@ class DocumentLoaderVertex(Vertex):
# show how many documents are in the list?
if not isinstance(self._built_object, UnbuiltObject):
avg_length = sum(len(record.text) for record in self._built_object if hasattr(record, "text")) / len(
self._built_object
)
avg_length = sum(
len(record.text)
for record in self._built_object
if hasattr(record, "text")
) / len(self._built_object)
return f"""{self.display_name}({len(self._built_object)} records)
\nAvg. Record Length (characters): {int(avg_length)}
Records: {self._built_object[:3]}..."""
@ -198,7 +200,9 @@ class TextSplitterVertex(Vertex):
# show how many documents are in the list?
if not isinstance(self._built_object, UnbuiltObject):
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(self._built_object)
avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(
self._built_object
)
return f"""{self.vertex_type}({len(self._built_object)} documents)
\nAvg. Document Length (characters): {int(avg_length)}
\nDocuments: {self._built_object[:3]}..."""
@ -245,18 +249,27 @@ class PromptVertex(Vertex):
user_id = kwargs.get("user_id", None)
tools = kwargs.get("tools", [])
if not self._built or force:
if "input_variables" not in self.params or self.params["input_variables"] is None:
if (
"input_variables" not in self.params
or self.params["input_variables"] is None
):
self.params["input_variables"] = []
# Check if it is a ZeroShotPrompt and needs a tool
if "ShotPrompt" in self.vertex_type:
tools = [tool_node.build(user_id=user_id) for tool_node in tools] if tools is not None else []
tools = (
[tool_node.build(user_id=user_id) for tool_node in tools]
if tools is not None
else []
)
# flatten the list of tools if it is a list of lists
# first check if it is a list
if tools and isinstance(tools, list) and isinstance(tools[0], list):
tools = flatten_list(tools)
self.params["tools"] = tools
prompt_params = [
key for key, value in self.params.items() if isinstance(value, str) and key != "format_instructions"
key
for key, value in self.params.items()
if isinstance(value, str) and key != "format_instructions"
]
else:
prompt_params = ["template"]
@ -266,14 +279,20 @@ class PromptVertex(Vertex):
prompt_text = self.params[param]
variables = extract_input_variables_from_prompt(prompt_text)
self.params["input_variables"].extend(variables)
self.params["input_variables"] = list(set(self.params["input_variables"]))
self.params["input_variables"] = list(
set(self.params["input_variables"])
)
elif isinstance(self.params, dict):
self.params.pop("input_variables", None)
await self._build(user_id=user_id)
def _built_object_repr(self):
if not self.artifacts or self._built_object is None or not hasattr(self._built_object, "format"):
if (
not self.artifacts
or self._built_object is None
or not hasattr(self._built_object, "format")
):
return super()._built_object_repr()
elif isinstance(self._built_object, UnbuiltObject):
return super()._built_object_repr()
@ -285,7 +304,9 @@ class PromptVertex(Vertex):
# so the prompt format doesn't break
artifacts.pop("handle_keys", None)
try:
if not hasattr(self._built_object, "template") and hasattr(self._built_object, "prompt"):
if not hasattr(self._built_object, "template") and hasattr(
self._built_object, "prompt"
):
template = self._built_object.prompt.template
else:
template = self._built_object.template
@ -293,7 +314,11 @@ class PromptVertex(Vertex):
if value:
replace_key = "{" + key + "}"
template = template.replace(replace_key, value)
return template if isinstance(template, str) else f"{self.vertex_type}({template})"
return (
template
if isinstance(template, str)
else f"{self.vertex_type}({template})"
)
except KeyError:
return str(self._built_object)
@ -372,7 +397,7 @@ class ChatVertex(Vertex):
self.will_stream = stream_url is not None
if artifacts:
self.artifacts = artifacts.model_dump()
self.artifacts = artifacts.model_dump(exclude_none=True)
if isinstance(self._built_object, (AsyncIterator, Iterator)):
if self.params["return_record"]:
self._built_object = Record(text=message, data=self.artifacts)

View file

@ -130,12 +130,7 @@ export default function ParameterComponent({
const shouldUpdate =
data.node?.template[name].refresh &&
data.node!.template[name].value !== newValue;
console.log("shouldUpdate", shouldUpdate);
console.log(
"data.node!.template[name].value",
data.node!.template[name].value
);
console.log("newValue", newValue);
data.node!.template[name].value = newValue; // necessary to enable ctrl+z inside the input
let newTemplate;
if (shouldUpdate) {