import json
import os
import matplotlib.pyplot as plt

ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))

OUTPUTS_DIR = os.path.join(ROOT_DIR, "outputs")
PLOTS_DIR = os.path.join(OUTPUTS_DIR, "plots")
os.makedirs(PLOTS_DIR, exist_ok=True)

# Load KPIs
with open(os.path.join(OUTPUTS_DIR, "kpis.json"), "r") as f:
    kpis = json.load(f)["kpis"]

# Load categorized transactions
with open(os.path.join(OUTPUTS_DIR, "categorized.json"), "r") as f:
    categorized = json.load(f)["categorized"]

# -------- FIGURE 1: Income vs Spending --------
plt.bar(["Total Spend", "Total Income"], [kpis["total_spend"], kpis["total_income"]])
plt.title("Income vs Spending")
plt.ylabel("Amount ($)")
plt.savefig(os.path.join(PLOTS_DIR, "income_vs_spend.png"))
plt.close()

# -------- FIGURE 2: Top Merchants --------
merchant_totals = {}

for entry in categorized:
    if entry["category"] != "Income":
        merchant = entry["merchant"]
        merchant_totals[merchant] = merchant_totals.get(merchant, 0) + entry["amount"]

amounts = [merchant_totals[m] for m in kpis["top_merchants"]]

plt.bar(kpis["top_merchants"], amounts)
plt.title("Top Merchant Spend Totals")
plt.ylabel("Amount ($)")
plt.savefig(os.path.join(PLOTS_DIR, "top_merchants.png"))
plt.close()

# -------- FIGURE 3: Spending by Category (Bar Chart) --------
category_totals = {}

for entry in categorized:
    cat = entry["category"]
    if cat != "Income":
        category_totals[cat] = category_totals.get(cat, 0) + entry["amount"]

plt.bar(category_totals.keys(), category_totals.values())
plt.title("Spending by Category")
plt.ylabel("Amount ($)")
plt.savefig(os.path.join(PLOTS_DIR, "category_spend.png"))
plt.close()

# -------- FIGURE 4: Spending by Category (Pie Chart) --------
plt.pie(category_totals.values(), labels=category_totals.keys(), autopct='%1.1f%%')
plt.title("Spending Distribution by Category")
plt.savefig(os.path.join(PLOTS_DIR, "category_spend_pie.png"))
plt.close()

print("All charts created successfully in:", PLOTS_DIR)
