In [None]:
# getting the latest transformers first, since this will require a restart

!pip install git+https://github.com/huggingface/transformers.git

In [None]:
# imports

import torch
from google.colab import userdata
from huggingface_hub import login
from transformers import AutoProcessor, AutoModelForImageTextToText
from google.colab import files

In [None]:
# logging in to HF

hf_token = userdata.get('HF_TOKEN')
login(hf_token, add_to_git_credential=True)

In [None]:
# this will start an input prompt for uploading local files

uploaded = files.upload()
print(uploaded.keys()) # this will look sth like dict_keys(["note2.jpg"])

In [None]:
'''
ChatGPT and Gemini explain the following part roughly like so:
The string contained in image_path is the key of the entry in the dictionary of uploaded files (see box above).
The value to that key contains the image in binary format.
The "with open(image_path, "wb") as f" part means: Create a new file "note2.jpg" on the server, and write to it in binary mode ("wb").
f.write(image) writes the binary image to that new file. "note2.jpg" aka image_path will now contain the image.
'''

image_path = "note2.jpg" # update this string depending on the printout in the previous cell!
image = uploaded[image_path]
with open(image_path, "wb") as f:
  f.write(image)

In [None]:
# from HF model instructions
device = "cuda" if torch.cuda.is_available() else "cpu"
model = AutoModelForImageTextToText.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf", device_map=device)
processor = AutoProcessor.from_pretrained("stepfun-ai/GOT-OCR-2.0-hf")

In [None]:
# also from HF documentation about this model, see https://huggingface.co/stepfun-ai/GOT-OCR-2.0-hf

image = image_path
inputs = processor(image, return_tensors="pt").to(device)

ocr = model.generate(
    **inputs,
    do_sample=False,
    tokenizer=processor.tokenizer,
    stop_strings="<|im_end|>",
    max_new_tokens=4096,
)

In [None]:
# prints out the recognized text. This can read my handwriting pretty well! And it works super quick on the free T4 GPU server here.

print(processor.decode(ocr[0, inputs["input_ids"].shape[1]:], skip_special_tokens=True))