aboutsummaryrefslogtreecommitdiff
path: root/src/client/views/nodes/chatbot/agentsystem/Agent.ts
diff options
context:
space:
mode:
authorA.J. Shulman <Shulman.aj@gmail.com>2024-11-06 22:23:03 -0500
committerA.J. Shulman <Shulman.aj@gmail.com>2024-11-06 22:23:03 -0500
commit5d4e19ad5961e42b90f7bfc920ea80da6edc5089 (patch)
tree5d6d7e86130a25e034114100de90d25a68c3494d /src/client/views/nodes/chatbot/agentsystem/Agent.ts
parent09d7d63d1f248a0bf1d36e4da804cbde5e12e209 (diff)
Enhance assistant security with structured validation and input sanitization
- Prompt enhancements: - Enforce strict response structure validation by requiring <stage>, <thought>, <action>, and <answer> tags in responses. - Add self-validation instruction in <final_instruction> for assistant to check response structure before outputting. - Instruct assistant to ignore XML-like syntax from user input, treating any <stage>, <action>, etc., as plain text. - Code changes: - Implement `validateAssistantResponse` function to enforce required response structure (e.g., ensuring <stage> element). - Add input sanitization using `lodash.escape` to treat user inputs as plain text, preventing XML or HTML injection. - Configure XML parser to ignore external entities and avoid interpreting embedded XML-like syntax. - Introduce fallback error handling in parsing and validation to prevent assistant crashes on malformed or unexpected input. - Log response errors with detailed messages to aid debugging and improve system resilience. - Enhance input validation for tools by adding parameter checks, handling malformed data gracefully, and logging safety errors.
Diffstat (limited to 'src/client/views/nodes/chatbot/agentsystem/Agent.ts')
-rw-r--r--src/client/views/nodes/chatbot/agentsystem/Agent.ts136
1 files changed, 130 insertions, 6 deletions
diff --git a/src/client/views/nodes/chatbot/agentsystem/Agent.ts b/src/client/views/nodes/chatbot/agentsystem/Agent.ts
index 870abbc47..750bbbf4f 100644
--- a/src/client/views/nodes/chatbot/agentsystem/Agent.ts
+++ b/src/client/views/nodes/chatbot/agentsystem/Agent.ts
@@ -2,6 +2,7 @@ import dotenv from 'dotenv';
import { XMLBuilder, XMLParser } from 'fast-xml-parser';
import OpenAI from 'openai';
import { ChatCompletionMessageParam } from 'openai/resources';
+import { escape } from 'lodash'; // Imported escape from lodash
import { AnswerParser } from '../response_parsers/AnswerParser';
import { StreamedAnswerParser } from '../response_parsers/StreamedAnswerParser';
import { CalculateTool } from '../tools/CalculateTool';
@@ -90,9 +91,10 @@ export class Agent {
*/
async askAgent(question: string, onProcessingUpdate: (processingUpdate: ProcessingInfo[]) => void, onAnswerUpdate: (answerUpdate: string) => void, maxTurns: number = 30): Promise<AssistantMessage> {
console.log(`Starting query: ${question}`);
+ const sanitizedQuestion = escape(question); // Sanitized user input
- // Push user's question to message history
- this.messages.push({ role: 'user', content: question });
+ // Push sanitized user's question to message history
+ this.messages.push({ role: 'user', content: sanitizedQuestion });
// Retrieve chat history and generate system prompt
const chatHistory = this._history();
@@ -100,14 +102,20 @@ export class Agent {
// Initialize intermediate messages
this.interMessages = [{ role: 'system', content: systemPrompt }];
- this.interMessages.push({ role: 'user', content: `<stage number="1" role="user"><query>${question}</query></stage>` });
+
+ this.interMessages.push({
+ role: 'user',
+ content: this.constructUserPrompt(1, 'user', `<query>${sanitizedQuestion}</query>`),
+ });
// Setup XML parser and builder
const parser = new XMLParser({
ignoreAttributes: false,
attributeNamePrefix: '@_',
textNodeName: '_text',
- isArray: (name /* , jpath, isLeafNode, isAttribute */) => ['query', 'url'].indexOf(name) !== -1,
+ isArray: name => ['query', 'url'].indexOf(name) !== -1,
+ processEntities: false, // Disable processing of entities
+ stopNodes: ['*.entity'], // Do not process any entities
});
const builder = new XMLBuilder({ ignoreAttributes: false, attributeNamePrefix: '@_' });
@@ -128,8 +136,11 @@ export class Agent {
try {
// Parse XML result from the assistant
parsedResult = parser.parse(result);
+
+ // Validate the structure of the parsedResult
+ this.validateAssistantResponse(parsedResult);
} catch (error) {
- throw new Error(`Error parsing response: ${error}`);
+ throw new Error(`Error parsing or validating response: ${error}`);
}
// Extract the stage from the parsed result
@@ -162,7 +173,10 @@ export class Agent {
} else {
// Handle error in case of an invalid action
console.log('Error: No valid action');
- this.interMessages.push({ role: 'user', content: `<stage number="${i + 1}" role="system-error-reporter">No valid action, try again.</stage>` });
+ this.interMessages.push({
+ role: 'user',
+ content: `<stage number="${i + 1}" role="system-error-reporter">No valid action, try again.</stage>`,
+ });
break;
}
} else if (key === 'action_input') {
@@ -198,6 +212,10 @@ export class Agent {
throw new Error('Reached maximum turns. Ending query.');
}
+ private constructUserPrompt(stageNumber: number, role: string, content: string): string {
+ return `<stage number="${stageNumber}" role="${role}">${content}</stage>`;
+ }
+
/**
* Executes a step in the conversation, processing the assistant's response and parsing it in real-time.
* @param onProcessingUpdate Callback for processing updates.
@@ -211,6 +229,7 @@ export class Agent {
messages: this.interMessages as ChatCompletionMessageParam[],
temperature: 0,
stream: true,
+ stop: ['</stage>'],
});
let fullResponse: string = '';
@@ -268,6 +287,111 @@ export class Agent {
}
/**
+ * Validates the assistant's response to ensure it conforms to the expected XML structure.
+ * @param response The parsed XML response from the assistant.
+ * @throws An error if the response does not meet the expected structure.
+ */
+ private validateAssistantResponse(response: any) {
+ if (!response.stage) {
+ throw new Error('Response does not contain a <stage> element');
+ }
+
+ // Validate that the stage has the required attributes
+ const stage = response.stage;
+ if (!stage['@_number'] || !stage['@_role']) {
+ throw new Error('Stage element must have "number" and "role" attributes');
+ }
+
+ // Extract the role of the stage to determine expected content
+ const role = stage['@_role'];
+
+ // Depending on the role, validate the presence of required elements
+ if (role === 'assistant') {
+ // Assistant's response should contain either 'thought', 'action', 'action_input', or 'answer'
+ if (!('thought' in stage || 'action' in stage || 'action_input' in stage || 'answer' in stage)) {
+ throw new Error('Assistant stage must contain a thought, action, action_input, or answer element');
+ }
+
+ // If 'thought' is present, validate it
+ if ('thought' in stage) {
+ if (typeof stage.thought !== 'string' || stage.thought.trim() === '') {
+ throw new Error('Thought must be a non-empty string');
+ }
+ }
+
+ // If 'action' is present, validate it
+ if ('action' in stage) {
+ if (typeof stage.action !== 'string' || stage.action.trim() === '') {
+ throw new Error('Action must be a non-empty string');
+ }
+
+ // Optional: Check if the action is among allowed actions
+ const allowedActions = Object.keys(this.tools);
+ if (!allowedActions.includes(stage.action)) {
+ throw new Error(`Action "${stage.action}" is not a valid tool`);
+ }
+ }
+
+ // If 'action_input' is present, validate its structure
+ if ('action_input' in stage) {
+ const actionInput = stage.action_input;
+
+ if (!('action_input_description' in actionInput) || typeof actionInput.action_input_description !== 'string') {
+ throw new Error('action_input must contain an action_input_description string');
+ }
+
+ if (!('inputs' in actionInput)) {
+ throw new Error('action_input must contain an inputs object');
+ }
+
+ // Further validation of inputs can be done here based on the expected parameters of the action
+ }
+
+ // If 'answer' is present, validate its structure
+ if ('answer' in stage) {
+ const answer = stage.answer;
+
+ // Ensure answer contains at least one of the required elements
+ if (!('grounded_text' in answer || 'normal_text' in answer)) {
+ throw new Error('Answer must contain grounded_text or normal_text');
+ }
+
+ // Validate follow_up_questions
+ if (!('follow_up_questions' in answer)) {
+ throw new Error('Answer must contain follow_up_questions');
+ }
+
+ // Validate loop_summary
+ if (!('loop_summary' in answer)) {
+ throw new Error('Answer must contain a loop_summary');
+ }
+
+ // Additional validation for citations, grounded_text, etc., can be added here
+ }
+ } else if (role === 'user') {
+ // User's stage should contain 'query' or 'observation'
+ if (!('query' in stage || 'observation' in stage)) {
+ throw new Error('User stage must contain a query or observation element');
+ }
+
+ // Validate 'query' if present
+ if ('query' in stage && typeof stage.query !== 'string') {
+ throw new Error('Query must be a string');
+ }
+
+ // Validate 'observation' if present
+ if ('observation' in stage) {
+ // Ensure observation has the correct structure
+ // This can be expanded based on how observations are structured
+ }
+ } else {
+ throw new Error(`Unknown role "${role}" in stage`);
+ }
+
+ // Add any additional validation rules as necessary
+ }
+
+ /**
* Helper function to check if a string can be parsed as an array of the expected type.
* @param input The input string to check.
* @param expectedType The expected type of the array elements ('string', 'number', or 'boolean').