Source code for ecgdatakit.plotting.interactive

"""Interactive ECG plots using plotly.

All public functions return ``plotly.graph_objects.Figure``.

Requires: ``pip install ecgdatakit[plotting-interactive]``
"""

from __future__ import annotations

from typing import TYPE_CHECKING

import numpy as np
from numpy.typing import NDArray

from ecgdatakit.models import ECGRecord, Lead, LeadLike
from ecgdatakit.plotting._core import (
    GRID_12LEAD,
    STANDARD_12LEAD,
    _find_lead,
    _grid_shape,
    _resolve_leads,
    ensure_lead,
    lead_color,
    require_plotly,
    time_axis,
)

if TYPE_CHECKING:
    import plotly.graph_objects as go


def _x_data_i(lead, x_axis):
    """Return ``(x_array, xlabel, hover_template)`` based on *x_axis* mode."""
    if x_axis == "samples":
        x = np.arange(1, len(lead.samples) + 1)
        return x, "Sample", "Sample: %{x}<br>Amplitude: %{y:.3f}<extra></extra>"
    x = time_axis(lead)
    return x, "Time (s)", "Time: %{x:.3f}s<br>Amplitude: %{y:.3f}<extra></extra>"



[docs] def iplot_lead( lead: LeadLike, peaks: NDArray[np.intp] | None = None, title: str | None = None, height: int = 300, *, fs: int | None = None, show: bool = True, x_axis: str = "time", ) -> go.Figure: """Interactive single lead with hover showing time/amplitude. Parameters ---------- lead : Lead | NDArray[np.float64] ECG lead or raw signal array to plot. peaks : NDArray | None Optional R-peak indices to mark. title : str | None Figure title. height : int Figure height in pixels. fs : int | None Sample rate in Hz. Required when *lead* is a numpy array. show : bool Display the plot immediately (default ``True``). x_axis : str ``"time"`` for seconds (default) or ``"samples"`` for sample indices. """ lead = ensure_lead(lead, fs=fs) require_plotly() import plotly.graph_objects as go x, xlabel, hover_tpl = _x_data_i(lead, x_axis) fig = go.Figure() fig.add_trace(go.Scatter( x=x, y=lead.samples, mode="lines", name=lead.label, line=dict(color=lead_color(lead.label), width=1), hovertemplate=hover_tpl, )) if peaks is not None and len(peaks) > 0: if x_axis == "samples": peak_hover = "Sample: %{x}<br>Amplitude: %{y:.3f}<extra></extra>" else: peak_hover = "Peak at %{x:.3f}s<br>Amplitude: %{y:.3f}<extra></extra>" fig.add_trace(go.Scatter( x=x[peaks], y=lead.samples[peaks], mode="markers", name="R-peaks", marker=dict(color="red", size=8, symbol="triangle-down"), hovertemplate=peak_hover, )) fig.update_layout( title=title or lead.label, xaxis_title=xlabel, yaxis_title=f"Amplitude ({lead.units})" if lead.units else "Amplitude", height=height, template="plotly_white", xaxis=dict( rangeslider=dict(visible=True), showspikes=True, spikemode="across", spikethickness=1, ), yaxis=dict(showspikes=True, spikemode="across", spikethickness=1), hovermode="x unified", ) if show: fig.show() return fig
[docs] def iplot_leads( leads: list[Lead] | ECGRecord | NDArray[np.float64] | list[NDArray[np.float64]], peaks_dict: dict[str, NDArray[np.intp]] | None = None, title: str | None = None, height: int | None = None, *, fs: int | None = None, show: bool = True, x_axis: str = "time", rows: int | None = None, cols: int | None = None, ) -> go.Figure: """Interactive leads in a grid layout (vertical stack by default). Parameters ---------- leads : list[Lead] | ECGRecord | NDArray | list[NDArray] Leads to plot. Also accepts a 2-D numpy array (n_leads × n_samples) or a list of 1-D numpy arrays. peaks_dict : dict | None ``{label: peaks_array}`` for per-lead peak markers. title : str | None Overall title. height : int | None Figure height (auto-calculated if ``None``). fs : int | None Sample rate in Hz. Required when *leads* is a numpy array. show : bool Display the plot immediately (default ``True``). x_axis : str ``"time"`` for seconds (default) or ``"samples"`` for sample indices. rows : int | None Number of rows in the subplot grid. Derived from *cols* or defaults to one row per lead when neither is given. cols : int | None Number of columns in the subplot grid (default ``1``). """ require_plotly() import plotly.graph_objects as go from plotly.subplots import make_subplots lead_list, _ = _resolve_leads(leads, fs=fs) n = len(lead_list) if n == 0: return go.Figure() r, c = _grid_shape(n, rows, cols) h = height or max(300, 150 * r) subplot_titles = [ld.label for ld in lead_list] + [""] * (r * c - n) fig = make_subplots( rows=r, cols=c, shared_xaxes=True, subplot_titles=subplot_titles, vertical_spacing=max(0.02, 0.12 / r), ) for i, ld in enumerate(lead_list): ri, ci = divmod(i, c) x, _, _ = _x_data_i(ld, x_axis) fig.add_trace( go.Scatter( x=x, y=ld.samples, mode="lines", name=ld.label, line=dict(color=lead_color(ld.label), width=1), hovertemplate="%{y:.3f}<extra></extra>", ), row=ri + 1, col=ci + 1, ) if peaks_dict and ld.label in peaks_dict: pk = peaks_dict[ld.label] fig.add_trace( go.Scatter( x=x[pk], y=ld.samples[pk], mode="markers", name=f"{ld.label} peaks", marker=dict(color="red", size=6, symbol="triangle-down"), showlegend=False, ), row=ri + 1, col=ci + 1, ) xlabel = "Sample" if x_axis == "samples" else "Time (s)" fig.update_layout( title=title or "ECG Leads", height=h, template="plotly_white", hovermode="x unified", showlegend=True, ) # X-axis label on bottom row only for ci in range(1, c + 1): fig.update_xaxes(title_text=xlabel, row=r, col=ci) if show: fig.show() return fig
[docs] def iplot_12lead( leads: list[Lead] | ECGRecord | NDArray[np.float64] | list[NDArray[np.float64]], record: ECGRecord | None = None, title: str | None = None, height: int | None = None, *, fs: int | None = None, show: bool = True, x_axis: str = "time", rows: int | None = None, cols: int | None = None, ) -> go.Figure: """Interactive 12-lead plot with standard lead names. Unlike :func:`iplot_leads`, this function assigns the standard 12-lead names (I, II, III, aVR, …, V6) when the input contains unnamed leads. The full signal is plotted without cropping. Parameters ---------- leads : list[Lead] | ECGRecord | NDArray | list[NDArray] Leads (or record) to plot. Also accepts a 2-D numpy array (n_leads × n_samples) or a list of 1-D numpy arrays. record : ECGRecord | None Optional record for header annotations. title : str | None Overall figure title. height : int | None Figure height in pixels (auto-calculated if ``None``). fs : int | None Sample rate in Hz. Required when *leads* is a numpy array. show : bool Display the plot immediately (default ``True``). x_axis : str ``"time"`` for seconds (default) or ``"samples"`` for sample indices. rows : int | None Number of rows in the subplot grid. Derived from *cols* or defaults to one row per lead when neither is given. cols : int | None Number of columns in the subplot grid (default ``1``). """ require_plotly() import plotly.graph_objects as go from plotly.subplots import make_subplots lead_list, rec = _resolve_leads(leads, fs=fs) if record is not None: rec = record n = len(lead_list) if n == 0: return go.Figure() # Assign standard 12-lead names when leads are unnamed for i, ld in enumerate(lead_list): if i < len(STANDARD_12LEAD) and ld.label.startswith("Lead "): ld.label = STANDARD_12LEAD[i] r, c = _grid_shape(n, rows, cols) h = height or max(300, 150 * r) subplot_titles = [ld.label for ld in lead_list] + [""] * (r * c - n) fig = make_subplots( rows=r, cols=c, shared_xaxes=True, subplot_titles=subplot_titles, vertical_spacing=max(0.02, 0.12 / r), ) for i, ld in enumerate(lead_list): ri, ci = divmod(i, c) x, _, _ = _x_data_i(ld, x_axis) fig.add_trace( go.Scatter( x=x, y=ld.samples, mode="lines", name=ld.label, line=dict(color=lead_color(ld.label), width=1), showlegend=False, hovertemplate=f"{ld.label}<br>%{{x:.3f}}: %{{y:.3f}}<extra></extra>", ), row=ri + 1, col=ci + 1, ) if rec is not None: header_text = _build_header_text(rec) fig.add_annotation( text=header_text, xref="paper", yref="paper", x=0, y=1.08, showarrow=False, font=dict(size=10, family="monospace"), align="left", ) xlabel = "Sample" if x_axis == "samples" else "Time (s)" layout_opts: dict = dict( height=h, hovermode="closest", margin=dict(t=80 if rec or title else 40), template="plotly_white", ) if title: layout_opts["title_text"] = title fig.update_layout(**layout_opts) for ci in range(1, c + 1): fig.update_xaxes(title_text=xlabel, row=r, col=ci) if show: fig.show() return fig
def _build_header_text(record: ECGRecord) -> str: """Build compact header text for plotly annotations.""" parts = [] p = record.patient name = f"{p.first_name} {p.last_name}".strip() if name: parts.append(f"<b>{name}</b>") if p.patient_id: parts.append(f"ID: {p.patient_id}") if p.age is not None: parts.append(f"Age: {p.age}") if p.sex: parts.append(f"Sex: {p.sex}") m = record.measurements meas = [] if m.heart_rate is not None: meas.append(f"HR:{m.heart_rate}") if m.pr_interval is not None: meas.append(f"PR:{m.pr_interval}") if m.qrs_duration is not None: meas.append(f"QRS:{m.qrs_duration}") if m.qtc_bazett is not None: meas.append(f"QTc:{m.qtc_bazett}") if meas: parts.append(" | ".join(meas)) return " \u2014 ".join(parts) if parts else "ECG Report"
[docs] def iplot_peaks( lead: LeadLike, peaks: NDArray[np.intp] | None = None, title: str | None = None, height: int = 350, *, fs: int | None = None, show: bool = True, x_axis: str = "time", ) -> go.Figure: """Interactive lead with R-peak markers. Hover on peaks shows peak index, RR interval, and instantaneous HR. Parameters ---------- lead : Lead | NDArray[np.float64] ECG lead or raw signal array to plot. peaks : NDArray | None R-peak indices. Auto-detected if ``None``. fs : int | None Sample rate in Hz. Required when *lead* is a numpy array. show : bool Display the plot immediately (default ``True``). x_axis : str ``"time"`` for seconds (default) or ``"samples"`` for sample indices. """ from ecgdatakit.processing.peaks import detect_r_peaks lead = ensure_lead(lead, fs=fs) require_plotly() import plotly.graph_objects as go if peaks is None: peaks = detect_r_peaks(lead) x, xlabel, hover_tpl = _x_data_i(lead, x_axis) fig = go.Figure() fig.add_trace(go.Scatter( x=x, y=lead.samples, mode="lines", name=lead.label, line=dict(color=lead_color(lead.label), width=1), hovertemplate=hover_tpl, )) if len(peaks) > 0: hover_texts = [] for i, p in enumerate(peaks): parts = [f"Peak #{i}", f"Pos: {x[p]:.3f}" if x_axis == "time" else f"Sample: {x[p]}"] if i > 0: rr = (peaks[i] - peaks[i - 1]) / lead.sampling_rate * 1000 hr = 60_000.0 / rr if rr > 0 else 0 parts.append(f"RR: {rr:.0f}ms") parts.append(f"HR: {hr:.0f}bpm") hover_texts.append("<br>".join(parts)) fig.add_trace(go.Scatter( x=x[peaks], y=lead.samples[peaks], mode="markers", name="R-peaks", marker=dict(color="red", size=9, symbol="triangle-down"), hovertext=hover_texts, hoverinfo="text", )) fig.update_layout( title=title or f"{lead.label} \u2014 R-peaks", xaxis_title=xlabel, yaxis_title=f"Amplitude ({lead.units})" if lead.units else "Amplitude", height=height, template="plotly_white", xaxis=dict(rangeslider=dict(visible=True)), hovermode="closest", ) if show: fig.show() return fig
[docs] def iplot_spectrum( lead: LeadLike, method: str = "welch", height: int = 400, *, fs: int | None = None, show: bool = True, ) -> go.Figure: """Interactive spectrum with frequency band highlighting. Parameters ---------- lead : Lead | NDArray[np.float64] ECG lead or raw signal array. method : str ``"welch"`` for PSD or ``"fft"`` for magnitude spectrum. fs : int | None Sample rate in Hz. Required when *lead* is a numpy array. show : bool Display the plot immediately (default ``True``). """ from ecgdatakit.processing.transforms import fft as ecg_fft from ecgdatakit.processing.transforms import power_spectrum lead = ensure_lead(lead, fs=fs) require_plotly() import plotly.graph_objects as go fig = go.Figure() if method == "welch": freqs, power = power_spectrum(lead) power_db = 10 * np.log10(np.maximum(power, 1e-20)) fig.add_trace(go.Scatter( x=freqs, y=power_db, mode="lines", name="PSD", line=dict(color=lead_color(lead.label), width=1), hovertemplate="Freq: %{x:.2f} Hz<br>Power: %{y:.1f} dB/Hz<extra></extra>", )) y_label = "Power (dB/Hz)" title_suffix = "PSD (Welch)" else: freqs, mags = ecg_fft(lead) fig.add_trace(go.Scatter( x=freqs, y=mags, mode="lines", name="FFT", line=dict(color=lead_color(lead.label), width=1), hovertemplate="Freq: %{x:.2f} Hz<br>Magnitude: %{y:.4f}<extra></extra>", )) y_label = "Magnitude" title_suffix = "FFT Magnitude" nyquist = lead.sampling_rate / 2 fig.add_vrect(x0=0.05, x1=150, fillcolor="green", opacity=0.05, line_width=0, annotation_text="ECG band") fig.update_layout( title=f"{lead.label} \u2014 {title_suffix}", xaxis_title="Frequency (Hz)", yaxis_title=y_label, height=height, template="plotly_white", xaxis=dict(range=[0, min(nyquist, 250)]), hovermode="x unified", ) if show: fig.show() return fig
[docs] def iplot_rr_tachogram( rr_ms: NDArray[np.float64], height: int = 300, *, show: bool = True, ) -> go.Figure: """Interactive RR interval tachogram. Parameters ---------- rr_ms : NDArray RR intervals in milliseconds. show : bool Display the plot immediately (default ``True``). """ require_plotly() import plotly.graph_objects as go beats = np.arange(len(rr_ms)) mean_rr = float(rr_ms.mean()) std_rr = float(rr_ms.std()) fig = go.Figure() fig.add_trace(go.Scatter( x=beats, y=rr_ms, mode="lines+markers", name="RR intervals", line=dict(color="#1f77b4", width=1), marker=dict(size=3), hovertemplate="Beat #%{x}<br>RR: %{y:.0f} ms<extra></extra>", )) fig.add_hline(y=mean_rr, line_dash="dash", line_color="red", annotation_text=f"Mean: {mean_rr:.0f} ms") fig.add_hline(y=mean_rr + std_rr, line_dash="dot", line_color="orange", annotation_text=f"+SD: {std_rr:.0f} ms") fig.add_hline(y=mean_rr - std_rr, line_dash="dot", line_color="orange") fig.update_layout( title="RR Tachogram", xaxis_title="Beat number", yaxis_title="RR interval (ms)", height=height, template="plotly_white", hovermode="x unified", ) if show: fig.show() return fig
[docs] def iplot_poincare( rr_ms: NDArray[np.float64], height: int = 500, *, show: bool = True, ) -> go.Figure: """Interactive Poincar\u00e9 plot with SD1/SD2 ellipse. Parameters ---------- rr_ms : NDArray RR intervals in milliseconds. show : bool Display the plot immediately (default ``True``). """ require_plotly() import plotly.graph_objects as go fig = go.Figure() if len(rr_ms) < 2: fig.add_annotation(text="Need >= 2 RR intervals", xref="paper", yref="paper", x=0.5, y=0.5, showarrow=False) return fig x = rr_ms[:-1] y = rr_ms[1:] hover = [f"Beat {i}\u2192{i+1}<br>RR(n): {x[i]:.0f} ms<br>RR(n+1): {y[i]:.0f} ms" for i in range(len(x))] fig.add_trace(go.Scatter( x=x, y=y, mode="markers", name="RR pairs", marker=dict(color="#1f77b4", size=5, opacity=0.5), hovertext=hover, hoverinfo="text", )) lo, hi = min(x.min(), y.min()), max(x.max(), y.max()) margin = (hi - lo) * 0.1 fig.add_trace(go.Scatter( x=[lo - margin, hi + margin], y=[lo - margin, hi + margin], mode="lines", name="Identity", line=dict(color="black", width=0.5, dash="dash"), showlegend=False, )) sd1 = float(np.std(y - x, ddof=1) / np.sqrt(2)) sd2 = float(np.std(y + x, ddof=1) / np.sqrt(2)) cx, cy = float(x.mean()), float(y.mean()) theta = np.linspace(0, 2 * np.pi, 100) cos45, sin45 = np.cos(np.pi / 4), np.sin(np.pi / 4) ex = sd2 * np.cos(theta) ey = sd1 * np.sin(theta) rx = cx + ex * cos45 - ey * sin45 ry = cy + ex * sin45 + ey * cos45 fig.add_trace(go.Scatter( x=rx, y=ry, mode="lines", name=f"SD1={sd1:.1f}, SD2={sd2:.1f}", line=dict(color="red", width=1.5, dash="dash"), )) fig.update_layout( title="Poincar\u00e9 Plot", xaxis_title="RR(n) (ms)", yaxis_title="RR(n+1) (ms)", height=height, template="plotly_white", yaxis=dict(scaleanchor="x", scaleratio=1), hovermode="closest", ) if show: fig.show() return fig
[docs] def iplot_report( record: ECGRecord, height: int = 1200, *, show: bool = True, x_axis: str = "time", ) -> go.Figure: """Interactive full ECG report. Parameters ---------- record : ECGRecord Full ECG record. show : bool Display the plot immediately (default ``True``). x_axis : str ``"time"`` for seconds (default) or ``"samples"`` for sample indices. """ require_plotly() import plotly.graph_objects as go from plotly.subplots import make_subplots leads = record.leads n_leads = min(len(leads), 12) total_rows = n_leads + 1 subplot_titles = [ld.label for ld in leads[:n_leads]] + ["II Rhythm Strip"] fig = make_subplots( rows=total_rows, cols=1, shared_xaxes=False, subplot_titles=subplot_titles, vertical_spacing=0.015, ) for i, ld in enumerate(leads[:n_leads], start=1): x, _, _ = _x_data_i(ld, x_axis) max_s = int(10.0 * ld.sampling_rate) sl = slice(0, min(max_s, len(ld.samples))) fig.add_trace( go.Scatter( x=x[sl], y=ld.samples[sl], mode="lines", name=ld.label, line=dict(color=lead_color(ld.label), width=1), showlegend=False, hovertemplate=f"{ld.label}<br>%{{x:.3f}}: %{{y:.3f}}<extra></extra>", ), row=i, col=1, ) rl = _find_lead(leads, "II") if rl is not None: x, _, _ = _x_data_i(rl, x_axis) fig.add_trace( go.Scatter( x=x, y=rl.samples, mode="lines", name="II rhythm", line=dict(color=lead_color("II"), width=1), showlegend=False, hovertemplate="II<br>%{x:.3f}: %{y:.3f}<extra></extra>", ), row=total_rows, col=1, ) fig.update_xaxes(rangeslider=dict(visible=True), row=total_rows, col=1) header_text = _build_header_text(record) fig.add_annotation( text=header_text, xref="paper", yref="paper", x=0, y=1.02, showarrow=False, font=dict(size=11, family="monospace"), align="left", ) xlabel = "Sample" if x_axis == "samples" else "Time (s)" fig.update_layout( height=height, template="plotly_white", hovermode="closest", margin=dict(t=60), ) fig.update_xaxes(title_text=xlabel, row=total_rows, col=1) if show: fig.show() return fig