style: formatting and linting

This commit is contained in:
Gabriel Almeida 2023-03-28 21:44:41 -03:00
commit 7ec6847ae4
7 changed files with 30 additions and 31 deletions

View file

@ -1,5 +1,4 @@
from pydantic import BaseModel, validator
import json
class Code(BaseModel):

View file

@ -1,8 +1,7 @@
from typing import Any, Dict
from fastapi import APIRouter, HTTPException, Response
from fastapi import APIRouter, HTTPException
from langflow.api.base import Code, ValidationResponse
from langflow.interface.custom_types import PythonFunction
from langflow.interface.run import process_graph
from langflow.interface.types import build_langchain_types_dict

View file

@ -240,11 +240,25 @@ class Edge:
)
class ToolNode(Node):
def __init__(self, data: Dict):
super().__init__(data)
def build(self, force: bool = False) -> Any:
if not self._built or force:
self._build()
return deepcopy(self._built_object)
class PromptNode(Node):
def __init__(self, data: Dict):
super().__init__(data)
def build(self, tools: Optional[List[Node]] = None, force: bool = False) -> Any:
def build(
self,
force: bool = False,
tools: Optional[List[Node]] | Optional[List[ToolNode]] = None,
) -> Any:
if not self._built or force:
# Check if it is a ZeroShotPrompt and needs a tool
if self.node_type == "ZeroShotPrompt":
@ -259,27 +273,21 @@ class PromptNode(Node):
return deepcopy(self._built_object)
class ToolNode(Node):
def __init__(self, data: Dict):
super().__init__(data)
def build(self, force: bool = False) -> Any:
if not self._built or force:
self._build()
return deepcopy(self._built_object)
class ChainNode(Node):
def __init__(self, data: Dict):
super().__init__(data)
def build(self, tools: Optional[List[Node]] = None, force: bool = False) -> Any:
def build(
self,
force: bool = False,
tools: Optional[List[Node]] | Optional[List[ToolNode]] = None,
) -> Any:
if not self._built or force:
# Check if the chain requires a PromptNode
for key, value in self.params.items():
if isinstance(value, PromptNode):
# Build the PromptNode, passing the tools if available
self.params[key] = value.build(tools=tools or [], force=force)
self.params[key] = value.build(tools=tools, force=force)
self._build()
return deepcopy(self._built_object)
@ -351,13 +359,13 @@ class Graph:
return edges
def _build_nodes(self) -> List[Node]:
nodes = []
nodes: List[Node] = []
for node in self._nodes:
node_data = node["data"]
node_type = node_data["type"]
node_lc_type = node_data["node"]["template"]["_type"]
node_type: str = node_data["type"] # type: ignore
node_lc_type: str = node_data["node"]["template"]["_type"] # type: ignore
if node_type in ["ZeroShotPrompt", "PromptTemplate"]:
if node_type in {"ZeroShotPrompt", "PromptTemplate"}:
nodes.append(PromptNode(node))
elif "agent" in node_type.lower():
nodes.append(AgentNode(node))

View file

@ -3,7 +3,6 @@ import importlib
import inspect
import re
from typing import Dict, Optional
import types
from langchain.agents.load_tools import (
_BASE_TOOLS,

View file

@ -101,9 +101,8 @@ def execute_function(code, function_name, *args, **kwargs):
)
try:
exec(code_obj, exec_globals, locals())
except Exception as e:
# handle execution error here
pass
except Exception as exc:
raise ValueError("Function string does not contain a function") from exc
# Add the function to the exec_globals dictionary
exec_globals[function_name] = locals()[function_name]
@ -145,7 +144,7 @@ def create_function(code, function_name):
)
try:
exec(code_obj, exec_globals, locals())
except Exception as e:
except Exception:
pass
exec_globals[function_name] = locals()[function_name]

View file

@ -1,9 +1,6 @@
# Test this:
from typing import Optional, Callable
from langflow.interface.custom_types import PythonFunction
from langflow.utils import util, constants
from pydantic import BaseModel, ValidationError, validator
from langchain.agents import tool
from langflow.utils import constants
import pytest

View file

@ -1,7 +1,5 @@
from langflow.interface.listing import CUSTOM_TOOLS
from langflow.utils.constants import DEFAULT_PYTHON_FUNCTION
from fastapi.testclient import TestClient
import pytest
def test_get_all(client: TestClient):