Update schema and utils modules
This commit is contained in:
parent
cff2563ee1
commit
33cca64f32
2 changed files with 49 additions and 19 deletions
|
|
@ -13,7 +13,7 @@ class Record(BaseModel):
|
|||
"""
|
||||
|
||||
data: dict = {}
|
||||
_default_value = None
|
||||
_default_value: str = ""
|
||||
|
||||
@classmethod
|
||||
def from_document(cls, document: Document) -> "Record":
|
||||
|
|
@ -63,7 +63,9 @@ class Record(BaseModel):
|
|||
return self.data.get(key, self._default_value)
|
||||
except KeyError:
|
||||
# Fallback to default behavior to raise AttributeError for undefined attributes
|
||||
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'")
|
||||
raise AttributeError(
|
||||
f"'{type(self).__name__}' object has no attribute '{key}'"
|
||||
)
|
||||
|
||||
def __setattr__(self, key, value):
|
||||
"""
|
||||
|
|
|
|||
|
|
@ -15,8 +15,12 @@ def remove_ansi_escape_codes(text):
|
|||
return re.sub(r"\x1b\[[0-9;]*[a-zA-Z]", "", text)
|
||||
|
||||
|
||||
def build_template_from_function(name: str, type_to_loader_dict: Dict, add_function: bool = False):
|
||||
classes = [item.__annotations__["return"].__name__ for item in type_to_loader_dict.values()]
|
||||
def build_template_from_function(
|
||||
name: str, type_to_loader_dict: Dict, add_function: bool = False
|
||||
):
|
||||
classes = [
|
||||
item.__annotations__["return"].__name__ for item in type_to_loader_dict.values()
|
||||
]
|
||||
|
||||
# Raise error if name is not in chains
|
||||
if name not in classes:
|
||||
|
|
@ -37,8 +41,10 @@ def build_template_from_function(name: str, type_to_loader_dict: Dict, add_funct
|
|||
for name_, value_ in value.__repr_args__():
|
||||
if name_ == "default_factory":
|
||||
try:
|
||||
variables[class_field_items]["default"] = get_default_factory(
|
||||
module=_class.__base__.__module__, function=value_
|
||||
variables[class_field_items]["default"] = (
|
||||
get_default_factory(
|
||||
module=_class.__base__.__module__, function=value_
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
variables[class_field_items]["default"] = None
|
||||
|
|
@ -46,7 +52,9 @@ def build_template_from_function(name: str, type_to_loader_dict: Dict, add_funct
|
|||
variables[class_field_items][name_] = value_
|
||||
|
||||
variables[class_field_items]["placeholder"] = (
|
||||
docs.params[class_field_items] if class_field_items in docs.params else ""
|
||||
docs.params[class_field_items]
|
||||
if class_field_items in docs.params
|
||||
else ""
|
||||
)
|
||||
# Adding function to base classes to allow
|
||||
# the output to be a function
|
||||
|
|
@ -61,7 +69,9 @@ def build_template_from_function(name: str, type_to_loader_dict: Dict, add_funct
|
|||
}
|
||||
|
||||
|
||||
def build_template_from_class(name: str, type_to_cls_dict: Dict, add_function: bool = False):
|
||||
def build_template_from_class(
|
||||
name: str, type_to_cls_dict: Dict, add_function: bool = False
|
||||
):
|
||||
classes = [item.__name__ for item in type_to_cls_dict.values()]
|
||||
|
||||
# Raise error if name is not in chains
|
||||
|
|
@ -85,9 +95,11 @@ def build_template_from_class(name: str, type_to_cls_dict: Dict, add_function: b
|
|||
for name_, value_ in value.__repr_args__():
|
||||
if name_ == "default_factory":
|
||||
try:
|
||||
variables[class_field_items]["default"] = get_default_factory(
|
||||
module=_class.__base__.__module__,
|
||||
function=value_,
|
||||
variables[class_field_items]["default"] = (
|
||||
get_default_factory(
|
||||
module=_class.__base__.__module__,
|
||||
function=value_,
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
variables[class_field_items]["default"] = None
|
||||
|
|
@ -95,7 +107,9 @@ def build_template_from_class(name: str, type_to_cls_dict: Dict, add_function: b
|
|||
variables[class_field_items][name_] = value_
|
||||
|
||||
variables[class_field_items]["placeholder"] = (
|
||||
docs.params[class_field_items] if class_field_items in docs.params else ""
|
||||
docs.params[class_field_items]
|
||||
if class_field_items in docs.params
|
||||
else ""
|
||||
)
|
||||
base_classes = get_base_classes(_class)
|
||||
# Adding function to base classes to allow
|
||||
|
|
@ -127,7 +141,9 @@ def build_template_from_method(
|
|||
|
||||
# Check if the method exists in this class
|
||||
if not hasattr(_class, method_name):
|
||||
raise ValueError(f"Method {method_name} not found in class {class_name}")
|
||||
raise ValueError(
|
||||
f"Method {method_name} not found in class {class_name}"
|
||||
)
|
||||
|
||||
# Get the method
|
||||
method = getattr(_class, method_name)
|
||||
|
|
@ -146,8 +162,14 @@ def build_template_from_method(
|
|||
"_type": _type,
|
||||
**{
|
||||
name: {
|
||||
"default": (param.default if param.default != param.empty else None),
|
||||
"type": (param.annotation if param.annotation != param.empty else None),
|
||||
"default": (
|
||||
param.default if param.default != param.empty else None
|
||||
),
|
||||
"type": (
|
||||
param.annotation
|
||||
if param.annotation != param.empty
|
||||
else None
|
||||
),
|
||||
"required": param.default == param.empty,
|
||||
}
|
||||
for name, param in params.items()
|
||||
|
|
@ -234,7 +256,9 @@ def sync_to_async(func):
|
|||
return async_wrapper
|
||||
|
||||
|
||||
def format_dict(dictionary: Dict[str, Any], class_name: Optional[str] = None) -> Dict[str, Any]:
|
||||
def format_dict(
|
||||
dictionary: Dict[str, Any], class_name: Optional[str] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Formats a dictionary by removing certain keys and modifying the
|
||||
values of other keys.
|
||||
|
|
@ -320,7 +344,9 @@ def check_list_type(_type: str, value: Dict[str, Any]) -> str:
|
|||
The modified type string.
|
||||
"""
|
||||
if any(list_type in _type for list_type in ["List", "Sequence", "Set"]):
|
||||
_type = _type.replace("List[", "").replace("Sequence[", "").replace("Set[", "")[:-1]
|
||||
_type = (
|
||||
_type.replace("List[", "").replace("Sequence[", "").replace("Set[", "")[:-1]
|
||||
)
|
||||
value["list"] = True
|
||||
else:
|
||||
value["list"] = False
|
||||
|
|
@ -423,7 +449,9 @@ def set_headers_value(value: Dict[str, Any]) -> None:
|
|||
value["value"] = """{"Authorization": "Bearer <token>"}"""
|
||||
|
||||
|
||||
def add_options_to_field(value: Dict[str, Any], class_name: Optional[str], key: str) -> None:
|
||||
def add_options_to_field(
|
||||
value: Dict[str, Any], class_name: Optional[str], key: str
|
||||
) -> None:
|
||||
"""
|
||||
Adds options to the field based on the class name and key.
|
||||
"""
|
||||
|
|
@ -442,7 +470,7 @@ def add_options_to_field(value: Dict[str, Any], class_name: Optional[str], key:
|
|||
|
||||
def build_loader_repr_from_records(records: List[Record]) -> str:
|
||||
if records:
|
||||
avg_length = sum(len(doc.page_content) for doc in records) / len(records)
|
||||
avg_length = sum(len(doc.text) for doc in records) / len(records)
|
||||
return f"""{len(records)} records
|
||||
\nAvg. Record Length (characters): {int(avg_length)}
|
||||
Records: {records[:3]}..."""
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue