{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "T-6b4FqreeIl",
        "collapsed": true
      },
      "outputs": [],
      "source": [
        "!pip install -q requests torch bitsandbytes transformers sentencepiece accelerate openai gradio"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "#imports\n",
        "\n",
        "import time\n",
        "from io import StringIO\n",
        "import torch\n",
        "import numpy as np\n",
        "import pandas as pd\n",
        "import random\n",
        "from openai import OpenAI\n",
        "from sqlalchemy import create_engine\n",
        "from google.colab import drive, userdata\n",
        "import gradio as gr\n",
        "from huggingface_hub import login\n",
        "from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig"
      ],
      "metadata": {
        "id": "JXpWOzKve7kr"
      },
      "execution_count": 3,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Model Constants\n",
        "LLAMA = \"meta-llama/Meta-Llama-3.1-8B-Instruct\""
      ],
      "metadata": {
        "id": "rcv0lCS5GRPX"
      },
      "execution_count": 4,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Authentication\n",
        "\n",
        "hf_token = userdata.get(\"HF_TOKEN\")\n",
        "openai_api_key = userdata.get(\"OPENAI_API_KEY\")\n",
        "if not hf_token or not openai_api_key:\n",
        " raise ValueError(\"Missing HF_TOKEN or OPENAI_API_KEY. Set them as environment variables.\")\n",
        "\n",
        "login(hf_token, add_to_git_credential=True)\n",
        "openai = OpenAI(api_key=openai_api_key)"
      ],
      "metadata": {
        "id": "3XS-s_CwFSQU"
      },
      "execution_count": 5,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Tokenizer Setup\n",
        "\n",
        "tokenizer = AutoTokenizer.from_pretrained(LLAMA)\n",
        "tokenizer.pad_token = tokenizer.eos_token"
      ],
      "metadata": {
        "id": "oRdmdzXoF_f9"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Model Quantization for Performance Optimization\n",
        "\n",
        "quant_config = BitsAndBytesConfig(\n",
        "    load_in_4bit=True,\n",
        "    bnb_4bit_use_double_quant=True,\n",
        "    bnb_4bit_compute_dtype=torch.bfloat16,\n",
        "    bnb_4bit_quant_type=\"nf4\"\n",
        ")"
      ],
      "metadata": {
        "id": "kRN0t2yrGmAe"
      },
      "execution_count": 7,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Load Model Efficiency\n",
        "\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "model = AutoModelForCausalLM.from_pretrained(LLAMA, device_map=\"auto\", quantization_config=quant_config)"
      ],
      "metadata": {
        "id": "fYPyudKHGuE9"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def generate_ev_driver(num_records, address_type):\n",
        "    # Adjusting the prompt based on checkbox selection\n",
        "    address_prompts = {\n",
        "        \"international\": f\"Generate {num_records} rows of synthetic personal data with international addresses and phone numbers.\",\n",
        "        \"us_only\": f\"Generate {num_records} rows of synthetic personal data with U.S.-only addresses and phone numbers.\",\n",
        "        \"us_international\": f\"Generate {num_records} rows of synthetic personal data with a mix of U.S. and international addresses and phone numbers.\",\n",
        "        \"americas\": f\"Generate {num_records} rows of synthetic personal data with a mix of U.S., Canada, Central America, and South America addresses and phone numbers.\",\n",
        "        \"europe\": f\"Generate {num_records} rows of synthetic personal data with Europe-only addresses and phone numbers.\",\n",
        "    }\n",
        "\n",
        "    address_prompt = address_prompts.get(address_type, \"Generate synthetic personal data.\")\n",
        "    # Generate unique driver IDs\n",
        "    driver_ids = random.sample(range(1, 1000001), num_records)\n",
        "\n",
        "    user_prompt = f\"\"\"\n",
        "    {address_prompt}\n",
        "    Each row should include:\n",
        "    - driverid (unique from the provided list: {driver_ids})\n",
        "    - first_name (string)\n",
        "    - last_name (string)\n",
        "    - email (string)\n",
        "    - phone_number (string)\n",
        "    - address (string)\n",
        "    - city (string)\n",
        "    - state (string)\n",
        "    - zip_code (string)\n",
        "    - country (string)\n",
        "\n",
        "    Ensure the CSV format is valid, with proper headers and comma separation.\n",
        "    \"\"\"\n",
        "\n",
        "    response = openai.chat.completions.create(\n",
        "        model=\"gpt-4o-mini\",\n",
        "        messages=[\n",
        "            {\"role\": \"system\", \"content\": \"You are a helpful assistant that generates structured CSV data.\"},\n",
        "            {\"role\": \"user\", \"content\": user_prompt}\n",
        "        ]\n",
        "    )\n",
        "\n",
        "    # Call the new function to clean and extract the CSV data\n",
        "    return clean_and_extract_csv(response)"
      ],
      "metadata": {
        "id": "9q9ccNr8fMyg"
      },
      "execution_count": 12,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def clean_and_extract_csv(response):\n",
        "    # Clean up the response and remove the last occurrence of the code block formatting\n",
        "    csv_data = response.choices[0].message.content.strip()\n",
        "    csv_data = csv_data.rsplit(\"```\", 1)[0].strip()\n",
        "\n",
        "    # Define header and split the content to extract the data\n",
        "    header = \"driverid,first_name,last_name,email,phone_number,address,city,state,zip_code,country\"\n",
        "    _, *content = csv_data.split(header, 1)\n",
        "\n",
        "    # Return the cleaned CSV data along with the header\n",
        "    return header + content[0].split(\"\\n\\n\")[0] if content else csv_data"
      ],
      "metadata": {
        "id": "So1aGRNJBUyv"
      },
      "execution_count": 13,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "def update_dataset(num_records, address_type):\n",
        "    response = generate_ev_driver(num_records, address_type)\n",
        "\n",
        "    # Convert response to DataFrame\n",
        "    try:\n",
        "        df = pd.read_csv(StringIO(response))\n",
        "    except Exception as e:\n",
        "        return pd.DataFrame(), f\"Error parsing dataset: {str(e)}\"\n",
        "\n",
        "    return df, response"
      ],
      "metadata": {
        "id": "T0KxUm2yYtuQ"
      },
      "execution_count": 14,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Function to handle address type selection\n",
        "def check_address_selection(selected_type):\n",
        "    if not selected_type:\n",
        "        # Return the error message and set button to yellow and disabled\n",
        "        return (\n",
        "            \"<span style='color:red;'>⚠️ Address type is required. Please select one.</span>\",\n",
        "            gr.update(interactive=False, elem_classes=\"yellow_btn\")\n",
        "        )\n",
        "    # Return success message and set button to blue and enabled\n",
        "    return (\n",
        "        \"<span style='color:green;'>Ready to generate dataset.</span>\",\n",
        "        gr.update(interactive=True, elem_classes=\"blue_btn\")\n",
        "    )\n"
      ],
      "metadata": {
        "id": "z5pFDbnTz-fP"
      },
      "execution_count": 15,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# Gradio UI\n",
        "with gr.Blocks() as app:\n",
        "    gr.Markdown(\"## Dynamic CSV Dataset Viewer\")\n",
        "\n",
        "    num_records_slider = gr.Slider(minimum=5, maximum=50, step=5, value=20, label=\"Number of Records\")\n",
        "\n",
        "    with gr.Row(equal_height=True):\n",
        "        address_type_radio = gr.Radio(\n",
        "            [\"us_only\", \"international\", \"us_international\", \"americas\", \"europe\"],\n",
        "            value=\"\",\n",
        "            label=\"Address and Phone Type\",\n",
        "            info=\"Select the type of addresses and phone numbers\"\n",
        "        )\n",
        "        status_text = gr.Markdown(\n",
        "            \"<span style='color:red;'>⚠️ Please select an address type above to proceed.</span>\",\n",
        "            elem_id=\"status_text\"\n",
        "        )\n",
        "\n",
        "    generate_btn = gr.Button(\"Generate Data\", interactive=True, elem_id=\"generate_btn\")\n",
        "\n",
        "    response_text = gr.Textbox(value=\"\", label=\"Generated Driver List CSV\", interactive=False)\n",
        "    dataframe_output = gr.Dataframe(value=pd.DataFrame(), label=\"Generated Driver List Dataset\")\n",
        "\n",
        "    # Update status text and button style dynamically\n",
        "    address_type_radio.change(fn=check_address_selection, inputs=[address_type_radio], outputs=[status_text, generate_btn])\n",
        "\n",
        "    generate_btn.click(update_dataset, inputs=[num_records_slider, address_type_radio], outputs=[dataframe_output, response_text])\n",
        "\n",
        "    # Custom CSS for button colors\n",
        "    app.css = \"\"\"\n",
        "    .blue_btn {\n",
        "        background-color: green;\n",
        "        color: white;\n",
        "    }\n",
        "    \"\"\"\n",
        "\n",
        "app.launch(share=True)  # Ensure sharing is enabled in Colab"
      ],
      "metadata": {
        "id": "z3K6PfAiL2ZA"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}