Commit Diff


commit - 54a207a967ad9b9c38546fcaf4e126eb7528c57e
commit + fec63d5bb3d424e94440249e671b16318d808a88
blob - d08381cf5587e49ee7bc6c9496cc3a55c6152760
blob + bb70c4cb175fc77a2d6d986616be653570fc5402
--- README.md
+++ README.md
@@ -20,13 +20,15 @@ you query and contents. 
     - `pip install -r requirements.txt`
   - Index the newly added files:
     >   It may be required to create a `chroma` directory first!
-    - `python ./rag_indexer.py -db chroma static/files`
-    > The indexer will scan for pdf and text documents, parse them
+    - `python ./rag_indexer.py --db chroma static/files`
+    > The indexer will scan for PDF, text, and markdown documents, parse them
+
     > and will add it in batches to the chroma db. Then you are ready!
 - **Start the bot backend**
   - `python ./rag_interface.py`
-- **Query your documents**
-  - https://localhost:5555/
+  - **Query your documents**
+    - https://localhost:5555/
+    > If you need a different port, use `--port`
 
 ----
 # LICENSE (MIT)
blob - 83352fe5e971b4612dd21262cc36f42e4effea83
blob + 8bc629689b74222116b67b957ad6d4888403d74b
--- configuration.py
+++ configuration.py
@@ -3,6 +3,11 @@ import os
 from langchain_ollama import OllamaEmbeddings
 
 # default paths
+LOG_LEVEL: str = os.getenv("RAG_LOG_LEVEL", "INFO")
+# Flask server configuration
+FLASK_PORT = int(os.getenv('RAG_FLASK_PORT', 5000))
+FLASK_HOST = os.getenv('RAG_FLASK_HOST', None)
+
 DB_PATH: str = os.getenv("RAG_DB_PATH", "chroma")
 FILE_PATH: str = os.getenv("RAG_STATIC_FILES", "static/files")
 
blob - ef43eb0adc92bafd8fe4bdd000effa4dfc3e5f92
blob + 76d575a85594861b3a85c958677ff3000daf5f8c
--- rag_backend.py
+++ rag_backend.py
@@ -3,8 +3,8 @@ import logging
 from pathlib import Path
 
 from chromadb import Settings
-from langchain_chroma import Chroma
 from langchain.prompts import ChatPromptTemplate
+from langchain_chroma import Chroma
 from langchain_community.document_loaders import TextLoader, PyPDFDirectoryLoader
 from langchain_core.documents import Document
 from langchain_core.messages import SystemMessage
@@ -12,7 +12,8 @@ from langchain_core.prompts import HumanMessagePromptT
 from langchain_ollama import OllamaLLM
 from langchain_text_splitters import RecursiveCharacterTextSplitter
 
-from configuration import embeddings, DB_PATH, HUMAN_TEMPLATE, OLLAMA_MODEL, OLLAMA_URL, SYSTEM_PROMPT
+from configuration import embeddings, DB_PATH, HUMAN_TEMPLATE, OLLAMA_MODEL, OLLAMA_URL, SYSTEM_PROMPT, FILE_PATH, \
+    LOG_LEVEL
 
 
 class RagBackend:
@@ -78,6 +79,14 @@ class RagBackend:
         return documents
 
     @staticmethod
+    def load_markdown_documents(path: str) -> list[Document]:
+        items = Path(path).glob("**/[!.]*.md")
+        documents: list[Document] = []
+        for item in items:
+            documents += TextLoader(item).load()
+        return documents
+
+    @staticmethod
     def split_documents(documents: list[Document]) -> list[Document]:
         text_splitter = RecursiveCharacterTextSplitter(
             chunk_size=800,
@@ -123,18 +132,18 @@ class RagBackend:
 
     @staticmethod
     def calculate_chunk_ids(chunks: list[Document]) -> list[Document]:
-        # This will create IDs like "source.ext:6:2"
-        # Page Source : Page Number : Chunk Index
+        # chunk id's look like '<dir/filename.pdf>:<page>:<chunk-index>'
 
         last_page_id = None
         current_chunk_index = 0
 
         for chunk in chunks:
-            source = Path(chunk.metadata.get("source")).name
+            source = str(Path(chunk.metadata.get("source")).relative_to(Path(FILE_PATH)))
             page = chunk.metadata.get("page")
-            current_page_id = f"{source}:{page}"
+            current_page_id = f"{source}:{'1' if page is None else page}"
+            logging.debug(f"indexing page ID: {current_page_id}")
 
-            # If the page ID is the same as the last one, increment the index.
+            # If the page ID is the same as the last one, increase the index
             if current_page_id == last_page_id:
                 current_chunk_index += 1
             else:
@@ -157,9 +166,11 @@ if __name__ == "__main__":
     )
 
     parser = argparse.ArgumentParser()
+    parser.add_argument("--loglevel", default=LOG_LEVEL, help="path to the database")
     parser.add_argument("--db", default=DB_PATH, help="path to the database")
     parser.add_argument("query_text", type=str, help="The query text.")
     args = parser.parse_args()
+    logging.getLogger().setLevel(args.loglevel)
 
     rag_backend = RagBackend(args.db)
     response, sources = rag_backend.query(args.query_text, "")
blob - 1df8c9a7c238359be87915a3bece49b191f1f4a7
blob + f273cb66f0e79988ed15700ba5f8d5683b7e99da
--- rag_indexer.py
+++ rag_indexer.py
@@ -3,7 +3,7 @@ import logging
 import shutil
 import sys
 
-from configuration import DB_PATH
+from configuration import DB_PATH, LOG_LEVEL
 from rag_backend import RagBackend
 
 if __name__ == "__main__":
@@ -13,10 +13,12 @@ if __name__ == "__main__":
     )
 
     parser = argparse.ArgumentParser()
+    parser.add_argument("--loglevel", default=LOG_LEVEL, help="path to the database")
     parser.add_argument("--db", default=DB_PATH, help="path to the database")
     parser.add_argument("--reset", action="store_true", help="reset the database")
-    parser.add_argument("sources", nargs="*", help="source directories (only pdf/txt")
+    parser.add_argument("sources", nargs="*", help="source directories (pdf/txt/md files)")
     args = parser.parse_args()
+    logging.getLogger().setLevel(args.loglevel)
 
     if not len(args.sources):
         logging.error("no source directories specified")
@@ -33,10 +35,21 @@ if __name__ == "__main__":
     indexer = RagBackend(args.db)
     for source in args.sources:
         logging.info(f"searching {source}")
+
+        # Process PDF documents
         pdf_docs = indexer.load_pdf_documents(source)
         pdf_docs = indexer.split_documents(pdf_docs)
         indexer.add_to_index(pdf_docs)
+        logging.info(f"added {len(pdf_docs)} PDF document chunks to index")
+
+        # Process text documents
         text_docs = indexer.load_text_documents(source)
         text_docs = indexer.split_documents(text_docs)
         indexer.add_to_index(text_docs)
-        logging.info(f"added {len(text_docs)} text documents to index")
+        logging.info(f"added {len(text_docs)} text document chunks to index")
+
+        # Process Markdown documents
+        md_docs = indexer.load_markdown_documents(source)
+        md_docs = indexer.split_documents(md_docs)
+        indexer.add_to_index(md_docs)
+        logging.info(f"added {len(md_docs)} markdown document chunks to index")
blob - 2681329d51ff47e981427b95421382f242729275
blob + 0586062ff0e8ce47d0c58925f0322093f8832f55
--- rag_interface.py
+++ rag_interface.py
@@ -1,23 +1,26 @@
+import argparse
 import logging
 
-from flask import Flask, request, jsonify, render_template
+from flask import Flask, request, jsonify, render_template, send_from_directory
 
+from configuration import FLASK_PORT, FLASK_HOST, FILE_PATH, LOG_LEVEL
 from rag_backend import RagBackend
 
 app = Flask(__name__)
 rag = RagBackend()
 
-logging.basicConfig(
-    level=logging.INFO,
-    format='%(asctime)s %(levelname)s %(message)s'
-)
 
-
 @app.route('/')
 def home():
     return render_template("page.html")
 
 
+@app.route('/static/files/<path:filename>')
+def serve_files(filename):
+    """Serve files from the configured files directory"""
+    return send_from_directory(FILE_PATH, filename)
+
+
 # MCP endpoint
 @app.route('/mcp', methods=['POST'])
 def handle_mcp():
@@ -49,4 +52,16 @@ def handle_mcp():
 
 
 if __name__ == '__main__':
-    app.run(host="0.0.0.0", port="5555", debug=False)
+    logging.basicConfig(
+        level=logging.INFO,
+        format='%(asctime)s %(levelname)s %(message)s'
+    )
+
+    parser = argparse.ArgumentParser()
+    parser.add_argument("--loglevel", default=LOG_LEVEL, help="path to the database")
+    parser.add_argument("--host", default=FLASK_HOST, help="host to bind to")
+    parser.add_argument("--port", type=int, default=FLASK_PORT, help="port to listen on")
+    args = parser.parse_args()
+    logging.getLogger().setLevel(args.loglevel)
+
+    app.run(host=args.host, port=args.port, debug=False)