From badc8362c80ca33d2b3d93dda6a73b3bfb35a214 Mon Sep 17 00:00:00 2001 From: "A.J. Shulman" Date: Thu, 19 Sep 2024 12:19:04 -0400 Subject: added python files to server --- src/server/chunker/pdf_chunker.py | 744 ++++++++++++++++++++++++++++++++++++ src/server/chunker/requirements.txt | 15 + 2 files changed, 759 insertions(+) create mode 100644 src/server/chunker/pdf_chunker.py create mode 100644 src/server/chunker/requirements.txt (limited to 'src/server/chunker') diff --git a/src/server/chunker/pdf_chunker.py b/src/server/chunker/pdf_chunker.py new file mode 100644 index 000000000..c9f6737e7 --- /dev/null +++ b/src/server/chunker/pdf_chunker.py @@ -0,0 +1,744 @@ +import asyncio +import concurrent +import sys + +from tqdm.asyncio import tqdm_asyncio # Progress bar for async tasks +import PIL +from anthropic import Anthropic # For language model API +from packaging.version import parse # Version checking +import pytesseract # OCR library for text extraction from images +import re +import dotenv # For environment variable loading +from lxml import etree # XML parsing +from tqdm import tqdm # Progress bar for non-async tasks +import fitz # PyMuPDF, PDF processing library +from PIL import Image, ImageDraw # Image processing +from typing import List, Dict, Any, TypedDict # Typing for function annotations +from ultralyticsplus import YOLO # Object detection model (YOLO) +import base64 +import io +import json +import os +import uuid # For generating unique IDs +from enum import Enum # Enums for types like document type and purpose +import cohere # Embedding client +import numpy as np +from PyPDF2 import PdfReader # PDF text extraction +from openai import OpenAI # OpenAI client for text completion +from sklearn.cluster import KMeans # Clustering for summarization + +dotenv.load_dotenv() # Load environment variables + +# Fix for newer versions of PIL +if parse(PIL.__version__) >= parse('10.0.0'): + Image.LINEAR = Image.BILINEAR + +# Global dictionary to track progress of document processing jobs +current_progress = {} + + +def update_progress(job_id, step, progress_value): + """ + Output the progress in JSON format to stdout for the Node.js process to capture. + """ + progress_data = { + "job_id": job_id, + "step": step, + "progress": progress_value + } + print(json.dumps(progress_data)) # Output progress to stdout + sys.stdout.flush() # Ensure it's sent immediately + + +def get_current_progress(): + """ + Return the current progress of all jobs. + """ + return current_progress + + +class ElementExtractor: + def __init__(self, output_folder: str): + self.output_folder = output_folder + self.model = YOLO('keremberke/yolov8m-table-extraction') + self.model.overrides['conf'] = 0.25 + self.model.overrides['iou'] = 0.45 + self.padding = 5 + + async def extract_elements(self, page, padding: int = 20) -> List[Dict[str, Any]]: + tasks = [ + asyncio.create_task(self.extract_tables(page.image, page.page_num)), + asyncio.create_task(self.extract_images(page.page, page.image, page.page_num)) + ] + results = await asyncio.gather(*tasks) + return [item for sublist in results for item in sublist] + + async def extract_tables(self, img: Image.Image, page_num: int) -> List[Dict[str, Any]]: + results = self.model.predict(img, verbose=False) + tables = [] + + for idx, box in enumerate(results[0].boxes): + x1, y1, x2, y2 = map(int, box.xyxy[0]) + + # Draw a red rectangle on the full page image around the table + page_with_outline = img.copy() + draw = ImageDraw.Draw(page_with_outline) + draw.rectangle( + [max(0, x1 + self.padding), max(0, y1 + self.padding), min(page_with_outline.width, x2 + self.padding), + min(page_with_outline.height, y2 + self.padding)], outline="red", width=2) # Draw red outline + + # Save the full page with the red outline + table_filename = f"table_page{page_num + 1}_{idx + 1}.png" + table_path = os.path.join(self.output_folder, table_filename) + page_with_outline.save(table_path) + + # Convert the full-page image with red outline to base64 + base64_data = self.image_to_base64(page_with_outline) + + tables.append({ + 'metadata': { + "type": "table", + "location": [x1 / img.width, y1 / img.height, x2 / img.width, y2 / img.height], + "file_path": table_path, + "start_page": page_num, + "end_page": page_num, + "base64_data": base64_data, + } + }) + + return tables + + async def extract_images(self, page: fitz.Page, img: Image.Image, page_num: int) -> List[Dict[str, Any]]: + images = [] + image_list = page.get_images(full=True) + + if not image_list: + return images + + for img_index, img_info in enumerate(image_list): + xref = img_info[0] + #try: + base_image = page.parent.extract_image(xref) + image_bytes = base_image["image"] + image = Image.open(io.BytesIO(image_bytes)) + width_ratio = img.width / page.rect.width + height_ratio = img.height / page.rect.height + + # Get image coordinates or default to page rectangle + rect_list = page.get_image_rects(xref) + if rect_list: + rect = rect_list[0] + x1, y1, x2, y2 = rect + else: + rect = page.rect + x1, y1, x2, y2 = rect + + # Draw a red rectangle on the full page image around the embedded image + page_with_outline = img.copy() + draw = ImageDraw.Draw(page_with_outline) + draw.rectangle([x1 * width_ratio, y1 * height_ratio, x2 * width_ratio, y2 * height_ratio], + outline="red", width=2) # Draw red outline + + # Save the full page with the red outline + image_filename = f"image_page{page_num + 1}_{img_index + 1}.png" + image_path = os.path.join(self.output_folder, image_filename) + page_with_outline.save(image_path) + + # Convert the full-page image with red outline to base64 + base64_data = self.image_to_base64(page_with_outline) + + images.append({ + 'metadata': { + "type": "image", + "location": [x1 / page.rect.width, y1 / page.rect.height, x2 / page.rect.width, + y2 / page.rect.height], + "file_path": image_path, + "start_page": page_num, + "end_page": page_num, + "base64_data": base64_data, + } + }) + + #except Exception as e: + # print(f"Error processing image on page {page_num + 1}, image {img_index + 1}: {str(e)}") + return images + + @staticmethod + def image_to_base64(image: Image.Image) -> str: + buffered = io.BytesIO() + image.save(buffered, format="PNG") + return base64.b64encode(buffered.getvalue()).decode('utf-8') + + +class ChunkMetaData(TypedDict): + """ + A TypedDict that defines the metadata structure for chunks of text and visual elements. + """ + text: str + type: str + original_document: str + file_path: str + doc_id: str + location: str + start_page: int + end_page: int + base64_data: str + + +class Chunk(TypedDict): + """ + A TypedDict that defines the structure for a document chunk, including metadata and embeddings. + """ + id: str + values: List[float] + metadata: ChunkMetaData + + +class Page: + """ + A class that represents a single PDF page, handling its image representation and element masking. + """ + + def __init__(self, page: fitz.Page, page_num: int): + self.page = page + self.page_num = page_num + # Get high-resolution image of the page (for table/image extraction) + self.pix = page.get_pixmap(matrix=fitz.Matrix(2, 2)) + self.image = Image.frombytes("RGB", [self.pix.width, self.pix.height], self.pix.samples) + self.masked_image = self.image.copy() # Image with masked elements (tables/images) + self.draw = ImageDraw.Draw(self.masked_image) + self.elements = [] # List to store extracted elements + + def add_element(self, element): + """ + Adds a detected element (table/image) to the page and masks its location on the page image. + """ + self.elements.append(element) + # Mask the element on the page image by drawing a white rectangle over its location + x1, y1, x2, y2 = [coord * self.image.width if i % 2 == 0 else coord * self.image.height + for i, coord in enumerate(element['metadata']['location'])] + self.draw.rectangle([x1, y1, x2, y2], fill="white") + + +class PDFChunker: + """ + The main class responsible for chunking PDF files into text and visual elements (tables/images). + """ + + def __init__(self, output_folder: str = "output", image_batch_size: int = 5) -> None: + self.client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) # Initialize the Anthropic API client + self.output_folder = output_folder + self.image_batch_size = image_batch_size # Batch size for image processing + self.element_extractor = ElementExtractor(output_folder) # Initialize the element extractor + + async def chunk_pdf(self, file_data: bytes, file_name: str, doc_id: str, job_id: str) -> List[Dict[str, Any]]: + """ + Processes a PDF file, extracting text and visual elements, and returning structured chunks. + """ + with fitz.open(stream=file_data, filetype="pdf") as pdf_document: + num_pages = len(pdf_document) # Get the total number of pages in the PDF + pages = [Page(pdf_document[i], i) for i in tqdm(range(num_pages), desc="Initializing Pages")] + + update_progress(job_id, "Extracting tables and images...", 0) + await self.extract_and_mask_elements(pages, job_id) + + update_progress(job_id, "Processing tables and images...", 0) + await self.process_visual_elements(pages, self.image_batch_size, job_id) + + update_progress(job_id, "Extracting text...", 0) + page_texts = await self.extract_text_from_masked_pages(pages, job_id) + + update_progress(job_id, "Processing text...", 0) + text_chunks = self.chunk_text_with_metadata(page_texts, max_words=1000, job_id=job_id) + + # Combine text and visual elements into a unified structure (chunks) + chunks = self.combine_chunks(text_chunks, [elem for page in pages for elem in page.elements], file_name, + doc_id) + + return chunks + + async def extract_and_mask_elements(self, pages: List[Page], job_id: str): + """ + Extract visual elements (tables and images) from each page and mask them on the page. + """ + total_pages = len(pages) + tasks = [] + + for i, page in enumerate(pages): + tasks.append(asyncio.create_task(self.element_extractor.extract_elements(page))) + progress = ((i + 1) / total_pages) * 100 + update_progress(job_id, "Extracting tables and images...", progress) + + # Gather all extraction results + results = await asyncio.gather(*tasks) + + # Mask the detected elements on the page images + for page, elements in zip(pages, results): + for element in elements: + page.add_element(element) + + async def process_visual_elements(self, pages: List[Page], image_batch_size: int, job_id: str) -> List[ + Dict[str, Any]]: + """ + Process extracted visual elements in batches, generating summaries or descriptions. + """ + pre_elements = [element for page in pages for element in page.elements] # Flatten list of elements + processed_elements = [] + total_batches = (len(pre_elements) // image_batch_size) + 1 + + loop = asyncio.get_event_loop() + with concurrent.futures.ThreadPoolExecutor() as executor: + # Process elements in batches + for i in tqdm(range(0, len(pre_elements), image_batch_size), desc="Processing Visual Elements"): + batch = pre_elements[i:i + image_batch_size] + # Run image summarization in a separate thread + summaries = await loop.run_in_executor( + executor, self.batch_summarize_images, + {j + 1: element.get('metadata').get('base64_data') for j, element in enumerate(batch)} + ) + + # Append generated summaries to the elements + for j, elem in enumerate(batch, start=1): + if j in summaries: + elem['metadata']['text'] = re.sub(r'^(Image|Table):\s*', '', summaries[j]) + processed_elements.append(elem) + + progress = ((i // image_batch_size) + 1) / total_batches * 100 + update_progress(job_id, "Processing tables and images...", progress) + + return processed_elements + + async def extract_text_from_masked_pages(self, pages: List[Page], job_id: str) -> Dict[int, str]: + """ + Extract text from masked page images (where tables and images have been masked out). + """ + total_pages = len(pages) + tasks = [] + + for i, page in enumerate(pages): + tasks.append(asyncio.create_task(self.extract_text(page.masked_image, page.page_num))) + progress = ((i + 1) / total_pages) * 100 + update_progress(job_id, "Extracting text...", progress) + + # Return extracted text from each page + return dict(await asyncio.gather(*tasks)) + + @staticmethod + async def extract_text(image: Image.Image, page_num: int) -> (int, str): + """ + Perform OCR on the provided image to extract text. + """ + result = pytesseract.image_to_string(image) + return page_num + 1, result.strip() # Return the page number and extracted text + + def chunk_text_with_metadata(self, page_texts: Dict[int, str], max_words: int, job_id: str) -> List[Dict[str, Any]]: + """ + Break the extracted text into smaller chunks with metadata (e.g., page numbers). + """ + chunks = [] + current_chunk = "" + current_start_page = 0 + total_words = 0 + + def add_chunk(chunk_text, start_page, end_page): + # Add a chunk of text with metadata + chunks.append({ + "text": chunk_text.strip(), + "start_page": start_page, + "end_page": end_page + }) + + total_pages = len(page_texts) + for i, (page_num, text) in enumerate(tqdm(page_texts.items(), desc="Chunking Text")): + sentences = self.split_into_sentences(text) + for sentence in sentences: + word_count = len(sentence.split()) + # If adding this sentence exceeds max_words, create a new chunk + if total_words + word_count > max_words: + add_chunk(current_chunk, current_start_page, page_num) + current_chunk = sentence + " " + current_start_page = page_num + total_words = word_count + else: + current_chunk += sentence + " " + total_words += word_count + current_chunk += "\n\n" + + progress = ((i + 1) / total_pages) * 100 + update_progress(job_id, "Processing text...", progress) + + # Add the last chunk if there is leftover text + if current_chunk.strip(): + add_chunk(current_chunk, current_start_page, page_num) + + return chunks + + @staticmethod + def split_into_sentences(text): + """ + Split the text into sentences using regular expressions. + """ + return re.split(r'(?<=[.!?])\s+', text) + + @staticmethod + def combine_chunks(text_chunks: List[Dict[str, Any]], visual_elements: List[Dict[str, Any]], pdf_path: str, + doc_id: str) -> List[Chunk]: + """ + Combine text and visual chunks into a unified list. + """ + combined_chunks = [] + # Add text chunks + for text_chunk in text_chunks: + chunk_metadata: ChunkMetaData = { + "text": text_chunk["text"], + "type": "text", + "original_document": pdf_path, + "file_path": "", + "location": "", + "start_page": text_chunk["start_page"], + "end_page": text_chunk["end_page"], + "base64_data": "", + "doc_id": doc_id, + } + chunk_dict: Chunk = { + "id": str(uuid.uuid4()), + "values": [], + "metadata": chunk_metadata, + } + combined_chunks.append(chunk_dict) + + # Add visual chunks (tables/images) + for elem in visual_elements: + visual_chunk_metadata: ChunkMetaData = { + "type": elem['metadata']['type'], + "file_path": elem['metadata']['file_path'], + "text": elem['metadata'].get('text', ''), + "start_page": elem['metadata']['start_page'], + "end_page": elem['metadata']['end_page'], + "location": str(elem['metadata']['location']), + "base64_data": elem['metadata']['base64_data'], + "doc_id": doc_id, + "original_document": pdf_path, + } + visual_chunk_dict: Chunk = { + "id": str(uuid.uuid4()), + "values": [], + "metadata": visual_chunk_metadata, + } + combined_chunks.append(visual_chunk_dict) + + return combined_chunks + + def batch_summarize_images(self, images: Dict[int, str]) -> Dict[int, str]: + """ + Summarize images or tables by generating descriptive text. + """ + # Prompt for the AI model to summarize images and tables + prompt = f""" + + You are tasked with summarizing a series of {len(images)} images and tables for use in a RAG (Retrieval-Augmented Generation) system. + Your goal is to create concise, informative summaries that capture the essential content of each image or table. + These summaries will be used for embedding, so they should be descriptive and relevant. The image or table will be outlined in red on an image of the full page that it is on. Where necessary, use the context of the full page to heklp with the summary but don't summarize other content on the page. + + + + Identify whether it's an image or a table. + Examine its content carefully. + + Write a detailed summary that captures the main points or visual elements: +
+ After summarizing what the table is about, include the column headers, a detailed summary of the data, and any notable data trends.
+ Describe the main subjects, actions, or notable features. +
+
+ Focus on writing summaries that would make it easy to retrieve the content if compared to a user query using vector similarity search. + Keep summaries concise and include important words that may help with retrieval (but do not include numbers and numerical data). +
+ + + Avoid using special characters like &, <, >, ", ', $, %, etc. Instead, use their word equivalents: + Use "and" instead of &. + Use "dollars" instead of $. + Use "percent" instead of %. + Refrain from using quotation marks " or apostrophes ' unless absolutely necessary. + Ensure your output is in valid XML format. + + + + Enclose all summaries within a root element called <summaries>. + Use <summary> tags to enclose each individual summary. + Include an attribute 'number' in each <summary> tag to indicate the sequence, matching the provided image numbers. + Start each summary by indicating whether it's an image or a table (e.g., "This image shows..." or "The table presents..."). + If an image is completely blank, leave the summary blank (e.g., <summary number="3"></summary>). + + + + Do not replicate the example below—stay grounded to the content of the table or image and describe it completely and accurately. + + <summaries> + <summary number="1"> + The image shows two men shaking hands on stage at a formal event. The man on the left, in a dark suit and glasses, has a professional appearance, possibly an academic or business figure. The man on the right, Tim Cook, CEO of Apple, is recognizable by his silver hair and dark blue blazer. Cook holds a document titled "Tsinghua SEM EMBA," suggesting a link to Tsinghua University’s Executive MBA program. The backdrop displays English and Chinese text about business management and education, with the event dated October 23, 2014. + </summary> + <summary number="2"> + The table compares the company's assets between December 30, 2023, and September 30, 2023. Key changes include an increase in cash and cash equivalents, while marketable securities had a slight rise. Accounts receivable and vendor non-trade receivables decreased. Inventories and other current assets saw minor fluctuations. Non-current assets like marketable securities slightly declined, while property, plant, and equipment remained stable. Total assets showed minimal change, holding steady at around three hundred fifty-three billion dollars. + </summary> + <summary number="3"> + The table outlines the company's shareholders' equity as of December 30, 2023, versus September 30, 2023. Common stock and additional paid-in capital increased, and retained earnings shifted from a deficit to a positive figure. Accumulated other comprehensive loss decreased. Overall, total shareholders' equity rose significantly, while total liabilities and equity remained nearly unchanged at about three hundred fifty-three billion dollars. + </summary> + <summary number="4"> + The table details the company's liabilities as of December 30, 2023, compared to September 30, 2023. Current liabilities decreased due to lower accounts payable and other current liabilities, while deferred revenue slightly increased. Commercial paper significantly decreased, and term debt rose modestly. Non-current liabilities were stable, with minimal changes in term debt and other non-current liabilities. Total liabilities dropped from two hundred ninety billion dollars to two hundred seventy-nine billion dollars. + </summary> + <summary number="5"> + </summary> + </summaries> + + + + + Process each image or table in the order provided. + Maintain consistent formatting throughout your response. + Ensure the output is in full, valid XML format with the root <summaries> element and each summary being within a <summary> element with the summary number specified as well. + +
+ """ + content = [] + for number, img in images.items(): + content.append({"type": "text", "text": f"\nImage {number}:\n"}) + content.append({"type": "image", "source": {"type": "base64", "media_type": "image/png", "data": img}}) + + messages = [ + {"role": "user", "content": content} + ] + + try: + response = self.client.messages.create( + model='claude-3-5-sonnet-20240620', + system=prompt, + max_tokens=400 * len(images), # Increased token limit for more detailed summaries + messages=messages, + temperature=0, + extra_headers={"anthropic-beta": "max-tokens-3-5-sonnet-2024-07-15"} + ) + + # Parse the response + text = response.content[0].text + #print(text) + # Attempt to parse and fix the XML if necessary + parser = etree.XMLParser(recover=True) + root = etree.fromstring(text, parser=parser) + # Check if there were errors corrected + # if parser.error_log: + # #print("XML Parsing Errors:") + # for error in parser.error_log: + # #print(error) + # Extract summaries + summaries = {} + for summary in root.findall('summary'): + number = int(summary.get('number')) + content = summary.text.strip() if summary.text else "" + if content: # Only include non-empty summaries + summaries[number] = content + + return summaries + + except Exception: + #print(f"Error in batch_summarize_images: {str(e)}") + #print("Returning placeholder summaries") + return {number: "Error: No summary available" for number in images} + + +class DocumentType(Enum): + PDF = "pdf" + CSV = "csv" + TXT = "txt" + HTML = "html" + + +class FileTypeNotSupportedException(Exception): + """ + Exception raised for unsupported file types. + """ + + def __init__(self, file_extension: str): + self.file_extension = file_extension + self.message = f"File type '{file_extension}' is not supported." + super().__init__(self.message) + + +class Document: + """ + Represents a document being processed, such as a PDF, handling chunking and embedding. + """ + + def __init__(self, file_data: bytes, file_name: str, job_id: str): + self.file_data = file_data + self.file_name = file_name + self.job_id = job_id + self.type = self._get_document_type(file_name) + self.doc_id = job_id # Use job_id as document ID + self.chunks = [] + self.num_pages = 0 + self.summary = "" + + self._process() # Start processing the document + + def _process(self): + """ + Process the document: chunk it, embed chunks, and generate a summary. + """ + pdf_chunker = PDFChunker(output_folder="output") + self.chunks = asyncio.run(pdf_chunker.chunk_pdf(self.file_data, self.file_name, self.doc_id, self.job_id)) + + self.num_pages = self._get_pdf_pages() # Get the number of pages + self._embed_chunks() # Embed the text chunks + self.summary = self._generate_summary() # Generate a summary + + def _get_document_type(self, file_name: str) -> DocumentType: + """ + Determine the document type based on its file extension. + """ + _, extension = os.path.splitext(file_name) + extension = extension.lower().lstrip('.') + try: + return DocumentType(extension) + except ValueError: + raise FileTypeNotSupportedException(extension) + + def _get_pdf_pages(self) -> int: + """ + Get the total number of pages in the PDF. + """ + pdf_file = io.BytesIO(self.file_data) + pdf_reader = PdfReader(pdf_file) + return len(pdf_reader.pages) + + def _embed_chunks(self) -> None: + """ + Embed the text chunks using the Cohere API. + """ + co = cohere.Client(os.getenv("COHERE_API_KEY")) + batch_size = 90 + chunks_len = len(self.chunks) + for i in tqdm(range(0, chunks_len, batch_size), desc="Embedding Chunks"): + batch = self.chunks[i: min(i + batch_size, chunks_len)] + texts = [chunk['metadata']['text'] for chunk in batch] + #try: + chunk_embs_batch = co.embed( + texts=texts, + model="embed-english-v3.0", + input_type="search_document" + ) + for j, emb in enumerate(chunk_embs_batch.embeddings): + self.chunks[i + j]['values'] = emb + #except Exception as e: + #print(f"Error embedding batch for {self.file_name}: {str(e)}") + + def _generate_summary(self) -> str: + """ + Generate a summary of the document using KMeans clustering and a language model. + """ + num_clusters = min(10, len(self.chunks)) + kmeans = KMeans(n_clusters=num_clusters, random_state=42) + doc_chunks = [chunk['values'] for chunk in self.chunks if 'values' in chunk] + cluster_labels = kmeans.fit_predict(doc_chunks) + + # Select representative chunks from each cluster + selected_chunks = [] + for i in range(num_clusters): + cluster_chunks = [chunk for chunk, label in zip(self.chunks, cluster_labels) if label == i] + cluster_embs = [emb for emb, label in zip(doc_chunks, cluster_labels) if label == i] + centroid = kmeans.cluster_centers_[i] + distances = [np.linalg.norm(np.array(emb) - centroid) for emb in cluster_embs] + closest_chunk = cluster_chunks[np.argmin(distances)] + selected_chunks.append(closest_chunk) + + # Combine selected chunks into a summary + combined_text = "\n\n".join([chunk['metadata']['text'] for chunk in selected_chunks]) + + client = OpenAI() # Call OpenAI API for text generation (summarization) + completion = client.chat.completions.create( + model="gpt-3.5-turbo", + messages=[ + {"role": "system", + "content": "You are an AI assistant tasked with summarizing a document. You are provided with important chunks from the document and provide a summary, as best you can, of what the document will contain overall. Be concise and brief with your response."}, + {"role": "user", "content": f"""Please provide a comprehensive summary of what you think the document from which these chunks were sampled would be. + Ensure the summary captures the main ideas and key points from all provided chunks. Be concise and brief and only provide the summary in paragraph form. + + Sample text chunks: + ``` + {combined_text} + ``` + ********** + Summary: + """} + ], + max_tokens=300 + ) + return completion.choices[0].message.content.strip() + + def to_json(self) -> str: + """ + Return the document's data in JSON format. + """ + return json.dumps({ + "file_name": self.file_name, + "num_pages": self.num_pages, + "summary": self.summary, + "chunks": self.chunks, + "type": self.type.value, + "doc_id": self.doc_id + }, indent=2) + + +def process_document(file_data, file_name, job_id): + """ + Top-level function to process a document and return the JSON output. + """ + new_document = Document(file_data, file_name, job_id) + return new_document.to_json() + + +def print_progress(job_id, step, progress_value): + """ + Output the progress in JSON format to stdout for the Node.js process to capture. + """ + progress_data = { + "job_id": job_id, + "step": step, + "progress": progress_value + } + print(json.dumps(progress_data)) # Output progress to stdout + sys.stdout.flush() # Ensure it's sent immediately + + +def main(): + """ + Main entry point for the script, called with arguments from Node.js. + """ + if len(sys.argv) != 4: + print(json.dumps({"error": "Invalid arguments"})) + return + + job_id = sys.argv[1] + file_name = sys.argv[2] + file_data = sys.argv[3] + + try: + # Decode the base64 file data + file_bytes = base64.b64decode(file_data) + + # Process the document + document_result = process_document(file_bytes, file_name, job_id) + + # Output the final result as JSON + print(document_result) + sys.stdout.flush() + + except Exception as e: + # If any error occurs, print the error to stdout for Node.js to capture + print(json.dumps({"error": str(e)})) + sys.stdout.flush() + + +if __name__ == "__main__": + main() diff --git a/src/server/chunker/requirements.txt b/src/server/chunker/requirements.txt new file mode 100644 index 000000000..20bd486e5 --- /dev/null +++ b/src/server/chunker/requirements.txt @@ -0,0 +1,15 @@ +anthropic==0.34.0 +cohere==5.8.0 +python-dotenv==1.0.1 +pymupdf==1.22.2 +lxml==5.3.0 +layoutparser==0.3.4 +numpy==1.26.4 +openai==1.40.6 +Pillow==10.4.0 +pytesseract==0.3.10 +PyPDF2==3.0.1 +scikit-learn==1.5.1 +tqdm==4.66.5 +ultralyticsplus==0.0.28 +easyocr==1.7.0 \ No newline at end of file -- cgit v1.2.3-70-g09d2 From 2d61b3b0d00c239f05615c691ffbf4b98f3054e9 Mon Sep 17 00:00:00 2001 From: "A.J. Shulman" Date: Thu, 19 Sep 2024 12:36:18 -0400 Subject: Working now with Python script --- src/server/ApiManagers/AssistantManager.ts | 47 ++++++++++++++++++++---------- src/server/chunker/pdf_chunker.py | 40 +++++++++---------------- 2 files changed, 44 insertions(+), 43 deletions(-) (limited to 'src/server/chunker') diff --git a/src/server/ApiManagers/AssistantManager.ts b/src/server/ApiManagers/AssistantManager.ts index dfe5d747b..224d47d3b 100644 --- a/src/server/ApiManagers/AssistantManager.ts +++ b/src/server/ApiManagers/AssistantManager.ts @@ -291,7 +291,10 @@ export default class AssistantManager extends ApiManager { if (jobProgress[jobId]) { res.json(jobProgress[jobId]); } else { - res.status(404).send({ error: 'Job not found' }); + res.json({ + step: 'Processing Document...', + progress: '0', + }); } }, }); @@ -452,43 +455,55 @@ function spawnPythonProcess(jobId: string, file_name: string, file_data: string) ]); let pythonOutput = ''; // Accumulate stdout data + let stderrOutput = ''; // For stderr logs and progress - // Handle stdout data (progress and final results) + // Handle stdout data (final result in JSON format) pythonProcess.stdout.on('data', data => { - pythonOutput += data.toString(); // Accumulate data + pythonOutput += data.toString(); // Accumulate data from stdout + }); - const lines = pythonOutput.split('\n'); // Handle multi-line JSON + // Handle stderr (progress logs or errors) + pythonProcess.stderr.on('data', data => { + stderrOutput += data.toString(); + const lines = stderrOutput.split('\n'); lines.forEach(line => { if (line.trim()) { try { - const parsedOutput = JSON.parse(line); // Parse each line of JSON + // Progress and warnings are printed as JSON to stderr + const parsedOutput = JSON.parse(line); + // Handle progress updates if (parsedOutput.job_id && parsedOutput.progress !== undefined) { jobProgress[parsedOutput.job_id] = { step: parsedOutput.step, progress: parsedOutput.progress, }; - } else if (parsedOutput.chunks) { - jobResults[parsedOutput.job_id] = parsedOutput; - jobProgress[parsedOutput.job_id] = { step: 'Complete', progress: 100 }; + } else if (parsedOutput.progress !== undefined) { + jobProgress[jobId] = { + step: parsedOutput.step, + progress: parsedOutput.progress, + }; } } catch (err) { - console.error('Error parsing Python output:', err); + console.error('Progress log from Python:', line); } } }); }); - // Handle stderr (error logging) - pythonProcess.stderr.on('data', data => { - console.error(`Python script error: ${data}`); - }); - // Handle process exit pythonProcess.on('close', code => { - if (code !== 0) { + if (code === 0) { + // Parse final JSON output (stdout) + try { + const finalResult = JSON.parse(pythonOutput); // Parse JSON from stdout + jobResults[jobId] = finalResult; + jobProgress[jobId] = { step: 'Complete', progress: 100 }; + } catch (err) { + console.error('Error parsing final JSON result:', err); + } + } else { console.error(`Python process exited with code ${code}`); - console.error(`Command: python3 ${path.join(__dirname, '../chunker/pdf_chunker.py')} ${jobId} ${file_name}`); jobResults[jobId] = { error: 'Python process failed' }; } }); diff --git a/src/server/chunker/pdf_chunker.py b/src/server/chunker/pdf_chunker.py index c9f6737e7..12e71c29d 100644 --- a/src/server/chunker/pdf_chunker.py +++ b/src/server/chunker/pdf_chunker.py @@ -26,6 +26,12 @@ import numpy as np from PyPDF2 import PdfReader # PDF text extraction from openai import OpenAI # OpenAI client for text completion from sklearn.cluster import KMeans # Clustering for summarization +import warnings + +# Silence specific warnings +warnings.filterwarnings('ignore', message="Valid config keys have changed") +warnings.filterwarnings('ignore', message="torch.load") + dotenv.load_dotenv() # Load environment variables @@ -36,7 +42,6 @@ if parse(PIL.__version__) >= parse('10.0.0'): # Global dictionary to track progress of document processing jobs current_progress = {} - def update_progress(job_id, step, progress_value): """ Output the progress in JSON format to stdout for the Node.js process to capture. @@ -46,15 +51,8 @@ def update_progress(job_id, step, progress_value): "step": step, "progress": progress_value } - print(json.dumps(progress_data)) # Output progress to stdout - sys.stdout.flush() # Ensure it's sent immediately - - -def get_current_progress(): - """ - Return the current progress of all jobs. - """ - return current_progress + print(json.dumps(progress_data), file=sys.stderr) # Use stderr for progress logs + sys.stderr.flush() # Ensure it's sent immediately class ElementExtractor: @@ -698,25 +696,13 @@ def process_document(file_data, file_name, job_id): return new_document.to_json() -def print_progress(job_id, step, progress_value): - """ - Output the progress in JSON format to stdout for the Node.js process to capture. - """ - progress_data = { - "job_id": job_id, - "step": step, - "progress": progress_value - } - print(json.dumps(progress_data)) # Output progress to stdout - sys.stdout.flush() # Ensure it's sent immediately - def main(): """ Main entry point for the script, called with arguments from Node.js. """ if len(sys.argv) != 4: - print(json.dumps({"error": "Invalid arguments"})) + print(json.dumps({"error": "Invalid arguments"}), file=sys.stderr) return job_id = sys.argv[1] @@ -730,14 +716,14 @@ def main(): # Process the document document_result = process_document(file_bytes, file_name, job_id) - # Output the final result as JSON + # Output the final result as JSON to stdout print(document_result) sys.stdout.flush() except Exception as e: - # If any error occurs, print the error to stdout for Node.js to capture - print(json.dumps({"error": str(e)})) - sys.stdout.flush() + # Print errors to stderr so they don't interfere with JSON output + print(json.dumps({"error": str(e)}), file=sys.stderr) + sys.stderr.flush() if __name__ == "__main__": -- cgit v1.2.3-70-g09d2 From b08befda6d7ec07a0e6653ccf5040474886dcd44 Mon Sep 17 00:00:00 2001 From: "A.J. Shulman" Date: Mon, 23 Sep 2024 08:55:37 -0400 Subject: added comments to pdf chunker --- src/server/chunker/pdf_chunker.py | 317 ++++++++++++++++++++++++++------------ 1 file changed, 215 insertions(+), 102 deletions(-) (limited to 'src/server/chunker') diff --git a/src/server/chunker/pdf_chunker.py b/src/server/chunker/pdf_chunker.py index 12e71c29d..4fe3b9dbf 100644 --- a/src/server/chunker/pdf_chunker.py +++ b/src/server/chunker/pdf_chunker.py @@ -32,7 +32,6 @@ import warnings warnings.filterwarnings('ignore', message="Valid config keys have changed") warnings.filterwarnings('ignore', message="torch.load") - dotenv.load_dotenv() # Load environment variables # Fix for newer versions of PIL @@ -45,6 +44,10 @@ current_progress = {} def update_progress(job_id, step, progress_value): """ Output the progress in JSON format to stdout for the Node.js process to capture. + + :param job_id: The unique identifier for the processing job. + :param step: The current step of the job. + :param progress_value: The percentage of completion for the current step. """ progress_data = { "job_id": job_id, @@ -56,27 +59,50 @@ def update_progress(job_id, step, progress_value): class ElementExtractor: + """ + A class that uses a YOLO model to extract tables and images from a PDF page. + """ + def __init__(self, output_folder: str): + """ + Initializes the ElementExtractor with the output folder for saving images and the YOLO model. + + :param output_folder: Path to the folder where extracted elements will be saved. + """ self.output_folder = output_folder - self.model = YOLO('keremberke/yolov8m-table-extraction') - self.model.overrides['conf'] = 0.25 - self.model.overrides['iou'] = 0.45 - self.padding = 5 + self.model = YOLO('keremberke/yolov8m-table-extraction') # Load YOLO model for table extraction + self.model.overrides['conf'] = 0.25 # Set confidence threshold for detection + self.model.overrides['iou'] = 0.45 # Set Intersection over Union (IoU) threshold + self.padding = 5 # Padding around detected elements async def extract_elements(self, page, padding: int = 20) -> List[Dict[str, Any]]: + """ + Asynchronously extract tables and images from a PDF page. + + :param page: A Page object representing a PDF page. + :param padding: Padding around the extracted elements. + :return: A list of dictionaries containing the extracted elements. + """ tasks = [ - asyncio.create_task(self.extract_tables(page.image, page.page_num)), - asyncio.create_task(self.extract_images(page.page, page.image, page.page_num)) + asyncio.create_task(self.extract_tables(page.image, page.page_num)), # Extract tables from the page + asyncio.create_task(self.extract_images(page.page, page.image, page.page_num)) # Extract images from the page ] - results = await asyncio.gather(*tasks) - return [item for sublist in results for item in sublist] + results = await asyncio.gather(*tasks) # Wait for both tasks to complete + return [item for sublist in results for item in sublist] # Flatten and return results async def extract_tables(self, img: Image.Image, page_num: int) -> List[Dict[str, Any]]: - results = self.model.predict(img, verbose=False) + """ + Asynchronously extract tables from a given page image using the YOLO model. + + :param img: The image of the PDF page. + :param page_num: The current page number. + :return: A list of dictionaries with metadata about the detected tables. + """ + results = self.model.predict(img, verbose=False) # Predict table locations using YOLO tables = [] for idx, box in enumerate(results[0].boxes): - x1, y1, x2, y2 = map(int, box.xyxy[0]) + x1, y1, x2, y2 = map(int, box.xyxy[0]) # Extract bounding box coordinates # Draw a red rectangle on the full page image around the table page_with_outline = img.copy() @@ -107,20 +133,27 @@ class ElementExtractor: return tables async def extract_images(self, page: fitz.Page, img: Image.Image, page_num: int) -> List[Dict[str, Any]]: + """ + Asynchronously extract embedded images from a PDF page. + + :param page: A fitz.Page object representing the PDF page. + :param img: The image of the PDF page. + :param page_num: The current page number. + :return: A list of dictionaries with metadata about the detected images. + """ images = [] - image_list = page.get_images(full=True) + image_list = page.get_images(full=True) # Get a list of images on the page if not image_list: return images for img_index, img_info in enumerate(image_list): - xref = img_info[0] - #try: - base_image = page.parent.extract_image(xref) + xref = img_info[0] # XREF of the image in the PDF + base_image = page.parent.extract_image(xref) # Extract the image by its XREF image_bytes = base_image["image"] - image = Image.open(io.BytesIO(image_bytes)) - width_ratio = img.width / page.rect.width - height_ratio = img.height / page.rect.height + image = Image.open(io.BytesIO(image_bytes)) # Convert bytes to PIL image + width_ratio = img.width / page.rect.width # Scale factor for width + height_ratio = img.height / page.rect.height # Scale factor for height # Get image coordinates or default to page rectangle rect_list = page.get_image_rects(xref) @@ -157,15 +190,19 @@ class ElementExtractor: } }) - #except Exception as e: - # print(f"Error processing image on page {page_num + 1}, image {img_index + 1}: {str(e)}") return images @staticmethod def image_to_base64(image: Image.Image) -> str: + """ + Convert a PIL image to a base64-encoded string. + + :param image: The PIL image to be converted. + :return: The base64-encoded string of the image. + """ buffered = io.BytesIO() - image.save(buffered, format="PNG") - return base64.b64encode(buffered.getvalue()).decode('utf-8') + image.save(buffered, format="PNG") # Save image as PNG to an in-memory buffer + return base64.b64encode(buffered.getvalue()).decode('utf-8') # Convert to base64 and return class ChunkMetaData(TypedDict): @@ -198,6 +235,12 @@ class Page: """ def __init__(self, page: fitz.Page, page_num: int): + """ + Initializes the Page with its page number and the image representation of the page. + + :param page: A fitz.Page object representing the PDF page. + :param page_num: The number of the page in the PDF. + """ self.page = page self.page_num = page_num # Get high-resolution image of the page (for table/image extraction) @@ -210,12 +253,14 @@ class Page: def add_element(self, element): """ Adds a detected element (table/image) to the page and masks its location on the page image. + + :param element: A dictionary containing metadata about the detected element. """ self.elements.append(element) # Mask the element on the page image by drawing a white rectangle over its location x1, y1, x2, y2 = [coord * self.image.width if i % 2 == 0 else coord * self.image.height for i, coord in enumerate(element['metadata']['location'])] - self.draw.rectangle([x1, y1, x2, y2], fill="white") + self.draw.rectangle([x1, y1, x2, y2], fill="white") # Draw a white rectangle to mask the element class PDFChunker: @@ -224,6 +269,12 @@ class PDFChunker: """ def __init__(self, output_folder: str = "output", image_batch_size: int = 5) -> None: + """ + Initializes the PDFChunker with an output folder and an element extractor for visual elements. + + :param output_folder: Folder to store the output files (extracted tables/images). + :param image_batch_size: The batch size for processing visual elements. + """ self.client = Anthropic(api_key=os.getenv("ANTHROPIC_API_KEY")) # Initialize the Anthropic API client self.output_folder = output_folder self.image_batch_size = image_batch_size # Batch size for image processing @@ -232,22 +283,28 @@ class PDFChunker: async def chunk_pdf(self, file_data: bytes, file_name: str, doc_id: str, job_id: str) -> List[Dict[str, Any]]: """ Processes a PDF file, extracting text and visual elements, and returning structured chunks. + + :param file_data: The binary data of the PDF file. + :param file_name: The name of the PDF file. + :param doc_id: The unique document ID for this job. + :param job_id: The unique job ID for the processing task. + :return: A list of structured chunks containing text and visual elements. """ with fitz.open(stream=file_data, filetype="pdf") as pdf_document: num_pages = len(pdf_document) # Get the total number of pages in the PDF - pages = [Page(pdf_document[i], i) for i in tqdm(range(num_pages), desc="Initializing Pages")] + pages = [Page(pdf_document[i], i) for i in tqdm(range(num_pages), desc="Initializing Pages")] # Initialize each page update_progress(job_id, "Extracting tables and images...", 0) - await self.extract_and_mask_elements(pages, job_id) + await self.extract_and_mask_elements(pages, job_id) # Extract and mask elements (tables/images) update_progress(job_id, "Processing tables and images...", 0) - await self.process_visual_elements(pages, self.image_batch_size, job_id) + await self.process_visual_elements(pages, self.image_batch_size, job_id) # Process visual elements update_progress(job_id, "Extracting text...", 0) - page_texts = await self.extract_text_from_masked_pages(pages, job_id) + page_texts = await self.extract_text_from_masked_pages(pages, job_id) # Extract text from masked pages update_progress(job_id, "Processing text...", 0) - text_chunks = self.chunk_text_with_metadata(page_texts, max_words=1000, job_id=job_id) + text_chunks = self.chunk_text_with_metadata(page_texts, max_words=1000, job_id=job_id) # Chunk text into smaller parts # Combine text and visual elements into a unified structure (chunks) chunks = self.combine_chunks(text_chunks, [elem for page in pages for elem in page.elements], file_name, @@ -258,13 +315,16 @@ class PDFChunker: async def extract_and_mask_elements(self, pages: List[Page], job_id: str): """ Extract visual elements (tables and images) from each page and mask them on the page. + + :param pages: A list of Page objects representing the PDF pages. + :param job_id: The unique job ID for the processing task. """ total_pages = len(pages) tasks = [] for i, page in enumerate(pages): - tasks.append(asyncio.create_task(self.element_extractor.extract_elements(page))) - progress = ((i + 1) / total_pages) * 100 + tasks.append(asyncio.create_task(self.element_extractor.extract_elements(page))) # Extract elements asynchronously + progress = ((i + 1) / total_pages) * 100 # Calculate progress update_progress(job_id, "Extracting tables and images...", progress) # Gather all extraction results @@ -273,16 +333,20 @@ class PDFChunker: # Mask the detected elements on the page images for page, elements in zip(pages, results): for element in elements: - page.add_element(element) + page.add_element(element) # Mask each extracted element on the page - async def process_visual_elements(self, pages: List[Page], image_batch_size: int, job_id: str) -> List[ - Dict[str, Any]]: + async def process_visual_elements(self, pages: List[Page], image_batch_size: int, job_id: str) -> List[Dict[str, Any]]: """ Process extracted visual elements in batches, generating summaries or descriptions. + + :param pages: A list of Page objects representing the PDF pages. + :param image_batch_size: The batch size for processing visual elements. + :param job_id: The unique job ID for the processing task. + :return: A list of processed elements with metadata and generated summaries. """ pre_elements = [element for page in pages for element in page.elements] # Flatten list of elements processed_elements = [] - total_batches = (len(pre_elements) // image_batch_size) + 1 + total_batches = (len(pre_elements) // image_batch_size) + 1 # Calculate total number of batches loop = asyncio.get_event_loop() with concurrent.futures.ThreadPoolExecutor() as executor: @@ -301,7 +365,7 @@ class PDFChunker: elem['metadata']['text'] = re.sub(r'^(Image|Table):\s*', '', summaries[j]) processed_elements.append(elem) - progress = ((i // image_batch_size) + 1) / total_batches * 100 + progress = ((i // image_batch_size) + 1) / total_batches * 100 # Calculate progress update_progress(job_id, "Processing tables and images...", progress) return processed_elements @@ -309,13 +373,17 @@ class PDFChunker: async def extract_text_from_masked_pages(self, pages: List[Page], job_id: str) -> Dict[int, str]: """ Extract text from masked page images (where tables and images have been masked out). + + :param pages: A list of Page objects representing the PDF pages. + :param job_id: The unique job ID for the processing task. + :return: A dictionary mapping page numbers to extracted text. """ total_pages = len(pages) tasks = [] for i, page in enumerate(pages): - tasks.append(asyncio.create_task(self.extract_text(page.masked_image, page.page_num))) - progress = ((i + 1) / total_pages) * 100 + tasks.append(asyncio.create_task(self.extract_text(page.masked_image, page.page_num))) # Perform OCR on each page + progress = ((i + 1) / total_pages) * 100 # Calculate progress update_progress(job_id, "Extracting text...", progress) # Return extracted text from each page @@ -325,13 +393,22 @@ class PDFChunker: async def extract_text(image: Image.Image, page_num: int) -> (int, str): """ Perform OCR on the provided image to extract text. + + :param image: The PIL image of the page. + :param page_num: The current page number. + :return: A tuple containing the page number and the extracted text. """ - result = pytesseract.image_to_string(image) + result = pytesseract.image_to_string(image) # Extract text using Tesseract OCR return page_num + 1, result.strip() # Return the page number and extracted text def chunk_text_with_metadata(self, page_texts: Dict[int, str], max_words: int, job_id: str) -> List[Dict[str, Any]]: """ Break the extracted text into smaller chunks with metadata (e.g., page numbers). + + :param page_texts: A dictionary mapping page numbers to extracted text. + :param max_words: The maximum number of words allowed in a chunk. + :param job_id: The unique job ID for the processing task. + :return: A list of dictionaries containing text chunks with metadata. """ chunks = [] current_chunk = "" @@ -362,7 +439,7 @@ class PDFChunker: total_words += word_count current_chunk += "\n\n" - progress = ((i + 1) / total_pages) * 100 + progress = ((i + 1) / total_pages) * 100 # Calculate progress update_progress(job_id, "Processing text...", progress) # Add the last chunk if there is leftover text @@ -375,6 +452,9 @@ class PDFChunker: def split_into_sentences(text): """ Split the text into sentences using regular expressions. + + :param text: The raw text to be split into sentences. + :return: A list of sentences. """ return re.split(r'(?<=[.!?])\s+', text) @@ -383,6 +463,12 @@ class PDFChunker: doc_id: str) -> List[Chunk]: """ Combine text and visual chunks into a unified list. + + :param text_chunks: A list of dictionaries containing text chunks with metadata. + :param visual_elements: A list of dictionaries containing visual elements (tables/images) with metadata. + :param pdf_path: The path to the original PDF file. + :param doc_id: The unique document ID for this job. + :return: A list of Chunk objects representing the combined data. """ combined_chunks = [] # Add text chunks @@ -399,7 +485,7 @@ class PDFChunker: "doc_id": doc_id, } chunk_dict: Chunk = { - "id": str(uuid.uuid4()), + "id": str(uuid.uuid4()), # Generate a unique ID for the chunk "values": [], "metadata": chunk_metadata, } @@ -419,7 +505,7 @@ class PDFChunker: "original_document": pdf_path, } visual_chunk_dict: Chunk = { - "id": str(uuid.uuid4()), + "id": str(uuid.uuid4()), # Generate a unique ID for the visual chunk "values": [], "metadata": visual_chunk_metadata, } @@ -430,6 +516,9 @@ class PDFChunker: def batch_summarize_images(self, images: Dict[int, str]) -> Dict[int, str]: """ Summarize images or tables by generating descriptive text. + + :param images: A dictionary mapping image numbers to base64-encoded image data. + :return: A dictionary mapping image numbers to their generated summaries. """ # Prompt for the AI model to summarize images and tables prompt = f""" @@ -544,118 +633,136 @@ class PDFChunker: #print("Returning placeholder summaries") return {number: "Error: No summary available" for number in images} - class DocumentType(Enum): - PDF = "pdf" - CSV = "csv" - TXT = "txt" - HTML = "html" + """ + Enum representing different types of documents that can be processed. + """ + PDF = "pdf" # PDF file type + CSV = "csv" # CSV file type + TXT = "txt" # Plain text file type + HTML = "html" # HTML file type class FileTypeNotSupportedException(Exception): """ - Exception raised for unsupported file types. + Exception raised when a file type is unsupported during document processing. """ def __init__(self, file_extension: str): + """ + Initialize the exception with the unsupported file extension. + + :param file_extension: The file extension that triggered the exception. + """ self.file_extension = file_extension self.message = f"File type '{file_extension}' is not supported." - super().__init__(self.message) + super().__init__(self.message) # Call the parent class constructor with the message class Document: """ - Represents a document being processed, such as a PDF, handling chunking and embedding. + Represents a document being processed, such as a PDF, handling chunking, embedding, and summarization. """ def __init__(self, file_data: bytes, file_name: str, job_id: str): + """ + Initialize the Document with file data, file name, and job ID. + + :param file_data: The binary data of the file being processed. + :param file_name: The name of the file being processed. + :param job_id: The job ID associated with this document processing task. + """ self.file_data = file_data self.file_name = file_name self.job_id = job_id - self.type = self._get_document_type(file_name) - self.doc_id = job_id # Use job_id as document ID - self.chunks = [] - self.num_pages = 0 - self.summary = "" + self.type = self._get_document_type(file_name) # Determine the document type (PDF, CSV, etc.) + self.doc_id = job_id # Use the job ID as the document ID + self.chunks = [] # List to hold text and visual chunks + self.num_pages = 0 # Number of pages in the document (if applicable) + self.summary = "" # The generated summary for the document self._process() # Start processing the document def _process(self): """ - Process the document: chunk it, embed chunks, and generate a summary. + Process the document: extract chunks, embed them, and generate a summary. """ - pdf_chunker = PDFChunker(output_folder="output") - self.chunks = asyncio.run(pdf_chunker.chunk_pdf(self.file_data, self.file_name, self.doc_id, self.job_id)) + pdf_chunker = PDFChunker(output_folder="output") # Initialize the PDF chunker + self.chunks = asyncio.run(pdf_chunker.chunk_pdf(self.file_data, self.file_name, self.doc_id, self.job_id)) # Extract chunks - self.num_pages = self._get_pdf_pages() # Get the number of pages - self._embed_chunks() # Embed the text chunks - self.summary = self._generate_summary() # Generate a summary + self.num_pages = self._get_pdf_pages() # Get the number of pages in the document + self._embed_chunks() # Embed the text chunks into embeddings + self.summary = self._generate_summary() # Generate a summary for the document def _get_document_type(self, file_name: str) -> DocumentType: """ Determine the document type based on its file extension. + + :param file_name: The name of the file being processed. + :return: The DocumentType enum value corresponding to the file extension. """ - _, extension = os.path.splitext(file_name) - extension = extension.lower().lstrip('.') + _, extension = os.path.splitext(file_name) # Split the file name to get the extension + extension = extension.lower().lstrip('.') # Convert to lowercase and remove leading period try: - return DocumentType(extension) + return DocumentType(extension) # Try to match the extension to a DocumentType except ValueError: - raise FileTypeNotSupportedException(extension) + raise FileTypeNotSupportedException(extension) # Raise exception if file type is unsupported def _get_pdf_pages(self) -> int: """ - Get the total number of pages in the PDF. + Get the total number of pages in the PDF document. + + :return: The number of pages in the PDF. """ - pdf_file = io.BytesIO(self.file_data) - pdf_reader = PdfReader(pdf_file) - return len(pdf_reader.pages) + pdf_file = io.BytesIO(self.file_data) # Convert the file data to an in-memory binary stream + pdf_reader = PdfReader(pdf_file) # Initialize PDF reader + return len(pdf_reader.pages) # Return the number of pages in the PDF def _embed_chunks(self) -> None: """ Embed the text chunks using the Cohere API. """ - co = cohere.Client(os.getenv("COHERE_API_KEY")) - batch_size = 90 - chunks_len = len(self.chunks) + co = cohere.Client(os.getenv("COHERE_API_KEY")) # Initialize Cohere client with API key + batch_size = 90 # Batch size for embedding + chunks_len = len(self.chunks) # Total number of chunks to embed for i in tqdm(range(0, chunks_len, batch_size), desc="Embedding Chunks"): - batch = self.chunks[i: min(i + batch_size, chunks_len)] - texts = [chunk['metadata']['text'] for chunk in batch] - #try: + batch = self.chunks[i: min(i + batch_size, chunks_len)] # Get batch of chunks + texts = [chunk['metadata']['text'] for chunk in batch] # Extract text from each chunk chunk_embs_batch = co.embed( texts=texts, - model="embed-english-v3.0", - input_type="search_document" + model="embed-english-v3.0", # Use Cohere's embedding model + input_type="search_document" # Specify input type ) for j, emb in enumerate(chunk_embs_batch.embeddings): - self.chunks[i + j]['values'] = emb - #except Exception as e: - #print(f"Error embedding batch for {self.file_name}: {str(e)}") + self.chunks[i + j]['values'] = emb # Store the embeddings in the corresponding chunks def _generate_summary(self) -> str: """ Generate a summary of the document using KMeans clustering and a language model. + + :return: The generated summary of the document. """ - num_clusters = min(10, len(self.chunks)) - kmeans = KMeans(n_clusters=num_clusters, random_state=42) - doc_chunks = [chunk['values'] for chunk in self.chunks if 'values' in chunk] - cluster_labels = kmeans.fit_predict(doc_chunks) + num_clusters = min(10, len(self.chunks)) # Set number of clusters for KMeans, capped at 10 + kmeans = KMeans(n_clusters=num_clusters, random_state=42) # Initialize KMeans with 10 clusters + doc_chunks = [chunk['values'] for chunk in self.chunks if 'values' in chunk] # Extract embeddings + cluster_labels = kmeans.fit_predict(doc_chunks) # Assign each chunk to a cluster # Select representative chunks from each cluster selected_chunks = [] for i in range(num_clusters): - cluster_chunks = [chunk for chunk, label in zip(self.chunks, cluster_labels) if label == i] - cluster_embs = [emb for emb, label in zip(doc_chunks, cluster_labels) if label == i] - centroid = kmeans.cluster_centers_[i] - distances = [np.linalg.norm(np.array(emb) - centroid) for emb in cluster_embs] - closest_chunk = cluster_chunks[np.argmin(distances)] + cluster_chunks = [chunk for chunk, label in zip(self.chunks, cluster_labels) if label == i] # Get all chunks in this cluster + cluster_embs = [emb for emb, label in zip(doc_chunks, cluster_labels) if label == i] # Get embeddings for this cluster + centroid = kmeans.cluster_centers_[i] # Get the centroid of the cluster + distances = [np.linalg.norm(np.array(emb) - centroid) for emb in cluster_embs] # Compute distance to centroid + closest_chunk = cluster_chunks[np.argmin(distances)] # Select chunk closest to the centroid selected_chunks.append(closest_chunk) # Combine selected chunks into a summary - combined_text = "\n\n".join([chunk['metadata']['text'] for chunk in selected_chunks]) + combined_text = "\n\n".join([chunk['metadata']['text'] for chunk in selected_chunks]) # Concatenate chunk texts - client = OpenAI() # Call OpenAI API for text generation (summarization) + client = OpenAI() # Initialize OpenAI client for text generation completion = client.chat.completions.create( - model="gpt-3.5-turbo", + model="gpt-3.5-turbo", # Specify the language model messages=[ {"role": "system", "content": "You are an AI assistant tasked with summarizing a document. You are provided with important chunks from the document and provide a summary, as best you can, of what the document will contain overall. Be concise and brief with your response."}, @@ -670,13 +777,15 @@ class Document: Summary: """} ], - max_tokens=300 + max_tokens=300 # Set max tokens for the summary ) - return completion.choices[0].message.content.strip() + return completion.choices[0].message.content.strip() # Return the generated summary def to_json(self) -> str: """ Return the document's data in JSON format. + + :return: JSON string representing the document's metadata, chunks, and summary. """ return json.dumps({ "file_name": self.file_name, @@ -685,16 +794,20 @@ class Document: "chunks": self.chunks, "type": self.type.value, "doc_id": self.doc_id - }, indent=2) + }, indent=2) # Convert the document's attributes to JSON format def process_document(file_data, file_name, job_id): """ Top-level function to process a document and return the JSON output. - """ - new_document = Document(file_data, file_name, job_id) - return new_document.to_json() + :param file_data: The binary data of the file being processed. + :param file_name: The name of the file being processed. + :param job_id: The job ID for this document processing task. + :return: The processed document's data in JSON format. + """ + new_document = Document(file_data, file_name, job_id) # Create a new Document object + return new_document.to_json() # Return the document's JSON data def main(): @@ -702,12 +815,12 @@ def main(): Main entry point for the script, called with arguments from Node.js. """ if len(sys.argv) != 4: - print(json.dumps({"error": "Invalid arguments"}), file=sys.stderr) + print(json.dumps({"error": "Invalid arguments"}), file=sys.stderr) # Print error if incorrect number of arguments return - job_id = sys.argv[1] - file_name = sys.argv[2] - file_data = sys.argv[3] + job_id = sys.argv[1] # Get the job ID from command-line arguments + file_name = sys.argv[2] # Get the file name from command-line arguments + file_data = sys.argv[3] # Get the base64-encoded file data from command-line arguments try: # Decode the base64 file data @@ -727,4 +840,4 @@ def main(): if __name__ == "__main__": - main() + main() # Execute the main function when the script is run -- cgit v1.2.3-70-g09d2