Replace recursive HTTP-based ShieldX scan with direct library integration. - 547+ rules, 50+ languages, sub-millisecond scans - Enables: rules, entropy, indirect injection, behavioral, unicode, tokenizer, compressed payload detection - Disables Ollama-dependent scanners for zero external dependency - Response now includes threat_level, kill_chain_phase, shieldx_latency_ms
350 lines
13 KiB
TypeScript
350 lines
13 KiB
TypeScript
import type { FastifyInstance, FastifyRequest, FastifyReply } from 'fastify';
|
|
import { z } from 'zod';
|
|
import { classifyInput } from '../pipeline/pre-classifier.js';
|
|
import { route } from '../pipeline/router.js';
|
|
import { resolvePrompt } from '../pipeline/prompt-resolver.js';
|
|
import { callOllamaWithFallbackChain } from '../pipeline/llm-client.js';
|
|
import { runPostValidation } from '../pipeline/post-validator.js';
|
|
import { evaluateConfidence } from '../pipeline/confidence-gate.js';
|
|
import { writeAuditLog, writeBanAnalytics, hashText } from '../observability/audit-log.js';
|
|
import { addToReviewQueue } from '../observability/review-queue.js';
|
|
import {
|
|
requestsTotal,
|
|
latencySeconds,
|
|
tokensTotal,
|
|
confidenceScore,
|
|
banlistHitsTotal,
|
|
validationFailuresTotal,
|
|
} from '../observability/metrics.js';
|
|
import { logger } from '../observability/logger.js';
|
|
import { ShieldX } from '@shieldx/core';
|
|
|
|
// Singleton ShieldX instance — initialized once, sub-millisecond scans
|
|
// Disable Ollama-dependent scanners (sentinel, constitutional, embedding, attention)
|
|
// to keep gateway scans fast and dependency-free
|
|
const shieldx = new ShieldX({
|
|
scanners: {
|
|
rules: true, // 547+ rules, 50+ languages
|
|
sentinel: false, // Requires Ollama
|
|
constitutional: false, // Requires Ollama
|
|
embedding: false, // Requires Ollama
|
|
embeddingAnomaly: false,
|
|
entropy: true, // Zero-cost entropy analysis
|
|
yara: false, // Requires YARA binary
|
|
attention: false, // Requires Ollama
|
|
canary: false, // Not needed in gateway context
|
|
indirect: true, // RAG/tool injection detection
|
|
selfConsciousness: false,
|
|
crossModel: false,
|
|
behavioral: true, // Session profiling
|
|
unicode: true, // Homoglyph/script detection
|
|
tokenizer: true, // I.g.n.o.r.e-style attacks
|
|
compressedPayload: true,
|
|
},
|
|
logging: { level: 'warn', structured: true, incidentLog: false },
|
|
} as any); // DeepPartial config — merges with defaults
|
|
|
|
const CompletionRequestSchema = z.object({
|
|
caller: z.string().min(1).max(100),
|
|
task_type: z.string().optional(),
|
|
input: z.string().min(1).max(50_000),
|
|
language: z.enum(['de', 'en']).optional(),
|
|
context: z.record(z.unknown()).optional(),
|
|
options: z
|
|
.object({
|
|
model: z.string().optional(),
|
|
temperature: z.number().min(0).max(2).optional(),
|
|
max_tokens: z.number().int().positive().max(16_384).optional(),
|
|
return_validation_details: z.boolean().optional(),
|
|
})
|
|
.optional(),
|
|
});
|
|
|
|
type CompletionRequest = z.infer<typeof CompletionRequestSchema>;
|
|
|
|
const SKIP_SHIELDX_CALLERS = new Set(['internal', 'shieldx']);
|
|
|
|
async function runShieldXScan(
|
|
input: string,
|
|
caller: string,
|
|
): Promise<{ passed: boolean; reason?: string; threatLevel?: string; phase?: string; latencyMs?: number }> {
|
|
try {
|
|
const result = await shieldx.scanInput(input);
|
|
|
|
if (result.detected) {
|
|
logger.warn({
|
|
caller,
|
|
threatLevel: result.threatLevel,
|
|
phase: result.killChainPhase,
|
|
action: result.action,
|
|
latencyMs: result.latencyMs,
|
|
ensemble: result.ensemble,
|
|
atlasMapping: result.atlasMapping?.techniqueIds?.slice(0, 5),
|
|
scannerCount: result.scanResults.length,
|
|
}, 'ShieldX threat detected — input blocked');
|
|
|
|
return {
|
|
passed: false,
|
|
reason: `Prompt injection detected: ${result.killChainPhase} (${result.threatLevel})`,
|
|
threatLevel: result.threatLevel,
|
|
phase: result.killChainPhase,
|
|
latencyMs: result.latencyMs,
|
|
};
|
|
}
|
|
|
|
return { passed: true, latencyMs: result.latencyMs };
|
|
} catch (err) {
|
|
logger.error({ err, caller }, 'ShieldX scan error — failing open');
|
|
return { passed: true };
|
|
}
|
|
}
|
|
|
|
export async function completionRoute(fastify: FastifyInstance): Promise<void> {
|
|
fastify.post(
|
|
'/completion',
|
|
{
|
|
config: { rateLimit: false }, // Custom rate limiting via caller
|
|
},
|
|
async (request: FastifyRequest, reply: FastifyReply) => {
|
|
const startMs = Date.now();
|
|
|
|
let body: CompletionRequest;
|
|
try {
|
|
body = CompletionRequestSchema.parse(request.body);
|
|
} catch (err) {
|
|
return reply.status(400).send({
|
|
statusCode: 400,
|
|
error: 'Bad Request',
|
|
message: err instanceof z.ZodError ? err.errors[0]?.message ?? 'Invalid request' : 'Invalid request body',
|
|
});
|
|
}
|
|
|
|
const { caller, input, language, context, options } = body;
|
|
const returnValidationDetails = options?.return_validation_details ?? false;
|
|
|
|
// Stage 2: ShieldX scan (real library, 547+ rules, sub-millisecond)
|
|
if (!SKIP_SHIELDX_CALLERS.has(caller)) {
|
|
const shieldResult = await runShieldXScan(input, caller);
|
|
if (!shieldResult.passed) {
|
|
requestsTotal.labels({ caller, task_type: 'unknown', status: 'rejected' }).inc();
|
|
return reply.status(400).send({
|
|
statusCode: 400,
|
|
error: 'Rejected',
|
|
message: shieldResult.reason ?? 'Input rejected by security scan',
|
|
threat_level: shieldResult.threatLevel,
|
|
kill_chain_phase: shieldResult.phase,
|
|
shieldx_latency_ms: shieldResult.latencyMs,
|
|
});
|
|
}
|
|
}
|
|
|
|
// Stage 3: Pre-classifier
|
|
let taskType = body.task_type;
|
|
let classificationResult;
|
|
if (!taskType) {
|
|
try {
|
|
classificationResult = await classifyInput(input);
|
|
taskType = classificationResult.task_type;
|
|
} catch (err) {
|
|
logger.warn({ err }, 'Pre-classifier failed');
|
|
taskType = 'generic_qa';
|
|
}
|
|
}
|
|
|
|
// Stage 4: Router
|
|
let decision;
|
|
try {
|
|
decision = route(taskType, caller, {
|
|
model: options?.model,
|
|
temperature: options?.temperature,
|
|
max_tokens: options?.max_tokens,
|
|
});
|
|
} catch (err) {
|
|
return reply.status(400).send({
|
|
statusCode: 400,
|
|
error: 'Routing Error',
|
|
message: err instanceof Error ? err.message : 'Failed to route request',
|
|
});
|
|
}
|
|
|
|
// Stage 5: Prompt assembly
|
|
// Use taskType directly for template lookup (so tip_transceiver_enrich.yaml is used,
|
|
// not the generic_qa fallback from routing). The router only selects the model.
|
|
//
|
|
// Variable resolution strategy:
|
|
// 1. Explicit context fields take priority (callers can pass structured data)
|
|
// 2. `input` is used as fallback for ALL common content variables so simple
|
|
// one-field callers work without knowing each template's specific var name.
|
|
const contextVars = context
|
|
? Object.fromEntries(Object.entries(context).map(([k, v]) => [k, v as string]))
|
|
: {};
|
|
|
|
// Common content variable names across all 59 templates — all default to `input`
|
|
const inputAliases: Record<string, string> = {
|
|
source_data: input, ocr_text: input, transcription: input,
|
|
ticket_content: input, alert_data: input, incident_data: input,
|
|
lldp_data: input, cve_data: input, inventory: input,
|
|
anomaly_data: input, flagged_input: input, attack_description: input,
|
|
bgp_data: input, health_checks: input, market_data: input,
|
|
manuscript_text: input, raw_content: input, content: input,
|
|
// Additional structured vars with sensible fallbacks
|
|
peeringdb_data: input, bgp_routes: input, network_context: input,
|
|
alert_context: input, affected_inventory: input,
|
|
};
|
|
|
|
const resolved = resolvePrompt(
|
|
taskType ?? decision.prompt_template,
|
|
{
|
|
...inputAliases, // low priority: input as fallback for all content vars
|
|
...contextVars, // medium priority: explicit context fields override aliases
|
|
input, // always available as {{input}}
|
|
user_context: context,
|
|
},
|
|
language ?? 'en',
|
|
);
|
|
|
|
// Stage 6: Ollama call with circuit breaker + retry
|
|
let ollamaResponse;
|
|
try {
|
|
ollamaResponse = await callOllamaWithFallbackChain(
|
|
{
|
|
model: decision.model,
|
|
prompt: resolved.prompt,
|
|
system: resolved.system,
|
|
options: {
|
|
temperature: decision.temperature,
|
|
num_predict: decision.max_tokens,
|
|
},
|
|
format: decision.output_format === 'json' ? 'json' : '',
|
|
stream: false,
|
|
},
|
|
decision.fallback_chain,
|
|
decision.tier,
|
|
);
|
|
} catch (err) {
|
|
const latency = Date.now() - startMs;
|
|
logger.error({ err, caller, taskType }, 'Ollama call failed');
|
|
requestsTotal.labels({ caller, task_type: taskType, status: 'rejected' }).inc();
|
|
latencySeconds.labels({ caller, task_type: taskType, model: decision.model }).observe(latency / 1000);
|
|
|
|
return reply.status(503).send({
|
|
statusCode: 503,
|
|
error: 'Service Unavailable',
|
|
message: 'LLM service unavailable, please retry',
|
|
});
|
|
}
|
|
|
|
const outputText = ollamaResponse.response;
|
|
const latencyMs = Date.now() - startMs;
|
|
|
|
// Stage 7: Post-validation chain
|
|
const validationOutput = await runPostValidation(outputText, {
|
|
validators: decision.validators,
|
|
language,
|
|
output_format: decision.output_format,
|
|
requires_fact_check: decision.requires_fact_check,
|
|
schema: resolved.schema,
|
|
});
|
|
|
|
// Stage 8: Confidence gate
|
|
const confidenceResult = evaluateConfidence(validationOutput);
|
|
|
|
// Record metrics
|
|
requestsTotal.labels({ caller, task_type: taskType, status: confidenceResult.status }).inc();
|
|
latencySeconds.labels({ caller, task_type: taskType, model: ollamaResponse.model ?? decision.model }).observe(latencyMs / 1000);
|
|
tokensTotal.labels({ direction: 'in', model: decision.model }).inc(ollamaResponse.prompt_eval_count ?? 0);
|
|
tokensTotal.labels({ direction: 'out', model: decision.model }).inc(ollamaResponse.eval_count ?? 0);
|
|
confidenceScore.labels({ task_type: taskType, model: decision.model }).observe(confidenceResult.score);
|
|
|
|
// Record ban hits in metrics
|
|
for (const violation of validationOutput.ban_violations) {
|
|
banlistHitsTotal.labels({ term: violation.term, language: violation.language, category: violation.category }).inc();
|
|
}
|
|
|
|
// Record validation failures
|
|
for (const result of validationOutput.results) {
|
|
if (!result.passed) {
|
|
validationFailuresTotal.labels({ validator: result.validator, task_type: taskType }).inc();
|
|
}
|
|
}
|
|
|
|
// Stage 9: Audit log
|
|
const inputHash = hashText(input);
|
|
const outputHash = hashText(outputText);
|
|
|
|
const callId = await writeAuditLog({
|
|
caller,
|
|
task_type: taskType,
|
|
model_used: decision.model,
|
|
prompt_id: resolved.prompt_id,
|
|
prompt_version: resolved.prompt_version,
|
|
input_hash: inputHash,
|
|
output_text: confidenceResult.status !== 'pending_review' ? outputText : undefined,
|
|
output_hash: outputHash,
|
|
token_count_in: ollamaResponse.prompt_eval_count ?? 0,
|
|
token_count_out: ollamaResponse.eval_count ?? 0,
|
|
latency_ms: latencyMs,
|
|
confidence: confidenceResult.score,
|
|
status: confidenceResult.status,
|
|
validation_log: validationOutput.results,
|
|
ban_hits: validationOutput.ban_violations,
|
|
metadata: {
|
|
classification: classificationResult,
|
|
model_tier: decision.tier,
|
|
fallback_used: ollamaResponse.model !== decision.model,
|
|
},
|
|
});
|
|
|
|
// Write ban analytics
|
|
if (validationOutput.ban_violations.length > 0 && callId) {
|
|
void writeBanAnalytics(callId, validationOutput.ban_violations, caller, taskType);
|
|
}
|
|
|
|
// Add to review queue if pending_review
|
|
if (confidenceResult.status === 'pending_review' && callId) {
|
|
void addToReviewQueue({
|
|
callId,
|
|
caller,
|
|
taskType,
|
|
inputText: input,
|
|
outputText,
|
|
confidence: confidenceResult.score,
|
|
validationLog: validationOutput.results,
|
|
});
|
|
}
|
|
|
|
// Stage 10: Response
|
|
const responseBody: Record<string, unknown> = {
|
|
id: callId,
|
|
status: confidenceResult.status,
|
|
confidence: Math.round(confidenceResult.score * 100) / 100,
|
|
model: decision.model,
|
|
task_type: taskType,
|
|
latency_ms: latencyMs,
|
|
tokens: {
|
|
in: ollamaResponse.prompt_eval_count ?? 0,
|
|
out: ollamaResponse.eval_count ?? 0,
|
|
},
|
|
};
|
|
|
|
if (confidenceResult.status !== 'pending_review') {
|
|
responseBody['output'] = outputText;
|
|
} else {
|
|
responseBody['output'] = null;
|
|
responseBody['message'] = 'Output is pending human review due to low confidence';
|
|
}
|
|
|
|
if (returnValidationDetails) {
|
|
responseBody['validation'] = validationOutput.results;
|
|
responseBody['confidence_detail'] = {
|
|
base_score: confidenceResult.base_score,
|
|
total_impact: confidenceResult.total_impact,
|
|
final_score: confidenceResult.score,
|
|
};
|
|
}
|
|
|
|
return reply.status(200).send(responseBody);
|
|
},
|
|
);
|
|
}
|