##############################################################################
# MIT License
#
# Copyright (c) 2021 - 2025 Advanced Micro Devices, Inc. All Rights Reserved.
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT.  IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN
# THE SOFTWARE.

##############################################################################

import csv
from dataclasses import dataclass
from pathlib import Path
from typing import Any, Optional, Union

import pandas as pd

from utils import schema
from utils.logger import console_debug, console_error, console_warning
from utils.parser import apply_filters, eval_metric
from utils.specs import MachineSpecs

################################################
# Global vars
################################################
XMIN = 0.01
XMAX = 1000

FONT_SIZE = 16
FONT_COLOR = "black"
FONT_WEIGHT = "bold"

# SUPPORTED_DATATYPES table is based on datatype support in rocm-amdgpu-bench repository
# Indicates which datatypes per gpu arch can be generated by the roofline binary
SUPPORTED_DATATYPES: dict[str, list[str]] = {
    "gfx90a": [
        "FP16",
        "BF16",
        "FP32",
        "FP64",
        "I8",
        "I32",
        "I64",
    ],  # Unsupported: F4, F6, F8
    "gfx940": [
        "FP8",
        "FP16",
        "BF16",
        "FP32",
        "FP64",
        "I8",
        "I32",
        "I64",
    ],  # Unsupported: F4, F6
    "gfx941": [
        "FP8",
        "FP16",
        "BF16",
        "FP32",
        "FP64",
        "I8",
        "I32",
        "I64",
    ],  # Unsupported: F4, F6
    "gfx942": [
        "FP8",
        "FP16",
        "BF16",
        "FP32",
        "FP64",
        "I8",
        "I32",
        "I64",
    ],  # Unsupported: F4, F6
    "gfx950": [
        "FP4",
        "FP6",
        "FP8",
        "FP16",
        "BF16",
        "FP32",
        "FP64",
        "I8",
        "I32",
        "I64",
    ],  # Unsupported:
}

PEAK_OPS_DATATYPES = ["FP16", "FP32", "FP64", "I8", "I32", "I64"]
MFMA_DATATYPES = ["FP4", "FP6", "FP8", "FP16", "BF16", "FP32", "FP64", "I8"]
CACHE_HIERARCHY = ["HBM", "L2", "L1", "LDS"]

TOP_N = 10


################################################
# Helper funcs
################################################
@dataclass
class AI_Data:
    KernelName: str
    numCalls: float

    total_flops: float
    valu_flops: float
    mfma_flops_f6f4: float
    mfma_flops_f8: float
    mfma_flops_f16: float
    mfma_flops_bf16: float
    mfma_flops_f32: float
    mfma_flops_f64: float
    mfma_iops_i8: float
    lds_data: float
    L1cache_data: float
    L2cache_data: float
    hbm_data: float

    totalDuration: float
    avgDuration: float


@dataclass
class PlotPoints:
    """Data structure for storing roofline plot points."""

    ai_l1: list[list[float]]
    ai_l2: list[list[float]]
    ai_hbm: list[list[float]]
    kernelNames: list[str]

    @classmethod
    def empty(cls) -> "PlotPoints":
        """Create empty plot points structure."""
        return cls(ai_l1=[[], []], ai_l2=[[], []], ai_hbm=[[], []], kernelNames=[])


@dataclass
class GraphPoints:
    """Data structure for storing roofline graph ceiling points."""

    hbm: list[Union[list[float], float, None]]
    l2: list[Union[list[float], float, None]]
    l1: list[Union[list[float], float, None]]
    lds: list[Union[list[float], float, None]]
    valu: list[Union[list[float], float, None]]
    mfma: list[Union[list[float], float, None]]

    @classmethod
    def empty(cls) -> "GraphPoints":
        """Create empty graph points structure."""
        return cls(
            hbm=[None, None, None],
            l2=[None, None, None],
            l1=[None, None, None],
            lds=[None, None, None],
            valu=[None, None, None],
            mfma=[None, None, None],
        )


################################################
# Helper functions
################################################
def get_font() -> dict[str, Union[int, str]]:
    return {
        "size": FONT_SIZE,
        "color": FONT_COLOR,
        "weight": FONT_WEIGHT,
        "family": "serif",
    }


def get_color(category: str) -> str:
    color_map = {"ai_l1": "green", "ai_l2": "blue", "ai_hbm": "red"}

    if category not in color_map:
        raise RuntimeError(f"Invalid category passed to get_color(): {category}")

    return color_map[category]


# -------------------------------------------------------------------------------------
#                           Plot BW at each cache level
# -------------------------------------------------------------------------------------
def calc_ceilings(
    roofline_parameters: dict[str, Any],
    dtype: str,
    benchmark_data: dict[str, list[str]],
    ai_data: Optional[dict] = None,
) -> dict[str, list[Union[list[float], float, None]]]:
    """Given benchmarking data, calculate ceilings (or peak performance) for
    empirical roofline"""

    if ai_data:
        max_ai = 0
        for cache_level in ["ai_l1", "ai_l2", "ai_hbm"]:
            if cache_level in ai_data and ai_data[cache_level][0]:
                cache_max = max(ai_data[cache_level][0])
                max_ai = max(max_ai, cache_max)

        dynamic_xmax = max_ai * 1.2 if max_ai > 0 else 1000
    else:
        dynamic_xmax = 1000

    # TODO: This is where filtering by memory level will need to occur for standalone
    graph_points: dict[str, list[Union[list[float], float, None]]] = {
        "hbm": [],
        "l2": [],
        "l1": [],
        "lds": [],
        "valu": [],
        "mfma": [],
    }

    cache_hierarchy = (
        CACHE_HIERARCHY
        if roofline_parameters["mem_level"] == "ALL"
        else roofline_parameters["mem_level"]
    )

    x1 = y1 = x2 = y2 = -1
    x1_mfma = y1_mfma = x2_mfma = y2_mfma = -1

    ops_flops = "Ops" if dtype.startswith("I") else "Flops"

    peak_ops = 0.0
    if dtype in PEAK_OPS_DATATYPES:
        try:
            peak_ops = float(
                benchmark_data[f"{dtype}{ops_flops}"][roofline_parameters["device_id"]]
            )
        except KeyError:
            console_warning(
                f"Missing benchmark data for {dtype}{ops_flops} in benchmark_results. "
                "Skipping peak operations calculation. This may indicate incomplete or "
                "corrupted benchmark data."
            )
            return GraphPoints.empty().__dict__

    for cache_level in cache_hierarchy:
        # Plot BW line
        curr_bw = f"{cache_level}Bw"
        try:
            peak_bw = float(benchmark_data[curr_bw][roofline_parameters["device_id"]])
        except KeyError:
            console_warning(
                f"Missing benchmark data for {curr_bw} in benchmark_results. "
                f"Skipping {cache_level} cache level. This may indicate incomplete or "
                "corrupted benchmark data."
            )
            continue

        x1 = float(XMIN)
        y1 = float(XMIN) * peak_bw

        if dtype in PEAK_OPS_DATATYPES:
            x2 = peak_ops / peak_bw
            y2 = peak_ops  # noqa

            # Plot MFMA lines (NOTE: Assuming MI200 soc)
            x1_mfma = peak_ops / peak_bw
            y1_mfma = peak_ops

        peak_mfma = 0.0
        if dtype in MFMA_DATATYPES:
            target_precision = dtype if dtype.startswith("I") else f"F{dtype[2:]}"

            try:
                peak_mfma = float(
                    benchmark_data[f"MFMA{target_precision}{ops_flops}"][
                        roofline_parameters["device_id"]
                    ]
                )
                x2_mfma = peak_mfma / peak_bw
                y2_mfma = peak_mfma
            except KeyError:
                console_warning(
                    f"Missing benchmark data for "
                    f"MFMA{target_precision}{ops_flops} in benchmark_results. "
                    f"Skipping MFMA calculations for {cache_level} cache level. "
                    "This may indicate incomplete or corrupted benchmark data."
                )

        # Check which peak is higher for formatting bandwidth lines
        if y2_mfma > y1_mfma:  # peak_mfma
            peak_x = x2_mfma
            peak_y = y2_mfma
        else:  # peakVALU
            peak_x = x1_mfma
            peak_y = y1_mfma

        cache_key = cache_level.lower()
        graph_points[cache_key].extend([[x1, peak_x], [y1, peak_y], peak_bw])

    # ----------------------------------------------------------------------------------
    #                                     Plot computing roof
    # ----------------------------------------------------------------------------------
    if dtype in PEAK_OPS_DATATYPES:
        # Plot FMA roof
        x0 = min(x2, dynamic_xmax) if x2 < dynamic_xmax else dynamic_xmax

        graph_points["valu"].extend([
            [x0, dynamic_xmax],
            [peak_ops, peak_ops],
            peak_ops,
        ])

    # Plot MFMA roof
    if dtype in MFMA_DATATYPES:  # assert that mfma has been assigned
        x0_mfma = min(x2_mfma, dynamic_xmax) if x2_mfma < dynamic_xmax else dynamic_xmax

        graph_points["mfma"].extend([
            [x0_mfma, dynamic_xmax],
            [peak_mfma, peak_mfma],
            peak_mfma,
        ])

    return graph_points


# -------------------------------------------------------------------------------------
#                              Overlay application performance
# -------------------------------------------------------------------------------------
# Calculate relevant metrics for ai calculation
def calc_ai_analyze(
    workload: schema.Workload,
    mspec: MachineSpecs,
    sort_type: str,
    config: dict[str, Any],
    arch_config: schema.ArchConfig,
) -> dict[str, Union[list[list[float]], list[str]]]:
    """
    Calculate per-kernel metrics and AI points with Roofline yamls using eval_metric.
    """
    console_debug("calc_ai_analyze", "Starting calc_ai analysis using Roofline yamls")
    plot_points = PlotPoints.empty()

    workload.roofline_metrics = {}
    filtered_pmc = apply_filters(workload, workload.path, is_gui=False, debug=False)

    kernel_ids_to_process: list[int] = []
    kernel_top_table_id = 1

    if workload.filter_kernel_ids:
        kernel_ids_to_process = workload.filter_kernel_ids
    elif kernel_top_table_id in workload.dfs:
        kernel_top_df = workload.dfs[kernel_top_table_id]
        kernel_ids_to_process = kernel_top_df.index.tolist()
        console_debug(
            "roofline", f"Found {len(kernel_ids_to_process)} kernels to process"
        )

    if not kernel_ids_to_process:
        console_warning("No kernels found to process for roofline")
        return plot_points.__dict__

    for kernel_id in kernel_ids_to_process:
        kernel_name = ""
        if kernel_top_table_id in workload.dfs:
            kernel_top_df = workload.dfs[kernel_top_table_id]
            if kernel_id not in kernel_top_df.index:
                continue
            kernel_name = kernel_top_df.loc[kernel_id, "Kernel_Name"]
        else:
            continue

        console_debug("roofline", f"Processing kernel {kernel_id}: {kernel_name[:50]}")

        # filter PMC data for specific kernel
        kernel_pmc_df = filtered_pmc[
            filtered_pmc["pmc_perf"]["Kernel_Name"] == kernel_name
        ]

        if kernel_pmc_df.empty:
            console_debug("roofline", f"No PMC data for kernel {kernel_id}")
            continue

        kernel_only_data = {"pmc_perf": kernel_pmc_df["pmc_perf"]}

        kernel_dfs: dict[int, pd.DataFrame] = {}
        kernel_dfs_type: dict[int, str] = {}

        for table_id in [401, 402]:
            if table_id in arch_config.dfs:
                kernel_dfs[table_id] = arch_config.dfs[table_id].copy()
                kernel_dfs_type[table_id] = arch_config.dfs_type[table_id]

        # eval metrics for single kernel only
        eval_metric(
            kernel_dfs,
            kernel_dfs_type,
            workload.sys_info.iloc[0],
            workload.roofline_peaks,
            kernel_only_data,
            debug=False,
            config=config,
        )

        ai_hbm = ai_l2 = ai_l1 = performance = 0

        if 402 in kernel_dfs:
            for idx, row in kernel_dfs[402].iterrows():
                metric = row.get("Metric", "")
                value = row.get("Value", 0)
                if metric == "AI HBM":
                    ai_hbm = value if value and value not in ("", "N/A") else 0
                elif metric == "AI L2":
                    ai_l2 = value if value and value not in ("", "N/A") else 0
                elif metric == "AI L1":
                    ai_l1 = value if value and value not in ("", "N/A") else 0
                elif metric == "Performance (GFLOPs)":
                    performance = value if value and value not in ("", "N/A") else 0

        console_debug(
            "roofline",
            f"Kernel {kernel_id}: "
            f"AI_HBM={ai_hbm:.2f}, "
            f"AI_L2={ai_l2:.2f}, "
            f"AI_L1={ai_l1:.2f}, "
            f"Performance={performance:.2e} GFLOP/s",
        )

        # add to plot points if we have valid data
        if performance > 0:
            if ai_hbm > 0:
                plot_points.ai_hbm[0].append(ai_hbm)
                plot_points.ai_hbm[1].append(performance)
            if ai_l2 > 0:
                plot_points.ai_l2[0].append(ai_l2)
                plot_points.ai_l2[1].append(performance)
            if ai_l1 > 0:
                plot_points.ai_l1[0].append(ai_l1)
                plot_points.ai_l1[1].append(performance)

            plot_points.kernelNames.append(f"K{kernel_id}")
            console_debug("roofline", f"Added kernel {kernel_id} to plot points")
        else:
            console_debug(
                "roofline", f"Skipping kernel {kernel_id} - no performance data"
            )

        # store metrics for display
        workload.roofline_metrics[kernel_id] = {
            "name": kernel_name,
            "ai_table": kernel_dfs.get(401, pd.DataFrame()),
            "calc_table": kernel_dfs.get(402, pd.DataFrame()),
        }

    console_debug("roofline", f"Generated {len(plot_points.kernelNames)} plot points")
    return plot_points.__dict__


def calc_ai_profile(
    mspec: MachineSpecs,
    sort_type: str,
    ret_df: dict[str, pd.DataFrame],
    iteration_multiplexing: str,
) -> dict[str, Union[list[list[float]], list[str]]]:
    """Given counter data, calculate arithmetic intensity for each kernel
    in the application. Leverage hard-coded equations to calculate AI values.

    Used during profiling stage to generate roofline HTML, since Roofline yamls
    are not available in the profiling stage."""

    console_debug(
        "calc_ai_profile: Starting legacy roofline calculation (from roofline_calc)"
    )
    df = ret_df["pmc_perf"]
    # Sort by top kernels or top dispatches?
    df = df.sort_values(by=["Kernel_Name"]).reset_index(drop=True)

    total_flops = valu_flops = mfma_flops_f6f4 = mfma_flops_f8 = mfma_flops_bf16 = (
        mfma_flops_f16
    ) = mfma_iops_i8 = mfma_flops_f32 = mfma_flops_f64 = lds_data = L1cache_data = (
        L2cache_data
    ) = hbm_data = calls = totalDuration = avgDuration = 0.0

    kernel_name = ""
    my_list: list[AI_Data] = []

    supported_dt = (
        SUPPORTED_DATATYPES[mspec.gpu_arch]
        if mspec.gpu_arch in SUPPORTED_DATATYPES
        else None
    )

    for idx in df.index:
        # CASE: Top kernels
        # Calculate + append AI data if
        # a) current KernelName is different than previous OR
        # b) We've reached the end of list
        at_end = idx + 1 == df.shape[0]
        next_kernel_name = df["Kernel_Name"][idx + 1] if not at_end else ""
        kernel_name = df["Kernel_Name"][idx]

        # Skip this kernel dispatch row if any counter value is n/a
        if df.iloc[idx].isna().any():
            continue

        try:
            total_flops += (
                (
                    64
                    * (
                        df["SQ_INSTS_VALU_ADD_F16"][idx]
                        + df["SQ_INSTS_VALU_MUL_F16"][idx]
                        + (2 * df["SQ_INSTS_VALU_FMA_F16"][idx])
                        + df["SQ_INSTS_VALU_TRANS_F16"][idx]
                    )
                )
                + (
                    64
                    * (
                        df["SQ_INSTS_VALU_ADD_F32"][idx]
                        + df["SQ_INSTS_VALU_MUL_F32"][idx]
                        + (2 * df["SQ_INSTS_VALU_FMA_F32"][idx])
                        + df["SQ_INSTS_VALU_TRANS_F32"][idx]
                    )
                )
                + (
                    64
                    * (
                        df["SQ_INSTS_VALU_ADD_F64"][idx]
                        + df["SQ_INSTS_VALU_MUL_F64"][idx]
                        + (2 * df["SQ_INSTS_VALU_FMA_F64"][idx])
                        + df["SQ_INSTS_VALU_TRANS_F64"][idx]
                    )
                )
                + (df["SQ_INSTS_VALU_MFMA_MOPS_F16"][idx] * 512)
                + (df["SQ_INSTS_VALU_MFMA_MOPS_BF16"][idx] * 512)
                + (df["SQ_INSTS_VALU_MFMA_MOPS_F32"][idx] * 512)
                + (df["SQ_INSTS_VALU_MFMA_MOPS_F64"][idx] * 512)
            )
            if "FP8" in supported_dt:
                total_flops += df["SQ_INSTS_VALU_MFMA_MOPS_F8"][idx] * 512
            if ("FP4" in supported_dt) or ("FP6" in supported_dt):
                total_flops += df["SQ_INSTS_VALU_MFMA_MOPS_F6F4"][idx] * 512
        except KeyError as e:
            console_debug(
                "roofline",
                f"{kernel_name[:35]}: Skipped total_flops at index \
                    {idx} due to {e}",
            )
            pass
        try:
            valu_flops += (
                64
                * (
                    df["SQ_INSTS_VALU_ADD_F16"][idx]
                    + df["SQ_INSTS_VALU_MUL_F16"][idx]
                    + (2 * df["SQ_INSTS_VALU_FMA_F16"][idx])
                    + df["SQ_INSTS_VALU_TRANS_F16"][idx]
                )
                + 64
                * (
                    df["SQ_INSTS_VALU_ADD_F32"][idx]
                    + df["SQ_INSTS_VALU_MUL_F32"][idx]
                    + (2 * df["SQ_INSTS_VALU_FMA_F32"][idx])
                    + df["SQ_INSTS_VALU_TRANS_F32"][idx]
                )
                + 64
                * (
                    df["SQ_INSTS_VALU_ADD_F64"][idx]
                    + df["SQ_INSTS_VALU_MUL_F64"][idx]
                    + (2 * df["SQ_INSTS_VALU_FMA_F64"][idx])
                    + df["SQ_INSTS_VALU_TRANS_F64"][idx]
                )
            )
        except KeyError as e:
            console_debug(
                "roofline",
                f"{kernel_name[:35]}: Skipped valu_flops at index {idx} due to {e}",
            )
            pass

        try:
            if "FP8" in supported_dt:
                mfma_flops_f8 += df["SQ_INSTS_VALU_MFMA_MOPS_F8"][idx] * 512
            if ("FP4" in supported_dt) or ("FP6" in supported_dt):
                mfma_flops_f6f4 += df["SQ_INSTS_VALU_MFMA_MOPS_F6F4"][idx] * 512
            mfma_flops_f16 += df["SQ_INSTS_VALU_MFMA_MOPS_F16"][idx] * 512
            mfma_flops_bf16 += df["SQ_INSTS_VALU_MFMA_MOPS_BF16"][idx] * 512
            mfma_flops_f32 += df["SQ_INSTS_VALU_MFMA_MOPS_F32"][idx] * 512
            mfma_flops_f64 += df["SQ_INSTS_VALU_MFMA_MOPS_F64"][idx] * 512
            mfma_iops_i8 += df["SQ_INSTS_VALU_MFMA_MOPS_I8"][idx] * 512
        except KeyError as e:
            console_debug(
                "roofline",
                f"{kernel_name[:35]}: Skipped mfma ops at index {idx} due to {e}",
            )
            pass

        try:
            lds_data += (
                (df["SQ_LDS_IDX_ACTIVE"][idx] - df["SQ_LDS_BANK_CONFLICT"][idx])
                * 4
                * (mspec.lds_banks_per_cu)
            )
        except KeyError as e:
            console_debug(
                "roofline",
                f"{kernel_name[:35]}: Skipped lds_data at index {idx} due to {e}",
            )
            pass

        try:
            L1cache_data += df["TCP_TOTAL_CACHE_ACCESSES_sum"][idx] * 64
        except KeyError as e:
            console_debug(
                "roofline",
                f"{kernel_name[:35]}: Skipped L1cache_data at index \
                    {idx} due to {e}",
            )
            pass

        try:
            L2cache_data += (
                df["TCP_TCC_WRITE_REQ_sum"][idx] * 64
                + df["TCP_TCC_ATOMIC_WITH_RET_REQ_sum"][idx] * 64
                + df["TCP_TCC_ATOMIC_WITHOUT_RET_REQ_sum"][idx] * 64
                + df["TCP_TCC_READ_REQ_sum"][idx] * 64
            )
        except KeyError as e:
            console_debug(
                "roofline",
                f"{kernel_name[:35]}: Skipped L2cache_data at index \
                    {idx} due to {e}",
            )
            pass
        try:
            if mspec.gpu_series == "MI200":
                hbm_data += (
                    (df["TCC_EA_RDREQ_32B_sum"][idx] * 32)
                    + (
                        (df["TCC_EA_RDREQ_sum"][idx] - df["TCC_EA_RDREQ_32B_sum"][idx])
                        * 64
                    )
                    + (df["TCC_EA_WRREQ_64B_sum"][idx] * 64)
                    + (
                        (df["TCC_EA_WRREQ_sum"][idx] - df["TCC_EA_WRREQ_64B_sum"][idx])
                        * 32
                    )
                )
            elif mspec.gpu_series == "MI350":
                # Use TCC_EA0_RDREQ_128B_sum TCC_EA0_RDREQ_64B_sum to calculate hbm_data
                hbm_data += (
                    (df["TCC_EA0_RDREQ_128B_sum"][idx] * 128)
                    + (df["TCC_EA0_RDREQ_64B_sum"][idx] * 64)
                    + (df["TCC_EA0_RDREQ_32B_sum"][idx] * 32)
                    + (
                        (
                            df["TCC_EA0_WRREQ_sum"][idx]
                            - df["TCC_EA0_WRREQ_64B_sum"][idx]
                        )
                        * 32
                    )
                    + (df["TCC_EA0_WRREQ_64B_sum"][idx] * 64)
                )
            else:
                # Use TCC_BUBBLE_sum to calculate hbm_data
                hbm_data += (
                    (df["TCC_BUBBLE_sum"][idx] * 128)
                    + (df["TCC_EA0_RDREQ_32B_sum"][idx] * 32)
                    + (
                        (
                            df["TCC_EA0_RDREQ_sum"][idx]
                            - df["TCC_BUBBLE_sum"][idx]
                            - df["TCC_EA0_RDREQ_32B_sum"][idx]
                        )
                        * 64
                    )
                    + (
                        (
                            df["TCC_EA0_WRREQ_sum"][idx]
                            - df["TCC_EA0_WRREQ_64B_sum"][idx]
                        )
                        * 32
                    )
                    + (df["TCC_EA0_WRREQ_64B_sum"][idx] * 64)
                )
        except KeyError as e:
            console_debug(
                "roofline",
                f"{kernel_name[:35]}: Skipped hbm_data at index {idx} due to {e}",
            )
            pass

        totalDuration += df["End_Timestamp"][idx] - df["Start_Timestamp"][idx]
        avgDuration += df["End_Timestamp"][idx] - df["Start_Timestamp"][idx]
        calls += 1

        if sort_type == "kernels" and (at_end or (kernel_name != next_kernel_name)):
            my_list.append(
                AI_Data(
                    kernel_name,
                    calls,
                    total_flops / calls,
                    valu_flops / calls,
                    mfma_flops_f6f4 / calls,
                    mfma_flops_f8 / calls,
                    mfma_flops_f16 / calls,
                    mfma_flops_bf16 / calls,
                    mfma_flops_f32 / calls,
                    mfma_flops_f64 / calls,
                    mfma_iops_i8 / calls,
                    lds_data / calls,
                    L1cache_data / calls,
                    L2cache_data / calls,
                    hbm_data / calls,
                    totalDuration,
                    avgDuration / calls,
                )
            )
            console_debug(f"Just added {kernel_name} to AI_Data. # of calls: {calls}")

            total_flops = valu_flops = mfma_flops_f6f4 = mfma_flops_f8 = (
                mfma_flops_bf16
            ) = mfma_flops_f16 = mfma_iops_i8 = mfma_flops_f32 = mfma_flops_f64 = (
                lds_data
            ) = L1cache_data = L2cache_data = hbm_data = calls = totalDuration = (
                avgDuration
            ) = 0.0

        if sort_type == "dispatches":
            my_list.append(
                AI_Data(
                    kernel_name,
                    calls,
                    total_flops,
                    valu_flops,
                    mfma_flops_f6f4,
                    mfma_flops_f8,
                    mfma_flops_f16,
                    mfma_flops_bf16,
                    mfma_flops_f32,
                    mfma_flops_f64,
                    mfma_iops_i8,
                    lds_data,
                    L1cache_data,
                    L2cache_data,
                    hbm_data,
                    totalDuration,
                    avgDuration,
                )
            )
            total_flops = valu_flops = mfma_flops_f6f4 = mfma_flops_f8 = (
                mfma_flops_bf16
            ) = mfma_flops_f16 = mfma_iops_i8 = mfma_flops_f32 = mfma_flops_f64 = (
                lds_data
            ) = L1cache_data = L2cache_data = hbm_data = calls = totalDuration = (
                avgDuration
            ) = 0.0

    my_list.sort(key=lambda x: x.totalDuration, reverse=True)

    intensities: dict[str, list[float]] = {"ai_l1": [], "ai_l2": [], "ai_hbm": []}
    curr_perf: list[float] = []
    kernel_names: list[str] = []

    # Create list of top N intensities
    for i in range(min(TOP_N, len(my_list))):
        kernel_data = my_list[i]

        if my_list[i].total_flops == 0:
            console_debug(
                f"No flops counted for {my_list[i].KernelName}, "
                "arithmetic intensities will not display on plots."
            )

        kernel_names.append(my_list[i].KernelName)

        # Calculate arithmetic intensities
        intensities["ai_l1"].append(
            kernel_data.total_flops / kernel_data.L1cache_data
            if kernel_data.L1cache_data
            else 0
        )
        intensities["ai_l2"].append(
            kernel_data.total_flops / kernel_data.L2cache_data
            if kernel_data.L2cache_data
            else 0
        )
        intensities["ai_hbm"].append(
            kernel_data.total_flops / kernel_data.hbm_data
            if kernel_data.hbm_data
            else 0
        )
        curr_perf.append(
            kernel_data.total_flops / kernel_data.avgDuration
            if kernel_data.avgDuration
            else 0
        )

    # Create intensity points for plotting
    intensity_points: dict[str, Union[list[list[float]], list[str]]] = {}

    for ai_type in intensities:
        values = intensities[ai_type]

        x = values
        y = curr_perf[: len(values)]
        intensity_points[ai_type] = [x, y]

    # Add kernel names
    intensity_points["kernelNames"] = kernel_names
    return intensity_points


def validate_roofline_csv(workload_dir: Union[str, Path, list]) -> tuple[bool, str]:
    """
    Validate roofline.csv exists and has consistent structure.

    Returns:
        tuple: (is_valid, error_message)
               is_valid=True if CSV is valid, False otherwise
               error_message contains description if invalid
    """
    if isinstance(workload_dir, list):
        base_dir = (
            workload_dir[0][0]
            if isinstance(workload_dir[0], (list, tuple))
            else workload_dir[0]
        )
    else:
        base_dir = workload_dir

    benchmark_results = Path(base_dir) / "roofline.csv"

    # Check if file exists
    if not benchmark_results.exists():
        return False, f"Benchmark results file not found: {benchmark_results}"

    # Validate CSV structure
    try:
        with open(benchmark_results) as csvfile:
            csv_reader = csv.reader(csvfile, delimiter=",")
            row_count = 0
            num_headers = 0

            for row in csv_reader:
                if row_count == 0:
                    num_headers = len(row) - 1
                    if num_headers <= 0:
                        return (
                            False,
                            "Empty or invalid header row in benchmark_results",
                        )
                else:
                    if len(row) - 1 != num_headers:
                        return (
                            False,
                            f"Inconsistent row length in benchmark_results at "
                            f"row {row_count + 1}. "
                            f"Expected {num_headers + 1} columns, "
                            f"found {len(row)}. "
                            "Roofline data appears corrupted or incomplete.",
                        )
                row_count += 1

            if row_count < 2:
                return (
                    False,
                    f"Insufficient data in benchmark_results. "
                    f"Found {row_count} rows (need at least 2)."
                    f" Roofline data appears corrupted or incomplete.",
                )
    except Exception as e:
        return False, f"Failed to read benchmark_results: {e}"

    return True, ""


def construct_roof(
    roofline_parameters: dict[str, Any], dtype: str, ai_data: Optional[dict] = None
) -> dict[str, list[Union[list[float], float, None]]]:
    workload_dir = roofline_parameters.get("workload_dir")

    # Normalize workload_dir to extract base directory
    if isinstance(workload_dir, list):
        base_dir = (
            workload_dir[0][0]
            if isinstance(workload_dir[0], (list, tuple))
            else workload_dir[0]
        )
    else:
        base_dir = workload_dir

    benchmark_results = Path(base_dir) / "roofline.csv"

    # Initialize benchmark data dictionary from roofline.csv
    benchmark_data: dict[str, list[str]] = {}
    headers: list[str] = []

    try:
        with open(benchmark_results) as csvfile:
            csv_reader = csv.reader(csvfile, delimiter=",")
            row_count = 0

            for row in csv_reader:
                row.pop(0)  # Remove first column (Device ID)
                if row_count == 0:
                    headers = row
                    for header in headers:
                        benchmark_data[header] = []
                else:
                    for i, key in enumerate(headers):
                        benchmark_data[key].append(row[i])
                row_count += 1
    except Exception as e:
        console_error(
            "roofline",
            f"Failed to read benchmark results from {base_dir}: {e}",
            exit=False,
        )
        return GraphPoints.empty().__dict__

    # ------------------
    #  Validate benchmark data completeness
    # ------------------
    ops_flops = "Ops" if dtype.startswith("I") else "Flops"
    expected_columns = []

    if dtype in PEAK_OPS_DATATYPES:
        expected_columns.append(f"{dtype}{ops_flops}")

    cache_hierarchy = (
        CACHE_HIERARCHY
        if roofline_parameters["mem_level"] == "ALL"
        else roofline_parameters["mem_level"]
    )
    for cache_level in cache_hierarchy:
        expected_columns.append(f"{cache_level}Bw")

    if dtype in MFMA_DATATYPES:
        target_precision = dtype if dtype.startswith("I") else f"F{dtype[2:]}"
        expected_columns.append(f"MFMA{target_precision}{ops_flops}")

    # Check for missing expected columns
    missing_columns = [col for col in expected_columns if col not in benchmark_data]
    if missing_columns:
        console_warning(
            f"Missing expected columns in roofline.csv for datatype {dtype}: "
            f"{', '.join(missing_columns)}. "
            "The roofline plot may be incomplete. Consider regenerating "
            "benchmark data or cleaning the directory and re-running the analysis."
        )

    # ------------------
    #  Generate Roofline
    # ------------------
    return calc_ceilings(roofline_parameters, dtype, benchmark_data, ai_data)
