Fix mark_branch function and refactor build_and_cache_graph_from_db function (#1833)

* Fix mark_branch function in Graph class to properly handle visited vertices

* Refactor build_and_cache_graph_from_db function in utils.py to simplify code and remove unnecessary parameter

* Fix retrieval of graph from cache in retrieve_vertices_order function

* Fix possible_id type annotation in get_id_from_search_string function

* Refactor APIRequest class to improve variable naming and remove unnecessary underscore prefix in headers parameter

* Refactor buildVertices function in buildUtils.ts to improve code readability and remove unnecessary variable assignment

* Fix API endpoints in test_endpoints.py to use correct HTTP methods
This commit is contained in:
Gabriel Luiz Freitas Almeida 2024-05-03 18:23:41 -03:00 committed by GitHub
commit 92d39a6500
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 31 additions and 38 deletions

View file

@ -205,17 +205,12 @@ async def build_and_cache_graph_from_db(
flow_id: str,
session: Session,
chat_service: "ChatService",
graph: Optional[Graph] = None,
):
"""Build and cache the graph."""
flow: Optional[Flow] = session.get(Flow, flow_id)
if not flow or not flow.data:
raise ValueError("Invalid flow ID")
other_graph = Graph.from_payload(flow.data, flow_id)
if graph is None:
graph = other_graph
else:
graph = graph.update(other_graph)
graph = Graph.from_payload(flow.data, flow_id)
await chat_service.set_cache(flow_id, graph)
return graph

View file

@ -79,13 +79,8 @@ async def retrieve_vertices_order(
"""
try:
# First, we need to check if the flow_id is in the cache
graph = None
if not data:
if cache := await chat_service.get_cache(flow_id):
graph = cache.get("result")
graph = await build_and_cache_graph_from_db(
flow_id=flow_id, session=session, chat_service=chat_service, graph=graph
)
graph = await build_and_cache_graph_from_db(flow_id=flow_id, session=session, chat_service=chat_service)
else:
graph = await build_and_cache_graph_from_data(
flow_id=flow_id, graph_data=data.model_dump(), chat_service=chat_service

View file

@ -98,9 +98,9 @@ class APIRequest(CustomComponent):
timeout: int = 5,
) -> List[Record]:
if headers is None:
headers = {}
headers_dict = {}
else:
headers = headers.data
headers_dict = headers.data
bodies = []
if body:
@ -114,7 +114,7 @@ class APIRequest(CustomComponent):
bodies += [None] * (len(urls) - len(bodies)) # type: ignore
async with httpx.AsyncClient() as client:
results = await asyncio.gather(
*[self.make_request(client, method, u, headers, rec, timeout) for u, rec in zip(urls, bodies)]
*[self.make_request(client, method, u, headers_dict, rec, timeout) for u, rec in zip(urls, bodies)]
)
self.status = results
return results

View file

@ -445,9 +445,16 @@ class Graph:
vertex = self.get_vertex(vertex_id)
vertex.set_state(state)
def mark_branch(self, vertex_id: str, state: str):
def mark_branch(self, vertex_id: str, state: str, visited: Optional[set] = None):
"""Marks a branch of the graph."""
if visited is None:
visited = set()
visited.add(vertex_id)
if vertex_id in visited:
return
self.mark_vertex(vertex_id, state)
for child_id in self.parent_child_map[vertex_id]:
self.mark_branch(child_id, state)

View file

@ -59,7 +59,7 @@ def get_id_from_search_string(search_string: str) -> Optional[str]:
Returns:
Optional[str]: The extracted ID, or None if no ID is found.
"""
possible_id = search_string
possible_id: Optional[str] = search_string
if "www.langflow.store/store/" in search_string:
possible_id = search_string.split("/")[-1]

View file

@ -116,33 +116,29 @@ export async function buildVertices({
nodes,
edges,
}: BuildVerticesParams) {
let verticesBuild = useFlowStore.getState().verticesBuild;
// if startNodeId and stopNodeId are provided
// something is wrong
if (startNodeId && stopNodeId) {
return;
}
let verticesOrderResponse = await updateVerticesOrder(
flowId,
startNodeId,
stopNodeId,
nodes,
edges
);
if (onValidateNodes) {
try {
onValidateNodes(verticesOrderResponse.verticesToRun);
} catch (e) {
useFlowStore.getState().setIsBuilding(false);
if (!verticesBuild || startNodeId || stopNodeId) {
let verticesOrderResponse = await updateVerticesOrder(
flowId,
startNodeId,
stopNodeId,
nodes,
edges
);
if (onValidateNodes) {
try {
onValidateNodes(verticesOrderResponse.verticesToRun);
} catch (e) {
useFlowStore.getState().setIsBuilding(false);
return;
}
return;
}
if (onGetOrderSuccess) onGetOrderSuccess();
verticesBuild = useFlowStore.getState().verticesBuild;
}
if (onGetOrderSuccess) onGetOrderSuccess();
let verticesBuild = useFlowStore.getState().verticesBuild;
const verticesIds = verticesBuild?.verticesIds!;
const verticesLayers = verticesBuild?.verticesLayers!;

View file

@ -393,13 +393,13 @@ def test_various_prompts(client, prompt, expected_input_variables):
def test_get_vertices_flow_not_found(client, logged_in_headers):
response = client.get("/api/v1/build/nonexistent_id/vertices", headers=logged_in_headers)
response = client.post("/api/v1/build/nonexistent_id/vertices", headers=logged_in_headers)
assert response.status_code == 500 # Or whatever status code you've set for invalid ID
def test_get_vertices(client, added_flow_with_prompt_and_history, logged_in_headers):
flow_id = added_flow_with_prompt_and_history["id"]
response = client.get(f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers)
response = client.post(f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers)
assert response.status_code == 200
assert "ids" in response.json()
# The response should contain the list in this order