From 90569d73c8e498cccf16149737dcb9db6b169afb Mon Sep 17 00:00:00 2001 From: Gabriel Luiz Freitas Almeida Date: Fri, 1 Mar 2024 09:19:00 -0300 Subject: [PATCH] Fix import statements and formatting issues --- .../documentloaders/GatherRecords.py | 32 ++++++++++--- .../langflow/components/io/base/chat.py | 4 +- .../components/utilities/SharedState.py | 38 +++++++++++++++ .../components/vectorstores/ChromaSearch.py | 3 +- .../components/vectorstores/FAISSSearch.py | 4 +- .../vectorstores/SupabaseVectorStoreSearch.py | 4 +- .../components/vectorstores/VectaraSearch.py | 4 +- .../components/vectorstores/WeaviateSearch.py | 4 +- .../components/vectorstores/base/model.py | 3 +- .../components/vectorstores/pgvectorSearch.py | 4 +- src/backend/langflow/graph/vertex/types.py | 47 ++++++++++++++----- src/backend/langflow/memory.py | 3 +- .../langflow/services/monitor/schema.py | 12 +++-- 13 files changed, 132 insertions(+), 30 deletions(-) create mode 100644 src/backend/langflow/components/utilities/SharedState.py diff --git a/src/backend/langflow/components/documentloaders/GatherRecords.py b/src/backend/langflow/components/documentloaders/GatherRecords.py index 51a14dd31..ac298c092 100644 --- a/src/backend/langflow/components/documentloaders/GatherRecords.py +++ b/src/backend/langflow/components/documentloaders/GatherRecords.py @@ -70,11 +70,17 @@ class GatherRecordsComponent(CustomComponent): glob = "**/*" if recursive else "*" paths = walk_level(path_obj, depth) if depth else path_obj.glob(glob) - file_paths = [Text(p) for p in paths if p.is_file() and match_types(p) and is_not_hidden(p)] + file_paths = [ + Text(p) + for p in paths + if p.is_file() and match_types(p) and is_not_hidden(p) + ] return file_paths - def parse_file_to_record(self, file_path: str, silent_errors: bool) -> Optional[Record]: + def parse_file_to_record( + self, file_path: str, silent_errors: bool + ) -> Optional[Record]: # Use the partition function to load the file from unstructured.partition.auto import partition # type: ignore @@ -100,9 +106,14 @@ class GatherRecordsComponent(CustomComponent): use_multithreading: bool, ) -> List[Optional[Record]]: if use_multithreading: - records = self.parallel_load_records(file_paths, silent_errors, max_concurrency) + records = self.parallel_load_records( + file_paths, silent_errors, max_concurrency + ) else: - records = [self.parse_file_to_record(file_path, silent_errors) for file_path in file_paths] + records = [ + self.parse_file_to_record(file_path, silent_errors) + for file_path in file_paths + ] records = list(filter(None, records)) return records @@ -131,13 +142,20 @@ class GatherRecordsComponent(CustomComponent): if types is None: types = [] resolved_path = self.resolve_path(path) - file_paths = self.retrieve_file_paths(resolved_path, types, load_hidden, recursive, depth) + file_paths = self.retrieve_file_paths( + resolved_path, types, load_hidden, recursive, depth + ) loaded_records = [] if use_multithreading: - loaded_records = self.parallel_load_records(file_paths, silent_errors, max_concurrency) + loaded_records = self.parallel_load_records( + file_paths, silent_errors, max_concurrency + ) else: - loaded_records = [self.parse_file_to_record(file_path, silent_errors) for file_path in file_paths] + loaded_records = [ + self.parse_file_to_record(file_path, silent_errors) + for file_path in file_paths + ] loaded_records = list(filter(None, loaded_records)) self.status = loaded_records return loaded_records diff --git a/src/backend/langflow/components/io/base/chat.py b/src/backend/langflow/components/io/base/chat.py index db15223ee..8695dbcbc 100644 --- a/src/backend/langflow/components/io/base/chat.py +++ b/src/backend/langflow/components/io/base/chat.py @@ -45,7 +45,9 @@ class ChatComponent(CustomComponent): return [] if not session_id or not sender or not sender_name: - raise ValueError("All of session_id, sender, and sender_name must be provided.") + raise ValueError( + "All of session_id, sender, and sender_name must be provided." + ) if isinstance(message, Record): record = message record.data.update( diff --git a/src/backend/langflow/components/utilities/SharedState.py b/src/backend/langflow/components/utilities/SharedState.py new file mode 100644 index 000000000..7d29da9bb --- /dev/null +++ b/src/backend/langflow/components/utilities/SharedState.py @@ -0,0 +1,38 @@ +from typing import Union + +from langflow import CustomComponent +from langflow.field_typing import Text +from langflow.schema import Record + + +class SharedState(CustomComponent): + display_name = "Shared State" + description = "A component to share state between components." + + def build_config(self): + return { + "name": {"display_name": "Name", "info": "The name of the state."}, + "record": {"display_name": "Record", "info": "The record to store."}, + "append": { + "display_name": "Append", + "info": "If True, the record will be appended to the state.", + }, + } + + def build( + self, name: str, record: Union[Text, Record], append: bool = False + ) -> Record: + if append: + self.append_state(name, record) + else: + self.update_state(name, record) + + state = self.get_state(name) + if not isinstance(state, Record): + if isinstance(state, str): + state = Record(text=state) + elif isinstance(state, dict): + state = Record(data=state) + else: + state = Record(text=str(state)) + return state diff --git a/src/backend/langflow/components/vectorstores/ChromaSearch.py b/src/backend/langflow/components/vectorstores/ChromaSearch.py index 8584d7a96..3a6d283b3 100644 --- a/src/backend/langflow/components/vectorstores/ChromaSearch.py +++ b/src/backend/langflow/components/vectorstores/ChromaSearch.py @@ -93,7 +93,8 @@ class ChromaSearchComponent(LCVectorStoreComponent): if chroma_server_host is not None: chroma_settings = chromadb.config.Settings( - chroma_server_cors_allow_origins=chroma_server_cors_allow_origins or None, + chroma_server_cors_allow_origins=chroma_server_cors_allow_origins + or None, chroma_server_host=chroma_server_host, chroma_server_port=chroma_server_port or None, chroma_server_grpc_port=chroma_server_grpc_port or None, diff --git a/src/backend/langflow/components/vectorstores/FAISSSearch.py b/src/backend/langflow/components/vectorstores/FAISSSearch.py index 2b7b5e633..f6ddf4f7a 100644 --- a/src/backend/langflow/components/vectorstores/FAISSSearch.py +++ b/src/backend/langflow/components/vectorstores/FAISSSearch.py @@ -34,7 +34,9 @@ class FAISSSearchComponent(LCVectorStoreComponent): if not folder_path: raise ValueError("Folder path is required to save the FAISS index.") path = self.resolve_path(folder_path) - vector_store = FAISS.load_local(folder_path=Text(path), embeddings=embedding, index_name=index_name) + vector_store = FAISS.load_local( + folder_path=Text(path), embeddings=embedding, index_name=index_name + ) if not vector_store: raise ValueError("Failed to load the FAISS index.") diff --git a/src/backend/langflow/components/vectorstores/SupabaseVectorStoreSearch.py b/src/backend/langflow/components/vectorstores/SupabaseVectorStoreSearch.py index ca8113c56..5fd4dbd18 100644 --- a/src/backend/langflow/components/vectorstores/SupabaseVectorStoreSearch.py +++ b/src/backend/langflow/components/vectorstores/SupabaseVectorStoreSearch.py @@ -38,7 +38,9 @@ class SupabaseSearchComponent(LCVectorStoreComponent): supabase_url: str = "", table_name: str = "", ) -> List[Record]: - supabase: Client = create_client(supabase_url, supabase_key=supabase_service_key) + supabase: Client = create_client( + supabase_url, supabase_key=supabase_service_key + ) vector_store = SupabaseVectorStore( client=supabase, embedding=embedding, diff --git a/src/backend/langflow/components/vectorstores/VectaraSearch.py b/src/backend/langflow/components/vectorstores/VectaraSearch.py index c676cb538..ae2d442be 100644 --- a/src/backend/langflow/components/vectorstores/VectaraSearch.py +++ b/src/backend/langflow/components/vectorstores/VectaraSearch.py @@ -11,7 +11,9 @@ from langflow.schema import Record class VectaraSearchComponent(VectaraComponent, LCVectorStoreComponent): display_name: str = "Vectara Search" description: str = "Search a Vectara Vector Store for similar documents." - documentation = "https://python.langchain.com/docs/integrations/vectorstores/vectara" + documentation = ( + "https://python.langchain.com/docs/integrations/vectorstores/vectara" + ) beta = True icon = "Vectara" diff --git a/src/backend/langflow/components/vectorstores/WeaviateSearch.py b/src/backend/langflow/components/vectorstores/WeaviateSearch.py index 6d33755a8..6eee202c9 100644 --- a/src/backend/langflow/components/vectorstores/WeaviateSearch.py +++ b/src/backend/langflow/components/vectorstores/WeaviateSearch.py @@ -11,7 +11,9 @@ from langflow.schema import Record class WeaviateSearchVectorStore(WeaviateVectorStoreComponent, LCVectorStoreComponent): display_name: str = "Weaviate Search" description: str = "Search a Weaviate Vector Store for similar documents." - documentation = "https://python.langchain.com/docs/integrations/vectorstores/weaviate" + documentation = ( + "https://python.langchain.com/docs/integrations/vectorstores/weaviate" + ) beta = True icon = "Weaviate" diff --git a/src/backend/langflow/components/vectorstores/base/model.py b/src/backend/langflow/components/vectorstores/base/model.py index 2bc766b8b..6c2c7d453 100644 --- a/src/backend/langflow/components/vectorstores/base/model.py +++ b/src/backend/langflow/components/vectorstores/base/model.py @@ -6,7 +6,8 @@ from langchain_core.vectorstores import VectorStore from langflow import CustomComponent from langflow.field_typing import Text -from langflow.schema import Record, docs_to_records +from langflow.helpers.record import docs_to_records +from langflow.schema import Record class LCVectorStoreComponent(CustomComponent): diff --git a/src/backend/langflow/components/vectorstores/pgvectorSearch.py b/src/backend/langflow/components/vectorstores/pgvectorSearch.py index 04666fe74..f40e5ed26 100644 --- a/src/backend/langflow/components/vectorstores/pgvectorSearch.py +++ b/src/backend/langflow/components/vectorstores/pgvectorSearch.py @@ -15,7 +15,9 @@ class PGVectorSearchComponent(PGVectorComponent, LCVectorStoreComponent): display_name: str = "PGVector Search" description: str = "Search a PGVector Store for similar documents." - documentation = "https://python.langchain.com/docs/integrations/vectorstores/pgvector" + documentation = ( + "https://python.langchain.com/docs/integrations/vectorstores/pgvector" + ) def build_config(self): """ diff --git a/src/backend/langflow/graph/vertex/types.py b/src/backend/langflow/graph/vertex/types.py index c4f33df40..f49bbdf0a 100644 --- a/src/backend/langflow/graph/vertex/types.py +++ b/src/backend/langflow/graph/vertex/types.py @@ -123,9 +123,11 @@ class DocumentLoaderVertex(StatefulVertex): # show how many documents are in the list? if not isinstance(self._built_object, UnbuiltObject): - avg_length = sum(len(doc.page_content) for doc in self._built_object if hasattr(doc, "page_content")) / len( - self._built_object - ) + avg_length = sum( + len(doc.page_content) + for doc in self._built_object + if hasattr(doc, "page_content") + ) / len(self._built_object) return f"""{self.display_name}({len(self._built_object)} documents) \nAvg. Document Length (characters): {int(avg_length)} Documents: {self._built_object[:3]}...""" @@ -198,7 +200,9 @@ class TextSplitterVertex(StatefulVertex): # show how many documents are in the list? if not isinstance(self._built_object, UnbuiltObject): - avg_length = sum(len(doc.page_content) for doc in self._built_object) / len(self._built_object) + avg_length = sum(len(doc.page_content) for doc in self._built_object) / len( + self._built_object + ) return f"""{self.vertex_type}({len(self._built_object)} documents) \nAvg. Document Length (characters): {int(avg_length)} \nDocuments: {self._built_object[:3]}...""" @@ -245,18 +249,27 @@ class PromptVertex(StatelessVertex): user_id = kwargs.get("user_id", None) tools = kwargs.get("tools", []) if not self._built or force: - if "input_variables" not in self.params or self.params["input_variables"] is None: + if ( + "input_variables" not in self.params + or self.params["input_variables"] is None + ): self.params["input_variables"] = [] # Check if it is a ZeroShotPrompt and needs a tool if "ShotPrompt" in self.vertex_type: - tools = [tool_node.build(user_id=user_id) for tool_node in tools] if tools is not None else [] + tools = ( + [tool_node.build(user_id=user_id) for tool_node in tools] + if tools is not None + else [] + ) # flatten the list of tools if it is a list of lists # first check if it is a list if tools and isinstance(tools, list) and isinstance(tools[0], list): tools = flatten_list(tools) self.params["tools"] = tools prompt_params = [ - key for key, value in self.params.items() if isinstance(value, str) and key != "format_instructions" + key + for key, value in self.params.items() + if isinstance(value, str) and key != "format_instructions" ] else: prompt_params = ["template"] @@ -266,14 +279,20 @@ class PromptVertex(StatelessVertex): prompt_text = self.params[param] variables = extract_input_variables_from_prompt(prompt_text) self.params["input_variables"].extend(variables) - self.params["input_variables"] = list(set(self.params["input_variables"])) + self.params["input_variables"] = list( + set(self.params["input_variables"]) + ) elif isinstance(self.params, dict): self.params.pop("input_variables", None) await self._build(user_id=user_id) def _built_object_repr(self): - if not self.artifacts or self._built_object is None or not hasattr(self._built_object, "format"): + if ( + not self.artifacts + or self._built_object is None + or not hasattr(self._built_object, "format") + ): return super()._built_object_repr() elif isinstance(self._built_object, UnbuiltObject): return super()._built_object_repr() @@ -285,7 +304,9 @@ class PromptVertex(StatelessVertex): # so the prompt format doesn't break artifacts.pop("handle_keys", None) try: - if not hasattr(self._built_object, "template") and hasattr(self._built_object, "prompt"): + if not hasattr(self._built_object, "template") and hasattr( + self._built_object, "prompt" + ): template = self._built_object.prompt.template else: template = self._built_object.template @@ -293,7 +314,11 @@ class PromptVertex(StatelessVertex): if value: replace_key = "{" + key + "}" template = template.replace(replace_key, value) - return template if isinstance(template, str) else f"{self.vertex_type}({template})" + return ( + template + if isinstance(template, str) + else f"{self.vertex_type}({template})" + ) except KeyError: return str(self._built_object) diff --git a/src/backend/langflow/memory.py b/src/backend/langflow/memory.py index dbf46b24d..c8f73f25e 100644 --- a/src/backend/langflow/memory.py +++ b/src/backend/langflow/memory.py @@ -1,9 +1,10 @@ from typing import Optional, Union +from loguru import logger + from langflow.schema import Record from langflow.services.deps import get_monitor_service from langflow.services.monitor.schema import MessageModel -from loguru import logger def get_messages( diff --git a/src/backend/langflow/services/monitor/schema.py b/src/backend/langflow/services/monitor/schema.py index d4293ecbf..2c1e34cd5 100644 --- a/src/backend/langflow/services/monitor/schema.py +++ b/src/backend/langflow/services/monitor/schema.py @@ -10,7 +10,9 @@ if TYPE_CHECKING: class TransactionModel(BaseModel): id: Optional[int] = Field(default=None, alias="id") - timestamp: Optional[datetime] = Field(default_factory=datetime.now, alias="timestamp") + timestamp: Optional[datetime] = Field( + default_factory=datetime.now, alias="timestamp" + ) source: str target: str target_args: dict @@ -51,8 +53,12 @@ class MessageModel(BaseModel): @classmethod def from_record(cls, record: "Record"): # first check if the record has all the required fields - if not record.data or ("sender" not in record.data and "sender_name" not in record.data): - raise ValueError("The record does not have the required fields 'sender' and 'sender_name' in the data.") + if not record.data or ( + "sender" not in record.data and "sender_name" not in record.data + ): + raise ValueError( + "The record does not have the required fields 'sender' and 'sender_name' in the data." + ) return cls( sender=record.data["sender"], sender_name=record.data["sender_name"],