{
  "nbformat": 4,
  "nbformat_minor": 0,
  "metadata": {
    "colab": {
      "provenance": [],
      "gpuType": "T4",
      "authorship_tag": "ABX9TyPtAT7Yq5xd4vDcJEZtg69J"
    },
    "kernelspec": {
      "name": "python3",
      "display_name": "Python 3"
    },
    "language_info": {
      "name": "python"
    },
    "accelerator": "GPU"
  },
  "cells": [
    {
      "cell_type": "code",
      "source": [
        "# getting the latest transformers first, since this will require a restart\n",
        "\n",
        "!pip install git+https://github.com/huggingface/transformers.git"
      ],
      "metadata": {
        "id": "6gGKXU5RXORf"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# imports\n",
        "\n",
        "import torch\n",
        "from google.colab import userdata\n",
        "from huggingface_hub import login\n",
        "from transformers import AutoProcessor, AutoModelForImageTextToText\n",
        "from google.colab import files"
      ],
      "metadata": {
        "id": "yCRrF4aiXPPo"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# logging in to HF\n",
        "\n",
        "hf_token = userdata.get('HF_TOKEN')\n",
        "login(hf_token, add_to_git_credential=True)"
      ],
      "metadata": {
        "id": "AAlOQuCbXcrv"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "_RRVc2j2Vun-"
      },
      "outputs": [],
      "source": [
        "# this will start an input prompt for uploading local files\n",
        "\n",
        "uploaded = files.upload()\n",
        "print(uploaded.keys()) # this will look sth like dict_keys([\"note2.jpg\"])"
      ]
    },
    {
      "cell_type": "code",
      "source": [
        "'''\n",
        "ChatGPT and Gemini explain the following part roughly like so:\n",
        "The string contained in image_path is the key of the entry in the dictionary of uploaded files (see box above).\n",
        "The value to that key contains the image in binary format.\n",
        "The \"with open(image_path, \"wb\") as f\" part means: Create a new file \"note2.jpg\" on the server, and write to it in binary mode (\"wb\").\n",
        "f.write(image) writes the binary image to that new file. \"note2.jpg\" aka image_path will now contain the image.\n",
        "'''\n",
        "\n",
        "image_path = \"note2.jpg\" # update this string depending on the printout in the previous cell!\n",
        "image = uploaded[image_path]\n",
        "with open(image_path, \"wb\") as f:\n",
        "  f.write(image)"
      ],
      "metadata": {
        "id": "V_UAuSSkXBKh"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# from HF model instructions\n",
        "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
        "model = AutoModelForImageTextToText.from_pretrained(\"stepfun-ai/GOT-OCR-2.0-hf\", device_map=device)\n",
        "processor = AutoProcessor.from_pretrained(\"stepfun-ai/GOT-OCR-2.0-hf\")"
      ],
      "metadata": {
        "id": "AiFP-mQtXrpV"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# also from HF documentation about this model, see https://huggingface.co/stepfun-ai/GOT-OCR-2.0-hf\n",
        "\n",
        "image = image_path\n",
        "inputs = processor(image, return_tensors=\"pt\").to(device)\n",
        "\n",
        "ocr = model.generate(\n",
        "    **inputs,\n",
        "    do_sample=False,\n",
        "    tokenizer=processor.tokenizer,\n",
        "    stop_strings=\"<|im_end|>\",\n",
        "    max_new_tokens=4096,\n",
        ")"
      ],
      "metadata": {
        "id": "7Adr8HB_YNf5"
      },
      "execution_count": null,
      "outputs": []
    },
    {
      "cell_type": "code",
      "source": [
        "# prints out the recognized text. This can read my handwriting pretty well! And it works super quick on the free T4 GPU server here.\n",
        "\n",
        "print(processor.decode(ocr[0, inputs[\"input_ids\"].shape[1]:], skip_special_tokens=True))"
      ],
      "metadata": {
        "id": "nRsRUIIuYdJ9"
      },
      "execution_count": null,
      "outputs": []
    }
  ]
}