Rene Fichtmueller b4593b6582 feat: integrate real @shieldx/core library into gateway pipeline
Replace recursive HTTP-based ShieldX scan with direct library integration.
- 547+ rules, 50+ languages, sub-millisecond scans
- Enables: rules, entropy, indirect injection, behavioral, unicode,
  tokenizer, compressed payload detection
- Disables Ollama-dependent scanners for zero external dependency
- Response now includes threat_level, kill_chain_phase, shieldx_latency_ms
2026-04-07 09:03:02 +02:00

350 lines
13 KiB
TypeScript

import type { FastifyInstance, FastifyRequest, FastifyReply } from 'fastify';
import { z } from 'zod';
import { classifyInput } from '../pipeline/pre-classifier.js';
import { route } from '../pipeline/router.js';
import { resolvePrompt } from '../pipeline/prompt-resolver.js';
import { callOllamaWithFallbackChain } from '../pipeline/llm-client.js';
import { runPostValidation } from '../pipeline/post-validator.js';
import { evaluateConfidence } from '../pipeline/confidence-gate.js';
import { writeAuditLog, writeBanAnalytics, hashText } from '../observability/audit-log.js';
import { addToReviewQueue } from '../observability/review-queue.js';
import {
requestsTotal,
latencySeconds,
tokensTotal,
confidenceScore,
banlistHitsTotal,
validationFailuresTotal,
} from '../observability/metrics.js';
import { logger } from '../observability/logger.js';
import { ShieldX } from '@shieldx/core';
// Singleton ShieldX instance — initialized once, sub-millisecond scans
// Disable Ollama-dependent scanners (sentinel, constitutional, embedding, attention)
// to keep gateway scans fast and dependency-free
const shieldx = new ShieldX({
scanners: {
rules: true, // 547+ rules, 50+ languages
sentinel: false, // Requires Ollama
constitutional: false, // Requires Ollama
embedding: false, // Requires Ollama
embeddingAnomaly: false,
entropy: true, // Zero-cost entropy analysis
yara: false, // Requires YARA binary
attention: false, // Requires Ollama
canary: false, // Not needed in gateway context
indirect: true, // RAG/tool injection detection
selfConsciousness: false,
crossModel: false,
behavioral: true, // Session profiling
unicode: true, // Homoglyph/script detection
tokenizer: true, // I.g.n.o.r.e-style attacks
compressedPayload: true,
},
logging: { level: 'warn', structured: true, incidentLog: false },
} as any); // DeepPartial config — merges with defaults
const CompletionRequestSchema = z.object({
caller: z.string().min(1).max(100),
task_type: z.string().optional(),
input: z.string().min(1).max(50_000),
language: z.enum(['de', 'en']).optional(),
context: z.record(z.unknown()).optional(),
options: z
.object({
model: z.string().optional(),
temperature: z.number().min(0).max(2).optional(),
max_tokens: z.number().int().positive().max(16_384).optional(),
return_validation_details: z.boolean().optional(),
})
.optional(),
});
type CompletionRequest = z.infer<typeof CompletionRequestSchema>;
const SKIP_SHIELDX_CALLERS = new Set(['internal', 'shieldx']);
async function runShieldXScan(
input: string,
caller: string,
): Promise<{ passed: boolean; reason?: string; threatLevel?: string; phase?: string; latencyMs?: number }> {
try {
const result = await shieldx.scanInput(input);
if (result.detected) {
logger.warn({
caller,
threatLevel: result.threatLevel,
phase: result.killChainPhase,
action: result.action,
latencyMs: result.latencyMs,
ensemble: result.ensemble,
atlasMapping: result.atlasMapping?.techniqueIds?.slice(0, 5),
scannerCount: result.scanResults.length,
}, 'ShieldX threat detected — input blocked');
return {
passed: false,
reason: `Prompt injection detected: ${result.killChainPhase} (${result.threatLevel})`,
threatLevel: result.threatLevel,
phase: result.killChainPhase,
latencyMs: result.latencyMs,
};
}
return { passed: true, latencyMs: result.latencyMs };
} catch (err) {
logger.error({ err, caller }, 'ShieldX scan error — failing open');
return { passed: true };
}
}
export async function completionRoute(fastify: FastifyInstance): Promise<void> {
fastify.post(
'/completion',
{
config: { rateLimit: false }, // Custom rate limiting via caller
},
async (request: FastifyRequest, reply: FastifyReply) => {
const startMs = Date.now();
let body: CompletionRequest;
try {
body = CompletionRequestSchema.parse(request.body);
} catch (err) {
return reply.status(400).send({
statusCode: 400,
error: 'Bad Request',
message: err instanceof z.ZodError ? err.errors[0]?.message ?? 'Invalid request' : 'Invalid request body',
});
}
const { caller, input, language, context, options } = body;
const returnValidationDetails = options?.return_validation_details ?? false;
// Stage 2: ShieldX scan (real library, 547+ rules, sub-millisecond)
if (!SKIP_SHIELDX_CALLERS.has(caller)) {
const shieldResult = await runShieldXScan(input, caller);
if (!shieldResult.passed) {
requestsTotal.labels({ caller, task_type: 'unknown', status: 'rejected' }).inc();
return reply.status(400).send({
statusCode: 400,
error: 'Rejected',
message: shieldResult.reason ?? 'Input rejected by security scan',
threat_level: shieldResult.threatLevel,
kill_chain_phase: shieldResult.phase,
shieldx_latency_ms: shieldResult.latencyMs,
});
}
}
// Stage 3: Pre-classifier
let taskType = body.task_type;
let classificationResult;
if (!taskType) {
try {
classificationResult = await classifyInput(input);
taskType = classificationResult.task_type;
} catch (err) {
logger.warn({ err }, 'Pre-classifier failed');
taskType = 'generic_qa';
}
}
// Stage 4: Router
let decision;
try {
decision = route(taskType, caller, {
model: options?.model,
temperature: options?.temperature,
max_tokens: options?.max_tokens,
});
} catch (err) {
return reply.status(400).send({
statusCode: 400,
error: 'Routing Error',
message: err instanceof Error ? err.message : 'Failed to route request',
});
}
// Stage 5: Prompt assembly
// Use taskType directly for template lookup (so tip_transceiver_enrich.yaml is used,
// not the generic_qa fallback from routing). The router only selects the model.
//
// Variable resolution strategy:
// 1. Explicit context fields take priority (callers can pass structured data)
// 2. `input` is used as fallback for ALL common content variables so simple
// one-field callers work without knowing each template's specific var name.
const contextVars = context
? Object.fromEntries(Object.entries(context).map(([k, v]) => [k, v as string]))
: {};
// Common content variable names across all 59 templates — all default to `input`
const inputAliases: Record<string, string> = {
source_data: input, ocr_text: input, transcription: input,
ticket_content: input, alert_data: input, incident_data: input,
lldp_data: input, cve_data: input, inventory: input,
anomaly_data: input, flagged_input: input, attack_description: input,
bgp_data: input, health_checks: input, market_data: input,
manuscript_text: input, raw_content: input, content: input,
// Additional structured vars with sensible fallbacks
peeringdb_data: input, bgp_routes: input, network_context: input,
alert_context: input, affected_inventory: input,
};
const resolved = resolvePrompt(
taskType ?? decision.prompt_template,
{
...inputAliases, // low priority: input as fallback for all content vars
...contextVars, // medium priority: explicit context fields override aliases
input, // always available as {{input}}
user_context: context,
},
language ?? 'en',
);
// Stage 6: Ollama call with circuit breaker + retry
let ollamaResponse;
try {
ollamaResponse = await callOllamaWithFallbackChain(
{
model: decision.model,
prompt: resolved.prompt,
system: resolved.system,
options: {
temperature: decision.temperature,
num_predict: decision.max_tokens,
},
format: decision.output_format === 'json' ? 'json' : '',
stream: false,
},
decision.fallback_chain,
decision.tier,
);
} catch (err) {
const latency = Date.now() - startMs;
logger.error({ err, caller, taskType }, 'Ollama call failed');
requestsTotal.labels({ caller, task_type: taskType, status: 'rejected' }).inc();
latencySeconds.labels({ caller, task_type: taskType, model: decision.model }).observe(latency / 1000);
return reply.status(503).send({
statusCode: 503,
error: 'Service Unavailable',
message: 'LLM service unavailable, please retry',
});
}
const outputText = ollamaResponse.response;
const latencyMs = Date.now() - startMs;
// Stage 7: Post-validation chain
const validationOutput = await runPostValidation(outputText, {
validators: decision.validators,
language,
output_format: decision.output_format,
requires_fact_check: decision.requires_fact_check,
schema: resolved.schema,
});
// Stage 8: Confidence gate
const confidenceResult = evaluateConfidence(validationOutput);
// Record metrics
requestsTotal.labels({ caller, task_type: taskType, status: confidenceResult.status }).inc();
latencySeconds.labels({ caller, task_type: taskType, model: ollamaResponse.model ?? decision.model }).observe(latencyMs / 1000);
tokensTotal.labels({ direction: 'in', model: decision.model }).inc(ollamaResponse.prompt_eval_count ?? 0);
tokensTotal.labels({ direction: 'out', model: decision.model }).inc(ollamaResponse.eval_count ?? 0);
confidenceScore.labels({ task_type: taskType, model: decision.model }).observe(confidenceResult.score);
// Record ban hits in metrics
for (const violation of validationOutput.ban_violations) {
banlistHitsTotal.labels({ term: violation.term, language: violation.language, category: violation.category }).inc();
}
// Record validation failures
for (const result of validationOutput.results) {
if (!result.passed) {
validationFailuresTotal.labels({ validator: result.validator, task_type: taskType }).inc();
}
}
// Stage 9: Audit log
const inputHash = hashText(input);
const outputHash = hashText(outputText);
const callId = await writeAuditLog({
caller,
task_type: taskType,
model_used: decision.model,
prompt_id: resolved.prompt_id,
prompt_version: resolved.prompt_version,
input_hash: inputHash,
output_text: confidenceResult.status !== 'pending_review' ? outputText : undefined,
output_hash: outputHash,
token_count_in: ollamaResponse.prompt_eval_count ?? 0,
token_count_out: ollamaResponse.eval_count ?? 0,
latency_ms: latencyMs,
confidence: confidenceResult.score,
status: confidenceResult.status,
validation_log: validationOutput.results,
ban_hits: validationOutput.ban_violations,
metadata: {
classification: classificationResult,
model_tier: decision.tier,
fallback_used: ollamaResponse.model !== decision.model,
},
});
// Write ban analytics
if (validationOutput.ban_violations.length > 0 && callId) {
void writeBanAnalytics(callId, validationOutput.ban_violations, caller, taskType);
}
// Add to review queue if pending_review
if (confidenceResult.status === 'pending_review' && callId) {
void addToReviewQueue({
callId,
caller,
taskType,
inputText: input,
outputText,
confidence: confidenceResult.score,
validationLog: validationOutput.results,
});
}
// Stage 10: Response
const responseBody: Record<string, unknown> = {
id: callId,
status: confidenceResult.status,
confidence: Math.round(confidenceResult.score * 100) / 100,
model: decision.model,
task_type: taskType,
latency_ms: latencyMs,
tokens: {
in: ollamaResponse.prompt_eval_count ?? 0,
out: ollamaResponse.eval_count ?? 0,
},
};
if (confidenceResult.status !== 'pending_review') {
responseBody['output'] = outputText;
} else {
responseBody['output'] = null;
responseBody['message'] = 'Output is pending human review due to low confidence';
}
if (returnValidationDetails) {
responseBody['validation'] = validationOutput.results;
responseBody['confidence_detail'] = {
base_score: confidenceResult.base_score,
total_impact: confidenceResult.total_impact,
final_score: confidenceResult.score,
};
}
return reply.status(200).send(responseBody);
},
);
}