feat: add challenge and red-blue competitions across API and web
This commit is contained in:
parent
f5161d9add
commit
8fd3c4bb64
77 changed files with 5355 additions and 24 deletions
|
|
@ -129,6 +129,10 @@ from .workspace import (
|
|||
workspace,
|
||||
)
|
||||
|
||||
# Import custom challenge controllers
|
||||
from . import challenges as challenges
|
||||
from . import red_blue_challenges as red_blue_challenges
|
||||
|
||||
api.add_namespace(console_ns)
|
||||
|
||||
__all__ = [
|
||||
|
|
@ -204,4 +208,6 @@ __all__ = [
|
|||
"workflow_run",
|
||||
"workflow_statistic",
|
||||
"workspace",
|
||||
"challenges",
|
||||
"red_blue_challenges",
|
||||
]
|
||||
|
|
|
|||
|
|
@ -12,7 +12,6 @@ from controllers.console.app.wraps import get_app_model
|
|||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
cloud_edition_billing_resource_check,
|
||||
enterprise_license_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.ops.ops_trace_manager import OpsTraceManager
|
||||
|
|
@ -53,7 +52,6 @@ class AppListApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
"""Get app list"""
|
||||
|
||||
|
|
@ -166,7 +164,6 @@ class AppApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@get_app_model
|
||||
@marshal_with(app_detail_fields_with_site)
|
||||
def get(self, app_model):
|
||||
|
|
|
|||
154
api/controllers/console/challenges.py
Normal file
154
api/controllers/console/challenges.py
Normal file
|
|
@ -0,0 +1,154 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.console import console_ns as api
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from libs.login import login_required
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user
|
||||
from models.challenge import Challenge
|
||||
|
||||
|
||||
@api.route("/challenges")
|
||||
class ChallengeListCreateApi(Resource):
|
||||
@api.doc("list_challenges")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
if not tenant_id:
|
||||
# no active workspace selected; return empty list to avoid leaking data
|
||||
return {"result": "success", "data": []}
|
||||
rows = (
|
||||
db.session.query(Challenge)
|
||||
.filter(Challenge.tenant_id == tenant_id)
|
||||
.order_by(Challenge.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
return {
|
||||
"result": "success",
|
||||
"data": [
|
||||
{
|
||||
"id": r.id,
|
||||
"name": r.name,
|
||||
"description": r.description,
|
||||
"goal": r.goal,
|
||||
"is_active": r.is_active,
|
||||
"success_type": r.success_type,
|
||||
"success_pattern": r.success_pattern,
|
||||
"scoring_strategy": r.scoring_strategy,
|
||||
"app_id": r.app_id,
|
||||
"workflow_id": r.workflow_id,
|
||||
}
|
||||
for r in rows
|
||||
],
|
||||
}
|
||||
|
||||
@api.doc("create_challenge")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tenant_id", type=str, required=False, location="json")
|
||||
parser.add_argument("app_id", type=str, required=True, location="json")
|
||||
parser.add_argument("workflow_id", type=str, required=False, location="json")
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
parser.add_argument("description", type=str, required=False, location="json")
|
||||
parser.add_argument("goal", type=str, required=False, location="json")
|
||||
parser.add_argument("success_type", type=str, required=False, location="json")
|
||||
parser.add_argument("success_pattern", type=str, required=False, location="json")
|
||||
parser.add_argument("scoring_strategy", type=str, required=False, location="json")
|
||||
parser.add_argument("is_active", type=bool, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
c = Challenge()
|
||||
c.tenant_id = args.get("tenant_id") or current_user.current_tenant_id
|
||||
c.app_id = args["app_id"]
|
||||
# Convert empty string to None for UUID field
|
||||
workflow_id = args.get("workflow_id")
|
||||
c.workflow_id = workflow_id if workflow_id else None
|
||||
c.name = args["name"]
|
||||
c.description = args.get("description")
|
||||
c.goal = args.get("goal")
|
||||
if args.get("success_type"):
|
||||
c.success_type = args["success_type"]
|
||||
c.success_pattern = args.get("success_pattern")
|
||||
if args.get("scoring_strategy"):
|
||||
c.scoring_strategy = args["scoring_strategy"]
|
||||
if args.get("is_active") is not None:
|
||||
c.is_active = args["is_active"]
|
||||
db.session.add(c)
|
||||
db.session.commit()
|
||||
return {"result": "success", "data": {"id": c.id}}, 201
|
||||
|
||||
|
||||
@api.route("/challenges/<uuid:challenge_id>")
|
||||
class ChallengeDetailApi(Resource):
|
||||
@api.doc("get_challenge")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, challenge_id):
|
||||
c = db.session.get(Challenge, str(challenge_id))
|
||||
if not c:
|
||||
return {"result": "not_found"}, 404
|
||||
return {
|
||||
"result": "success",
|
||||
"data": {
|
||||
"id": c.id,
|
||||
"name": c.name,
|
||||
"description": c.description,
|
||||
"goal": c.goal,
|
||||
"is_active": c.is_active,
|
||||
"success_type": c.success_type,
|
||||
"success_pattern": c.success_pattern,
|
||||
"scoring_strategy": c.scoring_strategy,
|
||||
"app_id": c.app_id,
|
||||
"workflow_id": c.workflow_id,
|
||||
},
|
||||
}
|
||||
|
||||
@api.doc("update_challenge")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, challenge_id):
|
||||
c = db.session.get(Challenge, str(challenge_id))
|
||||
if not c:
|
||||
return {"result": "not_found"}, 404
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=False, location="json")
|
||||
parser.add_argument("description", type=str, required=False, location="json")
|
||||
parser.add_argument("goal", type=str, required=False, location="json")
|
||||
parser.add_argument("is_active", type=bool, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
if args.get("name"):
|
||||
c.name = args["name"]
|
||||
if args.get("description") is not None:
|
||||
c.description = args["description"]
|
||||
if args.get("goal") is not None:
|
||||
c.goal = args["goal"]
|
||||
if args.get("is_active") is not None:
|
||||
c.is_active = bool(args["is_active"])
|
||||
db.session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@api.doc("delete_challenge")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, challenge_id):
|
||||
c = db.session.get(Challenge, str(challenge_id))
|
||||
if not c:
|
||||
return {"result": "not_found"}, 404
|
||||
db.session.delete(c)
|
||||
db.session.commit()
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
|
|
@ -5,7 +5,7 @@ from flask_restx import Resource, marshal_with, reqparse
|
|||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, enterprise_license_required, setup_required
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.dataset_fields import dataset_metadata_fields
|
||||
from libs.login import login_required
|
||||
from services.dataset_service import DatasetService
|
||||
|
|
@ -21,7 +21,6 @@ class DatasetMetadataCreateApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@marshal_with(dataset_metadata_fields)
|
||||
def post(self, dataset_id):
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -42,7 +41,6 @@ class DatasetMetadataCreateApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
|
|
@ -56,7 +54,6 @@ class DatasetMetadataApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@marshal_with(dataset_metadata_fields)
|
||||
def patch(self, dataset_id, metadata_id):
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
@ -77,7 +74,6 @@ class DatasetMetadataApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def delete(self, dataset_id, metadata_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
metadata_id_str = str(metadata_id)
|
||||
|
|
@ -95,7 +91,6 @@ class DatasetMetadataBuiltInFieldApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
built_in_fields = MetadataService.get_built_in_fields()
|
||||
return {"fields": built_in_fields}, 200
|
||||
|
|
@ -106,7 +101,6 @@ class DatasetMetadataBuiltInFieldActionApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, dataset_id, action: Literal["enable", "disable"]):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
|
|
@ -126,7 +120,6 @@ class DocumentMetadataEditApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, dataset_id):
|
||||
dataset_id_str = str(dataset_id)
|
||||
dataset = DatasetService.get_dataset(dataset_id_str)
|
||||
|
|
|
|||
|
|
@ -7,7 +7,6 @@ from sqlalchemy.orm import Session
|
|||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
enterprise_license_required,
|
||||
knowledge_pipeline_publish_enabled,
|
||||
setup_required,
|
||||
)
|
||||
|
|
@ -37,7 +36,6 @@ class PipelineTemplateListApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
type = request.args.get("type", default="built-in", type=str)
|
||||
language = request.args.get("language", default="en-US", type=str)
|
||||
|
|
@ -51,7 +49,6 @@ class PipelineTemplateDetailApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self, template_id: str):
|
||||
type = request.args.get("type", default="built-in", type=str)
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
|
|
@ -64,7 +61,6 @@ class CustomizedPipelineTemplateApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def patch(self, template_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument(
|
||||
|
|
@ -95,7 +91,6 @@ class CustomizedPipelineTemplateApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def delete(self, template_id: str):
|
||||
RagPipelineService.delete_customized_pipeline_template(template_id)
|
||||
return 200
|
||||
|
|
@ -103,7 +98,6 @@ class CustomizedPipelineTemplateApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def post(self, template_id: str):
|
||||
with Session(db.engine) as session:
|
||||
template = (
|
||||
|
|
@ -120,7 +114,6 @@ class PublishCustomizedPipelineTemplateApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
@knowledge_pipeline_publish_enabled
|
||||
def post(self, pipeline_id: str):
|
||||
parser = reqparse.RequestParser()
|
||||
|
|
|
|||
145
api/controllers/console/red_blue_challenges.py
Normal file
145
api/controllers/console/red_blue_challenges.py
Normal file
|
|
@ -0,0 +1,145 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from flask_restx import Resource, reqparse
|
||||
|
||||
from controllers.console import console_ns as api
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_user, login_required
|
||||
from models.red_blue import RedBlueChallenge, TeamPairing
|
||||
|
||||
|
||||
@api.route("/red-blue-challenges")
|
||||
class RedBlueListCreateApi(Resource):
|
||||
@api.doc("list_red_blue_challenges")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
tenant_id = current_user.current_tenant_id
|
||||
if not tenant_id:
|
||||
return {"result": "success", "data": []}
|
||||
rows = (
|
||||
db.session.query(RedBlueChallenge)
|
||||
.filter(RedBlueChallenge.tenant_id == tenant_id)
|
||||
.order_by(RedBlueChallenge.created_at.desc())
|
||||
.all()
|
||||
)
|
||||
return {
|
||||
"result": "success",
|
||||
"data": [
|
||||
{
|
||||
"id": r.id,
|
||||
"name": r.name,
|
||||
"description": r.description,
|
||||
"is_active": r.is_active,
|
||||
}
|
||||
for r in rows
|
||||
],
|
||||
}
|
||||
|
||||
@api.doc("create_red_blue_challenge")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("tenant_id", type=str, required=True, location="json")
|
||||
parser.add_argument("app_id", type=str, required=True, location="json")
|
||||
parser.add_argument("name", type=str, required=True, location="json")
|
||||
parser.add_argument("description", type=str, required=False, location="json")
|
||||
parser.add_argument("judge_suite", type=dict, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
c = RedBlueChallenge()
|
||||
c.tenant_id = args.get("tenant_id") or current_user.current_tenant_id
|
||||
c.app_id = args["app_id"]
|
||||
c.name = args["name"]
|
||||
c.description = args.get("description")
|
||||
c.judge_suite = args["judge_suite"]
|
||||
db.session.add(c)
|
||||
db.session.commit()
|
||||
return {"result": "success", "data": {"id": c.id}}, 201
|
||||
|
||||
|
||||
@api.route("/red-blue-challenges/<uuid:challenge_id>")
|
||||
class RedBlueDetailApi(Resource):
|
||||
@api.doc("get_red_blue_challenge")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, challenge_id):
|
||||
c = db.session.get(RedBlueChallenge, str(challenge_id))
|
||||
if not c:
|
||||
return {"result": "not_found"}, 404
|
||||
return {
|
||||
"result": "success",
|
||||
"data": {"id": c.id, "name": c.name, "description": c.description, "is_active": c.is_active},
|
||||
}
|
||||
|
||||
@api.doc("update_red_blue_challenge")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def patch(self, challenge_id):
|
||||
c = db.session.get(RedBlueChallenge, str(challenge_id))
|
||||
if not c:
|
||||
return {"result": "not_found"}, 404
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("name", type=str, required=False, location="json")
|
||||
parser.add_argument("description", type=str, required=False, location="json")
|
||||
parser.add_argument("is_active", type=bool, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
if args.get("name"):
|
||||
c.name = args["name"]
|
||||
if args.get("description") is not None:
|
||||
c.description = args["description"]
|
||||
if args.get("is_active") is not None:
|
||||
c.is_active = bool(args["is_active"])
|
||||
db.session.commit()
|
||||
return {"result": "success"}
|
||||
|
||||
@api.doc("delete_red_blue_challenge")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def delete(self, challenge_id):
|
||||
c = db.session.get(RedBlueChallenge, str(challenge_id))
|
||||
if not c:
|
||||
return {"result": "not_found"}, 404
|
||||
db.session.delete(c)
|
||||
db.session.commit()
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@api.route("/red-blue-challenges/<uuid:challenge_id>/pairings")
|
||||
class RedBluePairingsApi(Resource):
|
||||
@api.doc("list_red_blue_pairings")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, challenge_id):
|
||||
rows = (
|
||||
db.session.query(TeamPairing)
|
||||
.filter(TeamPairing.red_blue_challenge_id == str(challenge_id))
|
||||
.order_by(TeamPairing.created_at.desc())
|
||||
.limit(100)
|
||||
.all()
|
||||
)
|
||||
return {
|
||||
"result": "success",
|
||||
"data": [
|
||||
{
|
||||
"id": r.id,
|
||||
"red_points": r.red_points,
|
||||
"blue_points": r.blue_points,
|
||||
"judge_rating": r.judge_rating,
|
||||
"created_at": r.created_at.isoformat() if hasattr(r.created_at, "isoformat") else None,
|
||||
}
|
||||
for r in rows
|
||||
],
|
||||
}
|
||||
|
||||
|
|
@ -29,7 +29,6 @@ from controllers.console.wraps import (
|
|||
account_initialization_required,
|
||||
cloud_edition_billing_enabled,
|
||||
enable_change_email,
|
||||
enterprise_license_required,
|
||||
only_edition_cloud,
|
||||
setup_required,
|
||||
)
|
||||
|
|
@ -102,7 +101,6 @@ class AccountProfileApi(Resource):
|
|||
@login_required
|
||||
@account_initialization_required
|
||||
@marshal_with(account_fields)
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("Invalid user account")
|
||||
|
|
|
|||
|
|
@ -13,7 +13,6 @@ from configs import dify_config
|
|||
from controllers.console import api
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
enterprise_license_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.mcp.auth.auth_flow import auth, handle_callback
|
||||
|
|
@ -667,7 +666,6 @@ class ToolLabelsApi(Resource):
|
|||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@enterprise_license_required
|
||||
def get(self):
|
||||
return jsonable_encoder(ToolLabelsService.list_tool_labels())
|
||||
|
||||
|
|
|
|||
|
|
@ -24,12 +24,15 @@ from . import (
|
|||
files,
|
||||
forgot_password,
|
||||
login,
|
||||
register,
|
||||
message,
|
||||
passport,
|
||||
remote_files,
|
||||
saved_message,
|
||||
site,
|
||||
workflow,
|
||||
challenges,
|
||||
red_blue_challenges,
|
||||
)
|
||||
|
||||
api.add_namespace(web_ns)
|
||||
|
|
@ -45,6 +48,7 @@ __all__ = [
|
|||
"files",
|
||||
"forgot_password",
|
||||
"login",
|
||||
"register",
|
||||
"message",
|
||||
"passport",
|
||||
"remote_files",
|
||||
|
|
@ -52,4 +56,6 @@ __all__ = [
|
|||
"site",
|
||||
"web_ns",
|
||||
"workflow",
|
||||
"challenges",
|
||||
"red_blue_challenges",
|
||||
]
|
||||
|
|
|
|||
121
api/controllers/web/challenges.py
Normal file
121
api/controllers/web/challenges.py
Normal file
|
|
@ -0,0 +1,121 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.web import web_ns
|
||||
from extensions.ext_database import db
|
||||
from sqlalchemy import select
|
||||
|
||||
from models.challenge import Challenge, ChallengeAttempt
|
||||
from models.model import App, Site
|
||||
|
||||
|
||||
@web_ns.route("/challenges")
|
||||
class ChallengeListApi(Resource):
|
||||
def get(self):
|
||||
q = db.session.query(Challenge).filter(Challenge.is_active.is_(True)).order_by(Challenge.created_at.desc())
|
||||
items = []
|
||||
for c in q.all():
|
||||
app = db.session.get(App, c.app_id) if c.app_id else None
|
||||
site_code = None
|
||||
if c.app_id:
|
||||
site = db.session.execute(
|
||||
select(Site).where(Site.app_id == c.app_id, Site.status == "normal")
|
||||
).scalar_one_or_none()
|
||||
site_code = site.code if site else None
|
||||
items.append({
|
||||
"id": c.id,
|
||||
"name": c.name,
|
||||
"description": c.description,
|
||||
"goal": c.goal,
|
||||
"app_id": c.app_id,
|
||||
"workflow_id": c.workflow_id,
|
||||
"app_mode": app.mode if app else None,
|
||||
"app_site_code": site_code,
|
||||
})
|
||||
return {"result": "success", "data": items}
|
||||
|
||||
|
||||
@web_ns.route("/challenges/<uuid:challenge_id>")
|
||||
class ChallengeDetailApi(Resource):
|
||||
def get(self, challenge_id):
|
||||
c = db.session.get(Challenge, str(challenge_id))
|
||||
if not c:
|
||||
return {"result": "not_found"}, 404
|
||||
|
||||
app = db.session.get(App, c.app_id) if c.app_id else None
|
||||
site_code = None
|
||||
if c.app_id:
|
||||
site = db.session.execute(
|
||||
select(Site).where(Site.app_id == c.app_id, Site.status == "normal")
|
||||
).scalar_one_or_none()
|
||||
site_code = site.code if site else None
|
||||
data = {
|
||||
"id": c.id,
|
||||
"name": c.name,
|
||||
"description": c.description,
|
||||
"goal": c.goal,
|
||||
"is_active": c.is_active,
|
||||
"app_id": c.app_id,
|
||||
"workflow_id": c.workflow_id,
|
||||
"app_mode": app.mode if app else None,
|
||||
"app_site_code": site_code,
|
||||
}
|
||||
return {"result": "success", "data": data}
|
||||
|
||||
|
||||
@web_ns.route("/challenges/<uuid:challenge_id>/leaderboard")
|
||||
class ChallengeLeaderboardApi(Resource):
|
||||
def get(self, challenge_id):
|
||||
limit = 20
|
||||
|
||||
# Get the challenge to determine scoring strategy
|
||||
challenge = db.session.get(Challenge, str(challenge_id))
|
||||
if not challenge:
|
||||
return {"result": "not_found"}, 404
|
||||
|
||||
scoring_strategy = challenge.scoring_strategy or 'highest_rating'
|
||||
|
||||
# Build query based on scoring strategy
|
||||
q = db.session.query(ChallengeAttempt).filter(
|
||||
ChallengeAttempt.challenge_id == str(challenge_id),
|
||||
ChallengeAttempt.succeeded.is_(True)
|
||||
)
|
||||
|
||||
# Apply sorting based on strategy
|
||||
if scoring_strategy == 'first':
|
||||
# Earliest successful attempt wins
|
||||
q = q.order_by(ChallengeAttempt.created_at.asc())
|
||||
elif scoring_strategy == 'fastest':
|
||||
# Lowest elapsed_ms wins
|
||||
q = q.order_by(ChallengeAttempt.elapsed_ms.asc().nullslast(), ChallengeAttempt.created_at.asc())
|
||||
elif scoring_strategy == 'fewest_tokens':
|
||||
# Lowest tokens_total wins
|
||||
q = q.order_by(ChallengeAttempt.tokens_total.asc().nullslast(), ChallengeAttempt.created_at.asc())
|
||||
elif scoring_strategy == 'highest_rating':
|
||||
# Highest judge_rating wins, ties broken by earliest
|
||||
q = q.order_by(ChallengeAttempt.judge_rating.desc().nullslast(), ChallengeAttempt.created_at.asc())
|
||||
elif scoring_strategy == 'custom':
|
||||
# Custom score field (computed by plugin)
|
||||
q = q.order_by(ChallengeAttempt.score.desc().nullslast(), ChallengeAttempt.created_at.asc())
|
||||
else:
|
||||
# Default to highest_rating
|
||||
q = q.order_by(ChallengeAttempt.judge_rating.desc().nullslast(), ChallengeAttempt.created_at.asc())
|
||||
|
||||
rows = q.limit(limit).all()
|
||||
data = [
|
||||
{
|
||||
"attempt_id": r.id,
|
||||
"account_id": r.account_id,
|
||||
"end_user_id": r.end_user_id,
|
||||
"score": r.score,
|
||||
"judge_rating": r.judge_rating,
|
||||
"tokens_total": r.tokens_total,
|
||||
"elapsed_ms": r.elapsed_ms,
|
||||
"created_at": r.created_at.isoformat() if hasattr(r.created_at, "isoformat") else None,
|
||||
}
|
||||
for r in rows
|
||||
]
|
||||
return {"result": "success", "data": data}
|
||||
|
||||
|
||||
85
api/controllers/web/red_blue_challenges.py
Normal file
85
api/controllers/web/red_blue_challenges.py
Normal file
|
|
@ -0,0 +1,85 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.web import web_ns
|
||||
from extensions.ext_database import db
|
||||
from models.red_blue import RedBlueChallenge, TeamPairing
|
||||
from services.red_blue_service import RedBlueService
|
||||
|
||||
|
||||
@web_ns.route("/red-blue-challenges")
|
||||
class RedBlueListApi(Resource):
|
||||
def get(self):
|
||||
q = db.session.query(RedBlueChallenge).filter(RedBlueChallenge.is_active.is_(True))
|
||||
items = [
|
||||
{
|
||||
"id": c.id,
|
||||
"name": c.name,
|
||||
"description": c.description,
|
||||
}
|
||||
for c in q.all()
|
||||
]
|
||||
return {"result": "success", "data": items}
|
||||
|
||||
|
||||
@web_ns.route("/red-blue-challenges/<uuid:challenge_id>")
|
||||
class RedBlueDetailApi(Resource):
|
||||
def get(self, challenge_id):
|
||||
c = db.session.get(RedBlueChallenge, str(challenge_id))
|
||||
if not c:
|
||||
return {"result": "not_found"}, 404
|
||||
data = {
|
||||
"id": c.id,
|
||||
"name": c.name,
|
||||
"description": c.description,
|
||||
}
|
||||
return {"result": "success", "data": data}
|
||||
|
||||
|
||||
@web_ns.route("/red-blue-challenges/<uuid:challenge_id>/submit")
|
||||
class RedBlueSubmitApi(Resource):
|
||||
def post(self, challenge_id):
|
||||
payload = request.get_json(force=True) or {}
|
||||
team = payload.get("team")
|
||||
prompt = payload.get("prompt")
|
||||
if team not in ("red", "blue") or not prompt:
|
||||
return {"result": "bad_request"}, 400
|
||||
c = db.session.get(RedBlueChallenge, str(challenge_id))
|
||||
if not c:
|
||||
return {"result": "not_found"}, 404
|
||||
sub = RedBlueService.submit_prompt(
|
||||
challenge_id=str(challenge_id),
|
||||
tenant_id=c.tenant_id,
|
||||
team=team,
|
||||
prompt=prompt,
|
||||
account_id=None,
|
||||
end_user_id=None,
|
||||
)
|
||||
return {"result": "success", "data": {"id": sub.id}}, 201
|
||||
|
||||
|
||||
@web_ns.route("/red-blue-challenges/<uuid:challenge_id>/leaderboard")
|
||||
class RedBlueLeaderboardApi(Resource):
|
||||
def get(self, challenge_id):
|
||||
# aggregate simple totals
|
||||
red = (
|
||||
db.session.query(db.func.coalesce(db.func.sum(TeamPairing.red_points), 0.0))
|
||||
.filter(TeamPairing.red_blue_challenge_id == str(challenge_id))
|
||||
.scalar()
|
||||
)
|
||||
blue = (
|
||||
db.session.query(db.func.coalesce(db.func.sum(TeamPairing.blue_points), 0.0))
|
||||
.filter(TeamPairing.red_blue_challenge_id == str(challenge_id))
|
||||
.scalar()
|
||||
)
|
||||
total = (red or 0.0) + (blue or 0.0)
|
||||
data = {
|
||||
"red_points": float(red or 0.0),
|
||||
"blue_points": float(blue or 0.0),
|
||||
"red_ratio": (float(red or 0.0) / total) if total else 0.0,
|
||||
"blue_ratio": (float(blue or 0.0) / total) if total else 0.0,
|
||||
}
|
||||
return {"result": "success", "data": data}
|
||||
|
||||
30
api/controllers/web/register.py
Normal file
30
api/controllers/web/register.py
Normal file
|
|
@ -0,0 +1,30 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.web import web_ns
|
||||
from extensions.ext_database import db
|
||||
from services.account_service import RegisterService
|
||||
|
||||
|
||||
@web_ns.route('/register')
|
||||
class WebRegisterApi(Resource):
|
||||
def post(self):
|
||||
payload = request.get_json(force=True) or {}
|
||||
email = payload.get('email')
|
||||
name = payload.get('name') or 'Player'
|
||||
password = payload.get('password')
|
||||
if not email or not password:
|
||||
return { 'result': 'bad_request' }, 400
|
||||
account = RegisterService.register(
|
||||
email=email,
|
||||
name=name,
|
||||
password=password,
|
||||
is_setup=False,
|
||||
create_workspace_required=False,
|
||||
)
|
||||
db.session.commit()
|
||||
return { 'result': 'success', 'data': { 'account_id': account.id } }, 201
|
||||
|
||||
|
||||
|
|
@ -58,6 +58,9 @@ class NodeType(StrEnum):
|
|||
DOCUMENT_EXTRACTOR = "document-extractor"
|
||||
LIST_OPERATOR = "list-operator"
|
||||
AGENT = "agent"
|
||||
CHALLENGE_EVALUATOR = "challenge-evaluator"
|
||||
JUDGING_LLM = "judging-llm"
|
||||
TEAM_CHALLENGE = "team-challenge"
|
||||
|
||||
|
||||
class NodeExecutionType(StrEnum):
|
||||
|
|
|
|||
3
api/core/workflow/nodes/challenge_evaluator/__init__.py
Normal file
3
api/core/workflow/nodes/challenge_evaluator/__init__.py
Normal file
|
|
@ -0,0 +1,3 @@
|
|||
from .node import ChallengeEvaluatorNode
|
||||
|
||||
__all__ = ['ChallengeEvaluatorNode']
|
||||
258
api/core/workflow/nodes/challenge_evaluator/node.py
Normal file
258
api/core/workflow/nodes/challenge_evaluator/node.py
Normal file
|
|
@ -0,0 +1,258 @@
|
|||
# pyright: reportImplicitRelativeImport=none
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.variables.segments import Segment
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from extensions.ext_database import db
|
||||
from models.challenge import Challenge
|
||||
from services.challenge_scorer_service import ChallengeScorerService
|
||||
from services.challenge_service import ChallengeService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChallengeEvaluatorNode(Node):
|
||||
node_type = NodeType.CHALLENGE_EVALUATOR
|
||||
execution_type = NodeExecutionType.EXECUTABLE
|
||||
|
||||
_node_data: BaseNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
# Using BaseNodeData to carry title/desc; node data is accessed directly
|
||||
self._node_data = BaseNodeData.model_validate(data)
|
||||
self._config: dict[str, Any] = data
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return getattr(self._node_data, 'error_strategy', None)
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return getattr(self._node_data, 'retry_config', RetryConfig())
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return getattr(self._node_data, 'title', 'Challenge Evaluator')
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return getattr(self._node_data, 'desc', None)
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return getattr(self._node_data, 'default_value_dict', {})
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
# Resolve response text from selector in config.inputs.response (frontend schema)
|
||||
output_text = ''
|
||||
source_selector = None
|
||||
inputs_cfg = self._config.get('inputs') or {}
|
||||
if isinstance(inputs_cfg, dict):
|
||||
source_selector = inputs_cfg.get('response')
|
||||
# fallback to older key if any
|
||||
source_selector = source_selector or self._config.get('value_selector')
|
||||
|
||||
# Check evaluation mode from config
|
||||
evaluation_mode = self._config.get('evaluation_mode', 'rules')
|
||||
|
||||
logger.info("ChallengeEvaluator - evaluation_mode: %s, source_selector: %s", evaluation_mode, source_selector)
|
||||
|
||||
# Initialize judge variables
|
||||
is_judge_input = False
|
||||
judge_passed = False
|
||||
judge_rating = 0
|
||||
judge_feedback_from_input = ''
|
||||
output_text = ''
|
||||
|
||||
def _segment_to_value(segment: Segment | None) -> Any:
|
||||
if segment is None:
|
||||
return None
|
||||
if hasattr(segment, "to_object"):
|
||||
try:
|
||||
return segment.to_object()
|
||||
except Exception: # pragma: no cover - defensive
|
||||
pass
|
||||
return getattr(segment, "value", segment)
|
||||
|
||||
# If evaluation_mode is 'llm-judge', try to read from upstream Judging LLM node
|
||||
if evaluation_mode == 'llm-judge' and source_selector and len(source_selector) >= 1:
|
||||
try:
|
||||
node_id = source_selector[0]
|
||||
# Retrieve judge outputs as Segments and convert to primitive values
|
||||
passed_segment = self.graph_runtime_state.variable_pool.get([node_id, 'judge_passed'])
|
||||
rating_segment = self.graph_runtime_state.variable_pool.get([node_id, 'judge_rating'])
|
||||
feedback_segment = self.graph_runtime_state.variable_pool.get([node_id, 'judge_feedback'])
|
||||
|
||||
potential_judge_passed = _segment_to_value(passed_segment)
|
||||
potential_judge_rating = _segment_to_value(rating_segment)
|
||||
potential_judge_feedback = _segment_to_value(feedback_segment)
|
||||
|
||||
logger.info(
|
||||
"ChallengeEvaluator - Reading judge outputs: passed=%s, rating=%s, feedback=%s",
|
||||
potential_judge_passed,
|
||||
potential_judge_rating,
|
||||
potential_judge_feedback,
|
||||
)
|
||||
|
||||
# If judge_passed exists, we successfully read from a Judging LLM node
|
||||
if potential_judge_passed is not None:
|
||||
is_judge_input = True
|
||||
judge_passed = bool(potential_judge_passed)
|
||||
judge_rating = int(potential_judge_rating or 0)
|
||||
judge_feedback_from_input = str(potential_judge_feedback or '')
|
||||
logger.info(
|
||||
"ChallengeEvaluator - Judge input successfully read! passed=%s, rating=%s, feedback=%s",
|
||||
judge_passed,
|
||||
judge_rating,
|
||||
judge_feedback_from_input,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("ChallengeEvaluator - Error reading judge outputs: %s", e, exc_info=True)
|
||||
is_judge_input = False
|
||||
|
||||
# If not using judge input, get text output for rules-based evaluation
|
||||
if not is_judge_input and source_selector:
|
||||
try:
|
||||
segment = self.graph_runtime_state.variable_pool.get(source_selector)
|
||||
if segment is None:
|
||||
output_text = ''
|
||||
elif hasattr(segment, 'text'):
|
||||
output_text = segment.text
|
||||
else:
|
||||
output_text = str(_segment_to_value(segment) or '')
|
||||
except Exception:
|
||||
output_text = ''
|
||||
|
||||
# Evaluate based on mode
|
||||
if is_judge_input:
|
||||
ok = judge_passed
|
||||
details = {
|
||||
'mode': 'llm-judge',
|
||||
'rating': judge_rating,
|
||||
'feedback': judge_feedback_from_input,
|
||||
}
|
||||
else:
|
||||
# Rules-based evaluation (only if not using judge input)
|
||||
ok, details = ChallengeService.evaluate_outcome(output_text, self._config)
|
||||
|
||||
# optional persistence if config carries challenge_id
|
||||
challenge_id = self._config.get('challenge_id')
|
||||
if challenge_id:
|
||||
try:
|
||||
# Calculate elapsed time in milliseconds
|
||||
elapsed_ms = int((time.time() - self.graph_runtime_state.start_at) * 1000)
|
||||
|
||||
# Get total tokens used in the workflow so far
|
||||
tokens_total = self.graph_runtime_state.total_tokens
|
||||
|
||||
# Extract judge_rating from details if available (for highest_rating strategy)
|
||||
judge_rating = None
|
||||
judge_feedback = None
|
||||
if isinstance(details, dict):
|
||||
judge_rating = details.get('rating')
|
||||
judge_feedback = details.get('feedback')
|
||||
|
||||
# Load challenge to check scoring strategy
|
||||
challenge = db.session.get(Challenge, str(challenge_id))
|
||||
|
||||
# Score field is reserved for custom scoring plugins.
|
||||
# For built-in strategies (first, fastest, fewest_tokens, highest_rating),
|
||||
# the leaderboard sorts by specific columns (created_at, elapsed_ms, tokens_total, judge_rating).
|
||||
score = None
|
||||
|
||||
# If custom scoring is configured, compute score using plugin
|
||||
if challenge and challenge.scoring_strategy == 'custom':
|
||||
try:
|
||||
metrics = {
|
||||
'succeeded': ok,
|
||||
'tokens_total': tokens_total,
|
||||
'elapsed_ms': elapsed_ms,
|
||||
'rating': judge_rating,
|
||||
'created_at': int(time.time() * 1000),
|
||||
}
|
||||
|
||||
ctx = {
|
||||
'tenant_id': self.tenant_id,
|
||||
'app_id': self.app_id,
|
||||
'workflow_id': self.workflow_id,
|
||||
'challenge_id': str(challenge_id),
|
||||
'end_user_id': None,
|
||||
'timeout_ms': 5000,
|
||||
}
|
||||
|
||||
result = ChallengeScorerService.score_with_plugin(
|
||||
scorer_plugin_id=challenge.scoring_plugin_id,
|
||||
scorer_entrypoint=challenge.scoring_entrypoint,
|
||||
metrics=metrics,
|
||||
config=challenge.scoring_config or {},
|
||||
ctx=ctx,
|
||||
)
|
||||
|
||||
score = result.get('score')
|
||||
logger.info(
|
||||
"Custom scorer computed score: %s (details: %s)",
|
||||
score,
|
||||
result.get('details'),
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error("Custom scorer failed: %s", e, exc_info=True)
|
||||
# Continue with score=None on error
|
||||
|
||||
ChallengeService.record_attempt(
|
||||
tenant_id=self.tenant_id,
|
||||
challenge_id=challenge_id,
|
||||
end_user_id=None,
|
||||
account_id=None,
|
||||
workflow_run_id=None,
|
||||
succeeded=ok,
|
||||
score=score,
|
||||
judge_rating=judge_rating,
|
||||
judge_feedback=judge_feedback,
|
||||
tokens_total=tokens_total,
|
||||
elapsed_ms=elapsed_ms,
|
||||
session=db.session,
|
||||
)
|
||||
except Exception:
|
||||
# do not crash the workflow if recording fails
|
||||
pass
|
||||
|
||||
# Always provide all output variables to match frontend getOutputVars
|
||||
outputs: dict[str, Any] = {
|
||||
'challenge_succeeded': ok,
|
||||
'judge_rating': 0,
|
||||
'judge_feedback': '',
|
||||
'message': '',
|
||||
}
|
||||
|
||||
# Override with actual values if evaluator provides them
|
||||
if isinstance(details, dict):
|
||||
logger.debug("ChallengeEvaluator - details: %s", details)
|
||||
if 'rating' in details:
|
||||
outputs['judge_rating'] = details.get('rating')
|
||||
if 'feedback' in details:
|
||||
outputs['judge_feedback'] = details.get('feedback')
|
||||
if 'message' in details:
|
||||
outputs['message'] = details.get('message')
|
||||
# If no explicit message, create one from evaluation details
|
||||
if not outputs['message']:
|
||||
if ok:
|
||||
outputs['message'] = f"Success: {details.get('mode', 'evaluation')} matched"
|
||||
else:
|
||||
outputs['message'] = f"Failed: {details.get('mode', 'evaluation')} did not match"
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs=outputs,
|
||||
)
|
||||
|
||||
188
api/core/workflow/nodes/judging_llm/node.py
Normal file
188
api/core/workflow/nodes/judging_llm/node.py
Normal file
|
|
@ -0,0 +1,188 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
PromptMessageContentType,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from services.challenge_service import ChallengeService
|
||||
|
||||
|
||||
class JudgingLLMNode(Node):
|
||||
node_type = NodeType.JUDGING_LLM
|
||||
execution_type = NodeExecutionType.EXECUTABLE
|
||||
|
||||
_node_data: BaseNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = BaseNodeData.model_validate(data)
|
||||
# Access data directly from node_data, not from a 'config' key
|
||||
self._config: dict[str, Any] = data
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return getattr(self._node_data, 'error_strategy', None)
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return getattr(self._node_data, 'retry_config', RetryConfig())
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return getattr(self._node_data, 'title', 'Judging LLM')
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return getattr(self._node_data, 'desc', None)
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return getattr(self._node_data, 'default_value_dict', {})
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
# Placeholder with FE-compatible keys. Extract inputs for future wiring.
|
||||
inputs_cfg = self._config.get('inputs') or {}
|
||||
goal_selector = None
|
||||
response_selector = None
|
||||
if isinstance(inputs_cfg, dict):
|
||||
goal_selector = inputs_cfg.get('goal')
|
||||
response_selector = inputs_cfg.get('response')
|
||||
|
||||
# Attempt to read variables (not used in placeholder decision)
|
||||
_ = None
|
||||
try:
|
||||
if goal_selector:
|
||||
_ = self.graph_runtime_state.variable_pool.get(goal_selector)
|
||||
if response_selector:
|
||||
_ = self.graph_runtime_state.variable_pool.get(response_selector)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
outputs = {
|
||||
'judge_passed': False,
|
||||
'judge_rating': 0,
|
||||
'judge_feedback': '',
|
||||
}
|
||||
|
||||
# If model config and rubric provided, invoke LLM synchronously to judge
|
||||
judge_model = self._config.get('judge_model') or {}
|
||||
rubric = self._config.get('rubric_prompt_template') or ''
|
||||
provider = (judge_model or {}).get('provider')
|
||||
model_name = (judge_model or {}).get('name')
|
||||
completion_params = (judge_model or {}).get('completion_params') or {}
|
||||
|
||||
def _segment_to_text(seg: Any) -> str:
|
||||
try:
|
||||
# Many variable types expose .text
|
||||
if hasattr(seg, 'text'):
|
||||
return str(seg.text)
|
||||
if isinstance(seg, (dict, list)):
|
||||
return json.dumps(seg, ensure_ascii=False)
|
||||
return str(seg)
|
||||
except Exception:
|
||||
return ''
|
||||
|
||||
# Debug: log what we're checking
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(
|
||||
"JudgingLLM check - provider: %s, model: %s, rubric_len: %s, response_selector: %s",
|
||||
provider,
|
||||
model_name,
|
||||
len(rubric) if rubric else 0,
|
||||
response_selector,
|
||||
)
|
||||
|
||||
if provider and model_name and rubric and response_selector:
|
||||
logger.info("JudgingLLM: All conditions met, invoking LLM...")
|
||||
try:
|
||||
goal_val = self.graph_runtime_state.variable_pool.get(goal_selector) if goal_selector else None
|
||||
response_val = self.graph_runtime_state.variable_pool.get(response_selector)
|
||||
goal_text = _segment_to_text(goal_val)
|
||||
response_text = _segment_to_text(response_val)
|
||||
json_template = '{"passed": boolean, "rating": number (0-10), "feedback": string}'
|
||||
|
||||
prompt_body = (
|
||||
f"Goal:\n{goal_text}\n\n"
|
||||
f"Response:\n{response_text}\n\n"
|
||||
f"Return JSON with rating 0-10: {json_template}"
|
||||
)
|
||||
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content=rubric),
|
||||
UserPromptMessage(content=prompt_body),
|
||||
]
|
||||
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=provider,
|
||||
model=model_name,
|
||||
)
|
||||
result: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=completion_params,
|
||||
stop=[],
|
||||
stream=False,
|
||||
user=self.user_id,
|
||||
) # type: ignore
|
||||
# Extract text from result
|
||||
text_out = ''
|
||||
content = getattr(result.message, 'content', '')
|
||||
if isinstance(content, str):
|
||||
text_out = content
|
||||
elif isinstance(content, list):
|
||||
for item in content:
|
||||
if getattr(item, 'type', None) == PromptMessageContentType.TEXT:
|
||||
text_out += str(getattr(item, 'data', ''))
|
||||
else:
|
||||
text_out = str(content)
|
||||
|
||||
# Parse last JSON object in output
|
||||
verdict: dict[str, Any] | None = None
|
||||
try:
|
||||
matches = re.findall(r"\{[\s\S]*\}", text_out)
|
||||
if matches:
|
||||
verdict = json.loads(matches[-1])
|
||||
except Exception:
|
||||
verdict = None
|
||||
|
||||
if isinstance(verdict, dict):
|
||||
outputs['judge_passed'] = bool(verdict.get('passed'))
|
||||
outputs['judge_rating'] = int(verdict.get('rating') or 0)
|
||||
outputs['judge_feedback'] = str(verdict.get('feedback') or '')
|
||||
outputs['judge_raw'] = json.dumps(verdict)
|
||||
else:
|
||||
# Fallback to simple rules if configured
|
||||
success_type = self._config.get('success_type')
|
||||
success_pattern = self._config.get('success_pattern')
|
||||
if success_type and success_pattern:
|
||||
ok, _ = ChallengeService.evaluate_outcome(response_text, {
|
||||
'success_type': success_type,
|
||||
'success_pattern': success_pattern,
|
||||
})
|
||||
outputs['judge_passed'] = ok
|
||||
outputs['judge_rating'] = 10 if ok else 0
|
||||
outputs['judge_feedback'] = 'passed by rules' if ok else 'failed by rules'
|
||||
except Exception as e:
|
||||
# keep default outputs on error
|
||||
logger.error("JudgingLLM error: %s", e, exc_info=True)
|
||||
pass
|
||||
else:
|
||||
logger.warning("JudgingLLM skipped - missing required fields")
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs)
|
||||
|
||||
|
|
@ -24,6 +24,9 @@ from core.workflow.nodes.tool import ToolNode
|
|||
from core.workflow.nodes.variable_aggregator import VariableAggregatorNode
|
||||
from core.workflow.nodes.variable_assigner.v1 import VariableAssignerNode as VariableAssignerNodeV1
|
||||
from core.workflow.nodes.variable_assigner.v2 import VariableAssignerNode as VariableAssignerNodeV2
|
||||
from core.workflow.nodes.challenge_evaluator.node import ChallengeEvaluatorNode
|
||||
from core.workflow.nodes.judging_llm.node import JudgingLLMNode
|
||||
from core.workflow.nodes.team_challenge.node import TeamChallengeNode
|
||||
|
||||
LATEST_VERSION = "latest"
|
||||
|
||||
|
|
@ -142,4 +145,16 @@ NODE_TYPE_CLASSES_MAPPING: Mapping[NodeType, Mapping[str, type[Node]]] = {
|
|||
LATEST_VERSION: KnowledgeIndexNode,
|
||||
"1": KnowledgeIndexNode,
|
||||
},
|
||||
NodeType.CHALLENGE_EVALUATOR: {
|
||||
LATEST_VERSION: ChallengeEvaluatorNode,
|
||||
"1": ChallengeEvaluatorNode,
|
||||
},
|
||||
NodeType.JUDGING_LLM: {
|
||||
LATEST_VERSION: JudgingLLMNode,
|
||||
"1": JudgingLLMNode,
|
||||
},
|
||||
NodeType.TEAM_CHALLENGE: {
|
||||
LATEST_VERSION: TeamChallengeNode,
|
||||
"1": TeamChallengeNode,
|
||||
},
|
||||
}
|
||||
|
|
|
|||
68
api/core/workflow/nodes/team_challenge/node.py
Normal file
68
api/core/workflow/nodes/team_challenge/node.py
Normal file
|
|
@ -0,0 +1,68 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.entities import BaseNodeData, RetryConfig
|
||||
from core.workflow.nodes.base.node import Node
|
||||
|
||||
|
||||
class TeamChallengeNode(Node):
|
||||
node_type = NodeType.TEAM_CHALLENGE
|
||||
execution_type = NodeExecutionType.EXECUTABLE
|
||||
|
||||
_node_data: BaseNodeData
|
||||
|
||||
def init_node_data(self, data: Mapping[str, Any]):
|
||||
self._node_data = BaseNodeData.model_validate(data)
|
||||
self._config: dict[str, Any] = data
|
||||
|
||||
def _get_error_strategy(self) -> ErrorStrategy | None:
|
||||
return getattr(self._node_data, 'error_strategy', None)
|
||||
|
||||
def _get_retry_config(self) -> RetryConfig:
|
||||
return getattr(self._node_data, 'retry_config', RetryConfig())
|
||||
|
||||
def _get_title(self) -> str:
|
||||
return getattr(self._node_data, 'title', 'Team Challenge')
|
||||
|
||||
def _get_description(self) -> str | None:
|
||||
return getattr(self._node_data, 'desc', None)
|
||||
|
||||
def _get_default_value_dict(self) -> dict[str, Any]:
|
||||
return getattr(self._node_data, 'default_value_dict', {})
|
||||
|
||||
def get_base_node_data(self) -> BaseNodeData:
|
||||
return self._node_data
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> NodeRunResult:
|
||||
# Read inputs.team_choice for consistency with FE
|
||||
inputs_cfg = self._config.get('inputs') or {}
|
||||
team_choice = ''
|
||||
if isinstance(inputs_cfg, dict):
|
||||
team_choice_selector = inputs_cfg.get('team_choice')
|
||||
if team_choice_selector:
|
||||
try:
|
||||
v = self.graph_runtime_state.variable_pool.get_value_by_selector(team_choice_selector)
|
||||
team_choice = str(v or '')
|
||||
except Exception:
|
||||
team_choice = ''
|
||||
|
||||
outputs = {
|
||||
'team': team_choice,
|
||||
'judge_passed': False,
|
||||
'judge_rating': 0,
|
||||
'judge_feedback': '',
|
||||
'categories': {},
|
||||
'team_points': 0.0,
|
||||
'total_points': 0.0,
|
||||
}
|
||||
return NodeRunResult(status=WorkflowNodeExecutionStatus.SUCCEEDED, outputs=outputs)
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,170 @@
|
|||
"""add challenge & red/blue tables
|
||||
|
||||
Revision ID: 183e2d30fb4e
|
||||
Revises: 68519ad5cd18
|
||||
Create Date: 2025-09-30 08:22:31.223257
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '183e2d30fb4e'
|
||||
down_revision = '68519ad5cd18'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('challenge_attempts',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('challenge_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('end_user_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('account_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('workflow_run_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('succeeded', sa.Boolean(), server_default=sa.text('false'), nullable=False),
|
||||
sa.Column('score', sa.Float(), nullable=True),
|
||||
sa.Column('judge_rating', sa.Integer(), nullable=True),
|
||||
sa.Column('judge_feedback', sa.Text(), nullable=True),
|
||||
sa.Column('judge_output_raw', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('tokens_total', sa.Integer(), nullable=True),
|
||||
sa.Column('elapsed_ms', sa.Integer(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='challenge_attempts_pkey')
|
||||
)
|
||||
with op.batch_alter_table('challenge_attempts', schema=None) as batch_op:
|
||||
batch_op.create_index('challenge_attempts_challenge_id_idx', ['challenge_id'], unique=False)
|
||||
batch_op.create_index('challenge_attempts_tenant_id_idx', ['tenant_id'], unique=False)
|
||||
|
||||
op.create_table('challenges',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('workflow_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('name', sa.Text(), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('goal', sa.Text(), nullable=True),
|
||||
sa.Column('success_type', sa.String(length=64), server_default=sa.text("'regex'"), nullable=False),
|
||||
sa.Column('success_pattern', sa.Text(), nullable=True),
|
||||
sa.Column('secret_ref', sa.Text(), nullable=True),
|
||||
sa.Column('evaluator_type', sa.String(length=32), server_default=sa.text("'rules'"), nullable=False),
|
||||
sa.Column('evaluator_plugin_id', sa.Text(), nullable=True),
|
||||
sa.Column('evaluator_entrypoint', sa.Text(), nullable=True),
|
||||
sa.Column('evaluator_config', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('scoring_strategy', sa.String(length=64), server_default=sa.text("'first'"), nullable=False),
|
||||
sa.Column('scoring_plugin_id', sa.Text(), nullable=True),
|
||||
sa.Column('scoring_entrypoint', sa.Text(), nullable=True),
|
||||
sa.Column('scoring_config', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), server_default=sa.text('true'), nullable=False),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='challenges_pkey')
|
||||
)
|
||||
with op.batch_alter_table('challenges', schema=None) as batch_op:
|
||||
batch_op.create_index('challenges_app_id_idx', ['app_id'], unique=False)
|
||||
batch_op.create_index('challenges_tenant_id_idx', ['tenant_id'], unique=False)
|
||||
|
||||
op.create_table('red_blue_challenges',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('app_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('workflow_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('name', sa.Text(), nullable=False),
|
||||
sa.Column('description', sa.Text(), nullable=True),
|
||||
sa.Column('judge_suite', postgresql.JSONB(astext_type=sa.Text()), nullable=False),
|
||||
sa.Column('defense_selection_policy', sa.String(length=64), server_default=sa.text("'latest_best'"), nullable=False),
|
||||
sa.Column('attack_selection_policy', sa.String(length=64), server_default=sa.text("'latest_best'"), nullable=False),
|
||||
sa.Column('scoring_strategy', sa.String(length=64), server_default=sa.text("'red_blue_ratio'"), nullable=False),
|
||||
sa.Column('theme', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('instructions_md', sa.Text(), nullable=True),
|
||||
sa.Column('is_active', sa.Boolean(), server_default=sa.text('true'), nullable=False),
|
||||
sa.Column('created_by', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_by', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='red_blue_challenges_pkey')
|
||||
)
|
||||
with op.batch_alter_table('red_blue_challenges', schema=None) as batch_op:
|
||||
batch_op.create_index('red_blue_challenges_app_id_idx', ['app_id'], unique=False)
|
||||
batch_op.create_index('red_blue_challenges_tenant_id_idx', ['tenant_id'], unique=False)
|
||||
|
||||
op.create_table('team_pairings',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('red_blue_challenge_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('attack_submission_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('defense_submission_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('judge_output_raw', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('categories', postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column('judge_rating', sa.Integer(), nullable=True),
|
||||
sa.Column('judge_feedback', sa.Text(), nullable=True),
|
||||
sa.Column('red_points', sa.Float(), server_default=sa.text('0'), nullable=False),
|
||||
sa.Column('blue_points', sa.Float(), server_default=sa.text('0'), nullable=False),
|
||||
sa.Column('tokens_total', sa.Integer(), nullable=True),
|
||||
sa.Column('elapsed_ms', sa.Integer(), nullable=True),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='team_pairings_pkey')
|
||||
)
|
||||
with op.batch_alter_table('team_pairings', schema=None) as batch_op:
|
||||
batch_op.create_index('team_pairings_challenge_id_idx', ['red_blue_challenge_id'], unique=False)
|
||||
batch_op.create_index('team_pairings_tenant_id_idx', ['tenant_id'], unique=False)
|
||||
|
||||
op.create_table('team_submissions',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('red_blue_challenge_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('account_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('end_user_id', models.types.StringUUID(), nullable=True),
|
||||
sa.Column('team', sa.String(length=16), nullable=False),
|
||||
sa.Column('prompt', sa.Text(), nullable=False),
|
||||
sa.Column('active', sa.Boolean(), server_default=sa.text('true'), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='team_submissions_pkey')
|
||||
)
|
||||
with op.batch_alter_table('team_submissions', schema=None) as batch_op:
|
||||
batch_op.create_index('team_submissions_challenge_id_idx', ['red_blue_challenge_id'], unique=False)
|
||||
batch_op.create_index('team_submissions_tenant_id_idx', ['tenant_id'], unique=False)
|
||||
|
||||
with op.batch_alter_table('providers', schema=None) as batch_op:
|
||||
batch_op.drop_column('credential_status')
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('providers', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('credential_status', sa.VARCHAR(length=20), server_default=sa.text("'active'::character varying"), autoincrement=False, nullable=True))
|
||||
|
||||
with op.batch_alter_table('team_submissions', schema=None) as batch_op:
|
||||
batch_op.drop_index('team_submissions_tenant_id_idx')
|
||||
batch_op.drop_index('team_submissions_challenge_id_idx')
|
||||
|
||||
op.drop_table('team_submissions')
|
||||
with op.batch_alter_table('team_pairings', schema=None) as batch_op:
|
||||
batch_op.drop_index('team_pairings_tenant_id_idx')
|
||||
batch_op.drop_index('team_pairings_challenge_id_idx')
|
||||
|
||||
op.drop_table('team_pairings')
|
||||
with op.batch_alter_table('red_blue_challenges', schema=None) as batch_op:
|
||||
batch_op.drop_index('red_blue_challenges_tenant_id_idx')
|
||||
batch_op.drop_index('red_blue_challenges_app_id_idx')
|
||||
|
||||
op.drop_table('red_blue_challenges')
|
||||
with op.batch_alter_table('challenges', schema=None) as batch_op:
|
||||
batch_op.drop_index('challenges_tenant_id_idx')
|
||||
batch_op.drop_index('challenges_app_id_idx')
|
||||
|
||||
op.drop_table('challenges')
|
||||
with op.batch_alter_table('challenge_attempts', schema=None) as batch_op:
|
||||
batch_op.drop_index('challenge_attempts_tenant_id_idx')
|
||||
batch_op.drop_index('challenge_attempts_challenge_id_idx')
|
||||
|
||||
op.drop_table('challenge_attempts')
|
||||
# ### end Alembic commands ###
|
||||
|
|
@ -91,6 +91,8 @@ from .workflow import (
|
|||
WorkflowRun,
|
||||
WorkflowType,
|
||||
)
|
||||
from .challenge import Challenge, ChallengeAttempt
|
||||
from .red_blue import RedBlueChallenge, TeamSubmission, TeamPairing
|
||||
|
||||
__all__ = [
|
||||
"APIBasedExtension",
|
||||
|
|
@ -181,4 +183,9 @@ __all__ = [
|
|||
"WorkflowRunTriggeredFrom",
|
||||
"WorkflowToolProvider",
|
||||
"WorkflowType",
|
||||
"Challenge",
|
||||
"ChallengeAttempt",
|
||||
"RedBlueChallenge",
|
||||
"TeamSubmission",
|
||||
"TeamPairing",
|
||||
]
|
||||
|
|
|
|||
91
api/models/challenge.py
Normal file
91
api/models/challenge.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import Base
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class Challenge(Base):
|
||||
__tablename__ = "challenges"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="challenges_pkey"),
|
||||
sa.Index("challenges_tenant_id_idx", "tenant_id"),
|
||||
sa.Index("challenges_app_id_idx", "app_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
workflow_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
|
||||
name: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
|
||||
goal: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
|
||||
|
||||
success_type: Mapped[str] = mapped_column(sa.String(64), nullable=False, server_default=sa.text("'regex'"))
|
||||
success_pattern: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
|
||||
secret_ref: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
|
||||
|
||||
evaluator_type: Mapped[str] = mapped_column(sa.String(32), nullable=False, server_default=sa.text("'rules'"))
|
||||
evaluator_plugin_id: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
|
||||
evaluator_entrypoint: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
|
||||
evaluator_config = mapped_column(JSONB, nullable=True)
|
||||
|
||||
scoring_strategy: Mapped[str] = mapped_column(sa.String(64), nullable=False, server_default=sa.text("'first'"))
|
||||
scoring_plugin_id: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
|
||||
scoring_entrypoint: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
|
||||
scoring_config = mapped_column(JSONB, nullable=True)
|
||||
|
||||
is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
|
||||
created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=sa.func.current_timestamp(),
|
||||
)
|
||||
updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=sa.func.current_timestamp(),
|
||||
)
|
||||
|
||||
|
||||
class ChallengeAttempt(Base):
|
||||
__tablename__ = "challenge_attempts"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="challenge_attempts_pkey"),
|
||||
sa.Index("challenge_attempts_tenant_id_idx", "tenant_id"),
|
||||
sa.Index("challenge_attempts_challenge_id_idx", "challenge_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
challenge_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
||||
end_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
account_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
workflow_run_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
|
||||
succeeded: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
score: Mapped[float | None] = mapped_column(sa.Float, nullable=True)
|
||||
|
||||
judge_rating: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
judge_feedback: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
|
||||
judge_output_raw = mapped_column(JSONB, nullable=True)
|
||||
|
||||
tokens_total: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
elapsed_ms: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=sa.func.current_timestamp(),
|
||||
)
|
||||
|
||||
114
api/models/red_blue.py
Normal file
114
api/models/red_blue.py
Normal file
|
|
@ -0,0 +1,114 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from datetime import datetime
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects.postgresql import JSONB
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from .base import Base
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class RedBlueChallenge(Base):
|
||||
__tablename__ = "red_blue_challenges"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="red_blue_challenges_pkey"),
|
||||
sa.Index("red_blue_challenges_tenant_id_idx", "tenant_id"),
|
||||
sa.Index("red_blue_challenges_app_id_idx", "app_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
workflow_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
|
||||
name: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
|
||||
|
||||
judge_suite = mapped_column(JSONB, nullable=False)
|
||||
defense_selection_policy: Mapped[str] = mapped_column(
|
||||
sa.String(64), nullable=False, server_default=sa.text("'latest_best'")
|
||||
)
|
||||
attack_selection_policy: Mapped[str] = mapped_column(
|
||||
sa.String(64), nullable=False, server_default=sa.text("'latest_best'")
|
||||
)
|
||||
scoring_strategy: Mapped[str] = mapped_column(
|
||||
sa.String(64), nullable=False, server_default=sa.text("'red_blue_ratio'")
|
||||
)
|
||||
|
||||
theme = mapped_column(JSONB, nullable=True)
|
||||
instructions_md: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
|
||||
|
||||
is_active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
|
||||
created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=sa.func.current_timestamp(),
|
||||
)
|
||||
updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=sa.func.current_timestamp(),
|
||||
)
|
||||
|
||||
|
||||
class TeamSubmission(Base):
|
||||
__tablename__ = "team_submissions"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="team_submissions_pkey"),
|
||||
sa.Index("team_submissions_challenge_id_idx", "red_blue_challenge_id"),
|
||||
sa.Index("team_submissions_tenant_id_idx", "tenant_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
red_blue_challenge_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
account_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
end_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
|
||||
team: Mapped[str] = mapped_column(sa.String(16), nullable=False) # 'red' | 'blue'
|
||||
prompt: Mapped[str] = mapped_column(sa.Text, nullable=False)
|
||||
active: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("true"))
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=sa.func.current_timestamp(),
|
||||
)
|
||||
|
||||
|
||||
class TeamPairing(Base):
|
||||
__tablename__ = "team_pairings"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="team_pairings_pkey"),
|
||||
sa.Index("team_pairings_challenge_id_idx", "red_blue_challenge_id"),
|
||||
sa.Index("team_pairings_tenant_id_idx", "tenant_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
|
||||
red_blue_challenge_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
||||
attack_submission_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
defense_submission_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
|
||||
judge_output_raw = mapped_column(JSONB, nullable=True)
|
||||
categories = mapped_column(JSONB, nullable=True)
|
||||
judge_rating: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
judge_feedback: Mapped[str | None] = mapped_column(sa.Text, nullable=True)
|
||||
red_points: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
|
||||
blue_points: Mapped[float] = mapped_column(sa.Float, nullable=False, server_default=sa.text("0"))
|
||||
|
||||
tokens_total: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
elapsed_ms: Mapped[int | None] = mapped_column(sa.Integer, nullable=True)
|
||||
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
sa.DateTime,
|
||||
nullable=False,
|
||||
server_default=sa.func.current_timestamp(),
|
||||
)
|
||||
|
||||
|
|
@ -223,3 +223,6 @@ vdb = [
|
|||
"xinference-client~=1.2.2",
|
||||
"mo-vector~=0.1.13",
|
||||
]
|
||||
[tool.pyright]
|
||||
typeCheckingMode = "basic"
|
||||
reportImplicitRelativeImport = "none"
|
||||
|
|
|
|||
|
|
@ -8,7 +8,8 @@
|
|||
"extensions",
|
||||
"core/app/app_config/easy_ui_based_app/dataset"
|
||||
],
|
||||
"typeCheckingMode": "strict",
|
||||
"typeCheckingMode": "basic",
|
||||
"reportImplicitRelativeImport": "none",
|
||||
"allowedUntypedLibraries": [
|
||||
"flask_restx",
|
||||
"flask_login",
|
||||
|
|
|
|||
1
api/run.sh
Normal file
1
api/run.sh
Normal file
|
|
@ -0,0 +1 @@
|
|||
uv run --dev flask --app app run --host 0.0.0.0 --port 5001 --debug
|
||||
56
api/services/challenge_scorer_protocol.py
Normal file
56
api/services/challenge_scorer_protocol.py
Normal file
|
|
@ -0,0 +1,56 @@
|
|||
"""
|
||||
Challenge scorer protocol and type definitions.
|
||||
|
||||
Defines the interface for custom scoring plugins that compute
|
||||
numeric scores from attempt metrics for leaderboard ranking.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol, TypedDict
|
||||
|
||||
|
||||
class ScoringContext(TypedDict, total=False):
|
||||
"""Context provided to scorer plugins."""
|
||||
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
workflow_id: str
|
||||
challenge_id: str
|
||||
end_user_id: str | None
|
||||
timeout_ms: int
|
||||
|
||||
|
||||
class AttemptMetrics(TypedDict, total=False):
|
||||
"""Metrics from a challenge attempt."""
|
||||
|
||||
succeeded: bool
|
||||
tokens_total: int | None
|
||||
elapsed_ms: int | None
|
||||
rating: int | None
|
||||
created_at: int | None # epoch ms
|
||||
|
||||
|
||||
class ScoringResult(TypedDict, total=False):
|
||||
"""Result returned by scorer plugin."""
|
||||
|
||||
score: float
|
||||
details: dict[str, Any] | None
|
||||
|
||||
|
||||
class ScorerProtocol(Protocol):
|
||||
"""Protocol that all scorer plugins must implement."""
|
||||
|
||||
def score(self, metrics: AttemptMetrics, config: dict[str, Any], ctx: ScoringContext) -> ScoringResult:
|
||||
"""
|
||||
Compute a numeric score from attempt metrics.
|
||||
|
||||
Args:
|
||||
metrics: Attempt metrics (tokens, time, rating, etc.)
|
||||
config: Plugin-specific configuration (from challenge.scoring_config)
|
||||
ctx: Context with tenant_id, app_id, etc.
|
||||
|
||||
Returns:
|
||||
ScoringResult with computed score and optional details
|
||||
"""
|
||||
...
|
||||
112
api/services/challenge_scorer_service.py
Normal file
112
api/services/challenge_scorer_service.py
Normal file
|
|
@ -0,0 +1,112 @@
|
|||
"""
|
||||
Challenge scorer service.
|
||||
|
||||
Loads and invokes custom scorer plugins to compute scores from attempt metrics.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from services.challenge_scorer_protocol import AttemptMetrics, ScoringContext, ScoringResult
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ChallengeScorerService:
|
||||
"""Service for loading and invoking custom scorer plugins."""
|
||||
|
||||
_plugin_cache: dict[str, Any] = {}
|
||||
|
||||
@classmethod
|
||||
def score_with_plugin(
|
||||
cls,
|
||||
*,
|
||||
scorer_plugin_id: str | None,
|
||||
scorer_entrypoint: str | None,
|
||||
metrics: AttemptMetrics,
|
||||
config: dict[str, Any] | None,
|
||||
ctx: ScoringContext,
|
||||
) -> ScoringResult:
|
||||
"""
|
||||
Compute score using a custom scorer plugin.
|
||||
|
||||
Args:
|
||||
scorer_plugin_id: Plugin identifier (e.g., 'builtin.weighted_scorer')
|
||||
scorer_entrypoint: Entrypoint path (e.g., 'services.scorers.weighted:WeightedScorer')
|
||||
metrics: Attempt metrics to score
|
||||
config: Plugin-specific configuration
|
||||
ctx: Scoring context
|
||||
|
||||
Returns:
|
||||
ScoringResult with computed score
|
||||
|
||||
Raises:
|
||||
ValueError: If plugin cannot be loaded or scoring fails
|
||||
"""
|
||||
if not scorer_plugin_id or not scorer_entrypoint:
|
||||
raise ValueError("scorer_plugin_id and scorer_entrypoint are required for custom scoring")
|
||||
|
||||
# Load plugin
|
||||
scorer = cls._load_plugin(scorer_plugin_id, scorer_entrypoint)
|
||||
if not scorer:
|
||||
raise ValueError(f"Failed to load scorer plugin: {scorer_plugin_id}:{scorer_entrypoint}")
|
||||
|
||||
# Invoke scorer with timeout protection
|
||||
timeout_ms = ctx.get("timeout_ms", 5000)
|
||||
try:
|
||||
# TODO: Add timeout enforcement using threading.Timer or signal.alarm
|
||||
result = scorer.score(metrics, config or {}, ctx)
|
||||
if not isinstance(result, dict) or "score" not in result:
|
||||
raise ValueError("Scorer must return a dict with 'score' key")
|
||||
return result
|
||||
except Exception as e:
|
||||
logger.error(f"Scorer plugin {scorer_plugin_id} failed: {e}", exc_info=True)
|
||||
raise ValueError(f"Scorer plugin execution failed: {e}")
|
||||
|
||||
@classmethod
|
||||
def _load_plugin(cls, plugin_id: str, entrypoint: str) -> Any:
|
||||
"""
|
||||
Load a scorer plugin by entrypoint.
|
||||
|
||||
Args:
|
||||
plugin_id: Plugin identifier for caching
|
||||
entrypoint: Python path like 'pkg.module:ClassName'
|
||||
|
||||
Returns:
|
||||
Scorer instance or None if loading fails
|
||||
"""
|
||||
cache_key = f"{plugin_id}:{entrypoint}"
|
||||
if cache_key in cls._plugin_cache:
|
||||
return cls._plugin_cache[cache_key]
|
||||
|
||||
try:
|
||||
# Parse entrypoint: 'pkg.module:ClassName'
|
||||
if ":" not in entrypoint:
|
||||
raise ValueError(f"Invalid entrypoint format: {entrypoint}. Expected 'module:ClassName'")
|
||||
|
||||
module_path, class_name = entrypoint.split(":", 1)
|
||||
|
||||
# Dynamic import
|
||||
import importlib
|
||||
|
||||
module = importlib.import_module(module_path)
|
||||
scorer_class = getattr(module, class_name)
|
||||
|
||||
# Instantiate
|
||||
scorer = scorer_class()
|
||||
|
||||
# Cache it
|
||||
cls._plugin_cache[cache_key] = scorer
|
||||
logger.info(f"Loaded scorer plugin: {plugin_id} from {entrypoint}")
|
||||
return scorer
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load scorer plugin {plugin_id}:{entrypoint}: {e}", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def clear_cache(cls) -> None:
|
||||
"""Clear the plugin cache (useful for testing)."""
|
||||
cls._plugin_cache.clear()
|
||||
64
api/services/challenge_service.py
Normal file
64
api/services/challenge_service.py
Normal file
|
|
@ -0,0 +1,64 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import re
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.challenge import Challenge, ChallengeAttempt
|
||||
|
||||
|
||||
class ChallengeService:
|
||||
@staticmethod
|
||||
def evaluate_outcome(output_text: str, cfg: Mapping[str, Any]) -> tuple[bool, dict[str, Any]]:
|
||||
success_type = cfg.get("success_type", "regex")
|
||||
pattern = cfg.get("success_pattern")
|
||||
if success_type == "regex" and pattern:
|
||||
try:
|
||||
if re.search(pattern, output_text, flags=re.IGNORECASE | re.MULTILINE):
|
||||
return True, {"mode": "regex", "matched": True}
|
||||
return False, {"mode": "regex", "matched": False}
|
||||
except re.error as e:
|
||||
return False, {"mode": "regex", "error": f"invalid_regex: {e}"}
|
||||
if success_type == "contains" and pattern:
|
||||
return (pattern.lower() in output_text.lower()), {"mode": "contains"}
|
||||
return False, {"mode": success_type, "info": "no_pattern_or_unsupported"}
|
||||
|
||||
@staticmethod
|
||||
def record_attempt(
|
||||
*,
|
||||
tenant_id: str,
|
||||
challenge_id: str,
|
||||
end_user_id: str | None,
|
||||
account_id: str | None,
|
||||
workflow_run_id: str | None,
|
||||
succeeded: bool,
|
||||
score: float | None = None,
|
||||
judge_rating: int | None = None,
|
||||
judge_feedback: str | None = None,
|
||||
judge_output_raw: dict[str, Any] | None = None,
|
||||
tokens_total: int | None = None,
|
||||
elapsed_ms: int | None = None,
|
||||
session: Session | None = None,
|
||||
) -> ChallengeAttempt:
|
||||
sess = session or db.session
|
||||
attempt = ChallengeAttempt()
|
||||
attempt.tenant_id = tenant_id
|
||||
attempt.challenge_id = challenge_id
|
||||
attempt.end_user_id = end_user_id
|
||||
attempt.account_id = account_id
|
||||
attempt.workflow_run_id = workflow_run_id
|
||||
attempt.succeeded = succeeded
|
||||
attempt.score = score
|
||||
attempt.judge_rating = judge_rating
|
||||
attempt.judge_feedback = judge_feedback
|
||||
attempt.judge_output_raw = judge_output_raw
|
||||
attempt.tokens_total = tokens_total
|
||||
attempt.elapsed_ms = elapsed_ms
|
||||
sess.add(attempt)
|
||||
sess.commit()
|
||||
return attempt
|
||||
|
||||
|
||||
91
api/services/red_blue_service.py
Normal file
91
api/services/red_blue_service.py
Normal file
|
|
@ -0,0 +1,91 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.red_blue import RedBlueChallenge, TeamPairing, TeamSubmission
|
||||
|
||||
|
||||
class RedBlueService:
|
||||
@staticmethod
|
||||
def submit_prompt(
|
||||
*,
|
||||
challenge_id: str,
|
||||
tenant_id: str,
|
||||
team: str,
|
||||
prompt: str,
|
||||
account_id: str | None,
|
||||
end_user_id: str | None,
|
||||
session: Session | None = None,
|
||||
) -> TeamSubmission:
|
||||
sess = session or db.session
|
||||
sub = TeamSubmission()
|
||||
sub.red_blue_challenge_id = challenge_id
|
||||
sub.tenant_id = tenant_id
|
||||
sub.team = team
|
||||
sub.prompt = prompt
|
||||
sub.account_id = account_id
|
||||
sub.end_user_id = end_user_id
|
||||
sess.add(sub)
|
||||
sess.commit()
|
||||
return sub
|
||||
|
||||
@staticmethod
|
||||
def select_counterparty_submission(
|
||||
*,
|
||||
challenge: RedBlueChallenge,
|
||||
team: str,
|
||||
session: Session | None = None,
|
||||
) -> TeamSubmission | None:
|
||||
sess = session or db.session
|
||||
opposite = "blue" if team == "red" else "red"
|
||||
# Simplest policy: latest active from opposite team
|
||||
return (
|
||||
sess.query(TeamSubmission)
|
||||
.filter(
|
||||
TeamSubmission.red_blue_challenge_id == challenge.id,
|
||||
TeamSubmission.team == opposite,
|
||||
TeamSubmission.active.is_(True),
|
||||
)
|
||||
.order_by(TeamSubmission.created_at.desc())
|
||||
.first()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def record_pairing(
|
||||
*,
|
||||
challenge_id: str,
|
||||
tenant_id: str,
|
||||
attack_submission_id: str | None,
|
||||
defense_submission_id: str | None,
|
||||
judge_output_raw: dict[str, Any] | None,
|
||||
categories: dict[str, Any] | None,
|
||||
judge_rating: int | None,
|
||||
judge_feedback: str | None,
|
||||
red_points: float,
|
||||
blue_points: float,
|
||||
tokens_total: int | None,
|
||||
elapsed_ms: int | None,
|
||||
session: Session | None = None,
|
||||
) -> TeamPairing:
|
||||
sess = session or db.session
|
||||
pairing = TeamPairing()
|
||||
pairing.red_blue_challenge_id = challenge_id
|
||||
pairing.tenant_id = tenant_id
|
||||
pairing.attack_submission_id = attack_submission_id
|
||||
pairing.defense_submission_id = defense_submission_id
|
||||
pairing.judge_output_raw = judge_output_raw
|
||||
pairing.categories = categories
|
||||
pairing.judge_rating = judge_rating
|
||||
pairing.judge_feedback = judge_feedback
|
||||
pairing.red_points = red_points
|
||||
pairing.blue_points = blue_points
|
||||
pairing.tokens_total = tokens_total
|
||||
pairing.elapsed_ms = elapsed_ms
|
||||
sess.add(pairing)
|
||||
sess.commit()
|
||||
return pairing
|
||||
|
||||
|
||||
144
api/services/scorers/README.md
Normal file
144
api/services/scorers/README.md
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
# Custom Scorer Plugins
|
||||
|
||||
This directory contains custom scorer plugins for challenge leaderboards.
|
||||
|
||||
## Overview
|
||||
|
||||
Scorers compute numeric scores from challenge attempt metrics (tokens, time, rating, success) for ranking on leaderboards when `scoring_strategy = 'custom'`.
|
||||
|
||||
## Built-in Scorers
|
||||
|
||||
### WeightedScorer
|
||||
|
||||
**Entrypoint:** `services.scorers.weighted:WeightedScorer`
|
||||
|
||||
Computes a weighted score combining multiple metrics with configurable bonuses and penalties.
|
||||
|
||||
**Formula:**
|
||||
```
|
||||
score = success_bonus
|
||||
+ (rating × rating_weight)
|
||||
- (elapsed_seconds × time_penalty)
|
||||
- (tokens × token_penalty)
|
||||
```
|
||||
|
||||
**Configuration:**
|
||||
- `success_bonus` (float, default: 100): Base points for successful attempts
|
||||
- `rating_weight` (float, default: 10): Multiplier for judge rating (0-10)
|
||||
- `time_penalty` (float, default: 1.0): Penalty per second elapsed
|
||||
- `token_penalty` (float, default: 0.01): Penalty per token used
|
||||
|
||||
**Example Configuration:**
|
||||
```json
|
||||
{
|
||||
"success_bonus": 100.0,
|
||||
"rating_weight": 10.0,
|
||||
"time_penalty": 1.0,
|
||||
"token_penalty": 0.01
|
||||
}
|
||||
```
|
||||
|
||||
**Example Challenge Setup (via API):**
|
||||
```python
|
||||
{
|
||||
"name": "Advanced Prompt Challenge",
|
||||
"scoring_strategy": "custom",
|
||||
"scoring_plugin_id": "builtin.weighted_scorer",
|
||||
"scoring_entrypoint": "services.scorers.weighted:WeightedScorer",
|
||||
"scoring_config": {
|
||||
"success_bonus": 100.0,
|
||||
"rating_weight": 15.0,
|
||||
"time_penalty": 0.5,
|
||||
"token_penalty": 0.02
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
## Creating Custom Scorers
|
||||
|
||||
### 1. Implement the ScorerProtocol
|
||||
|
||||
Create a new file in this directory (e.g., `custom.py`):
|
||||
|
||||
```python
|
||||
from typing import Any
|
||||
from services.challenge_scorer_protocol import AttemptMetrics, ScoringContext, ScoringResult
|
||||
|
||||
class MyCustomScorer:
|
||||
def score(self, metrics: AttemptMetrics, config: dict[str, Any], ctx: ScoringContext) -> ScoringResult:
|
||||
# Access metrics
|
||||
succeeded = metrics.get('succeeded', False)
|
||||
tokens = metrics.get('tokens_total', 0)
|
||||
elapsed_ms = metrics.get('elapsed_ms', 0)
|
||||
rating = metrics.get('rating', 0)
|
||||
|
||||
# Access configuration
|
||||
multiplier = config.get('multiplier', 1.0)
|
||||
|
||||
# Compute score
|
||||
score = (rating * multiplier) if succeeded else 0.0
|
||||
|
||||
return {
|
||||
'score': score,
|
||||
'details': { # optional
|
||||
'multiplier_used': multiplier
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
### 2. Register in Challenge
|
||||
|
||||
Set the challenge's scoring fields:
|
||||
|
||||
```python
|
||||
challenge.scoring_strategy = 'custom'
|
||||
challenge.scoring_plugin_id = 'my_custom_scorer'
|
||||
challenge.scoring_entrypoint = 'services.scorers.custom:MyCustomScorer'
|
||||
challenge.scoring_config = {
|
||||
'multiplier': 2.0
|
||||
}
|
||||
```
|
||||
|
||||
### 3. Testing
|
||||
|
||||
Create tests in `api/tests/unit_tests/services/` following the pattern in `test_challenge_scorer_service.py`.
|
||||
|
||||
## Protocol Reference
|
||||
|
||||
### Input Types
|
||||
|
||||
**AttemptMetrics:**
|
||||
- `succeeded` (bool): Whether the challenge was passed
|
||||
- `tokens_total` (int | None): Total tokens used
|
||||
- `elapsed_ms` (int | None): Time taken in milliseconds
|
||||
- `rating` (int | None): Judge rating (0-10)
|
||||
- `created_at` (int | None): Timestamp in epoch milliseconds
|
||||
|
||||
**ScoringContext:**
|
||||
- `tenant_id` (str): Tenant identifier
|
||||
- `app_id` (str): Application identifier
|
||||
- `workflow_id` (str): Workflow identifier
|
||||
- `challenge_id` (str): Challenge identifier
|
||||
- `end_user_id` (str | None): End user identifier (if available)
|
||||
- `timeout_ms` (int): Maximum execution time
|
||||
|
||||
### Output Type
|
||||
|
||||
**ScoringResult:**
|
||||
- `score` (float, required): Computed numeric score
|
||||
- `details` (dict[str, Any] | None, optional): Additional scoring details
|
||||
|
||||
## Error Handling
|
||||
|
||||
- Scorers must return a dict with a `score` key
|
||||
- Exceptions are caught and logged; the attempt is recorded with `score=None`
|
||||
- Scorers are executed with a timeout (default: 5s)
|
||||
- Scorers should never return negative scores; use `max(score, 0.0)` to clamp
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Keep it simple**: Scoring should be fast and deterministic
|
||||
2. **Validate config**: Check configuration values and provide defaults
|
||||
3. **Clamp scores**: Ensure scores are non-negative
|
||||
4. **Document formula**: Clearly explain how your scorer works
|
||||
5. **Test edge cases**: Test with missing metrics, zeros, nulls
|
||||
1
api/services/scorers/__init__.py
Normal file
1
api/services/scorers/__init__.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
"""Built-in scorer plugins."""
|
||||
66
api/services/scorers/weighted.py
Normal file
66
api/services/scorers/weighted.py
Normal file
|
|
@ -0,0 +1,66 @@
|
|||
"""
|
||||
Weighted scorer plugin.
|
||||
|
||||
Computes a weighted score based on success bonus, rating, elapsed time, and token usage.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from services.challenge_scorer_protocol import AttemptMetrics, ScoringContext, ScoringResult
|
||||
|
||||
|
||||
class WeightedScorer:
|
||||
"""
|
||||
Example weighted scorer that combines multiple metrics.
|
||||
|
||||
Configuration options:
|
||||
- success_bonus (float): Base points for successful attempt (default: 100)
|
||||
- rating_weight (float): Multiplier for judge rating (default: 10)
|
||||
- time_penalty (float): Penalty per second elapsed (default: 1.0)
|
||||
- token_penalty (float): Penalty per token used (default: 0.01)
|
||||
|
||||
Formula:
|
||||
score = success_bonus
|
||||
+ (rating * rating_weight)
|
||||
- (elapsed_seconds * time_penalty)
|
||||
- (tokens * token_penalty)
|
||||
"""
|
||||
|
||||
def score(self, metrics: AttemptMetrics, config: dict[str, Any], ctx: ScoringContext) -> ScoringResult:
|
||||
"""Compute weighted score from metrics."""
|
||||
# Base score for success
|
||||
base = 0.0
|
||||
if metrics.get("succeeded"):
|
||||
base += config.get("success_bonus", 100.0)
|
||||
|
||||
# Add rating contribution
|
||||
rating = metrics.get("rating") or 0
|
||||
rating_weight = config.get("rating_weight", 10.0)
|
||||
rating_score = rating * rating_weight
|
||||
|
||||
# Subtract time penalty
|
||||
elapsed_ms = metrics.get("elapsed_ms") or 0
|
||||
elapsed_seconds = elapsed_ms / 1000.0
|
||||
time_penalty = config.get("time_penalty", 1.0)
|
||||
time_score = elapsed_seconds * time_penalty
|
||||
|
||||
# Subtract token penalty
|
||||
tokens = metrics.get("tokens_total") or 0
|
||||
token_penalty = config.get("token_penalty", 0.01)
|
||||
token_score = tokens * token_penalty
|
||||
|
||||
# Compute final score (never negative)
|
||||
final_score = base + rating_score - time_score - token_score
|
||||
final_score = max(final_score, 0.0)
|
||||
|
||||
return {
|
||||
"score": final_score,
|
||||
"details": {
|
||||
"base": base,
|
||||
"rating_contribution": rating_score,
|
||||
"time_penalty": time_score,
|
||||
"token_penalty": token_score,
|
||||
},
|
||||
}
|
||||
|
|
@ -0,0 +1,14 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
|
||||
|
||||
class TestWebChallenges:
|
||||
def test_list_and_detail(self, test_client_with_containers: FlaskClient):
|
||||
# list
|
||||
resp = test_client_with_containers.get("/api/web/challenges")
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert data["result"] == "success"
|
||||
|
||||
|
||||
|
|
@ -0,0 +1,13 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from flask.testing import FlaskClient
|
||||
|
||||
|
||||
class TestWebRedBlueChallenges:
|
||||
def test_list(self, test_client_with_containers: FlaskClient):
|
||||
resp = test_client_with_containers.get("/api/web/red-blue-challenges")
|
||||
assert resp.status_code == 200
|
||||
data = resp.get_json()
|
||||
assert data["result"] == "success"
|
||||
|
||||
|
||||
144
api/tests/unit_tests/services/test_challenge_scorer_service.py
Normal file
144
api/tests/unit_tests/services/test_challenge_scorer_service.py
Normal file
|
|
@ -0,0 +1,144 @@
|
|||
"""Tests for ChallengeScorerService."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from services.challenge_scorer_service import ChallengeScorerService
|
||||
|
||||
|
||||
class TestChallengeScorerService:
|
||||
"""Test custom scorer plugin loading and execution."""
|
||||
|
||||
def test_weighted_scorer_success(self):
|
||||
"""Test WeightedScorer with successful attempt."""
|
||||
metrics = {
|
||||
"succeeded": True,
|
||||
"tokens_total": 1000,
|
||||
"elapsed_ms": 5000, # 5 seconds
|
||||
"rating": 8,
|
||||
"created_at": 1730000000000,
|
||||
}
|
||||
|
||||
config = {
|
||||
"success_bonus": 100.0,
|
||||
"rating_weight": 10.0,
|
||||
"time_penalty": 1.0,
|
||||
"token_penalty": 0.01,
|
||||
}
|
||||
|
||||
ctx = {
|
||||
"tenant_id": "test-tenant",
|
||||
"app_id": "test-app",
|
||||
"workflow_id": "test-workflow",
|
||||
"challenge_id": "test-challenge",
|
||||
"timeout_ms": 5000,
|
||||
}
|
||||
|
||||
result = ChallengeScorerService.score_with_plugin(
|
||||
scorer_plugin_id="builtin.weighted_scorer",
|
||||
scorer_entrypoint="services.scorers.weighted:WeightedScorer",
|
||||
metrics=metrics,
|
||||
config=config,
|
||||
ctx=ctx,
|
||||
)
|
||||
|
||||
# Expected: 100 (success) + 80 (8*10 rating) - 5 (5s*1.0) - 10 (1000*0.01) = 165
|
||||
assert result["score"] == 165.0
|
||||
assert "details" in result
|
||||
assert result["details"]["base"] == 100.0
|
||||
assert result["details"]["rating_contribution"] == 80.0
|
||||
assert result["details"]["time_penalty"] == 5.0
|
||||
assert result["details"]["token_penalty"] == 10.0
|
||||
|
||||
def test_weighted_scorer_failure(self):
|
||||
"""Test WeightedScorer with failed attempt."""
|
||||
metrics = {
|
||||
"succeeded": False,
|
||||
"tokens_total": 500,
|
||||
"elapsed_ms": 2000, # 2 seconds
|
||||
"rating": 3,
|
||||
"created_at": 1730000000000,
|
||||
}
|
||||
|
||||
config = {
|
||||
"success_bonus": 100.0,
|
||||
"rating_weight": 10.0,
|
||||
"time_penalty": 1.0,
|
||||
"token_penalty": 0.01,
|
||||
}
|
||||
|
||||
ctx = {
|
||||
"tenant_id": "test-tenant",
|
||||
"app_id": "test-app",
|
||||
"challenge_id": "test-challenge",
|
||||
"timeout_ms": 5000,
|
||||
}
|
||||
|
||||
result = ChallengeScorerService.score_with_plugin(
|
||||
scorer_plugin_id="builtin.weighted_scorer",
|
||||
scorer_entrypoint="services.scorers.weighted:WeightedScorer",
|
||||
metrics=metrics,
|
||||
config=config,
|
||||
ctx=ctx,
|
||||
)
|
||||
|
||||
# Expected: 0 (no success bonus) + 30 (3*10) - 2 (2s*1.0) - 5 (500*0.01) = 23
|
||||
assert result["score"] == 23.0
|
||||
|
||||
def test_weighted_scorer_minimum_zero(self):
|
||||
"""Test WeightedScorer never returns negative scores."""
|
||||
metrics = {
|
||||
"succeeded": False,
|
||||
"tokens_total": 10000, # High token count
|
||||
"elapsed_ms": 30000, # 30 seconds
|
||||
"rating": 1,
|
||||
"created_at": 1730000000000,
|
||||
}
|
||||
|
||||
config = {
|
||||
"success_bonus": 100.0,
|
||||
"rating_weight": 10.0,
|
||||
"time_penalty": 1.0,
|
||||
"token_penalty": 0.01,
|
||||
}
|
||||
|
||||
ctx = {
|
||||
"tenant_id": "test-tenant",
|
||||
"app_id": "test-app",
|
||||
"challenge_id": "test-challenge",
|
||||
"timeout_ms": 5000,
|
||||
}
|
||||
|
||||
result = ChallengeScorerService.score_with_plugin(
|
||||
scorer_plugin_id="builtin.weighted_scorer",
|
||||
scorer_entrypoint="services.scorers.weighted:WeightedScorer",
|
||||
metrics=metrics,
|
||||
config=config,
|
||||
ctx=ctx,
|
||||
)
|
||||
|
||||
# Expected: 0 + 10 - 30 - 100 = -120, but clamped to 0
|
||||
assert result["score"] == 0.0
|
||||
|
||||
def test_scorer_with_missing_plugin(self):
|
||||
"""Test error handling for missing plugin."""
|
||||
with pytest.raises(ValueError, match="Failed to load scorer plugin"):
|
||||
ChallengeScorerService.score_with_plugin(
|
||||
scorer_plugin_id="nonexistent",
|
||||
scorer_entrypoint="nonexistent.module:NonexistentScorer",
|
||||
metrics={},
|
||||
config={},
|
||||
ctx={"timeout_ms": 5000},
|
||||
)
|
||||
|
||||
def test_scorer_with_invalid_entrypoint(self):
|
||||
"""Test error handling for invalid entrypoint format."""
|
||||
with pytest.raises(ValueError, match="scorer_plugin_id and scorer_entrypoint are required"):
|
||||
ChallengeScorerService.score_with_plugin(
|
||||
scorer_plugin_id=None,
|
||||
scorer_entrypoint=None,
|
||||
metrics={},
|
||||
config={},
|
||||
ctx={"timeout_ms": 5000},
|
||||
)
|
||||
40
api/tests/unit_tests/services/test_challenge_service.py
Normal file
40
api/tests/unit_tests/services/test_challenge_service.py
Normal file
|
|
@ -0,0 +1,40 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from models.challenge import ChallengeAttempt
|
||||
from services.challenge_service import ChallengeService
|
||||
|
||||
|
||||
def test_evaluate_outcome_regex_match():
|
||||
ok, details = ChallengeService.evaluate_outcome(
|
||||
"Hello SECRET",
|
||||
{"success_type": "regex", "success_pattern": "secret"},
|
||||
)
|
||||
assert ok is True
|
||||
assert details.get("mode") == "regex"
|
||||
|
||||
|
||||
def test_evaluate_outcome_contains():
|
||||
ok, _ = ChallengeService.evaluate_outcome(
|
||||
"hello world",
|
||||
{"success_type": "contains", "success_pattern": "world"},
|
||||
)
|
||||
assert ok is True
|
||||
|
||||
|
||||
def test_record_attempt_creates_row(mocker):
|
||||
# mock db.session
|
||||
session = mocker.MagicMock()
|
||||
attempt = ChallengeService.record_attempt(
|
||||
tenant_id="t1",
|
||||
challenge_id="c1",
|
||||
end_user_id=None,
|
||||
account_id=None,
|
||||
workflow_run_id=None,
|
||||
succeeded=True,
|
||||
score=10.0,
|
||||
session=session,
|
||||
)
|
||||
assert isinstance(attempt, ChallengeAttempt)
|
||||
session.add.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
34
api/tests/unit_tests/services/test_red_blue_service.py
Normal file
34
api/tests/unit_tests/services/test_red_blue_service.py
Normal file
|
|
@ -0,0 +1,34 @@
|
|||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
|
||||
from services.red_blue_service import RedBlueService
|
||||
|
||||
|
||||
def test_submit_prompt_creates_submission(mocker):
|
||||
session = mocker.MagicMock()
|
||||
sub = RedBlueService.submit_prompt(
|
||||
challenge_id="cid",
|
||||
tenant_id="tid",
|
||||
team="red",
|
||||
prompt="attack",
|
||||
account_id="aid",
|
||||
end_user_id="eid",
|
||||
session=session,
|
||||
)
|
||||
assert sub.team == "red"
|
||||
session.add.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_select_counterparty_submission_latest_active(mocker):
|
||||
c = SimpleNamespace(id="cid")
|
||||
session = mocker.MagicMock()
|
||||
qs = (
|
||||
session.query.return_value.filter.return_value.order_by.return_value
|
||||
)
|
||||
qs.first.return_value = SimpleNamespace(id="subid", team="blue")
|
||||
sub = RedBlueService.select_counterparty_submission(challenge=c, team="red", session=session)
|
||||
assert sub.team == "blue"
|
||||
|
||||
|
||||
Loading…
Add table
Add a link
Reference in a new issue