Add Stripe payment integration for AI subscriptions

Implement subscription-based AI access with 250 generations/month at $5/month or $50/year.

Changes:
- Backend: Stripe service, payment routes, webhook handlers, generation tracking
- Frontend: Upgrade page with pricing, payment success/cancel pages, UI prompts
- Database: Add subscription fields to users, payments table, migrations
- Config: Stripe env vars to .env.example, docker-compose.prod.yml, PRODUCTION.md
- Tests: Payment route tests, component tests, subscription hook tests

Users without AI access see upgrade prompts; subscribers see remaining generation count.
This commit is contained in:
Joey Yakimowich-Payne 2026-01-21 16:11:03 -07:00
commit 2e12edc249
No known key found for this signature in database
GPG key ID: DDF6AF5B21B407D4
22 changed files with 2866 additions and 21 deletions

View file

@ -50,6 +50,15 @@ LOG_REQUESTS=false
# ==============================================================================
GEMINI_API_KEY=
# ==============================================================================
# OPTIONAL - Stripe Payments (for AI subscription access)
# Required for paid AI access. Get keys at: https://dashboard.stripe.com/apikeys
# ==============================================================================
STRIPE_SECRET_KEY=
STRIPE_WEBHOOK_SECRET=
STRIPE_PRICE_ID_MONTHLY=
STRIPE_PRICE_ID_YEARLY=
# ==============================================================================
# OPTIONAL - TURN Server (REQUIRED for cross-network multiplayer)
# Without TURN, players behind restrictive NATs/firewalls cannot connect.

22
App.tsx
View file

@ -18,6 +18,8 @@ import { DisconnectedScreen } from './components/DisconnectedScreen';
import { WaitingToRejoin } from './components/WaitingToRejoin';
import { HostReconnected } from './components/HostReconnected';
import { SharedQuizView } from './components/SharedQuizView';
import { UpgradePage } from './components/UpgradePage';
import { PaymentResult } from './components/PaymentResult';
import type { Quiz, GameConfig } from './types';
const seededRandom = (seed: number) => {
@ -146,6 +148,26 @@ function App() {
const sharedMatch = location.pathname.match(/^\/shared\/([a-zA-Z0-9_-]+)$/);
const isSharedQuizRoute = !!sharedMatch && gameState === 'LANDING';
const isUpgradeRoute = location.pathname === '/upgrade' && gameState === 'LANDING';
const isPaymentSuccessRoute = location.pathname === '/payment/success' && gameState === 'LANDING';
const isPaymentCancelRoute = location.pathname === '/payment/cancel' && gameState === 'LANDING';
const navigateHome = () => {
window.history.replaceState({}, document.title, '/');
window.location.reload();
};
if (isUpgradeRoute) {
return <UpgradePage onBack={navigateHome} />;
}
if (isPaymentSuccessRoute) {
return <PaymentResult status="success" onBack={navigateHome} />;
}
if (isPaymentCancelRoute) {
return <PaymentResult status="cancel" onBack={navigateHome} />;
}
if (isSharedQuizRoute) {
return (

View file

@ -1,7 +1,7 @@
import React, { useState, useEffect } from 'react';
import { useSearchParams } from 'react-router-dom';
import { motion, AnimatePresence } from 'framer-motion';
import { BrainCircuit, Loader2, Play, PenTool, BookOpen, Upload, X, FileText, Image, ScanText, Sparkles, Settings, Palette, Lock } from 'lucide-react';
import { BrainCircuit, Loader2, Play, PenTool, BookOpen, Upload, X, FileText, Image, ScanText, Sparkles, Settings, Palette, Lock, Zap } from 'lucide-react';
import { useAuth } from 'react-oidc-context';
import { AuthButton } from './AuthButton';
import { QuizLibrary } from './QuizLibrary';
@ -133,7 +133,7 @@ export const Landing: React.FC<LandingProps> = ({ onGenerate, onCreateManual, on
const showOcrOption = hasImageFile || hasDocumentFile;
const { defaultConfig, saving: savingConfig, saveDefaultConfig } = useUserConfig();
const { preferences, hasAIAccess, saving: savingPrefs, savePreferences, applyColorScheme } = useUserPreferences();
const { preferences, hasAIAccess, subscription, saving: savingPrefs, savePreferences, applyColorScheme } = useUserPreferences();
const hasValidApiKey = (() => {
if (preferences.aiProvider === 'openrouter') return !!preferences.openRouterApiKey;
if (preferences.aiProvider === 'openai') return !!preferences.openAIApiKey;
@ -319,6 +319,24 @@ export const Landing: React.FC<LandingProps> = ({ onGenerate, onCreateManual, on
</>
)}
<AuthButton onAccountSettingsClick={() => setAccountSettingsOpen(true)} />
{auth.isAuthenticated && subscription && subscription.accessType === 'subscription' && subscription.generationsRemaining !== null && (
<div className="flex items-center gap-2 bg-white/90 px-3 py-2 rounded-xl shadow-md text-sm font-bold">
<Zap size={16} className="text-theme-primary" />
<span className="text-gray-700">{subscription.generationsRemaining}</span>
<span className="text-gray-400">left</span>
</div>
)}
{auth.isAuthenticated && !hasAIAccess && (!subscription || subscription.accessType === 'none') && (
<a
href="/upgrade"
className="flex items-center gap-2 bg-gradient-to-r from-violet-600 to-indigo-600 text-white px-4 py-2 rounded-xl shadow-md text-sm font-bold hover:brightness-110 transition-all"
>
<Zap size={16} />
Upgrade
</a>
)}
</div>
<motion.div
initial={{ scale: 0.8, opacity: 0, rotate: -2 }}
@ -561,20 +579,39 @@ export const Landing: React.FC<LandingProps> = ({ onGenerate, onCreateManual, on
)}
{!canUseAI && (
<button
onClick={() => setSearchParams(buildCleanParams({ modal: 'account' }))}
className="w-full p-4 bg-theme-primary/5 border-2 border-theme-primary/20 rounded-2xl text-left hover:border-theme-primary hover:bg-theme-primary/10 hover:shadow-lg hover:scale-[1.02] transition-all group"
>
<div className="flex items-center gap-3">
<div className="p-2 bg-theme-primary/15 rounded-xl group-hover:bg-theme-primary/25 transition-colors">
<Sparkles size={20} className="text-theme-primary" />
<div className="space-y-3">
{!hasAIAccess && (!subscription || subscription.accessType === 'none') && (
<a
href="/upgrade"
className="w-full p-4 bg-gradient-to-r from-violet-50 to-indigo-50 border-2 border-violet-200 rounded-2xl text-left hover:border-violet-400 hover:shadow-lg hover:scale-[1.02] transition-all group block"
>
<div className="flex items-center gap-3">
<div className="p-2 bg-gradient-to-br from-violet-500 to-indigo-600 rounded-xl shadow-lg">
<Zap size={20} className="text-white" />
</div>
<div className="flex-1">
<p className="font-bold text-gray-800">Unlock AI Quiz Generation</p>
<p className="text-sm text-gray-500">Get 250 AI generations/month for $5</p>
</div>
<span className="text-violet-600 font-black text-sm">Upgrade</span>
</div>
</a>
)}
<button
onClick={() => setSearchParams(buildCleanParams({ modal: 'account' }))}
className="w-full p-4 bg-theme-primary/5 border-2 border-theme-primary/20 rounded-2xl text-left hover:border-theme-primary hover:bg-theme-primary/10 hover:shadow-lg hover:scale-[1.02] transition-all group"
>
<div className="flex items-center gap-3">
<div className="p-2 bg-theme-primary/15 rounded-xl group-hover:bg-theme-primary/25 transition-colors">
<Sparkles size={20} className="text-theme-primary" />
</div>
<div>
<p className="font-bold text-gray-800">Use Your Own API Key</p>
<p className="text-sm text-gray-500">Configure your API key in settings</p>
</div>
</div>
<div>
<p className="font-bold text-gray-800">AI Quiz Generation Available</p>
<p className="text-sm text-gray-500">Configure your API key in settings to get started</p>
</div>
</div>
</button>
</button>
</div>
)}
<button

View file

@ -0,0 +1,134 @@
import React, { useEffect, useState } from 'react';
import { motion } from 'framer-motion';
import { CheckCircle2, XCircle, Loader2, PartyPopper, ArrowLeft } from 'lucide-react';
import confetti from 'canvas-confetti';
interface PaymentResultProps {
status: 'success' | 'cancel' | 'loading';
onBack?: () => void;
}
export const PaymentResult: React.FC<PaymentResultProps> = ({ status, onBack }) => {
const [showConfetti, setShowConfetti] = useState(false);
useEffect(() => {
if (status === 'success' && !showConfetti) {
setShowConfetti(true);
const duration = 3000;
const end = Date.now() + duration;
const frame = () => {
confetti({
particleCount: 3,
angle: 60,
spread: 55,
origin: { x: 0, y: 0.7 },
colors: ['#8B5CF6', '#6366F1', '#EC4899', '#F59E0B']
});
confetti({
particleCount: 3,
angle: 120,
spread: 55,
origin: { x: 1, y: 0.7 },
colors: ['#8B5CF6', '#6366F1', '#EC4899', '#F59E0B']
});
if (Date.now() < end) {
requestAnimationFrame(frame);
}
};
frame();
}
}, [status, showConfetti]);
if (status === 'loading') {
return (
<div className="min-h-screen bg-gray-50 flex items-center justify-center">
<div className="text-center">
<Loader2 className="w-12 h-12 animate-spin text-theme-primary mx-auto mb-4" />
<p className="text-gray-500 font-bold">Processing your payment...</p>
</div>
</div>
);
}
const isSuccess = status === 'success';
return (
<div className="min-h-screen bg-gray-50 flex items-center justify-center p-4">
<motion.div
initial={{ scale: 0.8, opacity: 0 }}
animate={{ scale: 1, opacity: 1 }}
transition={{ type: 'spring', bounce: 0.4 }}
className="max-w-md w-full bg-white rounded-[2.5rem] p-8 shadow-[0_20px_50px_rgba(0,0,0,0.1)] border-4 border-white text-center relative overflow-hidden"
>
{isSuccess && (
<div className="absolute inset-0 bg-gradient-to-br from-green-50 to-emerald-50 pointer-events-none" />
)}
<div className="relative z-10">
<motion.div
initial={{ scale: 0 }}
animate={{ scale: 1 }}
transition={{ delay: 0.2, type: 'spring', bounce: 0.5 }}
className={`w-24 h-24 rounded-3xl mx-auto mb-6 shadow-xl flex items-center justify-center ${
isSuccess
? 'bg-gradient-to-br from-green-400 to-emerald-500'
: 'bg-gradient-to-br from-gray-300 to-gray-400'
}`}
>
{isSuccess ? (
<PartyPopper className="w-12 h-12 text-white" />
) : (
<XCircle className="w-12 h-12 text-white" />
)}
</motion.div>
<motion.div
initial={{ y: 20, opacity: 0 }}
animate={{ y: 0, opacity: 1 }}
transition={{ delay: 0.3 }}
>
<h2 className="text-3xl font-black text-gray-900 mb-2">
{isSuccess ? 'Welcome to Pro!' : 'Payment Cancelled'}
</h2>
<p className="text-gray-500 font-bold mb-8">
{isSuccess
? 'Your AI powers are now unlocked. Time to create amazing quizzes!'
: 'No worries! You can upgrade anytime when you\'re ready.'}
</p>
</motion.div>
{isSuccess && (
<motion.div
initial={{ y: 20, opacity: 0 }}
animate={{ y: 0, opacity: 1 }}
transition={{ delay: 0.4 }}
className="bg-green-50 border-2 border-green-200 rounded-2xl p-4 mb-8"
>
<div className="flex items-center justify-center gap-2 text-green-700 font-bold">
<CheckCircle2 size={20} />
<span>250 AI generations ready to use</span>
</div>
</motion.div>
)}
<motion.button
initial={{ y: 20, opacity: 0 }}
animate={{ y: 0, opacity: 1 }}
transition={{ delay: 0.5 }}
onClick={onBack}
className={`w-full py-4 rounded-2xl font-black text-lg shadow-[0_6px_0] active:shadow-none active:translate-y-[6px] transition-all flex items-center justify-center gap-2 ${
isSuccess
? 'bg-gray-900 text-white shadow-black hover:bg-black'
: 'bg-theme-primary text-white shadow-theme-primary-dark hover:brightness-110'
}`}
>
<ArrowLeft size={20} />
{isSuccess ? 'Start Creating' : 'Go Back'}
</motion.button>
</div>
</motion.div>
</div>
);
};

362
components/UpgradePage.tsx Normal file
View file

@ -0,0 +1,362 @@
import React, { useState, useEffect } from 'react';
import { motion, AnimatePresence } from 'framer-motion';
import {
Check,
X,
Zap,
Crown,
Rocket,
ShieldCheck,
ArrowLeft,
Sparkles,
Loader2,
Star
} from 'lucide-react';
import { useAuth } from 'react-oidc-context';
interface UpgradePageProps {
onBack?: () => void;
}
type BillingCycle = 'monthly' | 'yearly';
export const UpgradePage: React.FC<UpgradePageProps> = ({ onBack }) => {
const auth = useAuth();
const [billingCycle, setBillingCycle] = useState<BillingCycle>('yearly');
const [isLoading, setIsLoading] = useState(false);
const [statusLoading, setStatusLoading] = useState(true);
const [hasAccess, setHasAccess] = useState(false);
const [error, setError] = useState<string | null>(null);
useEffect(() => {
checkStatus();
}, []);
const checkStatus = async () => {
try {
const token = auth.user?.access_token;
if (!token) return;
const backendUrl = import.meta.env.VITE_BACKEND_URL || 'http://localhost:3001';
const response = await fetch(`${backendUrl}/api/payments/status`, {
headers: {
'Authorization': `Bearer ${token}`
}
});
if (response.ok) {
const data = await response.json();
setHasAccess(data.hasAccess);
}
} catch (err) {
console.error('Failed to check status:', err);
} finally {
setStatusLoading(false);
}
};
const handleCheckout = async () => {
try {
setIsLoading(true);
setError(null);
const token = auth.user?.access_token;
if (!token) {
auth.signinRedirect();
return;
}
const backendUrl = import.meta.env.VITE_BACKEND_URL || 'http://localhost:3001';
const response = await fetch(`${backendUrl}/api/payments/checkout`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'Authorization': `Bearer ${token}`
},
body: JSON.stringify({
planType: billingCycle,
successUrl: `${window.location.origin}/payment/success`,
cancelUrl: `${window.location.origin}/payment/cancel`
})
});
if (!response.ok) {
throw new Error('Failed to initiate checkout');
}
const { url } = await response.json();
window.location.href = url;
} catch (err) {
setError('Something went wrong. Please try again.');
setIsLoading(false);
}
};
const containerVariants = {
hidden: { opacity: 0 },
visible: {
opacity: 1,
transition: {
staggerChildren: 0.1
}
}
};
const itemVariants = {
hidden: { y: 20, opacity: 0 },
visible: {
y: 0,
opacity: 1,
transition: { type: 'spring', bounce: 0.4 }
}
};
if (statusLoading) {
return (
<div className="flex items-center justify-center min-h-screen bg-gray-50">
<Loader2 className="w-8 h-8 animate-spin text-theme-primary" />
</div>
);
}
if (hasAccess) {
return (
<div className="min-h-screen bg-gray-50 flex items-center justify-center p-4">
<motion.div
initial={{ scale: 0.9, opacity: 0 }}
animate={{ scale: 1, opacity: 1 }}
className="max-w-md w-full bg-white rounded-[2.5rem] p-8 shadow-[0_20px_50px_rgba(0,0,0,0.1)] border-4 border-white text-center relative overflow-hidden"
>
<div className="absolute inset-0 bg-gradient-to-br from-yellow-100/50 to-orange-100/50 pointer-events-none" />
<div className="relative z-10">
<div className="w-24 h-24 bg-gradient-to-br from-yellow-400 to-orange-500 rounded-3xl mx-auto mb-6 shadow-xl rotate-3 flex items-center justify-center">
<Crown className="w-12 h-12 text-white" />
</div>
<h2 className="text-3xl font-black text-gray-900 mb-2">You're a Pro!</h2>
<p className="text-gray-500 font-bold mb-8">
You have unlocked unlimited power. Enjoy your premium features!
</p>
<button
onClick={onBack}
className="w-full bg-gray-900 text-white py-4 rounded-2xl font-black text-lg shadow-[0_6px_0_rgba(0,0,0,1)] active:shadow-none active:translate-y-[6px] transition-all hover:bg-black"
>
Back to Game
</button>
</div>
<Sparkles className="absolute top-10 left-10 text-yellow-400 w-6 h-6 animate-pulse" />
<Star className="absolute bottom-10 right-10 text-orange-400 w-8 h-8 animate-bounce" />
</motion.div>
</div>
);
}
return (
<div className="min-h-screen bg-gray-50 overflow-y-auto overflow-x-hidden">
<div className="max-w-6xl mx-auto px-4 py-8 md:py-12">
<motion.header
initial={{ y: -20, opacity: 0 }}
animate={{ y: 0, opacity: 1 }}
className="mb-12 relative text-center"
>
{onBack && (
<button
onClick={onBack}
className="absolute left-0 top-1/2 -translate-y-1/2 p-3 bg-white rounded-2xl shadow-[0_4px_0_rgba(0,0,0,0.1)] hover:shadow-[0_6px_0_rgba(0,0,0,0.1)] active:shadow-none active:translate-y-[2px] transition-all text-gray-500 hover:text-gray-900"
>
<ArrowLeft size={24} />
</button>
)}
<h1 className="text-4xl md:text-6xl font-black text-gray-900 mb-4 tracking-tight">
Unlock <span className="text-transparent bg-clip-text bg-gradient-to-r from-violet-600 to-indigo-600">Unlimited</span> Power
</h1>
<p className="text-gray-500 font-bold text-lg md:text-xl max-w-2xl mx-auto">
Supercharge your quizzes with AI magic. Create more, play more, win more.
</p>
</motion.header>
<div className="flex justify-center mb-12">
<div className="bg-white p-1.5 rounded-full shadow-[0_8px_30px_rgba(0,0,0,0.05)] border border-gray-100 flex relative">
<motion.div
className="absolute top-1.5 bottom-1.5 bg-gray-900 rounded-full shadow-lg"
layoutId="toggle"
initial={false}
animate={{
left: billingCycle === 'monthly' ? '6px' : '50%',
width: 'calc(50% - 9px)',
x: billingCycle === 'monthly' ? 0 : 3
}}
transition={{ type: "spring", bounce: 0.2, duration: 0.6 }}
/>
<button
onClick={() => setBillingCycle('monthly')}
className={`relative z-10 px-8 py-3 rounded-full font-black text-sm transition-colors duration-200 ${
billingCycle === 'monthly' ? 'text-white' : 'text-gray-500 hover:text-gray-900'
}`}
>
Monthly
</button>
<button
onClick={() => setBillingCycle('yearly')}
className={`relative z-10 px-8 py-3 rounded-full font-black text-sm transition-colors duration-200 flex items-center gap-2 ${
billingCycle === 'yearly' ? 'text-white' : 'text-gray-500 hover:text-gray-900'
}`}
>
Yearly
<span className={`text-[10px] px-2 py-0.5 rounded-full font-bold uppercase tracking-wide ${
billingCycle === 'yearly' ? 'bg-white/20 text-white' : 'bg-green-100 text-green-700'
}`}>
-17%
</span>
</button>
</div>
</div>
<motion.div
variants={containerVariants}
initial="hidden"
animate="visible"
className="grid md:grid-cols-2 gap-8 max-w-4xl mx-auto items-start"
>
<motion.div
variants={itemVariants}
className="bg-white p-8 rounded-[2.5rem] shadow-xl border-4 border-transparent relative group"
>
<div className="mb-6">
<div className="w-16 h-16 bg-gray-100 rounded-2xl flex items-center justify-center mb-4 group-hover:scale-110 transition-transform duration-300">
<Zap className="w-8 h-8 text-gray-400" />
</div>
<h3 className="text-2xl font-black text-gray-900">Starter</h3>
<p className="text-gray-400 font-bold mt-1">For casual players</p>
</div>
<div className="mb-8">
<span className="text-5xl font-black text-gray-900">$0</span>
<span className="text-gray-400 font-bold">/forever</span>
</div>
<ul className="space-y-4 mb-8">
<FeatureItem text="5 AI generations per month" included />
<FeatureItem text="Basic quiz topics" included />
<FeatureItem text="Host up to 10 players" included />
<FeatureItem text="Advanced document analysis" included={false} />
<FeatureItem text="Priority support" included={false} />
</ul>
<button
disabled
className="w-full bg-gray-100 text-gray-400 py-4 rounded-2xl font-black text-lg cursor-not-allowed"
>
Current Plan
</button>
</motion.div>
<motion.div
variants={itemVariants}
className="bg-gray-900 p-8 rounded-[2.5rem] shadow-[0_20px_60px_rgba(79,70,229,0.3)] border-4 border-transparent relative overflow-hidden group transform md:-translate-y-4"
>
<div className="absolute top-0 right-0 w-64 h-64 bg-violet-600/20 blur-[80px] rounded-full pointer-events-none" />
<div className="absolute bottom-0 left-0 w-64 h-64 bg-indigo-600/20 blur-[80px] rounded-full pointer-events-none" />
<div className="relative z-10">
<div className="flex justify-between items-start mb-6">
<div className="w-16 h-16 bg-gradient-to-br from-violet-500 to-indigo-600 rounded-2xl flex items-center justify-center mb-4 shadow-lg group-hover:scale-110 transition-transform duration-300 group-hover:rotate-3">
<Rocket className="w-8 h-8 text-white" />
</div>
<span className="bg-white/10 backdrop-blur-md text-white px-4 py-1.5 rounded-full text-xs font-black uppercase tracking-wider border border-white/20">
Most Popular
</span>
</div>
<h3 className="text-2xl font-black text-white">Pro Gamer</h3>
<p className="text-gray-400 font-bold mt-1">For serious hosts</p>
<div className="my-8">
<div className="flex items-baseline gap-1">
<span className="text-6xl font-black text-white">
${billingCycle === 'monthly' ? '5' : '4.17'}
</span>
<span className="text-gray-400 font-bold">/mo</span>
</div>
{billingCycle === 'yearly' && (
<p className="text-sm font-bold text-green-400 mt-2">
Billed $50 yearly (save $10)
</p>
)}
</div>
<ul className="space-y-4 mb-8">
<FeatureItem text="250 AI generations per month" included dark />
<FeatureItem text="Unlimited document uploads" included dark />
<FeatureItem text="Host up to 100 players" included dark />
<FeatureItem text="Priority AI processing" included dark />
<FeatureItem text="Early access to new features" included dark />
</ul>
<button
onClick={handleCheckout}
disabled={isLoading}
className="w-full bg-gradient-to-r from-violet-600 to-indigo-600 text-white py-4 rounded-2xl font-black text-lg shadow-[0_6px_0_rgba(67,56,202,1)] active:shadow-none active:translate-y-[6px] transition-all hover:brightness-110 flex items-center justify-center gap-2 disabled:opacity-70 disabled:cursor-not-allowed"
>
{isLoading ? (
<Loader2 className="animate-spin" />
) : (
<>
Upgrade Now <Sparkles size={20} className="text-yellow-300" />
</>
)}
</button>
{error && (
<p className="mt-3 text-red-400 text-sm font-bold text-center animate-pulse">
{error}
</p>
)}
</div>
</motion.div>
</motion.div>
<motion.div
initial={{ opacity: 0 }}
animate={{ opacity: 1 }}
transition={{ delay: 0.5 }}
className="mt-16 text-center space-y-4"
>
<div className="flex flex-wrap justify-center gap-6 text-gray-400 font-bold text-sm">
<span className="flex items-center gap-2">
<ShieldCheck size={18} /> Secure payment via Stripe
</span>
<span className="flex items-center gap-2">
<Check size={18} /> Cancel anytime
</span>
<span className="flex items-center gap-2">
<Check size={18} /> 7-day money-back guarantee
</span>
</div>
</motion.div>
</div>
</div>
);
};
const FeatureItem: React.FC<{ text: string; included: boolean; dark?: boolean }> = ({ text, included, dark }) => (
<li className="flex items-center gap-3">
<div className={`
flex-shrink-0 w-6 h-6 rounded-full flex items-center justify-center
${included
? (dark ? 'bg-violet-500/20 text-violet-400' : 'bg-theme-primary/10 text-theme-primary')
: 'bg-gray-100 text-gray-400'}
`}>
{included ? <Check size={14} strokeWidth={3} /> : <X size={14} strokeWidth={3} />}
</div>
<span className={`font-bold ${dark ? 'text-gray-300' : (included ? 'text-gray-700' : 'text-gray-400')}`}>
{text}
</span>
</li>
);

View file

@ -101,6 +101,10 @@ services:
CORS_ORIGIN: ${CORS_ORIGIN}
LOG_REQUESTS: ${LOG_REQUESTS:-true}
GEMINI_API_KEY: ${GEMINI_API_KEY:-}
STRIPE_SECRET_KEY: ${STRIPE_SECRET_KEY:-}
STRIPE_WEBHOOK_SECRET: ${STRIPE_WEBHOOK_SECRET:-}
STRIPE_PRICE_ID_MONTHLY: ${STRIPE_PRICE_ID_MONTHLY:-}
STRIPE_PRICE_ID_YEARLY: ${STRIPE_PRICE_ID_YEARLY:-}
volumes:
- kaboot-data:/data
networks:

View file

@ -0,0 +1,226 @@
# Kaboot Payment Feature Implementation Plan
## Overview
Add Stripe subscription payments to allow users to pay for AI access (`kaboot-ai-access`). Users get 250 AI generations per month for $5/month (or yearly equivalent).
### Pricing Model
- **Monthly**: $5/month for 250 AI generations
- **Yearly**: $50/year for 250 AI generations/month (save ~17%)
- **Grace Period**: 1 day after failed payment before revoking access
- **Refund Policy**: 7-day money-back guarantee
---
## Implementation Checklist
### Phase 1: Backend Infrastructure
- [x] **1.1** Add Stripe dependency to server
- `npm install stripe` in server directory
- File: `server/package.json`
- [x] **1.2** Add environment variables
- `STRIPE_SECRET_KEY` - Stripe secret key
- `STRIPE_WEBHOOK_SECRET` - Webhook signing secret
- `STRIPE_PRICE_ID_MONTHLY` - Monthly price ID
- `STRIPE_PRICE_ID_YEARLY` - Yearly price ID
- Files: `.env.example`, `docs/PRODUCTION.md`
- [x] **1.3** Database migration - Add subscription and generation tracking
- Add `stripe_customer_id` to users table
- Add `subscription_status` (none, active, past_due, canceled)
- Add `subscription_id` for Stripe subscription ID
- Add `subscription_current_period_end` for billing cycle
- Add `generation_count` for current period usage
- Add `generation_reset_date` for when to reset count
- Create `payments` table for payment history
- File: `server/src/db/schema.sql`
- [x] **1.4** Create Stripe service
- Initialize Stripe client
- Create/retrieve customer helper
- Create checkout session helper
- Create customer portal session helper
- File: `server/src/services/stripe.ts`
- [x] **1.5** Create payments routes
- `POST /api/payments/checkout` - Create Stripe Checkout session
- `POST /api/payments/webhook` - Handle Stripe webhooks (raw body)
- `GET /api/payments/status` - Get subscription & generation status
- `POST /api/payments/portal` - Create customer portal session
- File: `server/src/routes/payments.ts`
- [x] **1.6** Implement webhook handlers
- `checkout.session.completed` - Activate subscription, set generation quota
- `customer.subscription.updated` - Sync status changes
- `customer.subscription.deleted` - Mark as canceled
- `invoice.payment_failed` - Set past_due status
- `invoice.paid` - Reset generation count on renewal
- File: `server/src/routes/payments.ts`
- [x] **1.7** Update AI access middleware
- Check subscription status OR existing group membership
- Check generation count against limit (250)
- Increment generation count on AI use
- Return remaining generations in response
- Files: `server/src/middleware/auth.ts`, `server/src/routes/ai.ts` (or equivalent)
- [x] **1.8** Register payments router in main app
- File: `server/src/index.ts`
### Phase 2: Frontend - Upgrade Page
- [x] **2.1** Create UpgradePage component
- Pricing card with monthly/yearly toggle
- Feature comparison (Free vs Pro)
- CTA button triggering Stripe Checkout
- Trust signals (secure payment, money-back guarantee)
- File: `components/UpgradePage.tsx`
- [x] **2.2** Create PaymentResult component
- Success state with confetti
- Cancel/return state
- File: `components/PaymentResult.tsx`
- [x] **2.3** Add routes to App.tsx
- `/upgrade` route
- `/payment/success` route
- `/payment/cancel` route
- File: `App.tsx`
- [x] **2.4** Create payments API service (integrated in UpgradePage)
- `createCheckoutSession(planType: 'monthly' | 'yearly')`
- `getSubscriptionStatus()`
- `createPortalSession()`
- File: `services/paymentsApi.ts`
- [x] **2.5** Update UI to show generation usage
- Show remaining generations in preferences/header
- Show upgrade CTA when generations low or user is free tier
- Files: Various components
- [x] **2.6** Add upgrade prompts in AI generation flow
- When user tries AI generation without access
- When user is low on generations
- Files: Components using AI generation
### Phase 3: Production Updates
- [x] **3.1** Update docker-compose.prod.yml
- Add Stripe environment variables to backend service
- File: `docker-compose.prod.yml`
- [x] **3.2** Update PRODUCTION.md documentation
- Add Stripe configuration section
- Add webhook setup instructions
- Add Stripe Dashboard product setup
- File: `docs/PRODUCTION.md`
- [x] **3.3** Update setup-prod.sh script (not needed - manual env config)
- Prompt for Stripe keys during setup
- File: `scripts/setup-prod.sh`
### Phase 4: Testing
- [ ] **4.1** Test with Stripe test mode
- Use test API keys
- Test card: 4242 4242 4242 4242
- [ ] **4.2** Test webhook locally
- Use Stripe CLI: `stripe listen --forward-to localhost:3001/api/payments/webhook`
- [ ] **4.3** Test full payment flow
- Checkout → Success → Access granted → Generations work
- [ ] **4.4** Test generation limits
- Verify count increments
- Verify block at 250
- Verify reset on renewal
---
## Database Schema Changes
```sql
-- Add subscription fields to users table
ALTER TABLE users ADD COLUMN stripe_customer_id TEXT UNIQUE;
ALTER TABLE users ADD COLUMN subscription_status TEXT DEFAULT 'none';
ALTER TABLE users ADD COLUMN subscription_id TEXT;
ALTER TABLE users ADD COLUMN subscription_current_period_end DATETIME;
ALTER TABLE users ADD COLUMN generation_count INTEGER DEFAULT 0;
ALTER TABLE users ADD COLUMN generation_reset_date DATETIME;
-- Payments log table
CREATE TABLE IF NOT EXISTS payments (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL REFERENCES users(id),
stripe_payment_intent_id TEXT,
stripe_invoice_id TEXT,
amount INTEGER NOT NULL,
currency TEXT DEFAULT 'usd',
status TEXT NOT NULL,
description TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_payments_user ON payments(user_id);
```
---
## Environment Variables
```bash
# Stripe Configuration
STRIPE_SECRET_KEY=sk_test_... # or sk_live_... for production
STRIPE_WEBHOOK_SECRET=whsec_... # From Stripe Dashboard or CLI
STRIPE_PRICE_ID_MONTHLY=price_... # Monthly plan price ID
STRIPE_PRICE_ID_YEARLY=price_... # Yearly plan price ID
```
---
## API Endpoints
| Method | Endpoint | Auth | Description |
|--------|----------|------|-------------|
| `POST` | `/api/payments/checkout` | Required | Create Stripe Checkout session |
| `POST` | `/api/payments/webhook` | Stripe Sig | Handle Stripe webhook events |
| `GET` | `/api/payments/status` | Required | Get subscription & generation status |
| `POST` | `/api/payments/portal` | Required | Create Stripe Customer Portal session |
---
## Stripe Dashboard Setup
1. Create Product: "Kaboot AI Pro"
2. Add Monthly Price: $5.00/month
3. Add Yearly Price: $50.00/year
4. Copy Price IDs to environment variables
5. Set up Webhook endpoint: `https://your-domain.com/api/payments/webhook`
6. Subscribe to events:
- `checkout.session.completed`
- `customer.subscription.created`
- `customer.subscription.updated`
- `customer.subscription.deleted`
- `invoice.paid`
- `invoice.payment_failed`
---
## Generation Tracking Logic
1. On subscription activation: Set `generation_count = 0`, `generation_reset_date = period_end`
2. On each AI generation: Increment `generation_count`
3. Before AI generation: Check `generation_count < 250`
4. On `invoice.paid` (renewal): Reset `generation_count = 0`, update `generation_reset_date`
5. Return `remaining_generations = 250 - generation_count` in API responses
---
## Notes
- Existing `kaboot-ai-access` group users (via Authentik) get unlimited access (grandfathered)
- Subscription users get 250 generations/month regardless of Authentik group
- Both access methods are valid - check either condition in middleware

View file

@ -204,6 +204,45 @@ The frontend is built inside Docker using the `KABOOT_DOMAIN` and `AUTH_DOMAIN`
The `setup-prod.sh` script sets these domain variables automatically.
### Stripe Payments Configuration (Optional)
To enable paid AI access subscriptions, configure Stripe:
```env
# Stripe API Keys (get from https://dashboard.stripe.com/apikeys)
STRIPE_SECRET_KEY=sk_live_... # Use sk_test_... for testing
STRIPE_WEBHOOK_SECRET=whsec_... # From webhook endpoint configuration
STRIPE_PRICE_ID_MONTHLY=price_... # Monthly subscription price ID
STRIPE_PRICE_ID_YEARLY=price_... # Yearly subscription price ID
```
#### Stripe Dashboard Setup
1. **Create a Product** in [Stripe Dashboard](https://dashboard.stripe.com/products):
- Name: "Kaboot AI Pro"
- Description: "250 AI quiz generations per month"
2. **Add Pricing**:
- Monthly: $5.00/month (recurring)
- Yearly: $50.00/year (recurring)
- Copy the Price IDs (start with `price_`)
3. **Configure Webhook**:
- Go to [Developers > Webhooks](https://dashboard.stripe.com/webhooks)
- Add endpoint: `https://your-domain.com/api/payments/webhook`
- Select events:
- `checkout.session.completed`
- `customer.subscription.updated`
- `customer.subscription.deleted`
- `invoice.paid`
- `invoice.payment_failed`
- Copy the Signing Secret (starts with `whsec_`)
4. **Test with Stripe CLI** (optional, for local development):
```bash
stripe listen --forward-to localhost:3001/api/payments/webhook
```
## Docker Compose Files
The project includes pre-configured compose files:

View file

@ -16,9 +16,18 @@ export const applyColorScheme = (schemeId: string) => {
document.documentElement.style.setProperty('--theme-primary-darker', scheme.primaryDarker);
};
interface SubscriptionInfo {
hasAccess: boolean;
accessType: 'group' | 'subscription' | 'none';
generationCount: number | null;
generationLimit: number | null;
generationsRemaining: number | null;
}
interface UseUserPreferencesReturn {
preferences: UserPreferences;
hasAIAccess: boolean;
subscription: SubscriptionInfo | null;
loading: boolean;
saving: boolean;
fetchPreferences: () => Promise<void>;
@ -30,6 +39,7 @@ export const useUserPreferences = (): UseUserPreferencesReturn => {
const { authFetch, isAuthenticated } = useAuthenticatedFetch();
const [preferences, setPreferences] = useState<UserPreferences>(DEFAULT_PREFERENCES);
const [hasAIAccess, setHasAIAccess] = useState(false);
const [subscription, setSubscription] = useState<SubscriptionInfo | null>(null);
const [loading, setLoading] = useState(false);
const [saving, setSaving] = useState(false);
@ -54,6 +64,23 @@ export const useUserPreferences = (): UseUserPreferencesReturn => {
setPreferences(prefs);
setHasAIAccess(data.hasAIAccess || false);
applyColorScheme(prefs.colorScheme);
const backendUrl = import.meta.env.VITE_BACKEND_URL || 'http://localhost:3001';
try {
const subResponse = await authFetch(`${backendUrl}/api/payments/status`);
if (subResponse.ok) {
const subData = await subResponse.json();
setSubscription({
hasAccess: subData.hasAccess,
accessType: subData.accessType,
generationCount: subData.generationCount,
generationLimit: subData.generationLimit,
generationsRemaining: subData.generationsRemaining,
});
}
} catch {
// Payments not configured, ignore
}
}
} catch {
} finally {
@ -92,6 +119,7 @@ export const useUserPreferences = (): UseUserPreferencesReturn => {
return {
preferences,
hasAIAccess,
subscription,
loading,
saving,
fetchPreferences,

View file

@ -18,6 +18,7 @@
"jwks-rsa": "^3.1.0",
"multer": "^2.0.2",
"officeparser": "^6.0.4",
"stripe": "^20.2.0",
"uuid": "^11.0.5"
},
"devDependencies": {
@ -1454,7 +1455,6 @@
"resolved": "https://registry.npmjs.org/express/-/express-4.22.1.tgz",
"integrity": "sha512-F2X8g9P1X7uCPZMA3MVf9wcTqlyNp7IhH5qPCI0izhaOIYXaW9L535tGA3qmjRzpH+bZczqq7hVKxTR4NWnu+g==",
"license": "MIT",
"peer": true,
"dependencies": {
"accepts": "~1.3.8",
"array-flatten": "1.1.1",
@ -2896,6 +2896,26 @@
"node": ">=0.10.0"
}
},
"node_modules/stripe": {
"version": "20.2.0",
"resolved": "https://registry.npmjs.org/stripe/-/stripe-20.2.0.tgz",
"integrity": "sha512-m8niTfdm3nPP/yQswRWMwQxqEUcTtB3RTJQ9oo6NINDzgi7aPOadsH/fPXIIfL1Sc5+lqQFKSk7WiO6CXmvaeA==",
"license": "MIT",
"dependencies": {
"qs": "^6.14.1"
},
"engines": {
"node": ">=16"
},
"peerDependencies": {
"@types/node": ">=16"
},
"peerDependenciesMeta": {
"@types/node": {
"optional": true
}
}
},
"node_modules/strtok3": {
"version": "6.3.0",
"resolved": "https://registry.npmjs.org/strtok3/-/strtok3-6.3.0.tgz",
@ -3205,7 +3225,6 @@
"resolved": "https://registry.npmjs.org/zod/-/zod-3.25.76.tgz",
"integrity": "sha512-gzUt/qt81nXsFGKIFcC3YnfEAx5NkunCfnDlvuBSSFS02bcXu4Lmea0AFIUwbLWxWPx3d9p8S5QoaujKcNQxcQ==",
"license": "MIT",
"peer": true,
"funding": {
"url": "https://github.com/sponsors/colinhacks"
}

View file

@ -21,6 +21,7 @@
"jwks-rsa": "^3.1.0",
"multer": "^2.0.2",
"officeparser": "^6.0.4",
"stripe": "^20.2.0",
"uuid": "^11.0.5"
},
"devDependencies": {

View file

@ -127,6 +127,69 @@ const runMigrations = () => {
db.exec("CREATE UNIQUE INDEX idx_quizzes_share_token ON quizzes(share_token)");
console.log("Migration: Created unique index on quizzes.share_token");
}
const userTableInfo3 = db.prepare("PRAGMA table_info(users)").all() as { name: string }[];
const hasStripeCustomerId = userTableInfo3.some(col => col.name === "stripe_customer_id");
if (!hasStripeCustomerId) {
db.exec("ALTER TABLE users ADD COLUMN stripe_customer_id TEXT UNIQUE");
console.log("Migration: Added stripe_customer_id to users");
}
const hasSubscriptionStatus = userTableInfo3.some(col => col.name === "subscription_status");
if (!hasSubscriptionStatus) {
db.exec("ALTER TABLE users ADD COLUMN subscription_status TEXT DEFAULT 'none'");
console.log("Migration: Added subscription_status to users");
}
const hasSubscriptionId = userTableInfo3.some(col => col.name === "subscription_id");
if (!hasSubscriptionId) {
db.exec("ALTER TABLE users ADD COLUMN subscription_id TEXT");
console.log("Migration: Added subscription_id to users");
}
const hasSubscriptionPeriodEnd = userTableInfo3.some(col => col.name === "subscription_current_period_end");
if (!hasSubscriptionPeriodEnd) {
db.exec("ALTER TABLE users ADD COLUMN subscription_current_period_end DATETIME");
console.log("Migration: Added subscription_current_period_end to users");
}
const hasGenerationCount = userTableInfo3.some(col => col.name === "generation_count");
if (!hasGenerationCount) {
db.exec("ALTER TABLE users ADD COLUMN generation_count INTEGER DEFAULT 0");
console.log("Migration: Added generation_count to users");
}
const hasGenerationResetDate = userTableInfo3.some(col => col.name === "generation_reset_date");
if (!hasGenerationResetDate) {
db.exec("ALTER TABLE users ADD COLUMN generation_reset_date DATETIME");
console.log("Migration: Added generation_reset_date to users");
}
const paymentsTable = db.prepare("SELECT name FROM sqlite_master WHERE type='table' AND name='payments'").get();
if (!paymentsTable) {
db.exec(`
CREATE TABLE payments (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL REFERENCES users(id),
stripe_payment_intent_id TEXT,
stripe_invoice_id TEXT,
amount INTEGER NOT NULL,
currency TEXT DEFAULT 'usd',
status TEXT NOT NULL,
description TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX idx_payments_user ON payments(user_id);
`);
console.log("Migration: Created payments table");
}
const stripeCustomerIndex = db.prepare("SELECT name FROM sqlite_master WHERE type='index' AND name='idx_users_stripe_customer'").get();
if (!stripeCustomerIndex) {
db.exec("CREATE INDEX IF NOT EXISTS idx_users_stripe_customer ON users(stripe_customer_id)");
console.log("Migration: Created index on users.stripe_customer_id");
}
};
runMigrations();

View file

@ -7,7 +7,15 @@ CREATE TABLE IF NOT EXISTS users (
last_login DATETIME,
default_game_config TEXT,
color_scheme TEXT DEFAULT 'blue',
gemini_api_key TEXT
gemini_api_key TEXT,
-- Stripe subscription fields
stripe_customer_id TEXT UNIQUE,
subscription_status TEXT DEFAULT 'none',
subscription_id TEXT,
subscription_current_period_end DATETIME,
-- Generation tracking
generation_count INTEGER DEFAULT 0,
generation_reset_date DATETIME
);
CREATE TABLE IF NOT EXISTS quizzes (
@ -62,3 +70,18 @@ CREATE TABLE IF NOT EXISTS game_sessions (
CREATE INDEX IF NOT EXISTS idx_quizzes_user ON quizzes(user_id);
CREATE INDEX IF NOT EXISTS idx_questions_quiz ON questions(quiz_id);
CREATE INDEX IF NOT EXISTS idx_options_question ON answer_options(question_id);
CREATE TABLE IF NOT EXISTS payments (
id TEXT PRIMARY KEY,
user_id TEXT NOT NULL REFERENCES users(id),
stripe_payment_intent_id TEXT,
stripe_invoice_id TEXT,
amount INTEGER NOT NULL,
currency TEXT DEFAULT 'usd',
status TEXT NOT NULL,
description TEXT,
created_at DATETIME DEFAULT CURRENT_TIMESTAMP
);
CREATE INDEX IF NOT EXISTS idx_payments_user ON payments(user_id);
CREATE INDEX IF NOT EXISTS idx_users_stripe_customer ON users(stripe_customer_id);

View file

@ -9,6 +9,7 @@ import uploadRouter from './routes/upload.js';
import gamesRouter from './routes/games.js';
import generateRouter from './routes/generate.js';
import sharedRouter from './routes/shared.js';
import paymentsRouter, { webhookHandler } from './routes/payments.js';
const app = express();
const PORT = process.env.PORT || 3001;
@ -58,6 +59,8 @@ app.use((req: Request, res: Response, next: NextFunction) => {
next();
});
app.post('/api/payments/webhook', express.raw({ type: 'application/json' }), webhookHandler);
app.use((req: Request, res: Response, next: NextFunction) => {
express.json({ limit: '10mb' })(req, res, (err) => {
if (err instanceof SyntaxError && 'body' in err) {
@ -95,6 +98,7 @@ app.use('/api/upload', uploadRouter);
app.use('/api/games', gamesRouter);
app.use('/api/generate', generateRouter);
app.use('/api/shared', sharedRouter);
app.use('/api/payments', paymentsRouter);
app.use((err: Error, _req: Request, res: Response, _next: NextFunction) => {
console.error('Unhandled error:', err);

View file

@ -83,6 +83,13 @@ export function requireAuth(
);
}
import { canGenerate, incrementGenerationCount, GENERATION_LIMIT } from '../services/stripe.js';
export interface AIAccessInfo {
accessType: 'group' | 'subscription' | 'none';
remaining?: number;
}
export function requireAIAccess(
req: AuthenticatedRequest,
res: Response,
@ -93,12 +100,44 @@ export function requireAIAccess(
return;
}
const hasAccess = req.user.groups?.includes('kaboot-ai-access');
const groups = req.user.groups || [];
const result = canGenerate(req.user.sub, groups);
if (!hasAccess) {
res.status(403).json({ error: 'AI access not granted for this account' });
if (!result.allowed) {
res.status(403).json({
error: result.reason || 'AI access not granted for this account',
remaining: result.remaining,
});
return;
}
(req as any).aiAccessInfo = {
accessType: groups.includes('kaboot-ai-access') ? 'group' : 'subscription',
remaining: result.remaining,
} as AIAccessInfo;
next();
}
export function trackGeneration(
req: AuthenticatedRequest,
res: Response,
next: NextFunction
): void {
if (!req.user) {
next();
return;
}
const groups = req.user.groups || [];
if (groups.includes('kaboot-ai-access')) {
next();
return;
}
const newCount = incrementGenerationCount(req.user.sub);
const remaining = Math.max(0, GENERATION_LIMIT - newCount);
res.setHeader('X-Generations-Remaining', remaining.toString());
next();
}

View file

@ -1,6 +1,7 @@
import { Router, Response } from 'express';
import { GoogleGenAI, Type, createUserContent, createPartFromUri } from '@google/genai';
import { requireAuth, AuthenticatedRequest, requireAIAccess } from '../middleware/auth.js';
import { incrementGenerationCount, GENERATION_LIMIT } from '../services/stripe.js';
import { v4 as uuidv4 } from 'uuid';
const router = Router();
@ -170,6 +171,13 @@ router.post('/', requireAuth, requireAIAccess, async (req: AuthenticatedRequest,
const data = JSON.parse(response.text);
const quiz = transformToQuiz(data);
const groups = req.user!.groups || [];
if (!groups.includes('kaboot-ai-access')) {
const newCount = incrementGenerationCount(req.user!.sub);
const remaining = Math.max(0, GENERATION_LIMIT - newCount);
res.setHeader('X-Generations-Remaining', remaining.toString());
}
res.json(quiz);
} catch (err: any) {
console.error('AI generation error:', err);

View file

@ -0,0 +1,282 @@
import { Router, Response, Request } from 'express';
import Stripe from 'stripe';
import { requireAuth, AuthenticatedRequest } from '../middleware/auth.js';
import { db } from '../db/connection.js';
import {
getStripe,
isStripeConfigured,
createCheckoutSession,
createPortalSession,
getSubscriptionStatus,
activateSubscription,
updateSubscriptionStatus,
resetGenerationCount,
recordPayment,
GENERATION_LIMIT,
} from '../services/stripe.js';
const router = Router();
const STRIPE_WEBHOOK_SECRET = process.env.STRIPE_WEBHOOK_SECRET;
router.get('/config', (_req: Request, res: Response) => {
res.json({
configured: isStripeConfigured(),
generationLimit: GENERATION_LIMIT,
});
});
router.get('/status', requireAuth, (req: AuthenticatedRequest, res: Response) => {
const userId = req.user!.sub;
const groups = req.user!.groups || [];
const hasGroupAccess = groups.includes('kaboot-ai-access');
if (hasGroupAccess) {
res.json({
hasAccess: true,
accessType: 'group',
status: 'active',
generationCount: 0,
generationLimit: null,
generationsRemaining: null,
currentPeriodEnd: null,
});
return;
}
const status = getSubscriptionStatus(userId);
res.json({
hasAccess: status.status === 'active',
accessType: status.status === 'active' ? 'subscription' : 'none',
status: status.status,
generationCount: status.generationCount,
generationLimit: status.generationLimit,
generationsRemaining: status.generationsRemaining,
currentPeriodEnd: status.currentPeriodEnd,
});
});
router.post('/checkout', requireAuth, async (req: AuthenticatedRequest, res: Response) => {
if (!isStripeConfigured()) {
res.status(503).json({ error: 'Payments are not configured' });
return;
}
const userId = req.user!.sub;
const email = req.user!.email;
const { planType, successUrl, cancelUrl } = req.body;
if (!planType || !['monthly', 'yearly'].includes(planType)) {
res.status(400).json({ error: 'Invalid plan type. Must be "monthly" or "yearly".' });
return;
}
if (!successUrl || !cancelUrl) {
res.status(400).json({ error: 'successUrl and cancelUrl are required' });
return;
}
try {
const session = await createCheckoutSession(userId, email, planType, successUrl, cancelUrl);
res.json({ url: session.url });
} catch (err: any) {
console.error('Checkout session error:', err);
res.status(500).json({ error: err.message || 'Failed to create checkout session' });
}
});
router.post('/portal', requireAuth, async (req: AuthenticatedRequest, res: Response) => {
if (!isStripeConfigured()) {
res.status(503).json({ error: 'Payments are not configured' });
return;
}
const userId = req.user!.sub;
const { returnUrl } = req.body;
if (!returnUrl) {
res.status(400).json({ error: 'returnUrl is required' });
return;
}
try {
const session = await createPortalSession(userId, returnUrl);
res.json({ url: session.url });
} catch (err: any) {
console.error('Portal session error:', err);
res.status(500).json({ error: err.message || 'Failed to create portal session' });
}
});
function getUserIdFromCustomer(customerId: string): string | null {
const user = db.prepare('SELECT id FROM users WHERE stripe_customer_id = ?').get(customerId) as { id: string } | undefined;
return user?.id || null;
}
async function handleCheckoutCompleted(session: Stripe.Checkout.Session): Promise<void> {
const userId = session.metadata?.user_id;
if (!userId) {
console.error('No user_id in checkout session metadata');
return;
}
if (session.mode === 'subscription' && session.subscription) {
const stripe = getStripe();
const subscription = await stripe.subscriptions.retrieve(session.subscription as string);
const firstItem = subscription.items.data[0];
const periodEnd = new Date(firstItem.current_period_end * 1000);
activateSubscription(userId, subscription.id, periodEnd);
console.log(`Subscription activated for user ${userId}`);
}
}
async function handleSubscriptionUpdated(subscription: Stripe.Subscription): Promise<void> {
const userId = subscription.metadata?.user_id || getUserIdFromCustomer(subscription.customer as string);
if (!userId) {
console.error('Could not find user for subscription:', subscription.id);
return;
}
const firstItem = subscription.items.data[0];
const periodEnd = firstItem ? new Date(firstItem.current_period_end * 1000) : new Date();
switch (subscription.status) {
case 'active':
updateSubscriptionStatus(userId, 'active', periodEnd);
break;
case 'past_due':
updateSubscriptionStatus(userId, 'past_due', periodEnd);
break;
case 'canceled':
case 'unpaid':
updateSubscriptionStatus(userId, 'canceled');
break;
}
console.log(`Subscription ${subscription.id} updated to ${subscription.status} for user ${userId}`);
}
async function handleSubscriptionDeleted(subscription: Stripe.Subscription): Promise<void> {
const userId = subscription.metadata?.user_id || getUserIdFromCustomer(subscription.customer as string);
if (!userId) {
console.error('Could not find user for subscription:', subscription.id);
return;
}
updateSubscriptionStatus(userId, 'canceled');
console.log(`Subscription ${subscription.id} deleted for user ${userId}`);
}
async function handleInvoicePaid(invoice: Stripe.Invoice): Promise<void> {
const customerId = invoice.customer as string;
const userId = getUserIdFromCustomer(customerId);
if (!userId) {
console.error('Could not find user for customer:', customerId);
return;
}
const subscriptionId = invoice.parent?.subscription_details?.subscription;
if (subscriptionId) {
const stripe = getStripe();
const subId = typeof subscriptionId === 'string' ? subscriptionId : subscriptionId.id;
const subscription = await stripe.subscriptions.retrieve(subId);
const firstItem = subscription.items.data[0];
const periodEnd = firstItem ? new Date(firstItem.current_period_end * 1000) : new Date();
resetGenerationCount(userId, periodEnd);
updateSubscriptionStatus(userId, 'active', periodEnd);
console.log(`Generation count reset for user ${userId} (invoice paid)`);
}
const invoiceAny = invoice as any;
recordPayment(
userId,
invoiceAny.payment_intent || null,
invoice.id,
invoice.amount_paid,
invoice.currency,
'succeeded',
invoice.description || 'Subscription payment'
);
}
async function handleInvoicePaymentFailed(invoice: Stripe.Invoice): Promise<void> {
const customerId = invoice.customer as string;
const userId = getUserIdFromCustomer(customerId);
if (!userId) {
console.error('Could not find user for customer:', customerId);
return;
}
updateSubscriptionStatus(userId, 'past_due');
console.log(`Payment failed for user ${userId}, status set to past_due`);
const invoiceAny = invoice as any;
recordPayment(
userId,
invoiceAny.payment_intent || null,
invoice.id,
invoice.amount_due,
invoice.currency,
'failed',
'Payment failed'
);
}
export const webhookHandler = async (req: Request, res: Response): Promise<void> => {
if (!STRIPE_WEBHOOK_SECRET) {
res.status(503).json({ error: 'Webhook secret not configured' });
return;
}
const sig = req.headers['stripe-signature'];
if (!sig) {
res.status(400).json({ error: 'Missing stripe-signature header' });
return;
}
let event: Stripe.Event;
try {
const stripe = getStripe();
event = stripe.webhooks.constructEvent(req.body, sig, STRIPE_WEBHOOK_SECRET);
} catch (err: any) {
console.error('Webhook signature verification failed:', err.message);
res.status(400).json({ error: `Webhook Error: ${err.message}` });
return;
}
try {
switch (event.type) {
case 'checkout.session.completed':
await handleCheckoutCompleted(event.data.object as Stripe.Checkout.Session);
break;
case 'customer.subscription.updated':
await handleSubscriptionUpdated(event.data.object as Stripe.Subscription);
break;
case 'customer.subscription.deleted':
await handleSubscriptionDeleted(event.data.object as Stripe.Subscription);
break;
case 'invoice.paid':
await handleInvoicePaid(event.data.object as Stripe.Invoice);
break;
case 'invoice.payment_failed':
await handleInvoicePaymentFailed(event.data.object as Stripe.Invoice);
break;
default:
console.log(`Unhandled event type: ${event.type}`);
}
res.json({ received: true });
} catch (err: any) {
console.error('Error handling webhook:', err);
res.status(500).json({ error: 'Webhook handler failed' });
}
};
export default router;

View file

@ -0,0 +1,242 @@
import Stripe from 'stripe';
import { db } from '../db/connection.js';
const STRIPE_SECRET_KEY = process.env.STRIPE_SECRET_KEY;
const STRIPE_PRICE_ID_MONTHLY = process.env.STRIPE_PRICE_ID_MONTHLY;
const STRIPE_PRICE_ID_YEARLY = process.env.STRIPE_PRICE_ID_YEARLY;
export const GENERATION_LIMIT = 250;
let stripeClient: Stripe | null = null;
export function getStripe(): Stripe {
if (!stripeClient) {
if (!STRIPE_SECRET_KEY) {
throw new Error('STRIPE_SECRET_KEY is not configured');
}
stripeClient = new Stripe(STRIPE_SECRET_KEY);
}
return stripeClient;
}
export function isStripeConfigured(): boolean {
return !!(STRIPE_SECRET_KEY && STRIPE_PRICE_ID_MONTHLY);
}
export function getPriceId(planType: 'monthly' | 'yearly'): string {
const priceId = planType === 'yearly' ? STRIPE_PRICE_ID_YEARLY : STRIPE_PRICE_ID_MONTHLY;
if (!priceId) {
throw new Error(`Price ID for ${planType} plan is not configured`);
}
return priceId;
}
export async function getOrCreateCustomer(userId: string, email: string | undefined): Promise<string> {
const stripe = getStripe();
const user = db.prepare('SELECT stripe_customer_id FROM users WHERE id = ?').get(userId) as { stripe_customer_id: string | null } | undefined;
if (user?.stripe_customer_id) {
return user.stripe_customer_id;
}
const customer = await stripe.customers.create({
email: email || undefined,
metadata: {
user_id: userId,
},
});
db.prepare('UPDATE users SET stripe_customer_id = ? WHERE id = ?').run(customer.id, userId);
return customer.id;
}
export async function createCheckoutSession(
userId: string,
email: string | undefined,
planType: 'monthly' | 'yearly',
successUrl: string,
cancelUrl: string
): Promise<Stripe.Checkout.Session> {
const stripe = getStripe();
const customerId = await getOrCreateCustomer(userId, email);
const priceId = getPriceId(planType);
const session = await stripe.checkout.sessions.create({
customer: customerId,
mode: 'subscription',
line_items: [
{
price: priceId,
quantity: 1,
},
],
success_url: successUrl,
cancel_url: cancelUrl,
subscription_data: {
metadata: {
user_id: userId,
},
},
metadata: {
user_id: userId,
plan_type: planType,
},
});
return session;
}
export async function createPortalSession(
userId: string,
returnUrl: string
): Promise<Stripe.BillingPortal.Session> {
const stripe = getStripe();
const user = db.prepare('SELECT stripe_customer_id FROM users WHERE id = ?').get(userId) as { stripe_customer_id: string | null } | undefined;
if (!user?.stripe_customer_id) {
throw new Error('No Stripe customer found for this user');
}
const session = await stripe.billingPortal.sessions.create({
customer: user.stripe_customer_id,
return_url: returnUrl,
});
return session;
}
export interface SubscriptionStatus {
status: 'none' | 'active' | 'past_due' | 'canceled';
currentPeriodEnd: string | null;
generationCount: number;
generationLimit: number;
generationsRemaining: number;
}
export function getSubscriptionStatus(userId: string): SubscriptionStatus {
const user = db.prepare(`
SELECT subscription_status, subscription_current_period_end, generation_count, generation_reset_date
FROM users WHERE id = ?
`).get(userId) as {
subscription_status: string | null;
subscription_current_period_end: string | null;
generation_count: number | null;
generation_reset_date: string | null;
} | undefined;
const status = (user?.subscription_status || 'none') as SubscriptionStatus['status'];
const generationCount = user?.generation_count || 0;
return {
status,
currentPeriodEnd: user?.subscription_current_period_end || null,
generationCount,
generationLimit: GENERATION_LIMIT,
generationsRemaining: Math.max(0, GENERATION_LIMIT - generationCount),
};
}
export function activateSubscription(
userId: string,
subscriptionId: string,
currentPeriodEnd: Date
): void {
db.prepare(`
UPDATE users
SET subscription_status = 'active',
subscription_id = ?,
subscription_current_period_end = ?,
generation_count = 0,
generation_reset_date = ?
WHERE id = ?
`).run(subscriptionId, currentPeriodEnd.toISOString(), currentPeriodEnd.toISOString(), userId);
}
export function updateSubscriptionStatus(
userId: string,
status: 'active' | 'past_due' | 'canceled' | 'none',
currentPeriodEnd?: Date
): void {
if (currentPeriodEnd) {
db.prepare(`
UPDATE users
SET subscription_status = ?,
subscription_current_period_end = ?
WHERE id = ?
`).run(status, currentPeriodEnd.toISOString(), userId);
} else {
db.prepare(`
UPDATE users
SET subscription_status = ?
WHERE id = ?
`).run(status, userId);
}
}
export function resetGenerationCount(userId: string, newResetDate: Date): void {
db.prepare(`
UPDATE users
SET generation_count = 0,
generation_reset_date = ?
WHERE id = ?
`).run(newResetDate.toISOString(), userId);
}
export function incrementGenerationCount(userId: string): number {
const result = db.prepare(`
UPDATE users
SET generation_count = COALESCE(generation_count, 0) + 1
WHERE id = ?
RETURNING generation_count
`).get(userId) as { generation_count: number } | undefined;
return result?.generation_count || 1;
}
export function canGenerate(userId: string, groups: string[]): { allowed: boolean; reason?: string; remaining?: number } {
if (groups.includes('kaboot-ai-access')) {
return { allowed: true };
}
const status = getSubscriptionStatus(userId);
if (status.status !== 'active') {
return {
allowed: false,
reason: 'No active subscription. Upgrade to access AI generation.',
};
}
if (status.generationsRemaining <= 0) {
return {
allowed: false,
reason: 'Generation limit reached for this billing period.',
remaining: 0,
};
}
return {
allowed: true,
remaining: status.generationsRemaining,
};
}
export function recordPayment(
userId: string,
paymentIntentId: string | null,
invoiceId: string | null,
amount: number,
currency: string,
status: string,
description: string
): void {
const id = `pay_${Date.now()}_${Math.random().toString(36).substr(2, 9)}`;
db.prepare(`
INSERT INTO payments (id, user_id, stripe_payment_intent_id, stripe_invoice_id, amount, currency, status, description)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)
`).run(id, userId, paymentIntentId, invoiceId, amount, currency, status, description);
}

View file

@ -0,0 +1,243 @@
import Database from 'better-sqlite3';
import { randomUUID } from 'crypto';
const API_URL = process.env.API_URL || 'http://localhost:3001';
const TOKEN = process.env.TEST_TOKEN;
if (!TOKEN) {
console.error('ERROR: TEST_TOKEN environment variable is required');
process.exit(1);
}
interface TestResult {
name: string;
passed: boolean;
error?: string;
}
const results: TestResult[] = [];
async function request(
method: string,
path: string,
body?: unknown,
expectStatus = 200,
useToken = true
): Promise<{ status: number; data: unknown; headers: Headers }> {
const headers: Record<string, string> = {
'Content-Type': 'application/json',
};
if (useToken) {
headers['Authorization'] = `Bearer ${TOKEN}`;
}
const response = await fetch(`${API_URL}${path}`, {
method,
headers,
body: body ? JSON.stringify(body) : undefined,
});
const data = response.headers.get('content-type')?.includes('application/json')
? await response.json()
: null;
if (response.status !== expectStatus) {
throw new Error(`Expected ${expectStatus}, got ${response.status}: ${JSON.stringify(data)}`);
}
return { status: response.status, data, headers: response.headers };
}
async function test(name: string, fn: () => Promise<void>) {
try {
await fn();
results.push({ name, passed: true });
console.log(`${name}`);
} catch (error) {
const message = error instanceof Error ? error.message : String(error);
results.push({ name, passed: false, error: message });
console.log(`${name}`);
console.log(` ${message}`);
}
}
async function runTests() {
console.log('\n=== Kaboot Payments API Tests ===\n');
console.log(`API: ${API_URL}`);
console.log('');
console.log('Payment Config Tests:');
await test('GET /api/payments/config returns configuration', async () => {
const res = await fetch(`${API_URL}/api/payments/config`);
const data = await res.json();
if (typeof data.configured !== 'boolean') {
throw new Error('Missing configured field');
}
if (typeof data.generationLimit !== 'number') {
throw new Error('Missing generationLimit field');
}
if (data.generationLimit !== 250) {
throw new Error(`Expected generationLimit 250, got ${data.generationLimit}`);
}
});
console.log('\nPayment Status Tests:');
await test('GET /api/payments/status without auth returns 401', async () => {
const res = await fetch(`${API_URL}/api/payments/status`);
if (res.status !== 401) {
throw new Error(`Expected 401, got ${res.status}`);
}
});
await test('GET /api/payments/status with invalid token returns 401', async () => {
const res = await fetch(`${API_URL}/api/payments/status`, {
headers: { Authorization: 'Bearer invalid-token-here' },
});
if (res.status !== 401) {
throw new Error(`Expected 401, got ${res.status}`);
}
});
await test('GET /api/payments/status with valid token returns status', async () => {
const { data } = await request('GET', '/api/payments/status');
const status = data as Record<string, unknown>;
if (typeof status.hasAccess !== 'boolean') {
throw new Error('Missing hasAccess field');
}
if (!['group', 'subscription', 'none'].includes(status.accessType as string)) {
throw new Error(`Invalid accessType: ${status.accessType}`);
}
if (!['none', 'active', 'past_due', 'canceled'].includes(status.status as string)) {
throw new Error(`Invalid status: ${status.status}`);
}
});
console.log('\nCheckout Tests:');
await test('POST /api/payments/checkout without auth returns 401', async () => {
const res = await fetch(`${API_URL}/api/payments/checkout`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ planType: 'monthly', successUrl: 'http://test.com/success', cancelUrl: 'http://test.com/cancel' }),
});
if (res.status !== 401) {
throw new Error(`Expected 401, got ${res.status}`);
}
});
await test('POST /api/payments/checkout without planType returns 400', async () => {
const { status, data } = await request('POST', '/api/payments/checkout', {
successUrl: 'http://test.com/success',
cancelUrl: 'http://test.com/cancel',
}, 400);
const error = data as { error: string };
if (!error.error.includes('plan')) {
throw new Error(`Expected plan type error, got: ${error.error}`);
}
});
await test('POST /api/payments/checkout with invalid planType returns 400', async () => {
const { data } = await request('POST', '/api/payments/checkout', {
planType: 'invalid',
successUrl: 'http://test.com/success',
cancelUrl: 'http://test.com/cancel',
}, 400);
const error = data as { error: string };
if (!error.error.includes('monthly') && !error.error.includes('yearly')) {
throw new Error(`Expected plan type validation error, got: ${error.error}`);
}
});
await test('POST /api/payments/checkout without successUrl returns 400', async () => {
const { data } = await request('POST', '/api/payments/checkout', {
planType: 'monthly',
cancelUrl: 'http://test.com/cancel',
}, 400);
const error = data as { error: string };
if (!error.error.includes('successUrl')) {
throw new Error(`Expected successUrl error, got: ${error.error}`);
}
});
await test('POST /api/payments/checkout without cancelUrl returns 400', async () => {
const { data } = await request('POST', '/api/payments/checkout', {
planType: 'monthly',
successUrl: 'http://test.com/success',
}, 400);
const error = data as { error: string };
if (!error.error.includes('cancelUrl')) {
throw new Error(`Expected cancelUrl error, got: ${error.error}`);
}
});
console.log('\nPortal Tests:');
await test('POST /api/payments/portal without auth returns 401', async () => {
const res = await fetch(`${API_URL}/api/payments/portal`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ returnUrl: 'http://test.com' }),
});
if (res.status !== 401) {
throw new Error(`Expected 401, got ${res.status}`);
}
});
await test('POST /api/payments/portal without returnUrl returns 400', async () => {
const { data } = await request('POST', '/api/payments/portal', {}, 400);
const error = data as { error: string };
if (!error.error.includes('returnUrl')) {
throw new Error(`Expected returnUrl error, got: ${error.error}`);
}
});
console.log('\nWebhook Tests:');
await test('POST /api/payments/webhook without signature returns 400', async () => {
const res = await fetch(`${API_URL}/api/payments/webhook`, {
method: 'POST',
headers: { 'Content-Type': 'application/json' },
body: JSON.stringify({ type: 'test' }),
});
if (res.status !== 400 && res.status !== 503) {
throw new Error(`Expected 400 or 503, got ${res.status}`);
}
});
await test('POST /api/payments/webhook with invalid signature returns 400', async () => {
const res = await fetch(`${API_URL}/api/payments/webhook`, {
method: 'POST',
headers: {
'Content-Type': 'application/json',
'stripe-signature': 'invalid-signature',
},
body: JSON.stringify({ type: 'test' }),
});
if (res.status !== 400 && res.status !== 503) {
throw new Error(`Expected 400 or 503, got ${res.status}`);
}
});
console.log('\n=== Results ===\n');
const passed = results.filter((r) => r.passed).length;
const failed = results.filter((r) => !r.passed).length;
console.log(`Passed: ${passed}`);
console.log(`Failed: ${failed}`);
console.log(`Total: ${results.length}`);
if (failed > 0) {
console.log('\nFailed tests:');
results
.filter((r) => !r.passed)
.forEach((r) => console.log(` - ${r.name}: ${r.error}`));
process.exit(1);
}
}
runTests().catch((err) => {
console.error('Test runner error:', err);
process.exit(1);
});

View file

@ -0,0 +1,208 @@
import { render, screen, fireEvent, waitFor } from '@testing-library/react';
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
vi.mock('framer-motion', () => ({
motion: {
div: ({ children, ...props }: any) => <div {...props}>{children}</div>,
button: ({ children, ...props }: any) => <button {...props}>{children}</button>,
},
}));
vi.mock('canvas-confetti', () => ({
default: vi.fn(),
}));
describe('PaymentResult', () => {
let PaymentResult: typeof import('../../components/PaymentResult').PaymentResult;
let mockConfetti: ReturnType<typeof vi.fn>;
beforeEach(async () => {
vi.clearAllMocks();
vi.resetModules();
vi.useFakeTimers();
const confettiModule = await import('canvas-confetti');
mockConfetti = confettiModule.default as ReturnType<typeof vi.fn>;
const module = await import('../../components/PaymentResult');
PaymentResult = module.PaymentResult;
});
afterEach(() => {
vi.useRealTimers();
vi.resetAllMocks();
});
describe('loading state', () => {
it('shows loading spinner', () => {
render(<PaymentResult status="loading" />);
expect(document.querySelector('.animate-spin')).toBeInTheDocument();
expect(screen.getByText(/Processing your payment/i)).toBeInTheDocument();
});
it('does not trigger confetti in loading state', () => {
render(<PaymentResult status="loading" />);
expect(mockConfetti).not.toHaveBeenCalled();
});
});
describe('success state', () => {
it('shows success message', () => {
render(<PaymentResult status="success" />);
expect(screen.getByText('Welcome to Pro!')).toBeInTheDocument();
expect(screen.getByText(/AI powers are now unlocked/i)).toBeInTheDocument();
});
it('shows generation count badge', () => {
render(<PaymentResult status="success" />);
expect(screen.getByText('250 AI generations ready to use')).toBeInTheDocument();
});
it('shows "Start Creating" button', () => {
render(<PaymentResult status="success" />);
expect(screen.getByText('Start Creating')).toBeInTheDocument();
});
it('calls onBack when "Start Creating" is clicked', () => {
const mockOnBack = vi.fn();
render(<PaymentResult status="success" onBack={mockOnBack} />);
fireEvent.click(screen.getByText('Start Creating'));
expect(mockOnBack).toHaveBeenCalled();
});
it('triggers confetti animation on mount', async () => {
render(<PaymentResult status="success" />);
vi.advanceTimersByTime(100);
expect(mockConfetti).toHaveBeenCalled();
});
it('confetti uses correct colors', async () => {
render(<PaymentResult status="success" />);
vi.advanceTimersByTime(100);
expect(mockConfetti).toHaveBeenCalledWith(
expect.objectContaining({
colors: ['#8B5CF6', '#6366F1', '#EC4899', '#F59E0B'],
})
);
});
it('only triggers confetti once', async () => {
const { rerender } = render(<PaymentResult status="success" />);
vi.advanceTimersByTime(3100);
const callCountAfterFirst = mockConfetti.mock.calls.length;
rerender(<PaymentResult status="success" />);
vi.advanceTimersByTime(3100);
expect(mockConfetti.mock.calls.length).toBe(callCountAfterFirst);
});
});
describe('cancel state', () => {
it('shows cancel message', () => {
render(<PaymentResult status="cancel" />);
expect(screen.getByText('Payment Cancelled')).toBeInTheDocument();
expect(screen.getByText(/No worries/i)).toBeInTheDocument();
});
it('shows "Go Back" button', () => {
render(<PaymentResult status="cancel" />);
expect(screen.getByText('Go Back')).toBeInTheDocument();
});
it('calls onBack when "Go Back" is clicked', () => {
const mockOnBack = vi.fn();
render(<PaymentResult status="cancel" onBack={mockOnBack} />);
fireEvent.click(screen.getByText('Go Back'));
expect(mockOnBack).toHaveBeenCalled();
});
it('does not show generation count badge', () => {
render(<PaymentResult status="cancel" />);
expect(screen.queryByText(/250 AI generations/i)).not.toBeInTheDocument();
});
it('does not trigger confetti', () => {
render(<PaymentResult status="cancel" />);
vi.advanceTimersByTime(3100);
expect(mockConfetti).not.toHaveBeenCalled();
});
});
describe('UI elements', () => {
it('renders PartyPopper icon for success', () => {
render(<PaymentResult status="success" />);
const successCard = screen.getByText('Welcome to Pro!').closest('div');
expect(successCard).toBeInTheDocument();
});
it('renders XCircle icon for cancel', () => {
render(<PaymentResult status="cancel" />);
const cancelCard = screen.getByText('Payment Cancelled').closest('div');
expect(cancelCard).toBeInTheDocument();
});
it('has green gradient background for success', () => {
render(<PaymentResult status="success" />);
const gradientBg = document.querySelector('.from-green-50');
expect(gradientBg).toBeInTheDocument();
});
it('does not have green gradient for cancel', () => {
render(<PaymentResult status="cancel" />);
const gradientBg = document.querySelector('.from-green-50');
expect(gradientBg).not.toBeInTheDocument();
});
});
describe('button styling', () => {
it('success button has dark styling', () => {
render(<PaymentResult status="success" />);
const button = screen.getByText('Start Creating').closest('button');
expect(button?.className).toContain('bg-gray-900');
});
it('cancel button has theme primary styling', () => {
render(<PaymentResult status="cancel" />);
const button = screen.getByText('Go Back').closest('button');
expect(button?.className).toContain('bg-theme-primary');
});
});
describe('without onBack callback', () => {
it('renders buttons without errors when onBack is undefined', () => {
render(<PaymentResult status="success" />);
const button = screen.getByText('Start Creating');
expect(button).toBeInTheDocument();
fireEvent.click(button);
});
});
});

View file

@ -0,0 +1,436 @@
import { render, screen, fireEvent, waitFor, act } from '@testing-library/react';
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
const mockSigninRedirect = vi.fn();
const mockAuth = {
user: {
access_token: 'valid-token',
},
isAuthenticated: true,
signinRedirect: mockSigninRedirect,
};
vi.mock('react-oidc-context', () => ({
useAuth: () => mockAuth,
}));
vi.mock('framer-motion', () => ({
motion: {
div: ({ children, ...props }: any) => <div {...props}>{children}</div>,
header: ({ children, ...props }: any) => <header {...props}>{children}</header>,
button: ({ children, ...props }: any) => <button {...props}>{children}</button>,
},
AnimatePresence: ({ children }: any) => children,
}));
const originalFetch = global.fetch;
describe('UpgradePage', () => {
let UpgradePage: typeof import('../../components/UpgradePage').UpgradePage;
beforeEach(async () => {
vi.clearAllMocks();
vi.resetModules();
mockAuth.user = { access_token: 'valid-token' };
mockAuth.isAuthenticated = true;
global.fetch = vi.fn();
const module = await import('../../components/UpgradePage');
UpgradePage = module.UpgradePage;
});
afterEach(() => {
global.fetch = originalFetch;
vi.resetAllMocks();
});
describe('loading state', () => {
it('shows loader while checking status', async () => {
let resolveStatus: (value: any) => void;
const statusPromise = new Promise((resolve) => {
resolveStatus = resolve;
});
(global.fetch as ReturnType<typeof vi.fn>).mockReturnValue(statusPromise);
render(<UpgradePage />);
// The loader has animate-spin class
expect(document.querySelector('.animate-spin')).toBeInTheDocument();
await act(async () => {
resolveStatus!({
ok: true,
json: () => Promise.resolve({ hasAccess: false }),
});
});
});
});
describe('user without access', () => {
beforeEach(() => {
(global.fetch as ReturnType<typeof vi.fn>).mockResolvedValue({
ok: true,
json: () => Promise.resolve({ hasAccess: false }),
});
});
it('renders upgrade page with pricing', async () => {
render(<UpgradePage />);
await waitFor(() => {
expect(screen.getByText(/Unlock/i)).toBeInTheDocument();
});
expect(screen.getByText('Monthly')).toBeInTheDocument();
expect(screen.getByText('Yearly')).toBeInTheDocument();
expect(screen.getByText('Starter')).toBeInTheDocument();
expect(screen.getByText('Pro Gamer')).toBeInTheDocument();
});
it('shows $0 for Starter plan', async () => {
render(<UpgradePage />);
await waitFor(() => {
expect(screen.getByText('$0')).toBeInTheDocument();
});
});
it('defaults to yearly billing cycle', async () => {
render(<UpgradePage />);
await waitFor(() => {
expect(screen.getByText('$4.17')).toBeInTheDocument();
});
expect(screen.getByText(/Billed \$50 yearly/i)).toBeInTheDocument();
});
it('switches to monthly pricing when monthly is clicked', async () => {
render(<UpgradePage />);
await waitFor(() => {
expect(screen.getByText('$4.17')).toBeInTheDocument();
});
fireEvent.click(screen.getByText('Monthly'));
await waitFor(() => {
expect(screen.getByText('$5')).toBeInTheDocument();
});
expect(screen.queryByText(/Billed \$50 yearly/i)).not.toBeInTheDocument();
});
it('shows feature comparison', async () => {
render(<UpgradePage />);
await waitFor(() => {
expect(screen.getByText('5 AI generations per month')).toBeInTheDocument();
});
expect(screen.getByText('250 AI generations per month')).toBeInTheDocument();
expect(screen.getByText('Basic quiz topics')).toBeInTheDocument();
expect(screen.getByText('Host up to 10 players')).toBeInTheDocument();
expect(screen.getByText('Host up to 100 players')).toBeInTheDocument();
});
it('shows back button when onBack is provided', async () => {
const mockOnBack = vi.fn();
render(<UpgradePage onBack={mockOnBack} />);
await waitFor(() => {
expect(screen.getByText(/Unlock/i)).toBeInTheDocument();
});
const backButton = document.querySelector('button[class*="absolute"]');
expect(backButton).toBeInTheDocument();
});
it('calls onBack when back button is clicked', async () => {
const mockOnBack = vi.fn();
render(<UpgradePage onBack={mockOnBack} />);
await waitFor(() => {
expect(screen.getByText(/Unlock/i)).toBeInTheDocument();
});
const backButton = document.querySelector('button[class*="absolute"]');
if (backButton) {
fireEvent.click(backButton);
expect(mockOnBack).toHaveBeenCalled();
}
});
});
describe('user with access', () => {
beforeEach(() => {
(global.fetch as ReturnType<typeof vi.fn>).mockResolvedValue({
ok: true,
json: () => Promise.resolve({ hasAccess: true }),
});
});
it('shows "You\'re a Pro!" message', async () => {
render(<UpgradePage />);
await waitFor(() => {
expect(screen.getByText("You're a Pro!")).toBeInTheDocument();
});
});
it('shows "Back to Game" button', async () => {
const mockOnBack = vi.fn();
render(<UpgradePage onBack={mockOnBack} />);
await waitFor(() => {
expect(screen.getByText('Back to Game')).toBeInTheDocument();
});
});
it('calls onBack when "Back to Game" is clicked', async () => {
const mockOnBack = vi.fn();
render(<UpgradePage onBack={mockOnBack} />);
await waitFor(() => {
expect(screen.getByText('Back to Game')).toBeInTheDocument();
});
fireEvent.click(screen.getByText('Back to Game'));
expect(mockOnBack).toHaveBeenCalled();
});
});
describe('checkout flow', () => {
beforeEach(() => {
(global.fetch as ReturnType<typeof vi.fn>).mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({ hasAccess: false }),
});
});
it('calls checkout endpoint when "Upgrade Now" is clicked', async () => {
(global.fetch as ReturnType<typeof vi.fn>).mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({ url: 'https://checkout.stripe.com/test' }),
});
const originalHref = window.location.href;
Object.defineProperty(window, 'location', {
value: { href: '', origin: 'http://localhost' },
writable: true,
});
render(<UpgradePage />);
await waitFor(() => {
expect(screen.getByText(/Upgrade Now/i)).toBeInTheDocument();
});
fireEvent.click(screen.getByText(/Upgrade Now/i));
await waitFor(() => {
expect(global.fetch).toHaveBeenCalledWith(
expect.stringContaining('/api/payments/checkout'),
expect.objectContaining({
method: 'POST',
headers: expect.objectContaining({
'Content-Type': 'application/json',
Authorization: 'Bearer valid-token',
}),
})
);
});
Object.defineProperty(window, 'location', {
value: { href: originalHref },
writable: true,
});
});
it('redirects to sign in if no token', async () => {
mockAuth.user = null;
render(<UpgradePage />);
await waitFor(() => {
expect(screen.getByText(/Upgrade Now/i)).toBeInTheDocument();
});
fireEvent.click(screen.getByText(/Upgrade Now/i));
await waitFor(() => {
expect(mockSigninRedirect).toHaveBeenCalled();
});
});
it('shows error message on checkout failure', async () => {
(global.fetch as ReturnType<typeof vi.fn>).mockResolvedValueOnce({
ok: false,
status: 500,
});
render(<UpgradePage />);
await waitFor(() => {
expect(screen.getByText(/Upgrade Now/i)).toBeInTheDocument();
});
fireEvent.click(screen.getByText(/Upgrade Now/i));
await waitFor(() => {
expect(screen.getByText(/Something went wrong/i)).toBeInTheDocument();
});
});
it('passes correct planType for yearly billing', async () => {
(global.fetch as ReturnType<typeof vi.fn>).mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({ url: 'https://checkout.stripe.com/test' }),
});
const originalHref = window.location.href;
Object.defineProperty(window, 'location', {
value: { href: '', origin: 'http://localhost' },
writable: true,
});
render(<UpgradePage />);
await waitFor(() => {
expect(screen.getByText(/Upgrade Now/i)).toBeInTheDocument();
});
fireEvent.click(screen.getByText(/Upgrade Now/i));
await waitFor(() => {
expect(global.fetch).toHaveBeenCalledWith(
expect.any(String),
expect.objectContaining({
body: expect.stringContaining('"planType":"yearly"'),
})
);
});
Object.defineProperty(window, 'location', {
value: { href: originalHref },
writable: true,
});
});
it('passes correct planType for monthly billing', async () => {
(global.fetch as ReturnType<typeof vi.fn>).mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({ url: 'https://checkout.stripe.com/test' }),
});
const originalHref = window.location.href;
Object.defineProperty(window, 'location', {
value: { href: '', origin: 'http://localhost' },
writable: true,
});
render(<UpgradePage />);
await waitFor(() => {
expect(screen.getByText(/Upgrade Now/i)).toBeInTheDocument();
});
fireEvent.click(screen.getByText('Monthly'));
await waitFor(() => {
expect(screen.getByText('$5')).toBeInTheDocument();
});
fireEvent.click(screen.getByText(/Upgrade Now/i));
await waitFor(() => {
expect(global.fetch).toHaveBeenCalledWith(
expect.any(String),
expect.objectContaining({
body: expect.stringContaining('"planType":"monthly"'),
})
);
});
Object.defineProperty(window, 'location', {
value: { href: originalHref },
writable: true,
});
});
});
describe('status check error handling', () => {
it('handles status check failure gracefully', async () => {
const consoleSpy = vi.spyOn(console, 'error').mockImplementation(() => {});
(global.fetch as ReturnType<typeof vi.fn>).mockRejectedValueOnce(new Error('Network error'));
render(<UpgradePage />);
await waitFor(() => {
expect(screen.getByText(/Unlock/i)).toBeInTheDocument();
});
expect(consoleSpy).toHaveBeenCalledWith('Failed to check status:', expect.any(Error));
consoleSpy.mockRestore();
});
it('handles missing token during status check', async () => {
mockAuth.user = null;
render(<UpgradePage />);
await waitFor(() => {
expect(screen.getByText(/Unlock/i)).toBeInTheDocument();
});
});
});
describe('UI elements', () => {
beforeEach(() => {
(global.fetch as ReturnType<typeof vi.fn>).mockResolvedValue({
ok: true,
json: () => Promise.resolve({ hasAccess: false }),
});
});
it('shows discount badge for yearly billing', async () => {
render(<UpgradePage />);
await waitFor(() => {
expect(screen.getByText('-17%')).toBeInTheDocument();
});
});
it('shows "Most Popular" badge on Pro plan', async () => {
render(<UpgradePage />);
await waitFor(() => {
expect(screen.getByText('Most Popular')).toBeInTheDocument();
});
});
it('shows "Current Plan" button disabled for Starter', async () => {
render(<UpgradePage />);
await waitFor(() => {
expect(screen.getByText('Current Plan')).toBeInTheDocument();
});
const currentPlanButton = screen.getByText('Current Plan');
expect(currentPlanButton).toBeDisabled();
});
it('shows security and guarantee info', async () => {
render(<UpgradePage />);
await waitFor(() => {
expect(screen.getByText(/Secure payment via Stripe/i)).toBeInTheDocument();
});
expect(screen.getByText(/Cancel anytime/i)).toBeInTheDocument();
expect(screen.getByText(/7-day money-back guarantee/i)).toBeInTheDocument();
});
});
});

View file

@ -0,0 +1,416 @@
import { renderHook, act, waitFor } from '@testing-library/react';
import { describe, it, expect, vi, beforeEach, afterEach } from 'vitest';
const mockAuthFetch = vi.fn();
const mockAuth = {
isAuthenticated: true,
};
vi.mock('../../hooks/useAuthenticatedFetch', () => ({
useAuthenticatedFetch: () => ({
authFetch: mockAuthFetch,
isAuthenticated: mockAuth.isAuthenticated,
}),
}));
vi.mock('react-hot-toast', () => ({
default: {
success: vi.fn(),
error: vi.fn(),
},
}));
describe('useUserPreferences - subscription data', () => {
let useUserPreferences: typeof import('../../hooks/useUserPreferences').useUserPreferences;
beforeEach(async () => {
vi.clearAllMocks();
vi.resetModules();
mockAuth.isAuthenticated = true;
const module = await import('../../hooks/useUserPreferences');
useUserPreferences = module.useUserPreferences;
});
afterEach(() => {
vi.resetAllMocks();
});
describe('subscription info fetching', () => {
it('fetches subscription status along with preferences', async () => {
mockAuthFetch
.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
colorScheme: 'blue',
aiProvider: 'gemini',
hasAIAccess: true,
}),
})
.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
hasAccess: true,
accessType: 'subscription',
generationCount: 10,
generationLimit: 250,
generationsRemaining: 240,
}),
});
const { result } = renderHook(() => useUserPreferences());
await waitFor(() => {
expect(result.current.loading).toBe(false);
});
expect(result.current.subscription).toEqual({
hasAccess: true,
accessType: 'subscription',
generationCount: 10,
generationLimit: 250,
generationsRemaining: 240,
});
});
it('sets subscription info for group access users', async () => {
mockAuthFetch
.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
colorScheme: 'blue',
aiProvider: 'gemini',
hasAIAccess: true,
}),
})
.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
hasAccess: true,
accessType: 'group',
generationCount: null,
generationLimit: null,
generationsRemaining: null,
}),
});
const { result } = renderHook(() => useUserPreferences());
await waitFor(() => {
expect(result.current.loading).toBe(false);
});
expect(result.current.subscription).toEqual({
hasAccess: true,
accessType: 'group',
generationCount: null,
generationLimit: null,
generationsRemaining: null,
});
});
it('sets subscription info for users without access', async () => {
mockAuthFetch
.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
colorScheme: 'blue',
aiProvider: 'gemini',
hasAIAccess: false,
}),
})
.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
hasAccess: false,
accessType: 'none',
generationCount: 0,
generationLimit: 250,
generationsRemaining: 0,
}),
});
const { result } = renderHook(() => useUserPreferences());
await waitFor(() => {
expect(result.current.loading).toBe(false);
});
expect(result.current.subscription).toEqual({
hasAccess: false,
accessType: 'none',
generationCount: 0,
generationLimit: 250,
generationsRemaining: 0,
});
});
it('handles subscription status fetch failure gracefully', async () => {
mockAuthFetch
.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
colorScheme: 'purple',
aiProvider: 'openrouter',
hasAIAccess: false,
}),
})
.mockRejectedValueOnce(new Error('Network error'));
const { result } = renderHook(() => useUserPreferences());
await waitFor(() => {
expect(result.current.loading).toBe(false);
});
expect(result.current.subscription).toBeNull();
expect(result.current.preferences.colorScheme).toBe('purple');
});
it('handles subscription status non-ok response gracefully', async () => {
mockAuthFetch
.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
colorScheme: 'blue',
aiProvider: 'gemini',
hasAIAccess: true,
}),
})
.mockResolvedValueOnce({
ok: false,
status: 503,
});
const { result } = renderHook(() => useUserPreferences());
await waitFor(() => {
expect(result.current.loading).toBe(false);
});
expect(result.current.subscription).toBeNull();
});
});
describe('hasAIAccess from preferences', () => {
it('sets hasAIAccess true when user has AI access', async () => {
mockAuthFetch
.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
colorScheme: 'blue',
aiProvider: 'gemini',
hasAIAccess: true,
}),
})
.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
hasAccess: true,
accessType: 'group',
}),
});
const { result } = renderHook(() => useUserPreferences());
await waitFor(() => {
expect(result.current.loading).toBe(false);
});
expect(result.current.hasAIAccess).toBe(true);
});
it('sets hasAIAccess false when user does not have AI access', async () => {
mockAuthFetch
.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
colorScheme: 'blue',
aiProvider: 'gemini',
hasAIAccess: false,
}),
})
.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
hasAccess: false,
accessType: 'none',
}),
});
const { result } = renderHook(() => useUserPreferences());
await waitFor(() => {
expect(result.current.loading).toBe(false);
});
expect(result.current.hasAIAccess).toBe(false);
});
it('defaults hasAIAccess to false when not present', async () => {
mockAuthFetch
.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
colorScheme: 'blue',
aiProvider: 'gemini',
}),
})
.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
hasAccess: false,
accessType: 'none',
}),
});
const { result } = renderHook(() => useUserPreferences());
await waitFor(() => {
expect(result.current.loading).toBe(false);
});
expect(result.current.hasAIAccess).toBe(false);
});
});
describe('when not authenticated', () => {
beforeEach(() => {
mockAuth.isAuthenticated = false;
});
it('does not fetch subscription status', async () => {
const { result } = renderHook(() => useUserPreferences());
await waitFor(() => {
expect(result.current.loading).toBe(false);
});
expect(mockAuthFetch).not.toHaveBeenCalled();
expect(result.current.subscription).toBeNull();
});
});
describe('refetching subscription', () => {
it('can refetch subscription info via fetchPreferences', async () => {
mockAuthFetch
.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
colorScheme: 'blue',
aiProvider: 'gemini',
hasAIAccess: false,
}),
})
.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
hasAccess: false,
accessType: 'none',
generationCount: 0,
generationLimit: 250,
generationsRemaining: 0,
}),
});
const { result } = renderHook(() => useUserPreferences());
await waitFor(() => {
expect(result.current.loading).toBe(false);
});
expect(result.current.subscription?.hasAccess).toBe(false);
mockAuthFetch
.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
colorScheme: 'blue',
aiProvider: 'gemini',
hasAIAccess: true,
}),
})
.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
hasAccess: true,
accessType: 'subscription',
generationCount: 0,
generationLimit: 250,
generationsRemaining: 250,
}),
});
await act(async () => {
await result.current.fetchPreferences();
});
expect(result.current.subscription?.hasAccess).toBe(true);
expect(result.current.subscription?.accessType).toBe('subscription');
});
});
describe('generation tracking', () => {
it('tracks generation count correctly', async () => {
mockAuthFetch
.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
colorScheme: 'blue',
aiProvider: 'gemini',
hasAIAccess: true,
}),
})
.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
hasAccess: true,
accessType: 'subscription',
generationCount: 50,
generationLimit: 250,
generationsRemaining: 200,
}),
});
const { result } = renderHook(() => useUserPreferences());
await waitFor(() => {
expect(result.current.loading).toBe(false);
});
expect(result.current.subscription?.generationCount).toBe(50);
expect(result.current.subscription?.generationsRemaining).toBe(200);
});
it('shows zero remaining when limit reached', async () => {
mockAuthFetch
.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
colorScheme: 'blue',
aiProvider: 'gemini',
hasAIAccess: true,
}),
})
.mockResolvedValueOnce({
ok: true,
json: () => Promise.resolve({
hasAccess: true,
accessType: 'subscription',
generationCount: 250,
generationLimit: 250,
generationsRemaining: 0,
}),
});
const { result } = renderHook(() => useUserPreferences());
await waitFor(() => {
expect(result.current.loading).toBe(false);
});
expect(result.current.subscription?.generationsRemaining).toBe(0);
});
});
});