diff --git a/package-lock.json b/package-lock.json index 7a0d8f9..099663a 100644 --- a/package-lock.json +++ b/package-lock.json @@ -11,6 +11,38 @@ "packages/*" ] }, + "../../../shieldx": { + "name": "@shieldx/core", + "version": "0.5.0", + "license": "Apache-2.0", + "dependencies": { + "pg": "^8.13.0", + "pgvector": "^0.2.0", + "pino": "^9.6.0", + "zod": "^3.24.0" + }, + "devDependencies": { + "@types/node": "^22.0.0", + "@types/pg": "^8.11.0", + "@vitest/coverage-v8": "^3.0.0", + "eslint": "^9.0.0", + "tsup": "^8.3.0", + "tsx": "^4.19.0", + "typescript": "^5.7.0", + "vitest": "^3.0.0" + }, + "engines": { + "node": ">=20.0.0" + }, + "peerDependencies": { + "next": ">=15.0.0" + }, + "peerDependenciesMeta": { + "next": { + "optional": true + } + } + }, "node_modules/@esbuild/aix-ppc64": { "version": "0.27.7", "resolved": "https://registry.npmjs.org/@esbuild/aix-ppc64/-/aix-ppc64-0.27.7.tgz", @@ -935,6 +967,10 @@ "win32" ] }, + "node_modules/@shieldx/core": { + "resolved": "../../../shieldx", + "link": true + }, "node_modules/@types/estree": { "version": "1.0.8", "resolved": "https://registry.npmjs.org/@types/estree/-/estree-1.0.8.tgz", @@ -3425,6 +3461,7 @@ "@fastify/cors": "^9.0.1", "@fastify/helmet": "^11.1.1", "@fastify/rate-limit": "^9.1.0", + "@shieldx/core": "file:../../../../../shieldx", "ajv": "^8.17.1", "fastify": "^4.28.1", "franc": "^6.2.0", diff --git a/packages/gateway/package.json b/packages/gateway/package.json index 9642941..7008a6b 100644 --- a/packages/gateway/package.json +++ b/packages/gateway/package.json @@ -12,6 +12,7 @@ "@fastify/cors": "^9.0.1", "@fastify/helmet": "^11.1.1", "@fastify/rate-limit": "^9.1.0", + "@shieldx/core": "file:../../../../../shieldx", "ajv": "^8.17.1", "fastify": "^4.28.1", "franc": "^6.2.0", diff --git a/packages/gateway/src/routes/completion.ts b/packages/gateway/src/routes/completion.ts index 9563e91..3c7e739 100644 --- a/packages/gateway/src/routes/completion.ts +++ b/packages/gateway/src/routes/completion.ts @@ -17,6 +17,32 @@ import { validationFailuresTotal, } from '../observability/metrics.js'; import { logger } from '../observability/logger.js'; +import { ShieldX } from '@shieldx/core'; + +// Singleton ShieldX instance — initialized once, sub-millisecond scans +// Disable Ollama-dependent scanners (sentinel, constitutional, embedding, attention) +// to keep gateway scans fast and dependency-free +const shieldx = new ShieldX({ + scanners: { + rules: true, // 547+ rules, 50+ languages + sentinel: false, // Requires Ollama + constitutional: false, // Requires Ollama + embedding: false, // Requires Ollama + embeddingAnomaly: false, + entropy: true, // Zero-cost entropy analysis + yara: false, // Requires YARA binary + attention: false, // Requires Ollama + canary: false, // Not needed in gateway context + indirect: true, // RAG/tool injection detection + selfConsciousness: false, + crossModel: false, + behavioral: true, // Session profiling + unicode: true, // Homoglyph/script detection + tokenizer: true, // I.g.n.o.r.e-style attacks + compressedPayload: true, + }, + logging: { level: 'warn', structured: true, incidentLog: false }, +} as any); // DeepPartial config — merges with defaults const CompletionRequestSchema = z.object({ caller: z.string().min(1).max(100), @@ -38,43 +64,37 @@ type CompletionRequest = z.infer; const SKIP_SHIELDX_CALLERS = new Set(['internal', 'shieldx']); -async function runShieldXScan(input: string, caller: string): Promise<{ passed: boolean; reason?: string }> { - const GATEWAY_URL = `http://localhost:${process.env['PORT'] ?? '3100'}`; +async function runShieldXScan( + input: string, + caller: string, +): Promise<{ passed: boolean; reason?: string; threatLevel?: string; phase?: string; latencyMs?: number }> { try { - const response = await fetch(`${GATEWAY_URL}/v1/completion`, { - method: 'POST', - headers: { 'Content-Type': 'application/json', 'X-Caller-ID': 'internal' }, - body: JSON.stringify({ - caller: 'internal', - task_type: 'shieldx_threat_classification', - input, - options: { return_validation_details: false }, - }), - signal: AbortSignal.timeout(8000), - }); + const result = await shieldx.scanInput(input); - if (!response.ok) return { passed: true }; // Fail open if ShieldX is down + if (result.detected) { + logger.warn({ + caller, + threatLevel: result.threatLevel, + phase: result.killChainPhase, + action: result.action, + latencyMs: result.latencyMs, + ensemble: result.ensemble, + atlasMapping: result.atlasMapping?.techniqueIds?.slice(0, 5), + scannerCount: result.scanResults.length, + }, 'ShieldX threat detected — input blocked'); - const result = await response.json() as { output?: string; status?: string }; - if (result.status !== 'approved' || !result.output) return { passed: true }; - - type ShieldResult = { threat_detected: boolean; threat_type?: string; confidence?: number }; - let parsed: ShieldResult; - try { - parsed = JSON.parse(result.output) as ShieldResult; - } catch { - return { passed: true }; + return { + passed: false, + reason: `Prompt injection detected: ${result.killChainPhase} (${result.threatLevel})`, + threatLevel: result.threatLevel, + phase: result.killChainPhase, + latencyMs: result.latencyMs, + }; } - if (parsed.threat_detected && (parsed.confidence ?? 0) > 0.8) { - logger.warn({ caller, threat_type: parsed.threat_type }, 'ShieldX threat detected'); - return { passed: false, reason: `Threat detected: ${parsed.threat_type ?? 'unknown'}` }; - } - - return { passed: true }; + return { passed: true, latencyMs: result.latencyMs }; } catch (err) { - // ShieldX unavailable — fail open (log but continue) - logger.warn({ err, caller }, 'ShieldX scan failed, continuing without scan'); + logger.error({ err, caller }, 'ShieldX scan error — failing open'); return { passed: true }; } } @@ -102,7 +122,7 @@ export async function completionRoute(fastify: FastifyInstance): Promise { const { caller, input, language, context, options } = body; const returnValidationDetails = options?.return_validation_details ?? false; - // Stage 2: ShieldX scan + // Stage 2: ShieldX scan (real library, 547+ rules, sub-millisecond) if (!SKIP_SHIELDX_CALLERS.has(caller)) { const shieldResult = await runShieldXScan(input, caller); if (!shieldResult.passed) { @@ -111,6 +131,9 @@ export async function completionRoute(fastify: FastifyInstance): Promise { statusCode: 400, error: 'Rejected', message: shieldResult.reason ?? 'Input rejected by security scan', + threat_level: shieldResult.threatLevel, + kill_chain_phase: shieldResult.phase, + shieldx_latency_ms: shieldResult.latencyMs, }); } }