diff options
| author | A.J. Shulman <Shulman.aj@gmail.com> | 2024-11-06 22:23:03 -0500 |
|---|---|---|
| committer | A.J. Shulman <Shulman.aj@gmail.com> | 2024-11-06 22:23:03 -0500 |
| commit | 5d4e19ad5961e42b90f7bfc920ea80da6edc5089 (patch) | |
| tree | 5d6d7e86130a25e034114100de90d25a68c3494d /src/client/views/nodes/chatbot/agentsystem/Agent.ts | |
| parent | 09d7d63d1f248a0bf1d36e4da804cbde5e12e209 (diff) | |
Enhance assistant security with structured validation and input sanitization
- Prompt enhancements:
- Enforce strict response structure validation by requiring <stage>, <thought>, <action>, and <answer> tags in responses.
- Add self-validation instruction in <final_instruction> for assistant to check response structure before outputting.
- Instruct assistant to ignore XML-like syntax from user input, treating any <stage>, <action>, etc., as plain text.
- Code changes:
- Implement `validateAssistantResponse` function to enforce required response structure (e.g., ensuring <stage> element).
- Add input sanitization using `lodash.escape` to treat user inputs as plain text, preventing XML or HTML injection.
- Configure XML parser to ignore external entities and avoid interpreting embedded XML-like syntax.
- Introduce fallback error handling in parsing and validation to prevent assistant crashes on malformed or unexpected input.
- Log response errors with detailed messages to aid debugging and improve system resilience.
- Enhance input validation for tools by adding parameter checks, handling malformed data gracefully, and logging safety errors.
Diffstat (limited to 'src/client/views/nodes/chatbot/agentsystem/Agent.ts')
| -rw-r--r-- | src/client/views/nodes/chatbot/agentsystem/Agent.ts | 136 |
1 files changed, 130 insertions, 6 deletions
diff --git a/src/client/views/nodes/chatbot/agentsystem/Agent.ts b/src/client/views/nodes/chatbot/agentsystem/Agent.ts index 870abbc47..750bbbf4f 100644 --- a/src/client/views/nodes/chatbot/agentsystem/Agent.ts +++ b/src/client/views/nodes/chatbot/agentsystem/Agent.ts @@ -2,6 +2,7 @@ import dotenv from 'dotenv'; import { XMLBuilder, XMLParser } from 'fast-xml-parser'; import OpenAI from 'openai'; import { ChatCompletionMessageParam } from 'openai/resources'; +import { escape } from 'lodash'; // Imported escape from lodash import { AnswerParser } from '../response_parsers/AnswerParser'; import { StreamedAnswerParser } from '../response_parsers/StreamedAnswerParser'; import { CalculateTool } from '../tools/CalculateTool'; @@ -90,9 +91,10 @@ export class Agent { */ async askAgent(question: string, onProcessingUpdate: (processingUpdate: ProcessingInfo[]) => void, onAnswerUpdate: (answerUpdate: string) => void, maxTurns: number = 30): Promise<AssistantMessage> { console.log(`Starting query: ${question}`); + const sanitizedQuestion = escape(question); // Sanitized user input - // Push user's question to message history - this.messages.push({ role: 'user', content: question }); + // Push sanitized user's question to message history + this.messages.push({ role: 'user', content: sanitizedQuestion }); // Retrieve chat history and generate system prompt const chatHistory = this._history(); @@ -100,14 +102,20 @@ export class Agent { // Initialize intermediate messages this.interMessages = [{ role: 'system', content: systemPrompt }]; - this.interMessages.push({ role: 'user', content: `<stage number="1" role="user"><query>${question}</query></stage>` }); + + this.interMessages.push({ + role: 'user', + content: this.constructUserPrompt(1, 'user', `<query>${sanitizedQuestion}</query>`), + }); // Setup XML parser and builder const parser = new XMLParser({ ignoreAttributes: false, attributeNamePrefix: '@_', textNodeName: '_text', - isArray: (name /* , jpath, isLeafNode, isAttribute */) => ['query', 'url'].indexOf(name) !== -1, + isArray: name => ['query', 'url'].indexOf(name) !== -1, + processEntities: false, // Disable processing of entities + stopNodes: ['*.entity'], // Do not process any entities }); const builder = new XMLBuilder({ ignoreAttributes: false, attributeNamePrefix: '@_' }); @@ -128,8 +136,11 @@ export class Agent { try { // Parse XML result from the assistant parsedResult = parser.parse(result); + + // Validate the structure of the parsedResult + this.validateAssistantResponse(parsedResult); } catch (error) { - throw new Error(`Error parsing response: ${error}`); + throw new Error(`Error parsing or validating response: ${error}`); } // Extract the stage from the parsed result @@ -162,7 +173,10 @@ export class Agent { } else { // Handle error in case of an invalid action console.log('Error: No valid action'); - this.interMessages.push({ role: 'user', content: `<stage number="${i + 1}" role="system-error-reporter">No valid action, try again.</stage>` }); + this.interMessages.push({ + role: 'user', + content: `<stage number="${i + 1}" role="system-error-reporter">No valid action, try again.</stage>`, + }); break; } } else if (key === 'action_input') { @@ -198,6 +212,10 @@ export class Agent { throw new Error('Reached maximum turns. Ending query.'); } + private constructUserPrompt(stageNumber: number, role: string, content: string): string { + return `<stage number="${stageNumber}" role="${role}">${content}</stage>`; + } + /** * Executes a step in the conversation, processing the assistant's response and parsing it in real-time. * @param onProcessingUpdate Callback for processing updates. @@ -211,6 +229,7 @@ export class Agent { messages: this.interMessages as ChatCompletionMessageParam[], temperature: 0, stream: true, + stop: ['</stage>'], }); let fullResponse: string = ''; @@ -268,6 +287,111 @@ export class Agent { } /** + * Validates the assistant's response to ensure it conforms to the expected XML structure. + * @param response The parsed XML response from the assistant. + * @throws An error if the response does not meet the expected structure. + */ + private validateAssistantResponse(response: any) { + if (!response.stage) { + throw new Error('Response does not contain a <stage> element'); + } + + // Validate that the stage has the required attributes + const stage = response.stage; + if (!stage['@_number'] || !stage['@_role']) { + throw new Error('Stage element must have "number" and "role" attributes'); + } + + // Extract the role of the stage to determine expected content + const role = stage['@_role']; + + // Depending on the role, validate the presence of required elements + if (role === 'assistant') { + // Assistant's response should contain either 'thought', 'action', 'action_input', or 'answer' + if (!('thought' in stage || 'action' in stage || 'action_input' in stage || 'answer' in stage)) { + throw new Error('Assistant stage must contain a thought, action, action_input, or answer element'); + } + + // If 'thought' is present, validate it + if ('thought' in stage) { + if (typeof stage.thought !== 'string' || stage.thought.trim() === '') { + throw new Error('Thought must be a non-empty string'); + } + } + + // If 'action' is present, validate it + if ('action' in stage) { + if (typeof stage.action !== 'string' || stage.action.trim() === '') { + throw new Error('Action must be a non-empty string'); + } + + // Optional: Check if the action is among allowed actions + const allowedActions = Object.keys(this.tools); + if (!allowedActions.includes(stage.action)) { + throw new Error(`Action "${stage.action}" is not a valid tool`); + } + } + + // If 'action_input' is present, validate its structure + if ('action_input' in stage) { + const actionInput = stage.action_input; + + if (!('action_input_description' in actionInput) || typeof actionInput.action_input_description !== 'string') { + throw new Error('action_input must contain an action_input_description string'); + } + + if (!('inputs' in actionInput)) { + throw new Error('action_input must contain an inputs object'); + } + + // Further validation of inputs can be done here based on the expected parameters of the action + } + + // If 'answer' is present, validate its structure + if ('answer' in stage) { + const answer = stage.answer; + + // Ensure answer contains at least one of the required elements + if (!('grounded_text' in answer || 'normal_text' in answer)) { + throw new Error('Answer must contain grounded_text or normal_text'); + } + + // Validate follow_up_questions + if (!('follow_up_questions' in answer)) { + throw new Error('Answer must contain follow_up_questions'); + } + + // Validate loop_summary + if (!('loop_summary' in answer)) { + throw new Error('Answer must contain a loop_summary'); + } + + // Additional validation for citations, grounded_text, etc., can be added here + } + } else if (role === 'user') { + // User's stage should contain 'query' or 'observation' + if (!('query' in stage || 'observation' in stage)) { + throw new Error('User stage must contain a query or observation element'); + } + + // Validate 'query' if present + if ('query' in stage && typeof stage.query !== 'string') { + throw new Error('Query must be a string'); + } + + // Validate 'observation' if present + if ('observation' in stage) { + // Ensure observation has the correct structure + // This can be expanded based on how observations are structured + } + } else { + throw new Error(`Unknown role "${role}" in stage`); + } + + // Add any additional validation rules as necessary + } + + /** * Helper function to check if a string can be parsed as an array of the expected type. * @param input The input string to check. * @param expectedType The expected type of the array elements ('string', 'number', or 'boolean'). |
