Fixes SQLDatabaseChain import and deactivates pickle for local cache (#976)

This commit is contained in:
Gabriel Luiz Freitas Almeida 2023-09-27 13:04:49 -03:00 committed by GitHub
commit 7498b85cf9
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
5 changed files with 20 additions and 9 deletions

View file

@ -185,9 +185,10 @@ class Vertex:
# Load the type in value.get('suffixes') using
# what is inside value.get('content')
# value.get('value') is the file name
file_path = value.get("file_path")
params[key] = file_path
if file_path := value.get("file_path"):
params[key] = file_path
else:
raise ValueError(f"File path not found for {self.vertex_type}")
elif value.get("type") in DIRECT_TYPES and params.get(key) is None:
if value.get("type") == "code":
try:

View file

@ -144,6 +144,8 @@ def import_chain(chain: str) -> Type[Chain]:
if chain in CUSTOM_CHAINS:
return CUSTOM_CHAINS[chain]
if chain == "SQLDatabaseChain":
return import_class("langchain_experimental.sql.SQLDatabaseChain")
return import_class(f"langchain.chains.{chain}")

View file

@ -74,13 +74,17 @@ class InMemoryCache(BaseCacheService, Service):
):
# Move the key to the end to make it recently used
self._cache.move_to_end(key)
unpickled = pickle.loads(item["value"])
return unpickled
# Check if the value is pickled
if isinstance(item["value"], bytes):
value = pickle.loads(item["value"])
else:
value = item["value"]
return value
else:
self.delete(key)
return None
def set(self, key, value):
def set(self, key, value, pickle=False):
"""
Add an item to the cache.
@ -98,8 +102,10 @@ class InMemoryCache(BaseCacheService, Service):
# Remove least recently used item
self._cache.popitem(last=False)
# pickle locally to mimic Redis
pickled = pickle.dumps(value)
self._cache[key] = {"value": pickled, "time": time.time()}
if pickle:
value = pickle.dumps(value)
self._cache[key] = {"value": value, "time": time.time()}
def upsert(self, key, value):
"""

View file

@ -207,5 +207,7 @@ def save_uploaded_file(file: UploadFile, folder_name):
def update_build_status(cache_service, flow_id: str, status: BuildStatus):
cached_flow = cache_service[flow_id]
if cached_flow is None:
raise ValueError(f"Flow {flow_id} not found in cache")
cached_flow["status"] = status
cache_service[flow_id] = cached_flow

View file

@ -171,7 +171,7 @@ class DocumentLoaderFrontNode(FrontendNode):
self.template.add_field(
TemplateField(
field_type="dict",
required=True,
required=False,
show=True,
name="metadata",
value={},