custom prompt experiment

This commit is contained in:
Gabriel Almeida 2023-04-01 18:41:27 -03:00
commit 12c21b0185
7 changed files with 28 additions and 55 deletions

View file

@ -1,7 +1,9 @@
from langflow.template import nodes
CUSTOM_NODES = {
"prompts": {**nodes.ZeroShotPromptNode().to_dict()},
"prompts": {
**nodes.ZeroShotPromptNode().to_dict(),
},
"tools": {**nodes.PythonFunctionNode().to_dict(), **nodes.ToolNode().to_dict()},
"agents": {
**nodes.JsonAgentNode().to_dict(),

View file

@ -3,6 +3,7 @@ from typing import Any, Dict, Optional, Union
from pydantic import BaseModel
from langflow.template.constants import FORCE_SHOW_FIELDS
from langflow.utils import constants
@ -68,17 +69,7 @@ class TemplateFieldCreator(BaseModel, ABC):
# Show or not field
self.show = bool(
(self.required and key not in ["input_variables"])
or key
in [
"allowed_tools",
"memory",
"prefix",
"examples",
"temperature",
"model_name",
"headers",
"max_value_length",
]
or key in FORCE_SHOW_FIELDS
or "api_key" in key
)

View file

@ -49,6 +49,16 @@ class ZeroShotPromptNode(FrontendNode):
return super().to_dict()
class PromptTemplateNode(FrontendNode):
name: str = "PromptTemplate"
template: Template
description: str
base_classes: list[str] = ["BasePromptTemplate"]
def to_dict(self):
return super().to_dict()
class PythonFunctionNode(FrontendNode):
name: str = "PythonFunction"
template: Template = Template(

View file

@ -3,6 +3,7 @@ import inspect
import re
from typing import Dict, Optional
from langflow.template.constants import FORCE_SHOW_FIELDS
from langflow.utils import constants
@ -284,17 +285,7 @@ def format_dict(d, name: Optional[str] = None):
# Show or not field
value["show"] = bool(
(value["required"] and key not in ["input_variables"])
or key
in [
"allowed_tools",
"memory",
"prefix",
"examples",
"temperature",
"model_name",
"headers",
"max_value_length",
]
or key in FORCE_SHOW_FIELDS
or "api_key" in key
)

View file

@ -2,6 +2,7 @@ from pathlib import Path
import pytest
from fastapi.testclient import TestClient
from langflow.graph.graph import Graph
def pytest_configure():

View file

@ -340,43 +340,21 @@ def test_build_params(basic_graph):
assert isinstance(llm_node.params["model_name"], str)
def test_build(basic_graph, complex_graph):
def test_build(basic_graph, complex_graph, openapi_graph):
"""Test Node's build method"""
# def build(self):
# # The params dict is used to build the module
# # it contains values and keys that point to nodes which
# # have their own params dict
# # When build is called, we iterate through the params dict
# # and if the value is a node, we call build on that node
# # and use the output of that build as the value for the param
# # if the value is not a node, then we use the value as the param
# # and continue
# # Another aspect is that the node_type is the class that we need to import
# # and instantiate with these built params
assert_agent_was_built(basic_graph)
assert_agent_was_built(complex_graph)
assert_agent_was_built(openapi_graph)
# # Build each node in the params dict
# for key, value in self.params.items():
# if isinstance(value, Node):
# self.params[key] = value.build()
# # Get the class from LANGCHAIN_TYPES_DICT
# # and instantiate it with the params
# # and return the instance
# return LANGCHAIN_TYPES_DICT[self.node_type](**self.params)
assert isinstance(basic_graph, Graph)
def assert_agent_was_built(graph):
"""Assert that the agent was built"""
assert isinstance(graph, Graph)
# Now we test the build method
# Build the Agent
agent = basic_graph.build()
result = graph.build()
# The agent should be a AgentExecutor
assert isinstance(agent, AgentExecutor)
# Now we test the complex example
assert isinstance(complex_graph, Graph)
# Now we test the build method
agent = complex_graph.build()
# The agent should be a AgentExecutor
assert isinstance(agent, AgentExecutor)
assert isinstance(result, AgentExecutor)
def test_agent_node_build(basic_graph):
@ -384,7 +362,6 @@ def test_agent_node_build(basic_graph):
assert agent_node is not None
built_object = agent_node.build()
assert built_object is not None
# Add any further assertions specific to the AgentNode's build() method
def test_tool_node_build(basic_graph):

View file

@ -1,4 +1,5 @@
import importlib
import re
from typing import Dict, List, Optional
import pytest