{ "cells": [ { "cell_type": "code", "execution_count": null, "metadata": { "id": "kU2JrcPlhwd9" }, "outputs": [], "source": [ "!pip install -q requests torch bitsandbytes transformers sentencepiece accelerate gradio" ] }, { "cell_type": "markdown", "metadata": { "id": "lAMIVT4iwNg0" }, "source": [ "**Imports**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "-Apd7-p-hyLk" }, "outputs": [], "source": [ "import os\n", "import requests\n", "from google.colab import drive\n", "from huggingface_hub import login\n", "from google.colab import userdata\n", "from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, BitsAndBytesConfig\n", "import torch\n", "import gradio as gr\n", "\n", "hf_token = userdata.get('HF_TOKEN')\n", "login(hf_token, add_to_git_credential=True)" ] }, { "cell_type": "markdown", "metadata": { "id": "xa0qYqZrwQ66" }, "source": [ "**Model**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "z5enGmuKjtJu" }, "outputs": [], "source": [ "model_name = \"meta-llama/Meta-Llama-3.1-8B-Instruct\"\n", "quant_config = BitsAndBytesConfig(\n", " load_in_4bit=True,\n", " bnb_4bit_use_double_quant=True,\n", " bnb_4bit_compute_dtype=torch.bfloat16,\n", " bnb_4bit_quant_type=\"nf4\"\n", ")\n", "\n", "model = AutoModelForCausalLM.from_pretrained(\n", " model_name,\n", " device_map=\"auto\",\n", " quantization_config=quant_config\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "y1hUSmWlwSbp" }, "source": [ "**Tokenizer**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "WjxNWW6bvdgj" }, "outputs": [], "source": [ "tokenizer = AutoTokenizer.from_pretrained(model_name)\n", "tokenizer.pad_token = tokenizer.eos_token" ] }, { "cell_type": "markdown", "metadata": { "id": "1pg2U-B3wbIK" }, "source": [ "**Functions**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "ZvljDKdji8iV" }, "outputs": [], "source": [ "def generate_dataset(topic, number_of_data, inst1, resp1, inst2, resp2, inst3, resp3):\n", " # Convert user inputs into multi-shot examples\n", " multi_shot_examples = [\n", " {\"instruction\": inst1, \"response\": resp1},\n", " {\"instruction\": inst2, \"response\": resp2},\n", " {\"instruction\": inst3, \"response\": resp3}\n", " ]\n", "\n", " # System prompt\n", " system_prompt = f\"\"\"\n", " You are a helpful assistant whose main purpose is to generate datasets.\n", " Topic: {topic}\n", " Return the dataset in JSON format. Use examples with simple, fun, and easy-to-understand instructions for kids.\n", " Include the following examples: {multi_shot_examples}\n", " Return {number_of_data} examples each time.\n", " Do not repeat the provided examples.\n", " \"\"\"\n", "\n", " # Example Messages\n", " messages = [\n", " {\"role\": \"system\", \"content\": system_prompt},\n", " {\"role\": \"user\", \"content\": f\"Please generate my dataset for {topic}\"}\n", " ]\n", "\n", " # Tokenize Input\n", " inputs = tokenizer.apply_chat_template(messages, return_tensors=\"pt\").to(\"cuda\")\n", " streamer = TextStreamer(tokenizer)\n", "\n", " # Generate Output\n", " outputs = model.generate(inputs, max_new_tokens=2000, streamer=streamer)\n", "\n", " # Decode and Return\n", " return tokenizer.decode(outputs[0], skip_special_tokens=True)\n", "\n", "\n", "def gradio_interface(topic, number_of_data, inst1, resp1, inst2, resp2, inst3, resp3):\n", " return generate_dataset(topic, number_of_data, inst1, resp1, inst2, resp2, inst3, resp3)" ] }, { "cell_type": "markdown", "metadata": { "id": "_WDZ5dvRwmng" }, "source": [ "**Default Values**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "JAdfqYXnvEDE" }, "outputs": [], "source": [ "default_topic = \"Talking to a (5-8) years old and teaching them manners.\"\n", "default_number_of_data = 10\n", "default_multi_shot_examples = [\n", " {\n", " \"instruction\": \"Why do I have to say please when I want something?\",\n", " \"response\": \"Because it’s like magic! It shows you’re nice, and people want to help you more.\"\n", " },\n", " {\n", " \"instruction\": \"What should I say if someone gives me a toy?\",\n", " \"response\": \"You say, 'Thank you!' because it makes them happy you liked it.\"\n", " },\n", " {\n", " \"instruction\": \"why should I listen to my parents?\",\n", " \"response\": \"Because parents want the best for you and they love you the most.\"\n", " }\n", "]" ] }, { "cell_type": "markdown", "metadata": { "id": "JwZtD032wuK8" }, "source": [ "**Init gradio**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "xy2RP5T-vxXg" }, "outputs": [], "source": [ "gr_interface = gr.Interface(\n", " fn=gradio_interface,\n", " inputs=[\n", " gr.Textbox(label=\"Topic\", value=default_topic, lines=2),\n", " gr.Number(label=\"Number of Examples\", value=default_number_of_data, precision=0),\n", " gr.Textbox(label=\"Instruction 1\", value=default_multi_shot_examples[0][\"instruction\"]),\n", " gr.Textbox(label=\"Response 1\", value=default_multi_shot_examples[0][\"response\"]),\n", " gr.Textbox(label=\"Instruction 2\", value=default_multi_shot_examples[1][\"instruction\"]),\n", " gr.Textbox(label=\"Response 2\", value=default_multi_shot_examples[1][\"response\"]),\n", " gr.Textbox(label=\"Instruction 3\", value=default_multi_shot_examples[2][\"instruction\"]),\n", " gr.Textbox(label=\"Response 3\", value=default_multi_shot_examples[2][\"response\"]),\n", " ],\n", " outputs=gr.Textbox(label=\"Generated Dataset\")\n", ")" ] }, { "cell_type": "markdown", "metadata": { "id": "HZx-mm9Uw3Ph" }, "source": [ "**Run the app**" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "bfGs5ip8mndg" }, "outputs": [], "source": [ "gr_interface.launch()" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "Cveqx392x7Mm" }, "outputs": [], "source": [] } ], "metadata": { "accelerator": "GPU", "colab": { "gpuType": "T4", "provenance": [] }, "kernelspec": { "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.11" } }, "nbformat": 4, "nbformat_minor": 4 }