You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
 
 

102 lines
2.7 KiB

# imports
import os
from dotenv import load_dotenv
from huggingface_hub import login
from datasets import load_dataset, Dataset, DatasetDict
import matplotlib.pyplot as plt
import json
# environment
load_dotenv()
os.environ['OPENAI_API_KEY'] = os.getenv('OPENAI_API_KEY', 'your-key-if-not-using-env')
os.environ['ANTHROPIC_API_KEY'] = os.getenv('ANTHROPIC_API_KEY', 'your-key-if-not-using-env')
os.environ['HF_TOKEN'] = os.getenv('HF_TOKEN', 'your-key-if-not-using-env')
# Log in to HuggingFace
hf_token = os.environ['HF_TOKEN']
login(hf_token, add_to_git_credential=True)
# One more import - the Item class
# If you get an error that you need to agree to Meta's terms when you run this, then follow the link it provides you and follow their instructions
# You should get approved by Meta within minutes
# Any problems - message me or email me!
from items import Item
# Load in our dataset
# Open and read the JSON file
with open('/home/ivob/Projects/llm_engineering/project/data/training_data.json', 'r') as file:
dataset = json.load(file)
# Print the data
print(dataset)
print(f"Number of Situations: {len(dataset):,}")
# Investigate a particular datapoint
datapoint = dataset[2]
# Investigate
print(datapoint["input"])
print(datapoint["result"])
print(datapoint["reason"])
# Plot the distribution of results
# Count the occurrences of "normal" and "anomalous" results
situation_counts = {"normal": 0, "anomalous": 0}
for entry in dataset:
result = entry.get("result", "unknown").lower()
if result in situation_counts:
situation_counts[result] += 1
# Extract keys and values for the bar chart
labels = list(situation_counts.keys())
counts = list(situation_counts.values())
# Plot the bar chart
plt.figure(figsize=(8, 6))
plt.bar(labels, counts, color=['green', 'red'], alpha=0.7)
# Add labels and title
plt.xlabel("Situation Type", fontsize=12)
plt.ylabel("Count", fontsize=12)
plt.title("Number of Normal vs Anomalous Situations", fontsize=14)
# Annotate bars with counts
for i, count in enumerate(counts):
plt.text(i, count + 0.2, str(count), ha='center', fontsize=10)
# Display the plot
plt.tight_layout()
plt.show()
# So what are the anomalous items??
for datapoint in dataset:
try:
result = datapoint["result"]
if result == "anomalous":
print(datapoint['input'])
except ValueError as e:
pass
# Create an Item object for each with a result
items = []
for datapoint in dataset:
try:
result = datapoint["result"]
if result == 'normal' or result == 'anomalous':
item = Item(datapoint, result)
if item.include:
items.append(item)
except ValueError as e:
pass
print(f"There are {len(items):,} items")