{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "kU2JrcPlhwd9"
   },
   "outputs": [],
   "source": [
    "!pip install -q requests torch bitsandbytes transformers sentencepiece accelerate gradio"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "lAMIVT4iwNg0"
   },
   "source": [
    "**Imports**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "-Apd7-p-hyLk"
   },
   "outputs": [],
   "source": [
    "import os\n",
    "import requests\n",
    "from google.colab import drive\n",
    "from huggingface_hub import login\n",
    "from google.colab import userdata\n",
    "from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, BitsAndBytesConfig\n",
    "import torch\n",
    "import gradio as gr\n",
    "\n",
    "hf_token = userdata.get('HF_TOKEN')\n",
    "login(hf_token, add_to_git_credential=True)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xa0qYqZrwQ66"
   },
   "source": [
    "**Model**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "z5enGmuKjtJu"
   },
   "outputs": [],
   "source": [
    "model_name = \"meta-llama/Meta-Llama-3.1-8B-Instruct\"\n",
    "quant_config = BitsAndBytesConfig(\n",
    "    load_in_4bit=True,\n",
    "    bnb_4bit_use_double_quant=True,\n",
    "    bnb_4bit_compute_dtype=torch.bfloat16,\n",
    "    bnb_4bit_quant_type=\"nf4\"\n",
    ")\n",
    "\n",
    "model = AutoModelForCausalLM.from_pretrained(\n",
    "  model_name,\n",
    "  device_map=\"auto\",\n",
    "  quantization_config=quant_config\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "y1hUSmWlwSbp"
   },
   "source": [
    "**Tokenizer**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "WjxNWW6bvdgj"
   },
   "outputs": [],
   "source": [
    "tokenizer = AutoTokenizer.from_pretrained(model_name)\n",
    "tokenizer.pad_token = tokenizer.eos_token"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "1pg2U-B3wbIK"
   },
   "source": [
    "**Functions**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ZvljDKdji8iV"
   },
   "outputs": [],
   "source": [
    "def generate_dataset(topic, number_of_data, inst1, resp1, inst2, resp2, inst3, resp3):\n",
    "    # Convert user inputs into multi-shot examples\n",
    "    multi_shot_examples = [\n",
    "        {\"instruction\": inst1, \"response\": resp1},\n",
    "        {\"instruction\": inst2, \"response\": resp2},\n",
    "        {\"instruction\": inst3, \"response\": resp3}\n",
    "    ]\n",
    "\n",
    "    # System prompt\n",
    "    system_prompt = f\"\"\"\n",
    "    You are a helpful assistant whose main purpose is to generate datasets.\n",
    "    Topic: {topic}\n",
    "    Return the dataset in JSON format. Use examples with simple, fun, and easy-to-understand instructions for kids.\n",
    "    Include the following examples: {multi_shot_examples}\n",
    "    Return {number_of_data} examples each time.\n",
    "    Do not repeat the provided examples.\n",
    "    \"\"\"\n",
    "\n",
    "    # Example Messages\n",
    "    messages = [\n",
    "        {\"role\": \"system\", \"content\": system_prompt},\n",
    "        {\"role\": \"user\", \"content\": f\"Please generate my dataset for {topic}\"}\n",
    "    ]\n",
    "\n",
    "    # Tokenize Input\n",
    "    inputs = tokenizer.apply_chat_template(messages, return_tensors=\"pt\").to(\"cuda\")\n",
    "    streamer = TextStreamer(tokenizer)\n",
    "\n",
    "    # Generate Output\n",
    "    outputs = model.generate(inputs, max_new_tokens=2000, streamer=streamer)\n",
    "\n",
    "    # Decode and Return\n",
    "    return tokenizer.decode(outputs[0], skip_special_tokens=True)\n",
    "\n",
    "\n",
    "def gradio_interface(topic, number_of_data, inst1, resp1, inst2, resp2, inst3, resp3):\n",
    "    return generate_dataset(topic, number_of_data, inst1, resp1, inst2, resp2, inst3, resp3)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "_WDZ5dvRwmng"
   },
   "source": [
    "**Default Values**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "JAdfqYXnvEDE"
   },
   "outputs": [],
   "source": [
    "default_topic = \"Talking to a (5-8) years old and teaching them manners.\"\n",
    "default_number_of_data = 10\n",
    "default_multi_shot_examples = [\n",
    "    {\n",
    "        \"instruction\": \"Why do I have to say please when I want something?\",\n",
    "        \"response\": \"Because it’s like magic! It shows you’re nice, and people want to help you more.\"\n",
    "    },\n",
    "    {\n",
    "        \"instruction\": \"What should I say if someone gives me a toy?\",\n",
    "        \"response\": \"You say, 'Thank you!' because it makes them happy you liked it.\"\n",
    "    },\n",
    "    {\n",
    "        \"instruction\": \"why should I listen to my parents?\",\n",
    "        \"response\": \"Because parents want the best for you and they love you the most.\"\n",
    "    }\n",
    "]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "JwZtD032wuK8"
   },
   "source": [
    "**Init gradio**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "xy2RP5T-vxXg"
   },
   "outputs": [],
   "source": [
    "gr_interface = gr.Interface(\n",
    "    fn=gradio_interface,\n",
    "    inputs=[\n",
    "        gr.Textbox(label=\"Topic\", value=default_topic, lines=2),\n",
    "        gr.Number(label=\"Number of Examples\", value=default_number_of_data, precision=0),\n",
    "        gr.Textbox(label=\"Instruction 1\", value=default_multi_shot_examples[0][\"instruction\"]),\n",
    "        gr.Textbox(label=\"Response 1\", value=default_multi_shot_examples[0][\"response\"]),\n",
    "        gr.Textbox(label=\"Instruction 2\", value=default_multi_shot_examples[1][\"instruction\"]),\n",
    "        gr.Textbox(label=\"Response 2\", value=default_multi_shot_examples[1][\"response\"]),\n",
    "        gr.Textbox(label=\"Instruction 3\", value=default_multi_shot_examples[2][\"instruction\"]),\n",
    "        gr.Textbox(label=\"Response 3\", value=default_multi_shot_examples[2][\"response\"]),\n",
    "    ],\n",
    "    outputs=gr.Textbox(label=\"Generated Dataset\")\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "HZx-mm9Uw3Ph"
   },
   "source": [
    "**Run the app**"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "bfGs5ip8mndg"
   },
   "outputs": [],
   "source": [
    "gr_interface.launch()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Cveqx392x7Mm"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "GPU",
  "colab": {
   "gpuType": "T4",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.11"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}