Fix mark_branch function and refactor build_and_cache_graph_from_db function (#1833)
* Fix mark_branch function in Graph class to properly handle visited vertices * Refactor build_and_cache_graph_from_db function in utils.py to simplify code and remove unnecessary parameter * Fix retrieval of graph from cache in retrieve_vertices_order function * Fix possible_id type annotation in get_id_from_search_string function * Refactor APIRequest class to improve variable naming and remove unnecessary underscore prefix in headers parameter * Refactor buildVertices function in buildUtils.ts to improve code readability and remove unnecessary variable assignment * Fix API endpoints in test_endpoints.py to use correct HTTP methods
This commit is contained in:
parent
2dc4ffd99f
commit
92d39a6500
7 changed files with 31 additions and 38 deletions
|
|
@ -205,17 +205,12 @@ async def build_and_cache_graph_from_db(
|
|||
flow_id: str,
|
||||
session: Session,
|
||||
chat_service: "ChatService",
|
||||
graph: Optional[Graph] = None,
|
||||
):
|
||||
"""Build and cache the graph."""
|
||||
flow: Optional[Flow] = session.get(Flow, flow_id)
|
||||
if not flow or not flow.data:
|
||||
raise ValueError("Invalid flow ID")
|
||||
other_graph = Graph.from_payload(flow.data, flow_id)
|
||||
if graph is None:
|
||||
graph = other_graph
|
||||
else:
|
||||
graph = graph.update(other_graph)
|
||||
graph = Graph.from_payload(flow.data, flow_id)
|
||||
await chat_service.set_cache(flow_id, graph)
|
||||
return graph
|
||||
|
||||
|
|
|
|||
|
|
@ -79,13 +79,8 @@ async def retrieve_vertices_order(
|
|||
"""
|
||||
try:
|
||||
# First, we need to check if the flow_id is in the cache
|
||||
graph = None
|
||||
if not data:
|
||||
if cache := await chat_service.get_cache(flow_id):
|
||||
graph = cache.get("result")
|
||||
graph = await build_and_cache_graph_from_db(
|
||||
flow_id=flow_id, session=session, chat_service=chat_service, graph=graph
|
||||
)
|
||||
graph = await build_and_cache_graph_from_db(flow_id=flow_id, session=session, chat_service=chat_service)
|
||||
else:
|
||||
graph = await build_and_cache_graph_from_data(
|
||||
flow_id=flow_id, graph_data=data.model_dump(), chat_service=chat_service
|
||||
|
|
|
|||
|
|
@ -98,9 +98,9 @@ class APIRequest(CustomComponent):
|
|||
timeout: int = 5,
|
||||
) -> List[Record]:
|
||||
if headers is None:
|
||||
headers = {}
|
||||
headers_dict = {}
|
||||
else:
|
||||
headers = headers.data
|
||||
headers_dict = headers.data
|
||||
|
||||
bodies = []
|
||||
if body:
|
||||
|
|
@ -114,7 +114,7 @@ class APIRequest(CustomComponent):
|
|||
bodies += [None] * (len(urls) - len(bodies)) # type: ignore
|
||||
async with httpx.AsyncClient() as client:
|
||||
results = await asyncio.gather(
|
||||
*[self.make_request(client, method, u, headers, rec, timeout) for u, rec in zip(urls, bodies)]
|
||||
*[self.make_request(client, method, u, headers_dict, rec, timeout) for u, rec in zip(urls, bodies)]
|
||||
)
|
||||
self.status = results
|
||||
return results
|
||||
|
|
|
|||
|
|
@ -445,9 +445,16 @@ class Graph:
|
|||
vertex = self.get_vertex(vertex_id)
|
||||
vertex.set_state(state)
|
||||
|
||||
def mark_branch(self, vertex_id: str, state: str):
|
||||
def mark_branch(self, vertex_id: str, state: str, visited: Optional[set] = None):
|
||||
"""Marks a branch of the graph."""
|
||||
if visited is None:
|
||||
visited = set()
|
||||
visited.add(vertex_id)
|
||||
if vertex_id in visited:
|
||||
return
|
||||
|
||||
self.mark_vertex(vertex_id, state)
|
||||
|
||||
for child_id in self.parent_child_map[vertex_id]:
|
||||
self.mark_branch(child_id, state)
|
||||
|
||||
|
|
|
|||
|
|
@ -59,7 +59,7 @@ def get_id_from_search_string(search_string: str) -> Optional[str]:
|
|||
Returns:
|
||||
Optional[str]: The extracted ID, or None if no ID is found.
|
||||
"""
|
||||
possible_id = search_string
|
||||
possible_id: Optional[str] = search_string
|
||||
if "www.langflow.store/store/" in search_string:
|
||||
possible_id = search_string.split("/")[-1]
|
||||
|
||||
|
|
|
|||
|
|
@ -116,33 +116,29 @@ export async function buildVertices({
|
|||
nodes,
|
||||
edges,
|
||||
}: BuildVerticesParams) {
|
||||
let verticesBuild = useFlowStore.getState().verticesBuild;
|
||||
// if startNodeId and stopNodeId are provided
|
||||
// something is wrong
|
||||
if (startNodeId && stopNodeId) {
|
||||
return;
|
||||
}
|
||||
let verticesOrderResponse = await updateVerticesOrder(
|
||||
flowId,
|
||||
startNodeId,
|
||||
stopNodeId,
|
||||
nodes,
|
||||
edges
|
||||
);
|
||||
if (onValidateNodes) {
|
||||
try {
|
||||
onValidateNodes(verticesOrderResponse.verticesToRun);
|
||||
} catch (e) {
|
||||
useFlowStore.getState().setIsBuilding(false);
|
||||
|
||||
if (!verticesBuild || startNodeId || stopNodeId) {
|
||||
let verticesOrderResponse = await updateVerticesOrder(
|
||||
flowId,
|
||||
startNodeId,
|
||||
stopNodeId,
|
||||
nodes,
|
||||
edges
|
||||
);
|
||||
if (onValidateNodes) {
|
||||
try {
|
||||
onValidateNodes(verticesOrderResponse.verticesToRun);
|
||||
} catch (e) {
|
||||
useFlowStore.getState().setIsBuilding(false);
|
||||
|
||||
return;
|
||||
}
|
||||
return;
|
||||
}
|
||||
if (onGetOrderSuccess) onGetOrderSuccess();
|
||||
verticesBuild = useFlowStore.getState().verticesBuild;
|
||||
}
|
||||
if (onGetOrderSuccess) onGetOrderSuccess();
|
||||
let verticesBuild = useFlowStore.getState().verticesBuild;
|
||||
|
||||
const verticesIds = verticesBuild?.verticesIds!;
|
||||
const verticesLayers = verticesBuild?.verticesLayers!;
|
||||
|
|
|
|||
|
|
@ -393,13 +393,13 @@ def test_various_prompts(client, prompt, expected_input_variables):
|
|||
|
||||
|
||||
def test_get_vertices_flow_not_found(client, logged_in_headers):
|
||||
response = client.get("/api/v1/build/nonexistent_id/vertices", headers=logged_in_headers)
|
||||
response = client.post("/api/v1/build/nonexistent_id/vertices", headers=logged_in_headers)
|
||||
assert response.status_code == 500 # Or whatever status code you've set for invalid ID
|
||||
|
||||
|
||||
def test_get_vertices(client, added_flow_with_prompt_and_history, logged_in_headers):
|
||||
flow_id = added_flow_with_prompt_and_history["id"]
|
||||
response = client.get(f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers)
|
||||
response = client.post(f"/api/v1/build/{flow_id}/vertices", headers=logged_in_headers)
|
||||
assert response.status_code == 200
|
||||
assert "ids" in response.json()
|
||||
# The response should contain the list in this order
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue