{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Requirements\n",
    "\n",
    "1. Install pytest and pytest-cov library\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "#!pipenv install pytest pytest-cov"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [],
   "source": [
    "# imports\n",
    "import re\n",
    "import os\n",
    "import sys\n",
    "import textwrap\n",
    "from dotenv import load_dotenv\n",
    "from openai import OpenAI\n",
    "import anthropic\n",
    "import gradio as gr\n",
    "from pathlib import Path\n",
    "import subprocess\n",
    "from IPython.display import Markdown"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Initialization\n",
    "\n",
    "load_dotenv()\n",
    "\n",
    "openai_api_key = os.getenv('OPENAI_API_KEY')\n",
    "os.environ['ANTHROPIC_API_KEY'] = os.getenv('ANTHROPIC_API_KEY', 'your-key-if-not-using-env')\n",
    "if openai_api_key:\n",
    "    print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n",
    "else:\n",
    "    print(\"OpenAI API Key not set\")\n",
    "    \n",
    "OPENAI_MODEL = \"gpt-4o-mini\"\n",
    "CLAUDE_MODEL = \"claude-3-5-sonnet-20240620\"\n",
    "openai = OpenAI()\n",
    "claude = anthropic.Anthropic()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "OLLAMA_API = \"http://localhost:11434/api/chat\"\n",
    "HEADERS = {\"Content-Type\": \"application/json\"}\n",
    "OLLAMA_MODEL = \"llama3.2\""
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Code execution"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "\n",
    "def extract_code(text):\n",
    "    # Regular expression to find text between ``python and ``\n",
    "    match = re.search(r\"```python(.*?)```\", text, re.DOTALL)\n",
    "\n",
    "    if match:\n",
    "        code = match.group(0).strip()  # Extract and strip extra spaces\n",
    "    else:\n",
    "        code = \"\"\n",
    "        print(\"No matching substring found.\")\n",
    "\n",
    "    return code.replace(\"```python\\n\", \"\").replace(\"```\", \"\")\n",
    "\n",
    "\n",
    "def execute_coverage_report(python_interpreter=sys.executable):\n",
    "    if not python_interpreter:\n",
    "        raise EnvironmentError(\"Python interpreter not found in the specified virtual environment.\")\n",
    "    \n",
    "    command = [\"coverage\", \"run\", \"-m\", \"pytest\"]\n",
    "\n",
    "    try:\n",
    "        result = subprocess.run(command, check=True, capture_output=True, text=True)\n",
    "        print(\"Tests ran successfully!\")\n",
    "        print(result.stdout)\n",
    "        return result.stdout\n",
    "    except subprocess.CalledProcessError as e:\n",
    "        print(\"Some tests failed!\")\n",
    "        print(\"Output:\\n\", e.stdout)\n",
    "        print(\"Errors:\\n\", e.stderr)\n",
    "        # Extracting failed test information\n",
    "        return e.stdout\n",
    "\n",
    "def save_unit_tests(code):\n",
    "\n",
    "    match = re.search(r\"def\\s+(\\w+)\\(\", code, re.DOTALL)\n",
    "\n",
    "    if match:\n",
    "        function_name = match.group(1).strip()  # Extract and strip extra spaces\n",
    "    else:\n",
    "        function_name = \"\"\n",
    "        print(\"No matching substring found.\")\n",
    "\n",
    "    test_code_path = Path(\"tests\")\n",
    "    (test_code_path / f\"test_{function_name}.py\").write_text(extract_code(code))\n",
    "    Path(\"tests\", \"test_code.py\").unlink()\n",
    "    \n",
    "\n",
    "def execute_tests_in_venv(code_to_test, tests, python_interpreter=sys.executable):\n",
    "    \"\"\"\n",
    "    Execute the given Python code string within the specified virtual environment.\n",
    "    \n",
    "    Args:\n",
    "    - code_str: str, the Python code to execute.\n",
    "    - venv_dir: str, the directory path to the virtual environment created by pipenv.\n",
    "    \"\"\"\n",
    "    \n",
    "    if not python_interpreter:\n",
    "        raise EnvironmentError(\"Python interpreter not found in the specified virtual environment.\")\n",
    "\n",
    "    # Prepare the command to execute the code\n",
    "    code_str = textwrap.dedent(code_to_test) + \"\\n\" + extract_code(tests)\n",
    "    test_code_path = Path(\"tests\")\n",
    "    test_code_path.mkdir(parents=True, exist_ok=True)\n",
    "    (test_code_path / f\"test_code.py\").write_text(code_str)\n",
    "    command = [\"pytest\", str(test_code_path)]\n",
    "\n",
    "    try:\n",
    "        result = subprocess.run(command, check=True, capture_output=True, text=True)\n",
    "        print(\"Tests ran successfully!\")\n",
    "        print(result.stderr)\n",
    "        return result.stdout\n",
    "    except subprocess.CalledProcessError as e:\n",
    "        print(\"Some tests failed!\")\n",
    "        print(\"Output:\\n\", e.stdout)\n",
    "        print(\"Errors:\\n\", e.stderr)\n",
    "        # Extracting failed test information\n",
    "        failed_tests = []\n",
    "        for line in e.stdout.splitlines():\n",
    "            if \"FAILED\" in line and \"::\" in line:\n",
    "                failed_tests.append(line.strip())\n",
    "        if failed_tests:\n",
    "            print(\"Failed Tests:\")\n",
    "            for test in failed_tests:\n",
    "                print(test)\n",
    "    \n",
    "        return e.stdout\n",
    "    "
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prompts and calls to the models"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [],
   "source": [
    "system_message = \"\"\"You are a helpful assistant which helps developers to write unit test cases for their code.\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_user_prompt(code):\n",
    "\n",
    "    user_prompt = \"\"\"Test include:\n",
    "\n",
    "    - Valid inputs with expected results.\n",
    "    - Inputs that test the boundaries or limits of the function's behavior.\n",
    "    - Invalid inputs or scenarios where the function is expected to raise exceptions.\n",
    "\n",
    "    Structure:\n",
    "\n",
    "    - Begin with all necessary imports. \n",
    "    - Do not create custom imports. \n",
    "    - Do not insert in the response the function for the tests.\n",
    "    - Ensure proper error handling for tests that expect exceptions.\n",
    "    - Clearly name the test functions to indicate their purpose (e.g., test_function_name).\n",
    "\n",
    "    Example Structure:\n",
    "\n",
    "    - Use pytest.raises to validate exceptions.\n",
    "    - Use assertions to verify correct outputs for successful and edge cases.\n",
    "\n",
    "    Documentation:\n",
    "\n",
    "    - Add docstrings explaining what each test verifies.\"\"\"\n",
    "    user_prompt += code\n",
    "\n",
    "    return user_prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [],
   "source": [
    "def stream_gpt(code):\n",
    "\n",
    "    user_prompt = get_user_prompt(code)\n",
    "    stream = openai.chat.completions.create(\n",
    "        model=OPENAI_MODEL,\n",
    "        messages=[\n",
    "            {\"role\": \"system\", \"content\": system_message},\n",
    "            {\n",
    "                \"role\": \"user\",\n",
    "                \"content\": user_prompt,\n",
    "            },\n",
    "        ],\n",
    "        stream=True,\n",
    "    )\n",
    "\n",
    "    response = \"\"\n",
    "    for chunk in stream:\n",
    "        response += chunk.choices[0].delta.content or \"\"\n",
    "        yield response\n",
    "    \n",
    "    return response\n",
    "\n",
    "def stream_ollama(code):\n",
    "\n",
    "    user_prompt = get_user_prompt(code)\n",
    "    ollama_via_openai = OpenAI(base_url='http://localhost:11434/v1', api_key='ollama')\n",
    "    stream = ollama_via_openai.chat.completions.create(\n",
    "        model=OLLAMA_MODEL,\n",
    "        messages=[\n",
    "            {\"role\": \"system\", \"content\": system_message},\n",
    "            {\n",
    "                \"role\": \"user\",\n",
    "                \"content\": user_prompt,\n",
    "            },\n",
    "        ],\n",
    "        stream=True,\n",
    "    )\n",
    "\n",
    "    response = \"\"\n",
    "    for chunk in stream:\n",
    "        response += chunk.choices[0].delta.content or \"\"\n",
    "        yield response\n",
    "    \n",
    "    return response\n",
    "\n",
    "\n",
    "def stream_claude(code):\n",
    "    user_prompt = get_user_prompt(code)\n",
    "    result = claude.messages.stream(\n",
    "        model=CLAUDE_MODEL,\n",
    "        max_tokens=2000,\n",
    "        system=system_message,\n",
    "        messages=[\n",
    "            {\n",
    "                \"role\": \"user\",\n",
    "                \"content\": user_prompt,\n",
    "            }\n",
    "        ],\n",
    "    )\n",
    "    reply = \"\"\n",
    "    with result as stream:\n",
    "        for text in stream.text_stream:\n",
    "            reply += text\n",
    "            yield reply\n",
    "            print(text, end=\"\", flush=True)\n",
    "    return reply"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Code examples to test the inteface"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {},
   "outputs": [],
   "source": [
    "function_to_test = \"\"\"\n",
    "    def lengthOfLongestSubstring(s):\n",
    "        if not isinstance(s, str):\n",
    "            raise TypeError(\"Input must be a string\")\n",
    "        max_length = 0\n",
    "        substring = \"\"\n",
    "        start_idx = 0\n",
    "        while start_idx < len(s):\n",
    "            string = s[start_idx:]\n",
    "            for i, x in enumerate(string):\n",
    "                substring += x\n",
    "                if len(substring) == len(set((list(substring)))):\n",
    "                    \n",
    "                    if len(set((list(substring)))) > max_length:\n",
    "                        \n",
    "                        max_length = len(substring)\n",
    "\n",
    "            start_idx += 1\n",
    "            substring = \"\"\n",
    "                  \n",
    "                \n",
    "        return max_length\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {},
   "outputs": [],
   "source": [
    "test_code = \"\"\"```python\n",
    "import pytest\n",
    "\n",
    "# Unit tests using pytest\n",
    "def test_lengthOfLongestSubstring():\n",
    "    assert lengthOfLongestSubstring(\"abcabcbb\") == 3  # Case with repeating characters\n",
    "    assert lengthOfLongestSubstring(\"bbbbb\") == 1    # Case with all same characters\n",
    "    assert lengthOfLongestSubstring(\"pwwkew\") == 3    # Case with mixed characters\n",
    "    assert lengthOfLongestSubstring(\"\") == 0           # Empty string case\n",
    "    assert lengthOfLongestSubstring(\"abcdef\") == 6     # All unique characters\n",
    "    assert lengthOfLongestSubstring(\"abca\") == 3       # Case with pattern and repeat\n",
    "    assert lengthOfLongestSubstring(\"dvdf\") == 3       # Case with repeated characters separated\n",
    "    assert lengthOfLongestSubstring(\"a\") == 1           # Case with single character\n",
    "    assert lengthOfLongestSubstring(\"au\") == 2          # Case with unique two characters\n",
    "```\"\"\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {},
   "outputs": [],
   "source": [
    "def optimize(code, model):\n",
    "    if model == \"GPT\":\n",
    "        result = stream_gpt(code)\n",
    "    elif model == \"Claude\":\n",
    "        result = stream_claude(code)\n",
    "    elif model == \"Ollama\":\n",
    "        result = stream_ollama(code)\n",
    "    else:\n",
    "        raise ValueError(\"Unknown model\")\n",
    "    for stream_so_far in result:\n",
    "        yield stream_so_far\n",
    "    return result"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Gradio interface"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "with gr.Blocks() as ui:\n",
    "    gr.Markdown(\"## Write unit tests for Python code\")\n",
    "    with gr.Row():\n",
    "        with gr.Column(scale=1, min_width=300):\n",
    "            python = gr.Textbox(label=\"Python code:\", value=function_to_test, lines=10)\n",
    "            model = gr.Dropdown([\"GPT\", \"Claude\", \"Ollama\"], label=\"Select model\", value=\"GPT\")\n",
    "            unit_tests = gr.Button(\"Write unit tests\")\n",
    "        with gr.Column(scale=1, min_width=300):\n",
    "            unit_tests_out = gr.TextArea(label=\"Unit tests\", value=test_code, elem_classes=[\"python\"])\n",
    "            unit_tests_run = gr.Button(\"Run unit tests\")\n",
    "            coverage_run = gr.Button(\"Coverage report\")\n",
    "            save_test_run = gr.Button(\"Save unit tests\")\n",
    "    with gr.Row():\n",
    "        \n",
    "        python_out = gr.TextArea(label=\"Unit tests result\", elem_classes=[\"python\"])\n",
    "        coverage_out = gr.TextArea(label=\"Coverage report\", elem_classes=[\"python\"])\n",
    "        \n",
    "\n",
    "    unit_tests.click(optimize, inputs=[python, model], outputs=[unit_tests_out])\n",
    "    unit_tests_run.click(execute_tests_in_venv, inputs=[python, unit_tests_out], outputs=[python_out])\n",
    "    coverage_run.click(execute_coverage_report, outputs=[coverage_out])\n",
    "    save_test_run.click(save_unit_tests, inputs=[unit_tests_out])\n",
    "\n",
    "\n",
    "ui.launch(inbrowser=True)\n",
    "# ui.launch()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "llm_engineering-yg2xCEUG",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.8"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}