diff --git a/week4/community-contributions/unit-tests-generator.ipynb b/week4/community-contributions/unit-tests-generator.ipynb new file mode 100644 index 0000000..4825544 --- /dev/null +++ b/week4/community-contributions/unit-tests-generator.ipynb @@ -0,0 +1,432 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Requirements\n", + "\n", + "1. Install pytest and pytest-cov library\n" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!pipenv install pytest pytest-cov" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "# imports\n", + "import re\n", + "import os\n", + "import sys\n", + "import textwrap\n", + "from dotenv import load_dotenv\n", + "from openai import OpenAI\n", + "import anthropic\n", + "import gradio as gr\n", + "from pathlib import Path\n", + "import subprocess\n", + "from IPython.display import Markdown" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Initialization\n", + "\n", + "load_dotenv()\n", + "\n", + "openai_api_key = os.getenv('OPENAI_API_KEY')\n", + "os.environ['ANTHROPIC_API_KEY'] = os.getenv('ANTHROPIC_API_KEY', 'your-key-if-not-using-env')\n", + "if openai_api_key:\n", + " print(f\"OpenAI API Key exists and begins {openai_api_key[:8]}\")\n", + "else:\n", + " print(\"OpenAI API Key not set\")\n", + " \n", + "OPENAI_MODEL = \"gpt-4o-mini\"\n", + "CLAUDE_MODEL = \"claude-3-5-sonnet-20240620\"\n", + "openai = OpenAI()\n", + "claude = anthropic.Anthropic()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "OLLAMA_API = \"http://localhost:11434/api/chat\"\n", + "HEADERS = {\"Content-Type\": \"application/json\"}\n", + "OLLAMA_MODEL = \"llama3.2\"" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Code execution" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [], + "source": [ + "\n", + "def extract_code(text):\n", + " # Regular expression to find text between ``python and ``\n", + " match = re.search(r\"```python(.*?)```\", text, re.DOTALL)\n", + "\n", + " if match:\n", + " code = match.group(0).strip() # Extract and strip extra spaces\n", + " else:\n", + " code = \"\"\n", + " print(\"No matching substring found.\")\n", + "\n", + " return code.replace(\"```python\\n\", \"\").replace(\"```\", \"\")\n", + "\n", + "\n", + "def execute_coverage_report(python_interpreter=sys.executable):\n", + " if not python_interpreter:\n", + " raise EnvironmentError(\"Python interpreter not found in the specified virtual environment.\")\n", + " # test_code_path = Path(\"tests\")\n", + " # command = [\"pytest\", \"-cov\",\"--capture=no\"]\n", + " command = [\"coverage\", \"run\", \"-m\", \"pytest\"]\n", + " # command =[\"pytest\", \"--cov=your_package\", \"--cov-report=term-missing\"]\n", + "\n", + " try:\n", + " result = subprocess.run(command, check=True, capture_output=True, text=True)\n", + " print(\"Tests ran successfully!\")\n", + " print(result.stdout)\n", + " return result.stdout\n", + " except subprocess.CalledProcessError as e:\n", + " print(\"Some tests failed!\")\n", + " print(\"Output:\\n\", e.stdout)\n", + " print(\"Errors:\\n\", e.stderr)\n", + " # Extracting failed test information\n", + " failed_tests = []\n", + " for line in e.stdout.splitlines():\n", + " if \"FAILED\" in line and \"::\" in line:\n", + " failed_tests.append(line.strip())\n", + " if failed_tests:\n", + " print(\"Failed Tests:\")\n", + " for test in failed_tests:\n", + " print(test)\n", + " return failed_tests\n", + "\n", + "def save_unit_tests(code):\n", + "\n", + " match = re.search(r\"def\\s+(\\w+)\\(\", code, re.DOTALL)\n", + "\n", + " if match:\n", + " function_name = match.group(1).strip() # Extract and strip extra spaces\n", + " else:\n", + " function_name = \"\"\n", + " print(\"No matching substring found.\")\n", + "\n", + " test_code_path = Path(\"tests\")\n", + " (test_code_path / f\"test_{function_name}.py\").write_text(extract_code(code))\n", + " Path(\"tests\", \"test_code.py\").unlink()\n", + " \n", + "\n", + "def execute_tests_in_venv(code_to_test, tests, python_interpreter=sys.executable):\n", + " \"\"\"\n", + " Execute the given Python code string within the specified virtual environment.\n", + " \n", + " Args:\n", + " - code_str: str, the Python code to execute.\n", + " - venv_dir: str, the directory path to the virtual environment created by pipenv.\n", + " \"\"\"\n", + " \n", + " if not python_interpreter:\n", + " raise EnvironmentError(\"Python interpreter not found in the specified virtual environment.\")\n", + "\n", + " # Prepare the command to execute the code\n", + " code_str = textwrap.dedent(code_to_test) + \"\\n\" + extract_code(tests)\n", + " test_code_path = Path(\"tests\")\n", + " test_code_path.mkdir(parents=True, exist_ok=True)\n", + " (test_code_path / f\"test_code.py\").write_text(code_str)\n", + " command = [\"pytest\", str(test_code_path)]\n", + "\n", + " try:\n", + " result = subprocess.run(command, check=True, capture_output=True, text=True)\n", + " print(\"Tests ran successfully!\")\n", + " print(result.stderr)\n", + " return result.stdout\n", + " except subprocess.CalledProcessError as e:\n", + " print(\"Some tests failed!\")\n", + " print(\"Output:\\n\", e.stdout)\n", + " print(\"Errors:\\n\", e.stderr)\n", + " # Extracting failed test information\n", + " failed_tests = []\n", + " for line in e.stdout.splitlines():\n", + " if \"FAILED\" in line and \"::\" in line:\n", + " failed_tests.append(line.strip())\n", + " if failed_tests:\n", + " print(\"Failed Tests:\")\n", + " for test in failed_tests:\n", + " print(test)\n", + " return e.stderr\n", + " " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prompts and calls to the models" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [], + "source": [ + "system_message = \"\"\"You are a helpful assistant which helps developers to write unit test cases for their code.\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [], + "source": [ + "def get_user_prompt(code):\n", + "\n", + " user_prompt = \"Write for a python code the unit test cases.\"\n", + " user_prompt += \"Return unit tests cases using pytest library, do not create any custom imports; do not explain your work other than a few comments.\"\n", + " user_prompt += \"Do not insert the function to be tested in the output before the tests. Validate both the case where the function is executed successfully and where it is expected to fail.\"\n", + " user_prompt += code\n", + "\n", + " return user_prompt" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "def stream_gpt(code):\n", + "\n", + " user_prompt = get_user_prompt(code)\n", + " stream = openai.chat.completions.create(\n", + " model=OPENAI_MODEL,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": system_message},\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": user_prompt,\n", + " },\n", + " ],\n", + " stream=True,\n", + " )\n", + "\n", + " response = \"\"\n", + " for chunk in stream:\n", + " response += chunk.choices[0].delta.content or \"\"\n", + " yield response\n", + " \n", + " return response\n", + "\n", + "def stream_ollama(code):\n", + "\n", + " user_prompt = get_user_prompt(code)\n", + " ollama_via_openai = OpenAI(base_url='http://localhost:11434/v1', api_key='ollama')\n", + " stream = ollama_via_openai.chat.completions.create(\n", + " model=OLLAMA_MODEL,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": system_message},\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": user_prompt,\n", + " },\n", + " ],\n", + " stream=True,\n", + " )\n", + "\n", + " response = \"\"\n", + " for chunk in stream:\n", + " response += chunk.choices[0].delta.content or \"\"\n", + " yield response\n", + " \n", + " return response\n", + "\n", + "\n", + "def stream_claude(code):\n", + " user_prompt = get_user_prompt(code)\n", + " result = claude.messages.stream(\n", + " model=CLAUDE_MODEL,\n", + " max_tokens=2000,\n", + " system=system_message,\n", + " messages=[\n", + " {\n", + " \"role\": \"user\",\n", + " \"content\": user_prompt,\n", + " }\n", + " ],\n", + " )\n", + " reply = \"\"\n", + " with result as stream:\n", + " for text in stream.text_stream:\n", + " reply += text\n", + " yield reply\n", + " print(text, end=\"\", flush=True)\n", + " return reply" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Code examples to test the inteface" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "function_to_test = \"\"\"\n", + " def lengthOfLongestSubstring(s):\n", + " max_length = 0\n", + " substring = \"\"\n", + " start_idx = 0\n", + " while start_idx < len(s):\n", + " string = s[start_idx:]\n", + " for i, x in enumerate(string):\n", + " substring += x\n", + " if len(substring) == len(set((list(substring)))):\n", + " \n", + " if len(set((list(substring)))) > max_length:\n", + " \n", + " max_length = len(substring)\n", + "\n", + " start_idx += 1\n", + " substring = \"\"\n", + " \n", + " \n", + " return max_length\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [], + "source": [ + "test_code = \"\"\"```python\n", + "import pytest\n", + "\n", + "# Unit tests using pytest\n", + "def test_lengthOfLongestSubstring():\n", + " assert lengthOfLongestSubstring(\"abcabcbb\") == 3 # Case with repeating characters\n", + " assert lengthOfLongestSubstring(\"bbbbb\") == 1 # Case with all same characters\n", + " assert lengthOfLongestSubstring(\"pwwkew\") == 3 # Case with mixed characters\n", + " assert lengthOfLongestSubstring(\"\") == 0 # Empty string case\n", + " assert lengthOfLongestSubstring(\"abcdef\") == 6 # All unique characters\n", + " assert lengthOfLongestSubstring(\"abca\") == 3 # Case with pattern and repeat\n", + " assert lengthOfLongestSubstring(\"dvdf\") == 3 # Case with repeated characters separated\n", + " assert lengthOfLongestSubstring(\"a\") == 1 # Case with single character\n", + " assert lengthOfLongestSubstring(\"au\") == 2 # Case with unique two characters\n", + "```\"\"\"" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "def optimize(code, model):\n", + " if model == \"GPT\":\n", + " result = stream_gpt(code)\n", + " elif model == \"Claude\":\n", + " result = stream_claude(code)\n", + " elif model == \"Ollama\":\n", + " result = stream_ollama(code)\n", + " else:\n", + " raise ValueError(\"Unknown model\")\n", + " for stream_so_far in result:\n", + " yield stream_so_far\n", + " return result" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Gradio interface" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "with gr.Blocks() as ui:\n", + " gr.Markdown(\"## Write unit tests for Python code\")\n", + " with gr.Row():\n", + " with gr.Column(scale=1, min_width=300):\n", + " python = gr.Textbox(label=\"Python code:\", value=function_to_test, lines=10)\n", + " model = gr.Dropdown([\"GPT\", \"Claude\", \"Ollama\"], label=\"Select model\", value=\"GPT\")\n", + " unit_tests = gr.Button(\"Write unit tests\")\n", + " with gr.Column(scale=1, min_width=300):\n", + " unit_tests_out = gr.TextArea(label=\"Unit tests\", value=test_code, elem_classes=[\"python\"])\n", + " unit_tests_run = gr.Button(\"Run unit tests\")\n", + " coverage_run = gr.Button(\"Coverage report\")\n", + " save_test_run = gr.Button(\"Save unit tests\")\n", + " with gr.Row():\n", + " \n", + " python_out = gr.TextArea(label=\"Unit tests result\", elem_classes=[\"python\"])\n", + " coverage_out = gr.TextArea(label=\"Coverage report\", elem_classes=[\"python\"])\n", + " \n", + "\n", + " unit_tests.click(optimize, inputs=[python, model], outputs=[unit_tests_out])\n", + " unit_tests_run.click(execute_tests_in_venv, inputs=[python, unit_tests_out], outputs=[python_out])\n", + " coverage_run.click(execute_coverage_report, outputs=[coverage_out])\n", + " save_test_run.click(save_unit_tests, inputs=[unit_tests_out])\n", + "\n", + "\n", + "ui.launch(inbrowser=True)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "llm_engineering-yg2xCEUG", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.8" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}