{ "cells": [ { "cell_type": "markdown", "id": "d13be0fd-db15-4ab1-860a-b00257051339", "metadata": {}, "source": [ "## Gradio UI for Markdown-Based Q&A with Visualization" ] }, { "cell_type": "markdown", "id": "bc63fbdb-66a9-4c10-8dbd-11476b5e2d21", "metadata": {}, "source": [ "This interface enables users to:\n", "- Upload Markdown files for processing\n", "- Visualize similarity between document chunks in 2D and 3D using embeddings\n", "- Ask questions and receive RAG enabled responses\n", "- Mantain conversation context for better question answering\n", "- Clear chat history when required for fresh sessions\n", "- Store and retrieve embeddings using ChromaDB\n", "\n", "Integrates LangChain, ChromaDB, and OpenAI to process, store, and retrieve information efficiently." ] }, { "cell_type": "code", "execution_count": null, "id": "91da28d8-8e29-44b7-a62a-a3a109753727", "metadata": {}, "outputs": [], "source": [ "# imports\n", "\n", "import os\n", "from dotenv import load_dotenv\n", "import gradio as gr" ] }, { "cell_type": "code", "execution_count": null, "id": "e47f670a-e2cb-4700-95d0-e59e440677a1", "metadata": {}, "outputs": [], "source": [ "# imports for langchain, plotly and Chroma\n", "\n", "from langchain.document_loaders import DirectoryLoader, TextLoader\n", "from langchain.text_splitter import CharacterTextSplitter\n", "from langchain.schema import Document\n", "from langchain_openai import OpenAIEmbeddings, ChatOpenAI\n", "from langchain.embeddings import HuggingFaceEmbeddings\n", "from langchain_chroma import Chroma\n", "from langchain.memory import ConversationBufferMemory\n", "from langchain.chains import ConversationalRetrievalChain\n", "import numpy as np\n", "from sklearn.manifold import TSNE\n", "import plotly.graph_objects as go\n", "import plotly.express as px\n", "import matplotlib.pyplot as plt\n", "from random import randint\n", "import shutil" ] }, { "cell_type": "code", "execution_count": null, "id": "362d4976-2553-4ed8-8fbb-49806145cad1", "metadata": {}, "outputs": [], "source": [ "!pip install --upgrade gradio" ] }, { "cell_type": "code", "execution_count": null, "id": "968b6e96-557e-439f-b2f1-942c05168641", "metadata": {}, "outputs": [], "source": [ "MODEL = \"gpt-4o-mini\"\n", "db_name = \"vector_db\"" ] }, { "cell_type": "code", "execution_count": null, "id": "537f66de-6abf-4b34-8e05-6b9a9df8ae82", "metadata": {}, "outputs": [], "source": [ "# Load environment variables in a file called .env\n", "\n", "load_dotenv(override=True)\n", "os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY')" ] }, { "cell_type": "code", "execution_count": null, "id": "246c1c1b-fcfa-4f4c-b99c-024598751361", "metadata": {}, "outputs": [], "source": [ "folder = \"my-knowledge-base/\"\n", "db_name = \"vectorstore_db\"\n", "\n", "def process_files(files):\n", " os.makedirs(folder, exist_ok=True)\n", "\n", " processed_files = []\n", " for file in files:\n", " file_path = os.path.join(folder, os.path.basename(file)) # Get filename\n", " shutil.copy(file, file_path)\n", " processed_files.append(os.path.basename(file))\n", "\n", " # Load documents using LangChain's DirectoryLoader\n", " text_loader_kwargs = {'autodetect_encoding': True}\n", " loader = DirectoryLoader(folder, glob=\"**/*.md\", loader_cls=TextLoader, loader_kwargs=text_loader_kwargs)\n", " folder_docs = loader.load()\n", "\n", " # Assign filenames as metadata\n", " for doc in folder_docs:\n", " filename_md = os.path.basename(doc.metadata[\"source\"])\n", " filename, _ = os.path.splitext(filename_md)\n", " doc.metadata[\"filename\"] = filename\n", "\n", " documents = folder_docs \n", "\n", " # Split documents into chunks\n", " text_splitter = CharacterTextSplitter(chunk_size=400, chunk_overlap=200)\n", " chunks = text_splitter.split_documents(documents)\n", "\n", " # Initialize embeddings\n", " embeddings = HuggingFaceEmbeddings(model_name=\"sentence-transformers/all-MiniLM-L6-v2\")\n", "\n", " # Delete previous vectorstore\n", " if os.path.exists(db_name):\n", " Chroma(persist_directory=db_name, embedding_function=embeddings).delete_collection()\n", "\n", " # Store in ChromaDB\n", " vectorstore = Chroma.from_documents(documents=chunks, embedding=embeddings, persist_directory=db_name)\n", "\n", " # Retrieve results\n", " collection = vectorstore._collection\n", " result = collection.get(include=['embeddings', 'documents', 'metadatas'])\n", "\n", " llm = ChatOpenAI(temperature=0.7, model_name=MODEL)\n", " memory = ConversationBufferMemory(memory_key='chat_history', return_messages=True)\n", " retriever = vectorstore.as_retriever(search_kwargs={\"k\": 35})\n", " global conversation_chain\n", " conversation_chain = ConversationalRetrievalChain.from_llm(llm=llm, retriever=retriever, memory=memory)\n", "\n", " processed_text = \"**Processed Files:**\\n\\n\" + \"\\n\".join(f\"- {file}\" for file in processed_files)\n", " return result, processed_text" ] }, { "cell_type": "code", "execution_count": null, "id": "48678d3a-0ab2-4aa4-aa9e-4160c6a9cb24", "metadata": {}, "outputs": [], "source": [ "def random_color():\n", " return f\"rgb({randint(0,255)},{randint(0,255)},{randint(0,255)})\"" ] }, { "cell_type": "code", "execution_count": null, "id": "6caed889-9bb4-42ad-b1c2-da051aefc802", "metadata": {}, "outputs": [], "source": [ "def show_embeddings_2d(result):\n", " vectors = np.array(result['embeddings']) \n", " documents = result['documents']\n", " metadatas = result['metadatas']\n", " filenames = [metadata['filename'] for metadata in metadatas]\n", " filenames_unique = sorted(set(filenames))\n", "\n", " # color assignment\n", " color_map = {name: random_color() for name in filenames_unique}\n", " colors = [color_map[name] for name in filenames]\n", "\n", " tsne = TSNE(n_components=2, random_state=42,perplexity=4)\n", " reduced_vectors = tsne.fit_transform(vectors)\n", "\n", " # Create the 2D scatter plot\n", " fig = go.Figure(data=[go.Scatter(\n", " x=reduced_vectors[:, 0],\n", " y=reduced_vectors[:, 1],\n", " mode='markers',\n", " marker=dict(size=5,color=colors, opacity=0.8),\n", " text=[f\"Type: {t}
Text: {d[:100]}...\" for t, d in zip(filenames, documents)],\n", " hoverinfo='text'\n", " )])\n", "\n", " fig.update_layout(\n", " title='2D Chroma Vector Store Visualization',\n", " scene=dict(xaxis_title='x',yaxis_title='y'),\n", " width=800,\n", " height=600,\n", " margin=dict(r=20, b=10, l=10, t=40)\n", " )\n", "\n", " return fig" ] }, { "cell_type": "code", "execution_count": null, "id": "de993495-c8cd-4313-a6bb-7d27494ecc13", "metadata": {}, "outputs": [], "source": [ "def show_embeddings_3d(result):\n", " vectors = np.array(result['embeddings']) \n", " documents = result['documents']\n", " metadatas = result['metadatas']\n", " filenames = [metadata['filename'] for metadata in metadatas]\n", " filenames_unique = sorted(set(filenames))\n", "\n", " # color assignment\n", " color_map = {name: random_color() for name in filenames_unique}\n", " colors = [color_map[name] for name in filenames]\n", "\n", " tsne = TSNE(n_components=3, random_state=42)\n", " reduced_vectors = tsne.fit_transform(vectors)\n", "\n", " fig = go.Figure(data=[go.Scatter3d(\n", " x=reduced_vectors[:, 0],\n", " y=reduced_vectors[:, 1],\n", " z=reduced_vectors[:, 2],\n", " mode='markers',\n", " marker=dict(size=5, color=colors, opacity=0.8),\n", " text=[f\"Type: {t}
Text: {d[:100]}...\" for t, d in zip(filenames, documents)],\n", " hoverinfo='text'\n", " )])\n", "\n", " fig.update_layout(\n", " title='3D Chroma Vector Store Visualization',\n", " scene=dict(xaxis_title='x', yaxis_title='y', zaxis_title='z'),\n", " width=900,\n", " height=700,\n", " margin=dict(r=20, b=10, l=10, t=40)\n", " )\n", "\n", " return fig" ] }, { "cell_type": "code", "execution_count": null, "id": "7b7bf62b-c559-4e97-8135-48cd8d97a40e", "metadata": {}, "outputs": [], "source": [ "def chat(question, history):\n", " result = conversation_chain.invoke({\"question\": question})\n", " return result[\"answer\"]\n", "\n", "def visualise_data(result):\n", " fig_2d = show_embeddings_2d(result)\n", " fig_3d = show_embeddings_3d(result)\n", " return fig_2d,fig_3d" ] }, { "cell_type": "code", "execution_count": null, "id": "99217109-fbee-4269-81c7-001e6f768a72", "metadata": {}, "outputs": [], "source": [ "css = \"\"\"\n", ".btn {background-color: #1d53d1;}\n", "\"\"\"" ] }, { "cell_type": "code", "execution_count": null, "id": "e1429ea1-1d9f-4be6-b270-01997864c642", "metadata": {}, "outputs": [], "source": [ "with gr.Blocks(css=css) as ui:\n", " gr.Markdown(\"# Markdown-Based Q&A with Visualization\")\n", " with gr.Row():\n", " file_input = gr.Files(file_types=[\".md\"], label=\"Upload Markdown Files\")\n", " with gr.Column(scale=1):\n", " processed_output = gr.Markdown(\"Progress\")\n", " with gr.Row():\n", " process_btn = gr.Button(\"Process Files\",elem_classes=[\"btn\"])\n", " with gr.Row():\n", " question = gr.Textbox(label=\"Chat \", lines=10)\n", " answer = gr.Markdown(label= \"Response\")\n", " with gr.Row():\n", " question_btn = gr.Button(\"Ask a Question\",elem_classes=[\"btn\"])\n", " clear_btn = gr.Button(\"Clear Output\",elem_classes=[\"btn\"])\n", " with gr.Row():\n", " plot_2d = gr.Plot(label=\"2D Visualization\")\n", " plot_3d = gr.Plot(label=\"3D Visualization\")\n", " with gr.Row():\n", " visualise_btn = gr.Button(\"Visualise Data\",elem_classes=[\"btn\"])\n", "\n", " result = gr.State([])\n", " # Action: When button is clicked, process files and update visualization\n", " clear_btn.click(fn=lambda:(\"\", \"\"), inputs=[],outputs=[question, answer])\n", " process_btn.click(process_files, inputs=[file_input], outputs=[result,processed_output])\n", " question_btn.click(chat, inputs=[question], outputs= [answer])\n", " visualise_btn.click(visualise_data, inputs=[result], outputs=[plot_2d,plot_3d])\n", "\n", "# Launch Gradio app\n", "ui.launch(inbrowser=True)" ] }, { "cell_type": "code", "execution_count": null, "id": "d3686048-ac29-4df1-b816-e58996913ef1", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 5 }