perf: improve /logs concurrent access (#2633)

* perf: improve /logs concurrent access

* perf: improve /logs concurrency access

* fix

* [autofix.ci] apply automated fixes

---------

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
Nicolò Boschi 2024-07-11 17:06:34 +02:00 committed by GitHub
commit ab0ab5f306
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 125 additions and 101 deletions

View file

@ -1,5 +1,7 @@
import asyncio
import json
from typing import List, Any
from fastapi import APIRouter, Query, HTTPException, Request
from fastapi.responses import JSONResponse, StreamingResponse
from http import HTTPStatus
@ -9,26 +11,38 @@ log_router = APIRouter(tags=["Log"])
async def event_generator(request: Request):
# latest_timestamp = time.time()
global log_buffer
last_read_item = None
current_not_sent = 0
while not await request.is_disconnected():
to_write: List[Any] = []
with log_buffer.get_write_lock():
if last_read_item is None:
last_read_item = log_buffer.buffer[len(log_buffer.buffer) - 1]
else:
found_last = False
for item in log_buffer.buffer:
if found_last:
to_write.append(item)
last_read_item = item
continue
if item is last_read_item:
found_last = True
continue
last_line = log_buffer.get_last_n(1)
latest_timestamp, _ = last_line.popitem()
while True:
if await request.is_disconnected():
break
new_logs = log_buffer.get_after_timestamp(timestamp=latest_timestamp, lines=100)
if new_logs:
temp_ts = 0.0
for ts, msg in new_logs.items():
if ts > latest_timestamp:
yield f"{json.dumps({ts:msg})}\n"
temp_ts = ts
# for the next query iteration
latest_timestamp = temp_ts
# in case the last item is nomore in the buffer
if not found_last:
for item in log_buffer.buffer:
to_write.append(item)
last_read_item = item
if to_write:
for ts, msg in to_write:
yield f"{json.dumps({ts:msg})}\n\n"
else:
yield ": keepalive\n\n"
current_not_sent += 1
if current_not_sent == 5:
current_not_sent = 0
yield "keepalive\n\n"
await asyncio.sleep(1)
@ -54,9 +68,9 @@ async def stream_logs(
@log_router.get("/logs")
async def logs(
lines_before: int = Query(1, ge=1, description="The number of logs before the timestamp or the last log"),
lines_after: int = Query(0, ge=1, description="The number of logs after the timestamp"),
timestamp: float = Query(0, description="The timestamp to start streaming logs from"),
lines_before: int = Query(0, description="The number of logs before the timestamp or the last log"),
lines_after: int = Query(0, description="The number of logs after the timestamp"),
timestamp: int = Query(0, description="The timestamp to start getting logs from"),
):
global log_buffer
if log_buffer.enabled() is False:
@ -64,23 +78,26 @@ async def logs(
status_code=HTTPStatus.NOT_IMPLEMENTED,
detail="Log retrieval is disabled",
)
logs = dict()
if lines_after > 0 and timestamp == 0:
if lines_after > 0 and lines_before > 0:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="Timestamp is required when requesting logs after the timestamp",
detail="Cannot request logs before and after the timestamp",
)
if lines_after > 0 and timestamp > 0:
logs = log_buffer.get_after_timestamp(timestamp=timestamp, lines=lines_after)
return JSONResponse(content=logs)
if timestamp == 0:
if lines_before > 0:
logs = log_buffer.get_last_n(lines_before)
if timestamp <= 0:
if lines_after > 0:
raise HTTPException(
status_code=HTTPStatus.BAD_REQUEST,
detail="Timestamp is required when requesting logs after the timestamp",
)
if lines_before <= 0:
content = log_buffer.get_last_n(10)
else:
content = log_buffer.get_last_n(lines_before)
else:
if lines_before > 0:
logs = log_buffer.get_before_timestamp(timestamp=timestamp, lines=lines_before)
return JSONResponse(content=logs)
content = log_buffer.get_before_timestamp(timestamp=timestamp, lines=lines_before)
elif lines_after > 0:
content = log_buffer.get_after_timestamp(timestamp=timestamp, lines=lines_after)
else:
content = log_buffer.get_before_timestamp(timestamp=timestamp, lines=10)
return JSONResponse(content=content)

View file

@ -3,8 +3,7 @@ import logging
import os
import sys
from pathlib import Path
from collections import OrderedDict
from itertools import islice
from collections import deque
from threading import Lock, Semaphore
from typing import Optional
@ -31,71 +30,81 @@ class SizedLogBuffer:
if env_buffer_size.isdigit():
self.max = int(env_buffer_size)
self.buffer: OrderedDict[float, str] = OrderedDict()
self.buffer: deque = deque()
self._max_readers = max_readers
self._wlock = Lock()
self._rsemaphore = Semaphore(max_readers)
def get_write_lock(self) -> Lock:
return self._wlock
def write(self, message: str):
record = json.loads(message)
log_entry = record["text"]
epoch = record["record"]["time"]["timestamp"]
# wait until all reader semaphore are released
while self._rsemaphore._value != self._max_readers:
continue
epoch = int(record["record"]["time"]["timestamp"] * 1000)
with self._wlock:
if len(self.buffer) >= self.max:
# remove the oldest log entry if the buffer is full
self.buffer.popitem(last=False)
self.buffer[epoch] = log_entry
for _ in range(len(self.buffer) - self.max + 1):
self.buffer.popleft()
self.buffer.append((epoch, log_entry))
def __len__(self):
return len(self.buffer)
def get_after_timestamp(self, timestamp: float, lines: int = 5) -> dict[float, str]:
def get_after_timestamp(self, timestamp: int, lines: int = 5) -> dict[int, str]:
rc = dict()
# wait until no write
while self._wlock.locked():
continue
self._rsemaphore.acquire()
for ts, msg in self.buffer.items():
if lines == 0:
break
if ts >= timestamp and lines > 0:
try:
with self._wlock:
for ts, msg in self.buffer:
if lines == 0:
break
if ts >= timestamp and lines > 0:
rc[ts] = msg
lines -= 1
finally:
self._rsemaphore.release()
return rc
def get_before_timestamp(self, timestamp: int, lines: int = 5) -> dict[int, str]:
self._rsemaphore.acquire()
try:
with self._wlock:
as_list = list(self.buffer)
i = 0
max_index = -1
for ts, msg in as_list:
if ts >= timestamp:
max_index = i
break
i += 1
if max_index == -1:
return self.get_last_n(lines)
rc = {}
i = 0
start_from = max(max_index - lines, 0)
for ts, msg in as_list:
if start_from <= i < max_index:
rc[ts] = msg
i += 1
return rc
finally:
self._rsemaphore.release()
def get_last_n(self, last_idx: int) -> dict[int, str]:
self._rsemaphore.acquire()
try:
with self._wlock:
as_list = list(self.buffer)
rc = {}
for ts, msg in as_list[-last_idx:]:
rc[ts] = msg
lines -= 1
self._rsemaphore.release()
return rc
def get_before_timestamp(self, timestamp: float, lines: int = 5) -> dict[float, str]:
rc = dict()
# wait until no write
while self._wlock.locked():
continue
self._rsemaphore.acquire()
for ts, msg in reversed(self.buffer.items()):
if lines == 0:
break
if ts < timestamp and lines > 0:
rc[ts] = msg
lines -= 1
self._rsemaphore.release()
return rc
def get_last_n(self, last_idx: int) -> dict[float, str]:
# wait until no write
while self._wlock.locked():
continue
self._rsemaphore.acquire()
rc = dict(islice(reversed(self.buffer.items()), last_idx))
self._rsemaphore.release()
return rc
return rc
finally:
self._rsemaphore.release()
def enabled(self) -> bool:
return self.max > 0

View file

@ -1,9 +1,8 @@
import pytest
import os
import json
from collections import OrderedDict
from unittest.mock import patch
from langflow.utils.logger import SizedLogBuffer # Replace 'your_module' with the actual module name
from langflow.utils.logger import SizedLogBuffer
@pytest.fixture
@ -15,7 +14,6 @@ def test_init_default():
buffer = SizedLogBuffer()
assert buffer.max == 0
assert buffer._max_readers == 20
assert isinstance(buffer.buffer, OrderedDict)
def test_init_with_env_variable():
@ -25,12 +23,12 @@ def test_init_with_env_variable():
def test_write(sized_log_buffer):
message = json.dumps({"text": "Test log", "record": {"time": {"timestamp": 1625097600}}})
message = json.dumps({"text": "Test log", "record": {"time": {"timestamp": 1625097600.1244334}}})
sized_log_buffer.max = 1 # Set max size to 1 for testing
sized_log_buffer.write(message)
assert len(sized_log_buffer.buffer) == 1
assert 1625097600 in sized_log_buffer.buffer
assert sized_log_buffer.buffer[1625097600] == "Test log"
assert 1625097600124 == sized_log_buffer.buffer[0][0]
assert "Test log" == sized_log_buffer.buffer[0][1]
def test_write_overflow(sized_log_buffer):
@ -40,8 +38,8 @@ def test_write_overflow(sized_log_buffer):
sized_log_buffer.write(message)
assert len(sized_log_buffer.buffer) == 2
assert 1625097601 in sized_log_buffer.buffer
assert 1625097602 in sized_log_buffer.buffer
assert 1625097601000 == sized_log_buffer.buffer[0][0]
assert 1625097602000 == sized_log_buffer.buffer[1][0]
def test_len(sized_log_buffer):
@ -59,10 +57,10 @@ def test_get_after_timestamp(sized_log_buffer):
for message in messages:
sized_log_buffer.write(message)
result = sized_log_buffer.get_after_timestamp(1625097602, lines=2)
result = sized_log_buffer.get_after_timestamp(1625097602000, lines=2)
assert len(result) == 2
assert 1625097603 in result
assert 1625097602 in result
assert 1625097603000 in result
assert 1625097602000 in result
def test_get_before_timestamp(sized_log_buffer):
@ -71,10 +69,10 @@ def test_get_before_timestamp(sized_log_buffer):
for message in messages:
sized_log_buffer.write(message)
result = sized_log_buffer.get_before_timestamp(1625097603, lines=2)
result = sized_log_buffer.get_before_timestamp(1625097603000, lines=2)
assert len(result) == 2
assert 1625097601 in result
assert 1625097602 in result
assert 1625097601000 in result
assert 1625097602000 in result
def test_get_last_n(sized_log_buffer):
@ -85,9 +83,9 @@ def test_get_last_n(sized_log_buffer):
result = sized_log_buffer.get_last_n(3)
assert len(result) == 3
assert 1625097602 in result
assert 1625097603 in result
assert 1625097604 in result
assert 1625097602000 in result
assert 1625097603000 in result
assert 1625097604000 in result
def test_enabled(sized_log_buffer):