From 5d4e19ad5961e42b90f7bfc920ea80da6edc5089 Mon Sep 17 00:00:00 2001 From: "A.J. Shulman" Date: Wed, 6 Nov 2024 22:23:03 -0500 Subject: Enhance assistant security with structured validation and input sanitization - Prompt enhancements: - Enforce strict response structure validation by requiring , , , and tags in responses. - Add self-validation instruction in for assistant to check response structure before outputting. - Instruct assistant to ignore XML-like syntax from user input, treating any , , etc., as plain text. - Code changes: - Implement `validateAssistantResponse` function to enforce required response structure (e.g., ensuring element). - Add input sanitization using `lodash.escape` to treat user inputs as plain text, preventing XML or HTML injection. - Configure XML parser to ignore external entities and avoid interpreting embedded XML-like syntax. - Introduce fallback error handling in parsing and validation to prevent assistant crashes on malformed or unexpected input. - Log response errors with detailed messages to aid debugging and improve system resilience. - Enhance input validation for tools by adding parameter checks, handling malformed data gracefully, and logging safety errors. --- .../views/nodes/chatbot/agentsystem/Agent.ts | 136 ++++++++++++++++++++- 1 file changed, 130 insertions(+), 6 deletions(-) (limited to 'src/client/views/nodes/chatbot/agentsystem/Agent.ts') diff --git a/src/client/views/nodes/chatbot/agentsystem/Agent.ts b/src/client/views/nodes/chatbot/agentsystem/Agent.ts index 870abbc47..750bbbf4f 100644 --- a/src/client/views/nodes/chatbot/agentsystem/Agent.ts +++ b/src/client/views/nodes/chatbot/agentsystem/Agent.ts @@ -2,6 +2,7 @@ import dotenv from 'dotenv'; import { XMLBuilder, XMLParser } from 'fast-xml-parser'; import OpenAI from 'openai'; import { ChatCompletionMessageParam } from 'openai/resources'; +import { escape } from 'lodash'; // Imported escape from lodash import { AnswerParser } from '../response_parsers/AnswerParser'; import { StreamedAnswerParser } from '../response_parsers/StreamedAnswerParser'; import { CalculateTool } from '../tools/CalculateTool'; @@ -90,9 +91,10 @@ export class Agent { */ async askAgent(question: string, onProcessingUpdate: (processingUpdate: ProcessingInfo[]) => void, onAnswerUpdate: (answerUpdate: string) => void, maxTurns: number = 30): Promise { console.log(`Starting query: ${question}`); + const sanitizedQuestion = escape(question); // Sanitized user input - // Push user's question to message history - this.messages.push({ role: 'user', content: question }); + // Push sanitized user's question to message history + this.messages.push({ role: 'user', content: sanitizedQuestion }); // Retrieve chat history and generate system prompt const chatHistory = this._history(); @@ -100,14 +102,20 @@ export class Agent { // Initialize intermediate messages this.interMessages = [{ role: 'system', content: systemPrompt }]; - this.interMessages.push({ role: 'user', content: `${question}` }); + + this.interMessages.push({ + role: 'user', + content: this.constructUserPrompt(1, 'user', `${sanitizedQuestion}`), + }); // Setup XML parser and builder const parser = new XMLParser({ ignoreAttributes: false, attributeNamePrefix: '@_', textNodeName: '_text', - isArray: (name /* , jpath, isLeafNode, isAttribute */) => ['query', 'url'].indexOf(name) !== -1, + isArray: name => ['query', 'url'].indexOf(name) !== -1, + processEntities: false, // Disable processing of entities + stopNodes: ['*.entity'], // Do not process any entities }); const builder = new XMLBuilder({ ignoreAttributes: false, attributeNamePrefix: '@_' }); @@ -128,8 +136,11 @@ export class Agent { try { // Parse XML result from the assistant parsedResult = parser.parse(result); + + // Validate the structure of the parsedResult + this.validateAssistantResponse(parsedResult); } catch (error) { - throw new Error(`Error parsing response: ${error}`); + throw new Error(`Error parsing or validating response: ${error}`); } // Extract the stage from the parsed result @@ -162,7 +173,10 @@ export class Agent { } else { // Handle error in case of an invalid action console.log('Error: No valid action'); - this.interMessages.push({ role: 'user', content: `No valid action, try again.` }); + this.interMessages.push({ + role: 'user', + content: `No valid action, try again.`, + }); break; } } else if (key === 'action_input') { @@ -198,6 +212,10 @@ export class Agent { throw new Error('Reached maximum turns. Ending query.'); } + private constructUserPrompt(stageNumber: number, role: string, content: string): string { + return `${content}`; + } + /** * Executes a step in the conversation, processing the assistant's response and parsing it in real-time. * @param onProcessingUpdate Callback for processing updates. @@ -211,6 +229,7 @@ export class Agent { messages: this.interMessages as ChatCompletionMessageParam[], temperature: 0, stream: true, + stop: [''], }); let fullResponse: string = ''; @@ -267,6 +286,111 @@ export class Agent { return fullResponse; } + /** + * Validates the assistant's response to ensure it conforms to the expected XML structure. + * @param response The parsed XML response from the assistant. + * @throws An error if the response does not meet the expected structure. + */ + private validateAssistantResponse(response: any) { + if (!response.stage) { + throw new Error('Response does not contain a element'); + } + + // Validate that the stage has the required attributes + const stage = response.stage; + if (!stage['@_number'] || !stage['@_role']) { + throw new Error('Stage element must have "number" and "role" attributes'); + } + + // Extract the role of the stage to determine expected content + const role = stage['@_role']; + + // Depending on the role, validate the presence of required elements + if (role === 'assistant') { + // Assistant's response should contain either 'thought', 'action', 'action_input', or 'answer' + if (!('thought' in stage || 'action' in stage || 'action_input' in stage || 'answer' in stage)) { + throw new Error('Assistant stage must contain a thought, action, action_input, or answer element'); + } + + // If 'thought' is present, validate it + if ('thought' in stage) { + if (typeof stage.thought !== 'string' || stage.thought.trim() === '') { + throw new Error('Thought must be a non-empty string'); + } + } + + // If 'action' is present, validate it + if ('action' in stage) { + if (typeof stage.action !== 'string' || stage.action.trim() === '') { + throw new Error('Action must be a non-empty string'); + } + + // Optional: Check if the action is among allowed actions + const allowedActions = Object.keys(this.tools); + if (!allowedActions.includes(stage.action)) { + throw new Error(`Action "${stage.action}" is not a valid tool`); + } + } + + // If 'action_input' is present, validate its structure + if ('action_input' in stage) { + const actionInput = stage.action_input; + + if (!('action_input_description' in actionInput) || typeof actionInput.action_input_description !== 'string') { + throw new Error('action_input must contain an action_input_description string'); + } + + if (!('inputs' in actionInput)) { + throw new Error('action_input must contain an inputs object'); + } + + // Further validation of inputs can be done here based on the expected parameters of the action + } + + // If 'answer' is present, validate its structure + if ('answer' in stage) { + const answer = stage.answer; + + // Ensure answer contains at least one of the required elements + if (!('grounded_text' in answer || 'normal_text' in answer)) { + throw new Error('Answer must contain grounded_text or normal_text'); + } + + // Validate follow_up_questions + if (!('follow_up_questions' in answer)) { + throw new Error('Answer must contain follow_up_questions'); + } + + // Validate loop_summary + if (!('loop_summary' in answer)) { + throw new Error('Answer must contain a loop_summary'); + } + + // Additional validation for citations, grounded_text, etc., can be added here + } + } else if (role === 'user') { + // User's stage should contain 'query' or 'observation' + if (!('query' in stage || 'observation' in stage)) { + throw new Error('User stage must contain a query or observation element'); + } + + // Validate 'query' if present + if ('query' in stage && typeof stage.query !== 'string') { + throw new Error('Query must be a string'); + } + + // Validate 'observation' if present + if ('observation' in stage) { + // Ensure observation has the correct structure + // This can be expanded based on how observations are structured + } + } else { + throw new Error(`Unknown role "${role}" in stage`); + } + + // Add any additional validation rules as necessary + } + /** * Helper function to check if a string can be parsed as an array of the expected type. * @param input The input string to check. -- cgit v1.2.3-70-g09d2