From the uDemy course on LLM engineering.
https://www.udemy.com/course/llm-engineering-master-ai-and-large-language-models
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
37 lines
1.2 KiB
37 lines
1.2 KiB
# imports |
|
|
|
import os |
|
import re |
|
from typing import List |
|
from sentence_transformers import SentenceTransformer |
|
import joblib |
|
from agents.agent import Agent |
|
|
|
|
|
|
|
class RandomForestAgent(Agent): |
|
|
|
name = "Random Forest Agent" |
|
color = Agent.MAGENTA |
|
|
|
def __init__(self): |
|
""" |
|
Initialize this object by loading in the saved model weights |
|
and the SentenceTransformer vector encoding model |
|
""" |
|
self.log("Random Forest Agent is initializing") |
|
self.vectorizer = SentenceTransformer('sentence-transformers/all-MiniLM-L6-v2') |
|
self.model = joblib.load('random_forest_model.pkl') |
|
self.log("Random Forest Agent is ready") |
|
|
|
def price(self, description: str) -> float: |
|
""" |
|
Use a Random Forest model to estimate the price of the described item |
|
:param description: the product to be estimated |
|
:return: the price as a float |
|
""" |
|
self.log("Random Forest Agent is starting a prediction") |
|
vector = self.vectorizer.encode([description]) |
|
result = max(0, self.model.predict(vector)[0]) |
|
self.log(f"Random Forest Agent completed - predicting ${result:.2f}") |
|
return result |