Source code for dsci_524_ezplot.plot_heatmap

import matplotlib.pyplot as plt
import seaborn as sns
import pandas as pd
import numpy as np

[docs] def plot_heatmap(data, title=None, cmap="viridis", xlabel=None, ylabel=None): """ Create a heatmap using a pandas DataFrame or a 2D NumPy array. Parameters ---------- data : pandas.DataFrame or numpy.ndarray Input data for the heatmap. Can be a pandas DataFrame or a 2D NumPy array. The data should be numeric; non-numeric values will cause errors. title : str, optional Title of the heatmap. Default is None. cmap : str, optional Colormap for the heatmap. Defaults to 'viridis'. xlabel : str, optional Label for the x-axis. Defaults to None. ylabel : str, optional Label for the y-axis. Defaults to None. Returns ------- tuple - matplotlib.figure.Figure The figure object containing the heatmap. - matplotlib.axes.Axes The axes object containing the heatmap elements. Raises ------ TypeError If the input data is not a pandas DataFrame or a NumPy array. ValueError If the input data is empty or contains non-numeric values. Examples -------- Using a pandas DataFrame: >>> import pandas as pd >>> import numpy as np >>> df = pd.DataFrame(np.random.rand(5, 5), columns=['A', 'B', 'C', 'D', 'E']) >>> fig, ax = plot_heatmap(df, title="Sample Heatmap", xlabel="Columns", ylabel="Rows") Using a NumPy array: >>> arr = np.random.rand(5, 5) >>> fig, ax = plot_heatmap(arr, title="Heatmap from NumPy array") """ # Validate input type if not isinstance(data, (pd.DataFrame, np.ndarray)): raise TypeError("Input data must be a pandas DataFrame or a numpy array.") # Handle empty data if isinstance(data, pd.DataFrame) and data.empty: raise ValueError("DataFrame must not be empty.") if isinstance(data, np.ndarray) and data.size == 0: raise ValueError("NumPy array must not be empty.") # Ensure data is numeric if isinstance(data, pd.DataFrame): if not all(pd.api.types.is_numeric_dtype(dtype) for dtype in data.dtypes): raise TypeError("All columns in the DataFrame must contain numeric data.") elif isinstance(data, np.ndarray): if not np.issubdtype(data.dtype, np.number): raise TypeError("NumPy array must contain numeric data.") # Handle NaN values if isinstance(data, pd.DataFrame) and data.isnull().values.any(): data = data.fillna(0) # Plot the heatmap fig, ax = plt.subplots() sns.heatmap(data, cmap=cmap, ax=ax) ax.set_title(title) if xlabel: ax.set_xlabel(xlabel) if ylabel: ax.set_ylabel(ylabel) return fig, ax