diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml new file mode 100644 index 00000000..0b7e206c --- /dev/null +++ b/.github/workflows/ci.yml @@ -0,0 +1,35 @@ +name: CI + +on: + push: + branches: [main] + pull_request: + +jobs: + test: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - name: Set up uv + uses: astral-sh/setup-uv@v5 + with: + python-version: "3.12" + enable-cache: true + + - name: Install dependencies + run: uv sync --group dev + + - name: Install Playwright browsers + run: uv run playwright install --with-deps chromium + + - name: Run tests with coverage + run: uv run pytest + + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v5 + with: + files: coverage.xml + fail_ci_if_error: false + env: + CODECOV_TOKEN: ${{ secrets.CODECOV_TOKEN }} diff --git a/.github/workflows/docs.yml b/.github/workflows/docs.yml new file mode 100644 index 00000000..17a97dab --- /dev/null +++ b/.github/workflows/docs.yml @@ -0,0 +1,169 @@ +name: Docs + +on: + push: + # Publish to dev/ on every push to main … + branches: [main] + # … and to a versioned directory on every release tag. + tags: ["v*.*.*"] + pull_request: + branches: [main] + # Allow manual re-builds from the Actions tab. + workflow_dispatch: + +# Only one docs deployment should run at a time to avoid race conditions on +# the gh-pages branch. +concurrency: + group: docs-${{ github.ref }} + cancel-in-progress: true + +permissions: + contents: write # needed to push to gh-pages + +jobs: + # ── Build ────────────────────────────────────────────────────────────────── + # Runs on every push and every pull request. Treats warnings as errors so + # broken cross-references and bad docstrings are caught before merge. + build: + name: Build docs + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + # ── uv + Python ────────────────────────────────────────────────────── + - name: Set up uv + uses: astral-sh/setup-uv@v5 + with: + python-version: "3.13" + enable-cache: true + + # ── Dependencies ───────────────────────────────────────────────────── + # Install the package itself plus the [docs] optional-dependency group + # (sphinx, pydata-sphinx-theme, sphinx-gallery, pillow, playwright). + - name: Install dependencies (with docs extras) + run: uv sync --extra docs + + # Playwright ships the Python bindings but NOT the browser binaries. + # --with-deps also installs the OS-level shared libraries Chromium needs + # (libglib2, libnss3, etc.) on bare Ubuntu runners. + - name: Install Playwright browser + run: uv run playwright install chromium --with-deps + + # ── Build Pyodide wheel ─────────────────────────────────────────────── + # Produces docs/_static/wheels/anyplotlib-0.0.0-py3-none-any.whl so the + # in-browser Pyodide bridge can install the exact source tree that built + # these docs — no PyPI release required. + - name: Build Pyodide wheel + run: | + mkdir -p docs/_static/wheels + uv build --wheel --out-dir docs/_static/wheels/ + # Rename to the stable sentinel name micropip expects for URL installs. + cd docs/_static/wheels + for f in anyplotlib-*.whl; do + [ "$f" != "anyplotlib-0.0.0-py3-none-any.whl" ] && mv "$f" anyplotlib-0.0.0-py3-none-any.whl + done + + # ── Sphinx build ───────────────────────────────────────────────────── + # -W turns warnings into errors; --keep-going collects all of them. + - name: Build HTML documentation + run: | + uv run sphinx-build -b html docs build/html -W --keep-going + + # ── Upload built HTML as an artifact so it can be inspected on PRs ── + - name: Upload HTML artifact + uses: actions/upload-artifact@v4 + with: + name: docs-html + path: build/html + retention-days: 7 + + # ── Deploy ───────────────────────────────────────────────────────────────── + # Only runs after a successful build on pushes to main or release tags. + # Pull requests skip this job entirely. + deploy: + name: Deploy docs + needs: build + runs-on: ubuntu-latest + # Skip deployment for pull requests. + if: github.event_name != 'pull_request' + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + # ── uv + Python ────────────────────────────────────────────────────── + - name: Set up uv + uses: astral-sh/setup-uv@v5 + with: + python-version: "3.13" + enable-cache: true + + # ── Dependencies ───────────────────────────────────────────────────── + - name: Install dependencies (with docs extras) + run: uv sync --extra docs + + - name: Install Playwright browser + run: uv run playwright install chromium --with-deps + + # ── Determine deployment target ────────────────────────────────────── + # Release tag (refs/tags/v1.2.3) → destination = "v1.2.3" + # Everything else (push to main, manual dispatch) → destination = "dev" + - name: Determine deployment directory + id: target + shell: bash + run: | + if [[ "${GITHUB_REF}" == refs/tags/v* ]]; then + echo "dest_dir=${GITHUB_REF_NAME}" >> "$GITHUB_OUTPUT" + else + echo "dest_dir=dev" >> "$GITHUB_OUTPUT" + fi + + # ── Build Pyodide wheel ─────────────────────────────────────────────── + - name: Build Pyodide wheel + run: | + mkdir -p docs/_static/wheels + uv build --wheel --out-dir docs/_static/wheels/ + cd docs/_static/wheels + for f in anyplotlib-*.whl; do + [ "$f" != "anyplotlib-0.0.0-py3-none-any.whl" ] && mv "$f" anyplotlib-0.0.0-py3-none-any.whl + done + + # ── Sphinx build ───────────────────────────────────────────────────── + - name: Build HTML documentation + env: + DOCS_VERSION: ${{ steps.target.outputs.dest_dir }} + run: | + uv run sphinx-build -b html docs build/html -W --keep-going + + # ── Deploy to gh-pages ─────────────────────────────────────────────── + # keep_files: true preserves all existing directories on the branch so + # versioned releases accumulate rather than overwriting each other. + - name: Deploy to GitHub Pages + uses: peaceiris/actions-gh-pages@v4 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + publish_dir: ./build/html + destination_dir: ${{ steps.target.outputs.dest_dir }} + keep_files: true + commit_message: | + docs: deploy ${{ steps.target.outputs.dest_dir }} @ ${{ github.sha }} + + # ── Deploy root files (redirect + switcher) ────────────────────────── + # Places index.html and switcher.json at the root of gh-pages so the + # bare URL redirects to dev/ and the version switcher is always reachable. + - name: Deploy root redirect and switcher + uses: peaceiris/actions-gh-pages@v4 + with: + github_token: ${{ secrets.GITHUB_TOKEN }} + publish_dir: ./docs/_root + destination_dir: . + keep_files: true + commit_message: | + docs: update root redirect and switcher.json @ ${{ github.sha }} + diff --git a/.github/workflows/prepare_release.yml b/.github/workflows/prepare_release.yml new file mode 100644 index 00000000..cbcabe0f --- /dev/null +++ b/.github/workflows/prepare_release.yml @@ -0,0 +1,218 @@ +name: Prepare Release + +# Run manually from the Actions tab. +# Creates a branch + PR that bumps the version, builds the changelog, +# and updates the docs switcher — ready to review before tagging. +on: + workflow_dispatch: + inputs: + bump: + description: "Version component to bump" + required: true + type: choice + options: + - minor + - bugfix + - major + - pre-release # increments the bN counter on the current base version + beta: + description: "Mark as beta pre-release (adds bN suffix; always true for pre-release)" + required: false + type: boolean + default: false + +permissions: + contents: write # push branch + pull-requests: write # open PR + +jobs: + prepare: + name: Prepare release PR + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 0 + + - name: Set up uv + uses: astral-sh/setup-uv@v5 + with: + python-version: "3.13" + enable-cache: true + + - name: Install dev dependencies + run: uv sync + + # ── Compute the new version ────────────────────────────────────────── + - name: Compute new version + id: version + env: + BUMP: ${{ inputs.bump }} + IS_BETA: ${{ inputs.beta }} + run: | + CURRENT=$(grep '^version = ' pyproject.toml | sed 's/version = "\(.*\)"/\1/') + export CURRENT_VERSION="$CURRENT" + + NEW_VERSION=$(python3 - <<'PYEOF' + import re, os + + current = os.environ["CURRENT_VERSION"] + bump = os.environ["BUMP"] + is_beta = os.environ["IS_BETA"].lower() == "true" + + m = re.match(r"^(\d+)\.(\d+)\.(\d+)(?:b(\d+))?", current) + major = int(m.group(1)) + minor = int(m.group(2)) + patch = int(m.group(3)) + beta_n = int(m.group(4)) if m.group(4) else None + + if bump == "major": + major, minor, patch = major + 1, 0, 0 + elif bump == "minor": + minor, patch = minor + 1, 0 + elif bump == "bugfix": + patch += 1 + elif bump == "pre-release": + # Keep the same base; just walk the beta counter forward. + is_beta = True + beta_n = (beta_n or 0) + 1 + + if is_beta: + if bump != "pre-release": + beta_n = 1 # fresh beta series for the new base + print(f"{major}.{minor}.{patch}b{beta_n}", end="") + else: + print(f"{major}.{minor}.{patch}", end="") + PYEOF + ) + + echo "new_version=$NEW_VERSION" >> "$GITHUB_OUTPUT" + echo "tag=v$NEW_VERSION" >> "$GITHUB_OUTPUT" + echo "branch=release/v$NEW_VERSION" >> "$GITHUB_OUTPUT" + echo "is_beta=${{ inputs.beta }}" >> "$GITHUB_OUTPUT" + echo "Bumping (${{ inputs.bump }}): $CURRENT → $NEW_VERSION" + + # ── Bump version strings ───────────────────────────────────────────── + - name: Bump version in pyproject.toml + run: | + sed -i 's/^version = ".*"/version = "${{ steps.version.outputs.new_version }}"/' pyproject.toml + + - name: Bump version in docs/conf.py + run: | + sed -i 's/^release = ".*"/release = "${{ steps.version.outputs.new_version }}"/' docs/conf.py + + # ── Build changelog ────────────────────────────────────────────────── + - name: Build changelog with towncrier + run: | + FRAGMENT_COUNT=$(find upcoming_changes -maxdepth 1 -name "*.rst" \ + ! -name "README.rst" | wc -l) + if [ "$FRAGMENT_COUNT" -eq 0 ]; then + echo "⚠ No news fragments found — skipping towncrier (CHANGELOG.rst unchanged)." + else + uvx towncrier build --yes --version "${{ steps.version.outputs.new_version }}" + fi + + # ── Update docs switcher.json ──────────────────────────────────────── + - name: Update docs/switcher.json + env: + VERSION_TAG: ${{ steps.version.outputs.tag }} + IS_BETA: ${{ inputs.beta }} + shell: python + run: | + import json, re, pathlib, os + + version = os.environ["VERSION_TAG"] + is_beta = os.environ["IS_BETA"].lower() == "true" + + path = pathlib.Path("docs/_root/switcher.json") + text = path.read_text() + # The file may contain a trailing comma; strip it before parsing. + text_clean = re.sub(r",(\s*[\]\}])", r"\1", text) + entries = json.loads(text_clean) + + # Remove any existing entry for this version (makes the step idempotent). + entries = [e for e in entries if e.get("version") != version] + + label = f"{version} (beta)" if is_beta else f"{version} (stable)" + url = f"https://cssfrancis.github.io/anyplotlib/{version}/" + # Insert right after the "dev" entry so newest stable floats to top. + entries.insert(1, {"name": label, "version": version, "url": url}) + + path.write_text(json.dumps(entries, indent=2) + "\n") + + # ── Update root redirect for stable releases ───────────────────────── + - name: Update root redirect (stable releases only) + if: ${{ inputs.beta == false && inputs.bump != 'pre-release' }} + env: + VERSION_TAG: ${{ steps.version.outputs.tag }} + shell: python + run: | + import re, pathlib, os + + version = os.environ["VERSION_TAG"] + path = pathlib.Path("docs/_root/index.html") + text = path.read_text() + + text = re.sub(r'(content="0; url=)[^"]+(")', rf"\g<1>{version}/\2", text) + text = re.sub(r'(rel="canonical" href=")[^"]+(")', rf"\g<1>{version}/\2", text) + text = re.sub(r'([^<]*)', rf"\g<1>{version}/\2", text) + text = re.sub(r'(Redirecting to )[^<]*()', + rf"\g<1>{version} documentation\2", text) + + path.write_text(text) + + # ── Commit and push ────────────────────────────────────────────────── + - name: Configure git + run: | + git config user.name "github-actions[bot]" + git config user.email "github-actions[bot]@users.noreply.github.com" + + - name: Commit release changes + run: | + git checkout -b "${{ steps.version.outputs.branch }}" + + # Stage version bumps, updated changelog, and consumed fragments. + git add pyproject.toml docs/conf.py CHANGELOG.rst + git add docs/_root/switcher.json docs/_root/index.html + git add -A upcoming_changes/ # stages deleted fragment files + + git commit -m "chore: prepare release ${{ steps.version.outputs.tag }}" + git push origin "${{ steps.version.outputs.branch }}" + + # ── Open pull request ──────────────────────────────────────────────── + - name: Open pull request + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + TAG: ${{ steps.version.outputs.tag }} + BRANCH: ${{ steps.version.outputs.branch }} + run: | + gh pr create \ + --title "Release ${TAG}" \ + --base main \ + --head "${BRANCH}" \ + --body "## Release ${TAG} + + > Auto-generated by the **Prepare Release** workflow. + + ### What changed + - Version bumped to \`${TAG}\` in \`pyproject.toml\` and \`docs/conf.py\` + - \`CHANGELOG.rst\` updated from towncrier fragments + - \`docs/_root/switcher.json\` updated with the new version entry + $([ '${{ inputs.beta }}' = 'false' ] && echo '- Root redirect updated to point to this release' || echo '') + + ### Review checklist + - [ ] \`CHANGELOG.rst\` reads well — edit the fragment text directly if needed + - [ ] Version strings are correct in \`pyproject.toml\` and \`docs/conf.py\` + - [ ] \`switcher.json\` has the right label and URL + - [ ] CI passes + + ### After merging + Create and push the tag to trigger the Release and Docs workflows: + \`\`\`bash + git fetch origin + git tag ${TAG} origin/main + git push origin ${TAG} + \`\`\`" + diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml new file mode 100644 index 00000000..0ddf3410 --- /dev/null +++ b/.github/workflows/release.yml @@ -0,0 +1,109 @@ +name: Release + +# Fires when a version tag is pushed (manually, after the Prepare Release PR +# is merged and reviewed). +# +# Jobs: +# build - build wheel + sdist with uv +# publish - upload to PyPI via OIDC trusted publishing (no API token needed) +# release - create a GitHub Release with the dist files and changelog notes + +on: + push: + tags: ["v*.*.*"] + +permissions: + contents: write # create GitHub Releases and upload assets + id-token: write # OIDC token for PyPI trusted publishing + +jobs: + # -------------------------------------------------------------------------- + build: + name: Build distribution + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up uv + uses: astral-sh/setup-uv@v5 + with: + python-version: "3.13" + enable-cache: true + + - name: Build wheel and sdist + run: uv build + + - name: Upload dist artifact + uses: actions/upload-artifact@v4 + with: + name: dist + path: dist/ + if-no-files-found: error + retention-days: 7 + + # -------------------------------------------------------------------------- + publish: + name: Publish to PyPI + needs: build + runs-on: ubuntu-latest + + environment: + name: pypi + url: https://pypi.org/p/anyplotlib + + steps: + - name: Download dist artifact + uses: actions/download-artifact@v4 + with: + name: dist + path: dist/ + + # Trusted publishing - no API token required. + # One-time setup on pypi.org: add a pending publisher for + # Owner: CSSFrancis Repo: anyplotlib Workflow: release.yml + - name: Publish to PyPI + uses: pypa/gh-action-pypi-publish@release/v1 + + # -------------------------------------------------------------------------- + release: + name: Create GitHub Release + needs: build + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Download dist artifact + uses: actions/download-artifact@v4 + with: + name: dist + path: dist/ + + - name: Extract release notes from CHANGELOG.rst + env: + TAG: ${{ github.ref_name }} + shell: python + run: | + import re, pathlib, os + tag = os.environ["TAG"] + text = pathlib.Path("CHANGELOG.rst").read_text() + parts = re.split(r"(?m)(?=^\S[^\n]*\n=+\n)", text) + notes = next((p.strip() for p in parts if p.strip().startswith(tag)), None) + fallback = "Release " + tag + "\n" + "=" * (len(tag) + 8) + "\n\nSee CHANGELOG.rst." + pathlib.Path("release_notes.rst").write_text(notes if notes else fallback) + + - name: Create GitHub Release + env: + GH_TOKEN: ${{ secrets.GITHUB_TOKEN }} + TAG: ${{ github.ref_name }} + run: | + PRERELEASE_FLAG="" + if [[ "$TAG" == *b* ]]; then PRERELEASE_FLAG="--prerelease"; fi + gh release create "$TAG" \ + --title "$TAG" \ + --notes-file release_notes.rst \ + $PRERELEASE_FLAG \ + dist/* diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml new file mode 100644 index 00000000..32cbf627 --- /dev/null +++ b/.github/workflows/tests.yml @@ -0,0 +1,80 @@ +name: Tests + +on: + push: + branches: [main] + pull_request: + +concurrency: + group: tests-${{ github.ref }} + cancel-in-progress: true + +jobs: + test: + name: Python ${{ matrix.python-version }} / ${{ matrix.os }} + runs-on: ${{ matrix.os }} + + strategy: + fail-fast: false + matrix: + os: [ubuntu-latest, macos-latest, windows-latest] + python-version: ["3.10", "3.11", "3.12", "3.13"] + exclude: + - os: macos-latest + python-version: "3.10" + - os: macos-latest + python-version: "3.11" + - os: windows-latest + python-version: "3.10" + - os: windows-latest + python-version: "3.11" + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up uv + uses: astral-sh/setup-uv@v5 + with: + python-version: ${{ matrix.python-version }} + enable-cache: true + + - name: Install dependencies + run: uv sync + + - name: Install Playwright browsers (Linux) + if: runner.os == 'Linux' + run: uv run playwright install chromium --with-deps + + - name: Install Playwright browsers (macOS / Windows) + if: runner.os != 'Linux' + run: uv run playwright install chromium + + - name: Run tests + run: uv run pytest anyplotlib/tests/ -v --tb=short + + minimum-deps: + name: Minimum deps (Python 3.10 / ubuntu) + runs-on: ubuntu-latest + + steps: + - name: Checkout repository + uses: actions/checkout@v4 + + - name: Set up uv + uses: astral-sh/setup-uv@v5 + with: + python-version: "3.10" + enable-cache: true + + - name: Install dependencies at minimum versions + run: uv sync --resolution lowest-direct + + - name: Show installed versions + run: uv run pip list --format=columns + + - name: Install Playwright browsers + run: uv run playwright install chromium --with-deps + + - name: Run tests + run: uv run pytest anyplotlib/tests/ -v --tb=short diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..785d9799 --- /dev/null +++ b/.gitignore @@ -0,0 +1,53 @@ +# Python bytecode / caches +__pycache__/ +*.py[cod] +*$py.class +*.pyo + +# Distribution / packaging +dist/ +build/ +*.egg-info/ +*.egg +.eggs/ + +# Virtual environments +.venv/ +venv/ +env/ + +# Test / coverage artefacts +.pytest_cache/ +.coverage +coverage.xml +htmlcov/ + +# Jupyter notebooks checkpoints +.ipynb_checkpoints/ + +# Sphinx build output +docs/_build/ +docs/api/generated/ +docs/auto_examples/ +docs/sg_execution_times.rst +build/html/ +build/doctrees/ + +# Generated Pyodide wheel (built by workflow / make html — never commit) +docs/_static/wheels/ +docs/_static/anywidget_config.js + +# Editor / IDE +.idea/ +.vscode/ +*.swp +*.swo + +# macOS +.DS_Store + +# Git worktrees +.worktrees/ + +# Generated by Sphinx-Gallery (anywidget iframe HTML) — never commit +docs/_static/viewer_widgets/ diff --git a/AGENTS.md b/AGENTS.md new file mode 100644 index 00000000..ed0a6f59 --- /dev/null +++ b/AGENTS.md @@ -0,0 +1,127 @@ +# AGENTS.md — anyplotlib Codebase Guide + +## Architecture Overview + +`anyplotlib` is a Jupyter-compatible interactive plotting library. The key architectural split: + +- **`Figure`** (`anyplotlib/figure/_figure.py`) — the only `anywidget.AnyWidget` subclass. Owns all traitlets and is the Python↔JS bridge. +- **Plot objects** (`plot1d/`, `plot2d/`, `plot3d/`) — `Plot1D`, `PlotBar`, `Plot2D`, `PlotMesh`, `Plot3D` are **plain Python classes**, not widgets. They hold state in `_state` dicts and push to the Figure. Shared behaviour lives in `_base_plot.py` (`_BasePlot`, `_PanelMixin`, `_MarkerMixin`). +- **`Axes`** (`axes/_axes.py`) — grid-cell container; factory methods (`imshow`, `plot`, `bar`, `pcolormesh`, `plot_surface`, …) create plot objects and attach them. +- **`figure_esm.js`** — pure-JS canvas renderer (~4,400 lines); all rendering logic lives here. **Read `anyplotlib/FIGURE_ESM.md` first** — it is the section map. +- **`markers.py`** — static visual overlays (circles, arrows, lines, etc.) with a two-level dict registry: `plot.markers[type][name]`. +- **`widgets/`** — interactive draggable overlays (`RectangleWidget`, `CrosshairWidget`, etc.) that receive JS position updates. +- **`callbacks.py`** — event system: `Event` dataclass, `CallbackRegistry` (priority ordering, wildcard, pause/hold), `_EventMixin` (`add_event_handler`). +- **`embed.py`** — Jupyter-free embedding (Electron / web pages): `figure_state()`, `to_html()`/`save_html()`, `esm_path()`, and `FigureBridge` (transport-agnostic live Python↔JS sync). The JS counterpart is the `mount(el, state, opts)` export in `figure_esm.js`. See `docs/embedding.rst`. +- **`sphinx_anywidget/`** — Sphinx extension that makes anywidget figures live in docs pages via Pyodide (wheel builder, gallery scraper, `anywidget-figure` directive, `static/anywidget_bridge.js`). + +## Package layout + +``` +anyplotlib/ +├── __init__.py # public API re-exports +├── _base_plot.py # _BasePlot, _PanelMixin, _MarkerMixin +├── _utils.py # b64 encoding, linestyle/colormap helpers +├── _repr_utils.py # self-contained iframe HTML for non-kernel use +├── callbacks.py # Event, CallbackRegistry, _EventMixin +├── markers.py # MarkerRegistry, MarkerGroup +├── figure_esm.js # the entire JS renderer (see FIGURE_ESM.md) +├── figure/ # Figure widget, GridSpec/SubplotSpec, subplots() +├── axes/ # Axes, InsetAxes +├── plot1d/ # Plot1D, Line1D, PlotBar +├── plot2d/ # Plot2D, PlotMesh +├── plot3d/ # Plot3D (surface / scatter / line) +├── widgets/ # Widget base + 1D/2D widget classes +├── sphinx_anywidget/ # Sphinx/Pyodide extension (own test suite) +└── tests/ # main test suite, grouped by area +``` + +## Python ↔ JS Data Flow + +**Python → JS (push):** Every plot state mutation calls `plot._push()` → `figure._push(panel_id)` → serialises `_state` to JSON → writes to the dynamic traitlet `panel_{id}_json` (tagged `sync=True`) → JS observes and re-renders. + +**JS → Python (events/widgets):** JS interaction events (drags, clicks, zoom, keys) come through the `event_json` traitlet → dispatched by `Figure._dispatch_event()` → `Widget._update_from_js()` for widget drags, then `plot.callbacks.fire(event)`. + +**Adding state fields:** Add to `_state` in the constructor, include in `to_state_dict()`, and handle in `figure_esm.js`. + +## Key Patterns + +**`_push()` contract:** Any mutation to a plot's `_state` must end with `self._push()`. Forgetting this means changes won't appear in JS. + +**Marker kwargs use matplotlib names** — translated to wire format in `MarkerGroup.to_wire()`: +```python +plot.add_circles(offsets, name="g1", facecolors="#f00", edgecolors="#fff", radius=5) +plot.markers["circles"]["g1"].set(radius=8) # live update +``` + +**Widget (interactive overlay) pattern:** handlers register on the widget (or plot) via `add_event_handler` — directly or as a decorator: +```python +wid = plot.add_widget("crosshair", cx=64, cy=64) + +@wid.add_event_handler("pointer_move") # fires every drag frame — keep fast +def live(event): readout.value = f"({wid.cx:.1f}, {wid.cy:.1f})" + +@wid.add_event_handler("pointer_up") # fires once on release — safe for expensive work +def done(event): recompute(wid.cx, wid.cy) + +@plot.add_event_handler("pointer_settled", ms=400) # dwell-based settling +def settled(event): ... +``` + +**Label sizes and mini-TeX:** all label setters take an optional `fontsize` (CSS px), and label strings support a TeX subset inside `$...$` (superscripts `$10^{-3}$`, subscripts `$E_F$`, Greek `\alpha`, symbols `\times \AA \degree`) parsed at draw time by `_drawTex` in `figure_esm.js` — Python stores strings verbatim: +```python +plot.set_xlabel(r"$q_x$ ($\AA^{-1}$)", fontsize=13) +plot.set_tick_label_size(11) +``` + +**`subplots` squeeze behaviour** mirrors matplotlib: `(1,1)` → scalar `Axes`; `(1,N)`/`(N,1)` → 1-D array; `(M,N)` → 2-D array. + +**`GridSpec` indexing** mirrors matplotlib exactly, including negative indices, slices, and multi-cell spans — see `tests/test_layouts/test_gridspec.py`. + +## Developer Workflows + +```bash +# Install (uses uv) +uv sync +uv run playwright install chromium # one-time: browser for rendering tests + +# Run the full test suite (pytest testpaths cover both suites) +uv run pytest + +# Run a quick subset without coverage output +uv run pytest anyplotlib/tests/test_plot1d -q --no-cov + +# Build docs (Sphinx Gallery, outputs to build/html/) +make html +make clean # wipe build artefacts +``` + +Changelog entries: add a fragment file to `upcoming_changes/` (e.g. +`123.new_feature.rst`) — towncrier assembles `CHANGELOG.rst` at release time. + +## Key Files + +| File | Purpose | +|------|---------| +| `anyplotlib/figure/_figure.py` | `Figure` widget; layout engine; JS↔Python dispatch | +| `anyplotlib/figure/_gridspec.py` | `GridSpec`, `SubplotSpec` | +| `anyplotlib/figure/_subplots.py` | `subplots()` factory | +| `anyplotlib/axes/_axes.py` | `Axes` — plot factory methods | +| `anyplotlib/figure_esm.js` | All JS canvas rendering (~4,400 lines) | +| `anyplotlib/FIGURE_ESM.md` | Section map for `figure_esm.js` — read this before editing the JS | +| `anyplotlib/markers.py` | Static marker collections; `to_wire()` translation | +| `anyplotlib/widgets/` | Interactive overlay widgets | +| `anyplotlib/callbacks.py` | `CallbackRegistry`, `Event` dataclass, `_EventMixin` | +| `anyplotlib/tests/test_interactive/` | Callback + widget tests (good reference for event API) | +| `anyplotlib/tests/test_layouts/` | GridSpec / sizing pipeline / visual baseline tests | +| `Examples/` | Gallery examples (files must be named `plot_*.py`) | + +## Important Constraints + +- The **OO API only** — no `plt.plot()` style. Always create a `Figure` and call methods on `Axes`. +- Use **`import anyplotlib as apl`** in all examples, docs, and docstrings. +- Plot objects (`Plot2D` etc.) store all display state in `self._state` (plain dict). Never add traitlets to them. +- `Figure` adds per-panel traits **dynamically** (`add_traits(panel_{id}_json=...)`); check `has_trait()` before accessing. +- Colormap LUTs are built via colorcet (`_build_colormap_lut` in `_utils.py`) and serialised as `[[r,g,b], ...]` in `_state["colormap_data"]`; matplotlib is only a fallback and not a dependency. +- Docs examples in `Examples/` must have a module-level docstring (first lines) for Sphinx Gallery to pick them up; they are executed by `tests/test_examples`. +- Playwright tests share a session-scoped Chromium fixture (`anyplotlib/conftest.py`); they **error** (not skip) if browsers are missing — run `uv run playwright install chromium` first. +- When possible stop and ask questions if you're unsure about how something works. diff --git a/CHANGELOG.rst b/CHANGELOG.rst new file mode 100644 index 00000000..6cdc35c2 --- /dev/null +++ b/CHANGELOG.rst @@ -0,0 +1,19 @@ +========= +Changelog +========= + +All notable changes to **anyplotlib** are documented here. + +Fragment files in ``upcoming_changes/`` are assembled into this file by +`towncrier `_ when a release is prepared +(see ``upcoming_changes/README.rst`` for contributor instructions). + +.. towncrier release notes start + +v0.1.0 (2026-04-12) +==================== + +Initial release. Includes ``Figure``, ``Axes``, ``GridSpec``, ``subplots``, +``Plot1D``, ``Plot2D``, ``PlotMesh``, ``Plot3D``, ``PlotBar``, a full marker +system, interactive overlay widgets, and a two-tier callback registry. + diff --git a/Examples/Benchmarks/README.rst b/Examples/Benchmarks/README.rst new file mode 100644 index 00000000..fd905e30 --- /dev/null +++ b/Examples/Benchmarks/README.rst @@ -0,0 +1,8 @@ +Benchmarks +---------- + +Timing comparisons for the Python-side data-push pipeline in anyplotlib, +matplotlib, Plotly, and Bokeh. All measurements capture only the +**Python serialisation cost** — the bottleneck in a live Jupyter session +where new data must be encoded and dispatched to the browser on every frame. + diff --git a/Examples/Benchmarks/plot_benchmark_comparison.py b/Examples/Benchmarks/plot_benchmark_comparison.py new file mode 100644 index 00000000..07114e67 --- /dev/null +++ b/Examples/Benchmarks/plot_benchmark_comparison.py @@ -0,0 +1,559 @@ +""" +Plot Update Comparison +====================== + +There are a couple of different "costs" asscociated with rendering plots and images. There is +usually a Python-side cost as well as a browser-side rendering cost. We've broken down those +two costs here comparing different libraries for the first cost. The second is harder to +measure. We've done it for anyplotlib but doing it for `ipympl`, bokeh and plotly is a +little more difficult. + +* **Python pre-render** — everything that happens in the Python process before + bytes reach the browser (``timeit``-measured, no browser needed). +* **JS canvas render** — the actual canvas paint time measured inside headless + Chromium via Playwright (anyplotlib only; see the third and fourth charts). + +.. note:: + + The Python-side timings are pure-Python ``timeit`` benchmarks — no browser + is involved. The JS render timings use Playwright's + ``requestAnimationFrame`` loop and ``window._aplTiming`` to measure + inter-frame intervals in a real Chromium renderer. + +What each Python measurement covers +------------------------------------- + ++---------------+---------------------------------------------------------------+ +| Library | What is timed | ++===============+===============================================================+ +| anyplotlib | ``plot.set_data(data)`` — float → uint8 normalise → base64 | +| | encode → LUT rebuild → state-dict assembly → json.dumps → | +| | traitlet dispatch to JS renderer. | ++---------------+---------------------------------------------------------------+ +| ipympl | ``im.set_data(data); fig.canvas.draw()`` — fully rasterises | +| | the figure to an Agg pixel buffer, then encodes it as a PNG | +| | blob ready for the ipympl comm channel. This is the complete | +| | Python-side cost before the PNG is sent to the browser. | ++---------------+---------------------------------------------------------------+ +| Plotly | ``fig.data[0].z = data.tolist(); fig.to_json()`` — builds the | +| | full JSON blob that Plotly.js receives; every float becomes a | +| | decimal string. Plotly.js WebGL/SVG render is additional. | ++---------------+---------------------------------------------------------------+ +| Bokeh | ``source.data = {"image": [data]}; json_item(p)`` — builds | +| | the full JSON document patch that Bokeh.js receives. Canvas | +| | render is additional. | ++---------------+---------------------------------------------------------------+ + +""" +# sphinx_gallery_start_ignore +from __future__ import annotations + +import pathlib +import tempfile +import timeit +import warnings + +import matplotlib +matplotlib.use("Agg") # must be set before pyplot import — used for ipympl measurement +import matplotlib.pyplot as plt +import numpy as np + +# --------------------------------------------------------------------------- +# Optional library imports — degrade gracefully if not installed +# --------------------------------------------------------------------------- + +try: + from playwright.sync_api import sync_playwright as _sync_playwright + _HAS_PLAYWRIGHT = True +except ImportError: + _HAS_PLAYWRIGHT = False + warnings.warn("Playwright not installed — JS render timing omitted.", stacklevel=1) + +try: + import plotly.graph_objects as _go + _HAS_PLOTLY = True +except ImportError: + _HAS_PLOTLY = False + warnings.warn("Plotly not installed — Plotly bars omitted.", stacklevel=1) + +try: + from bokeh.plotting import figure as _bk_figure + from bokeh.models import ColumnDataSource as _CDS + from bokeh.embed import json_item as _json_item + _HAS_BOKEH = True +except ImportError: + _HAS_BOKEH = False + warnings.warn("Bokeh not installed — Bokeh bars omitted.", stacklevel=1) + +import anyplotlib as apl + +# --------------------------------------------------------------------------- +# Timing helpers +# --------------------------------------------------------------------------- + +_REPEATS = 5 +_NUMBER = 3 + + +def _timeit_min_ms(stmt) -> float: + """Return the best (minimum) per-call time in milliseconds.""" + raw = timeit.repeat(stmt=stmt, number=_NUMBER, repeat=_REPEATS) + return min(t / _NUMBER * 1000 for t in raw) + + +# rAF-paced bench loop — mirrors tests/conftest.py _run_bench. +# Each frame perturbs one state field so the blit-cache is invalidated and +# the full decode → LUT → render path executes every cycle. +_JS_BENCH = """ +([panelId, nWarmup, nSamples, field, delta]) => + new Promise((resolve, reject) => { + const total = nWarmup + nSamples; + let i = 0; + function step() { + if (i >= total) { + resolve(window._aplTiming ? window._aplTiming[panelId] : null); + return; + } + const key = 'panel_' + panelId + '_json'; + try { + const st = JSON.parse(window._aplModel.get(key)); + st[field] = (st[field] || 0) + delta; + window._aplModel.set(key, JSON.stringify(st)); + } catch(e) { reject(e); return; } + if (i === nWarmup - 1) { + if (window._aplTiming) delete window._aplTiming[panelId]; + } + i++; + requestAnimationFrame(step); + } + requestAnimationFrame(step); + }) +""" + + +def _measure_js_ms_all(pairs, n_warmup=3, n_samples=12): + """Measure JS render time for a list of (widget, panel_id, field, delta). + + Opens each widget in a shared headless Chromium session, runs the rAF + bench loop, and returns a list of mean_ms values (None on failure). + Only called when _HAS_PLAYWRIGHT is True. + """ + from anyplotlib._repr_utils import build_standalone_html + + results_js = [] + tmp_files = [] + try: + with _sync_playwright() as pw: + browser = pw.chromium.launch( + headless=True, + args=["--no-sandbox", "--disable-setuid-sandbox"], + ) + for pair in pairs: + widget, panel_id, field, delta = pair[:4] + # Per-pair timeout: large images take longer to decode and paint. + # Formula: max(30_000, sz*sz // 200) — scales from 30 s up for 4K+. + timeout_ms = pair[4] if len(pair) > 4 else 60_000 + html = build_standalone_html(widget, resizable=False) + html = html.replace( + "renderFn({ model, el });", + "renderFn({ model, el }); window._aplReady = true;", + ) + html = html.replace( + "const model = makeModel(STATE);", + "const model = makeModel(STATE);\nwindow._aplModel = model;", + ) + with tempfile.NamedTemporaryFile( + suffix=".html", mode="w", encoding="utf-8", delete=False + ) as fh: + fh.write(html) + tmp = pathlib.Path(fh.name) + tmp_files.append(tmp) + try: + page = browser.new_page() + page.goto(tmp.as_uri()) + page.wait_for_function( + "() => window._aplReady === true", timeout=timeout_ms + ) + page.evaluate( + "() => new Promise(r =>" + " requestAnimationFrame(() => requestAnimationFrame(r)))" + ) + timing = page.evaluate( + _JS_BENCH, + [panel_id, n_warmup, n_samples, field, delta], + ) + page.close() + results_js.append(timing["mean_ms"] if timing else None) + except Exception: + results_js.append(None) + browser.close() + finally: + for tmp in tmp_files: + tmp.unlink(missing_ok=True) + return results_js + + +# --------------------------------------------------------------------------- +# Benchmark configuration +# --------------------------------------------------------------------------- + +_SIZES_2D = [64, 256, 512, 1024, 2048] +_SIZES_1D = [100, 1_000, 10_000, 100_000] + +rng = np.random.default_rng(42) + +# Pre-generate fixed frames so array creation is outside the timing loops. +_frames_2d = {s: rng.uniform(size=(s, s)).astype(np.float32) for s in _SIZES_2D} +_frames_1d = {n: np.cumsum(rng.standard_normal(n)).astype(np.float32) + for n in _SIZES_1D} + +_LIBRARIES = ["anyplotlib", "ipympl", "plotly", "bokeh"] + +results_2d: dict[str, dict[int, float | None]] = {lib: {} for lib in _LIBRARIES} +results_1d: dict[str, dict[int, float | None]] = {lib: {} for lib in _LIBRARIES} + +# --------------------------------------------------------------------------- +# 2-D image benchmark +# --------------------------------------------------------------------------- + +for sz in _SIZES_2D: + data = _frames_2d[sz] + + # ── anyplotlib: normalize → uint8 → base64 → LUT → json push ──────────── + _fig_apl, _ax_apl = apl.subplots(1, 1, figsize=(min(sz, 640), min(sz, 640))) + _plot_apl = _ax_apl.imshow(data) + _update_frames = [rng.uniform(size=(sz, sz)).astype(np.float32) + for _ in range(_NUMBER)] + _idx = [0] + + def _make_apl_update(plot, frames, idx): + def _fn(): + plot.set_data(frames[idx[0] % len(frames)]) + idx[0] += 1 + return _fn + + results_2d["anyplotlib"][sz] = _timeit_min_ms( + _make_apl_update(_plot_apl, _update_frames, _idx) + ) + + # ── ipympl: set_data + full Agg rasterisation (PNG comm pathway) ──────── + _fig_mpl, _ax_mpl = plt.subplots() + _im_mpl = _ax_mpl.imshow(data, cmap="viridis") + _canvas_mpl = _fig_mpl.canvas + _new_mpl = rng.uniform(size=(sz, sz)).astype(np.float32) + + def _make_mpl_update(im, canvas, new_data): + def _fn(): + im.set_data(new_data) + canvas.draw() + return _fn + + results_2d["ipympl"][sz] = _timeit_min_ms( + _make_mpl_update(_im_mpl, _canvas_mpl, _new_mpl) + ) + plt.close(_fig_mpl) + + # ── Plotly: assign z list + serialise to JSON ──────────────────────────── + if _HAS_PLOTLY: + _pgo_fig = _go.Figure(_go.Heatmap(z=data.tolist())) + _new_plotly = rng.uniform(size=(sz, sz)).astype(np.float32).tolist() + + def _make_plotly_update(fig, new_z): + def _fn(): + fig.data[0].z = new_z + fig.to_json() + return _fn + + results_2d["plotly"][sz] = _timeit_min_ms( + _make_plotly_update(_pgo_fig, _new_plotly) + ) + else: + results_2d["plotly"][sz] = None + + # ── Bokeh: replace source.data + serialise full document ──────────────── + if _HAS_BOKEH: + _bk_src = _CDS(data={"image": [data], "x": [0], "y": [0], + "dw": [sz], "dh": [sz]}) + _bk_plot = _bk_figure(width=400, height=400) + _bk_plot.image(image="image", x="x", y="y", dw="dw", dh="dh", + source=_bk_src, palette="Viridis256") + _new_bokeh = rng.uniform(size=(sz, sz)).astype(np.float32) + + def _make_bokeh_update(src, new_data, plot, w, h): + def _fn(): + src.data = {"image": [new_data], "x": [0], "y": [0], + "dw": [w], "dh": [h]} + _json_item(plot) + return _fn + + results_2d["bokeh"][sz] = _timeit_min_ms( + _make_bokeh_update(_bk_src, _new_bokeh, _bk_plot, sz, sz) + ) + else: + results_2d["bokeh"][sz] = None + +# --------------------------------------------------------------------------- +# 1-D line benchmark +# --------------------------------------------------------------------------- + +for n_pts in _SIZES_1D: + xs = np.arange(n_pts, dtype=np.float32) + ys = _frames_1d[n_pts] + + # ── anyplotlib ─────────────────────────────────────────────────────────── + _fig_apl1, _ax_apl1 = apl.subplots(1, 1, figsize=(640, 320)) + _plot_apl1 = _ax_apl1.plot(ys) + _new_ys_apl = rng.standard_normal(n_pts).cumsum().astype(np.float32) + + def _make_apl1d(plot, new_y): + def _fn(): plot.set_data(new_y) + return _fn + + results_1d["anyplotlib"][n_pts] = _timeit_min_ms( + _make_apl1d(_plot_apl1, _new_ys_apl) + ) + + # ── ipympl: set_ydata + full Agg rasterisation (PNG comm pathway) ─────── + _fig_mpl1, _ax_mpl1 = plt.subplots() + (_line_mpl,) = _ax_mpl1.plot(xs, ys) + _new_ys_mpl = rng.standard_normal(n_pts).cumsum().astype(np.float32) + + def _make_mpl1d(line, canvas, new_y): + def _fn(): + line.set_ydata(new_y) + canvas.draw() + return _fn + + results_1d["ipympl"][n_pts] = _timeit_min_ms( + _make_mpl1d(_line_mpl, _fig_mpl1.canvas, _new_ys_mpl) + ) + plt.close(_fig_mpl1) + + # ── Plotly ─────────────────────────────────────────────────────────────── + if _HAS_PLOTLY: + _pgo_fig1 = _go.Figure(_go.Scatter(x=xs.tolist(), y=ys.tolist())) + _new_ys_plotly = rng.standard_normal(n_pts).cumsum().astype(np.float32).tolist() + + def _make_plotly1d(fig, new_y): + def _fn(): + fig.data[0].y = new_y + fig.to_json() + return _fn + + results_1d["plotly"][n_pts] = _timeit_min_ms( + _make_plotly1d(_pgo_fig1, _new_ys_plotly) + ) + else: + results_1d["plotly"][n_pts] = None + + # ── Bokeh ───────────────────────────────────────────────────────────────── + if _HAS_BOKEH: + _bk_src1 = _CDS(data={"x": xs.tolist(), "y": ys.tolist()}) + _bk_plot1 = _bk_figure(width=600, height=300) + _bk_plot1.line("x", "y", source=_bk_src1) + _new_ys_bokeh = rng.standard_normal(n_pts).cumsum().astype(np.float32).tolist() + + def _make_bokeh1d(src, plot, new_x, new_y): + def _fn(): + src.data = {"x": new_x, "y": new_y} + _json_item(plot) + return _fn + + results_1d["bokeh"][n_pts] = _timeit_min_ms( + _make_bokeh1d(_bk_src1, _bk_plot1, xs.tolist(), _new_ys_bokeh) + ) + else: + results_1d["bokeh"][n_pts] = None + +# --------------------------------------------------------------------------- +# JS render timing — anyplotlib only (headless Chromium via Playwright) +# --------------------------------------------------------------------------- +# _recordFrame() in figure_esm.js timestamps the *start* of every draw call, +# so the inter-frame interval captured by _aplTiming approximates the full +# JS render cycle: JSON.parse → uint8 decode → LUT expand → ImageBitmap → +# ctx.drawImage (2-D) or ctx.lineTo loop (1-D). + +results_2d_js: dict[int, float | None] = {s: None for s in _SIZES_2D} +results_1d_js: dict[int, float | None] = {n: None for n in _SIZES_1D} + +if _HAS_PLAYWRIGHT: + _pairs_2d_js = [] + for _sz in _SIZES_2D: + _fjs, _ajs = apl.subplots(1, 1, figsize=(min(_sz, 640), min(_sz, 640))) + _pjs = _ajs.imshow(_frames_2d[_sz]) + # Timeout scales with image area: larger images take longer to decode + # and paint in Chromium. Formula: max(30 s, sz²/200) ms. + _js_timeout = max(30_000, _sz * _sz // 200) + _pairs_2d_js.append((_fjs, _pjs._id, "display_min", 1e-4, _js_timeout)) + + for _sz, _t in zip(_SIZES_2D, _measure_js_ms_all(_pairs_2d_js)): + results_2d_js[_sz] = _t + + _pairs_1d_js = [] + for _npts in _SIZES_1D: + _fjs1, _ajs1 = apl.subplots(1, 1, figsize=(640, 320)) + _pjs1 = _ajs1.plot(_frames_1d[_npts]) + _pairs_1d_js.append((_fjs1, _pjs1._id, "view_x0", 1e-4)) + + for _npts, _t in zip(_SIZES_1D, _measure_js_ms_all(_pairs_1d_js)): + results_1d_js[_npts] = _t + +# --------------------------------------------------------------------------- +# Chart helpers +# --------------------------------------------------------------------------- + +_COLORS = { + "anyplotlib": "#1976D2", + "ipympl": "#E64A19", + "plotly": "#7B1FA2", + "bokeh": "#2E7D32", +} + +# Short legend labels shown inside the anyplotlib bar chart. +_LABELS = { + "anyplotlib": "anyplotlib (float→uint8→b64→json→traitlet)", + "ipympl": "ipympl (set_data + Agg render → PNG comm)", + "plotly": "Plotly (z=list + to_json)", + "bokeh": "Bokeh (source.data + json_item)", +} + + +def _results_to_array(results, sizes): + """Build a (N_sizes, N_libs) float array. + + Missing entries (None) become 0.0 — valid JSON, and invisible on a + log-scale axis where 0 is clamped to 1e-10 below the visible range. + Using NaN would produce bare ``NaN`` tokens that JSON.parse rejects, + silently blanking the chart. + """ + rows = [] + for s in sizes: + rows.append([ + results[lib].get(s) if results[lib].get(s) is not None else 0.0 + for lib in _LIBRARIES + ]) + return np.array(rows, dtype=float) + +# sphinx_gallery_end_ignore + +#%% +# --------------------------------------------------------------------------- +# 2-D image update (Python pre-render, all four libraries) +# --------------------------------------------------------------------------- + +# sphinx_gallery_start_ignore +_size_labels_2d = [f"{s}²" for s in _SIZES_2D] +_heights_2d = _results_to_array(results_2d, _SIZES_2D) + +fig2d, ax2d = apl.subplots(1, 1, figsize=(900, 480)) +ax2d.bar( + _size_labels_2d, + _heights_2d, + group_labels=[_LABELS[lib] for lib in _LIBRARIES], + group_colors=[_COLORS[lib] for lib in _LIBRARIES], + log_scale=True, + show_values=False, + width=0.85, + y_units="ms per call (log scale)", + units="Array size", +) +fig2d +# sphinx_gallery_end_ignore + +# %% +# --------------------------------------------------------------------------- +# 1-D line update (Python pre-render, all four libraries) +# --------------------------------------------------------------------------- + +# sphinx_gallery_start_ignore + +_size_labels_1d = [f"{n:,}" for n in _SIZES_1D] +_heights_1d = _results_to_array(results_1d, _SIZES_1D) + +fig1d, ax1d = apl.subplots(1, 1, figsize=(900, 480)) +ax1d.bar( + _size_labels_1d, + _heights_1d, + group_labels=[_LABELS[lib] for lib in _LIBRARIES], + group_colors=[_COLORS[lib] for lib in _LIBRARIES], + log_scale=True, + show_values=False, + width=0.85, + y_units="ms per call (log scale)", + units="Number of points", +) +fig1d +# sphinx_gallery_end_ignore + +# %% +# anyplotlib: Python prep vs JS canvas render +# ------------------------------------------- +# +# The two charts above show only the Python-side cost. The charts below add +# the JS render time for anyplotlib measured inside a real Chromium renderer +# via Playwright (``window._aplTiming`` populated by ``_recordFrame()`` in +# ``figure_esm.js``). The sum of both bars is the **total time-to-pixel** +# for an anyplotlib update. +# +# For ipympl, Plotly, and Bokeh the browser render cost is additional but not +# captured here — measuring it requires running their respective JS engines in +# a live browser session. +# +# .. note:: +# +# If Playwright is not installed the JS bars are absent (zero height) and +# a ``UserWarning`` is emitted at import time. Install Playwright +# (``pip install playwright && playwright install chromium``) to populate +# the JS timing columns. +# +# 2D Image Plotting Costs +# ----------------------- + +# sphinx_gallery_start_ignore + +_apl_py_2d = np.array([results_2d["anyplotlib"].get(s, 0.0) or 0.0 + for s in _SIZES_2D]) +_apl_js_2d = np.array([results_2d_js.get(s) or 0.0 for s in _SIZES_2D]) +_breakdown_2d = np.column_stack([_apl_py_2d, _apl_js_2d]) + +fig_bd2d, ax_bd2d = apl.subplots(1, 1, figsize=(700, 400)) +ax_bd2d.bar( + _size_labels_2d, + _breakdown_2d, + group_labels=["Python prep", "JS canvas render"], + group_colors=["#1976D2", "#4CAF50"], + log_scale=True, + show_values=False, + width=0.7, + y_units="ms per call (log scale)", + units="Array size — anyplotlib 2-D imshow", +) +fig_bd2d +# sphinx_gallery_end_ignore + +#%% +# Scatter Plotting Costs +# ------------------------- + +# sphinx_gallery_start_ignore +_apl_py_1d = np.array([results_1d["anyplotlib"].get(n, 0.0) or 0.0 + for n in _SIZES_1D]) +_apl_js_1d = np.array([results_1d_js.get(n) or 0.0 for n in _SIZES_1D]) +_breakdown_1d = np.column_stack([_apl_py_1d, _apl_js_1d]) + +fig_bd1d, ax_bd1d = apl.subplots(1, 1, figsize=(700, 400)) +ax_bd1d.bar( + _size_labels_1d, + _breakdown_1d, + group_labels=["Python prep", "JS canvas render"], + group_colors=["#1976D2", "#4CAF50"], + log_scale=True, + show_values=False, + width=0.7, + y_units="ms per call (log scale)", + units="Number of points — anyplotlib 1-D line", +) +fig_bd1d +# sphinx_gallery_end_ignore + +#%% diff --git a/Examples/Interactive/README.rst b/Examples/Interactive/README.rst new file mode 100644 index 00000000..3379ea8c --- /dev/null +++ b/Examples/Interactive/README.rst @@ -0,0 +1,6 @@ +Interactive Examples +==================== + +Examples that use the callback / event system to connect widget +interactions to live Python computations. + diff --git a/Examples/Interactive/plot_3d_spectral_viewer.py b/Examples/Interactive/plot_3d_spectral_viewer.py new file mode 100644 index 00000000..be714995 --- /dev/null +++ b/Examples/Interactive/plot_3d_spectral_viewer.py @@ -0,0 +1,229 @@ +""" +Interactive 3D Spectral Viewer +============================== + +A side-by-side viewer for a 3-D ``(y, x, energy)`` dataset. + +* **Left panel** — 2-D projection image (sum over the energy axis). + A draggable crosshair ROI selects the pixel whose spectrum appears on + the right. Press **i** to switch to an 8 × 8-pixel rectangle ROI + that integrates the enclosed area; press **i** again to revert. +* **Right panel** — 1-D spectrum extracted at the current ROI. Press + **s** to overlay an energy-span widget; on release the 2-D image + recomputes as the sum over the selected energy window. Press **s** + again to remove the span and restore the full-sum image. + +**Key bindings** + +.. list-table:: + :header-rows: 1 + :widths: 10 10 80 + + * - Panel + - Key + - Action + * - Image + - ``i`` + - Toggle crosshair / 8x8-px rectangle ROI. + Rectangle snaps to the pixel grid and integrates the spectrum live. + Press again to revert. + * - Spectrum + - ``s`` + - Add/remove an energy-span filter. + The 2-D image updates on release to show the sum over the selected + energy window. Press again to restore the full-sum image. + * - Both + - ``r`` + - Reset zoom / pan. +""" + +import numpy as np +import anyplotlib as apl + +# ── Synthetic (NY, NX, NE) dataset ───────────────────────────────────────── +rng = np.random.default_rng(7) + +NY, NX, NE = 64, 64, 256 +energy = np.linspace(100, 900, NE) # physical energy axis (eV) + +yy, xx = np.mgrid[0:NY, 0:NX] # spatial index grids + + +def _gauss2d(cx, cy, sigma): + return np.exp(-((xx - cx) ** 2 + (yy - cy) ** 2) / (2 * sigma ** 2)) + + +def _gauss1d(e, mu, sigma): + return np.exp(-0.5 * ((e - mu) / sigma) ** 2) + + +# Three Gaussian peaks with spatially-varying amplitudes +_peaks = [ + dict(e_mu=280.0, e_sig=18.0, cx=18, cy=18, sig2d=14), + dict(e_mu=500.0, e_sig=22.0, cx=46, cy=20, sig2d=13), + dict(e_mu=710.0, e_sig=28.0, cx=32, cy=48, sig2d=16), +] + +data = np.zeros((NY, NX, NE), dtype=np.float32) +for _p in _peaks: + _amp = _gauss2d(_p["cx"], _p["cy"], _p["sig2d"]) # (NY, NX) + _sp = _gauss1d(energy, _p["e_mu"], _p["e_sig"]) # (NE,) + data += (_amp[:, :, np.newaxis] * _sp[np.newaxis, np.newaxis, :]).astype(np.float32) + +data += rng.normal(scale=0.02, size=data.shape).astype(np.float32) + +img_full = data.sum(axis=-1).astype(float) # full-energy projection (NY, NX) + +# Initial ROI centre +CX0, CY0 = NX // 2, NY // 2 + +# ── Figure layout ─────────────────────────────────────────────────────────── +fig, (ax_img, ax_spec) = apl.subplots( + 1, 2, + figsize=(950, 460), + help=( + "Image — drag crosshair to pick a spectrum\n" + " — press i: toggle crosshair / 8×8 rectangle ROI\n" + "Spectrum — press s: add/remove energy-span filter" + ), +) + +# ── Left: 2-D projection image ────────────────────────────────────────────── +v_img = ax_img.imshow(img_full) +v_img.set_colormap("viridis") + +# ── Right: 1-D spectrum at initial position ───────────────────────────────── +v_spec = ax_spec.plot( + data[CY0, CX0, :].astype(float), + axes=[energy], + units="eV", + y_units="Intensity (a.u.)", + color="#4fc3f7", + linewidth=1.5, +) + +# ── Shared state (lists so closures can mutate them) ──────────────────────── +wid = [None] # active 2-D ROI widget +mode = ["crosshair"] # "crosshair" or "rectangle" +span_wid = [None] # active energy-span widget (or None) +_syncing = [False] # echo-loop guard for rectangle snap + +ROI_PX = 8 # rectangle ROI fixed size (pixels) + + +# ── Helpers ───────────────────────────────────────────────────────────────── + +def _snap_rect(x_raw, y_raw): + """Snap top-left corner to the nearest integer pixel, clamped to bounds.""" + x0 = int(np.clip(round(float(x_raw)), 0, NX - ROI_PX)) + y0 = int(np.clip(round(float(y_raw)), 0, NY - ROI_PX)) + return x0, y0 + + +def _wire_crosshair(w): + """Register pointer_move handler: update spectrum on every drag frame.""" + @w.add_event_handler("pointer_move") + def _ch_moved(event): + cx = int(np.clip(round(event.source.cx), 0, NX - 1)) + cy = int(np.clip(round(event.source.cy), 0, NY - 1)) + v_spec.set_data(data[cy, cx, :].astype(float), x_axis=energy) + + +def _wire_rectangle(w): + """Register pointer_move handler: snap widget to grid, integrate 8×8 region live.""" + @w.add_event_handler("pointer_move") + def _rect_moved(event): + if _syncing[0]: + return + _syncing[0] = True + try: + x0, y0 = _snap_rect( + event.source.x, + event.source.y, + ) + # Push snapped, fixed-size position back so the widget visually + # snaps to the pixel grid and stays exactly 8×8. + w.set(x=float(x0), y=float(y0), w=float(ROI_PX), h=float(ROI_PX)) + spec = data[y0:y0 + ROI_PX, x0:x0 + ROI_PX, :].mean(axis=(0, 1)) + v_spec.set_data(spec.astype(float), x_axis=energy) + finally: + _syncing[0] = False + + +# ── Install initial crosshair ──────────────────────────────────────────────── +wid[0] = v_img.add_widget( + "crosshair", + cx=float(CX0), cy=float(CY0), + color="#69f0ae", +) +_wire_crosshair(wid[0]) + + +# ── "i" — toggle crosshair ↔ 8×8 rectangle ───────────────────────────────── +@v_img.add_event_handler("key_down") +def _toggle_roi(event): + if event.key != 'i': + return + cur = wid[0] + v_img.remove_widget(cur) # remove old widget (Python ref still valid) + + if mode[0] == "crosshair": + # Preserve crosshair centre as rectangle anchor + cx_cur = float(cur.get("cx", CX0)) + cy_cur = float(cur.get("cy", CY0)) + x0, y0 = _snap_rect(cx_cur - ROI_PX / 2, cy_cur - ROI_PX / 2) + new_w = v_img.add_widget( + "rectangle", + x=float(x0), y=float(y0), + w=float(ROI_PX), h=float(ROI_PX), + color="#ffeb3b", + ) + _wire_rectangle(new_w) + wid[0] = new_w + mode[0] = "rectangle" + else: + # Restore crosshair at centre of old rectangle + rx = float(cur.get("x", CX0 - ROI_PX // 2)) + ry = float(cur.get("y", CY0 - ROI_PX // 2)) + cx_cur = rx + ROI_PX / 2 + cy_cur = ry + ROI_PX / 2 + new_w = v_img.add_widget( + "crosshair", + cx=float(np.clip(cx_cur, 0, NX - 1)), + cy=float(np.clip(cy_cur, 0, NY - 1)), + color="#69f0ae", + ) + _wire_crosshair(new_w) + wid[0] = new_w + mode[0] = "crosshair" + + +# ── "s" (spectrum panel) — add / remove energy-span filter ────────────────── +@v_spec.add_event_handler("key_down") +def _toggle_span(event): + if event.key != 's': + return + if span_wid[0] is None: + # Place span at 35 %–65 % of the energy range by default + e0 = float(energy[int(NE * 0.35)]) + e1 = float(energy[int(NE * 0.65)]) + sw = v_spec.add_range_widget(x0=e0, x1=e1, color="#ff7043") + span_wid[0] = sw + + @sw.add_event_handler("pointer_up") + def _span_released(ev): + x0_e = ev.source.x0 + x1_e = ev.source.x1 + if x0_e > x1_e: + x0_e, x1_e = x1_e, x0_e + mask = (energy >= x0_e) & (energy <= x1_e) + new_img = data[..., mask].sum(axis=-1).astype(float) if mask.any() else img_full + v_img.set_data(new_img) + else: + v_spec.remove_widget(span_wid[0]) + span_wid[0] = None + v_img.set_data(img_full) # restore full-energy projection + + +fig # Interactive + diff --git a/Examples/Interactive/plot_eels_explorer.py b/Examples/Interactive/plot_eels_explorer.py new file mode 100644 index 00000000..8c96fce7 --- /dev/null +++ b/Examples/Interactive/plot_eels_explorer.py @@ -0,0 +1,211 @@ +""" +EELS multi-spectrum explorer. +============================== + +Five synthetic EELS spectra (Carbon-rich, Nitride, Oxide, Silicide, +Mixed) stacked vertically on a single axis, each with known +characteristic edges and a power-law background. + +**Interaction** + +* **Click** a spectrum line — selects it (full opacity; others dim to + 25 %). +* **Dwell 250 ms** — shows eV position and intensity; nearby known + edges (C K, N K, O K, Ti L) are annotated. +* **Double-click** — places a permanent vertical edge marker on the + active spectrum. +* **Delete / Backspace** — removes the most recent marker on the + active spectrum. +* **Tab / Shift+Tab** — cycles the selection forward / backward. +""" +import numpy as np +import anyplotlib as apl + + +# ── synthetic data ───────────────────────────────────────────────────────────── + +ENERGY = np.linspace(50, 650, 1200) + +KNOWN_EDGES = {"C K": 284.0, "N K": 401.0, "O K": 532.0, "Ti L": 456.0} + +_SPECTRUM_DEFS = [ + {"name": "Carbon-rich", "color": "#4fc3f7", "edges": [("C K", 284, 0.6)]}, + {"name": "Nitride", "color": "#aed581", "edges": [("N K", 401, 0.5)]}, + {"name": "Oxide", "color": "#ff8a65", "edges": [("O K", 532, 0.7)]}, + {"name": "Silicide", "color": "#ba68c8", "edges": [("Si L", 99, 0.3)]}, + {"name": "Mixed", "color": "#fff176", "edges": [("C K", 284, 0.2), ("O K", 532, 0.15)]}, +] + + +def _power_law_bg(E, A=1e4, r=3.5): + return A * E ** (-r) + + +def _edge_onset(E, edge_ev, amplitude, width=20.0, decay=80.0): + onset = amplitude * (np.arctan((E - edge_ev) / (width / 6)) / np.pi + 0.5) + envelope = np.exp(-np.clip(E - edge_ev, 0, None) / decay) + return onset * envelope + + +def _make_spectrum(rng, defn, offset_y): + E = ENERGY + y = _power_law_bg(E) + for _, edge_ev, amp_frac in defn["edges"]: + peak = y.max() * amp_frac + y += _edge_onset(E, edge_ev, peak) + y += rng.normal(0, y.max() * 0.005, size=len(E)) + y = np.clip(y, 0, None) + y = y / y.max() + return y + offset_y + + +rng = np.random.default_rng(7) +spectra_y = [] +offset = 0.0 +for defn in _SPECTRUM_DEFS: + y = _make_spectrum(rng, defn, offset) + spectra_y.append(y) + offset += 1.2 * (y - offset).max() + + +# ── helpers ──────────────────────────────────────────────────────────────────── + +def _safe_remove(plot, marker_type: str, name: str) -> None: + try: + plot.remove_marker(marker_type, name) + except KeyError: + pass + + +# ── figure ───────────────────────────────────────────────────────────────────── + +# spectrum 0 is the primary line; spectra 1-4 are overlay lines +fig, ax = apl.subplots(1, 1, figsize=(800, 500)) +plot = ax.plot(spectra_y[0], axes=[ENERGY], color=_SPECTRUM_DEFS[0]["color"], linewidth=2.5) + +# overlay_lines[i] is the Line1D handle for spectrum i (None for the primary) +overlay_lines = [] +for i in range(1, len(_SPECTRUM_DEFS)): + defn = _SPECTRUM_DEFS[i] + line = plot.add_line(spectra_y[i], x_axis=ENERGY, color=defn["color"], linewidth=1.0) + overlay_lines.append(line) + +# spectra index → Line1D (or None for primary) +# lines[0] == None means "primary line", lines[1..] == Line1D handles +line_handles = [None] + overlay_lines # len == len(_SPECTRUM_DEFS) + +active_idx: int = 0 +markers_per_spectrum: list[list[str]] = [[] for _ in _SPECTRUM_DEFS] +_marker_counter = [0] + +info_label_mg = plot.add_texts( + offsets=np.array([[ENERGY[600], spectra_y[0][600]]]), + texts=[""], + name="info_label", + color="#00e5ff", + fontsize=11, +) + + +# ── selection helpers ─────────────────────────────────────────────────────────── + +def _set_overlay_line_props(lid: str, linewidth: float, alpha: float) -> None: + """Directly mutate an overlay line's entry in plot._state and push.""" + for entry in plot._state["extra_lines"]: + if entry["id"] == lid: + entry["linewidth"] = float(linewidth) + entry["alpha"] = float(alpha) + break + plot._push() + + +def _apply_selection(new_idx: int) -> None: + global active_idx + active_idx = new_idx + for i, handle in enumerate(line_handles): + if i == active_idx: + lw, alpha = 2.5, 1.0 + else: + lw, alpha = 1.0, 0.25 + if handle is None: + # primary line — use Plot1D setters + plot.set_linewidth(lw) + plot.set_alpha(alpha) + else: + _set_overlay_line_props(handle._lid, lw, alpha) + print(f"Selected: {_SPECTRUM_DEFS[active_idx]['name']}") + + +_apply_selection(0) + + +# ── event handlers ───────────────────────────────────────────────────────────── + +def _make_line_handler(idx: int): + def _handler(event) -> None: + _apply_selection(idx) + return _handler + + +# primary line click handler — line_id is None for the primary +plot.line.add_event_handler(_make_line_handler(0), "pointer_down") + +# overlay line click handlers +for i, handle in enumerate(overlay_lines, start=1): + handle.add_event_handler(_make_line_handler(i), "pointer_down") + + +def _on_settled(event) -> None: + if event.xdata is None: + return + ev = event.xdata + intensity = float(np.interp(ev, ENERGY, spectra_y[active_idx])) + label = f"eV: {ev:.1f} I: {intensity:.3f}" + for edge_name, edge_ev in KNOWN_EDGES.items(): + if abs(ev - edge_ev) < 15: + label += f"\n~ {edge_name}-edge" + y_pos = intensity + 0.05 + plot.markers["texts"]["info_label"].set( + offsets=np.array([[ev, y_pos]]), + texts=[label], + ) + + +def _on_double_click(event) -> None: + ev = event.xdata + _marker_counter[0] += 1 + name = f"edge_{active_idx}_{_marker_counter[0]}" + plot.add_vlines([ev], name=name) + markers_per_spectrum[active_idx].append(name) + print(f"Edge marker placed at {ev:.1f} eV on '{_SPECTRUM_DEFS[active_idx]['name']}'") + + +def _on_key(event) -> None: + global active_idx + if event.key in ("Delete", "Backspace"): + if not markers_per_spectrum[active_idx]: + return + name = markers_per_spectrum[active_idx].pop() + _safe_remove(plot, "vlines", name) + elif event.key == "Tab": + n = len(_SPECTRUM_DEFS) + if "shift" in event.modifiers: + new_idx = (active_idx - 1) % n + else: + new_idx = (active_idx + 1) % n + _apply_selection(new_idx) + + +plot.add_event_handler(_on_settled, "pointer_settled", ms=250) +plot.add_event_handler(_on_double_click, "double_click") +plot.add_event_handler(_on_key, "key_down") + +fig.set_help( + "Click a spectrum: select it\n" + "Dwell 250 ms: inspect eV + intensity\n" + "Double-click: place edge marker\n" + "Delete / Backspace: remove last marker\n" + "Tab / Shift+Tab: cycle selection" +) + +fig # interactive diff --git a/Examples/Interactive/plot_interactive_fft.py b/Examples/Interactive/plot_interactive_fft.py new file mode 100644 index 00000000..15a47b03 --- /dev/null +++ b/Examples/Interactive/plot_interactive_fft.py @@ -0,0 +1,180 @@ +""" +Interactive FFT ROI +=================== + +A draggable rectangle widget on a real-space image drives a live 2-D FFT +of the selected region, displayed in a side-by-side panel. + +**How it works** + +* The left panel shows a synthetic real-space image (a periodic lattice with + noise, similar to an atomic-resolution STEM image). +* A yellow rectangle widget marks the region-of-interest (ROI). +* Whenever the ROI is moved or resized the ``pointer_up`` event handler + re-computes ``numpy.fft.fft2`` on the cropped pixels, applies a Hann + window to reduce edge ringing, takes the log-magnitude, and pushes the + result into the right panel with :meth:`~anyplotlib.plot2d.Plot2D.update`. +* A second ``pointer_move`` event handler updates a lightweight text + readout (ROI size in pixels) on every drag frame without re-running + the FFT. + +**Interaction** + +* Drag the rectangle body to move the ROI. +* Drag any corner handle to resize it. +* The FFT panel refreshes automatically on mouse-release. + +.. note:: + The ``pointer_up`` / ``pointer_move`` event handlers are pure Python — + no kernel restart is needed after editing them. +""" + +import numpy as np +import anyplotlib as apl + +# ── Synthetic real-space image ──────────────────────────────────────────────── +# Periodic lattice (two overlapping sinusoidal gratings) + Gaussian envelope +# + shot noise. Mimics a crystalline region in an electron-microscopy image. + +N = 256 # image size (pixels) +rng = np.random.default_rng(42) + +x = np.arange(N) +XX, YY = np.meshgrid(x, x) + +# Two lattice periodicities (pixels) +a1, a2 = 22, 14 +theta = np.deg2rad(30) + +lattice = ( + np.cos(2 * np.pi * (XX * np.cos(theta) + YY * np.sin(theta)) / a1) + + 0.6 * np.cos(2 * np.pi * (XX * np.cos(theta + np.pi / 3) + + YY * np.sin(theta + np.pi / 3)) / a2) +) + +# Gaussian envelope (brighter in centre) +cx, cy = N // 2, N // 2 +gauss = np.exp(-((XX - cx) ** 2 + (YY - cy) ** 2) / (2 * (N * 0.35) ** 2)) + +image = gauss * lattice + rng.normal(scale=0.08, size=(N, N)) + +# Normalise to [0, 1] +image = (image - image.min()) / (image.max() - image.min()) + +# Physical axis: 0.1 Å / pixel +scale = 0.1 # Å per pixel +xy_px = np.arange(N) * scale # physical axis in Å + +# ── Figure layout: real-space (left) | FFT (right) ─────────────────────────── +fig, (ax_real, ax_fft) = apl.subplots( + 1, 2, + figsize=(900, 460), + sharex=False, + sharey=False, +) + +# ── Left panel: real-space image ────────────────────────────────────────────── +v_real = ax_real.imshow(image, axes=[xy_px, xy_px], units="Å") +v_real.set_colormap("gray") + +# Initial ROI: centred, 64 × 64 px +ROI_W, ROI_H = 64, 64 +roi_x0 = (N - ROI_W) // 2 # pixel coords (top-left corner) +roi_y0 = (N - ROI_H) // 2 + +wid = v_real.add_widget( + "rectangle", + color="#ffeb3b", + x=float(roi_x0), + y=float(roi_y0), + w=float(ROI_W), + h=float(ROI_H), +) + +# ── Right panel: FFT magnitude ──────────────────────────────────────────────── +def _compute_fft(img_full, x0, y0, w, h): + """Crop, window and FFT a region of *img_full*. + + Parameters + ---------- + img_full : ndarray, shape (N, N) – full real-space image (float) + x0, y0 : float – top-left corner of rectangle in pixel coords + w, h : float – width and height in pixels + + Returns + ------- + log_mag : ndarray – log10(1 + |FFT|), shifted so DC is at centre + freq_x : ndarray – spatial-frequency axis (1/Å), shape (w_int,) + freq_y : ndarray – spatial-frequency axis (1/Å), shape (h_int,) + """ + ih, iw = img_full.shape + + # Clamp ROI to image bounds + x0i = max(0, int(round(x0))) + y0i = max(0, int(round(y0))) + x1i = min(iw, x0i + max(1, int(round(w)))) + y1i = min(ih, y0i + max(1, int(round(h)))) + + crop = img_full[y0i:y1i, x0i:x1i].copy() + ch, cw = crop.shape + if ch < 2 or cw < 2: + # ROI too small — return a blank placeholder + blank = np.zeros((4, 4)) + f = np.fft.fftfreq(4, d=scale) + return blank, f, f + + # Hann window to suppress edge ringing + win_y = np.hanning(ch) + win_x = np.hanning(cw) + crop *= win_y[:, None] * win_x[None, :] + + # 2-D FFT → log magnitude, DC centred + fft2 = np.fft.fftshift(np.fft.fft2(crop)) + log_mag = np.log1p(np.abs(fft2)) + + # Spatial-frequency axes (cycles per Å) + freq_x = np.fft.fftshift(np.fft.fftfreq(cw, d=scale)) + freq_y = np.fft.fftshift(np.fft.fftfreq(ch, d=scale)) + + return log_mag, freq_x, freq_y + + +# Compute initial FFT and display it +_fft_init, _fx_init, _fy_init = _compute_fft(image, roi_x0, roi_y0, ROI_W, ROI_H) +v_fft = ax_fft.imshow(_fft_init, axes=[_fx_init, _fy_init], units="1/Å") +v_fft.set_colormap("inferno") + +# ── Callbacks ───────────────────────────────────────────────────────────────── + +@wid.add_event_handler("pointer_move") +def _roi_dragging(event): + """Fires on every drag frame — highlight rectangle while dragging.""" + # Cheaply pulse the widget colour to give live drag feedback. + for w in v_real._state["overlay_widgets"]: + if w["id"] == wid._id: + w["color"] = "#ff9800" # orange while dragging + break + v_real._push() + + +@wid.add_event_handler("pointer_up") +def _roi_released(event): + """Fires once on mouse-up — recompute and push the full FFT.""" + x0 = event.source.x + y0 = event.source.y + w = event.source.w + h = event.source.h + + # Restore widget colour to yellow + for widget in v_real._state["overlay_widgets"]: + if widget["id"] == wid._id: + widget["color"] = "#ffeb3b" + break + + log_mag, freq_x, freq_y = _compute_fft(image, x0, y0, w, h) + + # Push updated FFT into the right panel + v_fft.set_data(log_mag, x_axis=freq_x, y_axis=freq_y, units="1/\u00c5") + + +fig # Interactive diff --git a/Examples/Interactive/plot_interactive_fitting.py b/Examples/Interactive/plot_interactive_fitting.py new file mode 100644 index 00000000..02916a5d --- /dev/null +++ b/Examples/Interactive/plot_interactive_fitting.py @@ -0,0 +1,297 @@ +""" +Interactive 1-D Gaussian Fitting +================================= + +A noisy composite signal built from two Gaussians is displayed. Two +additional overlay lines show the individual **component** curves and a +white **sum** curve that always equals the current manual model. + +**Interaction** + +Click any coloured component line to reveal its control widgets: + +* **Circular handle** — drag to move the peak centre (μ) and amplitude (A). +* **Shaded range** — drag either edge to widen or narrow the width (σ). + +The sum curve updates on every drag frame. +Press **f** (with the plot canvas focused) to run a least-squares fit. +The components — and all active widgets — will snap to the fitted values, +and the sum curve will jump to the optimal fit. +Click a component line again to hide its widgets. +""" +# Packages required when running interactively in Pyodide (docs live mode). +_PYODIDE_PACKAGES = ["scipy"] + +import numpy as np +from scipy.optimize import curve_fit +import anyplotlib as apl + +# ── Gaussian helpers ─────────────────────────────────────────────────────── + +def gaussian(x, amp, mu, sigma): + return amp * np.exp(-0.5 * ((x - mu) / sigma) ** 2) + +# Half-width at half-maximum = sigma * _FWHM_K (full FWHM = 2 * sigma * _FWHM_K) +_FWHM_K = np.sqrt(2.0 * np.log(2.0)) + +# ── Data ─────────────────────────────────────────────────────────────────── + +x = np.linspace(0, 10, 500) + +TRUE_P = [ + dict(amp=1.0, mu=3.2, sigma=0.55), + dict(amp=0.75, mu=6.8, sigma=0.80), +] +COLORS = ["#ff6b6b", "#69db7c"] + +rng = np.random.default_rng(42) +signal = sum(gaussian(x, **p) for p in TRUE_P) + rng.normal(0, 0.03, len(x)) + +# Initial component guesses (slightly off from truth) +INIT_P = [ + dict(amp=1.0, mu=3.0, sigma=0.6), + dict(amp=0.7, mu=7.0, sigma=0.9), +] + +# ── Figure ───────────────────────────────────────────────────────────────── + +fig, ax = apl.subplots(1, 1, figsize=(720, 380), + help="Click a coloured line → show/hide its widgets\n" + "Drag circle handle → move peak center (μ) and amplitude (A)\n" + "Drag range edge → widen / narrow the width (σ)\n" + "press: f → run least-squares fit") +plot = ax.plot(signal, axes=[x], color="#adb5bd", linewidth=1.5, + alpha=0.6, label="data") +# +# Live sum of all components — this IS the fit after pressing 'f' +sum_line = plot.add_line( + sum(gaussian(x, **p) for p in INIT_P), x_axis=x, + color="#e0e0e0", linewidth=1.5, linestyle="dashed", label="sum", +) + +comp_lines = [ + plot.add_line(gaussian(x, **p), x_axis=x, + color=c, linewidth=2.0, + label=f"comp {i+1}") + for i, (p, c) in enumerate(zip(INIT_P, COLORS)) +] + + +# ── GaussianComponent ────────────────────────────────────────────────────── + +class GaussianComponent: + """Manages a PointWidget (peak) + RangeWidget (σ) for one component. + + Assign ``.model`` after constructing the ``Model`` so the component + can notify it on every drag frame. + """ + + def __init__(self, line, p, color): + self.line = line + self.amp = p["amp"] + self.mu = p["mu"] + self.sigma = p["sigma"] + self.color = color + self.model = None # injected after Model is constructed + self._active = False + self._syncing = False # guard against callback loops + self._pt = None # PointWidget — created once on first toggle + self._rng_w = None # RangeWidget + + def component_y(self): + return gaussian(x, self.amp, self.mu, self.sigma) + + def toggle(self): + if self._active: + self._pt.hide() + self._rng_w.hide() + self._active = False + else: + if self._pt is None: + self._pt = plot.add_point_widget(self.mu, self.amp, + color=self.color, + show_crosshair=False) + self._rng_w = plot.add_range_widget( + self.mu - self.sigma * _FWHM_K, + self.mu + self.sigma * _FWHM_K, + y=self.amp / 2.0, + color=self.color, + style="fwhm", + ) + self._wire() + else: + self._pt.show() + self._rng_w.show() + self._active = True + + def _wire(self): + @self._pt.add_event_handler("pointer_move") + def _peak_moved(event): + if self._syncing: + return + self._syncing = True + try: + self.amp = event.source.y + self.mu = event.source.x + self._rng_w.set(x0=self.mu - self.sigma * _FWHM_K, + x1=self.mu + self.sigma * _FWHM_K, + y=self.amp / 2.0) + self.line.set_data(self.component_y()) + if self.model: + self.model.update() + finally: + self._syncing = False + + @self._rng_w.add_event_handler("pointer_move") + def _range_moved(event): + if self._syncing: + return + self._syncing = True + try: + x0, x1 = event.source.x0, event.source.x1 + self.mu = (x0 + x1) / 2.0 + self.sigma = abs(x1 - x0) / (2.0 * _FWHM_K) + self._pt.set(x=self.mu) + self.line.set_data(self.component_y()) + if self.model: + self.model.update() + finally: + self._syncing = False + + def snap(self, amp: float, mu: float, sigma: float) -> None: + """Update parameters and snap **all** widgets to the new values. + + Creates and shows the point and FWHM range widgets if they do not + exist yet (so pressing **f** always reveals the fitted widths), then + updates their positions. Uses the ``_syncing`` guard so widget + callbacks do not fire during the programmatic update. + """ + self._syncing = True + try: + self.amp = amp + self.mu = mu + self.sigma = sigma + self.line.set_data(self.component_y()) + if self._pt is None: + # First fit — create widgets at the fitted position and show them. + self._pt = plot.add_point_widget(self.mu, self.amp, + color=self.color, + show_crosshair=False) + self._rng_w = plot.add_range_widget( + self.mu - self.sigma * _FWHM_K, + self.mu + self.sigma * _FWHM_K, + y=self.amp / 2.0, + color=self.color, + style="fwhm", + ) + self._wire() + self._active = True + else: + # Widgets already exist — move them to the new fitted position. + self._pt.set(x=self.mu, y=self.amp) + self._rng_w.set(x0=self.mu - self.sigma * _FWHM_K, + x1=self.mu + self.sigma * _FWHM_K, + y=self.amp / 2.0) + # If the user had hidden the widgets, bring them back. + if not self._active: + self._pt.show() + self._rng_w.show() + self._active = True + finally: + self._syncing = False + +# ── Model ────────────────────────────────────────────────────────────────── + +class Model: + """A list of GaussianComponents with a live sum line. + + ``update()`` redraws the sum line from the current component state and + is called on every drag frame. + + ``fit()`` runs a least-squares fit, snaps every component (and its + widgets) to the optimal parameters, then calls ``update()`` so the sum + line jumps to the best fit. It is also triggered by pressing **f**. + + Parameters + ---------- + components : list[GaussianComponent] + sum_line : Line1D + Always-live manual-sum / fit-result overlay. + x_data, y_data : ndarray + Observed signal to fit against. + """ + + def __init__(self, components, sum_line, x_data, y_data): + self.components = list(components) + self.sum_line = sum_line + self.x_data = x_data + self.y_data = y_data + + def update(self): + """Redraw the sum line as the manual sum of all components.""" + self.sum_line.set_data( + sum(c.component_y() for c in self.components) + ) + + def fit(self): + """Least-squares fit; snaps components and FWHM widgets to the result. + + Builds a generic n-Gaussian model from the component list and uses + their current state as the initial guess. On success every component + snaps to the fitted (amp, μ, σ): the component line, the peak handle, + **and** the FWHM range widget are all moved to the optimal values. + If a component's widgets have not been shown yet they are created and + revealed automatically. The sum line redraws as the best fit. + On failure the components are left unchanged. + """ + n = len(self.components) + p0 = [v for c in self.components for v in (c.amp, c.mu, c.sigma)] + lo = [v for c in self.components for v in (0, self.x_data[0], 1e-3)] + hi = [v for c in self.components + for v in (np.inf, self.x_data[-1], + self.x_data[-1] - self.x_data[0])] + + def _model_fn(x, *params): + return sum( + gaussian(x, params[3 * i], params[3 * i + 1], params[3 * i + 2]) + for i in range(n) + ) + + try: + popt, _ = curve_fit( + _model_fn, self.x_data, self.y_data, + p0=p0, bounds=(lo, hi), maxfev=3000 * n, + ) + for i, comp in enumerate(self.components): + comp.snap(popt[3 * i], popt[3 * i + 1], popt[3 * i + 2]) + self.update() + except RuntimeError: + pass # leave components unchanged if fit did not converge + +# ── Assemble ─────────────────────────────────────────────────────────────── + +components = [ + GaussianComponent(comp_lines[i], INIT_P[i], COLORS[i]) + for i in range(2) +] + +model = Model(components, sum_line, x, signal) +for comp in components: + comp.model = model + +# ── Key binding — press 'f' to fit ───────────────────────────────────────── + +@plot.add_event_handler("key_down") +def _on_fit(event): + if event.key != 'f': + return + model.fit() + +# ── Click handlers — toggle widgets per component ───────────────────────── + +for comp, line in zip(components, comp_lines): + @line.add_event_handler("pointer_down") + def _clicked(event, c=comp): + c.toggle() + +fig # Interactive \ No newline at end of file diff --git a/Examples/Interactive/plot_ipf_explorer.py b/Examples/Interactive/plot_ipf_explorer.py new file mode 100644 index 00000000..b3aed178 --- /dev/null +++ b/Examples/Interactive/plot_ipf_explorer.py @@ -0,0 +1,125 @@ +""" +Inverse Pole Figure (IPF) Explorer +================================== + +An EBSD-style orientation explorer for a synthetic polycrystal: + +* **Left panel** — IPF-Z orientation map, colored with the standard cubic + IPF key (red = ⟨001⟩, green = ⟨011⟩, blue = ⟨111⟩). Rendered as a + true-color RGB image. +* **Right panel** — the *reduced 3-D inverse pole figure*: every grain's + sample-Z direction, expressed in crystal coordinates and folded into the + cubic fundamental sector, plotted as an IPF-colored point cloud on a + shaded, wireframed unit sphere. + +Drag the crosshair on the map: the grain's orientation is marked with a +highlighted dot on the sphere, and the sphere **rotates so that direction +faces you**. Drag on the sphere to orbit freely; the next crosshair move +re-aims the camera. +""" + +import numpy as np +import anyplotlib as apl + +rng = np.random.default_rng(42) + +# ── 1. Synthetic polycrystal: nearest-seed grain map ──────────────────────── +H = W = 192 +N_GRAINS = 60 + +seeds = rng.uniform(0, [H, W], size=(N_GRAINS, 2)) +yy, xx = np.mgrid[0:H, 0:W] +d2 = (yy[..., None] - seeds[:, 0]) ** 2 + (xx[..., None] - seeds[:, 1]) ** 2 +grain_id = np.argmin(d2, axis=-1) # (H, W) labels + + +# ── 2. Random orientation per grain (uniform rotations via quaternions) ───── +def random_rotations(n): + """Uniform random rotation matrices, shape (n, 3, 3) (Shoemake method).""" + u1, u2, u3 = rng.random((3, n)) + q = np.stack([ + np.sqrt(1 - u1) * np.sin(2 * np.pi * u2), + np.sqrt(1 - u1) * np.cos(2 * np.pi * u2), + np.sqrt(u1) * np.sin(2 * np.pi * u3), + np.sqrt(u1) * np.cos(2 * np.pi * u3), + ], axis=1) # (n, 4) unit quats + x, y, z, w = q.T + return np.stack([ + np.stack([1 - 2 * (y * y + z * z), 2 * (x * y - z * w), 2 * (x * z + y * w)], -1), + np.stack([2 * (x * y + z * w), 1 - 2 * (x * x + z * z), 2 * (y * z - x * w)], -1), + np.stack([2 * (x * z - y * w), 2 * (y * z + x * w), 1 - 2 * (x * x + y * y)], -1), + ], axis=1) + + +rotations = random_rotations(N_GRAINS) + +# Sample-Z expressed in each grain's crystal frame: d = Rᵀ · ẑ +dirs = rotations[:, 2, :] # row 2 of R == Rᵀ·ẑ + +# ── 3. Reduce to the cubic fundamental sector and IPF-color ──────────────── +# For cubic symmetry, sorting |components| ascending lands every direction +# in the standard 001–011–111 stereographic triangle. +reduced = np.sort(np.abs(dirs), axis=1) # (a ≤ b ≤ c) +a, b, c = reduced.T + +# Classic IPF key: distance to each triangle corner → R, G, B +rgb = np.stack([c - b, b - a, a], axis=1) +rgb /= rgb.max(axis=1, keepdims=True) + 1e-12 # vivid normalisation +grain_rgb_u8 = (rgb * 255).astype(np.uint8) # (N_GRAINS, 3) + +ipf_map = grain_rgb_u8[grain_id] # (H, W, 3) true-color + + +# ── 4. Figure: RGB map + reduced 3-D IPF point cloud ─────────────────────── +fig, (ax_map, ax_ipf) = apl.subplots( + 1, 2, figsize=(880, 420), + help="Drag the crosshair: the sphere rotates to face that grain's\n" + "crystal direction. Drag the sphere to orbit freely.") + +vmap = ax_map.imshow(ipf_map) # (H, W, 3) → RGB +vmap.set_title("IPF-Z orientation map") +cross = vmap.add_widget("crosshair", cx=W // 2, cy=H // 2, color="#ffffff") + +# reduced directions live on the unit sphere → fix bounds to keep the +# origin centred and the geometry origin-true +vipf = ax_ipf.scatter3d( + reduced[:, 0], reduced[:, 1], reduced[:, 2], + colors=grain_rgb_u8, point_size=6, + x_label="[100]", y_label="[010]", z_label="[001]", + bounds=((-1, 1),) * 3, zoom=1.4, +) +vipf.set_title("Reduced 3D IPF (cubic fundamental sector)") +# Shaded unit sphere with lat/long wireframe behind the direction vectors +vipf.set_sphere(1.0) + + +# ── 5. Crosshair → highlight + rotate-to-face ─────────────────────────────── +def face_camera(v): + """(azimuth°, elevation°) that aim the camera straight down *v*. + + With the turntable camera, the view faces unit vector ``v`` when + ``el = asin(vz)`` and ``az = atan2(vx, -vy)``. + """ + vx, vy, vz = v + el = np.degrees(np.arcsin(np.clip(vz, -1.0, 1.0))) + az = np.degrees(np.arctan2(vx, -vy)) + return az, el + + +def show_orientation(gid: int) -> None: + v = reduced[gid] + vipf.set_highlight(*v, color="#ffffff", size=8) + az, el = face_camera(v) + vipf.set_view(azimuth=az, elevation=el) + + +@cross.add_event_handler("pointer_move") +def on_move(event): + ix = int(np.clip(round(cross.cx), 0, W - 1)) + iy = int(np.clip(round(cross.cy), 0, H - 1)) + show_orientation(int(grain_id[iy, ix])) + + +show_orientation(int(grain_id[H // 2, W // 2])) + +fig # Interactive diff --git a/Examples/Interactive/plot_key_bindings.py b/Examples/Interactive/plot_key_bindings.py new file mode 100644 index 00000000..6e19b19f --- /dev/null +++ b/Examples/Interactive/plot_key_bindings.py @@ -0,0 +1,128 @@ +""" +Key-Press Widget Placement +========================== + +Demonstrates the ``key_down`` event handler API: press a key while the plot +is focused to add an overlay widget centred on the current cursor position, +or press **Backspace / Delete** to remove the last widget you clicked. + +**Key bindings** + ++-------------------------------+---------------------------+ +| Key | Action | ++===============================+===========================+ +| ``q`` | Add a rectangle | ++-------------------------------+---------------------------+ +| ``w`` | Add a circle | ++-------------------------------+---------------------------+ +| ``e`` | Add an annulus | ++-------------------------------+---------------------------+ +| ``Backspace`` (macOS ⌫) | Remove last-clicked | +| ``Delete`` (Windows / Linux) | | ++-------------------------------+---------------------------+ + +**Built-in 2-D shortcuts** (not overridden in this example): + ++-------+---------------------------+ +| Key | Action | ++=======+===========================+ +| ``r`` | Reset zoom / pan | ++-------+---------------------------+ +| ``c`` | Toggle colorbar | ++-------+---------------------------+ +| ``l`` | Toggle log scale | ++-------+---------------------------+ +| ``s`` | Toggle symlog scale | ++-------+---------------------------+ + +The cursor coordinates are available as ``event.xdata`` and ``event.ydata`` +in image-pixel space (column, row), so widgets are centred exactly where +the cursor was when the key was pressed. + +.. note:: + Move the mouse over the image first so the plot panel receives focus, + then press a key. On macOS the backspace key (⌫) is used for deletion; + on Windows / Linux use the **Delete** key. +""" + +import numpy as np +import anyplotlib as apl + +# ── Synthetic test image ────────────────────────────────────────────────────── +rng = np.random.default_rng(0) +N = 256 +x = np.linspace(0, 4 * np.pi, N) +XX, YY = np.meshgrid(x, x) +data = np.sin(XX) * np.cos(YY) + 0.15 * rng.standard_normal((N, N)) + +# ── Figure ──────────────────────────────────────────────────────────────────── +fig, ax = apl.subplots(figsize=(520, 520)) +plot = ax.imshow(data) + +# ── Key handlers ───────────────────────────────────────────────────────────── + +@plot.add_event_handler("key_down") +def add_rectangle(event): + """Press 'q' — add a rectangle centred on the cursor.""" + if event.key != 'q': + return + cx, cy = event.xdata, event.ydata + half_w, half_h = N * 0.08, N * 0.08 + plot.add_widget( + "rectangle", + x=cx - half_w, y=cy - half_h, + w=half_w * 2, h=half_h * 2, + color="#ffd54f", + ) + + +@plot.add_event_handler("key_down") +def add_circle(event): + """Press 'w' — add a circle centred on the cursor.""" + if event.key != 'w': + return + plot.add_widget( + "circle", + cx=event.xdata, cy=event.ydata, + r=N * 0.07, + color="#80cbc4", + ) + + +@plot.add_event_handler("key_down") +def add_annulus(event): + """Press 'e' — add an annulus centred on the cursor.""" + if event.key != 'e': + return + plot.add_widget( + "annular", + cx=event.xdata, cy=event.ydata, + r_outer=N * 0.12, + r_inner=N * 0.06, + color="#ce93d8", + ) + + +# macOS sends 'Backspace' for the ⌫ key; Windows/Linux send 'Delete'. +# Register both so the example works cross-platform. +@plot.add_event_handler("key_down") +def delete_last(event): + """Press Backspace/Delete — remove the last widget that was clicked.""" + if event.key not in ('Backspace', 'Delete'): + return + wid = event.last_widget_id + if wid and wid in {w.id for w in plot.list_widgets()}: + plot.remove_widget(wid) + + +# ── Catch-all handler (optional) — log every registered key press ───────────── + +@plot.add_event_handler("key_down") +def log_key(event): + xdata = event.xdata + ydata = event.ydata + pos = f"({xdata:.1f}, {ydata:.1f})" if xdata is not None else "n/a" + print(f"[key_down] key={event.key!r} img={pos}" + f" last_widget={event.last_widget_id!r}") + +fig # Interactive diff --git a/Examples/Interactive/plot_particle_picker.py b/Examples/Interactive/plot_particle_picker.py new file mode 100644 index 00000000..72a1ee33 --- /dev/null +++ b/Examples/Interactive/plot_particle_picker.py @@ -0,0 +1,209 @@ +""" +HAADF STEM nanoparticle picker. +================================= + +Synthetic HAADF-STEM image with 18 Gaussian nanoparticles on a Poisson +noise background. Candidate peaks are detected automatically using a +7×7 local-maximum filter and marked with small grey circles. + +**Interaction** + +* **Dwell 300 ms** over a candidate — shows the sub-pixel centroid, + peak intensity, and estimated FWHM in a floating label. +* **Double-click** — confirms the pick (green ring). +* **Shift+double-click** — marks the pick as uncertain (orange ring). +* **Delete / Backspace** — removes the confirmed pick nearest the + cursor. +* **c** — clears all picks. +""" +import numpy as np +import anyplotlib as apl + + +# ── synthetic data ───────────────────────────────────────────────────────────── + +def _make_stem_image(rng: np.random.Generator) -> np.ndarray: + img = rng.poisson(lam=5, size=(512, 512)).astype(np.float32) + for _ in range(18): + cx, cy = rng.integers(30, 482, size=2) + sigma = rng.uniform(4, 9) + peak = rng.uniform(80, 200) + r = int(np.ceil(3 * sigma)) + y0, y1 = max(0, cy - r), min(512, cy + r + 1) + x0, x1 = max(0, cx - r), min(512, cx + r + 1) + ys = np.arange(y0, y1)[:, None] + xs = np.arange(x0, x1)[None, :] + img[y0:y1, x0:x1] += peak * np.exp( + -((xs - cx) ** 2 + (ys - cy) ** 2) / (2 * sigma ** 2) + ) + return np.clip(img, 0, 255).astype(np.float32) + + +def _find_candidates(img: np.ndarray) -> list[tuple[int, int]]: + """Local maxima via 7x7 sliding-window max filter (pure NumPy).""" + from numpy.lib.stride_tricks import sliding_window_view + pad = 3 + padded = np.pad(img, pad, mode="edge") + windows = sliding_window_view(padded, (7, 7)) + local_max = windows.max(axis=(-2, -1)) + mask = (img == local_max) & (img > 20) + ys, xs = np.where(mask) + return list(zip(xs.tolist(), ys.tolist())) + + +def _parabolic_centroid(img: np.ndarray, r: int, c: int) -> tuple[float, float]: + def _delta(left, center, right): + denom = 2 * (2 * center - left - right) + return 0.0 if abs(denom) < 1e-6 else (right - left) / denom + + dc = _delta(float(img[r, c - 1]), float(img[r, c]), float(img[r, c + 1])) + dr = _delta(float(img[r - 1, c]), float(img[r, c]), float(img[r + 1, c])) + return c + dc, r + dr + + +def _gaussian_fwhm(profile: np.ndarray) -> float: + p = np.clip(profile.astype(float), 1e-6, None) + peak_idx = int(np.argmax(p)) + if peak_idx == 0 or peak_idx >= len(p) - 1: + return 2.0 + try: + a, b, c_ = np.log(p[peak_idx - 1]), np.log(p[peak_idx]), np.log(p[peak_idx + 1]) + sigma = np.sqrt(-1.0 / (2 * (a + c_ - 2 * b))) + except Exception: + return 2.0 + return 2.355 * abs(sigma) + + +def _safe_remove(plot, marker_type: str, name: str) -> None: + try: + plot.remove_marker(marker_type, name) + except KeyError: + pass + + +# ── build data ───────────────────────────────────────────────────────────────── + +rng = np.random.default_rng(42) +image = _make_stem_image(rng) +candidates = _find_candidates(image) + +# ── figure ───────────────────────────────────────────────────────────────────── + +fig, ax = apl.subplots(1, 1, figsize=(640, 640)) +plot = ax.imshow(image, cmap="gray") + +if candidates: + cand_arr = np.array(candidates, dtype=float) + plot.add_circles(cand_arr, name="candidates", radius=6, + facecolors="none", edgecolors="#555555") + +info_label = plot.add_widget("label", x=10, y=10, text="", color="#00e5ff", fontsize=11) + +picks: list[dict] = [] + + +# ── helpers ──────────────────────────────────────────────────────────────────── + +def _redraw_picks() -> None: + _safe_remove(plot, "circles", "picks_certain") + _safe_remove(plot, "circles", "picks_uncertain") + certain = [p for p in picks if not p["uncertain"]] + uncertain = [p for p in picks if p["uncertain"]] + if certain: + arr = np.array([[p["cx"], p["cy"]] for p in certain]) + plot.add_circles(arr, name="picks_certain", radius=10, + facecolors="none", edgecolors="#00ff88") + if uncertain: + arr = np.array([[p["cx"], p["cy"]] for p in uncertain]) + plot.add_circles(arr, name="picks_uncertain", radius=10, + facecolors="none", edgecolors="#ff9100") + + +def _nearest_candidate(x: float, y: float, max_dist: float = 12.0): + best, best_d = None, max_dist + for cx, cy in candidates: + d = float(np.hypot(cx - x, cy - y)) + if d < best_d: + best, best_d = (cx, cy), d + return best + + +def _nearest_pick_idx(x: float, y: float) -> int | None: + if not picks: + return None + dists = [float(np.hypot(p["cx"] - x, p["cy"] - y)) for p in picks] + return int(np.argmin(dists)) + + +def _inspect(cx_f: float, cy_f: float) -> tuple[float, float, float, float]: + """Return (sub_cx, sub_cy, intensity, fwhm) for the pixel at (cx_f, cy_f).""" + r = int(np.clip(round(cy_f), 4, 507)) + c = int(np.clip(round(cx_f), 4, 507)) + sub_cx, sub_cy = _parabolic_centroid(image, r, c) + intensity = float(image[r, c]) + row_profile = image[r, max(0, c - 4):min(512, c + 5)] + col_profile = image[max(0, r - 4):min(512, r + 5), c] + fwhm = (_gaussian_fwhm(row_profile) + _gaussian_fwhm(col_profile)) / 2 + return sub_cx, sub_cy, intensity, fwhm + + +# ── event handlers ───────────────────────────────────────────────────────────── + +def _on_settled(event) -> None: + if event.xdata is None or event.ydata is None: + return + hit = _nearest_candidate(event.xdata, event.ydata) + if hit is None: + info_label.set(text="") + return + hx, hy = hit + sub_cx, sub_cy, intensity, fwhm = _inspect(hx, hy) + info_label.set( + text=f"centroid ({sub_cx:.1f}, {sub_cy:.1f})\npeak {intensity:.0f}\nFWHM {fwhm:.2f} px", + x=hx + 12, + y=hy - 30, + ) + + +def _on_double_click(event) -> None: + if event.xdata is None or event.ydata is None: + return + hit = _nearest_candidate(event.xdata, event.ydata) + if hit is None: + return + sub_cx, sub_cy, intensity, fwhm = _inspect(*hit) + uncertain = "shift" in event.modifiers + picks.append({"cx": sub_cx, "cy": sub_cy, "intensity": intensity, + "fwhm": fwhm, "uncertain": uncertain}) + _redraw_picks() + tag = "uncertain" if uncertain else "certain" + print(f"Pick #{len(picks)} [{tag}]: ({sub_cx:.1f}, {sub_cy:.1f}) " + f"peak={intensity:.0f} FWHM={fwhm:.2f} px") + + +def _on_key(event) -> None: + if event.key in ("Delete", "Backspace"): + x = event.xdata if event.xdata is not None else 256.0 + y = event.ydata if event.ydata is not None else 256.0 + idx = _nearest_pick_idx(x, y) + if idx is not None: + picks.pop(idx) + _redraw_picks() + elif event.key == "c": + picks.clear() + _redraw_picks() + + +plot.add_event_handler(_on_settled, "pointer_settled", ms=300, delta=6) +plot.add_event_handler(_on_double_click, "double_click") +plot.add_event_handler(_on_key, "key_down") + +fig.set_help( + "Dwell 300 ms: inspect peak\n" + "Double-click: confirm pick (green)\n" + "Shift+double-click: uncertain pick (orange)\n" + "Delete / Backspace: remove nearest pick\n" + "c: clear all picks" +) + +fig # interactive diff --git a/Examples/Interactive/plot_point_widget.py b/Examples/Interactive/plot_point_widget.py new file mode 100644 index 00000000..1f1654c4 --- /dev/null +++ b/Examples/Interactive/plot_point_widget.py @@ -0,0 +1,109 @@ +""" +Draggable Point Widget +====================== + +Demonstrates the :class:`~anyplotlib.widgets.PointWidget` on a 1-D panel. + +A smooth curve ``f(x) = sin(x) · e^(−x/6)`` is shown together with a +cyan control point that the user can drag freely inside the plot area. + +**Interaction** + +* **Drag the point** anywhere inside the plot — the widget reports its + data-space ``(x, y)`` position on every frame via the + ``pointer_move`` event handler. +* **Release** — the ``pointer_up`` event handler snaps the point's + y-coordinate to the curve value at the dragged x and draws the + **tangent line** through that point. + +**What is computed on release** + +Given the dragged x position *xq*, the code evaluates: + +* **Curve value**: ``yq = f(xq)`` +* **Derivative** (central finite difference): ``dy/dx ≈ [f(xq+h) − f(xq−h)] / 2h`` +* **Tangent line**: ``y_tan(x) = yq + slope · (x − xq)`` + +The tangent line is added with :meth:`~anyplotlib.plot1d.Plot1D.add_line` +and the previous one is removed, so only one tangent is shown at a time. + +.. note:: + Move the point to an interesting part of the curve (e.g. a local maximum) + and release — the tangent will be horizontal there. +""" + +import numpy as np +import anyplotlib as apl + +# ── Curve ────────────────────────────────────────────────────────────────── +x = np.linspace(0.0, 4.0 * np.pi, 512) + +def f(t): + return np.sin(t) * np.exp(-t / 6.0) + +def df(t, h=1e-5): + """Central finite-difference derivative of f.""" + return (f(t + h) - f(t - h)) / (2.0 * h) + +y = f(x) + +# ── Figure ───────────────────────────────────────────────────────────────── +fig, ax = apl.subplots(figsize=(680, 340)) +plot = ax.plot(y, axes=[x], units="rad", + color="#4fc3f7", linewidth=2.0, label="f(x)") + +# ── Initial point widget — placed at the first local maximum ─────────────── +x0_init = float(x[np.argmax(y)]) +y0_init = float(np.max(y)) +pt = plot.add_point_widget(x0_init, y0_init, color="#00e5ff") + +# Track the current tangent line handle so we can replace it +_tangent_line: "apl.Line1D | None" = None # type: ignore[name-defined] + +def _draw_tangent(xq: float) -> None: + """Snap point to curve, compute slope, draw tangent overlay.""" + global _tangent_line + + # Evaluate curve and slope at xq + yq = float(f(xq)) + slope = float(df(xq)) + + # Snap the widget y to the curve (visual feedback) + pt._data["y"] = yq + pt._push_fn() + + # Tangent line spans the full visible x range + x_tan = np.array([float(x[0]), float(x[-1])]) + y_tan = yq + slope * (x_tan - xq) + + # Replace previous tangent + if _tangent_line is not None: + _tangent_line.remove() + _tangent_line = plot.add_line( + y_tan, x_axis=x_tan, + color="#ff7043", linewidth=1.5, + linestyle="dashed", + label=f"slope = {slope:+.3f}", + ) + +# Draw the tangent at the initial position +_draw_tangent(x0_init) + + +# ── Callbacks ────────────────────────────────────────────────────────────── + +@pt.add_event_handler("pointer_move") +def _live(event): + """Every drag frame — print the current widget position.""" + print(f" dragging x={event.source.x:.4f} y={event.source.y:.4f}", end="\r") + + +@pt.add_event_handler("pointer_up") +def _settled(event): + """On mouse-up — snap y to the curve and refresh the tangent line.""" + print(f" released x={event.source.x:.4f} ") + _draw_tangent(event.source.x) + + +fig # Interactive + diff --git a/Examples/Interactive/plot_segment_by_contrast.py b/Examples/Interactive/plot_segment_by_contrast.py new file mode 100644 index 00000000..0edab46a --- /dev/null +++ b/Examples/Interactive/plot_segment_by_contrast.py @@ -0,0 +1,245 @@ +""" +Interactive Contrast Segmentation +=================================== + +Click on any region of the image to flood-fill all pixels of similar +intensity — the union of all seeded regions is shown as a live +semi-transparent overlay on the original image. + +**Interaction** + ++-----------------------------------+-----------------------------------------+ +| Action | Effect | ++===================================+=========================================+ +| **Left-click** | Add a *positive* seed (green dot). | +| | Flood-fill grows from that pixel. | ++-----------------------------------+-----------------------------------------+ +| **Shift + left-click** | Add a *negative* seed (red dot). | +| | Subtracts that connected region from | +| | the current mask. | ++-----------------------------------+-----------------------------------------+ +| **Hover + Delete / Backspace** | Remove the nearest seed within | +| | 12 image-px of the cursor. | ++-----------------------------------+-----------------------------------------+ +| **+** / **=** | Increase tolerance (grow regions). | ++-----------------------------------+-----------------------------------------+ +| **-** | Decrease tolerance (shrink regions). | ++-----------------------------------+-----------------------------------------+ +| **c** (while focused) | Clear all seeds and reset mask. | ++-----------------------------------+-----------------------------------------+ + +The current boolean mask numpy array is always accessible as ``mask``. +The cursor position is exposed as ``event.xdata`` (column) and +``event.ydata`` (row) in image-pixel coordinates. + +.. note:: + Move the cursor over the plot so it receives keyboard focus before + pressing keys. The tolerance is shown in the plot title. +""" + +import numpy as np +import anyplotlib as apl + +# ── Synthetic multi-region image ────────────────────────────────────────────── +# Five Gaussian blobs at different intensity levels on a smooth background, +# plus mild Poisson-like noise — gives interesting connected regions to segment. + +N = 256 +rng = np.random.default_rng(7) + +xx, yy = np.meshgrid(np.arange(N), np.arange(N)) + +def _gauss(cx, cy, sigma, amplitude): + return amplitude * np.exp(-((xx - cx)**2 + (yy - cy)**2) / (2 * sigma**2)) + +image = ( + _gauss( 64, 72, 28, 0.85) # bright top-left blob + + _gauss(190, 60, 22, 0.70) # mid top-right blob + + _gauss(128, 128, 40, 0.55) # dim centre blob (large) + + _gauss( 55, 195, 20, 0.90) # bright bottom-left blob + + _gauss(200, 185, 30, 0.60) # mid bottom-right blob + + 0.08 * rng.standard_normal((N, N)) # noise +) +# Normalise to [0, 1] +image = (image - image.min()) / (image.max() - image.min()) + +# ── Segmentation: pure-numpy BFS flood-fill ─────────────────────────────────── + +def _bfs_region(img, row: int, col: int, tol: float) -> np.ndarray: + """Return a boolean mask for the connected region reachable from (row, col). + + Connectivity is 4-connected. A neighbour is accepted when + ``|img[neighbour] - centre_value| <= tol``, where *centre_value* is the + intensity of the seed pixel (fixed, not growing). + """ + H, W = img.shape + seed_val = img[row, col] + visited = np.zeros((H, W), dtype=bool) + visited[row, col] = True + stack = [(row, col)] + while stack: + r, c = stack.pop() + for dr, dc in ((-1, 0), (1, 0), (0, -1), (0, 1)): + nr, nc = r + dr, c + dc + if 0 <= nr < H and 0 <= nc < W and not visited[nr, nc]: + if abs(float(img[nr, nc]) - float(seed_val)) <= tol: + visited[nr, nc] = True + stack.append((nr, nc)) + return visited + + +def _compute_mask(img, pos_seeds, neg_seeds, tol): + """Union of positive-seed BFS regions minus any negative-seed regions.""" + if not pos_seeds: + return np.zeros(img.shape, dtype=bool) + combined = np.zeros(img.shape, dtype=bool) + for r, c in pos_seeds: + combined |= _bfs_region(img, r, c, tol) + for r, c in neg_seeds: + combined &= ~_bfs_region(img, r, c, tol) + return combined + + +# ── State ───────────────────────────────────────────────────────────────────── + +pos_seeds: list[tuple[int, int]] = [] # (row, col) +neg_seeds: list[tuple[int, int]] = [] # (row, col) +tolerance: float = 0.08 +mask = np.zeros((N, N), dtype=bool) # exposed numpy array + +TOL_STEP = 0.01 +TOL_MIN = 0.005 +TOL_MAX = 0.40 +SEED_RADIUS_PIXELS = 5 # marker radius for seed dots + +# ── Figure ──────────────────────────────────────────────────────────────────── + +fig, ax = apl.subplots(figsize=(520, 520), + help="Left-click → add positive seed (grow mask)\n" + "Shift + Left-click → add negative seed (shrink mask)\n" + "Hover + Delete → remove nearest seed\n" + "+ / - → increase / decrease tolerance\n" + "c → clear all seeds") + +plot = ax.imshow(image) +plot.set_colormap("gray") + +# ── Persistent marker groups ────────────────────────────────────────────────── +# Create named groups once so _refresh() can update them with .set() instead of +# clear_markers() + add_circles(). Placing the placeholder far off-screen means +# empty groups render nothing without needing a special empty-list code path. +_HIDDEN = [[-9999.0, -9999.0]] # off-screen placeholder for an empty group + +plot.add_circles(_HIDDEN, name="pos", + facecolors="#00c853", edgecolors="#ffffff", + radius=SEED_RADIUS_PIXELS) +plot.add_circles(_HIDDEN, name="neg", + facecolors="#b71c1c", edgecolors="#ffffff", + radius=SEED_RADIUS_PIXELS) + +# ── Helpers: marker refresh and mask push ──────────────────────────────────── + +def _refresh(): + """Recompute mask and push updated markers + overlay in one go. + + Updates the two persistent marker groups in-place (no clear → blank → add + cycle) so there is no visible flicker when a seed is removed. + Each group has its own fixed colour string so the JS fill_color field + always receives a valid CSS colour (not a mixed list). + """ + global mask + mask = _compute_mask(image, pos_seeds, neg_seeds, tolerance) + + # Update offsets for each group; fall back to off-screen placeholder when empty. + pos_offsets = [(c, r) for r, c in pos_seeds] or _HIDDEN + neg_offsets = [(c, r) for r, c in neg_seeds] or _HIDDEN + plot.markers["circles"]["pos"].set(offsets=pos_offsets) + plot.markers["circles"]["neg"].set(offsets=neg_offsets) + + # Transparent overlay — teal for positive mask regions. + plot.set_overlay_mask(mask, color="#00e5ff", alpha=0.38) + + +# ── Click handler ───────────────────────────────────────────────────────────── + +@plot.add_event_handler("pointer_down") +def _on_click(event): + """Left-click → positive seed; Shift+Left-click → negative seed.""" + # xdata = column, ydata = row (image-pixel coordinates) + col = int(round(float(event.xdata))) + row = int(round(float(event.ydata))) + # Clamp to image bounds + col = max(0, min(N - 1, col)) + row = max(0, min(N - 1, row)) + + if getattr(event, "shift_key", False): + neg_seeds.append((row, col)) + else: + pos_seeds.append((row, col)) + + _refresh() + + +# ── Key bindings ────────────────────────────────────────────────────────────── + +@plot.add_event_handler("key_down") +def _tol_up(event): + """Increase tolerance → flood-fill grows to wider intensity range.""" + if event.key not in ('+', '='): # '+' on most keyboards requires Shift; '=' is the unshifted key + return + global tolerance + tolerance = min(TOL_MAX, round(tolerance + TOL_STEP, 4)) + _refresh() + print(f" tolerance = {tolerance:.3f}", end="\r") + + +@plot.add_event_handler("key_down") +def _tol_down(event): + """Decrease tolerance → flood-fill shrinks to narrower range.""" + if event.key != '-': + return + global tolerance + tolerance = max(TOL_MIN, round(tolerance - TOL_STEP, 4)) + _refresh() + print(f" tolerance = {tolerance:.3f}", end="\r") + + +@plot.add_event_handler("key_down") +def _clear(event): + """Clear all seeds and reset the mask.""" + if event.key != 'c': + return + pos_seeds.clear() + neg_seeds.clear() + _refresh() + print(" seeds cleared", end="\r") + + +@plot.add_event_handler("key_down") +def _delete_nearest(event): + """Remove the seed (positive or negative) nearest to the cursor.""" + if event.key not in ('Delete', 'Backspace'): + return + cx = float(event.xdata) + cy = float(event.ydata) # ydata = row + + best_dist = float("inf") + best_list = None + best_idx = -1 + + for lst in (pos_seeds, neg_seeds): + for i, (r, c) in enumerate(lst): + d = (c - cx) ** 2 + (r - cy) ** 2 + if d < best_dist: + best_dist = d + best_list = lst + best_idx = i + + if best_list is not None and best_dist <= (12 ** 2): + best_list.pop(best_idx) + _refresh() + + +fig # Interactive + + diff --git a/Examples/Interactive/plot_segment_by_contrast_advanced.py b/Examples/Interactive/plot_segment_by_contrast_advanced.py new file mode 100644 index 00000000..e54e57a3 --- /dev/null +++ b/Examples/Interactive/plot_segment_by_contrast_advanced.py @@ -0,0 +1,355 @@ +""" +Advanced Interactive Contrast Segmentation (3 × 3 Grid) +========================================================= + +A 3 × 3 grid of synthetic images, each independently segmented by +flood-fill. Pass 8 or 9 images as ``images_flat``; the grid always +has 3 columns and enough rows to fit them all (last cell left blank +when 8 images are supplied). + +**Interaction** + +.. list-table:: + :header-rows: 1 + :widths: 30 70 + + * - Action + - Effect + * - **Left-click** + - Add a *positive* seed (green dot) on the clicked panel. + * - **Shift + Left-click** + - Add a *negative* seed (red dot) — subtracts that connected region + from the mask. + * - **Ctrl + Left-click** + - Add a polygon vertex to the *clip polygon* of the active panel. + The mask is restricted to pixels inside the polygon once at least + 3 vertices exist. + * - **Drag polygon vertex** + - Reposition any clip-polygon vertex; mask updates on mouse-up. + * - **Hover + Delete / Backspace** + - Remove the clip vertex or seed nearest to the cursor (≤ 15 px). + * - **+** / **=** + - Increase tolerance (grow regions). + * - **-** + - Decrease tolerance (shrink regions). + * - **c** + - Clear all seeds (keeps clip polygon). + * - **p** + - Clear the clip polygon. + +After interaction, the resulting boolean mask arrays are in ``masks_flat`` +(same order as ``images_flat``). + +.. note:: + Click on a panel first to give it keyboard focus, then use the key + bindings. +""" + +import math +import numpy as np +import anyplotlib as apl + +# ── Helpers ─────────────────────────────────────────────────────────────────── + +N = 192 # image size (pixels per side) for the synthetic demo images +NCOLS = 3 # fixed column count + +rng = np.random.default_rng(42) +xx, yy = np.meshgrid(np.arange(N), np.arange(N)) + + +def _gauss(cx, cy, sigma, amplitude): + return amplitude * np.exp(-((xx - cx) ** 2 + (yy - cy) ** 2) / (2 * sigma ** 2)) + + +def _make_image(seed): + """Synthesise a unique multi-blob test image.""" + r = np.random.default_rng(seed) + blobs = [ + (r.integers(30, N - 30), r.integers(30, N - 30), + r.integers(15, 35), r.uniform(0.5, 1.0)) + for _ in range(5) + ] + img = sum(_gauss(cx, cy, sig, amp) for cx, cy, sig, amp in blobs) + img += 0.06 * r.standard_normal((N, N)) + return (img - img.min()) / (img.max() - img.min()) + + +# ── Images — swap this list for your own (8 or 9 arrays of shape (H, W)) ───── + +images_flat = [_make_image(seed) for seed in range(1, 9)] # 8 images +# images_flat = [_make_image(seed) for seed in range(1, 10)] # uncomment for 9 + +# ── Grid geometry derived from the image list ───────────────────────────────── + +n_images = len(images_flat) +if n_images not in (8, 9): + raise ValueError(f"images_flat must contain 8 or 9 images, got {n_images}") + +NROWS = math.ceil(n_images / NCOLS) # 3 for both 8 and 9 + +# ── BFS flood-fill ──────────────────────────────────────────────────────────── + +def _bfs_region(img, row, col, tol): + H, W = img.shape + seed_val = float(img[row, col]) + visited = np.zeros((H, W), dtype=bool) + visited[row, col] = True + stack = [(row, col)] + while stack: + r, c = stack.pop() + for dr, dc in ((-1, 0), (1, 0), (0, -1), (0, 1)): + nr, nc = r + dr, c + dc + if 0 <= nr < H and 0 <= nc < W and not visited[nr, nc]: + if abs(float(img[nr, nc]) - seed_val) <= tol: + visited[nr, nc] = True + stack.append((nr, nc)) + return visited + + +def _compute_mask(img, pos_seeds, neg_seeds, tol, clip_poly): + """Flood-fill union, optionally restricted to a drawn polygon.""" + H, W = img.shape + if not pos_seeds: + return np.zeros((H, W), dtype=bool) + combined = np.zeros((H, W), dtype=bool) + for r, c in pos_seeds: + combined |= _bfs_region(img, r, c, tol) + for r, c in neg_seeds: + combined &= ~_bfs_region(img, r, c, tol) + + if clip_poly and len(clip_poly) >= 3: + # Pure-numpy even-odd ray-casting point-in-polygon + # Polygon vertices are [x, y] = [col, row] in image-pixel space + poly = np.asarray(clip_poly, dtype=float) # (K, 2) as [x, y] + rows = np.arange(H, dtype=float) + cols = np.arange(W, dtype=float) + gc, gr = np.meshgrid(cols, rows) # gc[r,c]=col, gr[r,c]=row + xs = gc.ravel() # x = col index + ys = gr.ravel() # y = row index + inside = np.zeros(H * W, dtype=bool) + n_v = len(poly) + xp, yp = poly[:, 0], poly[:, 1] + for i in range(n_v): + x1, y1 = xp[i], yp[i] + x2, y2 = xp[(i + 1) % n_v], yp[(i + 1) % n_v] + cond = ((y1 > ys) != (y2 > ys)) & ( + xs < (x2 - x1) * (ys - y1) / (y2 - y1 + 1e-12) + x1 + ) + inside ^= cond + combined &= inside.reshape(H, W) + + return combined + + +# ── Per-panel state (flat) ──────────────────────────────────────────────────── + +TOL_STEP = 0.01 +TOL_MIN = 0.005 +TOL_MAX = 0.40 +SEED_RADIUS = 4 +_HIDDEN = [[-9999.0, -9999.0]] +_OFFSCREEN_TRI = [[-9990.0, -9990.0], [-9989.0, -9990.0], [-9989.0, -9989.0]] + +_CMAPS = ["gray", "viridis", "plasma", "inferno", "magma", + "cividis", "hot", "cool", "bone"] + +panel_state = [ + {"pos_seeds": [], "neg_seeds": [], "tolerance": 0.08, "clip_poly": []} + for _ in range(n_images) +] +masks_flat = [np.zeros((N, N), dtype=bool) for _ in range(n_images)] +active_idx = [0] + +# ── Figure ──────────────────────────────────────────────────────────────────── + +fig, axes = apl.subplots( + NROWS, NCOLS, + figsize=(900, 900), + help=( + "Left-click → positive seed (grow)\n" + "Shift + Left-click → negative seed (shrink)\n" + "Ctrl + Left-click → add clip-polygon vertex\n" + "Drag polygon vertex → reposition (mask updates on release)\n" + "Delete / Backspace → remove nearest vertex or seed\n" + "+ / - → tolerance up / down\n" + "c → clear seeds\n" + "p → clear clip polygon" + ), +) + +# Flatten axes to a 1-D list (row-major, matches images_flat) +axes_flat = [axes[r][c] for r in range(NROWS) for c in range(NCOLS)] + +# Build plot objects only for panels that have an image +plots_flat = [] +clip_wids = [] # one PolygonWidget per panel + +for idx in range(n_images): + p = axes_flat[idx].imshow(images_flat[idx]) + p.set_colormap(_CMAPS[idx % len(_CMAPS)]) + + # Seed marker groups + p.add_circles(_HIDDEN, name="pos", + facecolors="#69f0ae", edgecolors="#ffffff", + radius=SEED_RADIUS) + p.add_circles(_HIDDEN, name="neg", + facecolors="#ff5252", edgecolors="#ffffff", + radius=SEED_RADIUS) + + # Preview dots for partial polygon (< 3 vertices — before widget takes over) + p.add_circles(_HIDDEN, name="clip_pts", + facecolors="#ffeb3b", edgecolors="#ffffff", + radius=3) + + # Draggable polygon widget — starts offscreen until ≥ 3 vertices are placed. + # The widget provides per-vertex handles that can be dragged in the browser. + wid = p.add_widget("polygon", color="#ffeb3b", vertices=_OFFSCREEN_TRI) + clip_wids.append(wid) + + plots_flat.append(p) + + +# ── Refresh helper ──────────────────────────────────────────────────────────── + +def _refresh(idx): + """Recompute mask and push all markers + overlay for panel ``idx``.""" + try: + st = panel_state[idx] + p = plots_flat[idx] + img = images_flat[idx] + + masks_flat[idx] = _compute_mask( + img, st["pos_seeds"], st["neg_seeds"], + st["tolerance"], st["clip_poly"], + ) + + # Seed marker dots + pos_off = [(c, r) for r, c in st["pos_seeds"]] or _HIDDEN + neg_off = [(c, r) for r, c in st["neg_seeds"]] or _HIDDEN + p.markers["circles"]["pos"].set(offsets=pos_off) + p.markers["circles"]["neg"].set(offsets=neg_off) + + # Clip polygon widget — show real vertices once we have ≥ 3, else offscreen + clip = st["clip_poly"] + if len(clip) >= 3: + clip_wids[idx].set(vertices=clip) + # Hide the preview dots (widget handles are enough) + p.markers["circles"]["clip_pts"].set(offsets=_HIDDEN) + else: + clip_wids[idx].set(vertices=_OFFSCREEN_TRI) + # Show partial-polygon vertex dots during the building phase + clip_off = [[v[0], v[1]] for v in clip] or _HIDDEN + p.markers["circles"]["clip_pts"].set(offsets=clip_off) + + # Mask overlay + p.set_overlay_mask(masks_flat[idx], color="#00e5ff", alpha=0.38) + + except Exception as exc: + import traceback + print(f"[panel {idx}] _refresh error: {exc}") + traceback.print_exc() + + +# ── Click & key handlers (one closure per panel) ────────────────────────────── + +def _make_handlers(idx): + p = plots_flat[idx] + wid = clip_wids[idx] + img = images_flat[idx] + H, W = img.shape + + # ── Polygon widget: sync vertices → panel_state after any drag ──────────── + @wid.add_event_handler("pointer_up") + def _poly_dragged(event): + active_idx[0] = idx + vs = wid.vertices # widget data is synced from JS before callbacks + if vs is None: + return + # Filter out any accidental off-screen dummy vertices + real = [[float(v[0]), float(v[1])] for v in vs + if abs(float(v[0])) < 9000 and abs(float(v[1])) < 9000] + panel_state[idx]["clip_poly"] = real + _refresh(idx) + + # ── Click: add seed or polygon vertex ───────────────────────────────────── + @p.add_event_handler("pointer_down") + def _on_click(event): + if event.xdata is None or event.ydata is None: + return + active_idx[0] = idx + st = panel_state[idx] + r_px = max(0, min(H - 1, int(round(float(event.ydata))))) + c_px = max(0, min(W - 1, int(round(float(event.xdata))))) + if "ctrl" in event.modifiers: + st["clip_poly"].append([float(c_px), float(r_px)]) + elif "shift" in event.modifiers: + st["neg_seeds"].append((r_px, c_px)) + else: + st["pos_seeds"].append((r_px, c_px)) + _refresh(idx) + + # ── Keys: tolerance, clear, delete-nearest ───────────────────────────────── + @p.add_event_handler("key_down") + def _on_key(event): + active_idx[0] = idx + st = panel_state[idx] + if event.key in ("+", "="): + st["tolerance"] = min(TOL_MAX, round(st["tolerance"] + TOL_STEP, 4)) + _refresh(idx) + elif event.key == "-": + st["tolerance"] = max(TOL_MIN, round(st["tolerance"] - TOL_STEP, 4)) + _refresh(idx) + elif event.key == "c": + st["pos_seeds"].clear() + st["neg_seeds"].clear() + _refresh(idx) + elif event.key == "p": + st["clip_poly"].clear() + _refresh(idx) + elif event.key in ("Delete", "Backspace"): + _delete_nearest(event) + + def _delete_nearest(event): + st = panel_state[idx] + if event.xdata is None or event.ydata is None: + return + cx = float(event.xdata) + cy = float(event.ydata) + HIT2 = 15 ** 2 # hit radius squared (px) + + # Check clip-polygon vertices first (they're on top visually) + best_dist = float("inf") + best_poly_i = -1 + for i, (vx, vy) in enumerate(st["clip_poly"]): + d = (vx - cx) ** 2 + (vy - cy) ** 2 + if d < best_dist: + best_dist = d + best_poly_i = i + + if best_poly_i >= 0 and best_dist <= HIT2: + st["clip_poly"].pop(best_poly_i) + _refresh(idx) + return + + # Otherwise check seeds + best_dist = float("inf") + best_list = None + best_i = -1 + for lst in (st["pos_seeds"], st["neg_seeds"]): + for i, (r, c) in enumerate(lst): + d = (c - cx) ** 2 + (r - cy) ** 2 + if d < best_dist: + best_dist = d + best_list = lst + best_i = i + + if best_list is not None and best_dist <= HIT2: + best_list.pop(best_i) + _refresh(idx) + + +_handlers = [_make_handlers(idx) for idx in range(n_images)] + +fig + diff --git a/Examples/Interactive/plot_spectra_roi_inspector.py b/Examples/Interactive/plot_spectra_roi_inspector.py new file mode 100644 index 00000000..eef36b9c --- /dev/null +++ b/Examples/Interactive/plot_spectra_roi_inspector.py @@ -0,0 +1,265 @@ +""" +ROI-to-spectrum inspector for a 3-D EDS hyperspectral dataset. +============================================================== + +A synthetic ``(256, 256, 300)`` EDS datacube — one 300-channel +spectrum per scan position. Four rectangular ROIs overlay the +total-counts image (HAADF proxy). Entering an ROI **sums all spectra +within the rectangle** (spatial sum over every scan position in the +box) and displays the result in the top-right panel. Draggable +coloured range widgets on the spectrum define the integration window +for each element; each bar height is the **channel sum of the ROI +spectrum within that window**. + +**Interaction** + +* **Move cursor inside an ROI** — spatially sums the spectra of all + scan positions inside the box; updates the line plot and bars live. +* **Drag an ROI rectangle** — repositions the ROI on the image. +* **Release drag** — recomputes the spatial sum spectrum for the new + position. +* **Drag a coloured range widget** on the spectrum — adjusts the + integration window for that element; bar heights update on every + drag frame. +""" +import numpy as np +import anyplotlib as apl + + +# ── synthetic 3-D hyperspectral datacube ────────────────────────────────────── +# Shape: (NY, NX, NC). dataset[y, x, :] is the 300-channel EDS spectrum at +# scan position (x, y). Each pixel is an independent Poisson draw from the +# expected spectrum for its phase. + +NY, NX, NC = 256, 256, 300 +ENERGY = np.linspace(0.1, 3.0, NC) # keV + +EDS_ELEMENTS = ["O", "Fe", "Al", "Si"] +_EDS_EV = [0.525, 0.710, 1.487, 1.740] # characteristic keV +_EDS_WIN = [(0.45, 0.61), (0.64, 0.80), (1.40, 1.58), (1.65, 1.83)] +_EDS_SIGMA = 0.025 +_EDS_COLORS = ["#ff8a65", "#ba68c8", "#4fc3f7", "#aed581"] + +_PEAKS = np.array([ + np.exp(-0.5 * ((ENERGY - ev) / _EDS_SIGMA) ** 2) + for ev in _EDS_EV +]) # shape (4, NC) + +# Per-phase element weight vectors [O, Fe, Al, Si] and expected total +# counts per pixel (determines peak-to-background ratio and brightness). +_PHASE_DEFS = [ + dict(weights=[0.10, 0.05, 0.65, 0.20], counts=80), # 0 Matrix + dict(weights=[0.05, 0.08, 0.12, 0.75], counts=200), # 1 Precipitate A + dict(weights=[0.12, 0.60, 0.18, 0.10], counts=150), # 2 Precipitate B + dict(weights=[0.62, 0.12, 0.18, 0.08], counts=110), # 3 Grain Boundary +] + + +def _expected_spectrum(phase_idx: int) -> np.ndarray: + p = _PHASE_DEFS[phase_idx] + bkg = 3.0 * np.exp(-ENERGY / 0.8) + spec = bkg + (_PEAKS * np.array(p["weights"])[:, None]).sum(axis=0) * p["counts"] + return np.clip(spec, 0, None).astype(np.float64) + + +def _make_dataset(rng: np.random.Generator) -> tuple[np.ndarray, np.ndarray]: + phases = np.zeros((NY, NX), dtype=np.int8) # 0 = Matrix + + # Precipitate A (Si-rich) — cluster in top-left quadrant + for cx, cy, r in [(60, 60, 30), (75, 50, 22), (45, 75, 20)]: + ys, xs = np.ogrid[:NY, :NX] + phases[(xs - cx) ** 2 + (ys - cy) ** 2 < r ** 2] = 1 + + # Precipitate B (Fe-rich) — cluster in bottom-right quadrant + for cx, cy, r in [(195, 195, 27), (180, 210, 20), (210, 180, 17)]: + ys, xs = np.ogrid[:NY, :NX] + phases[(xs - cx) ** 2 + (ys - cy) ** 2 < r ** 2] = 2 + + # Grain boundary — thin horizontal band + phases[120:135, :] = 3 + + dataset = np.empty((NY, NX, NC), dtype=np.float32) + flat = dataset.reshape(-1, NC) + phases_flat = phases.ravel() + for pidx, pdef in enumerate(_PHASE_DEFS): + sel = phases_flat == pidx + n = int(sel.sum()) + if n == 0: + continue + lam = _expected_spectrum(pidx) + flat[sel] = rng.poisson(lam, size=(n, NC)).astype(np.float32) + + return dataset, phases + + +rng = np.random.default_rng(99) +dataset, _phase_map = _make_dataset(rng) + +# Total-counts image used as the HAADF-proxy display image +_display_img = dataset.sum(axis=2) + + +# ── ROI definitions (r0, r1, c0, c1) in scan-pixel coordinates ──────────────── + +ROIS: dict[str, tuple[int, int, int, int]] = { + "Matrix": ( 25, 100, 155, 230), + "Precipitate A": ( 25, 100, 25, 100), + "Precipitate B": (155, 230, 155, 230), + "Grain Boundary": (115, 140, 25, 230), +} +_ROI_COLORS: dict[str, str] = { + "Matrix": "#4fc3f7", + "Precipitate A": "#aed581", + "Precipitate B": "#ff8a65", + "Grain Boundary": "#ba68c8", +} + + +def _sum_spectrum(r0: int, r1: int, c0: int, c1: int) -> np.ndarray: + """Spatial sum of all spectra within the ROI box.""" + r0 = max(0, min(NY - 1, r0)); r1 = max(1, min(NY, r1)) + c0 = max(0, min(NX - 1, c0)); c1 = max(1, min(NX, c1)) + return dataset[r0:r1, c0:c1, :].sum(axis=(0, 1)) + + +def _roi_at(x: float, y: float) -> str | None: + for name, (r0, r1, c0, c1) in ROIS.items(): + if c0 <= x <= c1 and r0 <= y <= r1: + return name + return None + + +# ── layout ───────────────────────────────────────────────────────────────────── + +fig = apl.Figure(figsize=(1100, 560)) +gs = apl.GridSpec(2, 2, width_ratios=[1, 1], height_ratios=[1, 1]) + +ax_img = fig.add_subplot(gs[:, 0]) # total-counts image — left column +ax_spec = fig.add_subplot(gs[0, 1]) # ROI sum spectrum — top right +ax_bar = fig.add_subplot(gs[1, 1]) # element bar chart — bottom right + +img_plot = ax_img.imshow(_display_img, cmap="gray") + +_init_spec = _sum_spectrum(*ROIS["Matrix"]).astype(np.float32) +spec_plot = ax_spec.plot(_init_spec, axes=[ENERGY], + color=_ROI_COLORS["Matrix"], linewidth=1.5, + units="keV", y_units="counts") +bar_plot = ax_bar.bar(EDS_ELEMENTS, [0.0] * 4) + + +# ── ROI rectangle overlays on the image ─────────────────────────────────────── + +_roi_widgets: dict[str, object] = {} +for roi_name, (r0, r1, c0, c1) in ROIS.items(): + w = img_plot.add_widget( + "rectangle", + x=float(c0), y=float(r0), + w=float(c1 - c0), h=float(r1 - r0), + color=_ROI_COLORS[roi_name], + ) + _roi_widgets[roi_name] = w + +status_label = img_plot.add_widget( + "label", x=4, y=248, text="Move cursor into an ROI", + color="#ffffff", fontsize=10, +) + + +# ── adjustable range widgets on the spectrum ─────────────────────────────────── + +range_widgets: dict[str, object] = {} +for elem, (lo, hi), color in zip(EDS_ELEMENTS, _EDS_WIN, _EDS_COLORS): + range_widgets[elem] = spec_plot.add_range_widget(lo, hi, color=color) + +_current_spectrum: list[np.ndarray] = [_init_spec.copy()] + + +def _channel_sum(x0: float, x1: float) -> float: + """Sum of ROI spectrum counts within the energy window [x0, x1].""" + mask = (ENERGY >= x0) & (ENERGY <= x1) + return float(_current_spectrum[0][mask].sum()) if mask.any() else 0.0 + + +def _update_bars() -> None: + heights = np.array([ + _channel_sum(range_widgets[e].x0, range_widgets[e].x1) + for e in EDS_ELEMENTS + ]) + max_h = heights.max() or 1.0 + bar_plot.set_data((heights / max_h).tolist()) + + +for _rw in range_widgets.values(): + _rw.add_event_handler(lambda event: _update_bars(), "pointer_move") + _rw.add_event_handler(lambda event: _update_bars(), "pointer_up") + +_update_bars() + + +# ── update helper ────────────────────────────────────────────────────────────── + +_current_roi: list[str | None] = [None] +_roi_dragging = False + + +def _update_for_roi(roi_name: str) -> None: + _current_roi[0] = roi_name + r0, r1, c0, c1 = ROIS[roi_name] + _current_spectrum[0] = _sum_spectrum(r0, r1, c0, c1).astype(np.float32) + spec_plot.set_data(_current_spectrum[0], x_axis=ENERGY) + spec_plot.set_color(_ROI_COLORS[roi_name]) + _update_bars() + n_pixels = (r1 - r0) * (c1 - c0) + status_label.set(text=f"ROI: {roi_name} ({n_pixels} px)") + + +# ── event handlers ───────────────────────────────────────────────────────────── + +def _on_move(event) -> None: + if _roi_dragging or event.xdata is None or event.ydata is None: + return + roi_name = _roi_at(event.xdata, event.ydata) + if roi_name is None or roi_name == _current_roi[0]: + return + _update_for_roi(roi_name) + + +def _on_enter(event) -> None: + status_label.set(text="Move cursor into an ROI") + + +def _on_leave(event) -> None: + status_label.set(text="Move cursor over image to inspect") + _current_roi[0] = None + + +img_plot.add_event_handler(_on_move, "pointer_move") +img_plot.add_event_handler(_on_enter, "pointer_enter") +img_plot.add_event_handler(_on_leave, "pointer_leave") + +for roi_name, widget in _roi_widgets.items(): + def _make_drag_handler(): + def _on_drag(event) -> None: + global _roi_dragging + _roi_dragging = True + return _on_drag + + def _make_release_handler(name, wgt): + def _on_release(event) -> None: + global _roi_dragging + _roi_dragging = False + x, y, w, h = wgt.x, wgt.y, wgt.w, wgt.h + ROIS[name] = (int(y), int(y + h), int(x), int(x + w)) + _update_for_roi(name) + return _on_release + + widget.add_event_handler(_make_drag_handler(), "pointer_move") + widget.add_event_handler(_make_release_handler(roi_name, widget), "pointer_up") + +fig.set_help( + "Move cursor inside an ROI: spatial sum spectrum + bars\n" + "Drag ROI rectangle: repositions ROI; release recomputes\n" + "Drag a coloured range widget: adjust element integration window" +) + +fig # Interactive diff --git a/Examples/Interactive/plot_threshold_explorer.py b/Examples/Interactive/plot_threshold_explorer.py new file mode 100644 index 00000000..bb45ff4d --- /dev/null +++ b/Examples/Interactive/plot_threshold_explorer.py @@ -0,0 +1,138 @@ +""" +Live intensity thresholding on a multi-phase STEM image. +========================================================= + +A side-by-side view: the left panel shows a synthetic 512×512 STEM +image with a red overlay marking pixels above the threshold; the right +panel shows a 32-bin intensity histogram with a yellow vertical line at +the current threshold value. + +**Interaction** + +* **Shift+Scroll** over the image — adjusts the threshold by ±2 per + wheel tick (plain scroll pans/zooms the image as normal). +* **Click** a histogram bar — jumps the threshold to that bin's upper + edge. +* **Dwell 400 ms** over the image — shows pixel coordinates and + intensity in the bottom-left label. +""" +import numpy as np +import anyplotlib as apl + + +# ── synthetic data ───────────────────────────────────────────────────────────── + +def _make_multiphase_image(rng: np.random.Generator) -> np.ndarray: + img = rng.normal(20, 5, (512, 512)).astype(np.float32) + + # Grain A — 6 large blobs + for _ in range(6): + cx, cy = rng.integers(60, 452, size=2) + r = rng.integers(40, 80) + ys, xs = np.ogrid[:512, :512] + mask = (xs - cx) ** 2 + (ys - cy) ** 2 < r ** 2 + img[mask] = rng.normal(80, 8, mask.sum()) + + # Grain B — 8 smaller blobs + for _ in range(8): + cx, cy = rng.integers(40, 472, size=2) + r = rng.integers(15, 35) + ys, xs = np.ogrid[:512, :512] + mask = (xs - cx) ** 2 + (ys - cy) ** 2 < r ** 2 + img[mask] = rng.normal(130, 10, mask.sum()) + + # Voids — 12 dark circular regions + for _ in range(12): + cx, cy = rng.integers(20, 492, size=2) + r = rng.integers(8, 20) + ys, xs = np.ogrid[:512, :512] + mask = (xs - cx) ** 2 + (ys - cy) ** 2 < r ** 2 + img[mask] = rng.normal(5, 2, mask.sum()) + + return np.clip(img, 0, 255).astype(np.float32) + + +rng = np.random.default_rng(13) +image = _make_multiphase_image(rng) + +NBINS = 32 +counts, bin_edges = np.histogram(image, bins=NBINS, range=(0, 255)) +bin_centers = 0.5 * (bin_edges[:-1] + bin_edges[1:]) +x_labels = [f"{int(v)}" for v in bin_centers] + +threshold = 100.0 + + +# ── figure ───────────────────────────────────────────────────────────────────── + +fig, (ax_img, ax_hist) = apl.subplots(1, 2, figsize=(900, 500)) + +img_plot = ax_img.imshow(image, cmap="gray") +hist_plot = ax_hist.bar(x_labels, counts.astype(float)) + +# Track the threshold vline widget so we can remove/replace it +_thresh_widget = None + + +def _pct_above(thresh: float) -> float: + return 100.0 * float((image >= thresh).sum()) / image.size + + +def _update_display(thresh: float) -> None: + global threshold, _thresh_widget + threshold = float(np.clip(thresh, 0, 255)) + mask = image >= threshold + img_plot.set_overlay_mask(mask, color="#ff0000", alpha=0.35) + # Remove old threshold line widget and add a new one + if _thresh_widget is not None: + try: + hist_plot.remove_widget(_thresh_widget) + except KeyError: + pass + _thresh_widget = hist_plot.add_vline_widget(threshold, color="#ffeb3b") + pct = _pct_above(threshold) + print(f"Threshold: {threshold:.0f} | {pct:.1f}% above") + + +_update_display(threshold) + +info_label = img_plot.add_widget("label", x=10, y=490, text="", color="#ffeb3b", fontsize=11) + + +# ── event handlers ───────────────────────────────────────────────────────────── + +def _on_wheel(event) -> None: + if "shift" not in event.modifiers: + return + delta = -2.0 * np.sign(event.dy) if event.dy != 0 else 0.0 + _update_display(threshold + delta) + + +def _on_bar_click(event) -> None: + idx = event.bar_index + if idx is None: + return + new_thresh = float(bin_edges[idx + 1]) + _update_display(new_thresh) + + +def _on_settled(event) -> None: + if event.xdata is None or event.ydata is None: + return + x = int(np.clip(round(event.xdata), 0, 511)) + y = int(np.clip(round(event.ydata), 0, 511)) + intensity = float(image[y, x]) + info_label.set(text=f"px ({x}, {y}): {intensity:.0f}", x=10, y=490) + + +img_plot.add_event_handler(_on_wheel, "wheel") +img_plot.add_event_handler(_on_settled, "pointer_settled", ms=400, delta=4) +hist_plot.add_event_handler(_on_bar_click, "pointer_down") + +fig.set_help( + "Shift+Scroll over image: adjust threshold ±2\n" + "Click histogram bar: jump to bin upper edge\n" + "Dwell 400 ms over image: inspect pixel intensity" +) + +fig # Interactive diff --git a/Examples/Interactive/plot_voxel_grain_explorer.py b/Examples/Interactive/plot_voxel_grain_explorer.py new file mode 100644 index 00000000..47f73fb9 --- /dev/null +++ b/Examples/Interactive/plot_voxel_grain_explorer.py @@ -0,0 +1,292 @@ +""" +3-D Voxel Grain Explorer +======================== + +An orthoslice viewer for a synthetic 3-D polycrystal (voxel grain map), +in the style of EBSD/tomography volume browsers: + +* **Top row** — the three orthogonal slices (XY, XZ, YZ) through the + current voxel, rendered as true-colour IPF-RGB images. Each carries a + draggable crosshair; the three crosshairs are **linked**: dragging one + moves the slice planes of the other two views. +* **Bottom left** — the grain volume rendered as **translucent shaded + voxels** with three draggable **plane widgets** (the slice selectors in + 3-D). Voxels lying on a selected plane render more opaque, so the + current slices glow inside the volume. Drag a plane along its normal to + re-slice — the 2-D views follow. +* **Bottom right** — the *reduced 3-D inverse pole figure*: the selected + voxel's grain orientation is highlighted on the wireframed unit sphere, + which **rotates to face that crystal direction**. + +Everything is bidirectionally linked: drag a crosshair OR a 3-D plane and +the other views re-cut, the voxel highlight moves, and the IPF re-aims. +Drag empty space on either 3-D panel to orbit it freely. +""" + +import numpy as np +import anyplotlib as apl + +rng = np.random.default_rng(11) + +# ── 1. Synthetic 3-D polycrystal: nearest-seed voxel grain map ────────────── +N = 48 # volume is N³ voxels, indexed V[z, y, x] +N_GRAINS = 40 + +seeds = rng.uniform(0, N, size=(N_GRAINS, 3)) # (z, y, x) +zz, yy, xx = np.mgrid[0:N, 0:N, 0:N] +gid = np.zeros((N, N, N), dtype=np.int32) +best = np.full((N, N, N), np.inf) +for g, (sz, sy, sx) in enumerate(seeds): + d = (zz - sz) ** 2 + (yy - sy) ** 2 + (xx - sx) ** 2 + closer = d < best + gid[closer] = g + best[closer] = d[closer] + + +# ── 2. Orientations, cubic fundamental-sector reduction, IPF colours ──────── +def random_rotations(n): + """Uniform random rotation matrices, shape (n, 3, 3) (Shoemake method).""" + u1, u2, u3 = rng.random((3, n)) + q = np.stack([ + np.sqrt(1 - u1) * np.sin(2 * np.pi * u2), + np.sqrt(1 - u1) * np.cos(2 * np.pi * u2), + np.sqrt(u1) * np.sin(2 * np.pi * u3), + np.sqrt(u1) * np.cos(2 * np.pi * u3), + ], axis=1) + x, y, z, w = q.T + return np.stack([ + np.stack([1 - 2 * (y * y + z * z), 2 * (x * y - z * w), 2 * (x * z + y * w)], -1), + np.stack([2 * (x * y + z * w), 1 - 2 * (x * x + z * z), 2 * (y * z - x * w)], -1), + np.stack([2 * (x * z - y * w), 2 * (y * z + x * w), 1 - 2 * (x * x + y * y)], -1), + ], axis=1) + + +rotations = random_rotations(N_GRAINS) +dirs = rotations[:, 2, :] # Rᵀ·ẑ per grain +reduced = np.sort(np.abs(dirs), axis=1) # cubic 001–011–111 +a, b, c = reduced.T +rgb = np.stack([c - b, b - a, a], axis=1) +rgb /= rgb.max(axis=1, keepdims=True) + 1e-12 +grain_rgb_u8 = (rgb * 255).astype(np.uint8) # (N_GRAINS, 3) + +# ── 3. Voxels for the 3-D volume view ─────────────────────────────────────── +# Rather than a sparse random subsample of the whole volume (where the +# highlight marker floats in empty space because almost no cube sits at the +# selected voxel), render the voxels that actually lie ON the three slice +# planes. This anchors the highlight exactly where the slices intersect, +# shows real slice contents in 3-D, and scales: the on-plane count is +# ~3·(N/step)² regardless of N, so it stays fast even for a 256³ volume. +VSTEP = max(1, N // 48) # in-plane downsample → ~48² cubes per plane + +# Voxel cube size in data units. A touch larger than VSTEP so the three +# slabs read as solid sheets rather than a dotted grid. +VOXSIZE = float(VSTEP) * 1.3 + + +def slice_voxels(ix, iy, iz): + """Voxel centres + colours lying on the x=ix, y=iy, z=iz planes.""" + s = VSTEP + rng_ax = np.arange(0, N, s) + parts = [] + # z = iz plane (vary x, y) + yy2, xx2 = np.meshgrid(rng_ax, rng_ax, indexing="ij") + parts.append(np.column_stack([xx2.ravel(), yy2.ravel(), + np.full(xx2.size, iz)])) + # y = iy plane (vary x, z) + zz2, xx2 = np.meshgrid(rng_ax, rng_ax, indexing="ij") + parts.append(np.column_stack([xx2.ravel(), np.full(xx2.size, iy), + zz2.ravel()])) + # x = ix plane (vary y, z) + zz2, yy2 = np.meshgrid(rng_ax, rng_ax, indexing="ij") + parts.append(np.column_stack([np.full(yy2.size, ix), yy2.ravel(), + zz2.ravel()])) + pts = np.vstack(parts) # (M, 3) as (x,y,z) + cols = grain_rgb_u8[gid[pts[:, 2], pts[:, 1], pts[:, 0]]] # gid[z,y,x] + return pts, cols + +# ── 4. Figure: 3 slices on top, volume + IPF below ────────────────────────── +gs = apl.GridSpec(2, 3) +fig = apl.Figure(figsize=(960, 640), + help="Drag a crosshair: the other two slices re-cut, the\n" + "3-D voxel highlight moves, and the IPF sphere rotates\n" + "to the selected grain's crystal direction.\n" + "Drag the 3-D panels to orbit them freely.") + +ax_xy = fig.add_subplot(gs[0, 0]) +ax_xz = fig.add_subplot(gs[0, 1]) +ax_yz = fig.add_subplot(gs[0, 2]) +ax_vol = fig.add_subplot(gs[1, 0]) +ax_ipf = fig.add_subplot(gs[1, 1:3]) + +ix, iy, iz = N // 2, N // 2, N // 2 # integer slice indices +fx, fy, fz = float(ix), float(iy), float(iz) # smooth highlight pos + +px = [np.arange(N)] * 2 # pixel axes → gutters + +v_xy = ax_xy.imshow(grain_rgb_u8[gid[iz]], axes=px, units="vox") +v_xz = ax_xz.imshow(grain_rgb_u8[gid[:, iy, :]], axes=px, units="vox") +v_yz = ax_yz.imshow(grain_rgb_u8[gid[:, :, ix]], axes=px, units="vox") +v_xy.set_xlabel("x"); v_xy.set_ylabel("y") +v_xz.set_xlabel("x"); v_xz.set_ylabel("z") +v_yz.set_xlabel("y"); v_yz.set_ylabel("z") + +cw_xy = v_xy.add_widget("crosshair", cx=ix, cy=iy, color="#ffffff") +cw_xz = v_xz.add_widget("crosshair", cx=ix, cy=iz, color="#ffffff") +cw_yz = v_yz.add_widget("crosshair", cx=iy, cy=iz, color="#ffffff") + +_vpts, _vcols = slice_voxels(ix, iy, iz) +v_vol = ax_vol.voxels( + _vpts[:, 0], _vpts[:, 1], _vpts[:, 2], colors=_vcols, + size=VOXSIZE, alpha=0.55, + x_label="x", y_label="y", z_label="z", + bounds=((0, N - 1),) * 3, zoom=1.1, +) +v_vol.set_title("Grain volume — drag a plane to re-slice") + +# Three draggable slice-selector planes; on-plane voxels render opaque +pw_yz = v_vol.add_widget("plane", axis="x", position=ix, color="#ff5252", alpha=0.18) +pw_xz = v_vol.add_widget("plane", axis="y", position=iy, color="#69f0ae", alpha=0.18) +pw_xy = v_vol.add_widget("plane", axis="z", position=iz, color="#40c4ff", alpha=0.18) + +v_ipf = ax_ipf.scatter3d( + reduced[:, 0], reduced[:, 1], reduced[:, 2], + colors=grain_rgb_u8, point_size=6, + x_label="[100]", y_label="[010]", z_label="[001]", + bounds=((-1, 1),) * 3, zoom=1.4, +) +v_ipf.set_title("Reduced 3D IPF") +v_ipf.set_sphere(1.0) + + +# ── 5. Linked updates ──────────────────────────────────────────────────────── +def face_camera(v): + """Turntable (az°, el°) aiming the camera straight down unit vector v.""" + el = np.degrees(np.arcsin(np.clip(v[2], -1.0, 1.0))) + az = np.degrees(np.arctan2(v[0], -v[1])) + return az, el + + +_busy = [False] # programmatic widget.set() fires callbacks — guard re-entry + + +def update(source: str) -> None: + """Re-cut the other slices, move crosshairs/highlights, re-aim the IPF.""" + _busy[0] = True + try: + # Coalesce every panel mutation below into one push per panel — without + # this, a single crosshair drag fires ~8 full-state pushes across the + # comm boundary, which is the main source of Pyodide lag. + with fig.batch(): + if source != "xy": + v_xy.set_data(grain_rgb_u8[gid[iz]]) + cw_xy.set(cx=ix, cy=iy) + if source != "xz": + v_xz.set_data(grain_rgb_u8[gid[:, iy, :]]) + cw_xz.set(cx=ix, cy=iz) + if source != "yz": + v_yz.set_data(grain_rgb_u8[gid[:, :, ix]]) + cw_yz.set(cx=iy, cy=iz) + v_xy.set_title(f"XY slice — z={iz}") + v_xz.set_title(f"XZ slice — y={iy}") + v_yz.set_title(f"YZ slice — x={ix}") + + # 3-D slice-selector planes follow at the SMOOTH position (skipped for + # the one being dragged, so its own live position isn't overwritten). + if source != "px": + pw_yz.set(position=fx) + if source != "py": + pw_xz.set(position=fy) + if source != "pz": + pw_xy.set(position=fz) + + # Re-cut the 3-D slab voxels to the new slice indices so the volume + # view shows the actual slice contents (bounded ~3·(N/VSTEP)² voxels). + _p, _c = slice_voxels(ix, iy, iz) + v_vol.set_data(_p[:, 0], _p[:, 1], _p[:, 2]) + v_vol.set_point_colors(_c) + + # Highlight tracks the SMOOTH plane positions (fx,fy,fz) so the marker + # glides with the planes instead of jumping by whole voxels. + v_vol.set_highlight(fx, fy, fz, color="#ffffff", size=7) + + g = int(gid[iz, iy, ix]) + v_ipf.set_highlight(*reduced[g], color="#ffffff", size=8) + az, el = face_camera(reduced[g]) + v_ipf.set_view(azimuth=az, elevation=el) + finally: + _busy[0] = False + + +def _clipf(v): + """Clamp a float position to the volume range (kept smooth for the marker).""" + return float(np.clip(v, 0.0, N - 1)) + + +def _i(v): + """Round a float position to the nearest integer slice index.""" + return int(round(v)) + + +@cw_xy.add_event_handler("pointer_move") +def _moved_xy(event): + global ix, iy, fx, fy + if _busy[0]: + return + fx, fy = _clipf(cw_xy.cx), _clipf(cw_xy.cy) + ix, iy = _i(fx), _i(fy) + update("xy") + + +@cw_xz.add_event_handler("pointer_move") +def _moved_xz(event): + global ix, iz, fx, fz + if _busy[0]: + return + fx, fz = _clipf(cw_xz.cx), _clipf(cw_xz.cy) + ix, iz = _i(fx), _i(fz) + update("xz") + + +@cw_yz.add_event_handler("pointer_move") +def _moved_yz(event): + global iy, iz, fy, fz + if _busy[0]: + return + fy, fz = _clipf(cw_yz.cx), _clipf(cw_yz.cy) + iy, iz = _i(fy), _i(fz) + update("yz") + + +@pw_yz.add_event_handler("pointer_move") +def _plane_x(event): + global ix, fx + if _busy[0]: + return + fx = _clipf(pw_yz.position) + ix = _i(fx) + update("px") + + +@pw_xz.add_event_handler("pointer_move") +def _plane_y(event): + global iy, fy + if _busy[0]: + return + fy = _clipf(pw_xz.position) + iy = _i(fy) + update("py") + + +@pw_xy.add_event_handler("pointer_move") +def _plane_z(event): + global iz, fz + if _busy[0]: + return + fz = _clipf(pw_xy.position) + iz = _i(fz) + update("pz") + + +update("none") + +fig # Interactive diff --git a/Examples/Markers/plot_arrows.py b/Examples/Markers/plot_arrows.py index cb4c7f91..7c6eae59 100644 --- a/Examples/Markers/plot_arrows.py +++ b/Examples/Markers/plot_arrows.py @@ -3,27 +3,27 @@ ====== Draw vector arrows on a 2-D image with -:meth:`~anyplotlib.figure_plots.Plot2D.add_arrows`. +:meth:`~anyplotlib.plot2d.Plot2D.add_arrows`. Use ``markers["arrows"]["name"].set(...)`` to update them live. """ import numpy as np -import anyplotlib as vw +import anyplotlib as apl rng = np.random.default_rng(3) data = rng.standard_normal((128, 128)).cumsum(0).cumsum(1) data = (data - data.min()) / (data.max() - data.min()) xy = np.linspace(0, 10, 128) -fig, ax = vw.subplots(1, 1, figsize=(460, 460)) +fig, ax = apl.subplots(1, 1, figsize=(460, 460)) v = ax.imshow(data, axes=[xy, xy], units="nm") tails = rng.uniform(15, 100, (8, 2)) U = rng.uniform(-18, 18, 8) V = rng.uniform(-18, 18, 8) - v.add_arrows(tails, U, V, name="flow", edgecolors="#76ff03", linewidths=2.0, label="flow vectors") + fig # %% diff --git a/Examples/Markers/plot_circles.py b/Examples/Markers/plot_circles.py new file mode 100644 index 00000000..c10fb944 --- /dev/null +++ b/Examples/Markers/plot_circles.py @@ -0,0 +1,34 @@ +""" +Circles +======= + +Mark circular features on a 2-D image with +:meth:`~anyplotlib.plot2d.Plot2D.add_circles`. +Use ``markers["circles"]["name"].set(...)`` to update them live. +""" +import numpy as np +import anyplotlib as apl + +rng = np.random.default_rng(0) +data = rng.standard_normal((128, 128)).cumsum(0).cumsum(1) +data = (data - data.min()) / (data.max() - data.min()) +xy = np.linspace(0, 10, 128) + +fig, ax = apl.subplots(1, 1, figsize=(460, 460)) +v = ax.imshow(data, axes=[xy, xy], units="nm") + +centres = rng.uniform(15, 113, (8, 2)) +v.add_circles(centres, name="spots", radius=10, + edgecolors="#ff1744", facecolors="#ff174433", + labels=[f"#{i}" for i in range(8)]) + +fig + +# %% +# Live update +# ----------- +# Call ``.set()`` on the marker group to push any change immediately. + +v.markers["circles"]["spots"].set(radius=16, edgecolors="#ffcc00", + facecolors="#ffcc0033") +fig diff --git a/Examples/Markers/plot_ellipses.py b/Examples/Markers/plot_ellipses.py index 6f2be5de..c641c095 100644 --- a/Examples/Markers/plot_ellipses.py +++ b/Examples/Markers/plot_ellipses.py @@ -3,18 +3,18 @@ ======== Draw ellipses on a 2-D image with -:meth:`~anyplotlib.figure_plots.Plot2D.add_ellipses`. +:meth:`~anyplotlib.plot2d.Plot2D.add_ellipses`. Use ``markers["ellipses"]["name"].set(...)`` to update them live. """ import numpy as np -import anyplotlib as vw +import anyplotlib as apl rng = np.random.default_rng(2) data = rng.standard_normal((128, 128)).cumsum(0).cumsum(1) data = (data - data.min()) / (data.max() - data.min()) xy = np.linspace(0, 10, 128) -fig, ax = vw.subplots(1, 1, figsize=(460, 460)) +fig, ax = apl.subplots(1, 1, figsize=(460, 460)) v = ax.imshow(data, axes=[xy, xy], units="nm") centres = np.array([[32.0, 32.0], [64.0, 96.0], [96.0, 48.0]]) @@ -23,6 +23,7 @@ name="grains", edgecolors="#ff9100", facecolors="#ff910033", label="grains", labels=["A", "B", "C"]) + fig # %% @@ -34,4 +35,3 @@ edgecolors="#69f0ae", facecolors="#69f0ae33") fig - diff --git a/Examples/Markers/plot_horizontal_lines.py b/Examples/Markers/plot_horizontal_lines.py new file mode 100644 index 00000000..579e07d8 --- /dev/null +++ b/Examples/Markers/plot_horizontal_lines.py @@ -0,0 +1,29 @@ +""" +Horizontal Lines +================ + +Draw static horizontal threshold lines on a 1-D plot with +:meth:`~anyplotlib.plot1d.Plot1D.add_hlines`. +Use ``markers["hlines"]["name"].set(...)`` to update them live. +""" +import numpy as np +import anyplotlib as apl + +x = np.linspace(0, 4 * np.pi, 512) +signal = np.sin(x) + +fig, ax = apl.subplots(1, 1, figsize=(560, 300)) +v = ax.plot(signal, axes=[x], units="rad") + +v.add_hlines([0.5, 0.0, -0.5], name="thresholds", + color="#69f0ae", linewidths=1.5, + label="thresholds", labels=["+0.5", "zero", "-0.5"]) + +fig + +# %% +# Live update +# ----------- + +v.markers["hlines"]["thresholds"].set(color="#ff1744", linewidths=2.0) +fig diff --git a/Examples/Markers/plot_line_segments.py b/Examples/Markers/plot_line_segments.py new file mode 100644 index 00000000..2f901493 --- /dev/null +++ b/Examples/Markers/plot_line_segments.py @@ -0,0 +1,40 @@ +""" +Line Segments +============= + +Draw line segments on a 2-D image with +:meth:`~anyplotlib.plot2d.Plot2D.add_lines`. +Use ``markers["lines"]["name"].set(...)`` to update them live. +""" +import numpy as np +import anyplotlib as apl + +rng = np.random.default_rng(4) +data = rng.standard_normal((128, 128)).cumsum(0).cumsum(1) +data = (data - data.min()) / (data.max() - data.min()) +xy = np.linspace(0, 10, 128) + +fig, ax = apl.subplots(1, 1, figsize=(460, 460)) +v = ax.imshow(data, axes=[xy, xy], units="nm") + +segments = np.array([ + [[ 10.0, 10.0], [118.0, 10.0]], + [[118.0, 10.0], [118.0, 118.0]], + [[118.0, 118.0], [ 10.0, 118.0]], + [[ 10.0, 118.0], [ 10.0, 10.0]], + [[ 10.0, 10.0], [118.0, 118.0]], +]) +v.add_lines(segments, name="frame", + edgecolors="#00e5ff", linewidths=1.5, + label="frame", + labels=["top", "right", "bottom", "left", "diagonal"]) + +fig + +# %% +# Live update +# ----------- +# Update stroke colour and width for all segments at once. + +v.markers["lines"]["frame"].set(edgecolors="#ff9100", linewidths=2.5) +fig diff --git a/Examples/Markers/plot_points.py b/Examples/Markers/plot_points.py new file mode 100644 index 00000000..aa713b89 --- /dev/null +++ b/Examples/Markers/plot_points.py @@ -0,0 +1,32 @@ +""" +Points +====== + +Mark specific (x, y) positions on a 1-D plot with +:meth:`~anyplotlib.plot1d.Plot1D.add_points`. +Use ``markers["points"]["name"].set(...)`` to update them live. +""" +import numpy as np +import anyplotlib as apl + +x = np.linspace(0, 4 * np.pi, 512) +signal = np.sin(x) + +fig, ax = apl.subplots(1, 1, figsize=(560, 300)) +v = ax.plot(signal, axes=[x], units="rad") + +peak_x = np.array([np.pi / 2, 5 * np.pi / 2, 9 * np.pi / 2]) +offsets = np.column_stack([peak_x, np.sin(peak_x)]) +v.add_points(offsets, name="peaks", + sizes=8, color="#ff1744", facecolors="#ff174433", + label="peaks", labels=["P1", "P2", "P3"]) + +fig + +# %% +# Live update +# ----------- + +v.markers["points"]["peaks"].set(sizes=12, color="#ffcc00", + facecolors="#ffcc0033") +fig diff --git a/Examples/Markers/plot_polygons.py b/Examples/Markers/plot_polygons.py new file mode 100644 index 00000000..16af4076 --- /dev/null +++ b/Examples/Markers/plot_polygons.py @@ -0,0 +1,38 @@ +""" +Polygons +======== + +Draw closed polygons on a 2-D image with +:meth:`~anyplotlib.plot2d.Plot2D.add_polygons`. +Use ``markers["polygons"]["name"].set(...)`` to update them live. +""" +import numpy as np +import anyplotlib as apl + +rng = np.random.default_rng(5) +data = rng.standard_normal((128, 128)).cumsum(0).cumsum(1) +data = (data - data.min()) / (data.max() - data.min()) +xy = np.linspace(0, 10, 128) + +fig, ax = apl.subplots(1, 1, figsize=(460, 460)) +v = ax.imshow(data, axes=[xy, xy], units="nm") + +triangle = [[64.0, 10.0], [100.0, 60.0], [28.0, 60.0]] +hexagon = [[64.0 + 28 * np.cos(np.radians(60 * k)), + 95.0 + 28 * np.sin(np.radians(60 * k))] + for k in range(6)] +v.add_polygons([triangle, hexagon], name="shapes", + edgecolors="#69f0ae", facecolors="#69f0ae22", + linewidths=2.0, + label="shapes", labels=["triangle", "hexagon"]) + +fig + +# %% +# Live update +# ----------- +# Change the stroke and fill colour of every polygon at once. + +v.markers["polygons"]["shapes"].set(edgecolors="#e040fb", + facecolors="#e040fb33") +fig diff --git a/Examples/Markers/plot_rectangles.py b/Examples/Markers/plot_rectangles.py new file mode 100644 index 00000000..32a7b279 --- /dev/null +++ b/Examples/Markers/plot_rectangles.py @@ -0,0 +1,34 @@ +""" +Rectangles +========== + +Draw bounding boxes on a 2-D image with +:meth:`~anyplotlib.plot2d.Plot2D.add_rectangles`. +Use ``markers["rectangles"]["name"].set(...)`` to update them live. +""" +import numpy as np +import anyplotlib as apl + +rng = np.random.default_rng(1) +data = rng.standard_normal((128, 128)).cumsum(0).cumsum(1) +data = (data - data.min()) / (data.max() - data.min()) +xy = np.linspace(0, 10, 128) + +fig, ax = apl.subplots(1, 1, figsize=(460, 460)) +v = ax.imshow(data, axes=[xy, xy], units="nm") + +centres = rng.uniform(20, 108, (5, 2)) +v.add_rectangles(centres, widths=22, heights=14, name="boxes", + edgecolors="#00e5ff", facecolors="#00e5ff22", + labels=[f"R{i}" for i in range(5)]) + +fig + +# %% +# Live update +# ----------- + +v.markers["rectangles"]["boxes"].set(widths=30, heights=20, + edgecolors="#ff9100", + facecolors="#ff910033") +fig diff --git a/Examples/Markers/plot_squares.py b/Examples/Markers/plot_squares.py index 125dfe23..8b170967 100644 --- a/Examples/Markers/plot_squares.py +++ b/Examples/Markers/plot_squares.py @@ -3,18 +3,18 @@ ======= Draw squares on a 2-D image with -:meth:`~anyplotlib.figure_plots.Plot2D.add_squares`. +:meth:`~anyplotlib.plot2d.Plot2D.add_squares`. Use ``markers["squares"]["name"].set(...)`` to update them live. """ import numpy as np -import anyplotlib as vw +import anyplotlib as apl rng = np.random.default_rng(6) data = rng.standard_normal((128, 128)).cumsum(0).cumsum(1) data = (data - data.min()) / (data.max() - data.min()) xy = np.linspace(0, 10, 128) -fig, ax = vw.subplots(1, 1, figsize=(460, 460)) +fig, ax = apl.subplots(1, 1, figsize=(460, 460)) v = ax.imshow(data, axes=[xy, xy], units="nm") centres = np.array([[32.0, 32.0], [64.0, 64.0], [96.0, 96.0], @@ -24,6 +24,7 @@ name="tiles", edgecolors="#00e5ff", facecolors="#00e5ff22", label="tiles", labels=[f"T{i}" for i in range(5)]) + fig # %% diff --git a/Examples/Markers/plot_texts.py b/Examples/Markers/plot_texts.py index d0898ffe..8ea2891d 100644 --- a/Examples/Markers/plot_texts.py +++ b/Examples/Markers/plot_texts.py @@ -3,18 +3,18 @@ =========== Place text annotations on a 2-D image with -:meth:`~anyplotlib.figure_plots.Plot2D.add_texts`. +:meth:`~anyplotlib.plot2d.Plot2D.add_texts`. Use ``markers["texts"]["name"].set(...)`` to update them live. """ import numpy as np -import anyplotlib as vw +import anyplotlib as apl rng = np.random.default_rng(7) data = rng.standard_normal((128, 128)).cumsum(0).cumsum(1) data = (data - data.min()) / (data.max() - data.min()) xy = np.linspace(0, 10, 128) -fig, ax = vw.subplots(1, 1, figsize=(460, 460)) +fig, ax = apl.subplots(1, 1, figsize=(460, 460)) v = ax.imshow(data, axes=[xy, xy], units="nm") v.add_texts([[4.0, 4.0], [4.0, 116.0], [88.0, 4.0], [88.0, 116.0]], @@ -22,6 +22,7 @@ name="corners", color="#ffeb3b", fontsize=12, label="corners") + fig # %% @@ -31,4 +32,3 @@ v.markers["texts"]["corners"].set(color="#e040fb", fontsize=14) fig - diff --git a/Examples/Markers/plot_vertical_lines.py b/Examples/Markers/plot_vertical_lines.py new file mode 100644 index 00000000..d3d53c8b --- /dev/null +++ b/Examples/Markers/plot_vertical_lines.py @@ -0,0 +1,29 @@ +""" +Vertical Lines +============== + +Draw static vertical marker lines on a 1-D plot with +:meth:`~anyplotlib.plot1d.Plot1D.add_vlines`. +Use ``markers["vlines"]["name"].set(...)`` to update them live. +""" +import numpy as np +import anyplotlib as apl + +x = np.linspace(0, 4 * np.pi, 512) +signal = np.sin(x) + +fig, ax = apl.subplots(1, 1, figsize=(560, 300)) +v = ax.plot(signal, axes=[x], units="rad") + +v.add_vlines([np.pi, 2 * np.pi, 3 * np.pi], name="pi_mult", + color="#00e5ff", linewidths=1.5, + label="pi multiples", labels=["\u03c0", "2\u03c0", "3\u03c0"]) + +fig + +# %% +# Live update +# ----------- + +v.markers["vlines"]["pi_mult"].set(color="#ff9100", linewidths=2.0) +fig diff --git a/Examples/PlotTypes/README.rst b/Examples/PlotTypes/README.rst new file mode 100644 index 00000000..8c02b0e9 --- /dev/null +++ b/Examples/PlotTypes/README.rst @@ -0,0 +1,3 @@ +Plot Types +---------- +A collection of short examples showing different plot types. \ No newline at end of file diff --git a/Examples/PlotTypes/plot_3d.py b/Examples/PlotTypes/plot_3d.py new file mode 100644 index 00000000..7829d613 --- /dev/null +++ b/Examples/PlotTypes/plot_3d.py @@ -0,0 +1,74 @@ +""" +3D Plotting +=========== + +Demonstrate the three 3-D geometry types supported by +:meth:`~anyplotlib.Axes.plot_surface`, +:meth:`~anyplotlib.Axes.scatter3d`, and +:meth:`~anyplotlib.Axes.plot3d`. +Drag to rotate, scroll to zoom, press **R** to reset the view. +""" +import numpy as np +import anyplotlib as apl + +# ── Surface ─────────────────────────────────────────────────────────────────── +x = np.linspace(-3, 3, 60) +y = np.linspace(-3, 3, 60) +XX, YY = np.meshgrid(x, y) +ZZ = np.sin(np.sqrt(XX ** 2 + YY ** 2)) + +fig, ax = apl.subplots(1, 1, figsize=(520, 480)) +surf = ax.plot_surface(XX, YY, ZZ, + colormap="viridis", + x_label="x", y_label="y", z_label="sin(r)") + +fig + +# %% +# Scatter plot +# ------------ + +rng = np.random.default_rng(1) +n = 300 +theta = rng.uniform(0, 2 * np.pi, n) +phi = rng.uniform(0, np.pi, n) +r = rng.uniform(0.6, 1.0, n) +xs = r * np.sin(phi) * np.cos(theta) +ys = r * np.sin(phi) * np.sin(theta) +zs = r * np.cos(phi) + +fig2, ax2 = apl.subplots(1, 1, figsize=(480, 480)) +sc = ax2.scatter3d(xs, ys, zs, + color="#4fc3f7", point_size=3, + x_label="x", y_label="y", z_label="z") + +fig2 + +# %% +# 3-D line — parametric helix +# ---------------------------- + +t = np.linspace(0, 4 * np.pi, 300) +hx = np.cos(t) +hy = np.sin(t) +hz = t / (4 * np.pi) + +fig3, ax3 = apl.subplots(1, 1, figsize=(480, 480)) +ln = ax3.plot3d(hx, hy, hz, + color="#ff7043", linewidth=2, + x_label="cos t", y_label="sin t", z_label="t") + +fig3 + +# %% +# Update the surface data live +# ---------------------------- +# Call :meth:`~anyplotlib.Plot3D.set_data` to replace the geometry +# without recreating the panel. + +ZZ2 = np.cos(np.sqrt(XX ** 2 + YY ** 2)) +surf.set_data(XX, YY, ZZ2) +surf.set_colormap("plasma") +surf.set_view(azimuth=30, elevation=40) + +fig diff --git a/Examples/PlotTypes/plot_bar.py b/Examples/PlotTypes/plot_bar.py new file mode 100644 index 00000000..dfb966ee --- /dev/null +++ b/Examples/PlotTypes/plot_bar.py @@ -0,0 +1,151 @@ +""" +Bar Chart +========= + +Demonstrate :meth:`~anyplotlib.Axes.bar` with: + +* **Matplotlib-aligned API** — ``ax.bar(x, height, width, bottom, …)`` +* Vertical and horizontal orientations, per-bar colours, category labels +* **Grouped bars** — pass a 2-D *height* array ``(N, G)`` +* **Log-scale value axis** — ``log_scale=True`` +* Live data updates via :meth:`~anyplotlib.PlotBar.set_data` +""" +import numpy as np +import anyplotlib as apl + +rng = np.random.default_rng(7) + +# ── 1. Vertical bar chart — monthly sales ──────────────────────────────────── +# The first positional argument is now *x* (positions or labels), matching +# ``matplotlib.pyplot.bar(x, height, width=0.8, bottom=0.0, ...)``. +months = ["Jan", "Feb", "Mar", "Apr", "May", "Jun", + "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"] +sales = np.array([42, 55, 48, 63, 71, 68, 74, 81, 66, 59, 52, 78], + dtype=float) + +fig1, ax1 = apl.subplots(1, 1, figsize=(640, 340)) +bar1 = ax1.bar( + months, # x — category strings become x_labels automatically + sales, # height + width=0.6, + color="#4fc3f7", + show_values=True, + units="Month", + y_units="Units sold", +) +fig1 + +# %% +# Horizontal bar chart — ranked items +# ------------------------------------- +# Set ``orient="h"`` for a horizontal layout. Pass a list of CSS colours +# to ``colors`` to give each bar its own colour. + +categories = ["NumPy", "SciPy", "Matplotlib", "Pandas", "Scikit-learn", + "PyTorch", "TensorFlow", "JAX", "Polars", "Dask"] +scores = np.array([95, 88, 91, 87, 83, 79, 76, 72, 68, 65], dtype=float) + +palette = [ + "#ef5350", "#ec407a", "#ab47bc", "#7e57c2", "#42a5f5", + "#26c6da", "#26a69a", "#66bb6a", "#d4e157", "#ffa726", +] + +fig2, ax2 = apl.subplots(1, 1, figsize=(540, 400)) +bar2 = ax2.bar( + categories, + scores, + orient="h", + colors=palette, + width=0.65, + show_values=True, + y_units="Popularity score", +) +fig2 + +# %% +# Grouped bar chart — quarterly comparison +# ----------------------------------------- +# Pass a 2-D *height* array of shape ``(N, G)`` to draw *G* bars side by +# side for each category. Provide ``group_labels`` to show a legend and +# ``group_colors`` to customise each group's colour. + +quarters = ["Jan", "Feb", "Mar", "Apr", "May", "Jun"] +q_data = np.array([ + [42, 58, 51], # Jan — Q1, Q2, Q3 + [55, 61, 59], # Feb + [48, 70, 65], # Mar + [63, 75, 71], # Apr + [71, 69, 80], # May + [68, 83, 77], # Jun +], dtype=float) # shape (6, 3) → 6 categories, 3 groups + +fig3, ax3 = apl.subplots(1, 1, figsize=(680, 340)) +bar3 = ax3.bar( + quarters, + q_data, + width=0.8, + group_labels=["Q1", "Q2", "Q3"], + group_colors=["#4fc3f7", "#ff7043", "#66bb6a"], + show_values=False, + y_units="Sales", +) +fig3 + +# %% +# Log-scale value axis +# --------------------- +# Set ``log_scale=True`` for a logarithmic value axis. Non-positive values +# are clamped to ``1e-10`` — no error is raised. Tick marks are placed at +# each decade (10⁰, 10¹, 10², …) with faint minor gridlines at 2×, 3×, 5× +# multiples. + +log_labels = ["A", "B", "C", "D", "E"] +log_vals = np.array([1, 10, 100, 1_000, 10_000], dtype=float) + +fig4, ax4 = apl.subplots(1, 1, figsize=(500, 340)) +bar4 = ax4.bar( + log_labels, + log_vals, + log_scale=True, + color="#ab47bc", + show_values=True, + y_units="Count (log scale)", +) +fig4 + +# %% +# Side-by-side comparison — update data live +# ------------------------------------------- +# Place two :class:`~anyplotlib.PlotBar` panels in one figure. +# Call :meth:`~anyplotlib.PlotBar.set_data` to swap in Q2 data — +# the value-axis range recalculates automatically. + +q1 = np.array([42, 55, 48, 63, 71, 68, 74, 81, 66, 59, 52, 78], dtype=float) +q2 = np.array([58, 61, 70, 75, 69, 83, 90, 88, 77, 64, 71, 95], dtype=float) +all_months = ["Jan", "Feb", "Mar", "Apr", "May", "Jun", + "Jul", "Aug", "Sep", "Oct", "Nov", "Dec"] + +fig5, (ax_left, ax_right) = apl.subplots(1, 2, figsize=(820, 320)) +bar_left = ax_left.bar( + all_months, q1, width=0.6, + color="#4fc3f7", show_values=False, y_units="Q1 sales", +) +bar_right = ax_right.bar( + all_months, q1, width=0.6, + color="#ff7043", show_values=False, y_units="Q2 sales", +) +bar_right.set_data(q2) # swap in Q2 — axis range recalculates automatically + +fig5 + +# %% +# Mutate colours, annotations, and scale at runtime +# -------------------------------------------------- +# :meth:`~anyplotlib.PlotBar.set_color` repaints all bars, +# :meth:`~anyplotlib.PlotBar.set_show_values` toggles labels, +# :meth:`~anyplotlib.PlotBar.set_log_scale` switches the +# value-axis between linear and logarithmic. + +bar1.set_color("#ff7043") +bar1.set_show_values(False) +fig1 diff --git a/Examples/PlotTypes/plot_gridspec_custom.py b/Examples/PlotTypes/plot_gridspec_custom.py new file mode 100644 index 00000000..31e3dd4a --- /dev/null +++ b/Examples/PlotTypes/plot_gridspec_custom.py @@ -0,0 +1,187 @@ +""" +Custom Grid Layouts with GridSpec +================================== + +:class:`~anyplotlib.GridSpec` lets you build multi-panel figures where panels +have different sizes and span multiple grid cells. This gallery shows the most +common patterns. + +All examples use the **bare** ``Figure + GridSpec`` workflow — the figure's +grid dimensions are inferred automatically from the GridSpec the first time +``add_subplot`` is called. + +Overview +-------- + +1. **Side-by-side spectra** — two equal 1-D panels in one row (``1×2`` grid). +2. **Image + spectra** — image spanning full width, two spectra below + (``2×2`` grid with ``height_ratios=[3, 1]``). +3. **Image + histogram** — classic EM layout: large image on top, thin + histogram strip below (``2×1`` grid with ``height_ratios=[3, 1]``). +4. **Three-column** — three equal columns in a single row (``1×3`` grid). +5. **Asymmetric widths** — wide overview left, narrow detail right + (``1×2`` grid with ``width_ratios=[2, 1]``). +6. **Complex** — spanning top panel plus two bottom panels (``2×2`` grid). +""" +import numpy as np +import anyplotlib as apl + +rng = np.random.default_rng(42) +t = np.linspace(0.0, 2.0 * np.pi, 512) + +# ── 1. Side-by-side spectra (1×2, equal widths) ─────────────────────────────── +# %% +# Side-by-side spectra +# -------------------- +# The simplest multi-panel case: two 1-D spectra in one row. Each panel +# receives exactly half the figure width with a full-height inner plot area. +# Both panels share the same height so their axes baselines align visually. + +gs1 = apl.GridSpec(1, 2) +fig1 = apl.Figure(figsize=(720, 280)) + +sp_left = fig1.add_subplot(gs1[0, 0]).plot( + np.sin(t) + rng.normal(scale=0.05, size=len(t)), + color="#4fc3f7", label="channel A") + +sp_right = fig1.add_subplot(gs1[0, 1]).plot( + np.cos(t) + rng.normal(scale=0.05, size=len(t)), + color="#ff7043", label="channel B") + +fig1 # Interactive + +# ── 2. Image + two spectra (2×2, height_ratios=[3, 1]) ──────────────────────── +# %% +# Image on top, two spectra below +# -------------------------------- +# A ``2×2`` grid with ``height_ratios=[3, 1]`` puts a wide image in the upper +# three-quarters and two comparison spectra side-by-side in the lower quarter. +# +# The spanning subplot ``gs2[0, :]`` covers all columns in row 0, so the image +# gets the full figure width. + +N = 128 +x = np.linspace(-4, 4, N) +y = np.linspace(-4, 4, N) +XX, YY = np.meshgrid(x, y) +image = np.exp(-(XX**2 + YY**2) / 4) + 0.3 * np.exp(-((XX - 2)**2 + YY**2) / 1) +image += rng.normal(scale=0.03, size=image.shape) + +gs2 = apl.GridSpec(2, 2, height_ratios=[3, 1]) +fig2 = apl.Figure(figsize=(640, 560)) + +fig2.add_subplot(gs2[0, :]).imshow(image.astype(np.float32), cmap="inferno") + +row_profile = image[N // 2, :] +col_profile = image[:, N // 2] + +fig2.add_subplot(gs2[1, 0]).plot( + row_profile, axes=[x], units="nm", + color="#4fc3f7", label="row profile") + +fig2.add_subplot(gs2[1, 1]).plot( + col_profile, axes=[y], units="nm", + color="#ff7043", label="col profile") + +fig2 # Interactive + +# ── 3. Image + histogram (2×1, height_ratios=[3, 1]) ────────────────────────── +# %% +# Image + histogram strip +# ----------------------- +# A ``2×1`` grid with ``height_ratios=[3, 1]`` is the classic layout for +# showing an image with its intensity histogram below. The image occupies +# three-quarters of the height; the histogram strip the remaining quarter. + +gs3 = apl.GridSpec(2, 1, height_ratios=[3, 1]) +fig3 = apl.Figure(figsize=(500, 600)) + +fig3.add_subplot(gs3[0, 0]).imshow(image.astype(np.float32), cmap="viridis") + +counts, edges = np.histogram(image.ravel(), bins=64) +bin_centers = 0.5 * (edges[:-1] + edges[1:]) +fig3.add_subplot(gs3[1, 0]).plot( + counts.astype(float), axes=[bin_centers], + color="#aed581", label="histogram") + +fig3 # Interactive + +# ── 4. Three equal columns (1×3) ────────────────────────────────────────────── +# %% +# Three-column layout +# ------------------- +# A ``1×3`` grid gives three equal panels that are easy to compare visually. +# Useful for showing the same quantity at three different conditions or times. + +gs4 = apl.GridSpec(1, 3) +fig4 = apl.Figure(figsize=(900, 240)) + +spectra = [ + np.sin(t * (i + 1)) + rng.normal(scale=0.08, size=len(t)) + for i in range(3) +] +colors = ["#4fc3f7", "#ff7043", "#aed581"] +labels = ["f₁", "f₂", "f₃"] + +for i, (data, color, label) in enumerate(zip(spectra, colors, labels)): + fig4.add_subplot(gs4[0, i]).plot(data, color=color, label=label) + +fig4 # Interactive + +# ── 5. Asymmetric widths (1×2, width_ratios=[2, 1]) ────────────────────────── +# %% +# Asymmetric column widths +# ------------------------ +# ``width_ratios=[2, 1]`` makes the left panel twice as wide as the right. +# A common use-case is a broad overview spectrum on the left and a zoomed +# detail region on the right. + +energy = np.linspace(280, 295, 1024) +peak = np.exp(-0.5 * ((energy - 284.8) / 0.3)**2) +peak2 = 0.35 * np.exp(-0.5 * ((energy - 286.2) / 0.3)**2) +spectrum = peak + peak2 + 0.1 * np.exp(-0.05 * (energy - 280)) \ + + rng.normal(scale=0.01, size=len(energy)) + +gs5 = apl.GridSpec(1, 2, width_ratios=[2, 1]) +fig5 = apl.Figure(figsize=(720, 260)) + +fig5.add_subplot(gs5[0, 0]).plot( + spectrum, axes=[energy], units="eV", + color="#4fc3f7", label="survey") + +mask = (energy >= 283.5) & (energy <= 286.5) +fig5.add_subplot(gs5[0, 1]).plot( + spectrum[mask], axes=[energy[mask]], units="eV", + color="#ff7043", label="detail") + +fig5 # Interactive + +# ── 6. Complex layout: spanning top + two bottom (2×2, height_ratios=[2, 1]) ── +# %% +# Complex layout: spanning top panel +# ----------------------------------- +# A ``2×2`` grid where ``gs6[0, :]`` spans both columns creates a wide panel +# on top (e.g. a summed spectrum) with two comparison panels below it. +# ``height_ratios=[2, 1]`` gives the top panel twice the height of each bottom +# panel. + +summed = spectrum + rng.normal(scale=0.02, size=len(energy)) +diff1 = rng.normal(scale=0.05, size=len(energy)) +diff2 = rng.normal(scale=0.05, size=len(energy)) + +gs6 = apl.GridSpec(2, 2, height_ratios=[2, 1]) +fig6 = apl.Figure(figsize=(720, 480)) + +fig6.add_subplot(gs6[0, :]).plot( + summed, axes=[energy], units="eV", + color="#4fc3f7", label="summed") + +fig6.add_subplot(gs6[1, 0]).plot( + diff1, axes=[energy], units="eV", + color="#ff7043", label="Δ channel 1") + +fig6.add_subplot(gs6[1, 1]).plot( + diff2, axes=[energy], units="eV", + color="#aed581", label="Δ channel 2") + +fig6 # Interactive diff --git a/Examples/PlotTypes/plot_image2d.py b/Examples/PlotTypes/plot_image2d.py new file mode 100644 index 00000000..c568aff6 --- /dev/null +++ b/Examples/PlotTypes/plot_image2d.py @@ -0,0 +1,130 @@ +""" +2D Image with Histogram +======================= + +Display a 2-D image with physical axes, a colourmap, and an interactive +histogram below — all wired together with draggable threshold widgets. + +Layout +------ +A :class:`~anyplotlib.GridSpec` with two rows puts the image +on top and a bar-chart histogram below. Two +:class:`~anyplotlib.widgets.VLineWidget` handles on the histogram mark the +``display_min`` / ``display_max`` thresholds; dragging them updates the +image colour scale in real time. + +Key bindings on the image panel: **R** reset view · **C** toggle colorbar · +**L** / **S** cycle colour-scale modes. + +New ``imshow`` parameters +------------------------- +``cmap`` + Colormap name passed directly to :meth:`~anyplotlib.Axes.imshow` + (e.g. ``"viridis"``, ``"inferno"``). Defaults to ``"gray"``. +``vmin`` / ``vmax`` + Colormap clipping limits in data units. Values outside the range are + clamped to the colormap endpoints. Defaults to the data min/max. +``origin`` + ``"upper"`` (default) places row 0 at the top (image convention). + ``"lower"`` places row 0 at the bottom (scientific / matrix convention) + and automatically reverses the y-axis so tick values increase upward. +""" +import numpy as np +import anyplotlib as apl + + +rng = np.random.default_rng(1) + +# ── Synthetic diffraction pattern ───────────────────────────────────────────── +N = 256 +x = np.linspace(-5, 5, N) # physical axis in nm +y = np.linspace(-5, 5, N) +XX, YY = np.meshgrid(x, y) +R = np.sqrt(XX ** 2 + YY ** 2) + + +def _ring(r, r0, width, amp): + return amp * np.exp(-0.5 * ((r - r0) / width) ** 2) + + +image = ( + _ring(R, 0.0, 0.30, 1.00) # central spot + + _ring(R, 2.1, 0.15, 0.55) # first-order ring + + _ring(R, 4.2, 0.15, 0.25) # second-order ring + + rng.normal(scale=0.04, size=(N, N)) +) + +# ── Layout: image (top, 3×) + histogram bar chart (bottom, 1×) ──────────────── +gs = apl.GridSpec(2, 1, height_ratios=[3, 1]) +fig = apl.Figure(figsize=(500, 640)) +ax_img = fig.add_subplot(gs[0, 0]) +ax_hist = fig.add_subplot(gs[1, 0]) + +# ── Image panel — cmap, vmin, vmax supplied directly to imshow ──────────────── +vmin_init = float(image.min()) +vmax_init = float(image.max()) + +# Pass cmap, vmin, and vmax directly — no separate set_colormap / set_clim call +# needed for the initial display. +v = ax_img.imshow(image, axes=[x, y], units="nm", + cmap="inferno", vmin=vmin_init, vmax=vmax_init) + +# First-order spot markers in the same physical coordinates used by imshow +spot_nm = np.array([[ 2.1, 0.0], [-2.1, 0.0], + [ 0.0, 2.1], [ 0.0, -2.1]]) +v.add_circles(spot_nm, name="spots", radius=7, + edgecolors="#00e5ff", facecolors="#00e5ff22", + labels=["g1", "g1_bar", "g2", "g2_bar"]) + +# ── Histogram bar chart ──────────────────────────────────────────────────────── +counts, edges = np.histogram(image.ravel(), bins=64) +bin_centers = 0.5 * (edges[:-1] + edges[1:]) + +h = ax_hist.bar(counts, x_centers=bin_centers, orient="v", + color="#4fc3f7", y_units="count") + +# ── Draggable threshold handles on the histogram ────────────────────────────── +wlo = h.add_vline_widget(vmin_init, color="#ff6e40") # low-threshold handle +whi = h.add_vline_widget(vmax_init, color="#ffffff") # high-threshold handle + + +@wlo.add_event_handler("pointer_up") +def _apply_low(event): + """Update image display_min when the low handle is released.""" + v.set_clim(vmin=event.source.x) + + +@whi.add_event_handler("pointer_up") +def _apply_high(event): + """Update image display_max when the high handle is released.""" + v.set_clim(vmax=event.source.x) + + +fig # Interactive + +# %% +# Adjust colour map and display range +# ------------------------------------ +# :meth:`~anyplotlib.Plot2D.set_colormap` switches the palette; +# :meth:`~anyplotlib.Plot2D.set_clim` adjusts the display range. +# Both are equivalent to passing ``cmap`` / ``vmin`` / ``vmax`` at construction. + +v.set_colormap("viridis") +v.set_clim(vmin=0.0, vmax=0.8) + +fig + +# %% +# origin='lower' — scientific / matrix convention +# ------------------------------------------------ +# Passing ``origin='lower'`` places row 0 of the data at the *bottom* of the +# image, matching the matplotlib / scientific convention. The y-axis is +# automatically reversed so tick values still increase upward. + +mat = np.arange(64, dtype=float).reshape(8, 8) # row 0 = small values + +fig2, ax2 = apl.subplots() +v2 = ax2.imshow(mat, cmap="plasma", origin="lower") + +fig2 # Interactive + diff --git a/Examples/PlotTypes/plot_inset.py b/Examples/PlotTypes/plot_inset.py new file mode 100644 index 00000000..0f88a4fe --- /dev/null +++ b/Examples/PlotTypes/plot_inset.py @@ -0,0 +1,90 @@ +""" +Inset Plots +=========== + +Floating informational sub-plots that overlay the main figure — useful for +displaying supplementary data alongside a primary image, as seen in orientation +mapping, phase analysis, and similar workflows. + +Each inset has a **title bar** with two buttons: + +* **−** (minimize) — collapses the inset to its title bar only. +* **⤢** (maximize) — expands the inset to ~72 % of the figure, centred. + Click **⤡** to restore. + +Multiple insets sharing the same ``corner`` auto-stack so they never overlap +in the minimised or normal state. + +Python-side state can also be set programmatically:: + + inset.minimize() + inset.maximize() + inset.restore() + print(inset.inset_state) # "normal" | "minimized" | "maximized" +""" + +import numpy as np +import anyplotlib as apl + +rng = np.random.default_rng(42) + +# ── Helpers — synthetic data ────────────────────────────────────────────────── + +def _diffraction(N=256): + """Simulated diffraction pattern (Gaussian rings).""" + y, x = np.ogrid[-N//2:N//2, -N//2:N//2] + r = np.hypot(x, y) + img = np.zeros((N, N)) + for r0, sigma, amp in [(40, 6, 1.0), (80, 8, 0.6), (120, 10, 0.3)]: + img += amp * np.exp(-((r - r0) ** 2) / (2 * sigma ** 2)) + img += rng.normal(0, 0.04, img.shape) + return img + +def _phase_map(N=128): + """Fake two-phase orientation map.""" + img = rng.integers(0, 4, (N, N), dtype=np.uint8) + # blob of phase 2 in the centre + cy, cx = N // 2, N // 2 + yy, xx = np.ogrid[:N, :N] + img[((yy - cy)**2 + (xx - cx)**2) < (N // 4)**2] = np.uint8(5) + return img.astype(float) + +def _pole_figure(N=96): + """Simulated pole-figure intensity (radial Gaussian blob).""" + y, x = np.ogrid[-N//2:N//2, -N//2:N//2] + r = np.hypot(x, y) + return np.exp(-(r ** 2) / (2 * (N // 6) ** 2)) + rng.normal(0, 0.02, (N, N)) + +def _virtual_adf(N=128): + """Annular dark-field signal for a simple lattice.""" + y, x = np.mgrid[:N, :N] + return (np.sin(y * 0.4) * np.cos(x * 0.4)) ** 2 + rng.normal(0, 0.05, (N, N)) + +# ── Build figure ────────────────────────────────────────────────────────────── + +fig, ax = apl.subplots(1, 1, figsize=(660, 500)) + +# Primary large image: diffraction pattern +main = ax.imshow(_diffraction(256), cmap="inferno") + +# ── Inset 1: phase map (top-right) ─────────────────────────────────────────── +inset_phase = fig.add_inset(0.27, 0.27, corner="top-right", title="Phase Map") +inset_phase.imshow(_phase_map(128), cmap="tab10") + +# ── Inset 2: pole figure — stacks below inset 1 in the same corner ──────────── +inset_pole = fig.add_inset(0.27, 0.27, corner="top-right", title="Pole Figure") +inset_pole.imshow(_pole_figure(96), cmap="hot") + +# ── Inset 3: virtual ADF (bottom-left) ──────────────────────────────────────── +inset_adf = fig.add_inset(0.27, 0.27, corner="bottom-left", title="Virtual ADF") +inset_adf.imshow(_virtual_adf(128), cmap="gray") + +# ── Inset 4: 1-D line profile (bottom-right) ───────────────────────────────── +x_nm = np.linspace(0, 10, 256) +profile = np.sin(x_nm * 3.5) * np.exp(-x_nm * 0.18) + rng.normal(0, 0.05, 256) + +inset_line = fig.add_inset(0.30, 0.22, corner="bottom-right", title="Line Profile") +inset_line.plot(profile, axes=[x_nm], units="nm", color="#4fc3f7", linewidth=1.5) + +fig + diff --git a/Examples/PlotTypes/plot_label_formatting.py b/Examples/PlotTypes/plot_label_formatting.py new file mode 100644 index 00000000..e6bbd320 --- /dev/null +++ b/Examples/PlotTypes/plot_label_formatting.py @@ -0,0 +1,42 @@ +""" +Label Sizes and Scientific (TeX) Formatting +=========================================== + +Axis labels, titles, and the colorbar label accept an optional ``fontsize`` +(in CSS pixels) and support a small TeX subset inside ``$...$`` for +scientific notation — superscripts, subscripts, Greek letters, and common +symbols — rendered directly on the canvas with no MathJax dependency: + +* ``$10^{-3}$``, ``$x^2$`` — exponents +* ``$E_F$``, ``$k_{B}T$`` — subscripts +* ``$\\alpha$ … $\\Omega$``, ``\\mu``, ``\\Delta`` — Greek letters +* ``\\times``, ``\\pm``, ``\\AA``, ``\\degree``, ``\\propto``, ``\\partial`` — symbols +* ``$\\mathrm{...}$`` — upright text inside math +""" +import numpy as np +import anyplotlib as apl + +rng = np.random.default_rng(7) + +fig, (ax_img, ax_spec) = apl.subplots(1, 2, figsize=(880, 380)) + +# ── 2-D panel: diffraction-style image with TeX axis labels ──────────────── +data = rng.standard_normal((128, 128)).cumsum(0).cumsum(1) +q = np.linspace(-2.5, 2.5, 128) +img = ax_img.imshow(data, axes=[q, q], units="") +img.set_title(r"$|F(q)|^2$", fontsize=12) +img.set_xlabel(r"$q_x$ ($\AA^{-1}$)", fontsize=13) +img.set_ylabel(r"$q_y$ ($\AA^{-1}$)", fontsize=13) +img.set_colorbar_visible(True) +img.set_colorbar_label(r"Counts $\times 10^{3}$") + +# ── 1-D panel: spectrum with sized, TeX-formatted labels ─────────────────── +energy = np.linspace(0, 3, 512) +spectrum = np.exp(-((energy - 1.2) / 0.15) ** 2) + 0.05 * rng.random(512) +spec = ax_spec.plot(spectrum, axes=[energy], color="#ff7043") +spec.set_title(r"Plasmon peak near $E_p$", fontsize=12) +spec.set_xlabel(r"$\Delta E$ (eV)", fontsize=12) +spec.set_ylabel(r"Intensity ($10^{-3}$ counts)", fontsize=12) +spec.set_tick_label_size(11) + +fig diff --git a/Examples/PlotTypes/plot_line_styles.py b/Examples/PlotTypes/plot_line_styles.py new file mode 100644 index 00000000..2c38c527 --- /dev/null +++ b/Examples/PlotTypes/plot_line_styles.py @@ -0,0 +1,159 @@ +""" +1D Line Styles +============== + +Demonstrates the line-style, opacity, and per-point marker parameters +available on :meth:`~anyplotlib.Axes.plot` and +:meth:`~anyplotlib.Plot1D.add_line`. + +Four separate figures are shown: + +1. **Linestyles** – all four dash patterns on one panel with a legend. +2. **Alpha (transparency)** – two overlapping sine waves, each at 40 % opacity. +3. **Marker symbols** – all seven supported symbols, each on its own offset + curve. +4. **Combined** – dashed + semi-transparent + circle-marker overlay on a solid + primary line; demonstrates post-construction setters. +""" +import numpy as np +import anyplotlib as apl + +t256 = np.linspace(0.0, 2.0 * np.pi, 256) # dense — good for dashes / alpha +t24 = np.linspace(0.0, 2.0 * np.pi, 24) # sparse — makes markers visible + +# ── 1. Linestyles ───────────────────────────────────────────────────────────── +fig1, ax1 = apl.subplots(1, 1, figsize=(580, 300)) + +plot1 = ax1.plot(np.sin(t256), color="#4fc3f7", linewidth=2, + linestyle="solid", label="solid") +plot1.add_line(np.sin(t256) + 0.6, color="#ff7043", linewidth=2, + linestyle="dashed", label="dashed (\"--\")") +plot1.add_line(np.sin(t256) + 1.2, color="#aed581", linewidth=2, + linestyle="dotted", label="dotted (\":\")") +plot1.add_line(np.sin(t256) + 1.8, color="#ce93d8", linewidth=2, + linestyle="dashdot", label="dashdot (\"-.\")") + +fig1 + +# %% +# The ``ls`` shorthand +# -------------------- +# Each linestyle has a single-character (or two-character) shorthand that +# matches the matplotlib convention: +# +# * ``"-"`` → ``"solid"`` +# * ``"--"`` → ``"dashed"`` +# * ``":"`` → ``"dotted"`` +# * ``"-."`` → ``"dashdot"`` +# +# The shorthands work on both :meth:`~anyplotlib.Axes.plot` +# and :meth:`~anyplotlib.Plot1D.add_line`: + +fig2a, ax2a = apl.subplots(1, 1, figsize=(440, 220)) +p = ax2a.plot(np.sin(t256), ls="-", color="#4fc3f7", label='ls="-"') +p.add_line(np.sin(t256) + 0.8, ls="--", color="#ff7043", label='ls="--"') +p.add_line(np.sin(t256) + 1.6, ls=":", color="#aed581", label='ls=":"') +fig2a + +# %% +# Alpha (opacity) +# --------------- +# ``alpha`` controls line opacity on a 0–1 scale. Values below 1 let +# overlapping curves show through each other — useful for comparing signals +# that share the same amplitude range. + +fig2, ax2 = apl.subplots(1, 1, figsize=(580, 300)) + +plot2 = ax2.plot(np.sin(t256), color="#4fc3f7", alpha=0.4, linewidth=3, + label="sin α=0.4") +plot2.add_line(np.cos(t256), color="#ff7043", alpha=0.4, linewidth=3, + label="cos α=0.4") + +fig2 + +# %% +# Marker symbols +# -------------- +# Set ``marker`` to place a symbol at every data point. Use a **sparse** +# x-axis (few points) so the individual markers are legible. +# ``markersize`` is the radius (circles / diamonds) or half-side-length +# (squares, triangles) in canvas pixels. +# +# Supported symbols: +# +# * ``"o"`` — circle +# * ``"s"`` — square +# * ``"^"`` — triangle-up +# * ``"v"`` — triangle-down +# * ``"D"`` — diamond +# * ``"+"`` — plus (stroke-only) +# * ``"x"`` — cross (stroke-only) +# * ``"none"`` — no marker (default) + +SYMBOLS = [ + ("o", "#4fc3f7"), + ("s", "#ff7043"), + ("^", "#aed581"), + ("v", "#ce93d8"), + ("D", "#ffcc02"), + ("+", "#80cbc4"), + ("x", "#ef9a9a"), +] + +fig3, ax3 = apl.subplots(1, 1, figsize=(580, 380)) + +plot3 = ax3.plot( + np.sin(t24) + (0 - 3) * 0.9, + color=SYMBOLS[0][1], linewidth=1.5, + marker=SYMBOLS[0][0], markersize=5, + label=f'marker="{SYMBOLS[0][0]}"', +) +for i, (sym, col) in enumerate(SYMBOLS[1:], 1): + plot3.add_line( + np.sin(t24) + (i - 3) * 0.9, + color=col, linewidth=1.5, + marker=sym, markersize=5, + label=f'marker="{sym}"', + ) + +fig3 + +# %% +# Combined — linestyle + alpha + marker +# -------------------------------------- +# All three style parameters can be combined freely on the same line or on +# separate overlay lines. + +fig4, ax4 = apl.subplots(1, 1, figsize=(580, 300)) + +# Dense solid primary line +plot4 = ax4.plot(np.sin(t256), color="#4fc3f7", linewidth=2, + label="sin (solid)") + +# Sparse dashed overlay with circle markers and reduced opacity +plot4.add_line(np.cos(t24), color="#ff7043", linewidth=2, + linestyle="dashed", alpha=0.75, + marker="o", markersize=5, + label="cos (dashed, α=0.75, marker='o')") + +fig4 + +# %% +# Post-construction setters +# ------------------------- +# Every primary-line style property has a matching setter method. These +# mutate ``_state`` and push the change to the canvas immediately — no +# need to recreate the panel. + +fig5, ax5 = apl.subplots(1, 1, figsize=(440, 220)) +plot5 = ax5.plot(np.sin(t256), color="#4fc3f7", linewidth=1.5) + +# Change style via setters +plot5.set_color("#ff7043") +plot5.set_linewidth(2.5) +plot5.set_linestyle("dashdot") # equivalent: plot5.set_linestyle("-.") +plot5.set_alpha(0.8) +plot5.set_marker("o", markersize=5) + +fig5 + diff --git a/Examples/plot_pcolormesh.py b/Examples/PlotTypes/plot_pcolormesh.py similarity index 89% rename from Examples/plot_pcolormesh.py rename to Examples/PlotTypes/plot_pcolormesh.py index a9d298ba..c5a79b74 100644 --- a/Examples/plot_pcolormesh.py +++ b/Examples/PlotTypes/plot_pcolormesh.py @@ -2,17 +2,17 @@ pcolormesh — non-linear axes ============================ -Demonstrate :meth:`~anyplotlib.figure_plots.Axes.pcolormesh` with non-uniform +Demonstrate :meth:`~anyplotlib.Axes.pcolormesh` with non-uniform (log-spaced) x-edges and irregularly-spaced y-edges, mirroring ``matplotlib.axes.Axes.pcolormesh``. -The key difference from :meth:`~anyplotlib.figure_plots.Axes.imshow` is that +The key difference from :meth:`~anyplotlib.Axes.imshow` is that ``pcolormesh`` takes **edge** arrays (length N+1 and M+1 for an (M, N) data -array) rather than centre arrays. This enables fully non-linear axes where +array) rather than center arrays. This enables fully non-linear axes where each cell can have a different width/height in data coordinates. """ import numpy as np -import anyplotlib as vw +import anyplotlib as apl rng = np.random.default_rng(42) @@ -36,7 +36,7 @@ [y_centres[-1] + (y_centres[-1] - y_centres[-2]) / 2]]) # ── Plot ────────────────────────────────────────────────────────────────────── -fig, ax = vw.subplots(1, 1, figsize=(560, 460)) +fig, ax = apl.subplots(1, 1, figsize=(560, 460)) mesh = ax.pcolormesh(data, x_edges=x_edges, y_edges=y_edges, units="arb.") mesh.set_colormap("viridis") fig diff --git a/Examples/PlotTypes/plot_spectra1d.py b/Examples/PlotTypes/plot_spectra1d.py new file mode 100644 index 00000000..347565dc --- /dev/null +++ b/Examples/PlotTypes/plot_spectra1d.py @@ -0,0 +1,112 @@ +""" +1D Spectra +========== + +Plot a 1-D spectrum with a physical x-axis (energy in eV) using +:meth:`~anyplotlib.Axes.plot`. + +The spectrum contains a broad background and three Gaussian peaks. +Circle markers highlight the peak positions using +:meth:`~anyplotlib.Plot1D.add_points`, and a range widget +selects a region of interest. A model fit is overlaid with a dashed line, +and the background component is shown as a semi-transparent dotted curve with +diamond markers. + +Pan and zoom with the mouse; press **R** to reset the view. +""" +import numpy as np +import anyplotlib as apl + +rng = np.random.default_rng(0) + +# ── Synthetic XPS-style spectrum ────────────────────────────────────────────── +energy = np.linspace(280, 295, 512) # binding energy axis (eV) + +def gaussian(x, mu, sigma, amp): + return amp * np.exp(-0.5 * ((x - mu) / sigma) ** 2) + +background = 0.4 * np.exp(-0.08 * (energy - 280)) + +# Background + three peaks (C 1s region) +spectrum = ( + background + + gaussian(energy, 284.8, 0.4, 1.0) # C–C / C–H + + gaussian(energy, 286.2, 0.4, 0.35) # C–O + + gaussian(energy, 288.0, 0.4, 0.18) # C=O + + rng.normal(scale=0.015, size=len(energy)) +) + +# ── Plot ────────────────────────────────────────────────────────────────────── +fig, ax = apl.subplots(1, 1, figsize=(620, 340)) +v = ax.plot(spectrum, axes=[energy], units="eV", y_units="Intensity (a.u.)", + color="#4fc3f7", linewidth=1.5) + +# ── Peak markers (add_points collection) ────────────────────────────────────── +peak_energies = np.array([284.8, 286.2, 288.0]) +peak_offsets = np.column_stack([ + peak_energies, + np.interp(peak_energies, energy, spectrum), +]) +v.add_points(peak_offsets, name="peaks", + sizes=7, color="#ff1744", facecolors="#ff174433", + labels=["C\u2013C", "C\u2013O", "C=O"]) + +# ── Region-of-interest widget ───────────────────────────────────────────────── +v.add_range_widget(x0=285.8, x1=288.8, color="#00e5ff") + +fig + +# %% +# Overlay a model fit — linestyle and alpha +# ----------------------------------------- +# Use :meth:`~anyplotlib.Plot1D.add_line` to overlay additional +# curves. Here the noiseless model fit is drawn as a **dashed** line so it +# is visually distinct from the noisy measured spectrum. The ``alpha`` +# parameter makes the fit semi-transparent so the data underneath remains +# readable. +# +# The y-axis range is expanded automatically to accommodate any overlay line +# whose values fall outside the current bounds. + +fit = ( + background + + gaussian(energy, 284.8, 0.4, 1.0) + + gaussian(energy, 286.2, 0.4, 0.35) + + gaussian(energy, 288.0, 0.4, 0.18) +) +v.add_line(fit, x_axis=energy, + color="#ffcc00", linewidth=2.0, + linestyle="dashed", alpha=0.85, + label="fit") + +fig + +# %% +# Background component — dotted line with markers +# ------------------------------------------------ +# Draw the exponential background component as a **dotted** curve. Passing +# ``marker="D"`` places a diamond at every data point (useful when the line +# is sparse or when you want to emphasise individual sample positions). +# ``markersize`` controls the half-size of the symbol in pixels. + +# Sub-sample to keep the marker plot readable +step = 32 +v.add_line(background[::step], x_axis=energy[::step], + color="#ce93d8", linewidth=1.2, + linestyle="dotted", alpha=0.9, + marker="D", markersize=3, + label="background") + +fig + +# %% +# Post-construction setters +# ------------------------- +# All primary-line style properties can be changed after the panel is created +# without rebuilding it. This is useful in interactive notebooks where you +# want to tweak the appearance of the main trace. + +v.set_alpha(0.9) # slightly reduce primary-line opacity +v.set_linewidth(2.0) # thicker stroke for the main spectrum + +fig diff --git a/Examples/Widgets/plot_widget1d_hline.py b/Examples/Widgets/plot_widget1d_hline.py index 8cfb70e1..e753163e 100644 --- a/Examples/Widgets/plot_widget1d_hline.py +++ b/Examples/Widgets/plot_widget1d_hline.py @@ -3,18 +3,18 @@ ========================== A draggable horizontal line on a 1-D plot panel. -Add it with :meth:`~anyplotlib.figure_plots.Plot1D.add_hline_widget`. +Add it with :meth:`~anyplotlib.plot1d.Plot1D.add_hline_widget`. Drag the line up or down to change the selected y value. """ import numpy as np -import anyplotlib as vw +import anyplotlib as apl x = np.linspace(0, 4 * np.pi, 512) signal = np.sin(x) -fig, ax = vw.subplots(1, 1, figsize=(560, 300)) +fig, ax = apl.subplots(1, 1, figsize=(560, 300)) v = ax.plot(signal, axes=[x], units="rad") v.add_hline_widget(y=0.5, color="#69f0ae") -fig +fig diff --git a/Examples/Widgets/plot_widget1d_range.py b/Examples/Widgets/plot_widget1d_range.py index 7eecdf3c..0a9ab0c8 100644 --- a/Examples/Widgets/plot_widget1d_range.py +++ b/Examples/Widgets/plot_widget1d_range.py @@ -3,19 +3,19 @@ ================ A draggable range selector on a 1-D plot panel with two handles. -Add it with :meth:`~anyplotlib.figure_plots.Plot1D.add_range_widget`. +Add it with :meth:`~anyplotlib.plot1d.Plot1D.add_range_widget`. Drag either handle to resize the selected interval, or drag the band to move it. """ import numpy as np -import anyplotlib as vw +import anyplotlib as apl x = np.linspace(0, 4 * np.pi, 512) signal = np.sin(x) -fig, ax = vw.subplots(1, 1, figsize=(560, 300)) +fig, ax = apl.subplots(1, 1, figsize=(560, 300)) v = ax.plot(signal, axes=[x], units="rad") v.add_range_widget(x0=np.pi, x1=2 * np.pi, color="#ffeb3b") -fig +fig diff --git a/Examples/Widgets/plot_widget1d_vline.py b/Examples/Widgets/plot_widget1d_vline.py index 88debdb9..7699f479 100644 --- a/Examples/Widgets/plot_widget1d_vline.py +++ b/Examples/Widgets/plot_widget1d_vline.py @@ -3,17 +3,18 @@ ======================== A draggable vertical line on a 1-D plot panel. -Add it with :meth:`~anyplotlib.figure_plots.Plot1D.add_vline_widget`. +Add it with :meth:`~anyplotlib.plot1d.Plot1D.add_vline_widget`. Drag the line left or right to change the selected x position. """ import numpy as np -import anyplotlib as vw +import anyplotlib as apl x = np.linspace(0, 4 * np.pi, 512) signal = np.sin(x) -fig, ax = vw.subplots(1, 1, figsize=(560, 300)) +fig, ax = apl.subplots(1, 1, figsize=(560, 300)) v = ax.plot(signal, axes=[x], units="rad") v.add_vline_widget(x=np.pi, color="#e040fb") + fig diff --git a/Examples/Widgets/plot_widget2d_annular.py b/Examples/Widgets/plot_widget2d_annular.py index b79faa6e..f15b156f 100644 --- a/Examples/Widgets/plot_widget2d_annular.py +++ b/Examples/Widgets/plot_widget2d_annular.py @@ -6,16 +6,16 @@ Drag the inner or outer ring to adjust the radii. """ import numpy as np -import anyplotlib as vw +import anyplotlib as apl rng = np.random.default_rng(2) data = rng.standard_normal((128, 128)).cumsum(0).cumsum(1) data = (data - data.min()) / (data.max() - data.min()) xy = np.linspace(0, 10, 128) -fig, ax = vw.subplots(1, 1, figsize=(460, 460)) +fig, ax = apl.subplots(1, 1, figsize=(460, 460)) v = ax.imshow(data, axes=[xy, xy], units="nm") v.add_widget("annular", color="#00e5ff", cx=64, cy=64, r_outer=40, r_inner=20) -fig +fig diff --git a/Examples/Widgets/plot_widget2d_circle.py b/Examples/Widgets/plot_widget2d_circle.py index afb0d62a..589c6906 100644 --- a/Examples/Widgets/plot_widget2d_circle.py +++ b/Examples/Widgets/plot_widget2d_circle.py @@ -3,21 +3,20 @@ ================= A draggable, resizable circle overlay on a 2-D image panel. -Add it with :meth:`~anyplotlib.figure_plots.Plot2D.add_widget` using -``kind="circle"``, or via the convenience wrapper -``add_widget("circle", ...)``. +Add it with :meth:`~anyplotlib.plot2d.Plot2D.add_widget` using +``kind="circle"``. """ import numpy as np -import anyplotlib as vw +import anyplotlib as apl rng = np.random.default_rng(0) data = rng.standard_normal((128, 128)).cumsum(0).cumsum(1) data = (data - data.min()) / (data.max() - data.min()) xy = np.linspace(0, 10, 128) -fig, ax = vw.subplots(1, 1, figsize=(460, 460)) +fig, ax = apl.subplots(1, 1, figsize=(460, 460)) v = ax.imshow(data, axes=[xy, xy], units="nm") v.add_widget("circle", color="#e040fb", cx=64, cy=64, r=20) -fig +fig diff --git a/Examples/Widgets/plot_widget2d_crosshair.py b/Examples/Widgets/plot_widget2d_crosshair.py index 3b29b0a6..6352125c 100644 --- a/Examples/Widgets/plot_widget2d_crosshair.py +++ b/Examples/Widgets/plot_widget2d_crosshair.py @@ -6,16 +6,16 @@ on a 2-D image panel. """ import numpy as np -import anyplotlib as vw +import anyplotlib as apl rng = np.random.default_rng(3) data = rng.standard_normal((128, 128)).cumsum(0).cumsum(1) data = (data - data.min()) / (data.max() - data.min()) xy = np.linspace(0, 10, 128) -fig, ax = vw.subplots(1, 1, figsize=(460, 460)) +fig, ax = apl.subplots(1, 1, figsize=(460, 460)) v = ax.imshow(data, axes=[xy, xy], units="nm") v.add_widget("crosshair", color="#69f0ae", cx=64, cy=64) -fig +fig diff --git a/Examples/Widgets/plot_widget2d_label.py b/Examples/Widgets/plot_widget2d_label.py index 30055587..30fc2ab2 100644 --- a/Examples/Widgets/plot_widget2d_label.py +++ b/Examples/Widgets/plot_widget2d_label.py @@ -7,16 +7,17 @@ and ``fontsize``. """ import numpy as np -import anyplotlib as vw +import anyplotlib as apl rng = np.random.default_rng(5) data = rng.standard_normal((128, 128)).cumsum(0).cumsum(1) data = (data - data.min()) / (data.max() - data.min()) xy = np.linspace(0, 10, 128) -fig, ax = vw.subplots(1, 1, figsize=(460, 460)) +fig, ax = apl.subplots(1, 1, figsize=(460, 460)) v = ax.imshow(data, axes=[xy, xy], units="nm") v.add_widget("label", color="#ff1744", x=10, y=10, text="Region A", fontsize=14) + fig diff --git a/Examples/Widgets/plot_widget2d_polygon.py b/Examples/Widgets/plot_widget2d_polygon.py index 2a1ab9e7..fa3d2de2 100644 --- a/Examples/Widgets/plot_widget2d_polygon.py +++ b/Examples/Widgets/plot_widget2d_polygon.py @@ -6,18 +6,18 @@ Pass ``vertices`` as a list of ``[x, y]`` pixel coordinates. """ import numpy as np -import anyplotlib as vw +import anyplotlib as apl rng = np.random.default_rng(4) data = rng.standard_normal((128, 128)).cumsum(0).cumsum(1) data = (data - data.min()) / (data.max() - data.min()) xy = np.linspace(0, 10, 128) -fig, ax = vw.subplots(1, 1, figsize=(460, 460)) +fig, ax = apl.subplots(1, 1, figsize=(460, 460)) v = ax.imshow(data, axes=[xy, xy], units="nm") v.add_widget("polygon", color="#ff9100", vertices=[[32, 16], [96, 16], [112, 80], [64, 112], [16, 80]]) -fig +fig diff --git a/Examples/Widgets/plot_widget2d_rectangle.py b/Examples/Widgets/plot_widget2d_rectangle.py index afa2a904..52579e56 100644 --- a/Examples/Widgets/plot_widget2d_rectangle.py +++ b/Examples/Widgets/plot_widget2d_rectangle.py @@ -3,20 +3,20 @@ ==================== A draggable, resizable rectangle overlay on a 2-D image panel. -Add it with :meth:`~anyplotlib.figure_plots.Plot2D.add_widget` using +Add it with :meth:`~anyplotlib.plot2d.Plot2D.add_widget` using ``kind="rectangle"``. """ import numpy as np -import anyplotlib as vw +import anyplotlib as apl rng = np.random.default_rng(1) data = rng.standard_normal((128, 128)).cumsum(0).cumsum(1) data = (data - data.min()) / (data.max() - data.min()) xy = np.linspace(0, 10, 128) -fig, ax = vw.subplots(1, 1, figsize=(460, 460)) +fig, ax = apl.subplots(1, 1, figsize=(460, 460)) v = ax.imshow(data, axes=[xy, xy], units="nm") v.add_widget("rectangle", color="#ffeb3b", x=24, y=24, w=80, h=60) -fig +fig diff --git a/Examples/plot_image2d.py b/Examples/plot_image2d.py deleted file mode 100644 index 758e37cf..00000000 --- a/Examples/plot_image2d.py +++ /dev/null @@ -1,59 +0,0 @@ -""" -2D Image -======== -Display a 2-D image with physical axes using -:meth:`~anyplotlib.figure_plots.Axes.imshow`. -The image is a synthetic STEM-like diffraction pattern with a physical -length scale in nanometres. Circle markers highlight the first-order -diffraction spots, and an annular integration widget is placed over the -central beam. Pan and zoom with the mouse; press **R** to reset the view, -**H** to toggle the histogram, **L** / **S** to cycle colour-scale modes. -""" -import numpy as np -import anyplotlib as vw -rng = np.random.default_rng(1) -# ── Synthetic diffraction pattern ───────────────────────────────────────────── -N = 256 -x = np.linspace(-5, 5, N) # physical axis in nm -y = np.linspace(-5, 5, N) -XX, YY = np.meshgrid(x, y) -R = np.sqrt(XX ** 2 + YY ** 2) -def _ring(r, r0, width, amp): - return amp * np.exp(-0.5 * ((r - r0) / width) ** 2) -image = ( - _ring(R, 0.0, 0.30, 1.00) # central spot - + _ring(R, 2.1, 0.15, 0.55) # first-order ring - + _ring(R, 4.2, 0.15, 0.25) # second-order ring - + rng.normal(scale=0.04, size=(N, N)) -) -# ── Plot ─────────────────────────────────────────────────────────────────────── -fig, ax = vw.subplots(1, 1, figsize=(500, 500)) -v = ax.imshow(image, axes=[x, y], units="nm") -v.set_colormap("inferno") -# ── First-order spot markers ─────────────────────────────────────────────────── -# imshow axes are centre arrays: pixel = (phys - x[0]) / (x[1] - x[0]) -dx = x[1] - x[0] -def phys_to_px(val): - return (np.asarray(val) - x[0]) / dx -spot_nm = np.array([[ 2.1, 0.0], [-2.1, 0.0], - [ 0.0, 2.1], [ 0.0, -2.1]]) -spot_px = np.column_stack([phys_to_px(spot_nm[:, 0]), - phys_to_px(spot_nm[:, 1])]) -v.add_circles(spot_px, name="spots", radius=7, - edgecolors="#00e5ff", facecolors="#00e5ff22", - labels=["g1", "g1_bar", "g2", "g2_bar"]) -# ── Annular integration widget ───────────────────────────────────────────────── -cx = cy = float(phys_to_px(0.0)) -v.add_widget("annular", color="#ffcc00", - cx=cx, cy=cy, - r_outer=float(phys_to_px(2.8) - phys_to_px(0.0)), - r_inner=float(phys_to_px(1.2) - phys_to_px(0.0))) -fig -# %% -# Adjust display range and colour map -# ------------------------------------- -# :meth:`~anyplotlib.figure_plots.Plot2D.set_clim` clips the colour scale; -# :meth:`~anyplotlib.figure_plots.Plot2D.set_colormap` switches the palette. -v.set_clim(vmin=0.0, vmax=0.8) -v.set_colormap("viridis") -fig diff --git a/Examples/plot_spectra1d.py b/Examples/plot_spectra1d.py deleted file mode 100644 index eac54435..00000000 --- a/Examples/plot_spectra1d.py +++ /dev/null @@ -1,66 +0,0 @@ -""" -1D Spectra -========== - -Plot a 1-D spectrum with a physical x-axis (energy in eV) using -:meth:`~anyplotlib.figure_plots.Axes.plot`. - -The spectrum contains a broad background and three Gaussian peaks. -Vertical-line markers highlight the peak positions, and a range widget -selects a region of interest. Pan and zoom with the mouse; press **R** -to reset the view. -""" -import numpy as np -import anyplotlib as vw - -rng = np.random.default_rng(0) - -# ── Synthetic XPS-style spectrum ───────────────────────────────────────────── -energy = np.linspace(280, 295, 512) # binding energy axis (eV) - -def gaussian(x, mu, sigma, amp): - return amp * np.exp(-0.5 * ((x - mu) / sigma) ** 2) - -# Background + three peaks (C 1s region) -spectrum = ( - 0.4 * np.exp(-0.08 * (energy - 280)) # exponential background - + gaussian(energy, 284.8, 0.4, 1.0) # C–C / C–H - + gaussian(energy, 286.2, 0.4, 0.35) # C–O - + gaussian(energy, 288.0, 0.4, 0.18) # C=O - + rng.normal(scale=0.015, size=len(energy)) -) - -# ── Plot ────────────────────────────────────────────────────────────────────── -fig, ax = vw.subplots(1, 1, figsize=(620, 320)) -v = ax.plot(spectrum, axes=[energy], units="eV", y_units="Intensity (a.u.)") - -# ── Peak markers ────────────────────────────────────────────────────────────── -peak_energies = np.array([284.8, 286.2, 288.0]) -peak_offsets = np.column_stack([ - peak_energies, - np.interp(peak_energies, energy, spectrum), -]) -v.add_points(peak_offsets, name="peaks", - edgecolors="#ff1744", facecolors="#ff174433", sizes=7, - labels=["C–C", "C–O", "C=O"]) - -# ── Region-of-interest widget ───────────────────────────────────────────────── -v.add_range_widget(x0=285.8, x1=288.8, color="#00e5ff") - -fig - -# %% -# Overlay a second spectrum -# ------------------------- -# Use :meth:`~anyplotlib.figure_plots.Plot1D.add_line` to overlay additional -# curves — useful for comparing reference spectra or fits. - -fit = ( - 0.4 * np.exp(-0.08 * (energy - 280)) - + gaussian(energy, 284.8, 0.4, 1.0) - + gaussian(energy, 286.2, 0.4, 0.35) - + gaussian(energy, 288.0, 0.4, 0.18) -) -v.add_line(fit, x_axis=energy, color="#ffcc00", linewidth=1.5, label="fit") -fig - diff --git a/LICENSE b/LICENSE new file mode 100644 index 00000000..115f7b00 --- /dev/null +++ b/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 Carter Francis + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/Makefile b/Makefile index b97de95f..0005703e 100644 --- a/Makefile +++ b/Makefile @@ -18,3 +18,8 @@ help: # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) + +clean: + rm -rf $(BUILDDIR)/* + rm -rf docs/auto_examples/ + rm -rf docs/api/generated/ diff --git a/ORIX_BACKEND_PLAN.md b/ORIX_BACKEND_PLAN.md new file mode 100644 index 00000000..74f80787 --- /dev/null +++ b/ORIX_BACKEND_PLAN.md @@ -0,0 +1,133 @@ +# Making anyplotlib a plotting backend for orix + +Goal: render orix's IPF / stereographic / pole-figure plots **natively in +anyplotlib** so orix (and SpyDE) can drop matplotlib for them. + +## How orix plots today (matplotlib) + +orix builds everything on **matplotlib Axes subclasses + registered projections**: + +- `orix/plot/stereographic_plot.py` — `StereographicPlot(name="stereographic")` + subclasses `matplotlib.axes.Axes`. It overrides `plot` / `scatter` / `text` to + first project spherical → (x, y) via `orix.projections.StereographicProjection` + (pure numpy, **no matplotlib**), then call `super().plot/scatter/text(x, y, …)` + in **data coordinates** with `set_aspect("equal")`. +- `inverse_pole_figure_plot.py` — `InversePoleFigurePlot` (subclass of the above) + draws the fundamental-sector outline + `[hkl]` corner labels. +- `IPFColorKeyTSL.plot()` — fills the sector with colour (scatter / mesh of + projected directions) on such an axis. + +So the orix side that's *not* matplotlib is just the **projection math** +(`vector2xy`); the rendering is plain 2-D matplotlib: `scatter`, `plot`, `text`, +`fill`/patches, in **data coords**, aspect-equal. + +## The gap in anyplotlib + +anyplotlib's 2-D is **image-centric**. Marker/overlay groups (`add_points`, +`add_lines`, `add_polygons`, `add_texts`, `add_circles`) with `transform="data"` +map offsets through **`_imgFitRect` (image pixels)** — i.e. an offset `(x, y)` is +treated as image column/row, *not* the axis's `x_axis`/`y_axis` data values +(confirmed in `figure_esm.js`: all 2-D coordinate fns derive from `_imgFitRect`; +markers use it, never `st.x_axis`). There is no "blank axis with x/y limits + +data-coord scatter/line/polygon/text" surface — which is exactly what orix needs. + +(This is also why the SpyDE IPF-refine triangle is currently a matplotlib raster: +its overlays in stereographic coords collapsed into the image's top-left corner.) + +## Staged plan + +**Stage 1 — data-coordinate overlays for `Plot2D` (foundation).** +Make marker groups honour the panel's `x_axis`/`y_axis` when present: a +`transform="data"` offset `(x, y)` maps via the axis values → image fraction → +canvas (the matplotlib `imshow(extent=…)` + `scatter` alignment). Smallest change +that (a) unblocks a fully-native IPF triangle over the heatmap imshow and (b) +proves the data→pixel plumbing. Touches `markers.py` (wire) + `figure_esm.js` +(`drawMarkers2d` coord branch). Demo/test: native IPF heatmap triangle. + +**Stage 2 — a coordinate-only 2-D axis (no image).** +`ax.set_xlim/ylim` + `set_aspect("equal")` on a panel with **no imshow**, where +`add_points/lines/polygons/texts` live in data coords. This is the general +"matplotlib-Axes-like 2-D" surface orix's `StereographicPlot` draws onto. Likely +a lightweight `Plot2DCoords` (or extend `Plot2D` to allow `data_bounds` without an +image) reusing the Stage-1 transform. + +**Stage 3 — orix targets anyplotlib (lives in ORIX, not here).** +The stereographic projection + IPF / pole-figure plotting **belongs in orix** and +already exists there (`StereographicProjection`, `StereographicPlot`, +`IPFColorKeyTSL`). anyplotlib stays domain-agnostic — it must NOT know about +stereographic projections. The integration is an **orix-side** change: refactor +orix's plotting to draw through a backend (matplotlib OR anyplotlib's `axes2d` +surface) — `vector2xy` (orix) → `PlotXY.scatter/plot/fill/text` (anyplotlib). +anyplotlib's only job is to be a complete-enough generic 2-D backend. + +## Align with matplotlib's model + +matplotlib's data drawing is two ideas we should mirror: + +1. **`transData` = a composed transform chain.** `transData = transScale (log/lin) + + transLimits (data Bbox → unit [0,1] box) + transAxes (unit box → display)`. + i.e. **data → [0,1] via the axis limits → pixels via the axes rect**. + `set_aspect("equal")` is `apply_aspect()` — it adjusts the box (and limits) so a + data unit is the same length on x and y. + - anyplotlib's **1-D path already does this**: marker offsets are normalised to + `[0,1]` by the x/y data range, then `_tc2d(fx,fy)=[r.x+fx*r.w, r.y+(1-fy)*r.h]` + maps the unit box → the panel rect. That's exactly `transLimits` → `transAxes`. + - So the coordinate axis just needs **explicit `xlim`/`ylim`** as the transData + domain + an aspect step, reusing the same unit-box→rect mapping. + +2. **Scatter is a `Collection` (offsets + per-point props), not N artists.** + `ax.scatter(x,y,c=,s=)` → one `PathCollection`: an offsets array drawn with a + shared marker path, per-point colours/sizes. anyplotlib's `MarkerGroup` + (`add_points`/`add_circles`) is **already this** — offsets + `facecolors`/ + `sizes` arrays. So `ax.scatter` becomes a thin wrapper returning a points + MarkerGroup positioned via transData; `plot`→a polyline marker group (`Line2D`), + `fill`/polygons→`add_polygons` (`Polygon`/`PathCollection`), `text`→`add_texts`. + +So the coordinate axis = **the 1-D unit-box→rect transform driven by explicit +xlim/ylim (+ aspect), with the existing collection-style markers as the artists** — +semantically the same as matplotlib's `transData` + `PathCollection`. + +## API sketch (matplotlib-parity) + +```python +ax = fig.add_axes2d() # blank data-coord axis (no image) +ax.set_xlim(-1, 1); ax.set_ylim(-1, 1); ax.set_aspect("equal") +ax.scatter(xs, ys, c=colors, s=8) # -> PathCollection-style MarkerGroup +ax.plot(ex, ey, color="w") # -> Line2D-style polyline +ax.fill(px, py, facecolor="…") # -> Polygon +ax.text(x, y, r"$[111]$") # -> Text +``` + +## Status + +**Stage 2 landed:** `Axes.axes2d()` → `PlotXY` (`anyplotlib/plotxy/`). It reuses +the 1-D data→canvas transform (`kind="1d"`, hidden curve) so `scatter`/`plot`/ +`fill`/`text` draw as collection markers in **data coords** with no renderer +change. `set_xlim`/`set_ylim`/`set_aspect`. Tests in `tests/test_plotxy/` +(5 pass incl. a chromium render); demo = a native IPF triangle (fill + scatter + +labels in data coords). + +**Two renderer gaps the demo exposed — both now CLOSED:** +1. **Per-point scatter colours — DONE.** `drawMarkers1d` `points` now reads + per-offset `facecolors`/`color` arrays (matplotlib `PathCollection`), so + `scatter(c=[...])` renders the IPF colour-key gradient. +2. **`aspect="equal"` — DONE.** `_plotRect1d(p)` applies matplotlib + `apply_aspect`: when `state.aspect==='equal'` it shrinks + centres the panel + box so one data unit spans equal pixels on x and y. Baked into the shared rect + helper, so draw / markers / overlay / hit-test all use the identical adjusted + box (matplotlib's transData derives from the axes box). A wide-panel IPF + triangle now renders undistorted (`tests/test_plotxy`: + `test_aspect_equal_renders_square` vs `test_aspect_auto_fills_panel`). + +**Then the orix side (in the orix repo, not here):** the stereographic / IPF / +pole-figure plotting STAYS in orix; refactor it to draw through a backend so +`vector2xy` (orix) feeds `PlotXY.scatter/plot/fill/text` (anyplotlib). anyplotlib +stays generic — finishing (1) + (2) makes it a complete-enough backend. + +## Recommendation + +Built **Stage 2 (chosen): the coordinate-only 2-D axis** as above — +reuse the 1-D `transLimits→transAxes` unit-box transform with explicit xlim/ylim + +aspect, expose `scatter`/`plot`/`fill`/`text` as collection-style artists. Demo = +a native IPF fundamental-sector triangle (filled colour-key + outline + `[hkl]` +labels) drawn purely with these primitives. diff --git a/README.md b/README.md new file mode 100644 index 00000000..92392321 --- /dev/null +++ b/README.md @@ -0,0 +1,119 @@ +# anyplotlib + +[![codecov](https://codecov.io/gh/CSSFrancis/anyplotlib/branch/main/graph/badge.svg)](https://codecov.io/gh/CSSFrancis/anyplotlib) +[![Tests](https://github.com/CSSFrancis/anyplotlib/actions/workflows/tests.yml/badge.svg)](https://github.com/CSSFrancis/anyplotlib/actions/workflows/tests.yml) + +**anyplotlib** is a fast, interactive plotting library for Jupyter, built on +[anywidget](https://anywidget.dev/) and a pure-JavaScript canvas renderer. +It follows matplotlib's object-oriented API — create a `Figure`, call methods +on `Axes` — so switching is often a one-line change: + +```python +import anyplotlib as apl + +fig, ax = apl.subplots(1, 1) # same shape as plt.subplots(1, 1) +ax.imshow(data) # pan, zoom, and inspect — live +fig # display in a Jupyter cell +``` + +If you have used matplotlib's OO interface, you already know most of +anyplotlib. What you gain is interactivity that stays fast on large data — +without a kernel round-trip per frame. + +## Why another plotting library? + +Matplotlib is a superb tool for publication-quality static figures, but its +interactive notebook story (`ipympl`) re-renders the whole figure on the +Python side for every frame. anyplotlib makes the opposite trade-off: + +- **All rendering happens in the browser.** Python serialises compact state + (raw image bytes, base64-encoded float arrays) once; pan/zoom/drag never + touch the kernel. +- **Each image, line collection, or marker group is a single canvas object**, + so blitting works and drag interactions run at full frame rate. +- **The scope is deliberately limited.** The OO API only (no `plt.plot()` + global state), a curated set of plot types and marker styles, and raster + canvas output rather than vector graphics. For print-quality SVG/PDF + figures, matplotlib remains the right tool. + +## Features + +- **Plot types** — `plot` (1-D lines with markers, linestyles, legends, log y), + `imshow` (2-D images with colormaps, colorbars, scale bars, overlay masks), + `pcolormesh` (non-uniform 2-D meshes), `bar` (grouped, horizontal, log, + value labels), and 3-D `plot_surface` / `scatter3d` / `plot3d`. +- **Layouts** — `subplots`, matplotlib-compatible `GridSpec` indexing + (slices, spans, negative indices), `width_ratios`/`height_ratios`, + `sharex`/`sharey` linked pan-zoom, and floating inset axes with + minimize/maximize. +- **Markers** — static overlays (points, circles, ellipses, rectangles, + polygons, arrows, line segments, text, h/v lines) with matplotlib-style + kwargs and live `.set()` updates. +- **Widgets** — draggable overlays (`RectangleWidget`, `CircleWidget`, + `AnnularWidget`, `CrosshairWidget`, `PolygonWidget`, `VLineWidget`, + `HLineWidget`, `RangeWidget`, …) that report positions back to Python. +- **Events** — a two-tier callback system: `pointer_move` fires every drag + frame for cheap updates; `pointer_settled` / `pointer_up` fire once for + expensive recomputation. Plus `key_down`, `wheel`, `double_click`, and + per-line scoped handlers. +- **Interactive docs** — the bundled `anyplotlib.sphinx_anywidget` extension + makes any anywidget figure live in Sphinx Gallery pages via Pyodide — no + kernel or server needed. +- **Embeddable anywhere** — figures don't require Jupyter. Export + self-contained HTML (`fig.save_html("plot.html")`), mount the renderer + directly in an Electron app or web page via the JS `mount()` API, or run a + live Python backend over any transport with `anyplotlib.embed.FigureBridge` + (full callback support). See the embedding guide in the docs. + +```python +import numpy as np +import anyplotlib as apl + +fig, (ax_img, ax_spec) = apl.subplots(1, 2, figsize=(900, 400)) +img = ax_img.imshow(stack.mean(axis=2), cmap="viridis") +spec = ax_spec.plot(stack[64, 64], units="eV") + +cross = img.add_widget("crosshair", cx=64, cy=64) + +@cross.add_event_handler("pointer_move") # every drag frame — keep it cheap +def update(event): + spec.set_data(stack[int(cross.cy), int(cross.cx)]) +``` + +## Installation + +```bash +pip install anyplotlib +``` + +Works anywhere anywidget does: JupyterLab, Jupyter Notebook, VS Code, +PyCharm, Google Colab, and marimo. Dependencies are intentionally light: +`anywidget`, `numpy`, `traitlets`, and `colorcet` (no matplotlib required). + +## Documentation + +Full docs, a live example gallery (interactive in the browser — no install), +and the event-system guide are at +**[cssfrancis.github.io/anyplotlib](https://cssfrancis.github.io/anyplotlib/)**. + +## Development + +```bash +git clone https://github.com/CSSFrancis/anyplotlib +cd anyplotlib +uv sync # install with dev dependencies +uv run playwright install chromium # browsers for rendering tests +uv run pytest # full suite (unit + Playwright + visual) +make html # build the docs locally +``` + +The architecture is a single `anywidget.AnyWidget` (`Figure`) that owns all +traitlets; plot objects are plain Python classes that serialise their state +dicts to per-panel traits, and `figure_esm.js` renders them. See +[AGENTS.md](AGENTS.md) for the codebase guide and +[`anyplotlib/FIGURE_ESM.md`](anyplotlib/FIGURE_ESM.md) for a map of the JS +renderer. + +## License + +MIT — see [LICENSE](LICENSE). diff --git a/RELEASE_PLAN.md b/RELEASE_PLAN.md new file mode 100644 index 00000000..efa49618 --- /dev/null +++ b/RELEASE_PLAN.md @@ -0,0 +1,104 @@ +# anyplotlib 0.1.0 — Release Plan + +Status as of 2026-06-12: `pyproject.toml` already says `0.1.0`, `CHANGELOG.rst` +already contains a `v0.1.0 (2026-04-12)` section, but **no git tag exists and +nothing is on PyPI** (the name `anyplotlib` is still available). The release +automation (`prepare_release.yml` → tag → `release.yml` OIDC publish) is built +and ready; what remains is mostly housekeeping. + +## Phase 1 — Clean the working tree (blockers) + +- [ ] **Decide on the uncommitted `anywidget_bridge.js` work** (+611 lines: a + HyperSpy/Enthought-traits shim for Pyodide). It is experimental and + unrelated to core plotting — either finish it on a feature branch or + stash it. Don't let it ride into the release commit unreviewed. +- [ ] **Commit or drop `Examples/Interactive/plot_segment_by_contrast_advanced.py`** + (untracked). If kept, it runs in docs CI — verify it executes. +- [ ] **Commit `uv.lock`** (currently untracked). CI uses `uv sync`; a + committed lockfile makes CI and contributor environments reproducible. +- [ ] Commit the audit fixes from this session: `LICENSE`, packaging excludes, + classifier/keywords, colormap-fallback fix, `Plot3D` geometry refactor, + `vw` → `apl` alias standardization, README/AGENTS.md/FIGURE_ESM.md + updates. + +## Phase 2 — Reconcile the changelog and version + +The Prepare Release workflow can only bump *up* from 0.1.0, so for this first +release do the changelog manually: + +- [ ] Fold the three pending `upcoming_changes/` fragments (6, 9, 11) into the + existing `v0.1.0` section of `CHANGELOG.rst` (or run + `uvx towncrier build --version 0.1.0` after deleting the stale section), + update the date, and delete the consumed fragments. +- [ ] Verify `docs/conf.py` `release` string matches `0.1.0`. +- [ ] Verify `docs/_root/switcher.json` has (or will get) a `v0.1.0` entry. + +## Phase 3 — One-time PyPI setup + +- [ ] On pypi.org, add a **pending trusted publisher**: + Owner `CSSFrancis`, repo `anyplotlib`, workflow `release.yml`, + environment `pypi` (matches the `environment:` block in release.yml). +- [ ] Create the `pypi` environment in the GitHub repo settings (release.yml + references it; publishing fails without it). + +## Phase 4 — Pre-tag verification + +- [ ] CI green on `main` (tests.yml matrix: 3.10–3.13 × linux/mac/win, plus + lowest-direct resolution job). +- [ ] `uv build`, then sanity-check the artifacts: + `uvx twine check dist/*` and install the wheel in a fresh venv, + `python -c "import anyplotlib"`. (After this session's packaging fix the + wheel no longer ships `anyplotlib/tests/` and PNG baselines — confirm + it is ~250 KB, not ~890 KB.) +- [ ] Build docs locally (`make html`) and click through the interactive + gallery — the Pyodide bridge loads the wheel built from the release + commit. +- [ ] Smoke-test in a real JupyterLab session: `subplots`, `imshow` + widget + drag, `plot` + vline widget, `bar`, `plot_surface`, inset. + +## Phase 5 — Ship + +```bash +git fetch origin +git tag v0.1.0 origin/main +git push origin v0.1.0 +``` + +This triggers `release.yml` (build → PyPI publish → GitHub Release with +changelog notes) and the docs deploy. Afterwards: + +- [ ] Verify `pip install anyplotlib` works from a clean environment. +- [ ] Verify the GitHub Release notes rendered correctly. +- [ ] Check the versioned docs URL and the root redirect. + +## Post-0.1.0 backlog (quality items from the audit, none blocking) + +1. **Duplicate CI**: `ci.yml` and `tests.yml` both run pytest on every + push/PR (ubuntu + 3.12 overlaps). Move the Codecov upload into the + tests.yml ubuntu/3.12 job and delete `ci.yml`. +2. **Colormap fidelity**: with colorcet installed (a hard dependency), + `"viridis"` silently renders as colorcet `bmy` and `"inferno"` as `kb` + (black→blue) — visually very different from the matplotlib maps users + expect. Consider embedding real 256-entry LUTs for the half-dozen most + common matplotlib names (a few KB) instead of aliasing. +3. **Add a linter/formatter**: no ruff/flake8 config exists. Add `ruff` + (lint + format) to the dev group and CI; the codebase is clean enough + that adoption should be cheap. +4. **Coverage in `addopts`**: `--cov` on every local `pytest` run slows quick + iterations and overwrites `coverage.xml`. Consider moving coverage flags + into the CI invocation only. +5. **Typing**: annotations are partial (`_fig: object`, untyped dicts). + If type-checking is a goal, add `py.typed` + mypy/pyright gradually. +6. **`Axes.imshow` silently drops RGB channels** (`data[:, :, 0]`). Either + render RGB properly or raise with a clear message; silent channel + dropping will surprise matplotlib users. +7. **`figure_esm.js` size** (~4,400 lines, one closure): consider an + esbuild-based bundling step so the JS can live in modules while anywidget + still receives a single `_esm` string. Until then, keep + `FIGURE_ESM.md` regenerated (instructions are in its header). +8. **`Event` dataclass breadth**: plot-type-specific fields (`bar_index`, + `ray`, `line_id`) live on the universal event. Fine at this scale; if + event types grow, consider per-kind payload dataclasses. +9. **Large-scale 3-D rendering (WebGPU)**: scoped in `WEBGPU_PLAN.md` — + phased, demand-gated, canvas fallback contract. Phase 0 (canvas cheats + + `voxels_from_volume` resampling API) is worth shipping independently. diff --git a/WEBGPU_PLAN.md b/WEBGPU_PLAN.md new file mode 100644 index 00000000..44360d0b --- /dev/null +++ b/WEBGPU_PLAN.md @@ -0,0 +1,234 @@ +# WebGPU-on-demand rendering — scoping document + +Status: **Phases 1–2 prototyped & hardware-verified** (2026-06-13). +Instanced points (Phase 1) and voxels (Phase 2) render on the GPU with +canvas fallback; projection + shaders validated on an NVIDIA Pascal GPU via +offscreen-texture readback. Remaining: binary-trait transport for >200k +payloads, the flagged CI smoke job, and Phase 3 (OIT translucency). +Owner: @CSSFrancis +Prerequisite reading: `anyplotlib/FIGURE_ESM.md` (3D drawing, voxels, plane widgets) + +## 1. Goal + +Render **large point clouds and voxel volumes** interactively — targets: + +| Workload | Today (Canvas2D) | Target (WebGPU) | +|---|---|---| +| `scatter3d` points | ~50k usable | **1M @ ≥30fps** | +| `voxels` cubes | ~10k (≤30k after Phase 0) | **500k @ ≥30fps** | +| Plane-drag re-slice | O(N) re-blit | **uniform update, 60fps at any N** | + +…without causing problems: every figure must keep working everywhere it works +today (Jupyter, Pyodide docs, Electron embed, headless CI), with no new JS +dependencies and no behaviour change for users below the GPU threshold. + +## 2. Non-goals + +- **Not** replacing Canvas2D — it remains the universal baseline, the + fallback, the small-N path, and the fully-CI-tested path, forever. +- **No WebGL2** — we go straight to WebGPU; maintaining three paths is worse + than two. (Decided 2026-06: choosing in 2026, not 2023.) +- **No three.js / no bundler** — raw WebGPU API, WGSL shaders as inline + strings in the single-file ESM. +- **No 2D pipeline changes** — images/lines/bars stay Canvas2D. +- **No WebGPU compute in early phases** (see Phase 4). + +## 3. Coverage & the fallback contract + +As of mid-2026: Chromium Win/Mac/Android and Electron ✓ (since 2023/24), +Safari ≥26 ✓ (Sept 2025), Firefox Windows ✓ / macOS recent / Linux rolling +out, Chrome Linux driver-dependent. Weak populations for *our* users: Linux +workstations, remote-desktop/VM sessions (no adapter even in supporting +browsers), older Safari. Estimated 15–25 % of scientific users today. + +**Contract:** WebGPU is a progressive enhancement. `navigator.gpu` present +→ `requestAdapter()` resolves → device created → *then* a panel may switch. +Any failure at any point (including mid-session device loss) lands on the +Canvas2D path silently and permanently for that session. A figure must never +render nothing because GPU was attempted. + +## 4. Architecture + +### 4.1 Activation policy + +- Python: `gpu="auto" | True | False` kwarg on `scatter3d()` / `voxels()` + → state field `gpu_mode`. Default `"auto"`. +- JS (`auto`): attempt WebGPU only when `vertices_count > GPU_THRESHOLD` + (initial: 20 000 — at/below this Canvas2D is already smooth, so the + fallback population loses nothing). `True` forces an attempt at any count + (still falls back); `False` never attempts. + +### 4.2 Device lifecycle (the async-init problem) + +- One **module-level singleton** `_gpuDevicePromise` (adapter + device + requested once per page, on first demand). +- Per-panel state `p._gpu ∈ {undefined, 'pending', 'active', 'unavailable'}`. +- First frame is ALWAYS Canvas2D (render() stays synchronous). When the + device promise resolves, the panel builds its buffers/pipeline, flips to + `'active'`, and redraws; on rejection → `'unavailable'`. +- `device.lost.then(...)`: mark every GPU panel `'unavailable'`, drop GPU + resources, redraw via Canvas2D. Never re-attempt within the session. + +### 4.3 Canvas split — decorations stay 2D + +Add one `gpuCanvas` to the 3D panel stack, *below* `plotCanvas`: + +``` +gpuCanvas (WebGPU) geometry only: instanced points / cubes +plotCanvas (2D ctx) axes, ticks, labels (_drawTex), reference sphere, + plane-widget quads, highlight — unchanged code, + drawn on a now-transparent background +overlayCanvas / markersCanvas / statusBar — unchanged +``` + +This is the key cost-control decision: **all decoration, label, TeX, sphere, +plane-widget, and highlight code is reused verbatim**; only the instanced +geometry moves to the GPU. The camera matrix is shared (same turntable +`_rot3` semantics → one orthographic view-projection matrix uniform). + +### 4.4 Pipelines + +- **Points**: instanced screen-facing quads (point_size px), per-instance + position (f32×3) + colour (unorm8×4). Fragment discards outside the disc. +- **Voxels**: one 36-vertex cube, instanced; per-instance position + colour. + Per-face shading via vertex normals (match the 0.82/0.68/1.0 canvas look). + Depth buffer → **no sorting at all**. +- **Slice emphasis & planes as uniforms**: plane axis/position/count go into + a uniform buffer; the fragment shader computes emphasis + (`|pos[axis] − plane| ≤ size/2`). Plane drags therefore re-render with a + **uniform write only** — no geometry re-upload, no Python round-trip + needed for the visual. +- Wire format already fits: `vertices_b64` (f32) and `point_colors_b64` (u8) + upload to GPUBuffers unchanged. + +### 4.5 Transparency strategy + +- Phase 1–2 GPU mode is **opaque** (depth-tested). For ≥100k elements this + reads *better* than alpha soup; it differs visually from the canvas + translucent look — documented, and `voxel_alpha` still applies on the + canvas path. +- Phase 3 adds weighted-blended OIT (two extra render targets + composite + pass) to restore the translucent-volume aesthetic at scale. Gate: only + build if genuinely needed after using opaque mode in practice. + +### 4.6 Capability feedback → adaptive budgets (Python) + +JS reports the outcome once per panel via the existing state echo: a +`_gpu_active: true|false` field written into the panel state (no new event +type needed). Python exposes `plot.gpu_active`. The resampling helper +(Phase 0) uses it: send full-resolution boundary voxels to GPU clients, +auto-stride to ≤20k for canvas clients. **No client ever receives a payload +it can't render.** + +### 4.7 Payload reality check (often the real bottleneck) + +1M points = 12 MB f32 → ~16 MB as b64-in-JSON through the comm. Phase 2 +includes moving large geometry to **binary traits** (ipywidgets/anywidget +support binary buffers; `_repr_utils._widget_state` already handles `bytes`) +with b64 kept for small payloads and the standalone/Pyodide paths. Without +this, the wire — not the GPU — caps practical sizes around ~200k points. + +## 5. Phases + +### Phase 0 — Canvas cheats + resampling API (no GPU code; do first) +*~2–3 days. Worth shipping regardless of WebGPU.* + +1. Interaction LOD: stride the draw set 2–4× while a drag is active; full + set on release/settle. +2. Analytic back-to-front order for grid voxels (camera octant → lexicographic + traversal; kills the O(n log n) sort). +3. Layered plane-drag cache: bake the translucent base cloud to a bitmap; + redraw only the emphasized slice voxels per drag frame. +4. `Axes.voxels_from_volume(vol, *, max_voxels=15000, mode="boundary"|"stride", + colors=...)` — formalises the explorer example's hand-rolled extraction. + +**Acceptance:** 25–30k voxels orbit smoothly on canvas (bench: orbit ≤35 ms +software); plane drag ≤10 ms at 20k; new benchmarks committed. + +### Phase 1 — GPU infrastructure + instanced points +*~4–5 days. The risk-retiring phase.* + +Device singleton, `gpuCanvas` stack integration, async swap, device-lost +fallback, `gpu_mode`/`_gpu_active` plumbing, instanced point pipeline. + +**Acceptance:** +- 1M points orbit ≥30fps on a real GPU (manual + flagged CI job). +- Kill switch verified: adapter-absent, mid-session device loss, and + `gpu=False` all render identically to today via canvas (automated). +- Embedding `mount()` and the Pyodide docs page work in GPU mode + (verify WebGPU inside the gallery iframes — srcdoc/permission policy). + +### Phase 2 — Instanced voxels + shader slice emphasis + binary traits +*~3–4 days.* + +Cube pipeline, plane uniforms (emphasis in-shader), plane-drag = uniform +update, binary-trait transport for large buffers. + +**Acceptance:** 500k cubes orbit ≥30fps; plane drag 60fps at 500k; voxel +grain explorer runs a 192³-extracted volume (~150k boundary voxels) live. + +### Phase 3 — Translucency (weighted-blended OIT) *(gated)* +*~4–6 days. Only if opaque mode proves insufficient in real use.* + +**Acceptance:** GPU translucent render within visual tolerance of the canvas +look at N ≤ 4k (screenshot comparison), correct at 500k. + +### Phase 4 — Future options *(not scoped)* +GPU compute culling/LOD, surfaces/lines on GPU, picking via ID buffer. + +## 6. Testing & CI strategy + +- **Canvas path keeps 100 % of today's coverage** and remains the default CI + matrix — GPU never reduces existing test fidelity. +- New **flagged headless GPU smoke job** (ubuntu): Chromium with + `--enable-unsafe-webgpu --enable-features=Vulkan` on lavapipe/SwiftShader- + Vulkan; tests `pytest.skip` cleanly when `requestAdapter()` yields null so + the job can never hard-fail on runner GPU availability. +- Fallback tests run in the NORMAL suite (no flags): assert `_gpu_active` + is false and rendering matches canvas baselines when GPU is absent — + this is the path that protects "no problems". +- Benchmarks: `js_gpu_points_1M`, `js_gpu_voxels_500k` added to the existing + hardware-gated baseline framework (recorded on a real-GPU machine). +- Phase 3 parity: SSIM-style screenshot comparison GPU vs canvas at small N. + +## 7. Risks + +| Risk | Severity | Mitigation | +|---|---|---| +| Async init race / blank first paint | High | First frame always canvas; swap on resolve; `'pending'` state | +| CI has no GPU adapter | High | Skip-on-unavailable smoke job; canvas keeps full coverage | +| Device lost mid-session | Med | Permanent per-session fallback; tested by forcing `device.destroy()` | +| Comm payload size (≥200k pts) | High | Phase 2 binary traits; capability-aware resampling caps payloads | +| Opaque-vs-translucent visual surprise | Med | Document; Phase 3 OIT; `gpu=False` escape hatch | +| WebGPU inside docs iframes (permission policy) | Med | Verify in Phase 1 acceptance; fall back if blocked | +| Safari/WGSL implementation quirks | Low-Med | Stick to core WGSL, no extensions; manual Safari pass per phase | +| Two render paths drift apart | Med | Shared camera/constants; parity screenshots; FIGURE_ESM.md section per path | + +## 8. Decision gates + +- **Gate A (after Phase 0):** if resampled canvas + linked slices satisfies + the 512×512×300 workflow in practice, pause here — GPU work is demand- + driven, not speculative. +- **Gate B (after Phase 1):** confirmed-working fallback matrix + real-GPU + point benchmark before any voxel pipeline work. +- **Gate C (before Phase 3):** a concrete use case that opaque mode cannot + serve. + +## 9. API sketch + +```python +# Python +plot = ax.voxels_from_volume(gid_volume, max_voxels=15_000, + mode="boundary", colors=grain_rgb) # Phase 0 +plot = ax.voxels(x, y, z, colors=c, gpu="auto") # Phase 2 +plot.gpu_active # bool, after first render echo +plot = ax.scatter3d(x, y, z, colors=c, gpu=True) # Phase 1 +``` + +```js +// JS internals (figure_esm.js) +_gpuDevice() // module singleton → Promise +p._gpu // 'pending' | 'active' | 'unavailable' +_buildPointPipeline(device, p) / _buildVoxelPipeline(device, p) +_drawGpu3d(p) // geometry; decorations still drawn by draw3d's 2D code +``` diff --git a/anyplotlib/FIGURE_ESM.md b/anyplotlib/FIGURE_ESM.md new file mode 100644 index 00000000..3eec2802 --- /dev/null +++ b/anyplotlib/FIGURE_ESM.md @@ -0,0 +1,303 @@ +# FIGURE_ESM.md — Navigator for `figure_esm.js` + +`figure_esm.js` is **~4,640 lines** and one big closure. Everything lives inside +`function render({ model, el })` so that all helpers share the same scope +(`theme`, `PAD_*`, `panels` Map, etc.). This document is a section map so you +can jump straight to the relevant code without reading the whole file. + +> **Keeping this file fresh:** line numbers drift as the JS evolves. The +> section banners are greppable — regenerate the quick-reference with +> `rg -n '^\s*// ──' anyplotlib/figure_esm.js` and function anchors with +> `rg -n '^\s*function \w+' anyplotlib/figure_esm.js`. Update this file +> whenever a PR moves a section by more than ~50 lines. + +--- + +## Sizing contract + +``` +Rule 1 – Grid tracks are always pure ratio math. + col_px[i] = fig_width × width_ratios[i] / Σ width_ratios + row_px[r] = fig_height × height_ratios[r] / Σ height_ratios + No exceptions. No 2-D special-casing. Both Python + (_compute_cell_sizes) and JS (_applyFigResizeDOM) follow this rule. + +Rule 2 – All panels in the same grid column have the same canvas width. + All panels in the same grid row have the same canvas height. + (Follows automatically from Rule 1.) + +Rule 3 – Images are displayed "contain" (letterbox / pillarbox). + _imgFitRect(iw, ih, cw, ch) → largest rect of aspect iw:ih + that fits inside cw×ch, centred. + +Rule 4 – Zoom is relative to the fit-rect. + zoom=1 → fit-rect exactly filled by the whole image. + zoom=Z → a 1/Z portion of the image fills the fit-rect. + +Rule 5 – Text never clips. Optional gutters earn real layout space: + the colorbar (strip + label, _cbWidth) is subtracted from the + image width; the 2D title strip (_padT) grows for large or TeX + titles; 1D/bar titles clamp their drawn size to the fixed strip + (_titlePx); edge tick labels are nudged inward. +``` + +--- + +## Quick-reference: function anchors + +| Section / function | Line | +|--------------------|------| +| Shared plot-area padding (`PAD_*`) | 9 | +| Theme (dark/light detection) | 15 | +| Shared math helpers | 53 | +| b64 array decode helpers | 95 | +| **Rich-text (mini-TeX) engine**: `_texRuns` / `_texLayout` / `_drawTex` | 147 / 214 / 236 | +| **2D gutter geometry**: `_cbWidth` / `_padT` / `_titlePx` | 287 / 299 / 309 | +| **Layout engine** `applyLayout` | 590 | +| `_buildCanvasStack` | 656 | +| `_createPanelDOM` | 763 | +| `_createInsetDOM` / `_applyAllInsetStates` | 846 / 968 | +| `_resizePanelDOM` | 1027 | +| **2D drawing**: `_imgFitRect` | 1176 | +| `draw2d` | 1258 | +| `drawScaleBar2d` / `drawColorbar2d` | 1360 / 1436 | +| `_drawAxes2d` (ticks, labels, title) | 1491 | +| `drawOverlay2d` / `drawMarkers2d` | 1629 / 1685 | +| **3D drawing**: `draw3d` | 1833 | +| Event emission `_emitEvent` | 2031 | +| 3D event handlers `_attachEvents3d` | 2059 | +| **1D drawing**: `draw1d` | 2177 | +| `drawOverlay1d` / `drawMarkers1d` | 2516 / 2586 | +| Marker hit-test `_markerHitTest2d` | 2787 | +| Panel event dispatch `_attachPanelEvents` | 2905 | +| 2D events `_attachEvents2d` | 2928 | +| 1D events `_attachEvents1d` | 3201 | +| 2D widget drag `_ovHitTest2d` / `_doDrag2d` | 3409 / 3491 | +| 1D widget drag `_canvasXToFrac1d` … | 3565 | +| Shared-axis propagation `_getShareGroups` | 3650 | +| Figure resize `_applyFigResizeDOM` | 3714 | +| **Bar chart**: `_barGeom` / `drawBar` / `_attachEventsBar` | 3902 / 3965 / 4341 | +| Generic redraw `_redrawPanel` | 4531 | + +--- + +## Rich-text (mini-TeX) label engine + +Canvas cannot run MathJax, so labels support a small TeX subset inside +`$...$` delimiters — superscripts/subscripts (`$10^{-3}$`, `$E_F$`), Greek +letters (`\alpha`…`\Omega`), and symbols (`\times`, `\AA`, `\degree`, +`\propto`, …; see `_TEX_SYM`). `\mathrm{...}` gives upright text; math-mode +letters are italic. Python stores label strings verbatim — all parsing +happens here at draw time. + +| Function | Purpose | +|----------|---------| +| `_texRuns(text)` | Parse a label into runs `[{t, lvl, it}]` — lvl 0/+1/−1, it = italic | +| `_texLayout(ctx, text, px, weight, family)` | Measure runs; sup/sub at 0.68×, dy −0.28/+0.16 em from a shared alphabetic baseline | +| `_drawTex(ctx, text, x, y, px, opts)` | Draw a label. `opts: {align, weight, family}`. Fast path (no `$`) is a single `fillText`. Respects caller's `fillStyle`/`textBaseline`. | + +**Baseline conversion gotcha:** `TextMetrics.fontBoundingBoxAscent` is +measured **relative to the current `textBaseline`**, not alphabetic. +`_drawTex` therefore measures the ascent under the caller's baseline AND +under `alphabetic`, and shifts by the difference — this makes TeX text land +at exactly the same height a plain `fillText` would. + +**All axis labels, titles, the colorbar label, 3D axis labels, and log tick +labels (`$10^{N}$`) render through `_drawTex`.** Font sizes come from state +with fallbacks to the historical defaults: `title_size||11`, +`x_label_size`/`y_label_size` (11 for 2D, 9 for 1D units, 10 for bar, 11 for +3D), `tick_size||10`, `colorbar_label_size||10`. + +## 2D gutter geometry helpers + +| Function | Purpose | +|----------|---------| +| `_cbWidth(st)` | Width reserved for the colorbar: 0 when hidden, else `16 + (label ? label_size+8 : 0)`. Subtracted from the image width in `_resizePanelDOM` / `_resizePanelCSS` so the strip + label always fit inside the panel. | +| `_padT(st)` | 2D title-strip height: `PAD_T` (12) for default-size plain titles (pixel-identical layouts); grows to `ceil(size*1.3)+2..4` for `title_size > 11` or TeX titles (superscript rise). Stored as `p._padT`. | +| `_titlePx(st)` | Drawn title size for fixed-strip panels (1D/bar): clamps to 11 (10 for TeX titles) so nothing clips. | + +`draw2d` calls `_resizePanelDOM` on every state push, so colorbar/title +geometry changes (visibility, label, sizes) re-layout automatically. + +--- + +## Layout / panel details + +#### `applyLayout()` (line 590) +Reads `layout_json`. Builds CSS grid tracks from `panel_specs[].panel_width/height`. +Creates panels that don't exist yet, resizes existing ones, removes stale ones. +Also creates/updates inset panels from `inset_specs`. + +#### `_createPanelDOM(id, kind, pw, ph, spec)` (line 763) +Builds all canvas/DOM elements for one panel (via `_buildCanvasStack`), +stores the **`p` object** in `panels`, subscribes to +`change:panel_{id}_json`, runs the initial draw. + +**DOM structure by kind:** +| kind | elements | +|------|----------| +| `'2d'` | `plotWrap > plotCanvas + overlayCanvas + markersCanvas + yAxisCanvas + xAxisCanvas + cbCanvas + scaleBar + statusBar + titleCanvas` | +| `'3d'` | `wrap3 > plotCanvas + overlayCanvas + markersCanvas + statusBar` | +| `'1d'` / `'bar'` | `wrap > plotCanvas + overlayCanvas + markersCanvas + statusBar` | + +#### `_resizePanelDOM(id, pw, ph)` (line 1027) +Updates `canvas.width / canvas.height` (DPR-scaled) for every canvas in the +panel. For 2D, computes `imgX/imgY/imgW/imgH` from the gutters +(`PAD_*`, `_padT`, `_cbWidth`) and stores them on `p` plus `p._cbW`/`p._padT`. + +#### The `p` (panel) object — key fields +```js +p.id, p.kind, p.pw, p.ph +p.state // parsed JSON from panel_{id}_json (full plot state dict) +p.imgX, p.imgY, p.imgW, p.imgH // 2D inner image area (gutters removed) +p._cbW, p._padT // 2D gutter geometry at last layout +p.plotCanvas/.overlayCanvas/.markersCanvas (+ 2D: x/yAxisCanvas, cbCanvas, +p.titleCanvas, p.scaleBar), p.statusBar +p.blitCache // { bitmap, bytesKey, lutKey, w, h } — ImageBitmap cache +p.ovDrag / p.ovDrag2d / p.isPanning +``` + +--- + +## 2D drawing (from line 1176) + +Key state fields: +``` +st.image_b64, st.image_width/height +st.zoom, st.center_x/y +st.display_min/max, st.raw_min/max, st.scale_mode +st.colormap_data [[r,g,b], ...] × 256 +st.x_axis, st.y_axis, st.axis_visible +st.markers, st.overlay_widgets, st.overlay_mask_b64/_color/_alpha +st.title_size, st.x_label_size, st.y_label_size, st.tick_size, +st.colorbar_label_size (label font sizes; optional) +``` + +| Function | Line | Purpose | +|----------|------|---------| +| **`_imgFitRect(iw,ih,cw,ch)`** | **1176** | Largest rect of aspect `iw:ih` centred in `cw×ch`; all 2-D coordinate functions derive from this | +| `draw2d(p)` | 1258 | Main render: `_resizePanelDOM` → decode → LUT → ImageBitmap → blit; then mask, axes, scale bar, colorbar, overlay, markers | +| `drawScaleBar2d(p)` | 1360 | Physical scale bar | +| `drawColorbar2d(p)` | 1436 | Gradient strip + min/max marks + rotated label centred in the `_cbWidth` gutter | +| `_drawAxes2d(p)` | 1491 | Ticks (edge labels nudged inward both axes), axis labels + title via `_drawTex` | +| `drawOverlay2d(p)` / `drawMarkers2d(p)` | 1629 / 1685 | Widgets / marker groups | + +Zoom model: at `zoom=1` the whole image fills the fit-rect; at `zoom=Z>1` a +`1/Z` region fills it. `_imgToCanvas2d` / `_canvasToImg2d` must stay exact +inverses of the blit geometry. + +--- + +## 3D drawing (line ~1840) +Orthographic projection; geometry b64-decoded and cached. `draw3d` sorts +triangles, draws axes with per-axis `_drawTex` labels (`x/y/z_label_size`). + +- **Camera** (`_rot3`): turntable with matplotlib azim/elev semantics — + azimuth spins about the DATA z-axis, elevation tilts toward the viewer. + Faces unit vector v when `el = asin(vz)`, `az = atan2(vx, -vy)`. +- **Scatter colours**: `st.point_colors_b64` (uint8 RGB triplets) gives + per-point colours; empty string falls back to `st.color`. +- **Highlight**: `st.highlight = {x,y,z,color,size}` draws an emphasised + ringed dot on top of everything (semi-transparent on the far side). +- **Reference sphere**: `st.sphere = {radius,color,alpha,wireframe}` draws a + shaded silhouette disk + lat/long wireframe behind the geometry; far-side + wireframe segments and scatter points are dimmed. +- **Voxels** (`geom_type 'voxels'`): shaded translucent cubes at the vertex + centres. `st.voxel_size`, `st.voxel_alpha`, `st.voxel_slice_alpha`. + Performance design (budget ~3–6 µs/cube, ≤ ~20k cubes interactive): + cube-corner screen offsets + face visibility computed once per frame; + per-(colour, emphasis) sprites blitted with integer-snapped `drawImage` + (≤256 unique colours; falls back to path fills above); typed-array + projection + depth-sort cached per (geometry generation, view, panel + size) so camera-static redraws (plane drags) only re-blit. Benchmarks: + `test_bench_voxels_orbit` / `test_bench_voxels_reblit`. +- **Echo guard**: `_attachEvents3d` writes interaction state via + `_writeState()` (sets `p._selfWrite`), and the panel-json listener skips + self-writes — without this every drag frame paid a second + JSON.parse + full redraw. +- **Touch bridge** (`_attachTouch`, called from `_attachPanelEvents` for + every panel kind): translates touch gestures into the *existing* mouse / + wheel handlers via real `MouseEvent` / `WheelEvent` dispatch — 1-finger → + mousedown/move/up, 2-finger pinch → wheel (anchored at the gesture + midpoint via `p.mouseX/Y`), double-tap → dblclick. `move`/`up` go to + `document` (handlers listen there for off-canvas drags); `down`/`wheel`/ + `dblclick` go to the overlay canvas. Overlay canvases set + `touch-action:none` so the browser yields gestures to the plot. No + handler rewrites — a working mouse interaction is automatically a working + touch one. +- **Geometry channel** (perf): plots that declare `_GEOM_KEYS` on the Python + side (Plot2D, Plot3D) split heavy keys (`vertices_b64`, `image_b64`, + `colormap_data`, …) into a second `panel__geom` trait, re-sent only + when their content hash changes; the view trait carries `_geom_rev`. JS + caches the decoded geom (`p._geomCache`/`p._geomRev`) and `_applyGeom` + splices it into the state before every draw, so view-only updates + (highlight, camera, planes, title) never re-parse or re-transmit + geometry. Both the `change:panel__geom` and `change:panel__json` + listeners call `_applyGeom`; the geom trait is loaded before the first + draw. Pairs with `Figure.batch()` push-coalescing on the Python side. +- **WebGPU path** (progressive enhancement, additive): scatter points + (`_GPU_POINT_WGSL`) and voxels (`_GPU_VOXEL_WGSL`) render instanced on the + GPU when available and above threshold (`GPU_POINT_THRESHOLD` 20k / + `GPU_VOXEL_THRESHOLD` 8k); `gpu_mode` ∈ auto/always/off. `gpuCanvas` sits + below `plotCanvas`; decorations always draw on the 2D `plotCanvas` over a + transparent background. `_gpuMatrix` reproduces the canvas projection + EXACTLY (verify numerically — the y-coefficients are NOT negated: canvas + screen-y-down and NDC-y-up cancel). Voxel slice emphasis + per-face shade + are uniforms, so plane drags are a uniform write. Every failure path + (no `navigator.gpu`, null adapter, device lost, draw throw) sets + `p._gpu='unavailable'` and the Canvas2D path renders unchanged. **Testing: + use offscreen-texture readback (`copyTextureToBuffer`), NOT screenshots — + the WebGPU swapchain doesn't snapshot reliably under automation.** +- **Plane widgets** (`st.overlay_widgets`, type `'plane'`): translucent + draggable slice selectors. `draw3d` caches screen quads + the axis screen + direction on `p._3dPlanes`; `_attachEvents3d` hit-tests them on mousedown + (plane drag wins over orbit) and drags along the normal. Voxels within + half a voxel of a plane render at `voxel_slice_alpha`. NOTE: during drags + re-resolve widgets by id in `p.state` — object references go stale because + the model echo replaces `p.state` on every `save_changes()`. +- `st.data_bounds` may be fixed from Python (`bounds=` kwarg) so geometry + normalisation stays origin-true (unit-sphere direction vectors). + +## Events +- `_emitEvent(panelId, eventType, widgetId, extraData)` (line 2031) writes + `{source:'js', ...}` to `model.event_json`; `eventType` is any + `pointer_*` / `key_*` / `wheel` / `double_click` string + (see `callbacks.VALID_EVENT_TYPES`). +- Kind-specific attach functions: 3D 2059, 2D 2928, 1D 3201, bar 4341. +- Widget drag: 2D hit-test/drag 3409/3491; 1D from 3565. + +## 1D drawing (line 2177) +`draw1d` renders series (b64 decode cache), axes, ticks (log ticks as TeX +`$10^{N}$`; edge labels nudged inward), grid, legend, units labels + title +via `_drawTex` (title size clamped via `_titlePx`). + +## Bar chart (lines 3902–4530) +`_barGeom` (3902) computes per-bar geometry incl. grouped offsets and +log-scale mappers; `drawBar` (3965) renders grid, bars, value labels, ticks +(log ticks as TeX superscripts, category edge labels nudged inward), legend, +labels + clamped title; `_attachEventsBar` (4341) handles drag/hover/click. +Bar zoom/pan modifies `st.data_min/max` (value axis); `view_x0/x1` stays 0/1. + +--- + +## Key data flows + +``` +Python push: + plot._push() → figure._push(id) → panel_{id}_json trait changes + → model.on('change:panel_{id}_json') → p.state = JSON.parse(...) + → _redrawPanel(p) + +JS → Python (widget drag): + _doDrag2d / _doDrag1d → updates p.state.overlay_widgets in-place + → _emitEvent(id, 'pointer_move', widgetId, {…}) + → model.set('event_json', …) + save_changes() + → Python Figure._on_event() → Widget._update_from_js() + CallbackRegistry.fire() + +JS → Python (3D rotate / zoom): + _attachEvents3d → model.set('panel_{id}_json', …) + save_changes() + +Python → JS (set widget position from Python): + widget.set(…) → Figure._push_widget → event_json with source:'python' + → model.on('change:event_json') patches overlay_widgets + redraws +``` diff --git a/anyplotlib/__init__.py b/anyplotlib/__init__.py index f0985027..2671c4df 100644 --- a/anyplotlib/__init__.py +++ b/anyplotlib/__init__.py @@ -1,4 +1,45 @@ from anyplotlib.figure import Figure, GridSpec, SubplotSpec, subplots -from anyplotlib.figure_plots import PlotMesh, Plot3D +from anyplotlib.axes import Axes, InsetAxes +from anyplotlib.plot1d import Plot1D, PlotBar +from anyplotlib.plot1d._plot1d import Line1D +from anyplotlib.plot2d import Plot2D, PlotMesh +from anyplotlib.plot3d import Plot3D +from anyplotlib.plotxy import PlotXY +from anyplotlib.callbacks import CallbackRegistry, Event +from anyplotlib import embed +from anyplotlib.markers import MarkerRegistry, MarkerGroup +from anyplotlib.widgets import ( + Widget, RectangleWidget, CircleWidget, AnnularWidget, + CrosshairWidget, PolygonWidget, LabelWidget, + VLineWidget, HLineWidget, RangeWidget, PlaneWidget, +) -__all__ = ["Figure", "GridSpec", "SubplotSpec", "subplots", "PlotMesh", "Plot3D"] +# ── Global help flag ────────────────────────────────────────────────────── +# Set to False to suppress help badges on all figures in this session. +# Default True: badges appear whenever a figure has help text set. +show_help: bool = True + +_COLOR_CYCLE: list[str] = [ + "#4fc3f7", "#ff7043", "#aed581", "#ffd54f", + "#ba68c8", "#4db6ac", "#f06292", "#90a4ae", + "#ffb74d", "#a5d6a7", +] + + +def get_color_cycle() -> list[str]: + """Return the default color cycle as a list of CSS hex strings.""" + return list(_COLOR_CYCLE) + + +__all__ = [ + "Figure", "GridSpec", "SubplotSpec", "subplots", + "Axes", "InsetAxes", "Plot1D", "Plot2D", "PlotMesh", "Plot3D", "PlotBar", + "Line1D", + "CallbackRegistry", "Event", + "MarkerRegistry", "MarkerGroup", + "Widget", "RectangleWidget", "CircleWidget", "AnnularWidget", + "CrosshairWidget", "PolygonWidget", "LabelWidget", + "VLineWidget", "HLineWidget", "RangeWidget", "PlaneWidget", + "show_help", "get_color_cycle", + "embed", +] diff --git a/anyplotlib/_base_plot.py b/anyplotlib/_base_plot.py new file mode 100644 index 00000000..0dbe7be9 --- /dev/null +++ b/anyplotlib/_base_plot.py @@ -0,0 +1,229 @@ +""" +_base_plot.py +============= +Shared base classes and mixins for all plot panel types. +""" + +from __future__ import annotations + +from contextlib import contextmanager + +from anyplotlib.callbacks import _EventMixin + + +class _BasePlot(_EventMixin): + """Universal base for Plot1D, Plot2D, PlotBar, and Plot3D. + + Contains methods identical across all four panel types and helper + utilities used by view-setter and widget-adder methods. + + Subclasses must define: + _state : dict — the panel state dict + _push() -> None — serialize state and write to parent Figure + """ + + def configure_pointer_settled(self, ms: int, delta: float = 4) -> None: + """Configure the pointer-settled event threshold (ms and pixel delta).""" + self._state["pointer_settled_ms"] = ms + self._state["pointer_settled_delta"] = delta + self._push() + + _configure_pointer_settled = configure_pointer_settled + + #: Mini-TeX formatting note shared by all label setters. + #: + #: Label strings support a small TeX subset inside ``$...$`` delimiters, + #: rendered by the JS canvas engine (no MathJax needed): + #: + #: * ``$10^{-3}$`` / ``$x^2$`` — superscripts (exponents) + #: * ``$E_F$`` / ``$k_{B}T$`` — subscripts + #: * ``$\\alpha$ … $\\Omega$`` — Greek letters + #: * ``\\times \\cdot \\pm \\degree \\AA \\infty \\propto \\approx`` + #: ``\\leq \\geq \\neq \\partial \\nabla \\hbar \\rightarrow`` — symbols + #: * ``$\\mathrm{...}$`` — upright text inside math (letters in + #: math mode are italic by default) + #: + #: Example: ``plot.set_xlabel(r"$q$ ($\\AA^{-1}$)", fontsize=14)`` + + def _set_label(self, key: str, label: str, size_key: str, + fontsize: float | None) -> None: + """Store a label string (TeX subset allowed) and its optional size.""" + self._state[key] = str(label) + if fontsize is not None: + self._state[size_key] = float(fontsize) + self._push() + + def set_title(self, label: str, fontsize: float | None = None) -> None: + """Set the panel title. + + Parameters + ---------- + label : str + Title text. Supports the mini-TeX subset (``$10^{-3}$``, + ``$\\alpha$``, …) — see the class notes on label formatting. + fontsize : float, optional + Font size in CSS pixels. Default 11. On 2-D panels the title + strip grows to fit larger sizes. 1-D and bar titles render in a + fixed 12-px strip, so the drawn size is clamped to 11 there. + """ + self._set_label("title", label, "title_size", fontsize) + + def set_axis_off(self) -> None: + self._state["axis_visible"] = False + self._push() + + def set_axis_on(self) -> None: + self._state["axis_visible"] = True + self._push() + + @contextmanager + def _python_view_push(self): + """Context manager for view setters that must signal _view_from_python. + + Sets the flag on entry, yields for state mutations, then pushes + and clears the flag on exit. + """ + self._state["_view_from_python"] = True + try: + yield + finally: + self._push() + self._state["_view_from_python"] = False + + def _make_widget_push_fn(self, widget): + """Return a targeted-push closure for a widget. + + Replaces the repeated _tp / _targeted_push closures in every + add_*_widget method. + """ + plot_ref, wid_id = self, widget._id + def _push(): + if plot_ref._fig is not None: + fields = {k: v for k, v in widget._data.items() + if k not in ("id", "type")} + plot_ref._fig._push_widget(plot_ref._id, wid_id, fields) + return _push + + +class _PanelMixin: + """Mixin for panels that support interactive widgets and tick control. + + Shared by Plot1D, Plot2D, and PlotBar. Provides _push (with widget + serialization), widget management, and tick visibility control. + + Subclasses must define: + _state : dict + _fig : object + _id : str + _widgets : dict[str, Widget] + """ + + def _push(self) -> None: + if self._fig is None: + return + self._state["overlay_widgets"] = [w.to_dict() for w in self._widgets.values()] + self._fig._push(self._id) + + def set_tick_label_size(self, size: float) -> None: + """Set the font size of the tick (axis number) labels in CSS pixels. + + Applies to both axes of the panel. Default 10. + + Parameters + ---------- + size : float + Tick label font size in pixels. + """ + self._state["tick_size"] = float(size) + self._push() + + def set_ticks_visible(self, visible: bool, *, x: bool | None = None, + y: bool | None = None) -> None: + if x is None and y is None: + self._state["x_ticks_visible"] = bool(visible) + self._state["y_ticks_visible"] = bool(visible) + else: + if x is not None: + self._state["x_ticks_visible"] = bool(x) + if y is not None: + self._state["y_ticks_visible"] = bool(y) + self._push() + + def get_widget(self, wid): + """Return the Widget object by ID string or Widget instance.""" + from anyplotlib.widgets import Widget + if isinstance(wid, Widget): + wid = wid.id + try: + return self._widgets[wid] + except KeyError: + raise KeyError(wid) + + def remove_widget(self, wid) -> None: + """Remove a widget by ID string or Widget instance.""" + from anyplotlib.widgets import Widget + if isinstance(wid, Widget): + wid = wid.id + if wid not in self._widgets: + raise KeyError(wid) + del self._widgets[wid] + self._push() + + def list_widgets(self) -> list: + """Return a list of all active widget objects on this panel.""" + return list(self._widgets.values()) + + def clear_widgets(self) -> None: + """Remove all interactive overlay widgets from this panel.""" + self._widgets.clear() + self._push() + + +class _MarkerMixin: + """Mixin for panels that support static marker collections. + + Shared by Plot1D and Plot2D. + + Subclasses must define: + _state : dict + markers : MarkerRegistry + _push() -> None + """ + + def _push_markers(self) -> None: + self._state["markers"] = self.markers.to_wire_list() + self._push() + + def _add_marker(self, mtype: str, name, **kwargs): + return self.markers.add(mtype, name, **kwargs) + + def remove_marker(self, marker_type: str, name: str) -> None: + """Remove a named marker collection by type and name. + + Parameters + ---------- + marker_type : str + Collection type, e.g. ``"points"``, ``"vlines"``. + name : str + The name used when the collection was created. + """ + self.markers.remove(marker_type, name) + + def clear_markers(self) -> None: + """Remove all marker collections from this panel.""" + self.markers.clear() + + def list_markers(self) -> list: + """Return a summary list of all marker collections on this panel. + + Returns + ------- + list of dict + Each dict has keys ``"type"``, ``"name"``, and ``"n"`` + (number of markers in the collection). + """ + out = [] + for mtype, td in self.markers._types.items(): + for name, g in td.items(): + out.append({"type": mtype, "name": name, "n": g._count()}) + return out diff --git a/anyplotlib/_electron.py b/anyplotlib/_electron.py new file mode 100644 index 00000000..fbbd5c1f --- /dev/null +++ b/anyplotlib/_electron.py @@ -0,0 +1,74 @@ +""" +_electron.py +============ +Electron app bridge for anyplotlib figures. + +Registers figures so their trait changes are forwarded to the Electron +renderer via stdout, and provides dispatch_event() so the renderer can +send interaction events back to Python. +""" +from __future__ import annotations + +import json +import sys +import uuid + +_figures: dict[str, object] = {} # fig_id -> Figure + + +def register(fig) -> str: + """Register *fig* for bidirectional state sync and return its fig_id.""" + fig_id = uuid.uuid4().hex[:8] + _figures[fig_id] = fig + + def _on_change(change): + name = change["name"] + value = change["new"] + if isinstance(value, (bytes, bytearray)): + import base64 + value = {"buffer": base64.b64encode(value).decode()} + emit({"type": "state_update", "fig_id": fig_id, "key": name, "value": value}) + + for name in fig.traits(sync=True): + if not name.startswith("_"): + try: + fig.observe(_on_change, names=[name]) + except Exception: + pass + + return fig_id + + +def resize_figure(fig_id: str, width: int, height: int) -> None: + """Update fig_width / fig_height and push new layout to the iframe.""" + fig = _figures.get(fig_id) + if fig is None: + return + try: + # Batch both trait changes so _on_resize fires only once each. + with fig.hold_trait_notifications(): + fig.fig_width = int(width) + fig.fig_height = int(height) + except Exception: + pass + + +def dispatch_event(fig_id: str, event_json: str) -> None: + """Apply a frontend interaction event to the registered figure.""" + fig = _figures.get(fig_id) + if fig is None: + return + try: + # Figure.show() registers Figure objects which use _dispatch_event(raw_json_str). + # Standalone widgets use _update_from_js(dict, event_type). + if hasattr(fig, "_dispatch_event"): + fig._dispatch_event(event_json) + elif hasattr(fig, "_update_from_js"): + fig._update_from_js(json.loads(event_json)) + except Exception: + pass + + +def emit(obj: dict) -> None: + sys.stdout.write(f"PLOTAPP:{json.dumps(obj, default=str)}\n") + sys.stdout.flush() diff --git a/anyplotlib/_repr_utils.py b/anyplotlib/_repr_utils.py index 42be9d61..f8e57622 100644 --- a/anyplotlib/_repr_utils.py +++ b/anyplotlib/_repr_utils.py @@ -7,7 +7,7 @@ Strategy -------- - and 1. Serialise every synced traitlet value to a plain JSON dict. +1. Serialise every synced traitlet value to a plain JSON dict. 2. Embed that dict and the widget's ``_esm`` source directly in the page. 3. Provide a minimal model shim (get/set/on/save_changes) so the ESM's render() function works without any Jupyter comm infrastructure. @@ -24,6 +24,11 @@ from html import escape from uuid import uuid4 +# Maximum display width (px) for the non-resizable notebook embed. +# Figures wider than this are scaled down proportionally via CSS transform. +# 860 px fits comfortably in a standard JupyterLab / VS Code notebook cell. +MAX_NOTEBOOK_WIDTH = 860 + # --------------------------------------------------------------------------- # Trait serialisation @@ -124,18 +129,51 @@ def _widget_px(widget) -> tuple[int, int]:
""" -def build_standalone_html(widget, *, resizable: bool = True) -> str: +def build_standalone_html(widget, *, resizable: bool = True, + fig_id: str | None = None) -> str: """Return a self-contained HTML page that renders *widget* interactively. Parameters @@ -183,6 +231,9 @@ def build_standalone_html(widget, *, resizable: bool = True) -> str: When ``True`` (default) the widget's built-in resize handle is preserved. When ``False`` the handle is hidden via CSS and the page is sized exactly to the widget's natural dimensions. + fig_id : str or None + When provided, embedded as ``FIG_ID`` so the parent-page bridge + can route ``postMessage`` state updates to this iframe. """ state = _widget_state(widget) @@ -200,12 +251,14 @@ def build_standalone_html(widget, *, resizable: bool = True) -> str: extra_css=extra_css, state_json=json.dumps(state, default=str), esm_json=json.dumps(esm), + fig_id_json=json.dumps(fig_id), ) def repr_html_iframe(widget, *, resizable: bool = False, + max_width: int = MAX_NOTEBOOK_WIDTH, max_height: int = 800) -> str: - """Return a centred ``' f'' + f'' + f'' ) else: - # Resizable — fill container width, auto-resize height after render. + # ── Resizable embed (fills cell width, auto-sizes height) ────────── return ( f'' + f'' + f'' + f'' + ) + else: + return ( + f'' + ) + diff --git a/anyplotlib/sphinx_anywidget/_scraper.py b/anyplotlib/sphinx_anywidget/_scraper.py new file mode 100644 index 00000000..7cda5858 --- /dev/null +++ b/anyplotlib/sphinx_anywidget/_scraper.py @@ -0,0 +1,390 @@ +""" +sphinx_anywidget/_scraper.py +============================= + +Generic Sphinx Gallery image scraper for any ``anywidget.AnyWidget`` subclass. + +Drop-in replacement for the anyplotlib-specific ``_sg_html_scraper.ViewerScraper``. +Works with **any** library built on anywidget — just add the scraper to your +``sphinx_gallery_conf["image_scrapers"]``. + +Interactive tagging +------------------- +If a code block's last expression line contains a ``# Interactive`` comment +(case-insensitive), the scraper: + +* embeds the full example Python source in a ``' + f'' + ) + + +# --------------------------------------------------------------------------- +# Scraper +# --------------------------------------------------------------------------- + +class AnywidgetScraper: + """Sphinx Gallery image scraper for any ``anywidget.AnyWidget`` subclass. + """ + + def __init__(self): + # Maps src_file → list of fig_ids emitted so far (creation order). + self._example_figs: dict = {} + + def __repr__(self) -> str: + return "AnywidgetScraper()" + + def __call__(self, block, block_vars, gallery_conf): + globals_dict = block_vars.get("example_globals", {}) + widget = _find_widget(globals_dict) + if widget is None: + return "" + + src_file = str(block_vars.get("src_file", "")) + + # ── detect # Interactive tag ────────────────────────────────────── + block_source = block[1] if isinstance(block, (list, tuple)) else "" + is_interactive = bool(_INTERACTIVE_RE.search(block_source)) + + # ── assign a stable fig_id and fig_index ───────────────────────── + if src_file not in self._example_figs: + self._example_figs[src_file] = [] + fig_index = len(self._example_figs[src_file]) + + # ── 1. Write the thumbnail PNG ──────────────────────────────────── + image_path_iterator = block_vars["image_path_iterator"] + png_path = Path(next(image_path_iterator)) + png_path.parent.mkdir(parents=True, exist_ok=True) + png_path.write_bytes(_make_thumbnail_png(widget)) + + fig_id = png_path.stem # stable, unique stem from Sphinx Gallery + self._example_figs[src_file].append(fig_id) + + # ── 2. Write the standalone HTML ────────────────────────────────── + try: + from anyplotlib.sphinx_anywidget._repr_utils import ( + build_standalone_html, _widget_px, + ) + docs_dir = Path(gallery_conf["src_dir"]) + widgets_dir = docs_dir / "_static" / "viewer_widgets" + widgets_dir.mkdir(parents=True, exist_ok=True) + + html_name = png_path.stem + ".html" + html_path = widgets_dir / html_name + + inner_html = build_standalone_html(widget, resizable=False, fig_id=fig_id) + html_path.write_text(inner_html, encoding="utf-8") + w, h = _widget_px(widget) + have_html = True + except Exception as exc: + print(f"[sphinx_anywidget] WARNING: could not write iframe HTML: {exc}") + have_html = False + + # ── 3. Return rST ───────────────────────────────────────────────── + if have_html: + try: + src_dir = Path(gallery_conf["src_dir"]) + page_dir = png_path.parent.parent # strip /images + rel_parts = page_dir.relative_to(src_dir).parts + depth = len(rel_parts) + except Exception: + depth = 1 + prefix = "../" * depth + src = f"{prefix}_static/viewer_widgets/{html_name}" + + iframe_block = _iframe_html( + src, w, h, + fig_id=fig_id, + interactive=is_interactive, + ) + + rst = "\n\n.. raw:: html\n\n " + iframe_block + "\n\n" + + if is_interactive: + # Embed the example Python source so the Pyodide bridge can + # re-execute it and wire live callbacks. + python_src = "" + try: + python_src = Path(src_file).read_text(encoding="utf-8") + except Exception: + pass + + if python_src: + data_src = _html_escape(_json.dumps(python_src), quote=True) + + # Detect _PYODIDE_PACKAGES = [...] in the source. + _pkg_attr = "" + m = _PYODIDE_PACKAGES_RE.search(python_src) + if m: + try: + import ast as _ast + pkgs = _ast.literal_eval(m.group(1)) + if pkgs: + _pkg_attr = ( + f' data-pyodide-packages=' + f'"{_html_escape(_json.dumps(pkgs), quote=True)}"' + ) + except Exception: + pass + + # Detect _PYODIDE_MOCK_PACKAGES = [...] in the source. + _mock_attr = "" + m2 = _PYODIDE_MOCK_PACKAGES_RE.search(python_src) + if m2: + try: + import ast as _ast2 + mock_pkgs = _ast2.literal_eval(m2.group(1)) + if mock_pkgs: + _mock_attr = ( + f' data-pyodide-mock-packages=' + f'"{_html_escape(_json.dumps(mock_pkgs), quote=True)}"' + ) + except Exception: + pass + + python_block = ( + f'' + ) + rst += "\n\n.. raw:: html\n\n " + python_block + "\n\n" + + return rst + else: + return ( + f"\n\n.. image:: {png_path.name}\n" + f" :width: 100%\n\n" + ) + + +# Back-compat alias used by the existing anyplotlib docs. +ViewerScraper = AnywidgetScraper + diff --git a/anyplotlib/sphinx_anywidget/_wheel_builder.py b/anyplotlib/sphinx_anywidget/_wheel_builder.py new file mode 100644 index 00000000..fba6915c --- /dev/null +++ b/anyplotlib/sphinx_anywidget/_wheel_builder.py @@ -0,0 +1,84 @@ +""" +sphinx_anywidget/_wheel_builder.py +==================================== + +Builds a project wheel at docs-build time so the Pyodide bridge can install +the exact library version that generated the docs — no PyPI release required. +""" + +from __future__ import annotations + +import re +import subprocess +from pathlib import Path + + +def build_wheel( + static_dir: Path, + package_name: str, + project_root: Path, +) -> "Path | None": + """Build a pure-Python wheel into *static_dir/wheels/*. + + The wheel is renamed to ``{package_name}-0.0.0-py3-none-any.whl``. + ``0.0.0`` is a valid PEP 440 sentinel micropip accepts for URL installs. + + Parameters + ---------- + static_dir : + The docs ``_static`` directory; a ``wheels/`` sub-dir is created. + package_name : + PyPI / importable name (e.g. ``"anyplotlib"``). + project_root : + Directory containing ``pyproject.toml`` / ``setup.py``. + + Returns + ------- + Path or None + Path to the written wheel, or *None* on failure. + """ + wheels_dir = static_dir / "wheels" + wheels_dir.mkdir(parents=True, exist_ok=True) + + # PEP 427 normalises distribution names: hyphens and dots → underscores. + normalised = re.sub(r"[-.]", "_", package_name) + + stable = wheels_dir / f"{normalised}-0.0.0-py3-none-any.whl" + + # Build into a temporary sub-directory so we never clobber the existing + # stable wheel until we know the new build actually succeeded. + import tempfile + with tempfile.TemporaryDirectory(dir=wheels_dir) as tmp_str: + tmp_dir = Path(tmp_str) + result = subprocess.run( + [ + "uv", "build", "--wheel", + "--out-dir", str(tmp_dir), + str(project_root), + ], + capture_output=True, + text=True, + ) + + if result.returncode != 0: + print( + f"\n[sphinx_anywidget] WARNING: wheel build failed " + f"for {package_name!r}:\n{result.stderr}" + ) + return None + + new_wheels = sorted(tmp_dir.glob(f"{normalised}*.whl")) + if not new_wheels: + print(f"\n[sphinx_anywidget] WARNING: no wheel found for {package_name!r}") + return None + + # Build succeeded — now replace the stable wheel atomically. + stable.unlink(missing_ok=True) + # Remove any other stale versioned wheels before moving the new one. + for old in wheels_dir.glob(f"{normalised}*.whl"): + old.unlink(missing_ok=True) + new_wheels[-1].rename(stable) + # ASCII only: Windows consoles (cp1252) can't print '→' during builds + print(f"[sphinx_anywidget] wheel -> {stable}") + return stable + diff --git a/anyplotlib/sphinx_anywidget/static/anywidget_bridge.js b/anyplotlib/sphinx_anywidget/static/anywidget_bridge.js new file mode 100644 index 00000000..fbf2c96b --- /dev/null +++ b/anyplotlib/sphinx_anywidget/static/anywidget_bridge.js @@ -0,0 +1,1099 @@ +/** + * anywidget_bridge.js + * + * Generic Pyodide bridge for anywidget-based interactive documentation. + * + * Architecture + * ──────────── + * Parent page (this script) + * ├─ Per-figure ⚡ badge (in .awi-badge div, rendered by _scraper.py) + * ├─ Pyodide WASM runtime (loaded once from CDN on first ⚡ click) + * ├─ Package wheel at _static/wheels/{pkg}-0.0.0-py3-none-any.whl + * ├─ ", + ] + mock_js = "\n".join(mock_lines) + + parent_html = ( + "\n\n" + f"bridge test - {fig_id}\n" + f"{mock_js}\n" + "\n" + "\n" + "\n" + "
\n" + f"
\n" + f" \n" + f"
\n" + f" \n" + "
\n" + "
\n" + "
\n" + f"\n" + "" + ) + + parent_path = base_dir / f"{fig_id}_parent.html" + parent_path.write_text(parent_html, encoding="utf-8") + return parent_path + + +# ============================================================================= +# Tier 2 -- iframe postMessage tests (browser only, no HTTP server) +# ============================================================================= + +class TestIframeMessaging: + """Verify the awi_state postMessage protocol via the standalone iframe. + + The ``interact_page`` fixture opens the figure HTML as a top-level page + (``window.parent === window``), so outbound awi_event forwarding is + disabled. Tests focus on the *inbound* direction: awi_state updates the + model. + """ + + def _open_fig(self, interact_page): + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + ax.plot(np.sin(np.linspace(0, 2 * np.pi, 64)), color="#4fc3f7") + panel_id = list(fig._plots_map.keys())[0] + plot = list(fig._plots_map.values())[0] + page = interact_page(fig) + return fig, plot, panel_id, page + + def test_awi_state_updates_model_key(self, interact_page): + """Posting {type:'awi_state', key, value} updates the model.""" + fig, plot, panel_id, page = self._open_fig(interact_page) + raw = page.evaluate(f"() => window._aplModel.get('panel_{panel_id}_json')") + assert raw is not None + curr = json.loads(raw) + curr["__sentinel__"] = "hello" + new_json = json.dumps(curr) + page.evaluate( + "() => window.postMessage(" + + json.dumps({"type": "awi_state", + "key": f"panel_{panel_id}_json", + "value": new_json}) + + ", '*')" + ) + _rafter(page) + updated = json.loads( + page.evaluate(f"() => window._aplModel.get('panel_{panel_id}_json')") + ) + assert updated.get("__sentinel__") == "hello" + + def test_no_echo_in_standalone_mode(self, interact_page): + """No awi_event is echoed back in standalone mode (FIG_ID is null).""" + fig, plot, panel_id, page = self._open_fig(interact_page) + raw = json.loads( + page.evaluate(f"() => window._aplModel.get('panel_{panel_id}_json')") + ) + raw["__flag__"] = 1 + new_json = json.dumps(raw) + page.evaluate( + "() => {" + " window._aplEventsSeen = 0;" + " window.addEventListener('message', (e) => {" + " if (e.data && e.data.type === 'awi_event') window._aplEventsSeen++;" + " });" + "}" + ) + page.evaluate( + "() => window.postMessage(" + + json.dumps({"type": "awi_state", + "key": f"panel_{panel_id}_json", + "value": new_json}) + + ", '*')" + ) + _rafter(page) + assert page.evaluate("() => window._aplEventsSeen") == 0 + + def test_awi_state_fires_change_listeners(self, interact_page): + """Posting awi_state triggers on('change:...') listeners.""" + fig, plot, panel_id, page = self._open_fig(interact_page) + page.evaluate( + f"() => {{" + f" window._aplChangeCount = 0;" + f" window._aplModel.on('change:panel_{panel_id}_json'," + f" () => window._aplChangeCount++);" + f"}}" + ) + raw = json.loads( + page.evaluate(f"() => window._aplModel.get('panel_{panel_id}_json')") + ) + raw["__change__"] = 1 + new_json = json.dumps(raw) + page.evaluate( + "() => window.postMessage(" + + json.dumps({"type": "awi_state", + "key": f"panel_{panel_id}_json", + "value": new_json}) + + ", '*')" + ) + _rafter(page) + assert page.evaluate("() => window._aplChangeCount") >= 1 + + def test_layout_json_push_updates_model(self, interact_page): + """layout_json can be updated via awi_state.""" + fig, plot, panel_id, page = self._open_fig(interact_page) + layout = json.loads( + page.evaluate("() => window._aplModel.get('layout_json') || '{}'") + ) + layout["__layout_sentinel__"] = "bridge_test" + new_json = json.dumps(layout) + page.evaluate( + "() => window.postMessage(" + + json.dumps({"type": "awi_state", "key": "layout_json", "value": new_json}) + + ", '*')" + ) + _rafter(page) + updated = json.loads( + page.evaluate("() => window._aplModel.get('layout_json') || '{}'") + ) + assert updated.get("__layout_sentinel__") == "bridge_test" + + +# ============================================================================= +# Tier 3 -- Full bridge mock-boot tests (HTTP server + mock Pyodide) +# ============================================================================= + +class TestFullBridgeBoot: + """Boot anywidget_bridge.js end-to-end via a mock loadPyodide. + + Each test builds a parent HTML page and serves it from the shared + ``http_server`` fixture. All Pyodide network I/O is replaced by the JS + mock so tests complete in milliseconds. + """ + + def _open(self, browser, base_url, parent_path, timeout=15_000): + url = f"{base_url}/{parent_path.name}" + page = browser.new_page() + page.goto(url, wait_until="domcontentloaded", timeout=timeout) + return page + + def _basic_fig(self): + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + ax.plot(np.sin(np.linspace(0, 2 * np.pi, 64)), color="#50fa7b") + panel_id = list(fig._plots_map.keys())[0] + return fig, panel_id + + def test_button_appears_when_iframe_present(self, http_server, _pw_browser): + """The activate button is injected on any page with a data-awi-fig iframe.""" + base_url, base_dir = http_server + fig, _ = self._basic_fig() + parent = _build_parent_page(fig, "btn_test_001", base_dir=base_dir) + page = self._open(_pw_browser, base_url, parent) + page.wait_for_function( + "() => !!document.querySelector('button.awi-activate-btn')", + timeout=5_000, + ) + tooltip = page.evaluate( + "() => document.querySelector('button.awi-activate-btn').title" + ) + assert "interactive" in tooltip.lower() + page.close() + + def test_boot_completes_all_mock_steps(self, http_server, _pw_browser): + """Clicking the button runs through all expected mock Pyodide boot steps.""" + base_url, base_dir = http_server + fig, _ = self._basic_fig() + parent = _build_parent_page(fig, "boot_test_001", base_dir=base_dir) + page = self._open(_pw_browser, base_url, parent) + _click_and_wait_boot(page) + steps = page.evaluate("() => window._APL_BOOT_STEPS") + for step in ("loadPyodide", "micropip_install", "stub_anywidget", + "install_monkey_patch", "run_example"): + assert step in steps, f"Step {step!r} missing; got {steps}" + page.close() + + def test_anywidgetPush_is_function_after_boot(self, http_server, _pw_browser): + """window._anywidgetPush must be a function after the push-hook step.""" + base_url, base_dir = http_server + fig, _ = self._basic_fig() + parent = _build_parent_page(fig, "apush_test_001", base_dir=base_dir) + page = self._open(_pw_browser, base_url, parent) + _click_and_wait_boot(page) + assert page.evaluate( + "() => typeof window._anywidgetPush === 'function'" + ), "window._anywidgetPush not installed" + page.close() + + def test_state_pushed_into_iframe_model(self, http_server, _pw_browser): + """After boot the iframe's model contains the figure's panel JSON.""" + base_url, base_dir = http_server + fig, panel_id = self._basic_fig() + expected = fig._plots_map[panel_id].to_state_dict() + parent = _build_parent_page(fig, "state_push_001", base_dir=base_dir) + page = self._open(_pw_browser, base_url, parent) + _click_and_wait_boot(page) + _wait_for_iframe_model(page, "state_push_001", panel_id) + raw = page.evaluate( + "() => {" + " const el = document.querySelector('iframe[data-awi-fig=\"state_push_001\"]');" + f" return el && el.contentWindow ? el.contentWindow._aplModel.get('panel_{panel_id}_json') : null;" + "}" + ) + assert raw is not None, "panel JSON not delivered to iframe model" + assert json.loads(raw).get("kind") == expected.get("kind") + page.close() + + def test_layout_json_pushed_into_iframe(self, http_server, _pw_browser): + """layout_json is delivered to the iframe model.""" + base_url, base_dir = http_server + fig, _ = self._basic_fig() + parent = _build_parent_page(fig, "layout_push_001", base_dir=base_dir) + page = self._open(_pw_browser, base_url, parent) + _click_and_wait_boot(page) + page.wait_for_function( + "() => {" + " const el = document.querySelector('iframe[data-awi-fig=\"layout_push_001\"]');" + " if (!el || !el.contentWindow) return false;" + " const mdl = el.contentWindow._aplModel;" + " if (!mdl) return false;" + " const raw = mdl.get('layout_json');" + " return typeof raw === 'string' && raw.length > 10;" + "}", + timeout=8_000, + ) + raw = page.evaluate( + "() => {" + " const el = document.querySelector('iframe[data-awi-fig=\"layout_push_001\"]');" + " return el.contentWindow._aplModel.get('layout_json');" + "}" + ) + assert raw is not None + assert "panel_specs" in json.loads(raw) + page.close() + + def test_event_message_forwarded_to_parent(self, http_server, _pw_browser): + """awi_event messages from the iframe arrive at the parent window.""" + base_url, base_dir = http_server + fig, panel_id = self._basic_fig() + parent = _build_parent_page(fig, "event_fwd_001", base_dir=base_dir) + page = self._open(_pw_browser, base_url, parent) + _click_and_wait_boot(page) + page.evaluate( + "() => {" + " window._aplReceivedEvents = [];" + " window.addEventListener('message', (e) => {" + " if (e.data && e.data.type === 'awi_event')" + " window._aplReceivedEvents.push(e.data);" + " });" + "}" + ) + fake_event = json.dumps({ + "event_type": "on_release", "panel_id": panel_id, + "widget_id": "w_fake", "x": 42.0, + }) + page.evaluate( + "() => window.postMessage(" + + json.dumps({"type": "awi_event", + "figId": "event_fwd_001", + "data": fake_event}) + + ", '*')" + ) + _rafter(page) + events = page.evaluate("() => window._aplReceivedEvents") + assert len(events) >= 1, "No awi_event reached the parent message bus" + assert events[0]["figId"] == "event_fwd_001" + page.close() + + def test_multiple_panels_all_receive_state(self, http_server, _pw_browser): + """All panels in a multi-panel figure have their state pushed.""" + base_url, base_dir = http_server + fig, axes = apl.subplots(1, 2, figsize=(700, 300)) + axes[0].plot(np.zeros(32)) + axes[1].plot(np.ones(32) * 0.5) + panel_ids = list(fig._plots_map.keys()) + assert len(panel_ids) == 2 + parent = _build_parent_page(fig, "multi_panel_001", base_dir=base_dir) + page = self._open(_pw_browser, base_url, parent) + _click_and_wait_boot(page) + for pid in panel_ids: + _wait_for_iframe_model(page, "multi_panel_001", pid) + for pid in panel_ids: + raw = page.evaluate( + "() => {" + " const el = document.querySelector('iframe[data-awi-fig=\"multi_panel_001\"]');" + f" return el && el.contentWindow ? el.contentWindow._aplModel.get('panel_{pid}_json') : null;" + "}" + ) + assert raw is not None, f"Panel {pid!r} state not pushed" + page.close() + + def test_button_shows_error_on_boot_failure(self, http_server, _pw_browser): + """If Pyodide boot fails the button switches to the error state.""" + base_url, base_dir = http_server + fig, _ = self._basic_fig() + parent = _build_parent_page(fig, "error_test_001", base_dir=base_dir) + html = (base_dir / "error_test_001_parent.html").read_text(encoding="utf-8") + # Patch mock to throw immediately on loadPyodide + html = html.replace( + "window.loadPyodide = async function() {", + "window.loadPyodide = async function() { throw new Error('mock boot failure'); //", + ) + (base_dir / "error_test_001_parent.html").write_text(html, encoding="utf-8") + page = self._open(_pw_browser, base_url, parent) + page.wait_for_function( + "() => !!document.querySelector('button.awi-activate-btn')", + timeout=5_000, + ) + page.click("button.awi-activate-btn") + page.wait_for_function( + "() => {" + " const btn = document.querySelector('button.awi-activate-btn');" + " return btn && btn.dataset.state === 'error';" + "}", + timeout=10_000, + ) + label = page.evaluate( + "() => document.querySelector('button.awi-activate-btn').title" + ) + assert "mock boot failure" in label + page.close() diff --git a/anyplotlib/tests/test_documentation/test_push_hook.py b/anyplotlib/tests/test_documentation/test_push_hook.py new file mode 100644 index 00000000..900d6dad --- /dev/null +++ b/anyplotlib/tests/test_documentation/test_push_hook.py @@ -0,0 +1,158 @@ +""" +tests/test_documentation/test_push_hook.py +========================================== + +Unit tests for the Python→JS state-push pathway. + +These tests require **no browser** — they call ``_push()`` / ``_push_layout()`` +directly and inspect the resulting traitlet values. They cover the same +ground that older tests exercised via ``_pyodide_push_hook``; the hook is now +gone and state flows through standard ``sync=True`` traitlets instead. + +Related browser tests (iframe postMessage, full mock-boot) live in +``test_bridge.py``. +""" + +from __future__ import annotations + +import json + +import numpy as np +import pytest + +import anyplotlib as apl +import anyplotlib.figure as _af + + +# ───────────────────────────────────────────────────────────────────────────── +# Helper shared by multiple tests +# ───────────────────────────────────────────────────────────────────────────── + +def _capture_fig_state(fig) -> dict[str, str]: + """Return ``{trait_name: json_string}`` for layout + every panel trait. + + Reads traitlet values directly after calling the push methods. This + works even when the value hasn't changed (traitlets suppress duplicate + change events, so an observe-based approach would return nothing on a + second call with the same state). + """ + fig._push_layout() + for pid in list(fig._plots_map): + fig._push(pid) + + captured: dict[str, str] = {} + captured["layout_json"] = fig.layout_json + for tname in fig.trait_names(): + if tname.startswith("panel_") and tname.endswith("_json"): + captured[tname] = getattr(fig, tname) + return captured + + +# ───────────────────────────────────────────────────────────────────────────── +# Tests +# ───────────────────────────────────────────────────────────────────────────── + +class TestPushHook: + """Verify _push() / _push_layout() write to sync=True traitlets correctly.""" + + def test_push_does_not_crash(self): + """Normal mode: _push() succeeds without error.""" + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + ax.plot(np.zeros(16)) # must not raise + + def test_layout_json_written_on_create(self): + """layout_json traitlet is set when a figure is created.""" + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + parsed = json.loads(fig.layout_json) + assert "panel_specs" in parsed, ( + f"layout_json missing 'panel_specs': {list(parsed.keys())}" + ) + + def test_panel_json_written_after_plot(self): + """panel_*_json traitlet is set when a plot is added.""" + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + ax.plot(np.sin(np.linspace(0, 2 * np.pi, 64))) + + panel_keys = [ + k for k in fig.trait_names() + if k.startswith("panel_") and k.endswith("_json") + ] + assert len(panel_keys) >= 1, "Expected at least one panel_*_json trait" + for k in panel_keys: + parsed = json.loads(getattr(fig, k)) + assert "kind" in parsed, ( + f"panel JSON missing 'kind': {list(parsed.keys())}" + ) + + def test_observe_fires_on_push(self): + """traitlets.observe() fires when _push() writes a panel trait.""" + seen: list[str] = [] + + def _watch(change): + seen.append(change["name"]) + + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + fig.observe(_watch) + ax.plot(np.zeros(8)) + fig.unobserve(_watch) + + assert any(k.startswith("panel_") for k in seen), ( + f"Expected a panel_* trait change; got: {seen}" + ) + + def test_panel_id_deterministic(self): + """Panel IDs derived from SubplotSpec must be identical across rebuilds.""" + ids: list[str] = [] + for _ in range(3): + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + ax.plot(np.zeros(8)) + ids.append(list(fig._plots_map.keys())[0]) + assert ids[0] == ids[1] == ids[2], ( + f"Panel ID must be deterministic; got {ids}" + ) + + def test_panel_ids_unique_in_multiplot(self): + """Each panel in a multi-panel figure has a unique ID.""" + fig, axes = apl.subplots(1, 3, figsize=(900, 300)) + for ax in axes: + ax.plot(np.zeros(8)) + ids = list(fig._plots_map.keys()) + assert len(ids) == len(set(ids)), f"Panel IDs not unique: {ids}" + + def test_panel_id_matches_grid_position(self): + """Panel IDs encode the SubplotSpec row/col bounds.""" + fig, axes = apl.subplots(2, 2, figsize=(600, 400)) + for ax in np.asarray(axes).flat: + ax.plot(np.zeros(4)) + ids = set(fig._plots_map.keys()) + for pid in ids: + assert pid.startswith("p"), f"Unexpected panel ID format: {pid!r}" + + def test_dispatch_event_callable_without_kernel(self): + """_dispatch_event() can be called directly as the Pyodide bridge does.""" + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + ax.plot(np.zeros(16)) + raw = json.dumps({ + "event_type": "on_zoom", + "panel_id": list(fig._plots_map.keys())[0], + "source": "js", + }) + fig._dispatch_event(raw) # must not raise + + def test_capture_fig_state_helper(self): + """_capture_fig_state returns both layout_json and panel JSON(s).""" + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + ax.plot(np.zeros(32)) + state = _capture_fig_state(fig) + assert "layout_json" in state, ( + f"Expected layout_json; got {list(state.keys())}" + ) + panel_keys = [k for k in state if k.startswith("panel_")] + assert len(panel_keys) >= 1, "Expected at least one panel_ key" + + def test_no_pyodide_push_hook_attribute(self): + """figure module no longer exposes _pyodide_push_hook.""" + assert not hasattr(_af, "_pyodide_push_hook"), ( + "_pyodide_push_hook should not exist on figure module" + ) + diff --git a/anyplotlib/tests/test_documentation/test_scraper.py b/anyplotlib/tests/test_documentation/test_scraper.py new file mode 100644 index 00000000..a8f38030 --- /dev/null +++ b/anyplotlib/tests/test_documentation/test_scraper.py @@ -0,0 +1,123 @@ +""" +tests/test_documentation/test_scraper.py +========================================= + +Tests for the Playwright-based scraper thumbnail functionality. + +Two sections: + +1. **PNG format validation** — verifies ``_make_thumbnail_png`` returns a valid + PNG array for common figure types. No Playwright required. + +2. **Dark-theme validation** — checks the top-left pixel of the thumbnail is + dark-blue (matching the library's dark theme). Requires Playwright; skipped + automatically when not installed. +""" + +from __future__ import annotations + +import importlib.util as _ilu + +import numpy as np +import pytest + +import anyplotlib as apl +from anyplotlib.sphinx_anywidget._scraper import _make_thumbnail_png +from anyplotlib.tests._png_utils import decode_png + + +# ───────────────────────────────────────────────────────────────────────────── +# Shared fixtures +# ───────────────────────────────────────────────────────────────────────────── + +@pytest.fixture +def line_fig(): + fig, ax = apl.subplots(1, 1, figsize=(400, 250)) + ax.plot(np.sin(np.linspace(0, 2 * np.pi, 128)), color="#4fc3f7") + return fig + + +@pytest.fixture +def imshow_fig(): + fig, ax = apl.subplots(1, 1, figsize=(320, 320)) + data = np.linspace(0, 1, 64 * 64, dtype=np.float32).reshape(64, 64) + ax.imshow(data) + return fig + + +@pytest.fixture +def multi_panel_fig(): + fig, axes = apl.subplots(1, 2, figsize=(640, 300)) + axes[0].plot(np.cos(np.linspace(0, 2 * np.pi, 64))) + axes[1].imshow( + np.random.default_rng(0).uniform(0, 1, (32, 32)).astype(np.float32) + ) + return fig + + +# ───────────────────────────────────────────────────────────────────────────── +# Helper +# ───────────────────────────────────────────────────────────────────────────── + +def _decode_thumbnail(fig, label: str): + """Return the decoded RGBA/RGB array for *fig*'s thumbnail, asserting PNG.""" + png = _make_thumbnail_png(fig) + assert png[:4] == b"\x89PNG", f"[{label}] result is not a PNG" + arr = decode_png(png) + assert arr.ndim == 3, f"[{label}] expected H×W×C array, got shape {arr.shape}" + assert arr.shape[2] in (3, 4), ( + f"[{label}] expected RGB/RGBA, got {arr.shape[2]} channels" + ) + return arr + + +# ───────────────────────────────────────────────────────────────────────────── +# Section 1 — PNG format validation (no Playwright required) +# ───────────────────────────────────────────────────────────────────────────── + +class TestThumbnailFormat: + """Verify that _make_thumbnail_png produces a well-formed PNG for each + common figure type.""" + + def test_thumbnail_1d_line(self, line_fig): + _decode_thumbnail(line_fig, "1D line") + + def test_thumbnail_2d_imshow(self, imshow_fig): + _decode_thumbnail(imshow_fig, "2D imshow") + + def test_thumbnail_multi_panel(self, multi_panel_fig): + _decode_thumbnail(multi_panel_fig, "multi-panel") + + +# ───────────────────────────────────────────────────────────────────────────── +# Section 2 — Dark-theme pixel validation (requires Playwright) +# ───────────────────────────────────────────────────────────────────────────── + +_requires_playwright = pytest.mark.skipif( + _ilu.find_spec("playwright") is None, + reason="playwright not installed", +) + + +@_requires_playwright +class TestThumbnailDarkTheme: + """Verify the top-left pixel of each thumbnail is dark-blue, matching the + library's default dark theme. These tests are skipped when Playwright is + not installed.""" + + def _assert_dark_theme(self, fig, label: str) -> None: + arr = _decode_thumbnail(fig, label) + r, g, b = int(arr[0, 0, 0]), int(arr[0, 0, 1]), int(arr[0, 0, 2]) + assert (b > r) and (b > 30), ( + f"[{label}] expected a dark-theme thumbnail " + f"(top-left RGB=({r},{g},{b}))" + ) + + def test_dark_theme_1d_line(self, line_fig): + self._assert_dark_theme(line_fig, "1D line") + + def test_dark_theme_2d_imshow(self, imshow_fig): + self._assert_dark_theme(imshow_fig, "2D imshow") + + def test_dark_theme_multi_panel(self, multi_panel_fig): + self._assert_dark_theme(multi_panel_fig, "multi-panel") diff --git a/anyplotlib/tests/test_embed/__init__.py b/anyplotlib/tests/test_embed/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/anyplotlib/tests/test_embed/test_embed_api.py b/anyplotlib/tests/test_embed/test_embed_api.py new file mode 100644 index 00000000..5dbb421e --- /dev/null +++ b/anyplotlib/tests/test_embed/test_embed_api.py @@ -0,0 +1,133 @@ +""" +Unit tests for anyplotlib.embed — kernel-free embedding API. + +Covers figure_state / to_html / save_html / esm_path / Figure.to_html and +the transport-agnostic FigureBridge (outbound forwarding, inbound event +dispatch, echo suppression, dynamic panel traits). +""" +from __future__ import annotations + +import json + +import numpy as np +import pytest + +import anyplotlib as apl +from anyplotlib.embed import ( + FigureBridge, esm_path, figure_state, save_html, to_html, +) + + +def _fig_with_image(): + fig, ax = apl.subplots(1, 1, figsize=(320, 240)) + plot = ax.imshow(np.zeros((16, 16), dtype=np.float32)) + return fig, plot + + +class TestFigureState: + def test_contains_core_keys(self): + fig, plot = _fig_with_image() + state = figure_state(fig) + assert "layout_json" in state + assert state["fig_width"] == 320 + assert f"panel_{plot._id}_json" in state + + def test_panel_state_is_json(self): + fig, plot = _fig_with_image() + state = figure_state(fig) + panel = json.loads(state[f"panel_{plot._id}_json"]) + assert panel["kind"] == "2d" + + +class TestHtmlExport: + def test_to_html_is_self_contained(self): + fig, plot = _fig_with_image() + html = to_html(fig) + assert html.startswith("") + assert "function render" in html # inlined ESM + assert f"panel_{plot._id}_json" in html # inlined state + + def test_figure_methods(self, tmp_path): + fig, _ = _fig_with_image() + assert fig.to_html() == to_html(fig) + out = fig.save_html(tmp_path / "fig.html") + assert out.read_text(encoding="utf-8") == fig.to_html() + + def test_save_html(self, tmp_path): + fig, _ = _fig_with_image() + p = save_html(fig, tmp_path / "plot.html", resizable=False) + assert p.exists() and p.stat().st_size > 10_000 + + def test_esm_path_exports_mount(self): + src = esm_path().read_text(encoding="utf-8") + assert "export function mount" in src + assert "export function createLocalModel" in src + + +class TestFigureBridge: + def test_outbound_forwarding(self): + fig, plot = _fig_with_image() + sent = [] + FigureBridge(fig, send=lambda k, v: sent.append(k)) + plot.set_title("hello") + assert f"panel_{plot._id}_json" in sent + + def test_outbound_layout_changes(self): + fig, _ = _fig_with_image() + sent = [] + FigureBridge(fig, send=lambda k, v: sent.append(k)) + fig.fig_width = 500 + assert "fig_width" in sent and "layout_json" in sent + + def test_dynamic_panel_traits_forwarded(self): + """Panels added AFTER bridge creation must still forward.""" + fig = apl.Figure(1, 2, figsize=(400, 200)) + sent = [] + FigureBridge(fig, send=lambda k, v: sent.append(k)) + plot = fig.add_subplot((0, 0)).plot(np.zeros(8)) + plot.set_title("late panel") + assert f"panel_{plot._id}_json" in sent + + def test_inbound_event_dispatches_callbacks(self): + fig, plot = _fig_with_image() + bridge = FigureBridge(fig, send=lambda k, v: None) + got = [] + + @plot.add_event_handler("pointer_down") + def on_down(event): + got.append((event.event_type, event.xdata)) + + bridge.receive("event_json", json.dumps({ + "panel_id": plot._id, "event_type": "pointer_down", + "x": 5, "y": 6, "xdata": 1.5, "ydata": 2.5, "button": 0, + })) + assert got == [("pointer_down", 1.5)] + + def test_inbound_no_echo(self): + """receive() must not re-send the same key back.""" + fig, plot = _fig_with_image() + sent = [] + bridge = FigureBridge(fig, send=lambda k, v: sent.append(k)) + key = f"panel_{plot._id}_json" + new_state = json.dumps({**plot.to_state_dict(), "title": "from js"}) + bridge.receive(key, new_state) + assert key not in sent + assert getattr(fig, key) == new_state + + def test_inbound_unknown_key_ignored(self): + fig, _ = _fig_with_image() + bridge = FigureBridge(fig, send=lambda k, v: None) + bridge.receive("panel_doesnotexist_json", "{}") # must not raise + + def test_snapshot_matches_figure_state(self): + fig, _ = _fig_with_image() + bridge = FigureBridge(fig, send=lambda k, v: None) + assert bridge.snapshot() == figure_state(fig) + + def test_close_stops_forwarding(self): + fig, plot = _fig_with_image() + sent = [] + bridge = FigureBridge(fig, send=lambda k, v: sent.append(k)) + bridge.close() + plot.set_title("after close") + assert sent == [] diff --git a/anyplotlib/tests/test_embed/test_embed_mount.py b/anyplotlib/tests/test_embed/test_embed_mount.py new file mode 100644 index 00000000..f2f19926 --- /dev/null +++ b/anyplotlib/tests/test_embed/test_embed_mount.py @@ -0,0 +1,237 @@ +""" +Playwright tests for the JS `mount()` embedding entry point. + +These build a page that uses ONLY the public embedding contract — import +``figure_esm.js``, call ``mount(el, state, opts)`` — exactly as an Electron +app would. No anywidget shim, no Jupyter, no `_repr_utils` template. +""" +from __future__ import annotations + +import json +import pathlib +import tempfile + +import numpy as np +import pytest + +import anyplotlib as apl +from anyplotlib.embed import esm_path, figure_state + +_MOUNT_PAGE = """ + + +
+ +""" + + +@pytest.fixture +def mount_page(_pw_browser): + """Open a figure via the public mount() API; return the live Page.""" + pages, paths = [], [] + + def _open(fig): + html = (_MOUNT_PAGE + .replace("__STATE__", json.dumps(figure_state(fig))) + .replace("__ESM__", json.dumps(esm_path().read_text(encoding="utf-8")))) + with tempfile.NamedTemporaryFile( + suffix=".html", mode="w", encoding="utf-8", delete=False + ) as fh: + fh.write(html) + tmp = pathlib.Path(fh.name) + paths.append(tmp) + page = _pw_browser.new_page() + pages.append(page) + page.goto(tmp.as_uri()) + page.wait_for_function("() => window._aplReady === true", timeout=15_000) + page.evaluate( + "() => new Promise(r => requestAnimationFrame(() => requestAnimationFrame(r)))" + ) + return page + + yield _open + for p in pages: + try: + p.close() + except Exception: + pass + for f in paths: + f.unlink(missing_ok=True) + + +def _fig_with_image(): + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + q = np.linspace(0, 10, 32) + plot = ax.imshow(np.random.default_rng(0).random((32, 32)), axes=[q, q]) + return fig, plot + + +def _plot_canvas_ink(page) -> int: + return page.evaluate("""() => { + const c = document.querySelector('#host canvas'); + if (!c) return -1; + const d = c.getContext('2d').getImageData(0, 0, c.width, c.height).data; + let n = 0; + for (let i = 3; i < d.length; i += 4) if (d[i] > 0) n++; + return n; + }""") + + +class TestMountRenders: + def test_canvases_created_with_ink(self, mount_page): + fig, _ = _fig_with_image() + page = mount_page(fig) + n_canvas = page.evaluate("() => document.querySelectorAll('#host canvas').length") + assert n_canvas >= 3, f"expected canvas stack, got {n_canvas}" + assert _plot_canvas_ink(page) > 1000, "image canvas has no rendered pixels" + + def test_multiple_mounts_one_page_mdi_style(self, mount_page): + """Two figures in one page must not interfere (MDI sub-windows).""" + fig, _ = _fig_with_image() + page = mount_page(fig) + # Mount a second, independent figure into a fresh container. + fig2, _ = _fig_with_image() + state2 = json.dumps(figure_state(fig2)) + page.evaluate(f"""() => {{ + const div = document.createElement('div'); + div.id = 'host2'; + document.body.appendChild(div); + const esm = {json.dumps(esm_path().read_text(encoding="utf-8"))}; + const blobUrl = URL.createObjectURL(new Blob([esm], {{type:'text/javascript'}})); + return import(blobUrl).then(mod => {{ + window._handle2 = mod.mount(div, {state2}, {{}}); + }}); + }}""") + page.wait_for_function("() => window._handle2 !== undefined", timeout=15_000) + n1 = page.evaluate("() => document.querySelectorAll('#host canvas').length") + n2 = page.evaluate("() => document.querySelectorAll('#host2 canvas').length") + assert n1 >= 3 and n2 >= 3 + + def test_dispose_clears_dom(self, mount_page): + fig, _ = _fig_with_image() + page = mount_page(fig) + page.evaluate("() => window._handle.dispose()") + n = page.evaluate("() => document.querySelectorAll('#host canvas').length") + assert n == 0 + + +class TestMountLiveUpdates: + def test_set_panel_state_rerenders(self, mount_page): + """setPanelState() with a new title must draw title pixels.""" + fig, plot = _fig_with_image() + page = mount_page(fig) + + def title_ink(): + return page.evaluate("""() => { + const tc = Array.from(document.querySelectorAll('#host canvas')) + .find(c => c.style.zIndex === '8'); + if (!tc) return -1; + const d = tc.getContext('2d').getImageData(0,0,tc.width,tc.height).data; + let n = 0; + for (let i = 3; i < d.length; i += 4) if (d[i] > 0) n++; + return n; + }""") + + assert title_ink() == 0 + new_state = {**plot.to_state_dict(), "title": "Live from JS"} + page.evaluate( + "(args) => window._handle.setPanelState(args[0], args[1])", + [plot._id, new_state], + ) + page.wait_for_timeout(150) + assert title_ink() > 0, "setPanelState() did not re-render the title" + + def test_apply_update_does_not_echo(self, mount_page): + """applyUpdate() (Python → JS path) must not bounce back via onSync.""" + fig, plot = _fig_with_image() + page = mount_page(fig) + new_state = json.dumps({**plot.to_state_dict(), "title": "no echo"}) + page.evaluate( + "(args) => window._handle.applyUpdate('panel_' + args[0] + '_json', args[1])", + [plot._id, new_state], + ) + page.wait_for_timeout(100) + syncs = page.evaluate("() => window._syncs.map(s => s.key)") + assert f"panel_{plot._id}_json" not in syncs + + +class TestMountEvents: + def test_pointer_event_reaches_onevent_and_onsync(self, mount_page): + fig, plot = _fig_with_image() + page = mount_page(fig) + # Click the centre of the image area. + page.mouse.move(200, 150) + page.mouse.down() + page.mouse.up() + page.wait_for_timeout(200) + + events = page.evaluate("() => window._events") + assert any(e.get("event_type") == "pointer_down" for e in events), ( + f"no pointer_down in onEvent stream: {[e.get('event_type') for e in events]}" + ) + assert all(e.get("panel_id") == plot._id + for e in events if "panel_id" in e) + syncs = page.evaluate("() => window._syncs.map(s => s.key)") + assert "event_json" in syncs, "event_json was not flushed through onSync" + + +class TestBridgeRoundTrip: + """End-to-end Level-3 pattern: mount() in a real browser wired to a live + Python FigureBridge, with the test harness acting as the transport + (in an Electron app this would be a WebSocket / IPC pipe).""" + + def test_full_round_trip(self, mount_page): + from anyplotlib.embed import FigureBridge + + fig, plot = _fig_with_image() + clicks = [] + + @plot.add_event_handler("pointer_down") + def on_click(event): + clicks.append((event.xdata, event.ydata)) + + outbound = [] # Python → JS queue + bridge = FigureBridge(fig, send=lambda k, v: outbound.append((k, v))) + page = mount_page(fig) + + # ── JS → Python: click in the browser, pump onSync into the bridge ── + page.mouse.move(200, 150) + page.mouse.down() + page.mouse.up() + page.wait_for_timeout(200) + for s in page.evaluate("() => window._syncs"): + bridge.receive(s["key"], s["value"]) + assert clicks, "browser click did not reach the Python callback" + assert clicks[0][0] is not None, "event lost its data coordinates" + + # ── Python → JS: set_title streams back into rendered pixels ── + outbound.clear() + plot.set_title("From Python") + assert outbound, "Python mutation produced no bridge messages" + for k, v in outbound: + page.evaluate("(a) => window._handle.applyUpdate(a[0], a[1])", [k, v]) + page.wait_for_timeout(150) + title_ink = page.evaluate("""() => { + const tc = Array.from(document.querySelectorAll('#host canvas')) + .find(c => c.style.zIndex === '8'); + if (!tc) return -1; + const d = tc.getContext('2d').getImageData(0,0,tc.width,tc.height).data; + let n = 0; + for (let i = 3; i < d.length; i += 4) if (d[i] > 0) n++; + return n; + }""") + assert title_ink > 0, "Python set_title did not render in the browser" + bridge.close() diff --git a/anyplotlib/tests/test_examples/__init__.py b/anyplotlib/tests/test_examples/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/anyplotlib/tests/test_examples/test_interactive_examples.py b/anyplotlib/tests/test_examples/test_interactive_examples.py new file mode 100644 index 00000000..157ae253 --- /dev/null +++ b/anyplotlib/tests/test_examples/test_interactive_examples.py @@ -0,0 +1,28 @@ +"""Smoke tests: each EM example script must import and execute without error.""" +import importlib.util +import pathlib + +import pytest + +EXAMPLES_DIR = pathlib.Path(__file__).parents[3] / "Examples" / "Interactive" + +SCRIPTS = [ + "plot_particle_picker.py", + "plot_eels_explorer.py", + "plot_threshold_explorer.py", + "plot_spectra_roi_inspector.py", + "plot_voxel_grain_explorer.py", +] + + +def _exec_script(name: str) -> None: + path = EXAMPLES_DIR / name + mod_name = f"_smoke_ex_{path.stem}" + spec = importlib.util.spec_from_file_location(mod_name, path) + mod = importlib.util.module_from_spec(spec) + spec.loader.exec_module(mod) + + +@pytest.mark.parametrize("script", SCRIPTS) +def test_example_executes(script: str) -> None: + _exec_script(script) diff --git a/anyplotlib/tests/test_interactive/__init__.py b/anyplotlib/tests/test_interactive/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/anyplotlib/tests/test_interactive/_event_test_utils.py b/anyplotlib/tests/test_interactive/_event_test_utils.py new file mode 100644 index 00000000..9e87ddd4 --- /dev/null +++ b/anyplotlib/tests/test_interactive/_event_test_utils.py @@ -0,0 +1,35 @@ +"""Shared helpers for event system Playwright tests.""" +from __future__ import annotations + +# Layout constants (match figure_esm.js) +PAD_L, PAD_R, PAD_T, PAD_B = 58, 12, 12, 42 +GRID_PAD = 8 + + +def _collect_events(page) -> None: + """Monkey-patch model.set to accumulate all event_json payloads in window._aplAllEvents.""" + page.evaluate("""() => { + window._aplAllEvents = []; + const orig = window._aplModel.set.bind(window._aplModel); + window._aplModel.set = (k, v) => { + if (k === 'event_json') { + try { window._aplAllEvents.push(JSON.parse(v)); } catch(_) {} + } + return orig(k, v); + }; + }""") + + +def _get_events(page, event_type=None) -> list: + """Return collected events, optionally filtered by event_type.""" + events = page.evaluate("() => window._aplAllEvents") + if event_type: + return [e for e in events if e.get("event_type") == event_type] + return events + + +def _plot_center_page(fig_w: int = 400, fig_h: int = 300) -> tuple[int, int]: + """Return page coords for the center of the plot area.""" + cx = GRID_PAD + PAD_L + (fig_w - PAD_L - PAD_R) // 2 + cy = GRID_PAD + PAD_T + (fig_h - PAD_T - PAD_B) // 2 + return cx, cy diff --git a/anyplotlib/tests/test_interactive/test_blit_audit.py b/anyplotlib/tests/test_interactive/test_blit_audit.py new file mode 100644 index 00000000..5feb745f --- /dev/null +++ b/anyplotlib/tests/test_interactive/test_blit_audit.py @@ -0,0 +1,496 @@ +""" +tests/test_interactive/test_blit_audit.py +========================================== + +Playwright tests that audit canvas redraws and verify blitting behaviour. + +What we are testing +-------------------- +1. **Blit cache correctness** — The ``blitCache`` in ``figure_esm.js`` keyed + on ``(b64, lutKey, w, h)`` must be a genuine cache: adding a marker must + NOT create a new ``OffscreenCanvas`` (GPU texture), while changes to the + LUT parameters (display_min/max, scale_mode, cmap) MUST create a new one. + +2. **No flash on marker add** — Since markers live on a separate + ``markersCanvas`` layer, adding a marker should only clear-and-redraw that + layer. The base ``plotCanvas`` texture must be preserved (blitted, not + rebuilt). + +3. **Draw-call auditing** — Each ``model.set(panel__json, ...)`` call + triggers exactly one ``draw2d`` invocation. We count draw calls via an + injected Proxy on ``window._aplTiming`` that increments + ``window._aplDrawCount[id]`` on every timing assignment. + +Instrumentation strategy +------------------------- +Two counters are injected via ``page.add_init_script()`` before any page JS: + +**OffscreenCanvas counter** — wraps the global class: + + window._aplBitmapRebuildCount = 0 + class _TrackedOffscreen extends OffscreenCanvas { + constructor(w, h) { super(w, h); window._aplBitmapRebuildCount++; } + } + globalThis.OffscreenCanvas = _TrackedOffscreen; + +After the initial render this counter equals 1. Each blit-cache miss bumps +it by 1; a cache hit leaves it unchanged. + +**Draw-call counter** — intercepts ``_aplTiming[id]`` property assignments: + + window._aplTiming = new Proxy({}, { + set(target, key, value) { + window._aplDrawCount[key]++; + ... + } + }); + +``_recordFrame`` in ``figure_esm.js`` sets ``window._aplTiming[id]`` every +draw when ``n >= 2`` (rolling buffer has at least 2 entries). The very first +draw (n=1) is not counted, so ``_aplDrawCount[id] = total_draws - 1``. +Delta tests are used throughout to avoid dependence on this off-by-one. +""" +from __future__ import annotations + +import pathlib +import tempfile + +import numpy as np +import pytest + +import anyplotlib as apl + +# --------------------------------------------------------------------------- +# Init script: injects both counters before page JS runs +# --------------------------------------------------------------------------- + +_INSTRUMENTATION_SCRIPT = """ +(function () { + // ── OffscreenCanvas rebuild counter ─────────────────────────────────────── + window._aplBitmapRebuildCount = 0; + const _OrigOffscreen = globalThis.OffscreenCanvas; + class _TrackedOffscreen extends _OrigOffscreen { + constructor(w, h) { + super(w, h); + window._aplBitmapRebuildCount++; + } + } + globalThis.OffscreenCanvas = _TrackedOffscreen; + + // ── Draw-call counter via _aplTiming Proxy ──────────────────────────────── + // _recordFrame() in figure_esm.js does: + // if (!window._aplTiming) window._aplTiming = {}; // skipped: proxy is truthy + // window._aplTiming[p.id] = { count: n, ... }; // triggers our setter + // This fires on every draw after the rolling buffer reaches n >= 2. + window._aplDrawCount = {}; + window._aplTiming = new Proxy({}, { + set: function(target, key, value) { + if (typeof key === 'string') { + window._aplDrawCount[key] = (window._aplDrawCount[key] || 0) + 1; + } + return Reflect.set(target, key, value); + } + }); +})(); +""" + + +# --------------------------------------------------------------------------- +# blit_page fixture +# --------------------------------------------------------------------------- + +@pytest.fixture +def blit_page(_pw_browser): + """Like ``bench_page`` but injects rebuild + draw-call counters. + + Uses ``page.add_init_script()`` to wrap ``OffscreenCanvas`` and + ``window._aplTiming`` *before* the page's ``render()`` function runs. + + Usage:: + + def test_something(blit_page): + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + plot = ax.imshow(np.zeros((32, 32))) + page = blit_page(fig) + assert _get_rebuild_count(page) == 1 + """ + from anyplotlib.tests.conftest import _build_interact_html + + _pages: list = [] + _paths: list = [] + + def _open(widget): + html = _build_interact_html(widget) + with tempfile.NamedTemporaryFile( + suffix=".html", mode="w", encoding="utf-8", delete=False + ) as fh: + fh.write(html) + tmp = pathlib.Path(fh.name) + _paths.append(tmp) + + page = _pw_browser.new_page() + _pages.append(page) + # Inject counters BEFORE navigation so they wrap globals at startup. + page.add_init_script(_INSTRUMENTATION_SCRIPT) + page.goto(tmp.as_uri()) + page.wait_for_function("() => window._aplReady === true", timeout=30_000) + page.evaluate( + "() => new Promise(r => requestAnimationFrame(() => requestAnimationFrame(r)))" + ) + return page + + yield _open + + for page in _pages: + try: + page.close() + except Exception: + pass + for path in _paths: + path.unlink(missing_ok=True) + + +# --------------------------------------------------------------------------- +# JS helpers +# --------------------------------------------------------------------------- + +def _get_rebuild_count(page) -> int: + """Number of OffscreenCanvas instances created (= bitmap rebuilds).""" + return page.evaluate("() => window._aplBitmapRebuildCount") + + +def _get_draw_count(page, panel_id: str) -> int: + """Monotonic draw-call count for *panel_id* via the _aplTiming proxy. + + Returns 0 after the initial render (n=1, proxy not yet set) and + increments by 1 for each subsequent draw. Delta comparisons are + therefore reliable: draw_after - draw_before == draws_triggered. + """ + return page.evaluate( + "([id]) => (window._aplDrawCount && window._aplDrawCount[id] || 0)", + [panel_id], + ) + + +def _set_panel_state(page, panel_id: str, update: dict) -> None: + """Merge *update* into the panel state and push to the model synchronously.""" + page.evaluate( + """([id, patch]) => { + const key = 'panel_' + id + '_json'; + const st = JSON.parse(window._aplModel.get(key)); + Object.assign(st, patch); + window._aplModel.set(key, JSON.stringify(st)); + }""", + [panel_id, update], + ) + + +def _add_circle_markers(page, panel_id: str, offsets=None) -> None: + """Append a circle marker group to the panel state (no image data change).""" + if offsets is None: + offsets = [[16, 16]] + page.evaluate( + """([id, offsets]) => { + const key = 'panel_' + id + '_json'; + const st = JSON.parse(window._aplModel.get(key)); + const existing = st.markers || []; + existing.push({ + type: 'circles', + offsets: offsets, + sizes: [5], + color: '#ff0000', + }); + st.markers = existing; + window._aplModel.set(key, JSON.stringify(st)); + }""", + [panel_id, offsets], + ) + + +def _wait_raf(page) -> None: + """Wait two rAF ticks so canvas compositing catches up.""" + page.evaluate( + "() => new Promise(r => requestAnimationFrame(() => requestAnimationFrame(r)))" + ) + + +# ══════════════════════════════════════════════════════════════════════════════ +# Blit cache correctness +# ══════════════════════════════════════════════════════════════════════════════ + +class TestBlitCacheCorrectness: + """The blit cache key (b64 string + LUT params) must be honoured.""" + + def _make_page(self, blit_page): + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + plot = ax.imshow(np.zeros((32, 32), dtype=np.float32)) + page = blit_page(fig) + return page, plot + + def test_initial_render_creates_one_bitmap(self, blit_page): + """After initial render exactly one OffscreenCanvas has been created.""" + page, plot = self._make_page(blit_page) + count = _get_rebuild_count(page) + assert count == 1, ( + f"Expected 1 OffscreenCanvas after initial render, got {count}" + ) + + def test_adding_marker_does_not_rebuild_bitmap(self, blit_page): + """Adding a marker uses the cached bitmap — no new OffscreenCanvas. + + This is the core 'no flash' assertion: markers live on a separate + canvas layer, so the base image texture must not be invalidated. + """ + page, plot = self._make_page(blit_page) + count_before = _get_rebuild_count(page) + + _add_circle_markers(page, plot._id) + _wait_raf(page) + + count_after = _get_rebuild_count(page) + assert count_after == count_before, ( + f"Adding a marker must NOT create a new OffscreenCanvas " + f"(before={count_before}, after={count_after})" + ) + + def test_adding_multiple_markers_does_not_rebuild_bitmap(self, blit_page): + """Adding N markers sequentially causes 0 extra OffscreenCanvas creations.""" + page, plot = self._make_page(blit_page) + count_before = _get_rebuild_count(page) + + for i in range(5): + _add_circle_markers(page, plot._id, offsets=[[i * 5, i * 5]]) + _wait_raf(page) + + count_after = _get_rebuild_count(page) + assert count_after == count_before, ( + f"Adding 5 markers must not rebuild the bitmap " + f"(before={count_before}, after={count_after})" + ) + + def test_lut_change_invalidates_cache(self, blit_page): + """Changing display_min (LUT key) creates exactly one new OffscreenCanvas.""" + page, plot = self._make_page(blit_page) + count_before = _get_rebuild_count(page) + + _set_panel_state(page, plot._id, {"display_min": -0.5}) + _wait_raf(page) + + count_after = _get_rebuild_count(page) + assert count_after == count_before + 1, ( + f"Changing display_min must trigger one bitmap rebuild " + f"(before={count_before}, after={count_after})" + ) + + def test_lut_change_then_marker_add_reuses_new_bitmap(self, blit_page): + """After a LUT rebuild, subsequent marker adds still hit the cache.""" + page, plot = self._make_page(blit_page) + + # Invalidate cache with LUT change + _set_panel_state(page, plot._id, {"display_min": -0.5}) + count_after_lut = _get_rebuild_count(page) + + # Marker add must reuse the updated bitmap + _add_circle_markers(page, plot._id) + _wait_raf(page) + + count_after_marker = _get_rebuild_count(page) + assert count_after_marker == count_after_lut, ( + "After LUT rebuild, marker add must still use the cached bitmap. " + f"(after_lut={count_after_lut}, after_marker={count_after_marker})" + ) + + def test_display_max_change_invalidates_cache(self, blit_page): + """Changing display_max also invalidates the blit cache.""" + page, plot = self._make_page(blit_page) + count_before = _get_rebuild_count(page) + + _set_panel_state(page, plot._id, {"display_max": 2.0}) + _wait_raf(page) + + count_after = _get_rebuild_count(page) + assert count_after > count_before, ( + "Changing display_max must trigger a bitmap rebuild" + ) + + def test_pan_does_not_rebuild_bitmap(self, blit_page): + """Changing center_x/y (pan) does not rebuild the bitmap.""" + page, plot = self._make_page(blit_page) + count_before = _get_rebuild_count(page) + + _set_panel_state(page, plot._id, {"center_x": 0.6, "center_y": 0.4}) + _wait_raf(page) + + count_after = _get_rebuild_count(page) + assert count_after == count_before, ( + f"Pan (center_x/y change) must not rebuild the bitmap " + f"(before={count_before}, after={count_after})" + ) + + def test_zoom_does_not_rebuild_bitmap(self, blit_page): + """Changing zoom does not rebuild the bitmap.""" + page, plot = self._make_page(blit_page) + count_before = _get_rebuild_count(page) + + _set_panel_state(page, plot._id, {"zoom": 2.0}) + _wait_raf(page) + + count_after = _get_rebuild_count(page) + assert count_after == count_before, ( + f"Zoom change must not rebuild the bitmap " + f"(before={count_before}, after={count_after})" + ) + + +# ══════════════════════════════════════════════════════════════════════════════ +# Draw-call count audit +# ══════════════════════════════════════════════════════════════════════════════ + +class TestDrawCallAudit: + """Each state mutation must trigger exactly one draw2d call. + + Draw counts use _aplDrawCount which increments on every _aplTiming[id] + assignment (after n≥2 frames). The very first draw (n=1) is not counted, + so deltas are used: draw_after - draw_before == draws_triggered_by_action. + """ + + def _make_page(self, blit_page): + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + plot = ax.imshow(np.zeros((32, 32), dtype=np.float32)) + page = blit_page(fig) + return page, plot + + def test_draw_count_baseline_after_initial_render(self, blit_page): + """After initial render only, draw count = 0 (1 draw occurred, n=1 < threshold).""" + page, plot = self._make_page(blit_page) + count = _get_draw_count(page, plot._id) + assert count == 0, ( + f"After initial render, draw count must be 0 (n=1 not yet counted). " + f"Got {count} — indicates unexpected extra draws during setup." + ) + + def test_marker_add_triggers_exactly_one_draw(self, blit_page): + """Adding a single marker triggers exactly one additional draw2d call.""" + page, plot = self._make_page(blit_page) + draw_before = _get_draw_count(page, plot._id) + + _add_circle_markers(page, plot._id) + + draw_after = _get_draw_count(page, plot._id) + assert draw_after == draw_before + 1, ( + f"Adding a marker must trigger exactly 1 draw " + f"(before={draw_before}, after={draw_after}, delta={draw_after - draw_before})" + ) + + def test_n_marker_adds_trigger_n_draws(self, blit_page): + """Adding N markers sequentially triggers exactly N draw2d calls.""" + page, plot = self._make_page(blit_page) + draw_before = _get_draw_count(page, plot._id) + + n = 5 + for i in range(n): + _add_circle_markers(page, plot._id, offsets=[[i * 4, i * 4]]) + + draw_after = _get_draw_count(page, plot._id) + assert draw_after == draw_before + n, ( + f"Adding {n} markers must trigger exactly {n} draws " + f"(before={draw_before}, after={draw_after}, delta={draw_after - draw_before})" + ) + + def test_lut_change_triggers_exactly_one_draw(self, blit_page): + """A LUT parameter change triggers exactly one draw2d call.""" + page, plot = self._make_page(blit_page) + draw_before = _get_draw_count(page, plot._id) + + _set_panel_state(page, plot._id, {"display_min": -0.5}) + + draw_after = _get_draw_count(page, plot._id) + assert draw_after == draw_before + 1, ( + f"LUT change must trigger exactly 1 draw " + f"(before={draw_before}, after={draw_after})" + ) + + def test_pan_triggers_exactly_one_draw(self, blit_page): + """A Python-side pan update triggers exactly one draw2d call.""" + page, plot = self._make_page(blit_page) + draw_before = _get_draw_count(page, plot._id) + + _set_panel_state(page, plot._id, {"center_x": 0.6}) + + draw_after = _get_draw_count(page, plot._id) + assert draw_after == draw_before + 1, ( + "Python-side pan update must trigger exactly 1 draw " + f"(before={draw_before}, after={draw_after})" + ) + + +# ══════════════════════════════════════════════════════════════════════════════ +# No-flash integration test +# ══════════════════════════════════════════════════════════════════════════════ + +class TestNoFlashOnMarkerAdd: + """End-to-end: adding a marker must not flash (no bitmap rebuild + 1 draw).""" + + def test_no_flash_single_marker(self, blit_page): + """Single marker add: one extra draw, zero extra bitmap rebuilds.""" + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + plot = ax.imshow( + np.random.default_rng(0).standard_normal((64, 64)).astype(np.float32) + ) + page = blit_page(fig) + + rebuild_before = _get_rebuild_count(page) + draw_before = _get_draw_count(page, plot._id) + + _add_circle_markers(page, plot._id, offsets=[[32, 32]]) + _wait_raf(page) + + rebuild_after = _get_rebuild_count(page) + draw_after = _get_draw_count(page, plot._id) + + assert rebuild_after == rebuild_before, ( + "Adding a marker must not rebuild the GPU bitmap (would cause a flash). " + f"OffscreenCanvas count: {rebuild_before} → {rebuild_after}" + ) + assert draw_after == draw_before + 1, ( + f"Expected exactly 1 new draw call, got {draw_after - draw_before}" + ) + + def test_no_flash_multiple_markers_on_real_image(self, blit_page): + """Multiple marker adds on a real image: zero bitmap rebuilds throughout.""" + rng = np.random.default_rng(42) + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + plot = ax.imshow(rng.standard_normal((128, 128)).astype(np.float32)) + page = blit_page(fig) + + rebuild_before = _get_rebuild_count(page) + + for i in range(4): + _add_circle_markers( + page, plot._id, + offsets=[[int(rng.integers(10, 118)), int(rng.integers(10, 118))]] + ) + _wait_raf(page) + + rebuild_after = _get_rebuild_count(page) + assert rebuild_after == rebuild_before, ( + "4 sequential marker adds must not rebuild the bitmap. " + f"OffscreenCanvas count: {rebuild_before} → {rebuild_after}" + ) + + def test_flash_does_occur_on_lut_change(self, blit_page): + """Sanity: changing LUT params DOES create a new OffscreenCanvas.""" + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + plot = ax.imshow(np.zeros((32, 32), dtype=np.float32)) + page = blit_page(fig) + + rebuild_before = _get_rebuild_count(page) + + _set_panel_state(page, plot._id, {"display_min": -1.0, "display_max": 1.0}) + _wait_raf(page) + + rebuild_after = _get_rebuild_count(page) + assert rebuild_after > rebuild_before, ( + "LUT change must create a new OffscreenCanvas (confirms counter works). " + f"OffscreenCanvas count: {rebuild_before} → {rebuild_after}" + ) diff --git a/anyplotlib/tests/test_interactive/test_callbacks_playwright.py b/anyplotlib/tests/test_interactive/test_callbacks_playwright.py new file mode 100644 index 00000000..6920f467 --- /dev/null +++ b/anyplotlib/tests/test_interactive/test_callbacks_playwright.py @@ -0,0 +1,491 @@ +""" +tests/test_interactive/test_callbacks_playwright.py +==================================================== + +Playwright integration tests for the callback system. + +Each test exercises the full JS → Python dispatch pipeline: + 1. ``interact_page(fig)`` opens the standalone HTML in headless Chromium. + 2. ``_collect_events(page)`` intercepts every ``event_json`` write on the + JS model shim so we can verify the browser emitted the right payload. + 3. ``page.mouse.*`` / ``page.keyboard.*`` fires real browser events. + 4. ``_sim(fig, plot, event_type, ...)`` replays the same payload through + ``fig._dispatch_event`` to verify the Python handler receives it. + +Because the standalone HTML has no live Python kernel, steps 3 and 4 are +independent but complementary: step 3 confirms JS sends the event; step 4 +confirms Python processes it. + +Coordinate system (mirrors figure_esm.js constants) +---------------------------------------------------- + PAD_L=58 PAD_R=12 PAD_T=12 PAD_B=42 GRID_PAD=8 + 400×300 figure → plot area page-coords: x≈66, y≈20, w≈330, h≈246 +""" +from __future__ import annotations + +import json + +import numpy as np +import pytest + +import anyplotlib as apl +from anyplotlib.tests.test_interactive._event_test_utils import ( + _collect_events, + _get_events, + _plot_center_page, + GRID_PAD, + PAD_L, PAD_R, PAD_T, PAD_B, +) + +FIG_W, FIG_H = 400, 300 + + +# --------------------------------------------------------------------------- +# Shared helpers +# --------------------------------------------------------------------------- + +def _sim(fig, plot, event_type: str, **fields) -> None: + """Drive the Python dispatch path directly (no browser needed).""" + payload = {"source": "js", "panel_id": plot._id, "event_type": event_type} + payload.update(fields) + fig._dispatch_event(json.dumps(payload)) + + +def _make_1d(interact_page): + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + plot = ax.plot(np.sin(np.linspace(0, 6.28, 128))) + page = interact_page(fig) + _collect_events(page) + return fig, plot, page + + +def _make_2d(interact_page): + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + plot = ax.imshow(np.zeros((32, 32), dtype=np.float32)) + page = interact_page(fig) + _collect_events(page) + return fig, plot, page + + +def _center(): + return _plot_center_page(FIG_W, FIG_H) + + +def _plot_left_edge(): + """Page x-coordinate of the left edge of the plot area.""" + return GRID_PAD + PAD_L + 5 + + +def _plot_top_edge(): + """Page y-coordinate of the top edge of the plot area.""" + return GRID_PAD + PAD_T + 5 + + +def _outside_plot(): + """Page coords clearly outside the plot area (title bar region).""" + return GRID_PAD + 10, GRID_PAD + 5 + + +# ═══════════════════════════════════════════════════════════════════════════════ +# 1. Event types — JS emission verified with Playwright +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestEventTypesJsEmission: + """Verify each event type is emitted by the JS engine on real interactions.""" + + def test_pointer_down_emitted(self, interact_page): + _, _, page = _make_2d(interact_page) + cx, cy = _center() + page.mouse.move(cx, cy) + page.mouse.down() + page.wait_for_timeout(80) + page.mouse.up() + events = _get_events(page, "pointer_down") + assert len(events) >= 1, "pointer_down should be emitted on click" + + def test_pointer_up_emitted(self, interact_page): + # pointer_up fires on significant drag release (not a plain click). + _, _, page = _make_2d(interact_page) + cx, cy = _center() + page.mouse.move(cx, cy) + page.mouse.down() + page.mouse.move(cx + 50, cy, steps=10) + page.mouse.up() + page.wait_for_timeout(100) + events = _get_events(page, "pointer_up") + assert len(events) >= 1, "pointer_up should be emitted after a drag release" + + def test_pointer_move_emitted(self, interact_page): + # pointer_move fires on every mousemove over a 3D panel. + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + x = np.linspace(-1, 1, 8) + X, Y = np.meshgrid(x, x) + plot = ax.plot_surface(X, Y, X ** 2 + Y ** 2) + page = interact_page(fig) + _collect_events(page) + cx, cy = _center() + page.mouse.move(cx, cy) + page.mouse.down() + page.mouse.move(cx + 30, cy, steps=8) + page.mouse.up() + page.wait_for_timeout(50) + events = _get_events(page, "pointer_move") + assert len(events) > 0, "pointer_move events should fire during 3D drag" + + def test_double_click_emitted(self, interact_page): + _, _, page = _make_2d(interact_page) + cx, cy = _center() + page.mouse.dblclick(cx, cy) + page.wait_for_timeout(100) + events = _get_events(page, "double_click") + assert len(events) >= 1, "double_click should be emitted on dblclick" + + def test_wheel_emitted(self, interact_page): + _, _, page = _make_2d(interact_page) + cx, cy = _center() + page.mouse.move(cx, cy) + page.mouse.wheel(0, -100) + page.wait_for_timeout(80) + events = _get_events(page, "wheel") + assert len(events) >= 1, "wheel event should be emitted on scroll" + + def test_key_down_emitted(self, interact_page): + _, _, page = _make_2d(interact_page) + cx, cy = _center() + page.mouse.move(cx, cy) + page.mouse.down() + page.mouse.up() + page.wait_for_timeout(50) + page.keyboard.press("r") + page.wait_for_timeout(80) + events = _get_events(page, "key_down") + assert len(events) >= 1, "key_down should be emitted on key press" + + def test_pointer_enter_emitted(self, interact_page): + _, _, page = _make_2d(interact_page) + ox, oy = _outside_plot() + px = _plot_left_edge() + py = _plot_top_edge() + page.mouse.move(ox, oy) + page.wait_for_timeout(30) + page.mouse.move(px, py, steps=5) + page.wait_for_timeout(80) + events = _get_events(page, "pointer_enter") + assert len(events) >= 1, "pointer_enter should fire when mouse enters plot area" + + def test_pointer_leave_emitted(self, interact_page): + _, _, page = _make_2d(interact_page) + cx, cy = _center() + page.mouse.move(cx, cy) + page.wait_for_timeout(30) + ox, oy = _outside_plot() + page.mouse.move(ox, oy, steps=5) + page.wait_for_timeout(80) + events = _get_events(page, "pointer_leave") + assert len(events) >= 1, "pointer_leave should fire when mouse exits plot area" + + def test_pointer_down_has_required_fields(self, interact_page): + _, _, page = _make_2d(interact_page) + cx, cy = _center() + page.mouse.click(cx, cy) + page.wait_for_timeout(100) + events = _get_events(page, "pointer_down") + assert events, "No pointer_down events collected" + e = events[0] + for field in ("event_type", "x", "y", "button", "buttons", "modifiers"): + assert field in e, f"pointer_down missing field {field!r}" + + def test_pointer_down_has_xdata_ydata(self, interact_page): + _, _, page = _make_2d(interact_page) + cx, cy = _center() + page.mouse.click(cx, cy) + page.wait_for_timeout(100) + events = _get_events(page, "pointer_down") + assert events + e = events[0] + assert "xdata" in e and "ydata" in e, "2D pointer_down should carry xdata/ydata" + + def test_wheel_has_dx_dy_fields(self, interact_page): + _, _, page = _make_2d(interact_page) + cx, cy = _center() + page.mouse.move(cx, cy) + page.mouse.wheel(0, -120) + page.wait_for_timeout(80) + events = _get_events(page, "wheel") + assert events + e = events[0] + assert "dy" in e or "dx" in e, "wheel event should carry dx or dy" + + +# ═══════════════════════════════════════════════════════════════════════════════ +# 2. Python dispatch — via _sim + real Python handlers +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestPythonDispatch: + """Verify Python callback machinery processes dispatched events correctly.""" + + def test_pointer_down_calls_handler(self, interact_page): + fig, plot, page = _make_2d(interact_page) + received = [] + plot.add_event_handler(lambda e: received.append(e.event_type), "pointer_down") + _sim(fig, plot, "pointer_down", x=200, y=150, xdata=16.0, ydata=16.0) + assert received == ["pointer_down"] + + def test_pointer_move_calls_handler(self, interact_page): + fig, plot, page = _make_2d(interact_page) + received = [] + plot.add_event_handler(lambda e: received.append(e.xdata), "pointer_move") + _sim(fig, plot, "pointer_move", x=200, y=150, xdata=8.0, ydata=8.0) + assert received == [8.0] + + def test_double_click_calls_handler(self, interact_page): + fig, plot, page = _make_2d(interact_page) + received = [] + plot.add_event_handler(lambda e: received.append(True), "double_click") + _sim(fig, plot, "double_click", x=200, y=150) + assert received == [True] + + def test_wheel_calls_handler(self, interact_page): + fig, plot, page = _make_2d(interact_page) + received = [] + plot.add_event_handler(lambda e: received.append(e.dy), "wheel") + _sim(fig, plot, "wheel", dx=0.0, dy=-1.0) + assert received == [-1.0] + + def test_key_down_calls_handler(self, interact_page): + fig, plot, page = _make_2d(interact_page) + received = [] + plot.add_event_handler(lambda e: received.append(e.key), "key_down") + _sim(fig, plot, "key_down", key="r") + assert received == ["r"] + + def test_wildcard_handler_receives_all_event_types(self, interact_page): + fig, plot, page = _make_2d(interact_page) + received = [] + plot.add_event_handler(lambda e: received.append(e.event_type), "*") + for etype in ("pointer_down", "pointer_up", "pointer_move", "wheel"): + _sim(fig, plot, etype, x=100, y=100) + assert received == ["pointer_down", "pointer_up", "pointer_move", "wheel"] + + def test_priority_order_respected(self, interact_page): + fig, plot, page = _make_2d(interact_page) + order = [] + plot.add_event_handler( + lambda e: order.append("low"), "pointer_down", order=1 + ) + plot.add_event_handler( + lambda e: order.append("high"), "pointer_down", order=0 + ) + _sim(fig, plot, "pointer_down", x=100, y=100) + assert order == ["high", "low"] + + def test_stop_propagation_halts_chain(self, interact_page): + fig, plot, page = _make_2d(interact_page) + called = [] + + def first(e): + called.append("first") + e.stop_propagation = True + + plot.add_event_handler(first, "pointer_down", order=0) + plot.add_event_handler(lambda e: called.append("second"), "pointer_down", order=1) + _sim(fig, plot, "pointer_down", x=100, y=100) + assert called == ["first"] + + def test_disconnect_stops_delivery(self, interact_page): + fig, plot, page = _make_2d(interact_page) + received = [] + fn = lambda e: received.append(1) + plot.add_event_handler(fn, "pointer_down") + plot.remove_handler(fn) + _sim(fig, plot, "pointer_down", x=100, y=100) + assert received == [] + + +# ═══════════════════════════════════════════════════════════════════════════════ +# 3. pause_events — JS emission + Python dispatch combined +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestPauseEventsPlaywright: + """pause_events drops events in the Python callback layer.""" + + def test_pause_suppresses_pointer_move_handler(self, interact_page): + """JS fires pointer_move; Python handler does not receive it while paused.""" + fig, plot, page = _make_2d(interact_page) + received = [] + plot.add_event_handler(lambda e: received.append(1), "pointer_move") + + with plot.pause_events("pointer_move"): + cx, cy = _center() + page.mouse.move(cx, cy) + page.mouse.move(cx + 20, cy, steps=5) + page.wait_for_timeout(50) + # JS events are sent to model; Python dispatch is paused + _sim(fig, plot, "pointer_move", x=200, y=150) + _sim(fig, plot, "pointer_move", x=210, y=150) + + assert received == [], ( + "pause_events should prevent handler from firing during the context" + ) + + def test_pause_allows_other_types_through(self, interact_page): + fig, plot, page = _make_2d(interact_page) + move_received = [] + down_received = [] + plot.add_event_handler(lambda e: move_received.append(1), "pointer_move") + plot.add_event_handler(lambda e: down_received.append(1), "pointer_down") + + with plot.pause_events("pointer_move"): + _sim(fig, plot, "pointer_move", x=100, y=100) + _sim(fig, plot, "pointer_down", x=100, y=100) + + assert move_received == [] + assert down_received == [1] + + def test_events_resume_after_pause_context(self, interact_page): + fig, plot, page = _make_2d(interact_page) + received = [] + plot.add_event_handler(lambda e: received.append(1), "pointer_move") + + with plot.pause_events("pointer_move"): + _sim(fig, plot, "pointer_move", x=100, y=100) + + _sim(fig, plot, "pointer_move", x=110, y=100) + assert received == [1], "Handler should fire after pause context exits" + + def test_js_still_emits_events_during_pause(self, interact_page): + """The browser still emits events during Python pause — only dispatch is suppressed. + + Uses a 3D panel because pointer_move fires on every mousemove there. + """ + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + x = np.linspace(-1, 1, 8) + X, Y = np.meshgrid(x, x) + plot = ax.plot_surface(X, Y, X ** 2 + Y ** 2) + page = interact_page(fig) + _collect_events(page) + + with plot.pause_events("pointer_move"): + cx, cy = _center() + page.mouse.move(cx, cy) + page.mouse.down() + page.mouse.move(cx + 40, cy, steps=8) + page.mouse.up() + page.wait_for_timeout(50) + + js_events = _get_events(page, "pointer_move") + assert len(js_events) > 0, ( + "JS should still emit pointer_move even while Python pause is active" + ) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# 4. hold_events — buffers and flushes on context exit +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestHoldEventsPlaywright: + """hold_events buffers Python callbacks and flushes them on context exit.""" + + def test_hold_buffers_during_context(self, interact_page): + fig, plot, page = _make_2d(interact_page) + received = [] + plot.add_event_handler( + lambda e: received.append(e.dwell_ms), + "pointer_settled", + ms=50, + delta=2, + ) + + with plot.hold_events("pointer_settled"): + _sim(fig, plot, "pointer_settled", x=200, y=150, dwell_ms=100.0) + _sim(fig, plot, "pointer_settled", x=205, y=150, dwell_ms=110.0) + assert received == [], "Buffered events should not fire inside hold context" + + assert len(received) == 2, "Both buffered events should flush on exit" + + def test_hold_flush_preserves_order(self, interact_page): + fig, plot, page = _make_2d(interact_page) + order = [] + plot.add_event_handler( + lambda e: order.append(e.x), + "pointer_settled", + ms=50, + ) + + with plot.hold_events("pointer_settled"): + for x in (10, 20, 30, 40): + _sim(fig, plot, "pointer_settled", x=x, y=100, dwell_ms=60.0) + + assert order == [10, 20, 30, 40] + + def test_hold_non_held_type_fires_immediately(self, interact_page): + fig, plot, page = _make_2d(interact_page) + move_calls = [] + settled_calls = [] + plot.add_event_handler(lambda e: move_calls.append(1), "pointer_move") + plot.add_event_handler( + lambda e: settled_calls.append(1), "pointer_settled", ms=50 + ) + + with plot.hold_events("pointer_settled"): + _sim(fig, plot, "pointer_move", x=100, y=100) + _sim(fig, plot, "pointer_settled", x=100, y=100, dwell_ms=60.0) + assert move_calls == [1], "pointer_move not held — should fire immediately" + assert settled_calls == [], "pointer_settled should still be buffered" + + assert settled_calls == [1] + + def test_pause_inside_hold_drops_not_buffers(self, interact_page): + """An event that matches both hold and pause: pause wins, event is dropped.""" + fig, plot, page = _make_2d(interact_page) + received = [] + plot.add_event_handler(lambda e: received.append(1), "pointer_move") + + with plot.hold_events("pointer_move"): + with plot.pause_events("pointer_move"): + _sim(fig, plot, "pointer_move", x=100, y=100) + + assert received == [], "pause inside hold should drop the event entirely" + + +# ═══════════════════════════════════════════════════════════════════════════════ +# 5. pointer_settled — real dwell detection via Playwright +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestPointerSettledPlaywright: + def test_pointer_settled_fires_after_dwell(self, interact_page): + """After the mouse stops moving, pointer_settled is emitted by JS. + + The handler must be registered BEFORE interact_page() so the settled + dwell config (ms/delta) is baked into the serialised state that the + standalone HTML page loads. + """ + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + plot = ax.imshow(np.zeros((32, 32), dtype=np.float32)) + plot.add_event_handler(lambda e: None, "pointer_settled", ms=100, delta=2) + page = interact_page(fig) + _collect_events(page) + + cx, cy = _center() + page.mouse.move(cx, cy) + page.wait_for_timeout(400) + + events = _get_events(page, "pointer_settled") + assert len(events) >= 1, "pointer_settled should fire after dwell timeout" + + def test_pointer_settled_not_fired_on_rapid_movement(self, interact_page): + """Continuous rapid movement suppresses pointer_settled.""" + fig, plot, page = _make_2d(interact_page) + plot.add_event_handler(lambda e: None, "pointer_settled", ms=300, delta=2) + + cx, cy = _center() + for _ in range(8): + page.mouse.move(cx, cy) + page.mouse.move(cx + 60, cy, steps=4) + page.mouse.move(cx, cy, steps=4) + page.wait_for_timeout(100) + + events = _get_events(page, "pointer_settled") + assert len(events) == 0, ( + "pointer_settled should not fire during continuous rapid movement" + ) diff --git a/anyplotlib/tests/test_interactive/test_callbacks_unit.py b/anyplotlib/tests/test_interactive/test_callbacks_unit.py new file mode 100644 index 00000000..9d53fdf9 --- /dev/null +++ b/anyplotlib/tests/test_interactive/test_callbacks_unit.py @@ -0,0 +1,528 @@ +""" +tests/test_interactive/test_callbacks_unit.py +============================================== + +Pure-Python unit tests for the callback system. No browser required. + +These tests cover: + - ``Event`` dataclass fields and defaults + - ``CallbackRegistry`` connect / disconnect / fire / priority / wildcards + - ``pause_events`` / ``hold_events`` context-manager semantics + - ``_EventMixin`` registration, decoration, and removal API + - Regression: old callback API is gone from all plot types + - ``fig.close()`` fires the ``close`` event on every panel +""" +from __future__ import annotations + +import time + +import numpy as np +import pytest + +import anyplotlib as apl +from anyplotlib.callbacks import Event, CallbackRegistry, VALID_EVENT_TYPES, _EventMixin + + +# ── Event dataclass ─────────────────────────────────────────────────────────── + +class TestEvent: + def test_required_fields(self): + e = Event(event_type="pointer_down", source=None) + assert e.event_type == "pointer_down" + assert e.source is None + + def test_time_stamp_auto_set(self): + before = time.perf_counter() + e = Event(event_type="pointer_down") + after = time.perf_counter() + assert before <= e.time_stamp <= after + + def test_modifiers_default_empty_list(self): + e = Event(event_type="pointer_move") + assert e.modifiers == [] + assert isinstance(e.modifiers, list) + + def test_pointer_fields_default_none(self): + e = Event(event_type="pointer_move") + assert e.x is None + assert e.y is None + assert e.button is None + assert e.buttons == 0 + assert e.xdata is None + assert e.ydata is None + assert e.ray is None + assert e.line_id is None + assert e.dwell_ms is None + + def test_wheel_fields_default_none(self): + e = Event(event_type="wheel") + assert e.dx is None + assert e.dy is None + + def test_key_field_default_none(self): + e = Event(event_type="key_down") + assert e.key is None + + def test_bar_fields_default_none(self): + e = Event(event_type="pointer_down") + assert e.bar_index is None + assert e.value is None + assert e.x_label is None + assert e.group_index is None + + def test_stop_propagation_default_false(self): + e = Event(event_type="pointer_down") + assert e.stop_propagation is False + + def test_all_fields_settable(self): + e = Event( + event_type="pointer_down", + source="plot", + modifiers=["ctrl", "shift"], + x=100, y=200, + button=0, buttons=1, + xdata=3.14, ydata=2.71, + line_id="abc12345", + bar_index=2, value=99.5, x_label="Jan", group_index=1, + dx=10.0, dy=-5.0, + key="q", + ) + assert e.modifiers == ["ctrl", "shift"] + assert e.x == 100 + assert e.xdata == 3.14 + assert e.line_id == "abc12345" + assert e.bar_index == 2 + assert e.key == "q" + assert e.dx == 10.0 + assert e.dy == -5.0 + + def test_no_data_dict_attribute(self): + e = Event(event_type="pointer_move") + assert not hasattr(e, "data") + + def test_repr_includes_event_type(self): + e = Event(event_type="pointer_down", x=10, y=20) + assert "pointer_down" in repr(e) + + def test_stop_propagation_not_in_repr(self): + e = Event(event_type="pointer_down", stop_propagation=True) + assert "stop_propagation" not in repr(e) + + +# ── CallbackRegistry ────────────────────────────────────────────────────────── + +class TestCallbackRegistry: + def test_connect_returns_int_cid(self): + reg = CallbackRegistry() + cid = reg.connect("pointer_down", lambda e: None) + assert isinstance(cid, int) + + def test_fire_calls_handler(self): + reg = CallbackRegistry() + calls = [] + reg.connect("pointer_down", lambda e: calls.append(e.event_type)) + reg.fire(Event("pointer_down")) + assert calls == ["pointer_down"] + + def test_fire_only_matching_type(self): + reg = CallbackRegistry() + calls = [] + reg.connect("pointer_down", lambda e: calls.append("down")) + reg.connect("pointer_up", lambda e: calls.append("up")) + reg.fire(Event("pointer_down")) + assert calls == ["down"] + + def test_disconnect_by_cid(self): + reg = CallbackRegistry() + calls = [] + cid = reg.connect("pointer_down", lambda e: calls.append(1)) + reg.disconnect(cid) + reg.fire(Event("pointer_down")) + assert calls == [] + + def test_disconnect_silent_if_not_found(self): + reg = CallbackRegistry() + reg.disconnect(999) # should not raise + + def test_wildcard_receives_all_types(self): + reg = CallbackRegistry() + calls = [] + reg.connect("*", lambda e: calls.append(e.event_type)) + reg.fire(Event("pointer_down")) + reg.fire(Event("key_down")) + reg.fire(Event("wheel")) + assert calls == ["pointer_down", "key_down", "wheel"] + + def test_priority_order(self): + reg = CallbackRegistry() + order = [] + reg.connect("pointer_down", lambda e: order.append("second"), order=1) + reg.connect("pointer_down", lambda e: order.append("first"), order=0) + reg.fire(Event("pointer_down")) + assert order == ["first", "second"] + + def test_same_priority_fires_in_registration_order(self): + reg = CallbackRegistry() + order = [] + reg.connect("pointer_down", lambda e: order.append("a"), order=0) + reg.connect("pointer_down", lambda e: order.append("b"), order=0) + reg.fire(Event("pointer_down")) + assert order == ["a", "b"] + + def test_stop_propagation(self): + reg = CallbackRegistry() + calls = [] + def handler_a(e): + calls.append("a") + e.stop_propagation = True + reg.connect("pointer_down", handler_a, order=0) + reg.connect("pointer_down", lambda e: calls.append("b"), order=1) + reg.fire(Event("pointer_down")) + assert calls == ["a"] + + def test_disconnect_fn_by_reference(self): + reg = CallbackRegistry() + calls = [] + fn = lambda e: calls.append(1) + reg.connect("pointer_down", fn) + reg.disconnect_fn(fn) + reg.fire(Event("pointer_down")) + assert calls == [] + + def test_disconnect_fn_specific_type(self): + reg = CallbackRegistry() + calls = [] + fn = lambda e: calls.append(e.event_type) + reg.connect("pointer_down", fn) + reg.connect("pointer_up", fn) + reg.disconnect_fn(fn, "pointer_down") + reg.fire(Event("pointer_down")) + reg.fire(Event("pointer_up")) + assert calls == ["pointer_up"] + + def test_bool_true_when_handlers_present(self): + reg = CallbackRegistry() + assert not bool(reg) + reg.connect("pointer_down", lambda e: None) + assert bool(reg) + + def test_invalid_event_type_raises(self): + reg = CallbackRegistry() + with pytest.raises(ValueError, match="Invalid event_type"): + reg.connect("on_click", lambda e: None) + + def test_connect_same_fn_multiple_types(self): + reg = CallbackRegistry() + calls = [] + fn = lambda e: calls.append(e.event_type) + reg.connect("pointer_down", fn) + reg.connect("pointer_up", fn) + reg.fire(Event("pointer_down")) + reg.fire(Event("pointer_up")) + assert calls == ["pointer_down", "pointer_up"] + + +# ── pause_events / hold_events ──────────────────────────────────────────────── + +class TestPauseHold: + def test_pause_drops_events(self): + reg = CallbackRegistry() + calls = [] + reg.connect("pointer_move", lambda e: calls.append(1)) + with reg.pause_events("pointer_move"): + reg.fire(Event("pointer_move")) + assert calls == [] + + def test_pause_handlers_intact_after_exit(self): + reg = CallbackRegistry() + calls = [] + reg.connect("pointer_move", lambda e: calls.append(1)) + with reg.pause_events("pointer_move"): + reg.fire(Event("pointer_move")) + reg.fire(Event("pointer_move")) + assert calls == [1] + + def test_pause_all_types_when_no_args(self): + reg = CallbackRegistry() + calls = [] + reg.connect("pointer_down", lambda e: calls.append("down")) + reg.connect("key_down", lambda e: calls.append("key")) + with reg.pause_events(): + reg.fire(Event("pointer_down")) + reg.fire(Event("key_down")) + assert calls == [] + + def test_pause_only_specified_type(self): + reg = CallbackRegistry() + calls = [] + reg.connect("pointer_move", lambda e: calls.append("move")) + reg.connect("pointer_down", lambda e: calls.append("down")) + with reg.pause_events("pointer_move"): + reg.fire(Event("pointer_move")) + reg.fire(Event("pointer_down")) + assert calls == ["down"] + + def test_pause_nested_same_type(self): + reg = CallbackRegistry() + calls = [] + reg.connect("pointer_move", lambda e: calls.append(1)) + with reg.pause_events("pointer_move"): + with reg.pause_events("pointer_move"): + reg.fire(Event("pointer_move")) + reg.fire(Event("pointer_move")) # still paused + reg.fire(Event("pointer_move")) # now fires + assert calls == [1] + + def test_hold_buffers_and_flushes_on_exit(self): + reg = CallbackRegistry() + calls = [] + reg.connect("pointer_settled", lambda e: calls.append(1)) + with reg.hold_events("pointer_settled"): + reg.fire(Event("pointer_settled")) + reg.fire(Event("pointer_settled")) + assert calls == [] + assert calls == [1, 1] + + def test_hold_fires_non_held_types_immediately(self): + reg = CallbackRegistry() + move_calls = [] + settled_calls = [] + reg.connect("pointer_move", lambda e: move_calls.append(1)) + reg.connect("pointer_settled", lambda e: settled_calls.append(1)) + with reg.hold_events("pointer_settled"): + reg.fire(Event("pointer_move")) + reg.fire(Event("pointer_settled")) + assert move_calls == [1] + assert settled_calls == [1] + + def test_hold_events_in_order(self): + reg = CallbackRegistry() + calls = [] + reg.connect("pointer_settled", lambda e: calls.append(e.x)) + with reg.hold_events(): + reg.fire(Event("pointer_settled", x=1)) + reg.fire(Event("pointer_settled", x=2)) + reg.fire(Event("pointer_settled", x=3)) + assert calls == [1, 2, 3] + + def test_pause_wins_over_hold(self): + reg = CallbackRegistry() + calls = [] + reg.connect("pointer_move", lambda e: calls.append(1)) + with reg.hold_events("pointer_move"): + with reg.pause_events("pointer_move"): + reg.fire(Event("pointer_move")) + assert calls == [] + + +# ── _EventMixin ─────────────────────────────────────────────────────────────── + +class _FakePlot(_EventMixin): + def __init__(self): + self.callbacks = CallbackRegistry() + self._settled_config = (0, 0) + + def _configure_pointer_settled(self, ms: int, delta: float) -> None: + self._settled_config = (ms, delta) + + +class TestEventMixin: + def test_functional_form_single_type(self): + plot = _FakePlot() + calls = [] + plot.add_event_handler(lambda e: calls.append(e.event_type), "pointer_down") + plot.callbacks.fire(Event("pointer_down")) + assert calls == ["pointer_down"] + + def test_functional_form_multi_type(self): + plot = _FakePlot() + calls = [] + fn = lambda e: calls.append(e.event_type) + plot.add_event_handler(fn, "pointer_down", "pointer_up") + plot.callbacks.fire(Event("pointer_down")) + plot.callbacks.fire(Event("pointer_up")) + assert calls == ["pointer_down", "pointer_up"] + + def test_decorator_form_single_type(self): + plot = _FakePlot() + calls = [] + @plot.add_event_handler("pointer_move") + def handler(e): + calls.append(e.event_type) + plot.callbacks.fire(Event("pointer_move")) + assert calls == ["pointer_move"] + + def test_decorator_form_multi_type(self): + plot = _FakePlot() + calls = [] + @plot.add_event_handler("pointer_down", "key_down") + def handler(e): + calls.append(e.event_type) + plot.callbacks.fire(Event("pointer_down")) + plot.callbacks.fire(Event("key_down")) + assert calls == ["pointer_down", "key_down"] + + def test_wildcard_decorator(self): + plot = _FakePlot() + calls = [] + @plot.add_event_handler("*") + def handler(e): + calls.append(e.event_type) + plot.callbacks.fire(Event("pointer_down")) + plot.callbacks.fire(Event("wheel")) + assert calls == ["pointer_down", "wheel"] + + def test_remove_handler_by_fn(self): + plot = _FakePlot() + calls = [] + fn = lambda e: calls.append(1) + plot.add_event_handler(fn, "pointer_down") + plot.remove_handler(fn) + plot.callbacks.fire(Event("pointer_down")) + assert calls == [] + + def test_remove_handler_by_fn_specific_type(self): + plot = _FakePlot() + calls = [] + fn = lambda e: calls.append(e.event_type) + plot.add_event_handler(fn, "pointer_down", "pointer_up") + plot.remove_handler(fn, "pointer_down") + plot.callbacks.fire(Event("pointer_down")) + plot.callbacks.fire(Event("pointer_up")) + assert calls == ["pointer_up"] + + def test_remove_handler_by_cid(self): + plot = _FakePlot() + calls = [] + cid = plot.callbacks.connect("pointer_down", lambda e: calls.append(1)) + plot.remove_handler(cid) + plot.callbacks.fire(Event("pointer_down")) + assert calls == [] + + def test_pointer_settled_configures_on_connect(self): + plot = _FakePlot() + plot.add_event_handler(lambda e: None, "pointer_settled", ms=400, delta=5) + assert plot._settled_config == (400, 5) + + def test_pointer_settled_clears_on_last_disconnect(self): + plot = _FakePlot() + fn = lambda e: None + plot.add_event_handler(fn, "pointer_settled", ms=400, delta=5) + plot.remove_handler(fn) + assert plot._settled_config == (0, 0) + + def test_ms_delta_without_settled_raises(self): + plot = _FakePlot() + with pytest.raises(ValueError, match="ms/delta"): + plot.add_event_handler(lambda e: None, "pointer_down", ms=400) + + def test_pause_events_delegates_to_registry(self): + plot = _FakePlot() + calls = [] + plot.add_event_handler(lambda e: calls.append(1), "pointer_move") + with plot.pause_events("pointer_move"): + plot.callbacks.fire(Event("pointer_move")) + assert calls == [] + + def test_hold_events_delegates_to_registry(self): + plot = _FakePlot() + calls = [] + plot.add_event_handler(lambda e: calls.append(1), "pointer_settled") + with plot.hold_events("pointer_settled"): + plot.callbacks.fire(Event("pointer_settled")) + assert calls == [] + assert calls == [1] + + +# ── Regression: old API is gone ────────────────────────────────────────────── + +class TestRegressionOldAPIGone: + def test_plot1d_no_on_click(self): + fig, ax = apl.subplots(1, 1) + plot = ax.plot(np.zeros(10)) + assert not hasattr(plot, "on_click") + + def test_plot1d_no_on_changed(self): + fig, ax = apl.subplots(1, 1) + plot = ax.plot(np.zeros(10)) + assert not hasattr(plot, "on_changed") + + def test_plot1d_no_on_release(self): + fig, ax = apl.subplots(1, 1) + plot = ax.plot(np.zeros(10)) + assert not hasattr(plot, "on_release") + + def test_plot2d_no_on_click(self): + fig, ax = apl.subplots(1, 1) + plot = ax.imshow(np.zeros((32, 32))) + assert not hasattr(plot, "on_click") + + def test_widget_no_on_changed(self): + fig, ax = apl.subplots(1, 1) + plot = ax.plot(np.zeros(10)) + w = plot.add_vline_widget(5.0) + assert not hasattr(w, "on_changed") + + def test_event_no_phys_x(self): + e = Event(event_type="pointer_down", xdata=3.14) + assert not hasattr(e, "phys_x") + assert e.xdata == 3.14 + + def test_plot3d_no_on_click(self): + x = np.linspace(-2, 2, 10) + XX, YY = np.meshgrid(x, x) + fig, ax = apl.subplots(1, 1) + plot = ax.plot_surface(XX, YY, np.zeros_like(XX)) + assert not hasattr(plot, "on_click") + + def test_plotbar_no_on_click(self): + fig, ax = apl.subplots(1, 1) + plot = ax.bar(["A", "B"], [1.0, 2.0]) + assert not hasattr(plot, "on_click") + + def test_line1d_no_on_hover(self): + fig, ax = apl.subplots(1, 1) + plot = ax.plot(np.zeros(10)) + line = plot.add_line(np.zeros(10)) + assert not hasattr(line, "on_hover") + + +# ── fig.close() ────────────────────────────────────────────────────────────── + +class TestFigureClose: + def test_close_in_valid_event_types(self): + assert "close" in VALID_EVENT_TYPES + + def test_figure_close_sets_closed_flag(self): + fig, ax = apl.subplots(1, 1) + ax.plot(np.zeros(10)) + assert not getattr(fig, "_closed", False) + fig.close() + assert fig._closed is True + + def test_figure_close_fires_event_on_plot(self): + fig, ax = apl.subplots(1, 1) + plot = ax.plot(np.zeros(10)) + received = [] + plot.callbacks.connect("close", lambda e: received.append(e.event_type)) + fig.close() + assert received == ["close"] + + def test_figure_close_fires_on_all_panels(self): + fig, (ax1, ax2) = apl.subplots(1, 2) + p1 = ax1.plot(np.zeros(10)) + p2 = ax2.imshow(np.zeros((8, 8))) + counts = [0, 0] + p1.callbacks.connect("close", lambda e: counts.__setitem__(0, counts[0] + 1)) + p2.callbacks.connect("close", lambda e: counts.__setitem__(1, counts[1] + 1)) + fig.close() + assert counts == [1, 1] + + def test_figure_close_is_idempotent(self): + fig, ax = apl.subplots(1, 1) + plot = ax.plot(np.zeros(10)) + received = [] + plot.callbacks.connect("close", lambda e: received.append(e)) + fig.close() + fig.close() + assert len(received) == 1 diff --git a/anyplotlib/tests/test_interactive/test_event_pause_hold.py b/anyplotlib/tests/test_interactive/test_event_pause_hold.py new file mode 100644 index 00000000..2a21e70d --- /dev/null +++ b/anyplotlib/tests/test_interactive/test_event_pause_hold.py @@ -0,0 +1,214 @@ +""" +tests/test_interactive/test_event_pause_hold.py +================================================ + +Tests for ``pause_events`` and ``hold_events`` Python-side context managers. + +``pause_events`` and ``hold_events`` operate on the ``CallbackRegistry`` +after events have been dispatched to Python. The Figure's ``_dispatch_event`` +method is the entry point: it builds an ``Event`` and calls +``plot.callbacks.fire()``. When paused, ``fire()`` drops the event; when +held, ``fire()`` buffers it and flushes on context exit. + +In the standalone Playwright setup there is no real Python kernel — the model +is a JS-only shim. Python handlers are therefore not reachable from the +browser. These tests drive the Python dispatch path directly via +``fig._dispatch_event(json_str)`` to verify pause/hold semantics end-to-end, +with a Playwright test verifying JS actually sends the expected events. +""" +from __future__ import annotations + +import json + +import numpy as np +import pytest + +import anyplotlib as apl +from anyplotlib.tests.test_interactive._event_test_utils import ( + _collect_events, + _get_events, + GRID_PAD, +) + +FIG_W, FIG_H = 400, 300 + + +def _sim(fig, plot, event_type: str, **fields) -> None: + """Simulate a JS event by calling fig._dispatch_event directly.""" + payload = {"source": "js", "panel_id": plot._id, "event_type": event_type} + payload.update(fields) + fig._dispatch_event(json.dumps(payload)) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# 1. pause_events — Python-side dispatch simulation +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestPauseIntegration: + def test_pause_drops_pointer_move(self): + """pause_events suppresses Python handler calls for pointer_move.""" + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + plot = ax.imshow(np.zeros((64, 64))) + received = [] + plot.add_event_handler(lambda e: received.append(1), "pointer_move") + + with plot.pause_events("pointer_move"): + _sim(fig, plot, "pointer_move", x=100, y=100) + _sim(fig, plot, "pointer_move", x=110, y=100) + + assert received == [], ( + f"pause_events should drop all pointer_move calls; got {len(received)}" + ) + + def test_events_resume_after_pause_exits(self): + """pointer_move handler fires again after pause_events context exits.""" + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + plot = ax.imshow(np.zeros((64, 64))) + received = [] + plot.add_event_handler(lambda e: received.append(1), "pointer_move") + + with plot.pause_events("pointer_move"): + _sim(fig, plot, "pointer_move", x=100, y=100) + + assert received == [], "No events during pause" + + # After context exits, moves fire again + _sim(fig, plot, "pointer_move", x=120, y=100) + assert len(received) == 1, ( + "pointer_move should fire after pause_events context exits" + ) + + def test_pause_only_specified_type(self): + """pause_events('pointer_move') does not suppress pointer_down.""" + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + plot = ax.imshow(np.zeros((64, 64))) + move_calls = [] + down_calls = [] + plot.add_event_handler(lambda e: move_calls.append(1), "pointer_move") + plot.add_event_handler(lambda e: down_calls.append(1), "pointer_down") + + with plot.pause_events("pointer_move"): + _sim(fig, plot, "pointer_move", x=100, y=100) + _sim(fig, plot, "pointer_down", x=100, y=100) + + assert move_calls == [], "pointer_move should be paused" + assert len(down_calls) == 1, "pointer_down should not be paused" + + +# ═══════════════════════════════════════════════════════════════════════════════ +# 2. hold_events — buffers and flushes on exit +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestHoldIntegration: + def test_hold_buffers_pointer_settled_and_flushes_on_exit(self): + """pointer_settled is buffered during hold and flushed on context exit.""" + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + plot = ax.imshow(np.zeros((32, 32))) + received = [] + plot.add_event_handler( + lambda e: received.append(e), + "pointer_settled", + ms=200, + delta=4, + ) + + with plot.hold_events("pointer_settled"): + _sim(fig, plot, "pointer_settled", x=100, y=100, dwell_ms=250.0) + _sim(fig, plot, "pointer_settled", x=101, y=100, dwell_ms=260.0) + assert received == [], "Handler should not be called while holding" + + assert len(received) == 2, ( + f"Both buffered events should flush on context exit; got {len(received)}" + ) + + def test_hold_is_type_specific(self): + """hold_events('pointer_settled') does not delay pointer_move.""" + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + plot = ax.imshow(np.zeros((32, 32))) + + move_received = [] + settled_received = [] + plot.add_event_handler( + lambda e: move_received.append(1), "pointer_move" + ) + plot.add_event_handler( + lambda e: settled_received.append(1), + "pointer_settled", + ms=200, + delta=4, + ) + + with plot.hold_events("pointer_settled"): + _sim(fig, plot, "pointer_move", x=100, y=100) + _sim(fig, plot, "pointer_settled", x=100, y=100, dwell_ms=250.0) + + # pointer_move fires immediately + assert len(move_received) == 1, ( + "pointer_move should not be held when only pointer_settled is held" + ) + # pointer_settled is still buffered + assert settled_received == [], ( + "pointer_settled should not have fired yet (still inside hold)" + ) + + # On exit, buffered pointer_settled is flushed + assert len(settled_received) == 1, ( + "pointer_settled should flush on context exit" + ) + + def test_hold_flush_preserves_event_order(self): + """Buffered events are flushed in the order they were received.""" + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + plot = ax.imshow(np.zeros((32, 32))) + order = [] + plot.add_event_handler( + lambda e: order.append(e.x), + "pointer_settled", + ms=200, + ) + + with plot.hold_events("pointer_settled"): + for xval in (10, 20, 30): + _sim(fig, plot, "pointer_settled", x=xval, y=100, dwell_ms=210.0) + + assert order == [10, 20, 30], ( + f"Events should flush in order; got {order}" + ) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# 3. Playwright smoke test — JS sends pointer_move during drag on 3D panel +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestPlaywrightJSSends: + """Verify JS actually emits pointer_move events that could be paused/held. + + This confirms the JS side of the pipeline is working; the pause/hold + semantics are tested purely in Python (above) since the standalone shim + has no real Python kernel. + """ + + def test_3d_drag_sends_pointer_move_events(self, interact_page): + """A drag on a 3D panel emits multiple pointer_move event_json payloads.""" + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + x = np.linspace(-1, 1, 8) + X, Y = np.meshgrid(x, x) + Z = X ** 2 + Y ** 2 + plot = ax.plot_surface(X, Y, Z) + + page = interact_page(fig) + _collect_events(page) + + cx = FIG_W // 2 + GRID_PAD + cy = FIG_H // 2 + GRID_PAD + page.mouse.move(cx, cy) + page.mouse.down() + page.mouse.move(cx + 40, cy, steps=6) + page.mouse.up() + page.wait_for_timeout(100) + + events = _get_events(page, "pointer_move") + assert len(events) > 0, ( + "JS should emit pointer_move events during a 3D drag; " + "these are what pause_events/hold_events would intercept in Python" + ) diff --git a/anyplotlib/tests/test_interactive/test_event_plots.py b/anyplotlib/tests/test_interactive/test_event_plots.py new file mode 100644 index 00000000..b95a0e44 --- /dev/null +++ b/anyplotlib/tests/test_interactive/test_event_plots.py @@ -0,0 +1,295 @@ +""" +tests/test_interactive/test_event_plots.py +========================================== + +Playwright tests verifying that the JS event system correctly emits the new +event types introduced in the event system redesign. + +Coordinate system (mirrors figure_esm.js constants): + PAD_L=58 PAD_R=12 PAD_T=12 PAD_B=42 GRID_PAD=8 + For a 400×300 fig: plot rect = {x:58, y:12, w:330, h:246} + Page coords = canvas coords + 8 +""" +from __future__ import annotations + +import numpy as np +import pytest + +import anyplotlib as apl +from anyplotlib.tests.test_interactive._event_test_utils import ( + _collect_events, + _get_events, + _plot_center_page, + GRID_PAD, +) + +FIG_W, FIG_H = 400, 300 + + +# ── fixtures ────────────────────────────────────────────────────────────────── + +def _make_2d_page(interact_page): + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + plot = ax.imshow(np.zeros((32, 32))) + page = interact_page(fig) + _collect_events(page) + return page, plot + + +def _make_3d_page(interact_page): + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + x = np.linspace(-1, 1, 8) + X, Y = np.meshgrid(x, x) + Z = X ** 2 + Y ** 2 + plot = ax.plot_surface(X, Y, Z) + page = interact_page(fig) + _collect_events(page) + return page, plot + + +# ═══════════════════════════════════════════════════════════════════════════════ +# pointer_down — 2D click emits correct fields +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestPointerDown: + def test_2d_click_emits_pointer_down_fields(self, interact_page): + """Short click on a 2D panel emits pointer_down with required fields.""" + page, plot = _make_2d_page(interact_page) + px, py = _plot_center_page() + + page.mouse.click(px, py) + page.wait_for_timeout(100) + + events = _get_events(page, "pointer_down") + assert len(events) >= 1, "Expected at least one pointer_down event" + e = events[0] + for field in ("event_type", "x", "y", "button", "buttons", "modifiers", "time_stamp"): + assert field in e, f"pointer_down missing field {field!r}" + assert e["event_type"] == "pointer_down" + assert isinstance(e["modifiers"], list) + + def test_2d_pointer_down_has_xdata_ydata(self, interact_page): + """Plot2D pointer_down includes xdata and ydata fields.""" + page, plot = _make_2d_page(interact_page) + px, py = _plot_center_page() + + page.mouse.click(px, py) + page.wait_for_timeout(100) + + events = _get_events(page, "pointer_down") + assert len(events) >= 1 + e = events[0] + assert "xdata" in e, "2D pointer_down must include xdata" + assert "ydata" in e, "2D pointer_down must include ydata" + assert e["xdata"] is not None + assert e["ydata"] is not None + + def test_ctrl_click_modifiers(self, interact_page): + """Ctrl+click produces modifiers=['ctrl'] on pointer_down.""" + page, plot = _make_2d_page(interact_page) + px, py = _plot_center_page() + + page.keyboard.down("Control") + page.mouse.click(px, py) + page.keyboard.up("Control") + page.wait_for_timeout(100) + + events = _get_events(page, "pointer_down") + assert len(events) >= 1 + assert "ctrl" in events[0].get("modifiers", []), ( + f"Expected 'ctrl' in modifiers, got {events[0].get('modifiers')!r}" + ) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# pointer_up — fires after mousedown + mousemove + mouseup +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestPointerUp: + def test_fires_after_drag(self, interact_page): + """pointer_up fires after a drag sequence.""" + page, plot = _make_2d_page(interact_page) + px, py = _plot_center_page() + + page.mouse.move(px, py) + page.mouse.down() + page.mouse.move(px + 30, py, steps=5) + page.mouse.up() + page.wait_for_timeout(100) + + events = _get_events(page, "pointer_up") + assert len(events) >= 1, "Expected at least one pointer_up event" + + +# ═══════════════════════════════════════════════════════════════════════════════ +# pointer_move — fires during drag (3D panel emits it on every mousemove) +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestPointerMove: + def test_fires_during_drag(self, interact_page): + """pointer_move events fire during a drag on a 3D panel.""" + page, plot = _make_3d_page(interact_page) + cx = FIG_W // 2 + GRID_PAD + cy = FIG_H // 2 + GRID_PAD + + page.mouse.move(cx, cy) + page.mouse.down() + page.mouse.move(cx + 40, cy, steps=8) + page.mouse.up() + page.wait_for_timeout(100) + + events = _get_events(page, "pointer_move") + assert len(events) > 0, "Expected pointer_move events during 3D drag" + + +# ═══════════════════════════════════════════════════════════════════════════════ +# pointer_enter / pointer_leave +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestPointerEnterLeave: + def test_pointer_enter_fires_on_mouseenter(self, interact_page): + """pointer_enter fires when mouse enters the canvas.""" + page, plot = _make_2d_page(interact_page) + px, py = _plot_center_page() + + # Start outside, move inside + page.mouse.move(0, 0) + page.wait_for_timeout(50) + page.mouse.move(px, py) + page.wait_for_timeout(100) + + events = _get_events(page, "pointer_enter") + assert len(events) >= 1, "Expected pointer_enter event" + e = events[0] + # button should be null when no button is held + assert e.get("button") is None, ( + f"pointer_enter button should be null, got {e.get('button')!r}" + ) + + def test_pointer_leave_fires_on_mouseleave(self, interact_page): + """pointer_leave fires when mouse leaves the canvas.""" + page, plot = _make_2d_page(interact_page) + px, py = _plot_center_page() + + page.mouse.move(px, py) + page.wait_for_timeout(50) + page.mouse.move(0, 0) + page.wait_for_timeout(100) + + events = _get_events(page, "pointer_leave") + assert len(events) >= 1, "Expected pointer_leave event" + + +# ═══════════════════════════════════════════════════════════════════════════════ +# double_click +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestDoubleClick: + def test_fires_on_dblclick(self, interact_page): + """double_click event fires on a browser dblclick.""" + page, plot = _make_2d_page(interact_page) + px, py = _plot_center_page() + + page.mouse.dblclick(px, py) + page.wait_for_timeout(100) + + events = _get_events(page, "double_click") + assert len(events) >= 1, "Expected double_click event" + assert events[0].get("button") == 0 + + +# ═══════════════════════════════════════════════════════════════════════════════ +# wheel +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestWheel: + def test_fires_with_dy_field(self, interact_page): + """wheel event fires with a dy field when scrolling.""" + page, plot = _make_2d_page(interact_page) + px, py = _plot_center_page() + + page.mouse.move(px, py) + page.wait_for_timeout(50) + page.mouse.wheel(0, 120) + page.wait_for_timeout(100) + + events = _get_events(page, "wheel") + assert len(events) >= 1, "Expected wheel event" + e = events[0] + assert "dy" in e, "wheel event must include dy field" + + +# ═══════════════════════════════════════════════════════════════════════════════ +# key_down / key_up +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestKeyEvents: + def test_key_down_fires_on_keypress(self, interact_page): + """key_down fires for any keypress (not just registered shortcuts).""" + page, plot = _make_2d_page(interact_page) + px, py = _plot_center_page() + + # Focus canvas via mouseenter + page.mouse.move(px, py) + page.wait_for_timeout(50) + + page.keyboard.press("q") + page.wait_for_timeout(100) + + events = _get_events(page, "key_down") + assert len(events) >= 1, "Expected key_down event" + e = events[0] + assert e.get("key") == "q", f"Expected key='q', got {e.get('key')!r}" + + def test_key_up_fires_on_key_release(self, interact_page): + """key_up fires when a key is released.""" + page, plot = _make_2d_page(interact_page) + px, py = _plot_center_page() + + page.mouse.move(px, py) + page.wait_for_timeout(50) + + page.keyboard.down("z") + page.wait_for_timeout(30) + page.keyboard.up("z") + page.wait_for_timeout(100) + + events = _get_events(page, "key_up") + assert len(events) >= 1, "Expected key_up event" + e = events[0] + assert e.get("key") == "z", f"Expected key='z', got {e.get('key')!r}" + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Plot3D — pointer_down absent, wheel present +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestPlot3DEvents: + def test_3d_pointer_down_no_xdata(self, interact_page): + """3D panels do not emit pointer_down events (no click detection in 3D).""" + page, plot = _make_3d_page(interact_page) + _collect_events(page) + cx = FIG_W // 2 + GRID_PAD + cy = FIG_H // 2 + GRID_PAD + + page.mouse.move(cx, cy) + page.mouse.click(cx, cy) + page.wait_for_timeout(300) + + events = _get_events(page, "pointer_down") + assert len(events) == 0, "3D panels should not emit pointer_down events" + + def test_3d_wheel_fires(self, interact_page): + """Plot3D emits a wheel event on scroll.""" + page, plot = _make_3d_page(interact_page) + cx = FIG_W // 2 + GRID_PAD + cy = FIG_H // 2 + GRID_PAD + + page.mouse.move(cx, cy) + page.wait_for_timeout(50) + page.mouse.wheel(0, 120) + page.wait_for_timeout(100) + + wheel_events = _get_events(page, "wheel") + assert len(wheel_events) >= 1, "Expected wheel event from 3D panel" + assert "dy" in wheel_events[0] diff --git a/anyplotlib/tests/test_interactive/test_event_settled.py b/anyplotlib/tests/test_interactive/test_event_settled.py new file mode 100644 index 00000000..8655af6f --- /dev/null +++ b/anyplotlib/tests/test_interactive/test_event_settled.py @@ -0,0 +1,185 @@ +""" +tests/test_interactive/test_event_settled.py +============================================ + +Pure-Python unit tests and Playwright integration tests for the +``pointer_settled`` event. + +Pure-Python tests verify that connecting / disconnecting a handler updates +the ``pointer_settled_ms`` / ``pointer_settled_delta`` state fields. + +Playwright tests verify that the JS dwell timer fires after the configured +dwell period and suppresses when the pointer keeps moving. +""" +from __future__ import annotations + +import time + +import numpy as np +import pytest + +import anyplotlib as apl +from anyplotlib.tests.test_interactive._event_test_utils import ( + _collect_events, + _get_events, + _plot_center_page, +) + +FIG_W, FIG_H = 400, 300 + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Pure-Python: state field updates on connect / disconnect +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestSettledConfig: + def test_default_state_before_any_handler(self): + """pointer_settled_ms starts at 0 and delta at 4 before any handler.""" + fig, ax = apl.subplots(1, 1) + plot = ax.imshow(np.zeros((32, 32))) + assert plot._state["pointer_settled_ms"] == 0 + assert plot._state["pointer_settled_delta"] == 4 + + def test_state_set_on_first_connect(self): + """Connecting a pointer_settled handler sets ms and delta in _state.""" + fig, ax = apl.subplots(1, 1) + plot = ax.imshow(np.zeros((32, 32))) + plot.add_event_handler(lambda e: None, "pointer_settled", ms=400, delta=5) + assert plot._state["pointer_settled_ms"] == 400 + assert plot._state["pointer_settled_delta"] == 5 + + def test_state_cleared_on_last_disconnect(self): + """Removing the last pointer_settled handler resets ms to 0.""" + fig, ax = apl.subplots(1, 1) + plot = ax.imshow(np.zeros((32, 32))) + fn = lambda e: None + plot.add_event_handler(fn, "pointer_settled", ms=400, delta=5) + plot.remove_handler(fn) + assert plot._state["pointer_settled_ms"] == 0 + assert plot._state["pointer_settled_delta"] == 0 + + def test_multiple_handlers_use_last_configured_ms(self): + """Adding a second handler overrides ms/delta with the new values.""" + fig, ax = apl.subplots(1, 1) + plot = ax.imshow(np.zeros((32, 32))) + fn1 = lambda e: None + fn2 = lambda e: None + plot.add_event_handler(fn1, "pointer_settled", ms=300, delta=4) + plot.add_event_handler(fn2, "pointer_settled", ms=500, delta=8) + assert plot._state["pointer_settled_ms"] == 500 + assert plot._state["pointer_settled_delta"] == 8 + + def test_remove_one_handler_keeps_nonzero_ms(self): + """Removing one handler when another remains keeps ms non-zero.""" + fig, ax = apl.subplots(1, 1) + plot = ax.imshow(np.zeros((32, 32))) + fn1 = lambda e: None + fn2 = lambda e: None + plot.add_event_handler(fn1, "pointer_settled", ms=400) + plot.add_event_handler(fn2, "pointer_settled", ms=400) + plot.remove_handler(fn1) + # fn2 is still connected — ms must remain non-zero + assert plot._state["pointer_settled_ms"] > 0 + + +# ═══════════════════════════════════════════════════════════════════════════════ +# Playwright: dwell timer fires / suppresses correctly +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestSettledPlaywright: + def _make_page(self, interact_page, ms: int = 200, delta: int = 4): + """Create a 2D imshow with a pointer_settled handler at ms=200.""" + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + plot = ax.imshow(np.zeros((32, 32))) + received = [] + plot.add_event_handler( + lambda e: received.append(e), + "pointer_settled", + ms=ms, + delta=delta, + ) + page = interact_page(fig) + _collect_events(page) + return page, plot, received + + def test_no_timer_when_no_handler(self, interact_page): + """pointer_settled_ms stays 0 in JS when no handler is connected.""" + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + plot = ax.imshow(np.zeros((32, 32))) + # No handler — do NOT call add_event_handler + page = interact_page(fig) + + ms_val = page.evaluate( + f"() => JSON.parse(window._aplModel.get('panel_{plot._id}_json')).pointer_settled_ms" + ) + assert ms_val == 0, ( + f"pointer_settled_ms should be 0 when no handler connected, got {ms_val}" + ) + + def test_fires_after_hold(self, interact_page): + """pointer_settled fires after the pointer holds still for >= ms.""" + page, plot, received = self._make_page(interact_page, ms=200) + px, py = _plot_center_page() + + # Confirm JS sees the configured ms + ms_in_js = page.evaluate( + f"() => JSON.parse(window._aplModel.get('panel_{plot._id}_json')).pointer_settled_ms" + ) + assert ms_in_js == 200, f"JS pointer_settled_ms should be 200, got {ms_in_js}" + + # Move into panel and hold still for 350 ms (well past 200 ms threshold) + page.mouse.move(px, py) + page.wait_for_timeout(350) + + events = _get_events(page, "pointer_settled") + assert len(events) >= 1, ( + "pointer_settled should fire after holding still for >= 200 ms" + ) + e = events[0] + assert "dwell_ms" in e, "pointer_settled must include dwell_ms" + assert e["dwell_ms"] >= 200, ( + f"dwell_ms should be >= 200, got {e['dwell_ms']:.1f}" + ) + + def test_does_not_fire_if_moving(self, interact_page): + """pointer_settled does not fire if the pointer keeps moving.""" + page, plot, received = self._make_page(interact_page, ms=300) + px, py = _plot_center_page() + + # Keep moving for 250 ms (less than 300 ms threshold) + page.mouse.move(px, py) + for _ in range(8): + px += 5 + page.mouse.move(px, py) + page.wait_for_timeout(30) + + events = _get_events(page, "pointer_settled") + assert len(events) == 0, ( + "pointer_settled should not fire while the pointer is still moving" + ) + + def test_fires_again_after_re_settle(self, interact_page): + """pointer_settled fires a second time after a second dwell period.""" + page, plot, received = self._make_page(interact_page, ms=200) + px, py = _plot_center_page() + + def _settled_count(): + return "() => window._aplAllEvents.filter(e => e.event_type === 'pointer_settled').length" + + # First dwell — wait for the event rather than sleeping a fixed amount + page.mouse.move(px, py) + page.wait_for_function(f"{_settled_count()} >= 1", timeout=2000) + assert len(_get_events(page, "pointer_settled")) >= 1, ( + "First pointer_settled should have fired" + ) + + # Move away to reset the timer, then hold for a second dwell period + page.mouse.move(px + 30, py + 30) + page.wait_for_timeout(50) # ensure the move is processed before re-entering + page.mouse.move(px, py) + page.wait_for_function(f"{_settled_count()} >= 2", timeout=2000) + + second_count = len(_get_events(page, "pointer_settled")) + assert second_count >= 2, ( + f"Expected at least 2 pointer_settled events, got {second_count}" + ) diff --git a/anyplotlib/tests/test_interactive/test_events_regression.py b/anyplotlib/tests/test_interactive/test_events_regression.py new file mode 100644 index 00000000..fdac7e88 --- /dev/null +++ b/anyplotlib/tests/test_interactive/test_events_regression.py @@ -0,0 +1,1078 @@ +""" +tests/test_interactive/test_events_regression.py +================================================= + +Regression tests for event isolation in figure_esm.js. + +Core invariants verified here +------------------------------ +1. double_click fires on dblclick and is NOT consumed/suppressed by the + pan/drag machinery or the single-click candidate logic. +2. A true drag (significant movement) does NOT emit pointer_down; it emits + pointer_up instead. +3. A short single click emits exactly one pointer_down (no spurious extras). +4. Right-click (button=2) does not trigger the left-click event path. +5. The wheel event fires independently of click/drag state. +6. Event ordering on a double-click: pointer_down ×2 → double_click. +7. A drag followed immediately by a double-click: double_click still fires. + +Coordinate system (mirrors figure_esm.js constants) +---------------------------------------------------- + PAD_L=58 PAD_R=12 PAD_T=12 PAD_B=42 GRID_PAD=8 + For a 400×300 fig: plot area = {x:66, y:20, w:330, h:246} + (page coords = canvas coords + GRID_PAD) +""" +from __future__ import annotations + +import numpy as np +import pytest + +import anyplotlib as apl +from anyplotlib.tests.test_interactive._event_test_utils import ( + _collect_events, + _get_events, + _plot_center_page, + GRID_PAD, +) + + +def _clear_events(page) -> None: + """Clear the accumulated event list without re-wrapping the model setter.""" + page.evaluate("() => { window._aplAllEvents = []; }") + +FIG_W, FIG_H = 400, 300 + +# Large enough move to clear the 4 px² drag threshold (>4 px in one direction). +DRAG_DISTANCE = 40 + + +# ── page factories ───────────────────────────────────────────────────────────── + +def _make_2d_page(interact_page): + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + plot = ax.imshow(np.zeros((32, 32))) + page = interact_page(fig) + _collect_events(page) + return page, plot + + +def _make_1d_page(interact_page): + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + plot = ax.plot(np.sin(np.linspace(0, 2 * np.pi, 128))) + page = interact_page(fig) + _collect_events(page) + return page, plot + + +def _make_3d_page(interact_page): + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + x = np.linspace(-1, 1, 8) + X, Y = np.meshgrid(x, x) + Z = X ** 2 + Y ** 2 + plot = ax.plot_surface(X, Y, Z) + page = interact_page(fig) + _collect_events(page) + return page, plot + + +# ══════════════════════════════════════════════════════════════════════════════ +# Double-click isolation +# ══════════════════════════════════════════════════════════════════════════════ + +class TestDoubleClickIsolation: + """double_click must fire even when the pan/drag machinery is active.""" + + # ── Click-cascade prerequisites (expose the e.preventDefault() bug) ─────── + # + # Playwright's page.mouse.dblclick() injects dblclick via CDP (clickCount=2), + # bypassing the browser's click → dblclick cascade entirely. To detect the + # real regression we must verify the prerequisite: that `click` fires after + # mousedown + mouseup. Chrome suppresses `click` when mousedown calls + # e.preventDefault(), which breaks every real user double-click. + + def test_click_fires_after_mousedown_2d(self, interact_page): + """click fires after mousedown+mouseup on the 2D canvas (dblclick prerequisite). + + Chrome spec: mousedown.preventDefault() suppresses the subsequent click. + Without click, the browser's dblclick cascade breaks for real users. + This test directly verifies the precondition: no e.preventDefault() in + the 2D pan mousedown must allow click to propagate. + """ + page, _ = _make_2d_page(interact_page) + px, py = _plot_center_page() + + page.evaluate("""() => { + window._aplClickCount = 0; + document.addEventListener('click', () => window._aplClickCount++, true); + }""") + + page.mouse.move(px, py) + page.mouse.down() + page.mouse.up() + page.wait_for_timeout(50) + + click_count = page.evaluate("() => window._aplClickCount") + assert click_count >= 1, ( + "click must fire after mousedown+mouseup on the 2D canvas. " + "e.preventDefault() on mousedown suppresses click → breaks dblclick " + "for real users. Fix: remove preventDefault from the 2D pan mousedown." + ) + + def test_click_fires_after_mousedown_1d(self, interact_page): + """click fires after mousedown+mouseup on the 1D canvas (dblclick prerequisite).""" + page, _ = _make_1d_page(interact_page) + px, py = _plot_center_page() + + page.evaluate("""() => { + window._aplClickCount = 0; + document.addEventListener('click', () => window._aplClickCount++, true); + }""") + + page.mouse.move(px, py) + page.mouse.down() + page.mouse.up() + page.wait_for_timeout(50) + + click_count = page.evaluate("() => window._aplClickCount") + assert click_count >= 1, ( + "click must fire after mousedown+mouseup on the 1D canvas. " + "e.preventDefault() on mousedown suppresses click → breaks dblclick." + ) + + def test_click_fires_after_mousedown_3d(self, interact_page): + """click fires after mousedown+mouseup on the 3D canvas (dblclick prerequisite).""" + page, _ = _make_3d_page(interact_page) + cx = FIG_W // 2 + GRID_PAD + cy = FIG_H // 2 + GRID_PAD + + page.evaluate("""() => { + window._aplClickCount = 0; + document.addEventListener('click', () => window._aplClickCount++, true); + }""") + + page.mouse.move(cx, cy) + page.mouse.down() + page.mouse.up() + page.wait_for_timeout(50) + + click_count = page.evaluate("() => window._aplClickCount") + assert click_count >= 1, ( + "click must fire after mousedown+mouseup on the 3D canvas. " + "e.preventDefault() on mousedown suppresses click → breaks dblclick." + ) + + # ── Synthetic dblclick tests (page.mouse.dblclick uses CDP clickCount=2) ── + + def test_dblclick_fires_on_2d_panel(self, interact_page): + """double_click is emitted when the user double-clicks a 2D panel.""" + page, _ = _make_2d_page(interact_page) + px, py = _plot_center_page() + + page.mouse.dblclick(px, py) + page.wait_for_timeout(150) + + events = _get_events(page, "double_click") + assert len(events) >= 1, "double_click must fire on dblclick" + assert events[0].get("button") == 0, "double_click button should be 0" + + def test_dblclick_fires_on_1d_panel(self, interact_page): + """double_click is emitted when the user double-clicks a 1D panel.""" + page, _ = _make_1d_page(interact_page) + px, py = _plot_center_page() + + page.mouse.dblclick(px, py) + page.wait_for_timeout(150) + + events = _get_events(page, "double_click") + assert len(events) >= 1, "double_click must fire on 1D dblclick" + + def test_dblclick_fires_on_3d_panel(self, interact_page): + """double_click is emitted when the user double-clicks a 3D panel.""" + page, _ = _make_3d_page(interact_page) + cx = FIG_W // 2 + GRID_PAD + cy = FIG_H // 2 + GRID_PAD + + page.mouse.dblclick(cx, cy) + page.wait_for_timeout(150) + + events = _get_events(page, "double_click") + assert len(events) >= 1, "double_click must fire on 3D dblclick" + + def test_dblclick_fires_after_preceding_drag(self, interact_page): + """double_click still fires after a preceding drag sequence. + + This guards the regression where the isPanning flag or the drag + document-level listener could prevent subsequent dblclick events. + """ + page, _ = _make_2d_page(interact_page) + px, py = _plot_center_page() + + # Perform a drag first + page.mouse.move(px, py) + page.mouse.down() + page.mouse.move(px + DRAG_DISTANCE, py, steps=8) + page.mouse.up() + page.wait_for_timeout(100) + + # Now double-click: double_click must still fire + _clear_events(page) + page.mouse.dblclick(px, py) + page.wait_for_timeout(150) + + events = _get_events(page, "double_click") + assert len(events) >= 1, ( + "double_click must fire after a preceding drag — " + "isPanning flag must not suppress dblclick" + ) + + def test_dblclick_has_correct_coordinates(self, interact_page): + """double_click payload carries plausible x/y coordinates.""" + page, _ = _make_2d_page(interact_page) + px, py = _plot_center_page() + + page.mouse.dblclick(px, py) + page.wait_for_timeout(150) + + events = _get_events(page, "double_click") + assert len(events) >= 1 + e = events[0] + # x/y should be within the canvas bounds (0..FIG_W, 0..FIG_H) + assert "x" in e and "y" in e, "double_click must carry x, y fields" + assert 0 <= e["x"] <= FIG_W, f"double_click x={e['x']} out of range" + assert 0 <= e["y"] <= FIG_H, f"double_click y={e['y']} out of range" + + def test_double_click_event_order(self, interact_page): + """On dblclick: pointer_down fires before double_click. + + The expected sequence is: pointer_down(×1-2) then double_click. + We verify that the last event in the sequence is double_click (not + the first), so the double_click is never emitted before its preceding + single-click path has had a chance to run. + """ + page, _ = _make_2d_page(interact_page) + px, py = _plot_center_page() + + page.mouse.dblclick(px, py) + page.wait_for_timeout(150) + + all_events = _get_events(page) + # At minimum: pointer_down events and double_click + event_types = [e.get("event_type") for e in all_events] + assert "double_click" in event_types, "double_click must be in event sequence" + last_relevant = [t for t in event_types if t in ("pointer_down", "double_click")] + assert last_relevant, "Expected pointer_down and/or double_click events" + assert last_relevant[-1] == "double_click", ( + f"double_click must be the last in the click sequence, got {last_relevant}" + ) + + +# ══════════════════════════════════════════════════════════════════════════════ +# Drag vs click distinction +# ══════════════════════════════════════════════════════════════════════════════ + +class TestDragVsClick: + """Drag and single-click are mutually exclusive event paths on 2D panels.""" + + def test_single_click_emits_pointer_down(self, interact_page): + """A short stationary click emits exactly one pointer_down.""" + page, _ = _make_2d_page(interact_page) + px, py = _plot_center_page() + + page.mouse.click(px, py) + page.wait_for_timeout(150) + + events = _get_events(page, "pointer_down") + assert len(events) == 1, ( + f"Expected exactly 1 pointer_down on single click, got {len(events)}" + ) + + def test_significant_drag_does_not_emit_pointer_down(self, interact_page): + """A drag with significant motion clears the click candidate → no pointer_down.""" + page, _ = _make_2d_page(interact_page) + px, py = _plot_center_page() + + page.mouse.move(px, py) + page.mouse.down() + # Move well past the 4 px threshold + page.mouse.move(px + DRAG_DISTANCE, py, steps=10) + page.mouse.up() + page.wait_for_timeout(150) + + pd_events = _get_events(page, "pointer_down") + assert len(pd_events) == 0, ( + f"Drag must not emit pointer_down (click candidate should be cleared), " + f"got {len(pd_events)} pointer_down events" + ) + + def test_significant_drag_emits_pointer_up(self, interact_page): + """A drag emits pointer_up on release.""" + page, _ = _make_2d_page(interact_page) + px, py = _plot_center_page() + + page.mouse.move(px, py) + page.mouse.down() + page.mouse.move(px + DRAG_DISTANCE, py, steps=10) + page.mouse.up() + page.wait_for_timeout(150) + + pu_events = _get_events(page, "pointer_up") + assert len(pu_events) >= 1, "Drag must emit at least one pointer_up on release" + + def test_drag_then_click_emits_pointer_down(self, interact_page): + """After a drag completes, a subsequent short click fires pointer_down.""" + page, _ = _make_2d_page(interact_page) + px, py = _plot_center_page() + + # Drag first + page.mouse.move(px, py) + page.mouse.down() + page.mouse.move(px + DRAG_DISTANCE, py, steps=10) + page.mouse.up() + page.wait_for_timeout(100) + + # Reset event collector + _clear_events(page) + + # Short click + page.mouse.click(px, py) + page.wait_for_timeout(150) + + events = _get_events(page, "pointer_down") + assert len(events) == 1, ( + "After a drag, a short click must still emit pointer_down" + ) + + def test_small_movement_still_registers_as_click(self, interact_page): + """Movement within the 2 px click threshold still triggers pointer_down.""" + page, _ = _make_2d_page(interact_page) + px, py = _plot_center_page() + + page.mouse.move(px, py) + page.mouse.down() + # Move less than 2 px — within the distance² ≤ 25 threshold + page.mouse.move(px + 1, py + 1, steps=2) + page.mouse.up() + page.wait_for_timeout(150) + + events = _get_events(page, "pointer_down") + assert len(events) == 1, ( + "Tiny movement within click threshold must still produce pointer_down" + ) + + +# ══════════════════════════════════════════════════════════════════════════════ +# Button filtering +# ══════════════════════════════════════════════════════════════════════════════ + +class TestButtonFiltering: + """Non-primary buttons must not trigger the 2D left-click event path.""" + + def test_right_click_does_not_emit_pointer_down(self, interact_page): + """Right-click (button=2) on a 2D panel does not emit pointer_down. + + The mousedown handler returns early for button !== 0, so no + clickCandidate is set and pointer_down must not fire. + """ + page, _ = _make_2d_page(interact_page) + px, py = _plot_center_page() + + page.mouse.click(px, py, button="right") + page.wait_for_timeout(150) + + events = _get_events(page, "pointer_down") + assert len(events) == 0, ( + "Right-click must not emit pointer_down (button !== 0 guard)" + ) + + def test_middle_click_does_not_emit_pointer_down(self, interact_page): + """Middle-click (button=1) on a 2D panel does not emit pointer_down.""" + page, _ = _make_2d_page(interact_page) + px, py = _plot_center_page() + + page.mouse.click(px, py, button="middle") + page.wait_for_timeout(150) + + events = _get_events(page, "pointer_down") + assert len(events) == 0, ( + "Middle-click must not emit pointer_down (button !== 0 guard)" + ) + + def test_left_click_emits_pointer_down(self, interact_page): + """Sanity-check: left-click still emits pointer_down after button tests.""" + page, _ = _make_2d_page(interact_page) + px, py = _plot_center_page() + + page.mouse.click(px, py) + page.wait_for_timeout(150) + + events = _get_events(page, "pointer_down") + assert len(events) == 1, "Left-click must emit pointer_down" + + +# ══════════════════════════════════════════════════════════════════════════════ +# Wheel independence +# ══════════════════════════════════════════════════════════════════════════════ + +class TestWheelIndependence: + """Wheel events fire independently of click/drag state.""" + + def test_wheel_after_click_still_fires(self, interact_page): + """wheel event fires correctly after a preceding click.""" + page, _ = _make_2d_page(interact_page) + px, py = _plot_center_page() + + page.mouse.click(px, py) + page.wait_for_timeout(50) + + _clear_events(page) + page.mouse.move(px, py) + page.mouse.wheel(0, 120) + page.wait_for_timeout(100) + + events = _get_events(page, "wheel") + assert len(events) >= 1, "wheel must fire after a preceding click" + assert "dy" in events[0], "wheel event must carry dy field" + + def test_wheel_during_drag_does_not_suppress_dblclick(self, interact_page): + """wheel event during an active pan does not block subsequent dblclick.""" + page, _ = _make_2d_page(interact_page) + px, py = _plot_center_page() + + # Drag + wheel + page.mouse.move(px, py) + page.mouse.down() + page.mouse.move(px + DRAG_DISTANCE, py, steps=5) + page.mouse.wheel(0, 120) + page.mouse.up() + page.wait_for_timeout(100) + + _clear_events(page) + page.mouse.dblclick(px, py) + page.wait_for_timeout(150) + + events = _get_events(page, "double_click") + assert len(events) >= 1, "double_click must fire after wheel+drag sequence" + + +# ══════════════════════════════════════════════════════════════════════════════ +# 1D panel event specifics +# ══════════════════════════════════════════════════════════════════════════════ + +class TestPlot1DEvents: + """1D panel event path regression tests.""" + + def test_1d_single_click_emits_pointer_down_when_near_line(self, interact_page): + """Short 1D click near the plotted line emits pointer_down.""" + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + # Flat line at y=0; the plot centre is near the line. + ax.plot(np.zeros(128)) + page = interact_page(fig) + _clear_events(page) + + px, py = _plot_center_page() + page.mouse.click(px, py) + page.wait_for_timeout(150) + + # pointer_down fires when the hit-test finds the line; if not found + # the event is simply not emitted — so we verify count is 0 or 1. + events = _get_events(page, "pointer_down") + # Not asserting exact count because line hit depends on render geometry. + # Key guarantee: no error raised, and no spurious extra pointer_down events. + assert isinstance(events, list) + + def test_1d_drag_does_not_emit_pointer_down(self, interact_page): + """A 1D drag larger than 5 px does not emit pointer_down.""" + page, _ = _make_1d_page(interact_page) + px, py = _plot_center_page() + + page.mouse.move(px, py) + page.mouse.down() + page.mouse.move(px + 30, py, steps=10) + page.mouse.up() + page.wait_for_timeout(150) + + events = _get_events(page, "pointer_down") + assert len(events) == 0, ( + "1D drag must not emit pointer_down (distance guard)" + ) + + def test_1d_dblclick_fires_double_click(self, interact_page): + """1D panel dblclick emits double_click, not blocked by pan state.""" + page, _ = _make_1d_page(interact_page) + px, py = _plot_center_page() + + page.mouse.dblclick(px, py) + page.wait_for_timeout(150) + + events = _get_events(page, "double_click") + assert len(events) >= 1, "1D dblclick must emit double_click" + + def test_1d_pointer_up_fires_on_drag(self, interact_page): + """1D drag emits pointer_up on release.""" + page, _ = _make_1d_page(interact_page) + px, py = _plot_center_page() + + page.mouse.move(px, py) + page.mouse.down() + page.mouse.move(px + 30, py, steps=10) + page.mouse.up() + page.wait_for_timeout(150) + + events = _get_events(page, "pointer_up") + assert len(events) >= 1, "1D drag must emit pointer_up on release" + + +# ══════════════════════════════════════════════════════════════════════════════ +# 3D panel event specifics +# ══════════════════════════════════════════════════════════════════════════════ + +class TestPlot3DEvents: + """3D panel event regression tests.""" + + def test_3d_dblclick_fires_double_click(self, interact_page): + """3D panel dblclick emits double_click despite drag being active.""" + page, _ = _make_3d_page(interact_page) + cx = FIG_W // 2 + GRID_PAD + cy = FIG_H // 2 + GRID_PAD + + page.mouse.dblclick(cx, cy) + page.wait_for_timeout(150) + + events = _get_events(page, "double_click") + assert len(events) >= 1, "3D dblclick must emit double_click" + + def test_3d_drag_emits_pointer_move(self, interact_page): + """3D drag emits pointer_move events (not blocked by drag state).""" + page, _ = _make_3d_page(interact_page) + cx = FIG_W // 2 + GRID_PAD + cy = FIG_H // 2 + GRID_PAD + + page.mouse.move(cx, cy) + page.mouse.down() + page.mouse.move(cx + 40, cy, steps=8) + page.mouse.up() + page.wait_for_timeout(150) + + events = _get_events(page, "pointer_move") + assert len(events) > 0, "3D drag must emit pointer_move events" + + def test_3d_dblclick_fires_after_drag(self, interact_page): + """3D double_click fires after a preceding drag sequence.""" + page, _ = _make_3d_page(interact_page) + cx = FIG_W // 2 + GRID_PAD + cy = FIG_H // 2 + GRID_PAD + + # Drag first + page.mouse.move(cx, cy) + page.mouse.down() + page.mouse.move(cx + 40, cy, steps=8) + page.mouse.up() + page.wait_for_timeout(100) + + _clear_events(page) + page.mouse.dblclick(cx, cy) + page.wait_for_timeout(150) + + events = _get_events(page, "double_click") + assert len(events) >= 1, ( + "3D double_click must fire after preceding drag" + ) + + def test_3d_wheel_fires_independently(self, interact_page): + """3D wheel event fires even during/after a drag.""" + page, _ = _make_3d_page(interact_page) + cx = FIG_W // 2 + GRID_PAD + cy = FIG_H // 2 + GRID_PAD + + page.mouse.move(cx, cy) + page.mouse.wheel(0, 120) + page.wait_for_timeout(100) + + events = _get_events(page, "wheel") + assert len(events) >= 1, "3D wheel must fire" + assert "dy" in events[0] + + +# ══════════════════════════════════════════════════════════════════════════════ +# Pointer enter / leave +# ══════════════════════════════════════════════════════════════════════════════ + +class TestPointerEnterLeave: + """pointer_enter and pointer_leave must fire independently of click/drag.""" + + def test_pointer_enter_fires_after_drag(self, interact_page): + """pointer_enter fires when entering after a drag on another part of the page.""" + page, _ = _make_2d_page(interact_page) + px, py = _plot_center_page() + + # Leave canvas, do a drag outside, then re-enter + page.mouse.move(0, 0) + page.wait_for_timeout(50) + page.mouse.move(px, py) + page.wait_for_timeout(100) + + events = _get_events(page, "pointer_enter") + assert len(events) >= 1, "pointer_enter must fire on canvas entry" + + def test_pointer_leave_fires_after_drag(self, interact_page): + """pointer_leave fires when leaving even if a drag is in progress.""" + page, _ = _make_2d_page(interact_page) + px, py = _plot_center_page() + + page.mouse.move(px, py) + page.wait_for_timeout(30) + _clear_events(page) + + # Move outside the figure entirely + page.mouse.move(0, 0) + page.wait_for_timeout(100) + + events = _get_events(page, "pointer_leave") + assert len(events) >= 1, "pointer_leave must fire on canvas exit" + + +# ══════════════════════════════════════════════════════════════════════════════ +# HAADF STEM nanoparticle picker regression +# ══════════════════════════════════════════════════════════════════════════════ + +class TestParticlePickerDblClick: + """Regression tests mirroring the HAADF STEM nanoparticle picker example. + + The picker's ``_on_double_click`` handler starts with:: + + if event.xdata is None or event.ydata is None: + return + + So if the JS ``double_click`` event payload does not include ``xdata`` and + ``ydata``, every pick silently fails. These tests reproduce that exact + failure mode. + """ + + def test_dblclick_payload_includes_xdata_ydata(self, interact_page): + """double_click event on a 2D imshow carries non-None xdata and ydata. + + Root cause: the dblclick handler in figure_esm.js was emitting only + canvas-pixel ``x``/``y``, not the image-space ``xdata``/``ydata`` + that Python handlers receive as ``event.xdata``/``event.ydata``. + The particle picker's guard ``if event.xdata is None: return`` meant + every double-click was silently dropped. + """ + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + plot = ax.imshow(np.zeros((64, 64))) + page = interact_page(fig) + _collect_events(page) + + px, py = _plot_center_page() + page.mouse.dblclick(px, py) + page.wait_for_timeout(150) + + events = _get_events(page, "double_click") + assert len(events) >= 1, "double_click must fire on dblclick" + e = events[0] + assert "xdata" in e, "double_click payload must include xdata" + assert "ydata" in e, "double_click payload must include ydata" + assert e["xdata"] is not None, "xdata must not be None" + assert e["ydata"] is not None, "ydata must not be None" + + def test_dblclick_xdata_ydata_are_image_coords(self, interact_page): + """xdata/ydata in double_click are image-space coordinates (0..N range). + + For a 64×64 image, a click at the canvas centre should produce + xdata and ydata near 32 (the image midpoint). + """ + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + plot = ax.imshow(np.zeros((64, 64))) + page = interact_page(fig) + _collect_events(page) + + px, py = _plot_center_page() + page.mouse.dblclick(px, py) + page.wait_for_timeout(150) + + events = _get_events(page, "double_click") + assert len(events) >= 1 + e = events[0] + # Image is 64×64; centre click should land roughly in the middle half. + assert 10 <= e["xdata"] <= 54, ( + f"xdata={e['xdata']:.1f} out of expected range for 64×64 image centre click" + ) + assert 10 <= e["ydata"] <= 54, ( + f"ydata={e['ydata']:.1f} out of expected range for 64×64 image centre click" + ) + + def test_dblclick_with_circles_markers_present(self, interact_page): + """double_click still carries xdata/ydata when circles markers are on the plot. + + The particle picker adds candidate circles before any interaction. + This test ensures markers don't interfere with the dblclick coordinate + computation. + """ + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + plot = ax.imshow(np.zeros((64, 64))) + # Mirror the particle picker: add candidate circles + candidates = np.array([[16.0, 16.0], [48.0, 48.0], [32.0, 32.0]]) + plot.add_circles(candidates, name="candidates", radius=6, + facecolors="none", edgecolors="#555555") + page = interact_page(fig) + _collect_events(page) + + px, py = _plot_center_page() + page.mouse.dblclick(px, py) + page.wait_for_timeout(150) + + events = _get_events(page, "double_click") + assert len(events) >= 1, "double_click must fire with circles present" + e = events[0] + assert e.get("xdata") is not None, "xdata must not be None with circles present" + assert e.get("ydata") is not None, "ydata must not be None with circles present" + + def test_dblclick_after_pan_carries_xdata_ydata(self, interact_page): + """After a pan (which shifts the viewport), dblclick still carries xdata/ydata. + + The particle picker is used with zoom/pan interactions before picking. + xdata/ydata must track the panned viewport, not the raw canvas offset. + """ + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + plot = ax.imshow(np.zeros((64, 64))) + page = interact_page(fig) + _collect_events(page) + + px, py = _plot_center_page() + + # Pan the viewport + page.mouse.move(px, py) + page.mouse.down() + page.mouse.move(px + 30, py + 20, steps=8) + page.mouse.up() + page.wait_for_timeout(100) + _clear_events(page) + + # Now double-click — xdata/ydata must reflect the panned position + page.mouse.dblclick(px, py) + page.wait_for_timeout(150) + + events = _get_events(page, "double_click") + assert len(events) >= 1, "double_click must fire after a pan" + e = events[0] + assert e.get("xdata") is not None, "xdata must not be None after pan" + assert e.get("ydata") is not None, "ydata must not be None after pan" + + +# ══════════════════════════════════════════════════════════════════════════════ +# HAADF STEM nanoparticle picker — dwell/settle regression +# ══════════════════════════════════════════════════════════════════════════════ + +class TestParticlePickerDwell: + """Regression tests mirroring the particle picker's pointer_settled handler. + + The picker's ``_on_settled`` starts with:: + + if event.xdata is None or event.ydata is None: + return + + So ``pointer_settled`` must include ``xdata``/``ydata`` for the dwell + inspection to work. These tests reproduce that exact failure mode and + guard the fix. + """ + + def _make_picker_page(self, interact_page, ms: int = 200): + """Build a page that mirrors the particle picker setup. + + Uses ms=200 so the test doesn't have to wait the full 300 ms of the + real example. The panel state is serialised into the standalone HTML + so JS sees ``pointer_settled_ms = 200`` without needing a Python kernel. + """ + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + plot = ax.imshow(np.zeros((64, 64))) + # Mirrors the picker: add candidate circles + candidates = np.array([[16.0, 16.0], [48.0, 48.0], [32.0, 32.0]]) + plot.add_circles(candidates, name="candidates", radius=6, + facecolors="none", edgecolors="#555555") + # Register a dummy handler so pointer_settled_ms is baked into state + plot.add_event_handler(lambda e: None, "pointer_settled", ms=ms, delta=6) + page = interact_page(fig) + _collect_events(page) + return page, plot + + def test_settled_payload_includes_xdata_ydata(self, interact_page): + """pointer_settled event on a 2D imshow carries non-None xdata and ydata. + + Root cause: the setTimeout callback in figure_esm.js was emitting only + canvas-pixel ``x``/``y``. The particle picker's guard + ``if event.xdata is None: return`` therefore caused every dwell + inspection to be silently skipped. + """ + page, plot = self._make_picker_page(interact_page) + px, py = _plot_center_page() + + # Move into the plot area and hold still — wait for the event + page.mouse.move(px, py) + page.wait_for_function( + "() => window._aplAllEvents.some(e => e.event_type === 'pointer_settled')", + timeout=2000, + ) + + events = _get_events(page, "pointer_settled") + assert len(events) >= 1, "pointer_settled must fire after dwell" + e = events[0] + assert "xdata" in e, "pointer_settled payload must include xdata" + assert "ydata" in e, "pointer_settled payload must include ydata" + assert e["xdata"] is not None, "xdata must not be None" + assert e["ydata"] is not None, "ydata must not be None" + + def test_settled_xdata_ydata_are_image_coords(self, interact_page): + """xdata/ydata in pointer_settled are image-space coordinates (0..N range). + + For a 64×64 image, a dwell at the canvas centre should produce + xdata and ydata near 32. + """ + page, plot = self._make_picker_page(interact_page) + px, py = _plot_center_page() + + page.mouse.move(px, py) + page.wait_for_function( + "() => window._aplAllEvents.some(e => e.event_type === 'pointer_settled')", + timeout=2000, + ) + + events = _get_events(page, "pointer_settled") + assert len(events) >= 1 + e = events[0] + assert 10 <= e["xdata"] <= 54, ( + f"xdata={e['xdata']:.1f} out of expected range for 64×64 image centre dwell" + ) + assert 10 <= e["ydata"] <= 54, ( + f"ydata={e['ydata']:.1f} out of expected range for 64×64 image centre dwell" + ) + + def test_settled_fires_after_configured_ms(self, interact_page): + """pointer_settled fires after the configured dwell period (ms=200). + + Guards the full pipeline: Python sets pointer_settled_ms in state → + state is serialised to HTML → JS reads it and arms the setTimeout → + event fires after the dwell period with dwell_ms >= 200. + """ + page, plot = self._make_picker_page(interact_page, ms=200) + px, py = _plot_center_page() + + # Verify JS received the configured ms value + ms_in_js = page.evaluate( + f"() => JSON.parse(window._aplModel.get('panel_{plot._id}_json')).pointer_settled_ms" + ) + assert ms_in_js == 200, f"JS pointer_settled_ms should be 200, got {ms_in_js}" + + page.mouse.move(px, py) + page.wait_for_function( + "() => window._aplAllEvents.some(e => e.event_type === 'pointer_settled')", + timeout=2000, + ) + + events = _get_events(page, "pointer_settled") + e = events[0] + assert "dwell_ms" in e, "pointer_settled must carry dwell_ms" + assert e["dwell_ms"] >= 200, ( + f"dwell_ms={e['dwell_ms']:.0f} should be >= 200" + ) + assert e.get("xdata") is not None, "xdata must be present after dwell" + assert e.get("ydata") is not None, "ydata must be present after dwell" + + def test_settled_not_fired_while_moving(self, interact_page): + """pointer_settled does not fire while the pointer keeps moving. + + The particle picker should only inspect a candidate when the user + deliberately hovers over it — not during panning. + """ + page, plot = self._make_picker_page(interact_page, ms=200) + px, py = _plot_center_page() + + # Keep moving for ~240 ms (less than 200 ms settle threshold between moves) + page.mouse.move(px, py) + for _ in range(8): + px += 5 + page.mouse.move(px, py) + page.wait_for_timeout(25) + + events = _get_events(page, "pointer_settled") + assert len(events) == 0, ( + "pointer_settled must not fire while pointer is continuously moving" + ) + + def test_settled_fires_after_pan_with_xdata_ydata(self, interact_page): + """After a pan, pointer_settled still carries correct xdata/ydata. + + The particle picker is frequently used after navigating the image. + The settled event must report the panned position, not the original + canvas position. + """ + page, plot = self._make_picker_page(interact_page) + px, py = _plot_center_page() + + # Pan the viewport first + page.mouse.move(px, py) + page.mouse.down() + page.mouse.move(px + 30, py + 20, steps=8) + page.mouse.up() + page.wait_for_timeout(50) + _clear_events(page) + + # Now hold still over the same canvas position + page.mouse.move(px, py) + page.wait_for_function( + "() => window._aplAllEvents.some(e => e.event_type === 'pointer_settled')", + timeout=2000, + ) + + events = _get_events(page, "pointer_settled") + assert len(events) >= 1, "pointer_settled must fire after pan + dwell" + e = events[0] + assert e.get("xdata") is not None, "xdata must not be None after pan" + assert e.get("ydata") is not None, "ydata must not be None after pan" + + +# ══════════════════════════════════════════════════════════════════════════════ +# Marker pixel-centre alignment (_imgToCanvas2d +0.5 fix) +# ══════════════════════════════════════════════════════════════════════════════ + +class TestMarkerPixelCenterAlignment: + """Circle markers must be drawn at (ix+0.5)*scale, not ix*scale. + + Each rendered image pixel i occupies canvas [i*scale, (i+1)*scale). + Its visual centre is at (i+0.5)*scale. Previously _imgToCanvas2d used + ix*scale (the leading/top-left edge), so every marker appeared shifted + 0.5*scale pixels up and to the left — visibly wrong when zoomed in. + + This regression test directly samples the markersCanvas pixel at the + point that lies on the circle ring only when the centre is correct. + """ + + def test_circle_drawn_at_pixel_center(self, interact_page): + """Circle at image pixel (8,8) is rendered at canvas centre (136,136). + + Setup: 16×16 image. 2D panels always reserve PAD_T=12px at the top, + so to get scale=16 we need imgW=imgH=256, which requires: + FIG_W=256, FIG_H=256+12=268 (no axes → no left/bottom gutters) + imgW=256, imgH=268-12=256 → scale=min(256/16,256/16)=16 + + correct centre = (8+0.5)*16 = 136 + old wrong centre = 8*16 = 128 + + A radius-0.5 circle (canvas radius 8) centred at (136,136) has its + ring passing through canvas (144,136). The old wrong circle would + have its ring passing through canvas (136,128) instead. + We sample (144,136) and require non-zero alpha. + """ + PAD_T = 12 + IMG_W = IMG_H = 16 + FIG_W = IMG_W * 16 # 256 — so imgW = FIG_W = 256, scale=16 + FIG_H = IMG_H * 16 + PAD_T # 268 — so imgH = FIG_H - PAD_T = 256, scale=16 + + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + plot = ax.imshow(np.zeros((IMG_H, IMG_W))) + # radius=0.5 image-px → 8 canvas-px at scale=16 + plot.add_circles(np.array([[8.0, 8.0]]), radius=0.5) + + page = interact_page(fig) + page.wait_for_timeout(300) + + alpha = page.evaluate("""() => { + const dpr = window.devicePixelRatio || 1; + // markersCanvas: pointer-events:none, z-index:6, visible + const mk = Array.from(document.querySelectorAll('canvas')) + .find(c => c.style.pointerEvents === 'none' && + c.style.zIndex === '6' && + c.style.display !== 'none' && + c.width > 0); + if (!mk) return -1; + const ctx = mk.getContext('2d'); + // If circle centre is at (136,136), the ring (r=8) passes through (144,136). + // Check a 3px neighbourhood to be robust against sub-pixel rendering. + let maxAlpha = 0; + for (let dx = -1; dx <= 1; dx++) { + for (let dy = -1; dy <= 1; dy++) { + const bx = Math.round((144 + dx) * dpr); + const by = Math.round((136 + dy) * dpr); + const d = ctx.getImageData(bx, by, 1, 1).data; + maxAlpha = Math.max(maxAlpha, d[3]); + } + } + return maxAlpha; + }""") + + assert alpha > 0, ( + "Circle ring should appear near canvas (144, 136) when the centre " + "is at (8+0.5)*16=136. alpha=0 means _imgToCanvas2d is still " + "placing the circle at the leading edge (8*16=128) instead of the " + "pixel centre (8.5*16=136)." + ) + + +# ══════════════════════════════════════════════════════════════════════════════ +# Modifier keys in key_down and pointer_down events +# ══════════════════════════════════════════════════════════════════════════════ + +class TestModifierKeys: + """Verify that modifier keys (ctrl, shift, alt) appear in event payloads. + + The JS _modifiers() helper always runs; these tests lock that invariant + so future refactors can't silently drop modifier detection. + """ + + def test_shift_modifier_in_key_down(self, interact_page): + """Shift+a fires key_down with modifiers=['shift'].""" + page, _ = _make_2d_page(interact_page) + cx, cy = _plot_center_page(FIG_W, FIG_H) + page.mouse.move(cx, cy) + _clear_events(page) + page.keyboard.press('Shift+a') + page.wait_for_timeout(80) + key_events = [e for e in _get_events(page, 'key_down') + if e.get('key', '').lower() == 'a'] + assert key_events, "key_down must fire for Shift+a" + assert 'shift' in key_events[-1].get('modifiers', []), ( + "Shift key must appear in modifiers list" + ) + + def test_ctrl_modifier_in_key_down(self, interact_page): + """Ctrl+a fires key_down with modifiers=['ctrl'].""" + page, _ = _make_2d_page(interact_page) + cx, cy = _plot_center_page(FIG_W, FIG_H) + page.mouse.move(cx, cy) + _clear_events(page) + page.keyboard.press('Control+a') + page.wait_for_timeout(80) + key_events = [e for e in _get_events(page, 'key_down') + if e.get('key', '').lower() == 'a'] + assert key_events, "key_down must fire for Ctrl+a" + assert 'ctrl' in key_events[-1].get('modifiers', []), ( + "Ctrl key must appear in modifiers list" + ) + + def test_no_modifier_on_plain_key(self, interact_page): + """Plain key press carries an empty modifiers list.""" + page, _ = _make_2d_page(interact_page) + cx, cy = _plot_center_page(FIG_W, FIG_H) + page.mouse.move(cx, cy) + _clear_events(page) + page.keyboard.press('a') + page.wait_for_timeout(80) + key_events = [e for e in _get_events(page, 'key_down') + if e.get('key', '').lower() == 'a'] + assert key_events, "key_down must fire for plain 'a'" + assert key_events[-1].get('modifiers', None) == [], ( + "Plain key must have empty modifiers list" + ) + + def test_shift_modifier_in_pointer_down(self, interact_page): + """pointer_down with Shift held carries modifiers=['shift'].""" + page, _ = _make_2d_page(interact_page) + cx, cy = _plot_center_page(FIG_W, FIG_H) + _clear_events(page) + page.keyboard.down('Shift') + page.mouse.click(cx, cy) + page.keyboard.up('Shift') + page.wait_for_timeout(80) + ptr_events = _get_events(page, 'pointer_down') + assert ptr_events, "pointer_down must fire on click" + assert 'shift' in ptr_events[-1].get('modifiers', []), ( + "Shift held during click must appear in pointer_down modifiers" + ) diff --git a/anyplotlib/tests/test_interactive/test_title.py b/anyplotlib/tests/test_interactive/test_title.py new file mode 100644 index 00000000..b8d05101 --- /dev/null +++ b/anyplotlib/tests/test_interactive/test_title.py @@ -0,0 +1,157 @@ +""" +Playwright tests verifying 2D title rendering. + +Title rendering +--------------- +2D image panels always reserve a PAD_T (12 px) strip at the top, matching 1D +behaviour. ``set_title(...)`` draws text in that strip via a dedicated +``titleCanvas`` (z-index 8) above the plotCanvas. The title must be visible +(non-zero alpha pixels) regardless of whether physical axes are provided. +""" +from __future__ import annotations + +import numpy as np + +import anyplotlib as apl + + +# ── helpers ─────────────────────────────────────────────────────────────────── + +def _title_pixel_count(page) -> int: + """Count non-transparent pixels in the titleCanvas (z-index:8).""" + return page.evaluate("""() => { + const tc = Array.from(document.querySelectorAll('canvas')) + .find(c => c.style.zIndex === '8'); + if (!tc) return -1; + const ctx = tc.getContext('2d'); + const d = ctx.getImageData(0, 0, tc.width, tc.height).data; + let n = 0; + for (let i = 3; i < d.length; i += 4) { if (d[i] > 0) n++; } + return n; + }""") + + +def _title_canvas_info(page) -> dict: + """Return display/position/size info about the titleCanvas.""" + return page.evaluate("""() => { + const tc = Array.from(document.querySelectorAll('canvas')) + .find(c => c.style.zIndex === '8'); + if (!tc) return null; + return { + display: tc.style.display, + top: tc.style.top, + left: tc.style.left, + cssWidth: tc.style.width, + cssHeight: tc.style.height, + physW: tc.width, + physH: tc.height, + }; + }""") + + +# ══════════════════════════════════════════════════════════════════════════════ +# 2D title rendering +# ══════════════════════════════════════════════════════════════════════════════ + +class TestTitle2DRendering: + """Title text must appear above the image in the PAD_T strip.""" + + def test_title_canvas_visible_without_axes(self, interact_page): + """titleCanvas is display:block for imshow WITHOUT explicit axes.""" + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + plot = ax.imshow(np.zeros((32, 32), dtype=np.float32)) + plot.set_title("Plain imshow title") + page = interact_page(fig) + page.wait_for_timeout(200) + + info = _title_canvas_info(page) + assert info is not None, "titleCanvas not found (z-index:8 canvas missing)" + assert info["display"] == "block", ( + f"titleCanvas must be display:block, got {info['display']!r}" + ) + + def test_title_canvas_visible_with_axes(self, interact_page): + """titleCanvas is display:block for imshow WITH explicit axes.""" + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + plot = ax.imshow( + np.zeros((32, 32), dtype=np.float32), + axes=[np.linspace(0, 10, 32)] * 2, + units="nm", + ) + plot.set_title("Physical axes title") + page = interact_page(fig) + page.wait_for_timeout(200) + + info = _title_canvas_info(page) + assert info is not None + assert info["display"] == "block" + + def test_title_text_renders_pixels(self, interact_page): + """set_title() produces non-transparent pixels in the titleCanvas.""" + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + plot = ax.imshow(np.zeros((32, 32), dtype=np.float32)) + plot.set_title("Hello World") + page = interact_page(fig) + page.wait_for_timeout(200) + + n = _title_pixel_count(page) + assert n > 0, ( + "set_title() must produce visible pixels in titleCanvas. " + f"Got {n} non-zero alpha pixels — title is not rendering." + ) + + def test_empty_title_produces_no_pixels(self, interact_page): + """An empty (unset) title leaves titleCanvas transparent.""" + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + plot = ax.imshow(np.zeros((32, 32), dtype=np.float32)) + # No set_title call + page = interact_page(fig) + page.wait_for_timeout(200) + + n = _title_pixel_count(page) + assert n == 0, ( + f"Empty title must leave titleCanvas transparent, got {n} pixels" + ) + + def test_title_canvas_in_top_strip(self, interact_page): + """titleCanvas top=0 and height=PAD_T (12 px) — sits above the image.""" + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + plot = ax.imshow(np.zeros((32, 32), dtype=np.float32)) + plot.set_title("Position check") + page = interact_page(fig) + page.wait_for_timeout(200) + + info = _title_canvas_info(page) + assert info is not None + assert info["top"] == "0px", ( + f"titleCanvas must sit at top:0, got top={info['top']!r}" + ) + assert info["cssHeight"] == "12px", ( + f"titleCanvas height must be PAD_T=12px, got {info['cssHeight']!r}" + ) + + def test_title_above_image_not_overlapping(self, interact_page): + """titleCanvas sits in the 12px gutter above the plotCanvas (no overlap). + + The plotCanvas must start at top ≥ 12px so the title strip is + unobstructed. + """ + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + plot = ax.imshow(np.zeros((32, 32), dtype=np.float32)) + plot.set_title("No overlap check") + page = interact_page(fig) + page.wait_for_timeout(200) + + plot_canvas_top = page.evaluate("""() => { + // z-index auto = plotCanvas (the image canvas) + const canvases = Array.from(document.querySelectorAll('canvas')); + const pc = canvases.find(c => !c.style.zIndex && c.style.position === 'absolute'); + return pc ? pc.style.top : null; + }""") + + assert plot_canvas_top is not None, "plotCanvas not found" + top_px = int(plot_canvas_top.replace("px", "")) + assert top_px >= 12, ( + f"plotCanvas top must be >= 12px (PAD_T) so title is above image, " + f"got top={top_px}px" + ) diff --git a/anyplotlib/tests/test_interactive/test_touch.py b/anyplotlib/tests/test_interactive/test_touch.py new file mode 100644 index 00000000..58532a36 --- /dev/null +++ b/anyplotlib/tests/test_interactive/test_touch.py @@ -0,0 +1,229 @@ +""" +Touch input tests — the touch-to-mouse bridge in figure_esm.js makes plots +usable on iPad / iPhone: + + * 1-finger drag → pan / orbit / drag a widget / ROI / marker / plane + * 2-finger pinch → zoom (mapped to wheel) + * double-tap → dblclick → the panel's double_click event (picking / + app callbacks); reset-zoom is the ``r`` key, unchanged + +These drive the SAME handlers the mouse uses, via synthesised MouseEvent / +WheelEvent, so a passing mouse interaction implies a passing touch one. The +tests use Playwright's touch emulation (``has_touch=True``) and dispatch raw +TouchEvents (Playwright has no high-level multi-touch drag helper). +""" +from __future__ import annotations + +import json +import pathlib +import tempfile + +import numpy as np +import pytest + +import anyplotlib as apl +from anyplotlib.tests.conftest import _build_interact_html + + +# ── touch-enabled page fixture ──────────────────────────────────────────────── + +@pytest.fixture +def touch_page(_pw_browser): + """Open a figure in a touch-enabled context; return the live Page.""" + contexts, paths = [], [] + + def _open(widget): + html = _build_interact_html(widget) + with tempfile.NamedTemporaryFile( + suffix=".html", mode="w", encoding="utf-8", delete=False + ) as fh: + fh.write(html) + tmp = pathlib.Path(fh.name) + paths.append(tmp) + ctx = _pw_browser.new_context(has_touch=True, + viewport={"width": 600, "height": 600}) + contexts.append(ctx) + page = ctx.new_page() + page.goto(tmp.as_uri()) + page.wait_for_function("() => window._aplReady === true", timeout=15_000) + page.evaluate( + "() => new Promise(r => requestAnimationFrame(() => requestAnimationFrame(r)))" + ) + return page + + yield _open + for c in contexts: + try: + c.close() + except Exception: + pass + for p in paths: + p.unlink(missing_ok=True) + + +# ── touch-gesture helpers (raw TouchEvent dispatch) ─────────────────────────── + +_OVERLAY = "[...document.querySelectorAll('canvas')].find(x => x.style.zIndex === '5')" + + +def _overlay_box(page): + return page.evaluate( + f"() => {{ const c = {_OVERLAY}; const r = c.getBoundingClientRect();" + f" return {{ x: r.x, y: r.y, w: r.width, h: r.height }}; }}") + + +def _touch_drag(page, x0, y0, dx, dy, steps=6): + page.evaluate( + f"""([x0,y0,dx,dy,steps]) => {{ + const c = {_OVERLAY}; + const mk = (x,y) => new Touch({{identifier:1, target:c, clientX:x, clientY:y}}); + const tev = (t,x,y) => new TouchEvent(t, {{ + touches: t==='touchend' ? [] : [mk(x,y)], + changedTouches:[mk(x,y)], targetTouches: t==='touchend'?[]:[mk(x,y)], + bubbles:true, cancelable:true }}); + c.dispatchEvent(tev('touchstart', x0, y0)); + for (let i=1;i<=steps;i++) + c.dispatchEvent(tev('touchmove', x0+dx*i/steps, y0+dy*i/steps)); + c.dispatchEvent(tev('touchend', x0+dx, y0+dy)); + }}""", [x0, y0, dx, dy, steps]) + + +def _touch_tap(page, x, y): + page.evaluate( + f"""([x,y]) => {{ + const c = {_OVERLAY}; + const t = new Touch({{identifier:1, target:c, clientX:x, clientY:y}}); + c.dispatchEvent(new TouchEvent('touchstart', {{touches:[t],changedTouches:[t],bubbles:true,cancelable:true}})); + c.dispatchEvent(new TouchEvent('touchend', {{touches:[],changedTouches:[t],bubbles:true,cancelable:true}})); + }}""", [x, y]) + + +def _pinch(page, cx, cy, start_half=20, end_half=130, steps=8): + """Two-finger pinch centred at (cx,cy); spread = zoom in.""" + page.evaluate( + f"""([cx,cy,sh,eh,steps]) => {{ + const c = {_OVERLAY}; + const mk = (id,x,y) => new Touch({{identifier:id, target:c, clientX:x, clientY:y}}); + const tev = (t,ts) => new TouchEvent(t, {{touches:ts, changedTouches:ts, targetTouches:ts, bubbles:true, cancelable:true}}); + c.dispatchEvent(tev('touchstart', [mk(1,cx-sh,cy), mk(2,cx+sh,cy)])); + for (let i=1;i<=steps;i++) {{ + const h = sh + (eh-sh)*i/steps; + c.dispatchEvent(tev('touchmove', [mk(1,cx-h,cy), mk(2,cx+h,cy)])); + }} + c.dispatchEvent(tev('touchend', [])); + }}""", [cx, cy, start_half, end_half, steps]) + + +def _panel_state(page, pid): + return json.loads(page.evaluate(f"() => window._aplModel.get('panel_{pid}_json')")) + + +def _no_errors(page): + errs = [] + page.on("pageerror", lambda e: errs.append(str(e))) + return errs + + +# ── tests ───────────────────────────────────────────────────────────────────── + +class TestTouch2D: + def test_one_finger_drags_crosshair_widget(self, touch_page): + fig, ax = apl.subplots(1, 1, figsize=(400, 400)) + plot = ax.imshow(np.zeros((64, 64), dtype=np.float32)) + cw = plot.add_widget("crosshair", cx=32, cy=32, color="#ff0000") + page = touch_page(fig) + b = _overlay_box(page) + # crosshair at image-centre → overlay-centre (no axis gutters for plain imshow) + cx, cy = b["x"] + b["w"] * 0.5, b["y"] + b["h"] * 0.5 + before = _panel_state(page, plot._id)["overlay_widgets"][0] + _touch_drag(page, cx, cy, -80, -60) + page.wait_for_timeout(150) + after = _panel_state(page, plot._id)["overlay_widgets"][0] + assert abs(after["cx"] - before["cx"]) > 3 or abs(after["cy"] - before["cy"]) > 3, \ + f"crosshair did not move on 1-finger drag: {before} -> {after}" + + def test_pinch_zooms_image(self, touch_page): + fig, ax = apl.subplots(1, 1, figsize=(400, 400)) + plot = ax.imshow(np.zeros((64, 64), dtype=np.float32)) + page = touch_page(fig) + b = _overlay_box(page) + z0 = _panel_state(page, plot._id)["zoom"] + _pinch(page, b["x"] + b["w"] * 0.5, b["y"] + b["h"] * 0.5) + page.wait_for_timeout(150) + z1 = _panel_state(page, plot._id)["zoom"] + assert z1 > z0 + 0.1, f"pinch-out did not zoom in: {z0} -> {z1}" + + def test_no_console_errors(self, touch_page): + fig, ax = apl.subplots(1, 1, figsize=(400, 400)) + plot = ax.imshow(np.zeros((32, 32), dtype=np.float32)) + plot.add_widget("crosshair", cx=16, cy=16) + page = touch_page(fig) + errs = _no_errors(page) + b = _overlay_box(page) + _touch_drag(page, b["x"] + b["w"]*0.5, b["y"] + b["h"]*0.5, 40, 30) + _pinch(page, b["x"] + b["w"]*0.5, b["y"] + b["h"]*0.5) + page.wait_for_timeout(150) + assert not errs, f"touch interaction raised errors: {errs}" + + +class TestTouch3D: + def test_one_finger_orbits(self, touch_page): + fig, ax = apl.subplots(1, 1, figsize=(360, 360)) + g = np.linspace(-2, 2, 16); X, Y = np.meshgrid(g, g) + v = ax.plot_surface(X, Y, np.sin(np.sqrt(X**2 + Y**2)), azimuth=-60) + page = touch_page(fig) + b = _overlay_box(page) + az0 = _panel_state(page, v._id)["azimuth"] + _touch_drag(page, b["x"] + b["w"]*0.5, b["y"] + b["h"]*0.5, 90, 0) + page.wait_for_timeout(150) + az1 = _panel_state(page, v._id)["azimuth"] + assert abs(az1 - az0) > 5, f"3-D did not orbit on 1-finger drag: {az0} -> {az1}" + + def test_pinch_zooms(self, touch_page): + fig, ax = apl.subplots(1, 1, figsize=(360, 360)) + g = np.linspace(-2, 2, 16); X, Y = np.meshgrid(g, g) + v = ax.plot_surface(X, Y, np.sin(np.sqrt(X**2 + Y**2))) + page = touch_page(fig) + b = _overlay_box(page) + z0 = _panel_state(page, v._id)["zoom"] + _pinch(page, b["x"] + b["w"]*0.5, b["y"] + b["h"]*0.5) + page.wait_for_timeout(150) + z1 = _panel_state(page, v._id)["zoom"] + assert z1 != z0, f"3-D pinch did not change zoom: {z0} -> {z1}" + + +class TestTouch1D: + def test_one_finger_drags_vline(self, touch_page): + fig, ax = apl.subplots(1, 1, figsize=(400, 260)) + p = ax.plot(np.sin(np.linspace(0, 6, 100))) + p.add_vline_widget(50.0) + page = touch_page(fig) + b = _overlay_box(page) + # vline x=50/99 maps into the data area [PAD_L, w-PAD_R] = [58, w-12] + PAD_L, PAD_R = 58, 12 + line_x = b["x"] + PAD_L + (50/99.0) * (b["w"] - PAD_L - PAD_R) + cy = b["y"] + b["h"] * 0.5 + before = _panel_state(page, p._id)["overlay_widgets"][0]["x"] + _touch_drag(page, line_x, cy, 80, 0) + page.wait_for_timeout(150) + after = _panel_state(page, p._id)["overlay_widgets"][0]["x"] + assert abs(after - before) > 2, f"vline did not move on touch drag: {before} -> {after}" + + +class TestTouchDoubleTap: + def test_double_tap_fires_double_click(self, touch_page): + """A quick second tap near the first synthesises a dblclick, which the + 2-D handler turns into a ``double_click`` event (for picking / app + callbacks) — exactly as a mouse double-click does.""" + fig, ax = apl.subplots(1, 1, figsize=(400, 400)) + plot = ax.imshow(np.zeros((64, 64), dtype=np.float32)) + page = touch_page(fig) + b = _overlay_box(page) + cx, cy = b["x"] + b["w"]*0.5, b["y"] + b["h"]*0.5 + _touch_tap(page, cx, cy) + _touch_tap(page, cx, cy) # second tap within 300ms → dblclick + page.wait_for_timeout(120) + ev = json.loads(page.evaluate("() => window._aplModel.get('event_json')")) + assert ev.get("event_type") == "double_click", \ + f"double-tap did not fire double_click: {ev.get('event_type')}" + assert ev.get("panel_id") == plot._id diff --git a/anyplotlib/tests/test_interactive/test_widgets.py b/anyplotlib/tests/test_interactive/test_widgets.py new file mode 100644 index 00000000..17c6dd23 --- /dev/null +++ b/anyplotlib/tests/test_interactive/test_widgets.py @@ -0,0 +1,1380 @@ +""" +tests/test_interactive/test_widgets.py +======================================= + +Tests for the Widget class system and the event_json dispatch pipeline. + +Covers: + * Widget creation, attribute access, set(), to_dict(), __setattr__ + * add_event_handler / remove_handler (new _EventMixin API) + * _update_from_js — always fires for pointer_up/pointer_down + * Widget visibility — hide() / show() + * Plot2D / Plot1D widget integration (add / remove / list / clear) + * Figure event_json dispatch (JS→Python path via _simulate_js_event) + * End-to-end FFT example with simulated JS drag + * Interactive fitting scenario (PointWidget + RangeWidget + line.on_click) + +Callback infrastructure (Event, CallbackRegistry, plot-level callbacks, +Figure routing) is tested in ``test_callbacks.py``. +""" + +from __future__ import annotations + +import json +import numpy as np +import pytest + +import anyplotlib as apl +from anyplotlib.callbacks import Event +from anyplotlib.widgets import ( + Widget, RectangleWidget, CircleWidget, AnnularWidget, + CrosshairWidget, PolygonWidget, LabelWidget, + VLineWidget, HLineWidget, RangeWidget, +) + + +# ───────────────────────────────────────────────────────────────────────────── +# Helper: simulate JS sending an interaction event +# ───────────────────────────────────────────────────────────────────────────── + +def _simulate_js_event(fig, plot, event_type: str, *, widget_id=None, **fields): + """Simulate what JS does when the user interacts with a widget. + + JS writes to event_json: + { source:"js", panel_id, event_type, widget_id?, ...fields } + """ + payload = {"source": "js", "panel_id": plot._id, "event_type": event_type} + if widget_id is not None: + payload["widget_id"] = widget_id if isinstance(widget_id, str) else widget_id._id + payload.update(fields) + fig._on_event({"new": json.dumps(payload)}) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# 1. Widget class unit tests +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestWidgetBase: + def test_rectangle_attributes(self): + w = RectangleWidget(lambda: None, x=10, y=20, w=30, h=40) + assert w.x == 10.0 and w.y == 20.0 and w.w == 30.0 and w.h == 40.0 + assert w._type == "rectangle" + + def test_circle_attributes(self): + w = CircleWidget(lambda: None, cx=5, cy=6, r=7) + assert w.cx == 5.0 and w.r == 7.0 + + def test_annular_validates(self): + with pytest.raises(ValueError, match="r_inner"): + AnnularWidget(lambda: None, cx=0, cy=0, r_outer=5, r_inner=10) + + def test_polygon_validates(self): + with pytest.raises(ValueError, match="3 vertices"): + PolygonWidget(lambda: None, vertices=[[0, 0], [1, 1]]) + + def test_set_updates_and_pushes(self): + pushed = [] + w = RectangleWidget(lambda: pushed.append(1), x=0, y=0, w=10, h=10) + w.set(x=50) + assert w.x == 50.0 + assert len(pushed) == 1 + + def test_set_no_push(self): + pushed = [] + w = RectangleWidget(lambda: pushed.append(1), x=0, y=0, w=10, h=10) + w.set(_push=False, x=50) + assert w.x == 50.0 + assert len(pushed) == 0 + + def test_to_dict(self): + w = CircleWidget(lambda: None, cx=1, cy=2, r=3) + d = w.to_dict() + assert d["cx"] == 1.0 and d["type"] == "circle" and "id" in d + + def test_get(self): + w = RectangleWidget(lambda: None, x=10, y=20, w=30, h=40) + assert w.get("x") == 10.0 + assert w.get("missing", 99) == 99 + + def test_unknown_attr_raises(self): + w = RectangleWidget(lambda: None, x=0, y=0, w=1, h=1) + with pytest.raises(AttributeError, match="no_such"): + _ = w.no_such + + def test_repr(self): + w = RectangleWidget(lambda: None, x=1, y=2, w=3, h=4) + assert "RectangleWidget" in repr(w) and "1" in repr(w) + + def test_id_property(self): + w = RectangleWidget(lambda: None, x=0, y=0, w=1, h=1) + assert isinstance(w.id, str) and len(w.id) == 8 + + def test_setattr_routes_through_set(self): + """Public attribute assignment should call set() and push.""" + pushed = [] + w = RectangleWidget(lambda: pushed.append(1), x=0, y=0, w=10, h=10) + w.x = 40.0 + assert w.x == pytest.approx(40.0) + assert len(pushed) == 1 # set() triggered the push + + def test_setattr_private_bypasses_set(self): + """Private attributes must not go through set().""" + pushed = [] + w = RectangleWidget(lambda: pushed.append(1), x=0, y=0, w=10, h=10) + pushed.clear() + w._custom = "private" + assert len(pushed) == 0 + + def test_setattr_callbacks_bypasses_set(self): + """'callbacks' attribute assignment must never go through set().""" + from anyplotlib.callbacks import CallbackRegistry + pushed = [] + w = RectangleWidget(lambda: pushed.append(1), x=0, y=0, w=10, h=10) + pushed.clear() + w.callbacks = CallbackRegistry() # must not crash or push + assert len(pushed) == 0 + + +class TestWidgetCallbacks: + def test_on_changed_fires(self): + w = RectangleWidget(lambda: None, x=0, y=0, w=10, h=10) + results = [] + w.add_event_handler(lambda event: results.append(w.x), "pointer_move") + w.set(x=42) + assert results == [42.0] + + def test_on_changed_event_source_is_widget(self): + w = CircleWidget(lambda: None, cx=0, cy=0, r=5) + received = [] + w.add_event_handler(lambda event: received.append(event.source), "pointer_move") + w.set(cx=10) + assert received[0] is w + + def test_multiple_callbacks(self): + w = RectangleWidget(lambda: None, x=0, y=0, w=10, h=10) + a, b = [], [] + w.add_event_handler(lambda event: a.append(1), "pointer_move") + w.add_event_handler(lambda event: b.append(1), "pointer_move") + w.set(x=1) + assert len(a) == 1 and len(b) == 1 + + def test_disconnect_by_fn(self): + """Disconnecting using the function object should work.""" + w = RectangleWidget(lambda: None, x=0, y=0, w=10, h=10) + results = [] + fn = lambda event: results.append(1) + w.add_event_handler(fn, "pointer_move") + w.set(x=1); assert len(results) == 1 + w.remove_handler(fn) + w.set(x=2); assert len(results) == 1 + + def test_disconnect_by_cid(self): + """Disconnecting using remove_handler with a callable should work.""" + w = RectangleWidget(lambda: None, x=0, y=0, w=10, h=10) + results = [] + fn = lambda event: results.append(1) + w.add_event_handler(fn, "pointer_move") + w.remove_handler(fn) + w.set(x=2) + assert results == [] + + def test_disconnect_nonexistent_silent(self): + w = RectangleWidget(lambda: None, x=0, y=0, w=10, h=10) + w.remove_handler(9999) + + def test_on_release_decorator(self): + w = RectangleWidget(lambda: None, x=0, y=0, w=10, h=10) + results = [] + w.add_event_handler(lambda event: results.append(event.event_type), "pointer_up") + w.callbacks.fire(Event("pointer_up", w)) + assert results == ["pointer_up"] + + def test_on_click_decorator(self): + w = CircleWidget(lambda: None, cx=0, cy=0, r=5) + results = [] + w.add_event_handler(lambda event: results.append(event.event_type), "pointer_down") + w.callbacks.fire(Event("pointer_down", w)) + assert results == ["pointer_down"] + + +class TestWidgetUpdateFromJs: + def test_update_returns_true_on_change(self): + w = RectangleWidget(lambda: None, x=0, y=0, w=10, h=10) + assert w._update_from_js({"x": 5.0}) + + def test_update_returns_false_on_no_change(self): + w = RectangleWidget(lambda: None, x=0, y=0, w=10, h=10, color="#00e5ff") + assert not w._update_from_js( + {"id": w.id, "type": "rectangle", + "x": 0.0, "y": 0.0, "w": 10.0, "h": 10.0, "color": "#00e5ff"}) + + def test_update_fires_on_changed_when_changed(self): + w = RectangleWidget(lambda: None, x=0, y=0, w=10, h=10) + results = [] + w.add_event_handler(lambda event: results.append(event.x), "pointer_move") + w._update_from_js({"x": 99.0}) + assert results == [99.0] + + def test_update_does_not_fire_on_changed_if_unchanged(self): + w = RectangleWidget(lambda: None, x=5, y=5, w=10, h=10, color="#abc") + results = [] + w.add_event_handler(lambda event: results.append(1), "pointer_move") + w._update_from_js({"x": 5.0, "y": 5.0, "w": 10.0, "h": 10.0, "color": "#abc"}) + assert results == [] + + def test_update_always_fires_on_release(self): + """pointer_up fires even when nothing changed (drag ended in place).""" + w = RectangleWidget(lambda: None, x=5, y=5, w=10, h=10) + results = [] + w.add_event_handler(lambda event: results.append(1), "pointer_up") + w._update_from_js({"x": 5.0, "y": 5.0, "w": 10.0, "h": 10.0}, + event_type="pointer_up") + assert results == [1] + + def test_update_always_fires_on_click(self): + """pointer_down fires even when nothing changed.""" + w = CrosshairWidget(lambda: None, cx=16.0, cy=16.0) + results = [] + w.add_event_handler(lambda event: results.append(1), "pointer_down") + w._update_from_js({"cx": 16.0, "cy": 16.0}, event_type="pointer_down") + assert results == [1] + + def test_id_and_type_ignored(self): + w = RectangleWidget(lambda: None, x=0, y=0, w=10, h=10) + old_id = w.id + w._update_from_js({"id": "FAKE", "type": "FAKE", "x": 1.0}) + assert w.id == old_id and w._type == "rectangle" + + +# ═══════════════════════════════════════════════════════════════════════════════ +# 2. Plot2D widget integration +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestPlot2DWidgets: + def test_add_widget_returns_widget_object(self): + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + w = v.add_widget("rectangle", x=10, y=10, w=20, h=20) + assert isinstance(w, RectangleWidget) and w.x == 10.0 + + def test_add_circle(self): + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + w = v.add_widget("circle", cx=16, cy=16, r=5) + assert isinstance(w, CircleWidget) and w.cx == 16.0 + + def test_add_crosshair(self): + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + assert isinstance(v.add_widget("crosshair"), CrosshairWidget) + + def test_add_annular(self): + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + assert isinstance(v.add_widget("annular", r_outer=10, r_inner=5), AnnularWidget) + + def test_add_polygon(self): + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + w = v.add_widget("polygon", vertices=[[0,0],[10,0],[10,10],[0,10]]) + assert isinstance(w, PolygonWidget) + + def test_add_label(self): + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + w = v.add_widget("label", x=5, y=5, text="hello") + assert isinstance(w, LabelWidget) and w.text == "hello" + + def test_invalid_kind_raises(self): + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + with pytest.raises(ValueError): + v.add_widget("nonexistent") + + def test_get_widget_by_id(self): + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + w = v.add_widget("rectangle", x=1, y=2, w=3, h=4) + assert v.get_widget(w.id) is w + + def test_get_widget_by_object(self): + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + w = v.add_widget("circle") + assert v.get_widget(w) is w + + def test_remove_widget(self): + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + w = v.add_widget("rectangle") + v.remove_widget(w) + assert len(v.list_widgets()) == 0 + + def test_list_widgets(self): + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + v.add_widget("circle"); v.add_widget("rectangle") + assert len(v.list_widgets()) == 2 + + def test_clear_widgets(self): + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + v.add_widget("circle"); v.add_widget("rectangle") + v.clear_widgets() + assert v.list_widgets() == [] + + def test_widget_set_updates_state_dict(self): + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + w = v.add_widget("rectangle", x=0, y=0, w=10, h=10) + w.set(x=99) + found = [wd for wd in v.to_state_dict()["overlay_widgets"] if wd["id"] == w.id] + assert found[0]["x"] == 99.0 + + def test_to_state_dict_includes_widgets(self): + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + v.add_widget("circle", cx=1, cy=2, r=3) + d = v.to_state_dict() + assert len(d["overlay_widgets"]) == 1 + assert d["overlay_widgets"][0]["cx"] == 1.0 + + def test_setattr_moves_widget(self): + """widget.x = 40 triggers push and updates _data.""" + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + w = v.add_widget("rectangle", x=0.0, y=0.0, w=10.0, h=10.0) + w.x = 40.0 + assert w.x == pytest.approx(40.0) + d = v.to_state_dict()["overlay_widgets"] + assert d[0]["x"] == pytest.approx(40.0) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# 3. Plot1D widget integration +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestPlot1DWidgets: + def test_add_vline_returns_widget(self): + fig, ax = apl.subplots(1, 1) + v = ax.plot(np.zeros(64)) + w = v.add_vline_widget(x=10.0) + assert isinstance(w, VLineWidget) and w.x == 10.0 + + def test_add_hline_returns_widget(self): + fig, ax = apl.subplots(1, 1) + v = ax.plot(np.zeros(64)) + w = v.add_hline_widget(y=0.5) + assert isinstance(w, HLineWidget) and w.y == 0.5 + + def test_add_range_returns_widget(self): + fig, ax = apl.subplots(1, 1) + v = ax.plot(np.zeros(64)) + w = v.add_range_widget(x0=10, x1=20) + assert isinstance(w, RangeWidget) and w.x0 == 10.0 + + def test_remove_widget(self): + fig, ax = apl.subplots(1, 1) + v = ax.plot(np.zeros(64)) + w = v.add_vline_widget(x=5) + v.remove_widget(w) + assert len(v.list_widgets()) == 0 + + +# ═══════════════════════════════════════════════════════════════════════════════ +# 4. Figure event_json dispatch (the JS→Python path) +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestEventJsonDispatch: + """Simulate what JS does: write event_json with source:"js". + Verify that Widget callbacks fire correctly.""" + + def test_rectangle_drag_fires_on_changed(self): + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + w = v.add_widget("rectangle", x=10, y=10, w=20, h=20) + results = [] + w.add_event_handler(lambda event: results.append((event.x, event.y)), "pointer_move") + + _simulate_js_event(fig, v, "pointer_move", widget_id=w, x=50.0, y=60.0) + + assert len(results) == 1 + assert results[0] == (50.0, 60.0) + assert w.x == 50.0 and w.y == 60.0 + + def test_no_change_no_on_changed_callback(self): + """pointer_move must NOT fire when nothing actually changed.""" + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + w = v.add_widget("rectangle", x=10, y=10, w=20, h=20) + results = [] + w.add_event_handler(lambda event: results.append(1), "pointer_move") + + _simulate_js_event(fig, v, "pointer_move", widget_id=w, + x=10.0, y=10.0, w=20.0, h=20.0) + assert results == [] + + def test_on_release_always_fires(self): + """pointer_up fires even when position didn't change.""" + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + w = v.add_widget("rectangle", x=10, y=10, w=20, h=20) + results = [] + w.add_event_handler(lambda event: results.append(1), "pointer_up") + + _simulate_js_event(fig, v, "pointer_up", widget_id=w, + x=10.0, y=10.0, w=20.0, h=20.0) + assert len(results) == 1 + + def test_on_click_fires(self): + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + w = v.add_widget("crosshair", cx=16.0, cy=16.0) + results = [] + w.add_event_handler(lambda event: results.append(w.cx), "pointer_down") + + _simulate_js_event(fig, v, "pointer_down", widget_id=w, cx=16.0, cy=16.0) + assert len(results) == 1 + assert results[0] == pytest.approx(16.0) + + def test_on_click_line1d_overlay_fires(self): + """Line1D.add_event_handler fires when JS sends pointer_down with the matching line_id.""" + fig, ax = apl.subplots(1, 1) + v = ax.plot(np.zeros(64)) + line = v.add_line(np.ones(64), color="#ff0000") + results = [] + line.add_event_handler(lambda event: results.append(event.line_id), "pointer_down") + + _simulate_js_event(fig, v, "pointer_down", line_id=line.id) + assert len(results) == 1 + assert results[0] == line.id + + def test_on_click_line1d_primary_fires(self): + """Line1D.add_event_handler on the primary line fires when JS sends pointer_down with no line_id.""" + fig, ax = apl.subplots(1, 1) + v = ax.plot(np.zeros(64)) + results = [] + v.line.add_event_handler(lambda event: results.append(1), "pointer_down") + + # No line_id in payload → event.line_id is None → matches primary + _simulate_js_event(fig, v, "pointer_down") + assert len(results) == 1 + + def test_on_click_line1d_wrong_id_no_fire(self): + """Line1D.add_event_handler does NOT fire when the JS event carries a different line_id.""" + fig, ax = apl.subplots(1, 1) + v = ax.plot(np.zeros(64)) + line = v.add_line(np.ones(64), color="#00ff00") + results = [] + line.add_event_handler(lambda event: results.append(1), "pointer_down") + + _simulate_js_event(fig, v, "pointer_down", line_id="completely-wrong-id") + assert results == [] + + def test_circle_drag(self): + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + w = v.add_widget("circle", cx=16, cy=16, r=5) + results = [] + w.add_event_handler(lambda event: results.append(w.cx), "pointer_move") + + _simulate_js_event(fig, v, "pointer_move", widget_id=w, cx=25.0) + assert results == [25.0] + + def test_python_set_does_not_echo(self): + """Python widget.set() triggers pointer_move once (from set itself), + but the subsequent event_json push must NOT re-fire callbacks.""" + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + w = v.add_widget("rectangle", x=10, y=10, w=20, h=20) + results = [] + w.add_event_handler(lambda event: results.append("cb"), "pointer_move") + + w.set(x=99) + assert results == ["cb"] # one fire from set() + results.clear() + + # The push to event_json has source:"python" — must be ignored + assert results == [] + + def test_multi_widget_only_changed_fires(self): + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + w1 = v.add_widget("circle", cx=10, cy=10, r=5) + w2 = v.add_widget("rectangle", x=0, y=0, w=10, h=10) + r1, r2 = [], [] + w1.add_event_handler(lambda e: r1.append(1), "pointer_move") + w2.add_event_handler(lambda e: r2.append(1), "pointer_move") + + _simulate_js_event(fig, v, "pointer_move", widget_id=w2, x=50.0, y=50.0) + assert r1 == [] + assert len(r2) == 1 + + def test_multi_panel_routing(self): + fig, (ax1, ax2) = apl.subplots(1, 2) + v1 = ax1.imshow(np.zeros((16, 16))) + v2 = ax2.imshow(np.zeros((16, 16))) + w1 = v1.add_widget("circle", cx=8, cy=8, r=3) + w2 = v2.add_widget("circle", cx=8, cy=8, r=3) + r1, r2 = [], [] + w1.add_event_handler(lambda e: r1.append(1), "pointer_move") + w2.add_event_handler(lambda e: r2.append(1), "pointer_move") + + _simulate_js_event(fig, v1, "pointer_move", widget_id=w1, cx=12.0) + assert len(r1) == 1 and r2 == [] + + def test_1d_vline_drag(self): + fig, ax = apl.subplots(1, 1) + v = ax.plot(np.zeros(64)) + w = v.add_vline_widget(x=10.0) + results = [] + w.add_event_handler(lambda event: results.append(w.x), "pointer_move") + + _simulate_js_event(fig, v, "pointer_move", widget_id=w, x=30.0) + assert results == [30.0] + + def test_1d_range_drag(self): + fig, ax = apl.subplots(1, 1) + v = ax.plot(np.zeros(64)) + w = v.add_range_widget(x0=10, x1=20) + results = [] + w.add_event_handler(lambda event: results.append((w.x0, w.x1)), "pointer_move") + + _simulate_js_event(fig, v, "pointer_move", widget_id=w, x0=15.0, x1=25.0) + assert results == [(15.0, 25.0)] + + def test_disconnect_prevents_callback(self): + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + w = v.add_widget("rectangle", x=0, y=0, w=10, h=10) + results = [] + fn = lambda event: results.append(1) + w.add_event_handler(fn, "pointer_move") + w.remove_handler(fn) + + _simulate_js_event(fig, v, "pointer_move", widget_id=w, x=50.0) + assert results == [] + + def test_widget_state_synced_after_js_event(self): + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + w = v.add_widget("rectangle", x=0, y=0, w=10, h=10) + + _simulate_js_event(fig, v, "pointer_move", widget_id=w, + x=77.0, y=88.0, w=33.0, h=44.0) + assert w.x == 77.0 and w.y == 88.0 and w.w == 33.0 and w.h == 44.0 + + def test_widget_x_readback_after_js_event(self): + """After a JS event, reading widget.x returns the updated value.""" + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + w = v.add_widget("circle", cx=0.0, cy=0.0, r=5.0) + + _simulate_js_event(fig, v, "pointer_up", widget_id=w, cx=20.0, cy=30.0) + assert w.cx == pytest.approx(20.0) + assert w.cy == pytest.approx(30.0) + + +# ═══════════════════════════════════════════════════════════════════════════════ +# 5. End-to-end FFT example +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestInteractiveFft: + """End-to-end: two panels, rectangle widget, simulate JS events, + verify callback fires and updates the FFT panel.""" + + @staticmethod + def _compute_fft(img, x0, y0, w, h, scale=0.1): + ih, iw = img.shape + x0i = max(0, int(round(x0))); y0i = max(0, int(round(y0))) + x1i = min(iw, x0i + max(1, int(round(w)))) + y1i = min(ih, y0i + max(1, int(round(h)))) + crop = img[y0i:y1i, x0i:x1i].copy() + ch, cw = crop.shape + if ch < 2 or cw < 2: + f = np.fft.fftfreq(4, d=scale) + return np.zeros((4, 4)), f, f + crop *= np.hanning(ch)[:, None] * np.hanning(cw)[None, :] + fft2 = np.fft.fftshift(np.fft.fft2(crop)) + log_mag = np.log1p(np.abs(fft2)) + return (log_mag, + np.fft.fftshift(np.fft.fftfreq(cw, d=scale)), + np.fft.fftshift(np.fft.fftfreq(ch, d=scale))) + + def test_drag_rectangle_updates_fft(self): + N = 64 + rng = np.random.default_rng(0) + img = rng.standard_normal((N, N)).cumsum(0).cumsum(1) + img = (img - img.min()) / (img.max() - img.min()) + scale = 0.1 + xy = np.arange(N) * scale + + fig, (ax_real, ax_fft) = apl.subplots(1, 2, figsize=(600, 300)) + v_real = ax_real.imshow(img, axes=[xy, xy], units="Å") + + ROI_W, ROI_H = 32, 32 + roi_x0, roi_y0 = 16, 16 + rect = v_real.add_widget("rectangle", + x=float(roi_x0), y=float(roi_y0), + w=float(ROI_W), h=float(ROI_H)) + + fft0, fx0, fy0 = self._compute_fft(img, roi_x0, roi_y0, ROI_W, ROI_H) + v_fft = ax_fft.imshow(fft0, axes=[fx0, fy0], units="1/Å") + + initial_b64 = v_fft._state["image_b64"] + updates = [] + + @rect.add_event_handler("pointer_move") + def on_rect_changed(event): + log_mag, freq_x, freq_y = self._compute_fft( + img, rect.x, rect.y, rect.w, rect.h) + v_fft.set_data(log_mag, x_axis=freq_x, y_axis=freq_y, units="1/Å") + updates.append({"x": rect.x, "y": rect.y, + "w": rect.w, "h": rect.h}) + + _simulate_js_event(fig, v_real, "pointer_move", widget_id=rect, + x=0.0, y=0.0, w=48.0, h=48.0) + + assert len(updates) == 1 + assert updates[0]["x"] == 0.0 and updates[0]["w"] == 48.0 + assert v_fft._state["image_b64"] != initial_b64 + + def test_multiple_drags_fire_multiple_callbacks(self): + N = 32 + img = np.random.default_rng(1).random((N, N)) + fig, ax = apl.subplots(1, 1) + v = ax.imshow(img) + rect = v.add_widget("rectangle", x=0, y=0, w=16, h=16) + count = [0] + rect.add_event_handler(lambda e: count.__setitem__(0, count[0] + 1), "pointer_move") + + for i in range(5): + _simulate_js_event(fig, v, "pointer_move", widget_id=rect, x=float(i)) + + # Only fires when something actually changed — first fire is from x=0 + # (which equals the initial value, no change), then 1,2,3,4 = 4 fires + assert count[0] == 4 + + def test_drag_then_disconnect(self): + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + rect = v.add_widget("rectangle", x=0, y=0, w=10, h=10) + results = [] + fn = lambda e: results.append(1) + rect.add_event_handler(fn, "pointer_move") + + _simulate_js_event(fig, v, "pointer_move", widget_id=rect, x=5.0) + assert len(results) == 1 + + rect.remove_handler(fn) + _simulate_js_event(fig, v, "pointer_move", widget_id=rect, x=10.0) + assert len(results) == 1 + + def test_on_release_after_drags(self): + N = 32 + img = np.random.default_rng(2).random((N, N)) + fig, ax = apl.subplots(1, 1) + v = ax.imshow(img) + rect = v.add_widget("rectangle", x=0, y=0, w=16, h=16) + drag_count = [0]; release_count = [0] + + rect.add_event_handler(lambda e: drag_count.__setitem__(0, drag_count[0] + 1), "pointer_move") + rect.add_event_handler(lambda e: release_count.__setitem__(0, release_count[0] + 1), "pointer_up") + + for i in range(1, 6): + _simulate_js_event(fig, v, "pointer_move", widget_id=rect, x=float(i)) + _simulate_js_event(fig, v, "pointer_up", widget_id=rect, x=5.0) + + assert drag_count[0] == 5 + assert release_count[0] == 1 + + +# ═══════════════════════════════════════════════════════════════════════════════ +# 6. Widget visibility (hide / show) +# ═══════════════════════════════════════════════════════════════════════════════ + +class TestWidgetVisibility: + """Unit tests for Widget.hide(), Widget.show(), and Widget.visible.""" + + def test_visible_default_true(self): + """A freshly created widget is visible by default.""" + w = RectangleWidget(lambda: None, x=0, y=0, w=10, h=10) + assert w.visible is True + + def test_hide_sets_visible_false(self): + """hide() marks the widget as not visible.""" + w = CircleWidget(lambda: None, cx=5, cy=5, r=3) + w.hide() + assert w.visible is False + + def test_show_restores_visible(self): + """show() after hide() restores visibility.""" + w = CircleWidget(lambda: None, cx=5, cy=5, r=3) + w.hide() + w.show() + assert w.visible is True + + def test_hide_calls_push(self): + """hide() must call push_fn exactly once.""" + pushed = [] + w = RectangleWidget(lambda: pushed.append(1), x=0, y=0, w=10, h=10) + pushed.clear() + w.hide() + assert len(pushed) == 1 + + def test_show_calls_push(self): + """show() must call push_fn exactly once.""" + pushed = [] + w = RectangleWidget(lambda: pushed.append(1), x=0, y=0, w=10, h=10) + pushed.clear() + w.show() + assert len(pushed) == 1 + + def test_hide_does_not_fire_on_changed(self): + """hide() must NOT fire pointer_move callbacks.""" + w = CircleWidget(lambda: None, cx=0, cy=0, r=5) + fired = [] + w.add_event_handler(lambda e: fired.append(1), "pointer_move") + w.hide() + assert fired == [] + + def test_show_does_not_fire_on_changed(self): + """show() must NOT fire pointer_move callbacks.""" + w = CircleWidget(lambda: None, cx=0, cy=0, r=5) + fired = [] + w.add_event_handler(lambda e: fired.append(1), "pointer_move") + w.hide() + w.show() + assert fired == [] + + def test_visible_in_to_dict_after_hide(self): + """to_dict() reflects visible=False after hide().""" + w = RectangleWidget(lambda: None, x=0, y=0, w=10, h=10) + w.hide() + assert w.to_dict()["visible"] is False + + def test_visible_in_to_dict_after_show(self): + """to_dict() reflects visible=True after show().""" + w = RectangleWidget(lambda: None, x=0, y=0, w=10, h=10) + w.hide() + w.show() + assert w.to_dict()["visible"] is True + + def test_visible_in_state_dict_after_hide(self): + """The panel state dict propagates visible=False for a hidden widget.""" + fig, ax = apl.subplots(1, 1) + v = ax.plot(np.zeros(64)) + w = v.add_vline_widget(x=5.0) + w.hide() + widgets = v.to_state_dict()["overlay_widgets"] + entry = next(e for e in widgets if e["id"] == w.id) + assert entry["visible"] is False + + def test_visible_in_state_dict_after_show(self): + """The panel state dict propagates visible=True after show().""" + fig, ax = apl.subplots(1, 1) + v = ax.plot(np.zeros(64)) + w = v.add_vline_widget(x=5.0) + w.hide() + w.show() + widgets = v.to_state_dict()["overlay_widgets"] + entry = next(e for e in widgets if e["id"] == w.id) + assert entry["visible"] is True + + def test_hide_then_show_widget_still_draggable(self): + """After show(), a JS drag event fires callbacks as normal.""" + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((32, 32))) + w = v.add_widget("circle", cx=10, cy=10, r=5) + fired = [] + w.add_event_handler(lambda e: fired.append(w.cx), "pointer_move") + w.hide() + w.show() + _simulate_js_event(fig, v, "pointer_move", widget_id=w, cx=20.0) + assert fired == [20.0] + + def test_hide_show_1d_range_widget(self): + """hide/show round-trip works for a RangeWidget.""" + fig, ax = apl.subplots(1, 1) + v = ax.plot(np.zeros(64)) + w = v.add_range_widget(x0=10, x1=20) + w.hide() + assert w.visible is False + w.show() + assert w.visible is True + + def test_multiple_hide_calls_idempotent(self): + """Calling hide() twice leaves visible=False, pushes twice.""" + pushed = [] + w = CircleWidget(lambda: pushed.append(1), cx=0, cy=0, r=5) + pushed.clear() + w.hide() + w.hide() + assert w.visible is False + assert len(pushed) == 2 # each hide() pushes once + + def test_multiple_show_calls_idempotent(self): + """Calling show() twice leaves visible=True, pushes twice.""" + pushed = [] + w = CircleWidget(lambda: pushed.append(1), cx=0, cy=0, r=5) + pushed.clear() + w.show() + w.show() + assert w.visible is True + assert len(pushed) == 2 + + +# ═══════════════════════════════════════════════════════════════════════════════ +# 7. Interactive Fitting — plot_interactive_fitting.py scenario +# ═══════════════════════════════════════════════════════════════════════════════ + +from anyplotlib.widgets import RangeWidget as _RangeWidget2, PointWidget as _PointWidget2 + + +def _gaussian(x, amp, mu, sigma): + return amp * np.exp(-0.5 * ((x - mu) / sigma) ** 2) + + +def _two_gaussians(x, a1, mu1, s1, a2, mu2, s2): + return _gaussian(x, a1, mu1, s1) + _gaussian(x, a2, mu2, s2) + + +class _GaussianController: + """Mirror of GaussianController from plot_interactive_fitting.py.""" + + def __init__(self, plot, line, p, color, x, fit_callback): + self._plot = plot + self.line = line + self.amp = p["amp"] + self.mu = p["mu"] + self.sigma = p["sigma"] + self.color = color + self._x = x + self._refit = fit_callback + self._active = False + self._syncing = False + self._pt = None + self._rng_w = None + + def component_y(self): + return _gaussian(self._x, self.amp, self.mu, self.sigma) + + def toggle(self): + if self._active: + self._pt.hide() + self._rng_w.hide() + self._active = False + else: + if self._pt is None: + self._pt = self._plot.add_point_widget(self.mu, self.amp, + color=self.color) + self._rng_w = self._plot.add_range_widget( + self.mu - self.sigma, self.mu + self.sigma, + color=self.color, + ) + self._wire() + else: + self._pt.show() + self._rng_w.show() + self._active = True + + def _wire(self): + @self._pt.add_event_handler("pointer_move") + def _peak_moved(event): + if self._syncing: + return + self._syncing = True + try: + self.amp = self._pt.y + self.mu = self._pt.x + self._rng_w.set(x0=self.mu - self.sigma, + x1=self.mu + self.sigma) + self.line.set_data(self.component_y()) + self._refit() + finally: + self._syncing = False + + @self._rng_w.add_event_handler("pointer_move") + def _range_moved(event): + if self._syncing: + return + self._syncing = True + try: + x0, x1 = self._rng_w.x0, self._rng_w.x1 + self.mu = (x0 + x1) / 2.0 + self.sigma = abs(x1 - x0) / 2.0 + self._pt.set(x=self.mu) + self.line.set_data(self.component_y()) + self._refit() + finally: + self._syncing = False + + +class TestInteractiveFitting: + """End-to-end tests mirroring plot_interactive_fitting.py. + + Validates widget hide/show toggle, PointWidget and RangeWidget drag + callbacks, and the live refit flow — all without a browser. + """ + + def _build(self): + """Return (fig, plot, controllers, fit_line, x, signal).""" + from scipy.optimize import curve_fit + + x = np.linspace(0, 10, 200) + TRUE_P = [ + dict(amp=1.0, mu=3.2, sigma=0.55), + dict(amp=0.75, mu=6.8, sigma=0.80), + ] + COLORS = ["#ff6b6b", "#69db7c"] + rng = np.random.default_rng(0) + signal = sum(_gaussian(x, **p) for p in TRUE_P) + rng.normal(0, 0.03, len(x)) + + INIT_P = [ + dict(amp=1.0, mu=3.0, sigma=0.6), + dict(amp=0.7, mu=7.0, sigma=0.9), + ] + + fig, ax = apl.subplots(1, 1, figsize=(600, 300)) + plot = ax.plot(signal, axes=[x], color="#adb5bd") + + comp_lines = [ + plot.add_line(_gaussian(x, **p), x_axis=x, color=c) + for i, (p, c) in enumerate(zip(INIT_P, COLORS)) + ] + + fit_line = plot.add_line( + sum(_gaussian(x, **p) for p in INIT_P), x_axis=x, + color="#ffd43b", linestyle="dashed", + ) + + refit_calls = [0] + + def _refit(): + c0, c1 = controllers[0], controllers[1] + p0 = [c0.amp, c0.mu, c0.sigma, c1.amp, c1.mu, c1.sigma] + lo = [0, x[0], 1e-3, 0, x[0], 1e-3] + hi = [np.inf, x[-1], x[-1]-x[0], np.inf, x[-1], x[-1]-x[0]] + try: + popt, _ = curve_fit(_two_gaussians, x, signal, p0=p0, + bounds=(lo, hi), maxfev=3000) + fit_line.set_data(_two_gaussians(x, *popt)) + except RuntimeError: + fit_line.set_data(sum(c.component_y() for c in controllers)) + refit_calls[0] += 1 + + controllers = [ + _GaussianController(plot, comp_lines[i], INIT_P[i], COLORS[i], + x, _refit) + for i in range(2) + ] + + return fig, plot, controllers, fit_line, x, signal, refit_calls + + # ── toggle creates widgets ──────────────────────────────────────────────── + + def test_toggle_once_creates_point_and_range_widgets(self): + """First toggle creates a PointWidget and a RangeWidget.""" + _, plot, ctrls, *_ = self._build() + ctrl = ctrls[0] + assert ctrl._pt is None and ctrl._rng_w is None + ctrl.toggle() + assert ctrl._pt is not None + assert ctrl._rng_w is not None + assert ctrl._active is True + + def test_toggle_once_adds_two_widgets_to_plot(self): + """After first toggle, the plot has exactly 2 new widgets.""" + _, plot, ctrls, *_ = self._build() + ctrl = ctrls[0] + ctrl.toggle() + assert len(plot.list_widgets()) == 2 + + def test_widgets_visible_after_first_toggle(self): + """Widgets created on first toggle are visible.""" + _, plot, ctrls, *_ = self._build() + ctrl = ctrls[0] + ctrl.toggle() + assert ctrl._pt.visible is True + assert ctrl._rng_w.visible is True + + # ── toggle hides widgets ────────────────────────────────────────────────── + + def test_toggle_twice_hides_widgets(self): + """Second toggle hides the point and range widgets.""" + _, plot, ctrls, *_ = self._build() + ctrl = ctrls[0] + ctrl.toggle() # activate + ctrl.toggle() # deactivate + assert ctrl._active is False + assert ctrl._pt.visible is False + assert ctrl._rng_w.visible is False + + def test_toggle_twice_widgets_still_in_plot(self): + """Hidden widgets are NOT removed from the plot — they stay but are hidden.""" + _, plot, ctrls, *_ = self._build() + ctrl = ctrls[0] + ctrl.toggle() + ctrl.toggle() + # Still registered — just hidden + assert len(plot.list_widgets()) == 2 + + # ── toggle shows widgets again ──────────────────────────────────────────── + + def test_toggle_three_times_reshows_widgets(self): + """Third toggle re-shows the existing widgets without creating new ones.""" + _, plot, ctrls, *_ = self._build() + ctrl = ctrls[0] + ctrl.toggle() # create + show + pt_id = ctrl._pt.id + rng_id = ctrl._rng_w.id + ctrl.toggle() # hide + ctrl.toggle() # re-show + assert ctrl._active is True + assert ctrl._pt.visible is True + assert ctrl._rng_w.visible is True + # Same objects — not recreated + assert ctrl._pt.id == pt_id + assert ctrl._rng_w.id == rng_id + assert len(plot.list_widgets()) == 2 + + # ── PointWidget drag updates component line ─────────────────────────────── + + def test_point_drag_updates_component_amp_and_mu(self): + """Simulating a PointWidget drag updates amp and mu on the controller.""" + fig, plot, ctrls, fit_line, x, signal, refit_calls = self._build() + ctrl = ctrls[0] + ctrl.toggle() + + _simulate_js_event(fig, plot, "pointer_move", + widget_id=ctrl._pt, x=3.5, y=0.9) + + assert ctrl.mu == pytest.approx(3.5) + assert ctrl.amp == pytest.approx(0.9) + + def test_point_drag_updates_range_widget_position(self): + """Dragging the point recentres the range widget around new mu.""" + fig, plot, ctrls, fit_line, x, signal, refit_calls = self._build() + ctrl = ctrls[0] + ctrl.toggle() + original_sigma = ctrl.sigma + + _simulate_js_event(fig, plot, "pointer_move", + widget_id=ctrl._pt, x=4.0, y=1.0) + + expected_x0 = 4.0 - original_sigma + expected_x1 = 4.0 + original_sigma + assert ctrl._rng_w.x0 == pytest.approx(expected_x0) + assert ctrl._rng_w.x1 == pytest.approx(expected_x1) + + def test_point_drag_updates_component_line_data(self): + """After a PointWidget drag, the component line data reflects new params.""" + fig, plot, ctrls, fit_line, x, signal, refit_calls = self._build() + ctrl = ctrls[0] + ctrl.toggle() + + old_data = _gaussian(x, ctrl.amp, ctrl.mu, ctrl.sigma).copy() + _simulate_js_event(fig, plot, "pointer_move", + widget_id=ctrl._pt, x=4.0, y=0.8) + + # Find the extra_line entry for comp_lines[0] + lid = ctrl.line.id + entry = next(e for e in plot._state["extra_lines"] if e["id"] == lid) + new_y = entry["data"] + expected_y = _gaussian(x, 0.8, 4.0, ctrl.sigma) + np.testing.assert_allclose(new_y, expected_y, rtol=1e-10) + + def test_point_drag_triggers_refit(self): + """Dragging the PointWidget calls the refit callback.""" + fig, plot, ctrls, fit_line, x, signal, refit_calls = self._build() + ctrl = ctrls[0] + ctrl.toggle() + + _simulate_js_event(fig, plot, "pointer_move", + widget_id=ctrl._pt, x=3.5, y=0.9) + + assert refit_calls[0] >= 1 + + def test_point_drag_updates_fit_line(self): + """After a point drag, the fit line data changes.""" + fig, plot, ctrls, fit_line, x, signal, refit_calls = self._build() + ctrl = ctrls[0] + ctrl.toggle() + + lid = fit_line.id + entry_before = next(e for e in plot._state["extra_lines"] if e["id"] == lid) + old_fit = entry_before["data"].copy() + + _simulate_js_event(fig, plot, "pointer_move", + widget_id=ctrl._pt, x=4.5, y=0.5) + + entry_after = next(e for e in plot._state["extra_lines"] if e["id"] == lid) + assert not np.array_equal(entry_after["data"], old_fit) + + # ── RangeWidget drag updates component line ─────────────────────────────── + + def test_range_drag_updates_mu_and_sigma(self): + """Simulating a RangeWidget drag updates mu and sigma on the controller.""" + fig, plot, ctrls, fit_line, x, signal, refit_calls = self._build() + ctrl = ctrls[0] + ctrl.toggle() + + _simulate_js_event(fig, plot, "pointer_move", + widget_id=ctrl._rng_w, x0=2.5, x1=4.5) + + assert ctrl.mu == pytest.approx(3.5) + assert ctrl.sigma == pytest.approx(1.0) + + def test_range_drag_recentres_point_widget(self): + """Dragging the range widget moves the point widget to the new centre.""" + fig, plot, ctrls, fit_line, x, signal, refit_calls = self._build() + ctrl = ctrls[0] + ctrl.toggle() + + _simulate_js_event(fig, plot, "pointer_move", + widget_id=ctrl._rng_w, x0=2.0, x1=5.0) + + assert ctrl._pt.x == pytest.approx(3.5) + + def test_range_drag_updates_component_line_data(self): + """After a RangeWidget drag, the component line reflects the new sigma.""" + fig, plot, ctrls, fit_line, x, signal, refit_calls = self._build() + ctrl = ctrls[0] + ctrl.toggle() + + _simulate_js_event(fig, plot, "pointer_move", + widget_id=ctrl._rng_w, x0=2.5, x1=4.5) + + lid = ctrl.line.id + entry = next(e for e in plot._state["extra_lines"] if e["id"] == lid) + expected_y = _gaussian(x, ctrl.amp, 3.5, 1.0) + np.testing.assert_allclose(entry["data"], expected_y, rtol=1e-10) + + def test_range_drag_triggers_refit(self): + """Dragging the RangeWidget calls the refit callback.""" + fig, plot, ctrls, fit_line, x, signal, refit_calls = self._build() + ctrl = ctrls[0] + ctrl.toggle() + + _simulate_js_event(fig, plot, "pointer_move", + widget_id=ctrl._rng_w, x0=2.5, x1=4.5) + + assert refit_calls[0] >= 1 + + # ── both controllers independent ───────────────────────────────────────── + + def test_two_controllers_independent(self): + """Dragging ctrl[0] does not affect ctrl[1] state.""" + fig, plot, ctrls, fit_line, x, signal, refit_calls = self._build() + ctrls[0].toggle() + ctrls[1].toggle() + + old_mu1 = ctrls[1].mu + + _simulate_js_event(fig, plot, "pointer_move", + widget_id=ctrls[0]._pt, x=3.8, y=1.1) + + assert ctrls[1].mu == pytest.approx(old_mu1) + + def test_both_controllers_active_at_same_time(self): + """Both controllers can be active simultaneously with no crosstalk.""" + _, plot, ctrls, *_ = self._build() + ctrls[0].toggle() + ctrls[1].toggle() + assert len(plot.list_widgets()) == 4 + assert ctrls[0]._active and ctrls[1]._active + + def test_hide_one_leaves_other_visible(self): + """Hiding ctrl[0] does not affect ctrl[1] visibility.""" + _, plot, ctrls, *_ = self._build() + ctrls[0].toggle() # activate + ctrls[1].toggle() # activate + ctrls[0].toggle() # hide + assert ctrls[0]._pt.visible is False + assert ctrls[1]._pt.visible is True + + # ── line click toggles controller ───────────────────────────────────────── + + def test_line_click_activates_controller(self): + """Simulating a click on a component line activates its controller.""" + fig, plot, ctrls, fit_line, x, signal, refit_calls = self._build() + ctrl = ctrls[0] + + # Wire up the line click handler (same as the example) + @ctrl.line.add_event_handler("pointer_down") + def _clicked(event, c=ctrl): + c.toggle() + + # Simulate JS sending a pointer_down event for comp_lines[0] + fig._on_event({"new": __import__("json").dumps({ + "source": "js", + "panel_id": plot._id, + "event_type": "pointer_down", + "line_id": ctrl.line.id, + })}) + + assert ctrl._active is True + assert ctrl._pt is not None + + def test_line_click_twice_hides_widgets(self): + """Two clicks on the same component line toggle it off again.""" + fig, plot, ctrls, fit_line, x, signal, refit_calls = self._build() + ctrl = ctrls[0] + + @ctrl.line.add_event_handler("pointer_down") + def _clicked(event, c=ctrl): + c.toggle() + + import json as _json + + def _click(): + fig._on_event({"new": _json.dumps({ + "source": "js", + "panel_id": plot._id, + "event_type": "pointer_down", + "line_id": ctrl.line.id, + })}) + + _click() # → active + _click() # → hidden + + assert ctrl._active is False + assert ctrl._pt.visible is False + + def test_line_click_wrong_line_id_no_toggle(self): + """A click on a different line ID does NOT toggle this controller.""" + fig, plot, ctrls, fit_line, x, signal, refit_calls = self._build() + ctrl = ctrls[0] + + @ctrl.line.add_event_handler("pointer_down") + def _clicked(event, c=ctrl): + c.toggle() + + import json as _json + fig._on_event({"new": _json.dumps({ + "source": "js", + "panel_id": plot._id, + "event_type": "pointer_down", + "line_id": "completely-wrong-id", + })}) + + assert ctrl._active is False # was never toggled + + # ── example-mirroring tests ─────────────────────────────────────────────── + + def _build_with_click_handlers(self): + """Same as _build() but wires line click → ctrl.toggle() for both + components, exactly as the for-loop in plot_interactive_fitting.py.""" + result = self._build() + _, _, controllers, *_ = result + for ctrl in controllers: + @ctrl.line.add_event_handler("pointer_down") + def _clicked(event, c=ctrl): + c.toggle() + return result + + def test_example_both_lines_clickable(self): + """Clicking each component line activates its controller and makes + the widgets visible — mirrors the click-handler loop in the example.""" + fig, plot, ctrls, fit_line, x, signal, refit_calls = \ + self._build_with_click_handlers() + + # Click component 0 + _simulate_js_event(fig, plot, "pointer_down", line_id=ctrls[0].line.id) + assert ctrls[0]._active is True + assert ctrls[0]._pt is not None + assert ctrls[0]._rng_w is not None + assert ctrls[0]._pt.visible is True + assert ctrls[0]._rng_w.visible is True + assert ctrls[1]._active is False # other controller untouched + + # Click component 1 + _simulate_js_event(fig, plot, "pointer_down", line_id=ctrls[1].line.id) + assert ctrls[1]._active is True + assert ctrls[1]._pt.visible is True + assert ctrls[1]._rng_w.visible is True + + def test_example_click_shows_widgets_registered_in_plot(self): + """After clicking a component line its widgets appear in list_widgets().""" + fig, plot, ctrls, fit_line, x, signal, refit_calls = \ + self._build_with_click_handlers() + + assert len(plot.list_widgets()) == 0 + + _simulate_js_event(fig, plot, "pointer_down", line_id=ctrls[0].line.id) + assert len(plot.list_widgets()) == 2 # PointWidget + RangeWidget + + _simulate_js_event(fig, plot, "pointer_down", line_id=ctrls[1].line.id) + assert len(plot.list_widgets()) == 4 # +2 for ctrl[1] + + def test_example_second_click_hides_widgets(self): + """Second click hides widgets but keeps them registered in the plot.""" + fig, plot, ctrls, fit_line, x, signal, refit_calls = \ + self._build_with_click_handlers() + + def _click(ctrl): + _simulate_js_event(fig, plot, "pointer_down", + line_id=ctrl.line.id) + + _click(ctrls[0]) # show + assert ctrls[0]._active is True and ctrls[0]._pt.visible is True + + _click(ctrls[0]) # hide + assert ctrls[0]._active is False + assert ctrls[0]._pt.visible is False + assert ctrls[0]._rng_w.visible is False + assert len(plot.list_widgets()) == 2 # still registered, just hidden + + def test_example_third_click_reshows_same_widgets(self): + """Third click re-shows the same widget objects without recreating them.""" + fig, plot, ctrls, fit_line, x, signal, refit_calls = \ + self._build_with_click_handlers() + + def _click(ctrl): + _simulate_js_event(fig, plot, "pointer_down", + line_id=ctrl.line.id) + + _click(ctrls[0]) + pt_id = ctrls[0]._pt.id + rng_id = ctrls[0]._rng_w.id + + _click(ctrls[0]) # hide + _click(ctrls[0]) # re-show + + assert ctrls[0]._active is True + assert ctrls[0]._pt.visible is True + assert ctrls[0]._rng_w.visible is True + assert ctrls[0]._pt.id == pt_id # same objects, not recreated + assert ctrls[0]._rng_w.id == rng_id + assert len(plot.list_widgets()) == 2 + + def test_example_click_then_drag_updates_fit(self): + """Full flow: click to activate → drag PointWidget → fit line changes.""" + fig, plot, ctrls, fit_line, x, signal, refit_calls = \ + self._build_with_click_handlers() + + _simulate_js_event(fig, plot, "pointer_down", line_id=ctrls[0].line.id) + assert ctrls[0]._active is True + + lid = fit_line.id + fit_before = next( + e for e in plot._state["extra_lines"] if e["id"] == lid + )["data"].copy() + + _simulate_js_event(fig, plot, "pointer_move", + widget_id=ctrls[0]._pt, x=4.0, y=0.8) + + fit_after = next( + e for e in plot._state["extra_lines"] if e["id"] == lid + )["data"] + assert not np.array_equal(fit_after, fit_before) + assert refit_calls[0] >= 1 + + def test_example_wrong_line_id_not_clickable(self): + """A click event for an unknown line ID activates no controller.""" + fig, plot, ctrls, fit_line, x, signal, refit_calls = \ + self._build_with_click_handlers() + + _simulate_js_event(fig, plot, "pointer_down", line_id="no-such-line") + assert ctrls[0]._active is False + assert ctrls[1]._active is False diff --git a/anyplotlib/tests/test_labels/__init__.py b/anyplotlib/tests/test_labels/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/anyplotlib/tests/test_labels/test_label_api.py b/anyplotlib/tests/test_labels/test_label_api.py new file mode 100644 index 00000000..a4cc7714 --- /dev/null +++ b/anyplotlib/tests/test_labels/test_label_api.py @@ -0,0 +1,132 @@ +""" +Unit tests for the label font-size API and TeX pass-through. + +Covers: + * fontsize kwarg on set_xlabel / set_ylabel / set_zlabel / set_title / + set_colorbar_label for every panel type + * fontsize=None leaves the size state untouched (JS falls back to defaults) + * set_tick_label_size + * TeX-formatted label strings are stored verbatim (parsing happens in JS) +""" +from __future__ import annotations + +import numpy as np +import pytest + +import anyplotlib as apl + + +def _imshow(): + fig, ax = apl.subplots(1, 1) + return ax.imshow(np.zeros((8, 8))) + + +def _plot(): + fig, ax = apl.subplots(1, 1) + return ax.plot(np.zeros(16)) + + +def _bar(): + fig, ax = apl.subplots(1, 1) + return ax.bar(["a", "b"], [1.0, 2.0]) + + +def _surface(): + fig, ax = apl.subplots(1, 1) + g = np.linspace(-1, 1, 8) + XX, YY = np.meshgrid(g, g) + return ax.plot_surface(XX, YY, XX * YY) + + +class TestFontsizeKwarg: + def test_plot2d_xlabel_fontsize(self): + v = _imshow() + v.set_xlabel("x", fontsize=14) + assert v._state["x_label"] == "x" + assert v._state["x_label_size"] == 14.0 + + def test_plot2d_ylabel_fontsize(self): + v = _imshow() + v.set_ylabel("y", fontsize=16) + assert v._state["y_label"] == "y" + assert v._state["y_label_size"] == 16.0 + + def test_plot2d_colorbar_label_fontsize(self): + v = _imshow() + v.set_colorbar_label("counts", fontsize=13) + assert v._state["colorbar_label"] == "counts" + assert v._state["colorbar_label_size"] == 13.0 + + def test_plot1d_label_fontsize_maps_to_units(self): + v = _plot() + v.set_xlabel("eV", fontsize=12) + v.set_ylabel("counts", fontsize=11) + assert v._state["units"] == "eV" + assert v._state["x_label_size"] == 12.0 + assert v._state["y_units"] == "counts" + assert v._state["y_label_size"] == 11.0 + + def test_plotbar_label_fontsize(self): + v = _bar() + v.set_xlabel("category", fontsize=12) + v.set_ylabel("value", fontsize=13) + assert v._state["x_label_size"] == 12.0 + assert v._state["y_label_size"] == 13.0 + + def test_plot3d_label_fontsize(self): + v = _surface() + v.set_xlabel("x", fontsize=14) + v.set_ylabel("y", fontsize=15) + v.set_zlabel("z", fontsize=16) + assert v._state["x_label_size"] == 14.0 + assert v._state["y_label_size"] == 15.0 + assert v._state["z_label_size"] == 16.0 + + def test_title_fontsize_all_panel_types(self): + for make in (_imshow, _plot, _bar, _surface): + v = make() + v.set_title("T", fontsize=12) + assert v._state["title"] == "T" + assert v._state["title_size"] == 12.0 + + +class TestFontsizeNoneKeepsState: + def test_none_does_not_create_size_key(self): + v = _imshow() + v.set_xlabel("x") + assert "x_label_size" not in v._state + + def test_none_does_not_overwrite_previous_size(self): + v = _imshow() + v.set_xlabel("x", fontsize=18) + v.set_xlabel("renamed") # no fontsize — keep 18 + assert v._state["x_label"] == "renamed" + assert v._state["x_label_size"] == 18.0 + + +class TestTickLabelSize: + @pytest.mark.parametrize("make", [_imshow, _plot, _bar]) + def test_set_tick_label_size(self, make): + v = make() + v.set_tick_label_size(14) + assert v._state["tick_size"] == 14.0 + + +class TestTexPassThrough: + """Python stores TeX strings verbatim; all parsing happens at JS draw time.""" + + def test_tex_label_stored_verbatim(self): + v = _imshow() + label = r"$q$ ($\AA^{-1}$)" + v.set_xlabel(label) + assert v._state["x_label"] == label + + def test_tex_exponent_title(self): + v = _plot() + v.set_title(r"Intensity $\times 10^{-3}$") + assert v._state["title"] == r"Intensity $\times 10^{-3}$" + + def test_tex_subscript_colorbar(self): + v = _imshow() + v.set_colorbar_label(r"$E_F$ (eV)") + assert v._state["colorbar_label"] == r"$E_F$ (eV)" diff --git a/anyplotlib/tests/test_labels/test_label_rendering.py b/anyplotlib/tests/test_labels/test_label_rendering.py new file mode 100644 index 00000000..933b70aa --- /dev/null +++ b/anyplotlib/tests/test_labels/test_label_rendering.py @@ -0,0 +1,137 @@ +""" +Playwright tests for label font sizes and mini-TeX rendering. + +Strategy +-------- +Canvas text cannot be read back as strings, so these tests assert on *ink*: + +* a larger ``fontsize`` must produce more non-background pixels in the + axis gutter than a smaller one; +* a TeX string like ``$10^{-3}$`` must render *narrower* than the same + characters drawn literally (``10^{-3}``) — the ``$`` delimiters are + consumed and the exponent shrinks to a superscript; +* TeX titles must produce visible pixels in the title canvas. +""" +from __future__ import annotations + +import numpy as np + +import anyplotlib as apl + +PAD_B = 42 # bottom axis gutter height (PAD_* constants in figure_esm.js) + + +# ── helpers ─────────────────────────────────────────────────────────────────── + +def _x_gutter(img: np.ndarray) -> np.ndarray: + """Return the bottom PAD_B-row strip of a widget screenshot.""" + return img[-PAD_B:, :, :3].astype(int) + + +def _ink_mask(strip: np.ndarray) -> np.ndarray: + """Boolean mask of pixels that differ from the strip's corner colour.""" + bg = strip[2, 2] + return np.abs(strip - bg).sum(axis=-1) > 60 + + +def _x_gutter_ink(take_screenshot, label: str, fontsize=None) -> np.ndarray: + """Render an imshow with the given x label; return the gutter ink mask.""" + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + plot = ax.imshow( + np.zeros((32, 32), dtype=np.float32), + axes=[np.linspace(0, 10, 32)] * 2, + units="nm", + ) + if fontsize is None: + plot.set_xlabel(label) + else: + plot.set_xlabel(label, fontsize=fontsize) + return _ink_mask(_x_gutter(take_screenshot(fig))) + + +def _title_pixel_count(page) -> int: + """Count non-transparent pixels in the 2D titleCanvas (z-index:8).""" + return page.evaluate("""() => { + const tc = Array.from(document.querySelectorAll('canvas')) + .find(c => c.style.zIndex === '8'); + if (!tc) return -1; + const ctx = tc.getContext('2d'); + const d = ctx.getImageData(0, 0, tc.width, tc.height).data; + let n = 0; + for (let i = 3; i < d.length; i += 4) { if (d[i] > 0) n++; } + return n; + }""") + + +# ══════════════════════════════════════════════════════════════════════════════ + + +class TestFontsizeRendering: + def test_larger_fontsize_more_ink(self, take_screenshot): + small = _x_gutter_ink(take_screenshot, "Distance", fontsize=9) + large = _x_gutter_ink(take_screenshot, "Distance", fontsize=18) + assert large.sum() > small.sum() * 1.3, ( + f"fontsize=18 must draw more label ink than fontsize=9 " + f"(got {large.sum()} vs {small.sum()})" + ) + + def test_tick_label_size_changes_ink(self, take_screenshot): + def gutter_ink(tick_size): + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + plot = ax.imshow( + np.zeros((32, 32), dtype=np.float32), + axes=[np.linspace(0, 10, 32)] * 2, + ) + if tick_size: + plot.set_tick_label_size(tick_size) + return _ink_mask(_x_gutter(take_screenshot(fig))).sum() + + assert gutter_ink(16) > gutter_ink(None) * 1.2, ( + "set_tick_label_size(16) must draw more tick ink than the default" + ) + + +class TestTexRendering: + def test_tex_label_renders_ink(self, take_screenshot): + ink = _x_gutter_ink(take_screenshot, r"$10^{-3}$ m") + assert ink.sum() > 0, "TeX label must render visible pixels" + + def test_tex_consumes_dollars_and_shrinks_exponent(self, take_screenshot): + """$10^{-3}$ must be narrower than the literal text 10^{-3}. + + The TeX path drops the two ``$`` delimiters and the ``^{}`` braces + and renders ``-3`` at ~0.68× size, so its ink must span fewer + columns than the literal 7-glyph string. + """ + tex = _x_gutter_ink(take_screenshot, r"$10^{-3}$") + lit = _x_gutter_ink(take_screenshot, "10^{-3}") + # Width = number of columns containing any ink in the label row band. + # Restrict to the bottom 14 rows where the centred label is drawn, + # away from tick numbers at the top of the gutter. + tex_cols = np.flatnonzero(tex[-14:, :].any(axis=0)) + lit_cols = np.flatnonzero(lit[-14:, :].any(axis=0)) + assert len(tex_cols) > 0 and len(lit_cols) > 0 + tex_w = tex_cols[-1] - tex_cols[0] + lit_w = lit_cols[-1] - lit_cols[0] + assert tex_w < lit_w, ( + f"TeX '$10^{{-3}}$' must render narrower than literal '10^{{-3}}' " + f"(got {tex_w} vs {lit_w} px)" + ) + + def test_tex_title_renders_pixels(self, interact_page): + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + plot = ax.imshow(np.zeros((32, 32), dtype=np.float32)) + plot.set_title(r"$\sigma^2 = \langle x^2 \rangle$") + page = interact_page(fig) + page.wait_for_timeout(200) + n = _title_pixel_count(page) + assert n > 0, "TeX title must produce visible pixels in titleCanvas" + + def test_greek_and_symbols_render(self, take_screenshot): + ink = _x_gutter_ink(take_screenshot, r"$\Delta E$ ($\mu$eV) $\times$ $\AA$") + assert ink.sum() > 0 + + def test_plain_label_unaffected(self, take_screenshot): + """A label with no $ must render through the fast path identically.""" + ink = _x_gutter_ink(take_screenshot, "plain label, no math") + assert ink.sum() > 0 diff --git a/anyplotlib/tests/test_labels/test_no_clipping.py b/anyplotlib/tests/test_labels/test_no_clipping.py new file mode 100644 index 00000000..85b302c3 --- /dev/null +++ b/anyplotlib/tests/test_labels/test_no_clipping.py @@ -0,0 +1,102 @@ +""" +Playwright regression tests: labels, titles, and tick text must never be +clipped by their canvas bounds. + +Strategy: read back each text-bearing canvas with ``getImageData`` and assert +no ink (non-transparent pixel) sits on the canvas's first/last row. Text +whose glyphs are cut by the canvas edge always leaves ink on the edge row, so +"no ink on the edge" ⇒ "nothing clipped vertically". + +The 2D title canvas is fully transparent except for the title, making it the +cleanest probe for both the dynamic title strip and the TeX superscript rise. +""" +from __future__ import annotations + +import numpy as np +import pytest + +import anyplotlib as apl + + +def _title_ink_rows(page) -> dict: + """Return {h, minRow, maxRow} of ink in the 2D titleCanvas (z-index 8).""" + return page.evaluate("""() => { + const tc = Array.from(document.querySelectorAll('canvas')) + .find(c => c.style.zIndex === '8'); + if (!tc) return null; + const d = tc.getContext('2d').getImageData(0, 0, tc.width, tc.height).data; + let minR = 1e9, maxR = -1; + for (let y = 0; y < tc.height; y++) for (let x = 0; x < tc.width; x++) { + if (d[(y * tc.width + x) * 4 + 3] > 0) { + if (y < minR) minR = y; + if (y > maxR) maxR = y; + } + } + return { h: tc.height, minRow: minR, maxRow: maxR }; + }""") + + +def _open_imshow_with_title(interact_page, title, fontsize=None): + fig, ax = apl.subplots(1, 1, figsize=(460, 380)) + q = np.linspace(-2.3, 2.3, 64) + plot = ax.imshow(np.zeros((64, 64), dtype=np.float32), axes=[q, q], units="nm") + if fontsize is None: + plot.set_title(title) + else: + plot.set_title(title, fontsize=fontsize) + page = interact_page(fig) + page.wait_for_timeout(150) + return page + + +class TestTitleNeverClipped: + @pytest.mark.parametrize("title,fontsize", [ + ("Plain gyp TX", None), # default plain — baseline case + (r"TeX: $|F(q)|^2$ gyp", None), # default TeX — strip grows for sup + (r"Large $x^2$ gyp", 16), # large TeX — strip grows + ("Plain large gyp", 16), # large plain + (r"XL $y_i^2$ gyp", 22), # extreme, sup + sub + descenders + ]) + def test_title_ink_within_strip(self, interact_page, title, fontsize): + page = _open_imshow_with_title(interact_page, title, fontsize) + r = _title_ink_rows(page) + assert r is not None and r["maxRow"] >= 0, "title produced no ink" + assert r["minRow"] > 0, ( + f"title ink touches the top edge (clipped ascender/superscript): {r}" + ) + assert r["maxRow"] < r["h"] - 1, ( + f"title ink touches the bottom edge (clipped descender): {r}" + ) + + +class TestColorbarLabelVisible: + def test_colorbar_label_renders_in_reserved_gutter(self, interact_page): + """The image must shrink so the colorbar strip + label fit the panel.""" + fig, ax = apl.subplots(1, 1, figsize=(460, 380)) + q = np.linspace(-2.3, 2.3, 64) + plot = ax.imshow(np.zeros((64, 64), dtype=np.float32), axes=[q, q]) + plot.set_colorbar_visible(True) + plot.set_colorbar_label(r"counts $\times 10^{3}$") + page = interact_page(fig) + page.wait_for_timeout(150) + + res = page.evaluate("""() => { + // cbCanvas: the only canvas right of the image, width > 16 + const panel = 460; + for (const c of document.querySelectorAll('canvas')) { + const left = parseFloat(c.style.left || '0'); + if (c.width > 16 && c.width < 80 && left > 300) { + // entire canvas must sit inside the panel width + const fits = left + c.width <= panel; + // ink in the label gutter (x > 16) + const d = c.getContext('2d').getImageData(16, 0, c.width - 16, c.height).data; + let ink = 0; + for (let i = 3; i < d.length; i += 4) if (d[i] > 0) ink++; + return { w: c.width, left, fits, labelInk: ink }; + } + } + return null; + }""") + assert res is not None, "colorbar canvas not found" + assert res["fits"], f"colorbar extends past the panel edge: {res}" + assert res["labelInk"] > 0, f"colorbar label has no visible ink: {res}" diff --git a/anyplotlib/tests/test_layouts/__init__.py b/anyplotlib/tests/test_layouts/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/anyplotlib/tests/test_layouts/test_batch.py b/anyplotlib/tests/test_layouts/test_batch.py new file mode 100644 index 00000000..21e6a548 --- /dev/null +++ b/anyplotlib/tests/test_layouts/test_batch.py @@ -0,0 +1,90 @@ +"""Tests for Figure.batch() push coalescing — the linked-view lag fix.""" +from __future__ import annotations + +import numpy as np +import anyplotlib as apl + + +def _fig3(): + fig = apl.Figure(figsize=(600, 200)) + gs = apl.GridSpec(1, 3) + axs = [fig.add_subplot(gs[0, c]) for c in range(3)] + px = [np.arange(16)] * 2 + plots = [a.imshow(np.zeros((16, 16, 3), dtype=np.uint8), axes=px) for a in axs] + return fig, plots + + +def _count_pushes(fig): + calls = {"n": 0} + orig = type(fig)._push + def counting(self, pid): + # count only real trait writes (batch dirty-marking returns early) + if not self._batching: + calls["n"] += 1 + return orig(self, pid) + type(fig)._push = counting + return calls, lambda: setattr(type(fig), "_push", orig) + + +class TestBatch: + def test_coalesces_multiple_pushes_per_panel(self): + fig, plots = _fig3() + calls, restore = _count_pushes(fig) + try: + with fig.batch(): + for p in plots: + p.set_data(np.ones((16, 16, 3), dtype=np.uint8)) + p.set_title("x") # 2nd mutation, same panel + # 3 panels × 2 mutations each = 6 mutations → 3 pushes + assert calls["n"] == 3, f"expected 3 coalesced pushes, got {calls['n']}" + finally: + restore() + + def test_without_batch_pushes_per_mutation(self): + fig, plots = _fig3() + calls, restore = _count_pushes(fig) + try: + for p in plots: + p.set_data(np.ones((16, 16, 3), dtype=np.uint8)) + p.set_title("x") + assert calls["n"] == 6, f"expected 6 pushes, got {calls['n']}" + finally: + restore() + + def test_batch_applies_state(self): + fig, plots = _fig3() + with fig.batch(): + plots[0].set_title("hello") + assert plots[0]._state["title"] == "hello" + # trait reflects the change after the block + import json + st = json.loads(getattr(fig, f"panel_{plots[0]._id}_json")) + assert st["title"] == "hello" + + def test_nested_batch_is_transparent(self): + fig, plots = _fig3() + calls, restore = _count_pushes(fig) + try: + with fig.batch(): + with fig.batch(): + plots[0].set_title("a") + plots[1].set_title("b") + assert calls["n"] == 2 + finally: + restore() + + def test_3d_view_and_highlight_coalesce(self): + fig = apl.Figure(figsize=(300, 300)) + ax = fig.add_subplot(apl.GridSpec(1, 1)[0, 0]) + v = ax.scatter3d(np.zeros(4), np.zeros(4), np.zeros(4), + bounds=((-1, 1),) * 3) + calls, restore = _count_pushes(fig) + try: + with fig.batch(): + v.set_highlight(0.1, 0.2, 0.3) + v.set_view(azimuth=10, elevation=20) + assert calls["n"] == 1, f"expected 1 coalesced push, got {calls['n']}" + assert v._state["highlight"]["x"] == 0.1 + assert v._state["azimuth"] == 10 + finally: + restore() diff --git a/anyplotlib/tests/test_layouts/test_geom_channel.py b/anyplotlib/tests/test_layouts/test_geom_channel.py new file mode 100644 index 00000000..9660b016 --- /dev/null +++ b/anyplotlib/tests/test_layouts/test_geom_channel.py @@ -0,0 +1,71 @@ +"""Tests for the geometry channel: heavy geometry rides a separate trait and +is re-sent only when it actually changes (view updates don't re-transmit it).""" +from __future__ import annotations + +import json +import numpy as np +import anyplotlib as apl + + +def _scatter(): + fig = apl.Figure(figsize=(300, 300)) + ax = fig.add_subplot(apl.GridSpec(1, 1)[0, 0]) + v = ax.scatter3d(np.zeros(8), np.zeros(8), np.zeros(8), + bounds=((-1, 1),) * 3, + colors=np.tile([1, 2, 3], (8, 1)).astype(np.uint8)) + return fig, v + + +class TestGeomChannel: + def test_geom_trait_allocated(self): + fig, v = _scatter() + assert fig.has_trait(f"panel_{v._id}_geom") + + def test_view_trait_excludes_geometry(self): + fig, v = _scatter() + view = json.loads(getattr(fig, f"panel_{v._id}_json")) + for k in ("vertices_b64", "faces_b64", "point_colors_b64", "colormap_data"): + assert k not in view, f"{k} leaked into the view trait" + assert view["_geom_rev"] >= 1 + + def test_geom_trait_contains_geometry(self): + fig, v = _scatter() + geom = json.loads(getattr(fig, f"panel_{v._id}_geom")) + assert "vertices_b64" in geom and "point_colors_b64" in geom + + def test_highlight_does_not_resend_geometry(self): + fig, v = _scatter() + gkey = f"panel_{v._id}_geom" + before = getattr(fig, gkey) + rev_before = json.loads(getattr(fig, f"panel_{v._id}_json"))["_geom_rev"] + v.set_highlight(0.1, 0.2, 0.3) + assert getattr(fig, gkey) == before, "geometry re-sent on highlight move" + rev_after = json.loads(getattr(fig, f"panel_{v._id}_json"))["_geom_rev"] + assert rev_after == rev_before, "geom_rev bumped without geometry change" + assert json.loads(getattr(fig, f"panel_{v._id}_json"))["highlight"]["x"] == 0.1 + + def test_view_change_does_not_resend_geometry(self): + fig, v = _scatter() + before = getattr(fig, f"panel_{v._id}_geom") + v.set_view(azimuth=42, elevation=15) + assert getattr(fig, f"panel_{v._id}_geom") == before + assert json.loads(getattr(fig, f"panel_{v._id}_json"))["azimuth"] == 42 + + def test_geometry_change_bumps_rev_and_resends(self): + fig, v = _scatter() + gkey = f"panel_{v._id}_geom" + before = getattr(fig, gkey) + rev_before = json.loads(getattr(fig, f"panel_{v._id}_json"))["_geom_rev"] + v.set_data(np.ones(8) * 3, np.ones(8) * 4, np.ones(8) * 5) # new geometry + assert getattr(fig, gkey) != before, "geometry change not re-sent" + rev_after = json.loads(getattr(fig, f"panel_{v._id}_json"))["_geom_rev"] + assert rev_after == rev_before + 1, "geom_rev not bumped on geometry change" + + def test_plot_without_geom_keys_unaffected(self): + # Plot1D declares no _GEOM_KEYS → single-trait path, no geom trait. + fig = apl.Figure(figsize=(300, 200)) + ax = fig.add_subplot(apl.GridSpec(1, 1)[0, 0]) + p = ax.plot(np.sin(np.linspace(0, 6, 64))) + assert not fig.has_trait(f"panel_{p._id}_geom") + view = json.loads(getattr(fig, f"panel_{p._id}_json")) + assert "data_b64" in view # geometry stays inline for non-split plots diff --git a/anyplotlib/tests/test_layouts/test_gridspec.py b/anyplotlib/tests/test_layouts/test_gridspec.py new file mode 100644 index 00000000..5c604a85 --- /dev/null +++ b/anyplotlib/tests/test_layouts/test_gridspec.py @@ -0,0 +1,1141 @@ +""" +tests/test_gridspec.py +====================== + +Tests for GridSpec / SubplotSpec indexing, the figure sizing pipeline +(_compute_cell_sizes), and per-panel plot-area alignment. + +Sizing contract (all measured at the *canvas* level, before PAD margins): + - All panels in the same grid column have the same canvas width (pw). + - All panels in the same grid row have the same canvas height (ph). + - Grid tracks are pure ratio math — no aspect-locking. + col_px[i] = fig_width * width_ratios[i] / sum(width_ratios) + row_px[r] = fig_height * height_ratios[r] / sum(height_ratios) + - For N equal-ratio columns inside figsize (fw, fh): + each column width == fw / N (within 1 px rounding). + - width_ratios / height_ratios scale the tracks proportionally. + - The total figure area is not exceeded: sum(col tracks) <= fw, + sum(row tracks) <= fh. + - Images are rendered "contain" (letterboxed) in JS — the Python layout + engine never modifies tracks because of image content. + +Alignment contract (inner plot-area coordinates, shared PAD constants): + - PAD_L=58 PAD_R=12 PAD_T=12 PAD_B=42 + - The inner plot/image area for any panel kind is: + x=PAD_L, y=PAD_T, w=pw-PAD_L-PAD_R, h=ph-PAD_T-PAD_B + - All panels in the same column share pw → same left/right edges. + - All panels in the same row share ph → same top/bottom edges. +""" + +from __future__ import annotations + +import json +import numpy as np +import pytest + +import anyplotlib as apl +from anyplotlib.figure import Figure +from anyplotlib.figure import GridSpec, SubplotSpec +from anyplotlib.axes import Axes # noqa: F401 + +# PAD constants must match figure_esm.js (used in panel-alignment tests) +PAD_L, PAD_R, PAD_T, PAD_B = 58, 12, 12, 42 + + +# ───────────────────────────────────────────────────────────────────────────── +# Helpers +# ───────────────────────────────────────────────────────────────────────────── + +def _sizes(fig: Figure) -> dict[str, tuple[int, int]]: + """Return {panel_id: (panel_width, panel_height)} from layout_json.""" + layout = json.loads(fig.layout_json) + return {s["id"]: (s["panel_width"], s["panel_height"]) + for s in layout["panel_specs"]} + + +def _specs(fig: Figure) -> list[dict]: + return json.loads(fig.layout_json)["panel_specs"] + + +def _layout(fig: Figure) -> dict: + return json.loads(fig.layout_json) + + +def approx(a, b, tol=1): + """True when two integer pixel values are within `tol` pixels.""" + return abs(a - b) <= tol + + +# ───────────────────────────────────────────────────────────────────────────── +# Part 1 – GridSpec / SubplotSpec indexing +# ───────────────────────────────────────────────────────────────────────────── + +class TestGridSpecIndexing: + + def test_integer_index(self): + gs = GridSpec(3, 3) + s = gs[1, 2] + assert s.row_start == 1 and s.row_stop == 2 + assert s.col_start == 2 and s.col_stop == 3 + + def test_negative_index(self): + gs = GridSpec(3, 4) + s = gs[-1, -2] + assert s.row_start == 2 and s.row_stop == 3 # last row + assert s.col_start == 2 and s.col_stop == 3 # second-to-last col + + def test_full_slice(self): + gs = GridSpec(2, 4) + s = gs[0, :] # entire first row + assert s.row_start == 0 and s.row_stop == 1 + assert s.col_start == 0 and s.col_stop == 4 + + def test_partial_slice(self): + gs = GridSpec(3, 4) + s = gs[1, 1:3] + assert s.row_start == 1 and s.row_stop == 2 + assert s.col_start == 1 and s.col_stop == 3 + + def test_row_span(self): + gs = GridSpec(4, 2) + s = gs[1:3, 0] + assert s.row_start == 1 and s.row_stop == 3 + assert s.col_start == 0 and s.col_stop == 1 + + def test_full_row_and_col_span(self): + gs = GridSpec(3, 3) + s = gs[:, :] + assert s.row_start == 0 and s.row_stop == 3 + assert s.col_start == 0 and s.col_stop == 3 + + def test_last_row_full_col_span(self): + """gs[-1, :] should select the last row across all columns.""" + gs = GridSpec(3, 4) + s = gs[-1, :] + assert s.row_start == 2 and s.row_stop == 3 + assert s.col_start == 0 and s.col_stop == 4 + + def test_multi_row_multi_col_span(self): + """gs[0:2, 1:3] spans rows 0–1 and cols 1–2.""" + gs = GridSpec(4, 4) + s = gs[0:2, 1:3] + assert s.row_start == 0 and s.row_stop == 2 + assert s.col_start == 1 and s.col_stop == 3 + + # --- error cases --- + + def test_slice_step_raises(self): + gs = GridSpec(3, 3) + with pytest.raises(ValueError, match="step"): + _ = gs[0, 0:3:2] + + def test_out_of_bounds_int_row_raises(self): + gs = GridSpec(2, 2) + with pytest.raises(IndexError): + _ = gs[5, 0] + + def test_out_of_bounds_int_col_raises(self): + gs = GridSpec(2, 2) + with pytest.raises(IndexError): + _ = gs[0, 10] + + def test_out_of_bounds_negative_raises(self): + gs = GridSpec(2, 2) + with pytest.raises(IndexError): + _ = gs[-5, 0] + + def test_empty_slice_raises(self): + """A slice that produces no cells (start >= stop) must raise.""" + gs = GridSpec(3, 3) + with pytest.raises(IndexError): + _ = gs[2:2, 0] # start == stop → empty + + def test_bad_index_raises(self): + gs = GridSpec(2, 2) + with pytest.raises(IndexError): + _ = gs[0] # must be 2-tuple + + def test_wrong_index_type_raises(self): + gs = GridSpec(2, 2) + with pytest.raises(IndexError): + _ = gs["a", 0] + + def test_wrong_width_ratios_length_raises(self): + with pytest.raises(ValueError, match="width_ratios"): + GridSpec(2, 3, width_ratios=[1, 2]) # length 2 ≠ ncols 3 + + def test_wrong_height_ratios_length_raises(self): + with pytest.raises(ValueError, match="height_ratios"): + GridSpec(2, 3, height_ratios=[1, 2, 3]) # length 3 ≠ nrows 2 + + def test_default_ratios_are_equal(self): + gs = GridSpec(2, 3) + assert gs.width_ratios == [1, 1, 1] + assert gs.height_ratios == [1, 1] + + def test_custom_ratios_stored(self): + gs = GridSpec(2, 3, width_ratios=[2, 1, 1], height_ratios=[3, 1]) + assert gs.width_ratios == [2, 1, 1] + assert gs.height_ratios == [3, 1] + + def test_subplot_spec_parent_gs(self): + """SubplotSpec must reference the GridSpec it came from.""" + gs = GridSpec(2, 2) + s = gs[0, 1] + assert s._gs is gs + + def test_subplot_spec_repr(self): + gs = GridSpec(2, 2) + s = gs[0, 1] + r = repr(s) + assert "0:1" in r and "1:2" in r + + def test_gridspec_repr(self): + gs = GridSpec(3, 4) + assert "3" in repr(gs) and "4" in repr(gs) + + +# ───────────────────────────────────────────────────────────────────────────── +# Part 2 – subplots() convenience API +# ───────────────────────────────────────────────────────────────────────────── + +class TestSubplotsAPI: + + def test_1x1_returns_scalar_axes(self): + fig, ax = apl.subplots(1, 1) + assert isinstance(ax, Axes) + + def test_1xN_returns_1d_array(self): + fig, axs = apl.subplots(1, 3) + assert axs.shape == (3,) + assert all(isinstance(a, Axes) for a in axs) + + def test_Nx1_returns_1d_array(self): + fig, axs = apl.subplots(3, 1) + assert axs.shape == (3,) + + def test_NxM_returns_2d_array(self): + fig, axs = apl.subplots(2, 3) + assert axs.shape == (2, 3) + + def test_axes_specs_match_positions(self): + fig, axs = apl.subplots(2, 3) + for r in range(2): + for c in range(3): + ax = axs[r, c] + assert ax._spec.row_start == r + assert ax._spec.col_start == c + + def test_figure_nrows_ncols(self): + fig, _ = apl.subplots(3, 4) + assert fig._nrows == 3 and fig._ncols == 4 + + def test_figsize_stored(self): + fig, _ = apl.subplots(1, 1, figsize=(800, 600)) + assert fig.fig_width == 800 and fig.fig_height == 600 + + def test_width_ratios_forwarded(self): + fig, _ = apl.subplots(1, 3, width_ratios=[2, 1, 1]) + assert fig._width_ratios == [2, 1, 1] + + def test_height_ratios_forwarded(self): + fig, _ = apl.subplots(3, 1, height_ratios=[1, 2, 1]) + assert fig._height_ratios == [1, 2, 1] + + def test_sharex_stored(self): + fig, _ = apl.subplots(2, 1, sharex=True) + assert fig._sharex is True + + def test_gridspec_kw_width_ratios(self): + """gridspec_kw={'width_ratios': ...} should work like width_ratios=.""" + fig1, _ = apl.subplots(1, 2, width_ratios=[2, 1], figsize=(300, 100)) + fig2, _ = apl.subplots(1, 2, gridspec_kw={"width_ratios": [2, 1]}, figsize=(300, 100)) + assert fig1._width_ratios == fig2._width_ratios == [2, 1] + + def test_gridspec_kw_height_ratios(self): + fig1, _ = apl.subplots(2, 1, height_ratios=[3, 1], figsize=(100, 400)) + fig2, _ = apl.subplots(2, 1, gridspec_kw={"height_ratios": [3, 1]}, figsize=(100, 400)) + assert fig1._height_ratios == fig2._height_ratios == [3, 1] + + +# ───────────────────────────────────────────────────────────────────────────── +# Part 3 – _compute_cell_sizes: equal-ratio grids (no images) +# ───────────────────────────────────────────────────────────────────────────── + +class TestEqualRatioSizing: + """1D-only panels, equal ratios: each track should be fw/ncols × fh/nrows.""" + + def test_1x1_1d(self): + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + v = ax.plot(np.zeros(10)) + pw, ph = _sizes(fig)[v._id] + assert pw == 400 and ph == 300 + + def test_2x1_equal_heights(self): + fig, axs = apl.subplots(2, 1, figsize=(400, 600)) + v0 = axs[0].plot(np.zeros(10)) + v1 = axs[1].plot(np.zeros(10)) + s = _sizes(fig) + pw0, ph0 = s[v0._id] + pw1, ph1 = s[v1._id] + assert pw0 == pw1, f"widths should match: {pw0} vs {pw1}" + assert approx(ph0, 300) and approx(ph1, 300), \ + f"each row height should be 300, got {ph0}, {ph1}" + + def test_1x2_equal_widths(self): + fig, axs = apl.subplots(1, 2, figsize=(600, 300)) + v0 = axs[0].plot(np.zeros(10)) + v1 = axs[1].plot(np.zeros(10)) + s = _sizes(fig) + pw0, ph0 = s[v0._id] + pw1, ph1 = s[v1._id] + assert ph0 == ph1, f"heights should match: {ph0} vs {ph1}" + assert approx(pw0, 300) and approx(pw1, 300), \ + f"each column width should be 300, got {pw0}, {pw1}" + + def test_3x3_equal_all(self): + fig, axs = apl.subplots(3, 3, figsize=(600, 600)) + # Attach 1D plots to all 9 cells + plots = [[axs[r, c].plot(np.zeros(10)) for c in range(3)] for r in range(3)] + s = _sizes(fig) + for r in range(3): + for c in range(3): + pw, ph = s[plots[r][c]._id] + assert approx(pw, 200), f"[{r},{c}] pw={pw}, expected 200" + assert approx(ph, 200), f"[{r},{c}] ph={ph}, expected 200" + + def test_total_width_not_exceeded(self): + fig, axs = apl.subplots(1, 3, figsize=(500, 200)) + plots = [axs[c].plot(np.zeros(10)) for c in range(3)] + s = _sizes(fig) + total_w = sum(s[p._id][0] for p in plots) + assert total_w <= 500 + 3, f"total_w={total_w} exceeds figsize width 500" + + def test_total_height_not_exceeded(self): + fig, axs = apl.subplots(3, 1, figsize=(200, 500)) + plots = [axs[r].plot(np.zeros(10)) for r in range(3)] + s = _sizes(fig) + total_h = sum(s[p._id][1] for p in plots) + assert total_h <= 500 + 3, f"total_h={total_h} exceeds figsize height 500" + + +# ───────────────────────────���───────────────────────────────────────────────── +# Part 4 – _compute_cell_sizes: width_ratios / height_ratios +# ───────────────────────────────────────────────────────────────────────────── + +class TestRatioSizing: + """Verify that width/height ratios correctly scale the tracks.""" + + def test_2col_2to1_width_ratio(self): + """Left column 2× wider than right column.""" + fig, axs = apl.subplots(1, 2, figsize=(600, 200), + width_ratios=[2, 1]) + v0 = axs[0].plot(np.zeros(10)) + v1 = axs[1].plot(np.zeros(10)) + s = _sizes(fig) + pw0 = s[v0._id][0] + pw1 = s[v1._id][0] + # expected: pw0 = 400, pw1 = 200 + assert approx(pw0, 400, tol=2), f"left pw={pw0}, expected 400" + assert approx(pw1, 200, tol=2), f"right pw={pw1}, expected 200" + assert approx(pw0, 2 * pw1, tol=2), f"pw0 should be 2×pw1: {pw0} vs {pw1}" + + def test_2row_3to1_height_ratio(self): + """Top row 3× taller than bottom row.""" + fig, axs = apl.subplots(2, 1, figsize=(200, 800), + height_ratios=[3, 1]) + v0 = axs[0].plot(np.zeros(10)) + v1 = axs[1].plot(np.zeros(10)) + s = _sizes(fig) + ph0 = s[v0._id][1] + ph1 = s[v1._id][1] + assert approx(ph0, 600, tol=2), f"top ph={ph0}, expected 600" + assert approx(ph1, 200, tol=2), f"bottom ph={ph1}, expected 200" + assert approx(ph0, 3 * ph1, tol=3), f"ph0 should be 3×ph1: {ph0} vs {ph1}" + + def test_3col_equal_after_normalisation(self): + """Ratios [2, 2, 2] → same as [1, 1, 1] → equal widths.""" + fig_eq, axs_eq = apl.subplots(1, 3, figsize=(600, 100)) + fig_rat, axs_rat = apl.subplots(1, 3, figsize=(600, 100), + width_ratios=[2, 2, 2]) + for i in range(3): + axs_eq[i].plot(np.zeros(5)) + axs_rat[i].plot(np.zeros(5)) + + s_eq = sorted(pw for pw, ph in _sizes(fig_eq).values()) + s_rat = sorted(pw for pw, ph in _sizes(fig_rat).values()) + for a, b in zip(s_eq, s_rat): + assert approx(a, b, tol=1), f"equal vs scaled ratio: {a} vs {b}" + + def test_ratios_reflected_in_layout_json(self): + """layout_json must carry the correct ratios.""" + fig, _ = apl.subplots(2, 2, + width_ratios=[1, 3], + height_ratios=[2, 1]) + layout = _layout(fig) + assert layout["width_ratios"] == [1, 3] + assert layout["height_ratios"] == [2, 1] + + def test_nrows_ncols_in_layout_json(self): + fig, _ = apl.subplots(3, 4) + layout = _layout(fig) + assert layout["nrows"] == 3 + assert layout["ncols"] == 4 + + def test_panel_row_col_indices_in_layout_json(self): + """Each panel spec must carry the correct row/col start-stop.""" + fig, axs = apl.subplots(2, 3, figsize=(600, 400)) + for r in range(2): + for c in range(3): + axs[r, c].plot(np.zeros(5)) + specs = {(s["row_start"], s["col_start"]): s for s in _specs(fig)} + for r in range(2): + for c in range(3): + s = specs[(r, c)] + assert s["row_start"] == r and s["row_stop"] == r + 1 + assert s["col_start"] == c and s["col_stop"] == c + 1 + + +# ───────────────────────────────────────────────────────────────────────────── +# Part 5 – _compute_cell_sizes: 2D panels obey pure ratio math (no aspect-lock) +# ───────────────────────────────────────────────────────────────────────────── + +class Test2DPanelLayout: + """2D panels must receive exactly the canvas size their grid cell dictates. + + Images are rendered "contain" (letterboxed) by the JS renderer, so the + Python layout engine never shrinks tracks to match image aspect ratios. + """ + + def test_2d_panel_gets_full_cell_width(self): + """A 2D panel's canvas width equals the grid-ratio column width.""" + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + v = ax.imshow(np.zeros((128, 128))) + pw, ph = _sizes(fig)[v._id] + assert pw == 400, f"expected pw=400, got {pw}" + assert ph == 300, f"expected ph=300, got {ph}" + + def test_2d_nonsquare_canvas_from_nonsquare_figsize(self): + """Non-square figsize → non-square canvas even for a square image.""" + fig, ax = apl.subplots(1, 1, figsize=(600, 200)) + v = ax.imshow(np.zeros((128, 128))) + pw, ph = _sizes(fig)[v._id] + assert pw == 600 and ph == 200, f"expected 600×200, got {pw}×{ph}" + + def test_wide_image_does_not_shrink_canvas(self): + """Wide image (2:1) in a square cell — canvas stays square.""" + fig, ax = apl.subplots(1, 1, figsize=(400, 400)) + v = ax.imshow(np.zeros((128, 256))) # H=128, W=256 + pw, ph = _sizes(fig)[v._id] + assert pw == 400 and ph == 400, f"expected 400×400, got {pw}×{ph}" + + def test_tall_image_does_not_shrink_canvas(self): + """Tall image (1:2) in a square cell — canvas stays square.""" + fig, ax = apl.subplots(1, 1, figsize=(400, 400)) + v = ax.imshow(np.zeros((256, 128))) # H=256, W=128 + pw, ph = _sizes(fig)[v._id] + assert pw == 400 and ph == 400, f"expected 400×400, got {pw}×{ph}" + + def test_2d_and_1d_same_row_same_height(self): + """2D and 1D panels in the same row must have the same canvas height.""" + fig, axs = apl.subplots(1, 2, figsize=(800, 400)) + v2d = axs[0].imshow(np.zeros((128, 128))) + v1d = axs[1].plot(np.zeros(256)) + s = _sizes(fig) + ph2d = s[v2d._id][1] + ph1d = s[v1d._id][1] + assert ph2d == ph1d, \ + f"same-row panels must have equal height: 2D={ph2d}, 1D={ph1d}" + + def test_2d_and_1d_same_col_same_width(self): + """2D and 1D panels in the same column must have the same canvas width.""" + fig, axs = apl.subplots(2, 1, figsize=(400, 800)) + v2d = axs[0].imshow(np.zeros((128, 128))) + v1d = axs[1].plot(np.zeros(256)) + s = _sizes(fig) + pw2d = s[v2d._id][0] + pw1d = s[v1d._id][0] + assert pw2d == pw1d, \ + f"same-col panels must have equal width: 2D={pw2d}, 1D={pw1d}" + + def test_image_does_not_affect_sibling_panel_size(self): + """Adding an image to one panel must NOT change a sibling panel's dimensions. + + This is the key regression test for the old aspect-lock bug: + a square image in row-0 of a height_ratios=[2,1] layout used to + shrink the shared column from 800 px to 333 px. + """ + fig, axs = apl.subplots(2, 1, figsize=(800, 600), + height_ratios=[2, 1]) + v2d = axs[0].imshow(np.zeros((256, 256))) + v1d = axs[1].plot(np.zeros(10)) + s = _sizes(fig) + pw2d, ph2d = s[v2d._id] + pw1d, ph1d = s[v1d._id] + # Both panels must share the full figure width + assert pw2d == 800, f"2D panel width should be 800, got {pw2d}" + assert pw1d == 800, f"1D panel width should be 800, got {pw1d}" + # Heights follow height_ratios=[2,1] → 400 and 200 + assert approx(ph2d, 400, tol=2), f"2D panel height should be ~400, got {ph2d}" + assert approx(ph1d, 200, tol=2), f"1D panel height should be ~200, got {ph1d}" + + def test_two_2d_panels_same_col_same_width(self): + """Two 2D panels with different aspect ratios in the same column + must both get the column width — no convergence loop needed.""" + fig, axs = apl.subplots(2, 1, figsize=(400, 800)) + vA = axs[0].imshow(np.zeros((128, 128))) # square + vB = axs[1].imshow(np.zeros((128, 64))) # wide + s = _sizes(fig) + pwA, phA = s[vA._id] + pwB, phB = s[vB._id] + assert pwA == pwB == 400, \ + f"Both 2D panels in same col must have pw=400: {pwA}, {pwB}" + + def test_minimum_canvas_size_floor(self): + """Even a tiny figsize must produce canvas size ≥ 64 px.""" + fig, ax = apl.subplots(1, 1, figsize=(10, 10)) + v = ax.imshow(np.zeros((128, 128))) + pw, ph = _sizes(fig)[v._id] + assert pw >= 64 and ph >= 64, f"min size floor: {pw}×{ph}" + + +# ───────────────────────────────────────────────────────────────────────────── +# Part 6 – layout_json structure and live update +# ───────────────────────────────────────────────────────────────────────────── + +class TestLayoutJson: + + def test_layout_json_has_required_keys(self): + fig, _ = apl.subplots(2, 2) + layout = _layout(fig) + for key in ("nrows", "ncols", "width_ratios", "height_ratios", + "fig_width", "fig_height", "panel_specs", "share_groups"): + assert key in layout, f"missing key '{key}' in layout_json" + + def test_panel_specs_has_required_keys(self): + fig, ax = apl.subplots(1, 1) + ax.plot(np.zeros(5)) + spec = _specs(fig)[0] + for key in ("id", "kind", "row_start", "row_stop", + "col_start", "col_stop", "panel_width", "panel_height"): + assert key in spec, f"missing key '{key}' in panel_spec" + + def test_panel_kind_1d(self): + fig, ax = apl.subplots(1, 1) + ax.plot(np.zeros(5)) + assert _specs(fig)[0]["kind"] == "1d" + + def test_panel_kind_2d(self): + fig, ax = apl.subplots(1, 1) + ax.imshow(np.zeros((32, 32))) + assert _specs(fig)[0]["kind"] == "2d" + + def test_sharex_group_in_layout(self): + fig, axs = apl.subplots(2, 1, sharex=True) + axs[0].plot(np.zeros(5)) + axs[1].plot(np.zeros(5)) + layout = _layout(fig) + assert "x" in layout["share_groups"], "sharex=True must produce 'x' share group" + group = layout["share_groups"]["x"][0] + assert len(group) == 2 + + def test_sharey_group_in_layout(self): + fig, axs = apl.subplots(1, 2, sharey=True) + axs[0].plot(np.zeros(5)) + axs[1].plot(np.zeros(5)) + layout = _layout(fig) + assert "y" in layout["share_groups"] + + def test_no_share_groups_when_false(self): + fig, axs = apl.subplots(2, 1) + axs[0].plot(np.zeros(5)) + axs[1].plot(np.zeros(5)) + layout = _layout(fig) + assert layout["share_groups"] == {} + + def test_layout_updates_on_resize(self): + """fig_width/fig_height change must propagate into layout_json.""" + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + ax.plot(np.zeros(5)) + fig.fig_width = 800 + fig.fig_height = 600 + layout = _layout(fig) + assert layout["fig_width"] == 800 + assert layout["fig_height"] == 600 + pw, ph = list(_sizes(fig).values())[0] + assert pw == 800 and ph == 600 + + def test_panel_sizes_update_after_adding_second_panel(self): + """ + Add a second panel after the first. Both must get updated sizes + (the column or row track must be recalculated). + """ + fig, axs = apl.subplots(2, 1, figsize=(400, 400)) + v0 = axs[0].plot(np.zeros(5)) + # At this point only one panel is registered + s_before = _sizes(fig)[v0._id] + # Add second panel + v1 = axs[1].plot(np.zeros(5)) + s = _sizes(fig) + ph0 = s[v0._id][1] + ph1 = s[v1._id][1] + assert ph0 == ph1, f"after adding 2nd panel, row heights must equalise: {ph0} vs {ph1}" + assert approx(ph0, 200, tol=2), f"each row should be 200 px, got {ph0}" + + def test_panel_count_in_layout(self): + fig, axs = apl.subplots(2, 3, figsize=(600, 400)) + for r in range(2): + for c in range(3): + axs[r, c].plot(np.zeros(5)) + assert len(_specs(fig)) == 6 + + def test_figure_repr(self): + fig, _ = apl.subplots(2, 3, figsize=(600, 400)) + r = repr(fig) + assert "2x3" in r + + def test_get_axes_order(self): + """get_axes() must return axes sorted row-major (top-left → bottom-right).""" + fig, axs = apl.subplots(2, 2, figsize=(400, 400)) + for r in range(2): + for c in range(2): + axs[r, c].plot(np.zeros(5)) + ordered = fig.get_axes() + positions = [(a._spec.row_start, a._spec.col_start) for a in ordered] + assert positions == sorted(positions), f"axes not in row-major order: {positions}" + + +# ───────────────────────────────────────────────────────────────────────────── +# Part 7 – edge cases +# ───────────────────────────────────────────────────────────────────────────── + +class TestEdgeCases: + + def test_single_row_many_cols(self): + fig, axs = apl.subplots(1, 5, figsize=(500, 100)) + plots = [axs[c].plot(np.zeros(5)) for c in range(5)] + s = _sizes(fig) + widths = [s[p._id][0] for p in plots] + heights = [s[p._id][1] for p in plots] + # All same height + assert len(set(heights)) == 1, f"all heights must be equal: {heights}" + # Each ~100 px wide + for w in widths: + assert approx(w, 100, tol=2), f"width {w} should be ≈100" + + def test_single_col_many_rows(self): + fig, axs = apl.subplots(5, 1, figsize=(100, 500)) + plots = [axs[r].plot(np.zeros(5)) for r in range(5)] + s = _sizes(fig) + widths = [s[p._id][0] for p in plots] + heights = [s[p._id][1] for p in plots] + assert len(set(widths)) == 1, f"all widths must be equal: {widths}" + for h in heights: + assert approx(h, 100, tol=2), f"height {h} should be ≈100" + + def test_add_subplot_by_int(self): + """add_subplot(int) should map correctly to row/col.""" + fig = Figure(2, 3, figsize=(600, 400)) + ax = fig.add_subplot(4) # index 4 → row=1, col=1 + assert ax._spec.row_start == 1 + assert ax._spec.col_start == 1 + + def test_add_subplot_by_tuple(self): + fig = Figure(2, 3, figsize=(600, 400)) + ax = fig.add_subplot((0, 2)) + assert ax._spec.row_start == 0 + assert ax._spec.col_start == 2 + + def test_add_subplot_by_subplot_spec(self): + fig = Figure(3, 3, figsize=(300, 300)) + gs = GridSpec(3, 3) + spec = gs[1:3, 0:2] + ax = fig.add_subplot(spec) + assert ax._spec.row_start == 1 + assert ax._spec.row_stop == 3 + assert ax._spec.col_start == 0 + assert ax._spec.col_stop == 2 + + def test_replacing_plot_preserves_panel_id(self): + """Calling imshow/plot a second time on the same Axes must reuse panel id.""" + fig, ax = apl.subplots(1, 1) + v1 = ax.plot(np.zeros(5)) + pid1 = v1._id + v2 = ax.imshow(np.zeros((32, 32))) + pid2 = v2._id + assert pid1 == pid2, "replacing plot must reuse the same panel id" + + def test_2d_canvas_equals_cell_allocation(self): + """Non-square figsize with a square image → canvas equals the full cell + (no aspect-lock shrinking). The image is letterboxed by the JS renderer.""" + fig, ax = apl.subplots(1, 1, figsize=(600, 300)) + v = ax.imshow(np.zeros((128, 128))) + pw, ph = _sizes(fig)[v._id] + assert pw == 600 and ph == 300, \ + f"canvas should equal full figsize 600×300, got {pw}×{ph}" + + def test_layout_json_is_valid_json(self): + fig, axs = apl.subplots(2, 2, figsize=(400, 400)) + for r in range(2): + for c in range(2): + axs[r, c].plot(np.zeros(5)) + # Should not raise + json.loads(fig.layout_json) + + def test_add_subplot_bad_type_raises(self): + fig = Figure(2, 2, figsize=(200, 200)) + with pytest.raises(TypeError): + fig.add_subplot("bad") + + def test_add_subplot_by_subplot_spec_is_identity(self): + """add_subplot(SubplotSpec) must use the spec exactly — no re-wrapping.""" + fig = Figure(3, 3, figsize=(300, 300)) + gs = GridSpec(3, 3) + spec = gs[1:3, 0:2] + ax = fig.add_subplot(spec) + assert ax._spec is spec # same object, not a copy + + def test_figure_add_subplot_with_gridspec_typical_workflow(self): + """Mirror the typical matplotlib workflow: + gs = GridSpec(2, 2); fig.add_subplot(gs[0, :]); etc.""" + fig = Figure(2, 2, figsize=(400, 400)) + gs = GridSpec(2, 2) + ax_top = fig.add_subplot(gs[0, :]) # top row, full width + ax_bl = fig.add_subplot(gs[1, 0]) # bottom-left + ax_br = fig.add_subplot(gs[1, 1]) # bottom-right + assert ax_top._spec.col_start == 0 and ax_top._spec.col_stop == 2 + assert ax_bl._spec.row_start == 1 and ax_bl._spec.col_start == 0 + assert ax_br._spec.row_start == 1 and ax_br._spec.col_start == 1 + + def test_figsize_in_layout_json(self): + fig, ax = apl.subplots(1, 1, figsize=(777, 555)) + ax.plot(np.zeros(5)) + layout = _layout(fig) + assert layout["fig_width"] == 777 + assert layout["fig_height"] == 555 + + +# ───────────────────────────────────────────────────────────────────────────── +# Part 8 – Panel alignment +# ───────────────────────────────────────────────────────────────────────────── + +def _plot_area(pw: int, ph: int) -> tuple[int, int, int, int]: + """Return (x, y, w, h) of the inner plot/image area for any panel kind. + + Both 1-D and 2-D panels use the same PAD constants in figure_esm.js, + so as long as Python assigns the same (pw, ph) to sibling panels they + are guaranteed to be pixel-aligned inside the shared canvas grid cell. + """ + return PAD_L, PAD_T, pw - PAD_L - PAD_R, ph - PAD_T - PAD_B + + +class TestPanelAlignment: + """Same-row / same-column panels must share canvas dimensions and + therefore produce identical inner plot-area coordinates.""" + + # ── two-row, one-column ─────────────────────────────────────────────── + + def test_2row_1col_same_width(self): + fig, axs = apl.subplots(2, 1, figsize=(600, 600)) + v2d = axs[0].imshow(np.random.rand(128, 128)) + v1d = axs[1].plot(np.sin(np.linspace(0, 6, 256))) + s = _sizes(fig) + pw2d = s[v2d._id][0] + pw1d = s[v1d._id][0] + assert pw2d == pw1d, ( + f"Panels in same column must have equal width: 2D={pw2d}, 1D={pw1d}" + ) + + def test_2row_1col_left_edge_aligned(self): + """Left edge of the 2D image area and 1D plot area must both be PAD_L.""" + fig, axs = apl.subplots(2, 1, figsize=(600, 600)) + v2d = axs[0].imshow(np.random.rand(128, 128)) + v1d = axs[1].plot(np.sin(np.linspace(0, 6, 256))) + s = _sizes(fig) + x2d = _plot_area(*s[v2d._id])[0] + x1d = _plot_area(*s[v1d._id])[0] + assert x2d == x1d == PAD_L, ( + f"Left edge must be PAD_L={PAD_L}: 2D={x2d}, 1D={x1d}" + ) + + def test_2row_1col_plot_area_widths_equal(self): + """Plot-area widths must match when panels share a column.""" + fig, axs = apl.subplots(2, 1, figsize=(600, 600)) + v2d = axs[0].imshow(np.random.rand(128, 128)) + v1d = axs[1].plot(np.sin(np.linspace(0, 6, 256))) + s = _sizes(fig) + w2d = _plot_area(*s[v2d._id])[2] + w1d = _plot_area(*s[v1d._id])[2] + assert w2d == w1d, f"Plot area widths: 2D={w2d}, 1D={w1d}" + + # ── one-row, two-column ─────────────────────────────────────────────── + + def test_1row_2col_same_height(self): + fig, axs = apl.subplots(1, 2, figsize=(800, 400)) + v2d = axs[0].imshow(np.random.rand(64, 64)) + v1d = axs[1].plot(np.cos(np.linspace(0, 6, 256))) + s = _sizes(fig) + ph2d = s[v2d._id][1] + ph1d = s[v1d._id][1] + assert ph2d == ph1d, ( + f"Panels in same row must have equal height: 2D={ph2d}, 1D={ph1d}" + ) + + def test_1row_2col_top_bottom_aligned(self): + """Top and bottom y-coordinates of plot areas must match across the row.""" + fig, axs = apl.subplots(1, 2, figsize=(800, 400)) + v2d = axs[0].imshow(np.random.rand(64, 64)) + v1d = axs[1].plot(np.cos(np.linspace(0, 6, 256))) + s = _sizes(fig) + y2d, h2d = _plot_area(*s[v2d._id])[1], _plot_area(*s[v2d._id])[3] + y1d, h1d = _plot_area(*s[v1d._id])[1], _plot_area(*s[v1d._id])[3] + assert y2d == y1d == PAD_T, f"Top y: 2D={y2d}, 1D={y1d}" + assert h2d == h1d, f"Plot area heights: 2D={h2d}, 1D={h1d}" + + # ── 2D panel canvas equals its grid cell ───────────────────────────── + + def test_square_image_gets_square_canvas(self): + """A 128×128 image in a 500×500 figsize → canvas is 500×500 (pw == ph). + Images are letterboxed in JS; the Python layout never changes the cell.""" + fig, axs = apl.subplots(1, 1, figsize=(500, 500)) + v2d = axs.imshow(np.random.rand(128, 128)) + pw, ph = _sizes(fig)[v2d._id] + assert pw == ph, f"Square figsize must give pw==ph: pw={pw}, ph={ph}" + + def test_wide_image_canvas_equals_cell(self): + """A 2:1 image in a square cell gets a square canvas — no aspect-lock.""" + fig, axs = apl.subplots(1, 1, figsize=(512, 512)) + v2d = axs.imshow(np.random.rand(128, 256)) # w=256, h=128 + pw, ph = _sizes(fig)[v2d._id] + assert pw == 512 and ph == 512, ( + f"Canvas should equal full figsize 512×512, got {pw}×{ph}" + ) + + # ── non-square 2D panel plus 1D panel — column width consistent ─────── + + def test_nonsquare_2d_and_1d_same_column(self): + """A tall non-square image in a 2-row, 1-col layout must not affect the + 1D panel's canvas width — both must equal the column track width.""" + fig, axs = apl.subplots(2, 1, figsize=(600, 800)) + v2d = axs[0].imshow(np.random.rand(256, 128)) # tall image + v1d = axs[1].plot(np.random.rand(256)) + s = _sizes(fig) + pw2d = s[v2d._id][0] + pw1d = s[v1d._id][0] + assert pw2d == pw1d, ( + f"Same-column panels must have equal width: 2D={pw2d}, 1D={pw1d}" + ) + + # ── plot-area dimensions are positive ───────────────────────────────── + + def test_plot_areas_positive(self): + fig, axs = apl.subplots(2, 1, figsize=(400, 400)) + v2d = axs[0].imshow(np.random.rand(64, 64)) + v1d = axs[1].plot(np.random.rand(128)) + for pid, (pw, ph) in _sizes(fig).items(): + x, y, w, h = _plot_area(pw, ph) + assert w > 0, f"Panel {pid}: plot area width must be positive, got {w}" + assert h > 0, f"Panel {pid}: plot area height must be positive, got {h}" + + +# ───────────────────────────────────────────────────────────────────────────── +# Part 9 – Figure + GridSpec workflow (bare Figure auto-syncs to GridSpec) +# ───────────────────────────────────────────────────────────────────────────── + +class TestFigureGridSpecWorkflow: + """Tests for the Figure + GridSpec workflow where Figure is created without + explicit nrows/ncols and auto-syncs its grid from the parent GridSpec. + + The typical pattern under test:: + + gs = GridSpec(2, 2, height_ratios=[3, 1]) + fig = Figure(figsize=(800, 600)) # defaults to nrows=1, ncols=1 + ax = fig.add_subplot(gs[0, :]) # Figure adopts 2×2 grid from gs + + Without the auto-sync, panels at row_start≥1 would get ph=0 (floored to 64) + because the Figure only knows about 1 row track. + """ + + def test_auto_sync_nrows_from_gridspec(self): + """Figure auto-updates _nrows when GridSpec has more rows.""" + gs = GridSpec(2, 1) + fig = Figure(figsize=(400, 400)) + fig.add_subplot(gs[0, 0]) + fig.add_subplot(gs[1, 0]) + assert fig._nrows == 2, f"nrows should auto-sync to 2, got {fig._nrows}" + assert fig._ncols == 1 + + def test_auto_sync_ncols_from_gridspec(self): + """Figure auto-updates _ncols when GridSpec has more columns.""" + gs = GridSpec(1, 3) + fig = Figure(figsize=(600, 200)) + fig.add_subplot(gs[0, 0]) + fig.add_subplot(gs[0, 1]) + fig.add_subplot(gs[0, 2]) + assert fig._ncols == 3, f"ncols should auto-sync to 3, got {fig._ncols}" + assert fig._nrows == 1 + + def test_auto_sync_height_ratios_from_gridspec(self): + """height_ratios from the GridSpec are adopted into the Figure.""" + gs = GridSpec(2, 1, height_ratios=[3, 1]) + fig = Figure(figsize=(400, 800)) + fig.add_subplot(gs[0, 0]) + assert fig._height_ratios == [3, 1], ( + f"height_ratios should be [3, 1], got {fig._height_ratios}" + ) + + def test_auto_sync_width_ratios_from_gridspec(self): + """width_ratios from the GridSpec are adopted into the Figure.""" + gs = GridSpec(1, 2, width_ratios=[2, 1]) + fig = Figure(figsize=(600, 200)) + fig.add_subplot(gs[0, 0]) + assert fig._width_ratios == [2, 1], ( + f"width_ratios should be [2, 1], got {fig._width_ratios}" + ) + + def test_gridspec_height_ratios_applied_to_sizes(self): + """Panels at correct heights according to GridSpec height_ratios.""" + gs = GridSpec(2, 1, height_ratios=[3, 1]) + fig = Figure(figsize=(400, 800)) + v0 = fig.add_subplot(gs[0, 0]).plot(np.zeros(10)) + v1 = fig.add_subplot(gs[1, 0]).plot(np.zeros(10)) + s = _sizes(fig) + ph0 = s[v0._id][1] + ph1 = s[v1._id][1] + assert approx(ph0, 600, tol=2), ( + f"top panel should be 600px (3/4 of 800), got {ph0}" + ) + assert approx(ph1, 200, tol=2), ( + f"bottom panel should be 200px (1/4 of 800), got {ph1}" + ) + assert approx(ph0, 3 * ph1, tol=3), ( + f"3:1 height ratio not met: {ph0} vs {ph1}" + ) + + def test_gridspec_width_ratios_applied_to_sizes(self): + """Panels at correct widths according to GridSpec width_ratios.""" + gs = GridSpec(1, 2, width_ratios=[2, 1]) + fig = Figure(figsize=(600, 200)) + v0 = fig.add_subplot(gs[0, 0]).plot(np.zeros(10)) + v1 = fig.add_subplot(gs[0, 1]).plot(np.zeros(10)) + s = _sizes(fig) + pw0 = s[v0._id][0] + pw1 = s[v1._id][0] + assert approx(pw0, 400, tol=2), ( + f"left panel should be 400px (2/3 of 600), got {pw0}" + ) + assert approx(pw1, 200, tol=2), ( + f"right panel should be 200px (1/3 of 600), got {pw1}" + ) + + def test_two_spectra_side_by_side_not_squished(self): + """Two 1D spectra side by side must each get half the figure width.""" + gs = GridSpec(1, 2) + fig = Figure(figsize=(800, 300)) + v0 = fig.add_subplot(gs[0, 0]).plot(np.zeros(100)) + v1 = fig.add_subplot(gs[0, 1]).plot(np.zeros(100)) + s = _sizes(fig) + pw0, ph0 = s[v0._id] + pw1, ph1 = s[v1._id] + assert approx(pw0, 400, tol=2), ( + f"left spectrum should be 400px wide, got {pw0}" + ) + assert approx(pw1, 400, tol=2), ( + f"right spectrum should be 400px wide, got {pw1}" + ) + assert ph0 == ph1 == 300, ( + f"both spectra should be 300px tall: {ph0}, {ph1}" + ) + # Inner plot area must be substantial (not 64px-floor squished) + inner_w = pw0 - PAD_L - PAD_R + assert inner_w > 200, ( + f"inner plot width should be >200px, got {inner_w} " + f"(panel was squished if ≤64)" + ) + + def test_image_and_two_spectra_correct_ratios(self): + """Image spanning top row (3×), two spectra below (1×) side by side. + + This is the canonical use-case the bug report describes: when using + GridSpec with a bare Figure, the second-row spectra used to get floored + to 64px because Figure._height_ratios had only 1 track. + """ + gs = GridSpec(2, 2, height_ratios=[3, 1]) + fig = Figure(figsize=(800, 800)) + v_img = fig.add_subplot(gs[0, :]).imshow(np.zeros((64, 64))) + v_sp1 = fig.add_subplot(gs[1, 0]).plot(np.zeros(100)) + v_sp2 = fig.add_subplot(gs[1, 1]).plot(np.zeros(100)) + s = _sizes(fig) + + pw_img, ph_img = s[v_img._id] + pw_sp1, ph_sp1 = s[v_sp1._id] + pw_sp2, ph_sp2 = s[v_sp2._id] + + # Image spans full width + assert pw_img == 800, f"image should span full width 800, got {pw_img}" + # Image gets 3/4 of height = 600px + assert approx(ph_img, 600, tol=2), ( + f"image should be 600px tall (3/4 of 800), got {ph_img}" + ) + # Each spectrum gets half width + assert approx(pw_sp1, 400, tol=2), ( + f"left spectrum width should be 400, got {pw_sp1}" + ) + assert approx(pw_sp2, 400, tol=2), ( + f"right spectrum width should be 400, got {pw_sp2}" + ) + # Spectra get 1/4 of height = 200px (not 64px floor!) + assert approx(ph_sp1, 200, tol=2), ( + f"spectrum height should be 200px (1/4 of 800), not 64 floor, got {ph_sp1}" + ) + assert ph_sp1 == ph_sp2, ( + f"both spectra must have the same height: {ph_sp1} vs {ph_sp2}" + ) + + def test_explicit_figure_dims_beat_smaller_gridspec(self): + """When Figure has explicit nrows/ncols >= GridSpec, Figure values win.""" + gs = GridSpec(2, 1, height_ratios=[1, 1]) # equal ratios + fig = Figure(2, 1, figsize=(400, 800), height_ratios=[3, 1]) # explicit 3:1 + v0 = fig.add_subplot(gs[0, 0]).plot(np.zeros(10)) + v1 = fig.add_subplot(gs[1, 0]).plot(np.zeros(10)) + s = _sizes(fig) + ph0 = s[v0._id][1] + ph1 = s[v1._id][1] + # Figure's [3:1] must win over GridSpec's [1:1] + assert approx(ph0, 600, tol=2), ( + f"Figure's 3:1 ratio must be preserved: top={ph0}, expected 600" + ) + assert approx(ph1, 200, tol=2), ( + f"Figure's 3:1 ratio must be preserved: bottom={ph1}, expected 200" + ) + + def test_layout_json_nrows_ncols_after_auto_sync(self): + """layout_json must reflect the auto-synced nrows/ncols.""" + gs = GridSpec(3, 2) + fig = Figure(figsize=(600, 600)) + fig.add_subplot(gs[0, 0]).plot(np.zeros(5)) + fig.add_subplot(gs[1, 0]).plot(np.zeros(5)) + fig.add_subplot(gs[2, 0]).plot(np.zeros(5)) + layout = _layout(fig) + assert layout["nrows"] == 3, ( + f"layout_json nrows should be 3, got {layout['nrows']}" + ) + assert layout["ncols"] == 2, ( + f"layout_json ncols should be 2, got {layout['ncols']}" + ) + + def test_second_row_panel_not_floored_to_64(self): + """Regression: panel at row_start=1 with a 1-row Figure used to be floored to 64px.""" + gs = GridSpec(2, 1) + fig = Figure(figsize=(400, 400)) + _ = fig.add_subplot(gs[0, 0]).plot(np.zeros(5)) + v1 = fig.add_subplot(gs[1, 0]).plot(np.zeros(5)) + s = _sizes(fig) + ph1 = s[v1._id][1] + assert ph1 > 64, ( + f"Row-1 panel must NOT be floored to 64px; got ph={ph1}. " + "This indicates the Figure failed to auto-sync its nrows from the GridSpec." + ) + assert approx(ph1, 200, tol=2), ( + f"Row-1 panel should be 200px (half of 400), got {ph1}" + ) + + def test_three_row_gridspec_all_panels_correct_height(self): + """All three panels in a 3-row GridSpec (equal ratios) get 1/3 of height.""" + gs = GridSpec(3, 1) + fig = Figure(figsize=(400, 600)) + plots = [fig.add_subplot(gs[r, 0]).plot(np.zeros(5)) for r in range(3)] + s = _sizes(fig) + for i, v in enumerate(plots): + ph = s[v._id][1] + assert approx(ph, 200, tol=2), ( + f"Panel {i} should be 200px (1/3 of 600), got {ph}" + ) + + def test_spanning_subplot_correct_size(self): + """gs[0, :] spanning all columns must get the full figure width.""" + gs = GridSpec(2, 3, height_ratios=[2, 1]) + fig = Figure(figsize=(900, 600)) + v_top = fig.add_subplot(gs[0, :]).plot(np.zeros(10)) # spans 3 cols + v_bl = fig.add_subplot(gs[1, 0]).plot(np.zeros(10)) + v_bm = fig.add_subplot(gs[1, 1]).plot(np.zeros(10)) + v_br = fig.add_subplot(gs[1, 2]).plot(np.zeros(10)) + s = _sizes(fig) + + pw_top, ph_top = s[v_top._id] + assert pw_top == 900, f"spanning subplot should be full width 900, got {pw_top}" + assert approx(ph_top, 400, tol=2), ( + f"spanning subplot should be 400px (2/3 of 600), got {ph_top}" + ) + + # Bottom row: each panel = 300px wide, 200px tall + for label, v in [("bottom-left", v_bl), ("bottom-mid", v_bm), ("bottom-right", v_br)]: + pw, ph = s[v._id] + assert approx(pw, 300, tol=2), f"{label} width should be 300, got {pw}" + assert approx(ph, 200, tol=2), f"{label} height should be 200, got {ph}" + + +# ───────────────────────────────────────────────────────────────────────────── +# subplots_adjust +# ───────────────────────────────────────────────────────────────────────────── + +class TestSubplotsAdjust: + + def test_hspace_in_layout_json(self): + fig, _ = apl.subplots(2, 1, figsize=(400, 400)) + fig.subplots_adjust(hspace=0.3) + layout = _layout(fig) + assert abs(layout['hspace'] - 0.3) < 1e-9 + + def test_wspace_in_layout_json(self): + fig, _ = apl.subplots(1, 2, figsize=(400, 200)) + fig.subplots_adjust(wspace=0.2) + layout = _layout(fig) + assert abs(layout['wspace'] - 0.2) < 1e-9 + + def test_defaults_are_none(self): + fig, _ = apl.subplots(2, 2, figsize=(400, 400)) + layout = _layout(fig) + assert layout['hspace'] is None + assert layout['wspace'] is None + + def test_both_together(self): + fig, _ = apl.subplots(2, 2, figsize=(600, 600)) + fig.subplots_adjust(hspace=0.15, wspace=0.25) + layout = _layout(fig) + assert abs(layout['hspace'] - 0.15) < 1e-9 + assert abs(layout['wspace'] - 0.25) < 1e-9 + + def test_retriggers_layout_push(self): + fig, _ = apl.subplots(2, 1, figsize=(400, 400)) + before = fig.layout_json + fig.subplots_adjust(hspace=0.1) + assert fig.layout_json != before + + +# =========================================================================== +# hspace / wspace initial-value contract +# =========================================================================== + +class TestHspaceWspaceInitialState: + def test_initial_hspace_is_none(self): + """Before subplots_adjust the internal value is None (browser 4px default).""" + fig, _ = apl.subplots(2, 2) + assert fig._hspace is None + assert fig._wspace is None + + def test_subplots_adjust_zero_stores_zero(self): + """subplots_adjust(hspace=0.0) must store 0.0, not None.""" + fig, _ = apl.subplots(2, 1) + fig.subplots_adjust(hspace=0.0, wspace=0.0) + assert fig._hspace == 0.0 + assert fig._wspace == 0.0 + + def test_subplots_adjust_zero_appears_in_layout(self): + fig, _ = apl.subplots(2, 2) + fig.subplots_adjust(hspace=0.0, wspace=0.0) + layout = json.loads(fig.layout_json) + assert layout["hspace"] == pytest.approx(0.0) + assert layout["wspace"] == pytest.approx(0.0) + + diff --git a/anyplotlib/tests/test_layouts/test_inset.py b/anyplotlib/tests/test_layouts/test_inset.py new file mode 100644 index 00000000..7c67bd33 --- /dev/null +++ b/anyplotlib/tests/test_layouts/test_inset.py @@ -0,0 +1,386 @@ +""" +Tests for InsetAxes — floating overlay inset panels. + +Unit tests +---------- +Covers: + - Creation via fig.add_inset() + - layout_json inset_specs content + - All four corners + - Multi-inset stacking (same corner) + - State transitions (minimize / maximize / restore) + - Python-side property inset_state + - _on_event dispatch for on_inset_state_change + - pcolormesh and 1D insets + - Invalid corner raises ValueError + - Figure resize keeps inset fracs correct + - plot._id registered in _plots_map + +Visual regression tests +----------------------- +Pixel-accurate rendering checks for inset panels in a headless Chromium +browser. Each test renders a deterministic Figure and compares it against +a golden PNG in ``tests/baselines/``. + +Generate / refresh baselines:: + + uv run pytest tests/test_layouts/test_inset.py --update-baselines -v + +Normal CI run (fails on regression):: + + uv run pytest tests/test_layouts/test_inset.py -v +""" +from __future__ import annotations + +import json +import pathlib + +import numpy as np +import pytest +import anyplotlib as apl +from anyplotlib.axes import InsetAxes + + +# ── helpers (unit tests) ────────────────────────────────────────────────────── + +def _make_fig(): + fig, ax = apl.subplots(1, 1, figsize=(640, 480)) + ax.imshow(np.zeros((64, 64))) + return fig + + +def _inset_spec(fig, plot_id): + layout = json.loads(fig.layout_json) + return next(s for s in layout["inset_specs"] if s["id"] == plot_id) + + +# ── creation ───────────────────────────────────────────────────────────────── + +def test_add_inset_returns_inset_axes(): + fig = _make_fig() + inset = fig.add_inset(0.3, 0.3, corner="top-right", title="T") + assert isinstance(inset, InsetAxes) + + +def test_inset_imshow_returns_plot2d(): + from anyplotlib import Plot2D + fig = _make_fig() + inset = fig.add_inset(0.3, 0.3) + plot = inset.imshow(np.zeros((32, 32))) + assert isinstance(plot, Plot2D) + + +def test_inset_plot_returns_plot1d(): + from anyplotlib import Plot1D + fig = _make_fig() + inset = fig.add_inset(0.3, 0.2, corner="bottom-left") + plot = inset.plot(np.zeros(64)) + assert isinstance(plot, Plot1D) + + +def test_inset_pcolormesh_returns_plotmesh(): + from anyplotlib import PlotMesh + fig = _make_fig() + inset = fig.add_inset(0.3, 0.3, corner="bottom-right") + plot = inset.pcolormesh(np.zeros((8, 8)), + np.linspace(0, 1, 9), np.linspace(0, 1, 9)) + assert isinstance(plot, PlotMesh) + + +# ── layout JSON ────────────────────────────────────────────────────────────── + +def test_inset_spec_in_layout_json(): + fig = _make_fig() + inset = fig.add_inset(0.25, 0.25, corner="top-left", title="Phase") + plot = inset.imshow(np.zeros((32, 32))) + + layout = json.loads(fig.layout_json) + assert "inset_specs" in layout + assert len(layout["inset_specs"]) == 1 + spec = layout["inset_specs"][0] + assert spec["id"] == plot._id + assert spec["kind"] == "2d" + assert spec["corner"] == "top-left" + assert spec["title"] == "Phase" + assert spec["w_frac"] == pytest.approx(0.25) + assert spec["h_frac"] == pytest.approx(0.25) + assert spec["inset_state"] == "normal" + + +def test_multiple_insets_in_layout(): + fig = _make_fig() + for corner in ("top-right", "top-left", "bottom-right", "bottom-left"): + inset = fig.add_inset(0.2, 0.2, corner=corner, title=corner) + inset.imshow(np.zeros((16, 16))) + + layout = json.loads(fig.layout_json) + assert len(layout["inset_specs"]) == 4 + corners = {s["corner"] for s in layout["inset_specs"]} + assert corners == {"top-right", "top-left", "bottom-right", "bottom-left"} + + +def test_inset_panel_width_height_computed_from_fracs(): + fig = _make_fig() # 640×480 + inset = fig.add_inset(0.25, 0.30, corner="top-right") + inset.imshow(np.zeros((32, 32))) + + spec = _inset_spec(fig, inset._plot._id) + assert spec["panel_width"] == max(64, round(640 * 0.25)) + assert spec["panel_height"] == max(64, round(480 * 0.30)) + + +def test_inset_registered_in_plots_map(): + fig = _make_fig() + inset = fig.add_inset(0.3, 0.3) + plot = inset.imshow(np.zeros((32, 32))) + assert plot._id in fig._plots_map + assert plot._id in fig._insets_map + + +# ── stacking (same corner) ─────────────────────────────────────────────────── + +def test_two_insets_same_corner(): + fig = _make_fig() + i1 = fig.add_inset(0.25, 0.25, corner="top-right", title="A") + i1.imshow(np.zeros((32, 32))) + i2 = fig.add_inset(0.25, 0.25, corner="top-right", title="B") + i2.imshow(np.zeros((32, 32))) + + layout = json.loads(fig.layout_json) + tr = [s for s in layout["inset_specs"] if s["corner"] == "top-right"] + assert len(tr) == 2 + + +# ── state transitions ──────────────────────────────────────────────────────── + +@pytest.mark.parametrize("method,expected", [ + ("minimize", "minimized"), + ("maximize", "maximized"), + ("restore", "normal"), +]) +def test_state_transition(method, expected): + fig = _make_fig() + inset = fig.add_inset(0.3, 0.3) + plot = inset.imshow(np.zeros((32, 32))) + + getattr(inset, method)() + assert inset.inset_state == expected + assert _inset_spec(fig, plot._id)["inset_state"] == expected + + +def test_state_idempotent(): + """Calling minimize() twice doesn't trigger an extra _push_layout.""" + fig = _make_fig() + inset = fig.add_inset(0.3, 0.3) + inset.imshow(np.zeros((32, 32))) + + inset.minimize() + layout_before = fig.layout_json + inset.minimize() # already minimized — should be a no-op + assert fig.layout_json == layout_before + + +def test_restore_from_minimized(): + fig = _make_fig() + inset = fig.add_inset(0.3, 0.3) + inset.imshow(np.zeros((32, 32))) + inset.minimize() + inset.restore() + assert inset.inset_state == "normal" + + +def test_maximize_then_restore(): + fig = _make_fig() + inset = fig.add_inset(0.3, 0.3) + inset.imshow(np.zeros((32, 32))) + inset.maximize() + assert inset.inset_state == "maximized" + inset.restore() + assert inset.inset_state == "normal" + + +# ── on_inset_state_change event (JS→Python path) ───────────────────────────── + +def test_on_event_inset_state_change(): + fig = _make_fig() + inset = fig.add_inset(0.3, 0.3) + plot = inset.imshow(np.zeros((32, 32))) + + # Simulate a JS button click delivering on_inset_state_change + fig.event_json = json.dumps({ + "source": "js", + "panel_id": plot._id, + "event_type": "inset_state_change", + "new_state": "minimized", + }) + + assert inset.inset_state == "minimized" + assert _inset_spec(fig, plot._id)["inset_state"] == "minimized" + + +def test_on_event_inset_state_restore_via_event(): + fig = _make_fig() + inset = fig.add_inset(0.3, 0.3) + plot = inset.imshow(np.zeros((32, 32))) + inset.minimize() + + fig.event_json = json.dumps({ + "source": "js", + "panel_id": plot._id, + "event_type": "inset_state_change", + "new_state": "normal", + }) + assert inset.inset_state == "normal" + + +# ── figure resize updates inset dimensions ─────────────────────────────────── + +def test_resize_updates_inset_panel_size(): + fig = _make_fig() + inset = fig.add_inset(0.3, 0.3) + plot = inset.imshow(np.zeros((32, 32))) + + fig.fig_width = 800 + fig.fig_height = 600 + + spec = _inset_spec(fig, plot._id) + assert spec["panel_width"] == max(64, round(800 * 0.3)) + assert spec["panel_height"] == max(64, round(600 * 0.3)) + + +# ── corner validation ───────────────────────────────────────────────────────── + +def test_invalid_corner_raises(): + fig = _make_fig() + with pytest.raises(ValueError, match="corner"): + fig.add_inset(0.3, 0.3, corner="centre").imshow(np.zeros((4, 4))) + + +# ── repr ───────────────────────────────────────────────────────────────────── + +def test_repr(): + fig = _make_fig() + inset = fig.add_inset(0.28, 0.28, corner="top-right", title="T") + inset.imshow(np.zeros((32, 32))) + r = repr(inset) + assert "InsetAxes" in r + assert "top-right" in r + assert "normal" in r + + +# ───────────────────────────────────────────────────────────────────────────── +# Visual regression tests +# ───────────────────────────────────────────────────────────────────────────── + +BASELINES = pathlib.Path(__file__).parent / "baselines" + + +def _check(name: str, arr: np.ndarray, update: bool) -> None: + from anyplotlib.tests._png_utils import decode_png, encode_png, compare_arrays + + path = BASELINES / f"{name}.png" + + if update: + BASELINES.mkdir(exist_ok=True) + path.write_bytes(encode_png(arr)) + pytest.skip(f"Baseline updated: {path.name}") + + if not path.exists(): + pytest.skip( + f"No baseline for {name!r} — run with --update-baselines to create it" + ) + + expected = decode_png(path.read_bytes()) + ok, msg = compare_arrays(arr, expected) + assert ok, f"Visual regression [{name}]: {msg}" + + +def _main_fig(): + """640×480 figure with a grayscale 64×64 imshow — the inset host.""" + rng = np.random.default_rng(0) + fig, ax = apl.subplots(1, 1, figsize=(640, 480)) + ax.imshow(rng.uniform(0.0, 1.0, (64, 64)).astype(np.float32)) + return fig + + +class TestInsetVisual: + """Pixel-level visual regression tests for the floating inset panel system.""" + + # ── single inset, normal state ───────────────────────────────────────── + + def test_inset_normal_2d(self, take_screenshot, update_baselines): + """2-D inset in top-right corner, normal state.""" + rng = np.random.default_rng(1) + fig = _main_fig() + inset = fig.add_inset(0.30, 0.30, corner="top-right", title="Zoom") + inset.imshow(rng.uniform(0.0, 1.0, (32, 32)).astype(np.float32), + cmap="viridis") + arr = take_screenshot(fig) + _check("inset_normal_2d", arr, update_baselines) + + def test_inset_minimized(self, take_screenshot, update_baselines): + """Inset collapsed to title bar only after minimize().""" + rng = np.random.default_rng(2) + fig = _main_fig() + inset = fig.add_inset(0.30, 0.30, corner="top-right", title="Phase") + inset.imshow(rng.uniform(0.0, 1.0, (32, 32)).astype(np.float32)) + inset.minimize() + arr = take_screenshot(fig) + _check("inset_minimized", arr, update_baselines) + + def test_inset_maximized(self, take_screenshot, update_baselines): + """Inset expanded to ~72 % of figure after maximize().""" + rng = np.random.default_rng(3) + fig = _main_fig() + inset = fig.add_inset(0.30, 0.30, corner="top-right", title="Detail") + inset.imshow(rng.uniform(0.0, 1.0, (32, 32)).astype(np.float32), + cmap="inferno") + inset.maximize() + arr = take_screenshot(fig) + _check("inset_maximized", arr, update_baselines) + + # ── two insets stacked in the same corner ────────────────────────────── + + def test_inset_stacked(self, take_screenshot, update_baselines): + """Two insets sharing top-right corner stack with constant gap.""" + rng = np.random.default_rng(4) + fig = _main_fig() + i1 = fig.add_inset(0.28, 0.25, corner="top-right", title="A") + i1.imshow(rng.uniform(0.0, 1.0, (32, 32)).astype(np.float32)) + i2 = fig.add_inset(0.28, 0.25, corner="top-right", title="B") + i2.imshow(rng.uniform(0.0, 1.0, (32, 32)).astype(np.float32), + cmap="hot") + arr = take_screenshot(fig) + _check("inset_stacked", arr, update_baselines) + + # ── 1-D line inset ───────────────────────────────────────────────────── + + def test_inset_1d(self, take_screenshot, update_baselines): + """1-D line plot inset in bottom-right corner.""" + rng = np.random.default_rng(5) + fig = _main_fig() + inset = fig.add_inset(0.32, 0.22, corner="bottom-right", + title="Profile") + t = np.linspace(0.0, 2 * np.pi, 128) + inset.plot(np.sin(t) + rng.normal(0, 0.05, 128), + color="#4fc3f7", linewidth=1.5) + arr = take_screenshot(fig) + _check("inset_1d", arr, update_baselines) + + # ── stacked with one minimized (restack test) ────────────────────────── + + def test_inset_stacked_one_minimized(self, take_screenshot, update_baselines): + """Two insets in same corner; first minimized — second shifts up.""" + rng = np.random.default_rng(6) + fig = _main_fig() + i1 = fig.add_inset(0.28, 0.25, corner="bottom-left", title="Min") + i1.imshow(rng.uniform(0.0, 1.0, (32, 32)).astype(np.float32)) + i2 = fig.add_inset(0.28, 0.25, corner="bottom-left", title="Normal") + i2.imshow(rng.uniform(0.0, 1.0, (32, 32)).astype(np.float32), + cmap="viridis") + i1.minimize() + arr = take_screenshot(fig) + _check("inset_stacked_one_minimized", arr, update_baselines) + + + diff --git a/anyplotlib/tests/test_layouts/test_interaction.py b/anyplotlib/tests/test_layouts/test_interaction.py new file mode 100644 index 00000000..002c6163 --- /dev/null +++ b/anyplotlib/tests/test_layouts/test_interaction.py @@ -0,0 +1,1153 @@ +""" +tests/test_interaction.py +========================= + +Real browser interaction tests using headless Chromium (Playwright). + +These tests open the widget's standalone HTML in a real browser, fire +actual mouse events (mousedown → mousemove → mouseup), and verify that: + + * Widget positions update correctly in the panel JSON state. + * ``pointer_move`` events are emitted during a drag. + * ``pointer_up`` events are emitted on mouseup with the correct widget ID. + +All coordinate maths mirrors the JavaScript constants exactly: + PAD_L=58 PAD_R=12 PAD_T=12 PAD_B=42 gridDiv padding=8 px + +For a 400×240 figure the plot rectangle in canvas space is: + r = {x:58, y:12, w:330, h:186} + +Canvas coords → page coords: page_x = canvas_x + 8, page_y = canvas_y + 8 + +Run: + uv run pytest tests/test_interaction.py -v + +""" +from __future__ import annotations + +import json + +import numpy as np +import pytest + +import anyplotlib as apl + +# ── layout constants (must match figure_esm.js) ─────────────────────────── +PAD_L, PAD_R, PAD_T, PAD_B = 58, 12, 12, 42 +GRID_PAD = 8 # gridDiv padding:8px — offset of canvas from page origin + + +# ── coordinate helpers ──────────────────────────────────────────────────── + +def _plot_rect(pw: int, ph: int) -> dict: + """Return the 1-D/1-D-bar plot rectangle (mirrors _plotRect1d in JS).""" + return dict(x=PAD_L, y=PAD_T, w=pw - PAD_L - PAD_R, h=ph - PAD_T - PAD_B) + + +def _data_to_frac(x_val: float, n_samples: int) -> float: + """Data value → [0,1] fraction for a uniform x_axis = arange(n_samples).""" + return x_val / (n_samples - 1) + + +def _frac_to_canvas_x(frac: float, r: dict, + view_x0: float = 0.0, view_x1: float = 1.0) -> float: + """Fraction along the data axis → canvas-space x pixel.""" + return r["x"] + ((frac - view_x0) / (view_x1 - view_x0)) * r["w"] + + +def _val_to_canvas_y(val: float, data_min: float, data_max: float, + r: dict) -> float: + """Data value → canvas-space y pixel (mirrors _valToPy1d in JS).""" + return r["y"] + r["h"] - ((val - data_min) / (data_max - data_min)) * r["h"] + + +def _to_page(canvas_x: float, canvas_y: float) -> tuple[int, int]: + """Canvas-space (x, y) → integer page-space (x, y).""" + return int(round(canvas_x)) + GRID_PAD, int(round(canvas_y)) + GRID_PAD + + +def _rafter(page) -> None: + """Wait one requestAnimationFrame so any pending draw/commit settles.""" + page.evaluate("() => new Promise(r => requestAnimationFrame(r))") + + +def _panel_state(page, panel_id: str) -> dict: + """Return the parsed panel JSON from the model.""" + raw = page.evaluate(f"() => window._aplModel.get('panel_{panel_id}_json')") + return json.loads(raw) + + +def _event(page) -> dict: + """Return the last parsed event_json from the model.""" + raw = page.evaluate("() => window._aplModel.get('event_json')") + return json.loads(raw) + + +# ── shared figure parameters ────────────────────────────────────────────── +FIG_W, FIG_H = 400, 240 +N = 100 # number of data samples; x_axis = [0, 1, …, 99] + +# ── CSS-scale simulation helpers ────────────────────────────────────────── + +def _to_page_scaled(canvas_x: float, canvas_y: float, scale: float) -> tuple[int, int]: + """Canvas coords → page coords when outerDiv has transform:scale(s) origin top-left. + + With transform-origin:top left the canvas (at layout offset GRID_PAD inside + outerDiv) maps to visual page position:: + + page_x = (GRID_PAD + canvas_x) * scale + page_y = (GRID_PAD + canvas_y) * scale + + These are the coordinates a user would actually click on screen. + """ + return ( + int(round((GRID_PAD + canvas_x) * scale)), + int(round((GRID_PAD + canvas_y) * scale)), + ) + + +def _inject_scale(page, scale: float = 0.75) -> float: + """Simulate a narrow Jupyter cell so ``_applyScale()`` computes AND maintains scale. + + The naive approach — directly setting ``transform:scale(s)`` on + ``.apl-outer`` — breaks under drag because every ``model.save_changes()`` + in the standalone shim fires the ``change:layout_json`` callback, which + schedules ``requestAnimationFrame(_applyScale)``. ``_applyScale`` then + recomputes ``s = cellW/nativeW`` and, seeing a cell that looks native-width, + silently **removes** the manually-injected transform. + + Instead, we constrain ``#widget-root`` to ``nativeW * scale`` pixels so + that ``_applyScale`` reads ``cellW = cell_w``, derives the same ``s``, and + keeps re-applying it on every rAF — including those triggered during drag. + + The ``out.style.width = nativeW + 'px'`` pin below is a **defensive + guard** only: since ``.apl-outer`` now carries ``min-width: max-content`` + in its CSS class, ``outerDiv.offsetWidth`` already equals the true native + figure width even when the parent ``scaleWrap`` has been narrowed. + Without ``min-width: max-content`` the ``inline-block`` would shrink to + ``cellW``, making ``_applyScale`` compute ``s = cellW/cellW = 1.0``. + + Returns the actual scale factor applied (the ``s`` passed to the transform). + """ + native_w = page.evaluate( + "() => { const o = document.querySelector('.apl-outer'); return o ? o.offsetWidth : 0; }" + ) + cell_w = max(10, int(round(native_w * scale))) + # 1. Pin outerDiv width explicitly (defensive — min-width:max-content in + # .apl-outer CSS already prevents shrinkage, but this is cheap insurance + # for any edge-case where the class hasn't fully applied yet). + # 2. Constrain #widget-root so _applyScale reads cellW = cell_w on every + # rAF — including those triggered by save_changes() during drag. + # 3. Apply the transform immediately (same formula as _applyScale) so the + # scale takes effect without waiting for a rAF cycle to fire. + actual_s = page.evaluate(f"""() => {{ + const el = document.getElementById('widget-root'); + const out = document.querySelector('.apl-outer'); + if (!out || !el) return 1.0; + // Defensive pin (redundant when min-width:max-content is active) + const nativeW = out.offsetWidth; + out.style.width = nativeW + 'px'; + // Constrain container so _applyScale re-derives s on every rAF + el.style.maxWidth = '{cell_w}px'; + el.style.overflow = 'visible'; + // Apply scale immediately (mirrors _applyScale formula) + const s = Math.min(1.0, {cell_w} / nativeW); + out.style.transformOrigin = 'top left'; + out.style.transform = s < 1 ? 'scale(' + s + ')' : ''; + return s; + }}""") + _rafter(page) + return float(actual_s) if actual_s else scale + + +# ═══════════════════════════════════════════════════════════════════════════ +# VLine drag tests +# ═══════════════════════════════════════════════════════════════════════════ + +class TestVLineDrag1D: + """Drag a VLineWidget on a 1-D panel and verify JS state + events.""" + + def _make_fig(self): + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + t = np.arange(N, dtype=float) + plot = ax.plot(np.sin(2 * np.pi * t / N)) + return fig, plot + + def _vline_page_coords(self, x_data: float, ph_override: int = FIG_H) -> tuple[int, int]: + r = _plot_rect(FIG_W, ph_override) + frac = _data_to_frac(x_data, N) + cx = _frac_to_canvas_x(frac, r) + # Use mid-height so the |mx - px| <= 5 hit-test branch fires reliably. + cy = r["y"] + r["h"] // 2 + return _to_page(cx, cy) + + def test_position_changes_after_drag(self, interact_page): + """Dragging the vline left updates its x value in the panel state.""" + fig, plot = self._make_fig() + vline = plot.add_vline_widget(50.0) + panel_id = plot._id + + page = interact_page(fig) + r = _plot_rect(FIG_W, FIG_H) + + # Click on the vline at x=50, drag left to approximately x=20. + px_start, py_start = self._vline_page_coords(50.0) + frac_end = _data_to_frac(20.0, N) + cx_end = _frac_to_canvas_x(frac_end, r) + px_end, py_end = _to_page(cx_end, r["y"] + r["h"] // 2) + + page.mouse.move(px_start, py_start) + page.mouse.down() + page.mouse.move(px_end, py_end, steps=10) + page.mouse.up() + _rafter(page) + + new_x = _panel_state(page, panel_id)["overlay_widgets"][0]["x"] + assert new_x < 35, f"VLine should have moved left; got x={new_x:.2f}" + assert new_x > 5, f"VLine should not have overshot; got x={new_x:.2f}" + + def test_release_event_widget_id(self, interact_page): + """pointer_up event_json carries the correct widget ID.""" + fig, plot = self._make_fig() + vline = plot.add_vline_widget(50.0) + wid_id = vline._id + + page = interact_page(fig) + r = _plot_rect(FIG_W, FIG_H) + + px_start, py_start = self._vline_page_coords(50.0) + cx_end = _frac_to_canvas_x(_data_to_frac(30.0, N), r) + px_end, py_end = _to_page(cx_end, r["y"] + r["h"] // 2) + + page.mouse.move(px_start, py_start) + page.mouse.down() + page.mouse.move(px_end, py_end, steps=10) + page.mouse.up() + _rafter(page) + + ev = _event(page) + assert ev["event_type"] == "pointer_up", f"Expected pointer_up, got {ev['event_type']!r}" + assert ev["widget_id"] == wid_id, ( + f"Event widget_id {ev['widget_id']!r} != expected {wid_id!r}" + ) + + def test_on_changed_events_during_drag(self, interact_page): + """pointer_move events are emitted for every mousemove during drag.""" + fig, plot = self._make_fig() + vline = plot.add_vline_widget(50.0) + wid_id = vline._id + + page = interact_page(fig) + r = _plot_rect(FIG_W, FIG_H) + + # Patch model.set to accumulate event_json writes before the drag. + page.evaluate("""() => { + window._aplAllEvents = []; + const orig = window._aplModel.set.bind(window._aplModel); + window._aplModel.set = (k, v) => { + if (k === 'event_json') { + try { window._aplAllEvents.push(JSON.parse(v)); } catch(_) {} + } + return orig(k, v); + }; + }""") + + px_start, py_start = self._vline_page_coords(50.0) + cx_end = _frac_to_canvas_x(_data_to_frac(25.0, N), r) + px_end, py_end = _to_page(cx_end, r["y"] + r["h"] // 2) + + page.mouse.move(px_start, py_start) + page.mouse.down() + page.mouse.move(px_end, py_end, steps=8) + page.mouse.up() + _rafter(page) + + events = page.evaluate("() => window._aplAllEvents") + + changed = [e for e in events if e.get("event_type") == "pointer_move" and e.get("widget_id") == wid_id] + released = [e for e in events if e.get("event_type") == "pointer_up"] + + assert len(changed) > 0, "Expected at least one pointer_move event with correct widget_id during drag" + assert len(released) >= 1, f"Expected at least one pointer_up, got {len(released)}" + + def test_drag_right_increases_x(self, interact_page): + """Dragging the vline right increases its x value.""" + fig, plot = self._make_fig() + vline = plot.add_vline_widget(20.0) + panel_id = plot._id + + page = interact_page(fig) + r = _plot_rect(FIG_W, FIG_H) + + px_start, py_start = self._vline_page_coords(20.0) + cx_end = _frac_to_canvas_x(_data_to_frac(60.0, N), r) + px_end, py_end = _to_page(cx_end, r["y"] + r["h"] // 2) + + page.mouse.move(px_start, py_start) + page.mouse.down() + page.mouse.move(px_end, py_end, steps=10) + page.mouse.up() + _rafter(page) + + new_x = _panel_state(page, panel_id)["overlay_widgets"][0]["x"] + assert new_x > 35, f"VLine should have moved right; got x={new_x:.2f}" + assert new_x < 80, f"VLine should not have overshot; got x={new_x:.2f}" + + +# ═══════════════════════════════════════════════════════════════════════════ +# HLine drag tests +# ═══════════════════════════════════════════════════════════════════════════ + +class TestHLineDrag1D: + """Drag an HLineWidget on a 1-D panel and verify JS state.""" + + def _make_fig(self): + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + t = np.arange(N, dtype=float) + # Sine so data spans [-1, 1]; padding gives data_min≈-1.1, data_max≈1.1. + plot = ax.plot(np.sin(2 * np.pi * t / N)) + return fig, plot + + def test_drag_up_increases_y(self, interact_page): + """Dragging the hline upward increases its data-y value.""" + fig, plot = self._make_fig() + data_min = plot._state["data_min"] + data_max = plot._state["data_max"] + hline = plot.add_hline_widget(0.0) + panel_id = plot._id + + page = interact_page(fig) + r = _plot_rect(FIG_W, FIG_H) + + # Canvas coords for y=0.0 (mid-range). + cy_start = _val_to_canvas_y(0.0, data_min, data_max, r) + # Use mid-plot x so we're safely inside the plot area. + cx_mid = r["x"] + r["w"] // 2 + px_start, py_start = _to_page(cx_mid, cy_start) + + # Drag up by 40 canvas pixels. + px_end, py_end = _to_page(cx_mid, cy_start - 40) + + page.mouse.move(px_start, py_start) + page.mouse.down() + page.mouse.move(px_end, py_end, steps=10) + page.mouse.up() + _rafter(page) + + new_y = _panel_state(page, panel_id)["overlay_widgets"][0]["y"] + assert new_y > 0.2, f"HLine should have moved up; got y={new_y:.3f}" + assert new_y < data_max, f"HLine should stay within data range; got y={new_y:.3f}" + + def test_drag_down_decreases_y(self, interact_page): + """Dragging the hline downward decreases its data-y value.""" + fig, plot = self._make_fig() + data_min = plot._state["data_min"] + data_max = plot._state["data_max"] + hline = plot.add_hline_widget(0.0) + panel_id = plot._id + + page = interact_page(fig) + r = _plot_rect(FIG_W, FIG_H) + + cy_start = _val_to_canvas_y(0.0, data_min, data_max, r) + cx_mid = r["x"] + r["w"] // 2 + px_start, py_start = _to_page(cx_mid, cy_start) + # Drag down by 40 canvas pixels. + px_end, py_end = _to_page(cx_mid, cy_start + 40) + + page.mouse.move(px_start, py_start) + page.mouse.down() + page.mouse.move(px_end, py_end, steps=10) + page.mouse.up() + _rafter(page) + + new_y = _panel_state(page, panel_id)["overlay_widgets"][0]["y"] + assert new_y < -0.2, f"HLine should have moved down; got y={new_y:.3f}" + assert new_y > data_min, f"HLine should stay within data range; got y={new_y:.3f}" + + def test_release_event_widget_id(self, interact_page): + """pointer_up carries the hline widget ID.""" + fig, plot = self._make_fig() + data_min = plot._state["data_min"] + data_max = plot._state["data_max"] + hline = plot.add_hline_widget(0.0) + wid_id = hline._id + + page = interact_page(fig) + r = _plot_rect(FIG_W, FIG_H) + + cy_start = _val_to_canvas_y(0.0, data_min, data_max, r) + cx_mid = r["x"] + r["w"] // 2 + px_start, py_start = _to_page(cx_mid, cy_start) + px_end, py_end = _to_page(cx_mid, cy_start - 30) + + page.mouse.move(px_start, py_start) + page.mouse.down() + page.mouse.move(px_end, py_end, steps=8) + page.mouse.up() + _rafter(page) + + ev = _event(page) + assert ev["event_type"] == "pointer_up" + assert ev["widget_id"] == wid_id + + +# ═══════════════════════════════════════════════════════════════════════════ +# Point widget drag tests +# ═══════════════════════════════════════════════════════════════════════════ + +class TestPointDrag1D: + """Drag a PointWidget on a 1-D panel — verifies 2-D free movement.""" + + # Hit-test radius for the point handle (HR+4 = 11 px, from the JS). + _HIT_R = 11 + + def _make_fig(self, x_init: float = 50.0, y_init: float = 0.0): + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + t = np.arange(N, dtype=float) + plot = ax.plot(np.sin(2 * np.pi * t / N)) + pt = plot.add_point_widget(x_init, y_init) + return fig, plot, pt + + def _point_page_coords(self, x_data: float, y_data: float, + data_min: float, data_max: float) -> tuple[int, int]: + r = _plot_rect(FIG_W, FIG_H) + frac = _data_to_frac(x_data, N) + cx = _frac_to_canvas_x(frac, r) + cy = _val_to_canvas_y(y_data, data_min, data_max, r) + return _to_page(cx, cy) + + def test_drag_changes_both_x_and_y(self, interact_page): + """Dragging the point widget updates both x and y in the panel state.""" + fig, plot, pt = self._make_fig(50.0, 0.0) + panel_id = plot._id + data_min = plot._state["data_min"] + data_max = plot._state["data_max"] + + page = interact_page(fig) + r = _plot_rect(FIG_W, FIG_H) + + # Start on the point at (x=50, y=0). + px_start, py_start = self._point_page_coords(50.0, 0.0, data_min, data_max) + + # Drag to approximately x=30, y=0.4. + cx_end = _frac_to_canvas_x(_data_to_frac(30.0, N), r) + cy_end = _val_to_canvas_y(0.4, data_min, data_max, r) + px_end, py_end = _to_page(cx_end, cy_end) + + page.mouse.move(px_start, py_start) + page.mouse.down() + page.mouse.move(px_end, py_end, steps=12) + page.mouse.up() + _rafter(page) + + ws = _panel_state(page, panel_id)["overlay_widgets"][0] + assert ws["x"] < 45, f"Point x should have moved left; got x={ws['x']:.2f}" + assert ws["x"] > 10, f"Point x should not have overshot; got x={ws['x']:.2f}" + assert ws["y"] > 0.1, f"Point y should have moved up; got y={ws['y']:.3f}" + assert ws["y"] < 0.9, f"Point y should not have overshot; got y={ws['y']:.3f}" + + def test_release_event_widget_id(self, interact_page): + """pointer_up event carries the point widget's ID.""" + fig, plot, pt = self._make_fig(50.0, 0.0) + wid_id = pt._id + data_min = plot._state["data_min"] + data_max = plot._state["data_max"] + + page = interact_page(fig) + r = _plot_rect(FIG_W, FIG_H) + + px_start, py_start = self._point_page_coords(50.0, 0.0, data_min, data_max) + cx_end = _frac_to_canvas_x(_data_to_frac(70.0, N), r) + cy_end = _val_to_canvas_y(-0.3, data_min, data_max, r) + px_end, py_end = _to_page(cx_end, cy_end) + + page.mouse.move(px_start, py_start) + page.mouse.down() + page.mouse.move(px_end, py_end, steps=10) + page.mouse.up() + _rafter(page) + + ev = _event(page) + assert ev["event_type"] == "pointer_up" + assert ev["widget_id"] == wid_id + + def test_on_changed_events_during_drag(self, interact_page): + """pointer_move events fire on every mousemove step during drag.""" + fig, plot, pt = self._make_fig(50.0, 0.0) + wid_id = pt._id + data_min = plot._state["data_min"] + data_max = plot._state["data_max"] + + page = interact_page(fig) + r = _plot_rect(FIG_W, FIG_H) + + page.evaluate("""() => { + window._aplAllEvents = []; + const orig = window._aplModel.set.bind(window._aplModel); + window._aplModel.set = (k, v) => { + if (k === 'event_json') { + try { window._aplAllEvents.push(JSON.parse(v)); } catch(_) {} + } + return orig(k, v); + }; + }""") + + px_start, py_start = self._point_page_coords(50.0, 0.0, data_min, data_max) + cx_end = _frac_to_canvas_x(_data_to_frac(30.0, N), r) + cy_end = _val_to_canvas_y(0.5, data_min, data_max, r) + px_end, py_end = _to_page(cx_end, cy_end) + + page.mouse.move(px_start, py_start) + page.mouse.down() + page.mouse.move(px_end, py_end, steps=8) + page.mouse.up() + _rafter(page) + + events = page.evaluate("() => window._aplAllEvents") + changed = [e for e in events if e.get("event_type") == "pointer_move" and e.get("widget_id") == wid_id] + released = [e for e in events if e.get("event_type") == "pointer_up"] + + assert len(changed) > 0, "Expected pointer_move events with correct widget_id during drag" + assert len(released) >= 1, f"Expected at least one pointer_up, got {len(released)}" + + def test_drag_right_and_down(self, interact_page): + """Dragging right+down increases x and decreases y.""" + fig, plot, pt = self._make_fig(30.0, 0.4) + panel_id = plot._id + data_min = plot._state["data_min"] + data_max = plot._state["data_max"] + + page = interact_page(fig) + r = _plot_rect(FIG_W, FIG_H) + + px_start, py_start = self._point_page_coords(30.0, 0.4, data_min, data_max) + cx_end = _frac_to_canvas_x(_data_to_frac(65.0, N), r) + cy_end = _val_to_canvas_y(-0.4, data_min, data_max, r) + px_end, py_end = _to_page(cx_end, cy_end) + + page.mouse.move(px_start, py_start) + page.mouse.down() + page.mouse.move(px_end, py_end, steps=12) + page.mouse.up() + _rafter(page) + + ws = _panel_state(page, panel_id)["overlay_widgets"][0] + assert ws["x"] > 50, f"Point x should have moved right; got x={ws['x']:.2f}" + assert ws["y"] < 0.1, f"Point y should have moved down; got y={ws['y']:.3f}" + + def test_drag_outside_plot_clamps_to_boundary(self, interact_page): + """Dragging past the plot edge clamps the point to the plot boundary.""" + fig, plot, pt = self._make_fig(50.0, 0.0) + panel_id = plot._id + data_min = plot._state["data_min"] + data_max = plot._state["data_max"] + + page = interact_page(fig) + r = _plot_rect(FIG_W, FIG_H) + + px_start, py_start = self._point_page_coords(50.0, 0.0, data_min, data_max) + + # Drag far to the right and up — well outside the plot area. + far_right = r["x"] + r["w"] + 80 # 80 px past the right edge + far_up = r["y"] - 60 # 60 px above the top edge + px_end, py_end = _to_page(far_right, far_up) + + page.mouse.move(px_start, py_start) + page.mouse.down() + page.mouse.move(px_end, py_end, steps=10) + page.mouse.up() + _rafter(page) + + ws = _panel_state(page, panel_id)["overlay_widgets"][0] + # x should be clamped to ≤ the rightmost data value (99). + assert ws["x"] <= N - 1 + 1, ( + f"Point x should be clamped to data range; got x={ws['x']:.2f}" + ) + # y should be clamped to ≤ data_max. + assert ws["y"] <= data_max + 0.01, ( + f"Point y should be clamped to data_max; got y={ws['y']:.3f}" + ) + + +# ═══════════════════════════════════════════════════════════════════════════ +# Range widget drag tests +# ═══════════════════════════════════════════════════════════════════════════ + +class TestRangeDrag1D: + """Drag a RangeWidget's edges and body on a 1-D panel.""" + + def _make_fig(self, x0: float = 20.0, x1: float = 70.0): + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + plot = ax.plot(np.zeros(N)) + rw = plot.add_range_widget(x0, x1) + return fig, plot, rw + + def test_right_edge_drag_moves_x1(self, interact_page): + """Dragging the right edge inward decreases x1.""" + fig, plot, rw = self._make_fig(20.0, 70.0) + panel_id = plot._id + + page = interact_page(fig) + r = _plot_rect(FIG_W, FIG_H) + + # Right-edge canvas x at x1=70. + cx_right = _frac_to_canvas_x(_data_to_frac(70.0, N), r) + cy = r["y"] + r["h"] // 2 + px_start, py_start = _to_page(cx_right, cy) + + # Drag the right edge left to approximately x1=50. + cx_new = _frac_to_canvas_x(_data_to_frac(50.0, N), r) + px_end, py_end = _to_page(cx_new, cy) + + page.mouse.move(px_start, py_start) + page.mouse.down() + page.mouse.move(px_end, py_end, steps=10) + page.mouse.up() + _rafter(page) + + ws = _panel_state(page, panel_id)["overlay_widgets"][0] + assert ws["x1"] < 65, f"Range right edge should have moved left; got x1={ws['x1']:.2f}" + assert abs(ws["x0"] - 20.0) < 5, ( + f"Range left edge should be ~20 (unchanged); got x0={ws['x0']:.2f}" + ) + + def test_left_edge_drag_moves_x0(self, interact_page): + """Dragging the left edge rightward increases x0.""" + fig, plot, rw = self._make_fig(20.0, 70.0) + panel_id = plot._id + + page = interact_page(fig) + r = _plot_rect(FIG_W, FIG_H) + + # Left-edge canvas x at x0=20. + cx_left = _frac_to_canvas_x(_data_to_frac(20.0, N), r) + cy = r["y"] + r["h"] // 2 + px_start, py_start = _to_page(cx_left, cy) + + # Drag the left edge right to approximately x0=40. + cx_new = _frac_to_canvas_x(_data_to_frac(40.0, N), r) + px_end, py_end = _to_page(cx_new, cy) + + page.mouse.move(px_start, py_start) + page.mouse.down() + page.mouse.move(px_end, py_end, steps=10) + page.mouse.up() + _rafter(page) + + ws = _panel_state(page, panel_id)["overlay_widgets"][0] + assert ws["x0"] > 30, f"Range left edge should have moved right; got x0={ws['x0']:.2f}" + assert abs(ws["x1"] - 70.0) < 5, ( + f"Range right edge should be ~70 (unchanged); got x1={ws['x1']:.2f}" + ) + + def test_body_drag_shifts_both_edges(self, interact_page): + """Dragging the range body shifts both x0 and x1 by the same amount.""" + fig, plot, rw = self._make_fig(30.0, 60.0) + panel_id = plot._id + + page = interact_page(fig) + r = _plot_rect(FIG_W, FIG_H) + + # Body midpoint canvas x (safely inside the body, away from edges). + cx_mid = _frac_to_canvas_x(_data_to_frac(45.0, N), r) + cy = r["y"] + r["h"] // 2 + px_start, py_start = _to_page(cx_mid, cy) + + # Drag right by 33 canvas pixels (≈ 10 data units on a 330-px plot). + px_end, py_end = _to_page(cx_mid + 33, cy) + + page.mouse.move(px_start, py_start) + page.mouse.down() + page.mouse.move(px_end, py_end, steps=10) + page.mouse.up() + _rafter(page) + + ws = _panel_state(page, panel_id)["overlay_widgets"][0] + # Both edges should have moved right by roughly the same amount. + delta_x0 = ws["x0"] - 30.0 + delta_x1 = ws["x1"] - 60.0 + assert delta_x0 > 2, f"x0 should have moved right; Δx0={delta_x0:.2f}" + assert delta_x1 > 2, f"x1 should have moved right; Δx1={delta_x1:.2f}" + assert abs(delta_x0 - delta_x1) < 3, ( + f"Both edges should shift by the same amount; Δx0={delta_x0:.2f}, Δx1={delta_x1:.2f}" + ) + + def test_release_event_widget_id(self, interact_page): + """pointer_up event carries the range widget's ID.""" + fig, plot, rw = self._make_fig(30.0, 70.0) + wid_id = rw._id + + page = interact_page(fig) + r = _plot_rect(FIG_W, FIG_H) + + cx_right = _frac_to_canvas_x(_data_to_frac(70.0, N), r) + cy = r["y"] + r["h"] // 2 + px_start, py_start = _to_page(cx_right, cy) + px_end, py_end = _to_page(cx_right - 40, cy) + + page.mouse.move(px_start, py_start) + page.mouse.down() + page.mouse.move(px_end, py_end, steps=10) + page.mouse.up() + _rafter(page) + + ev = _event(page) + assert ev["event_type"] == "pointer_up" + assert ev["widget_id"] == wid_id + + +# ═══════════════════════════════════════════════════════════════════════════ +# CSS-scale interaction tests (simulate narrow Jupyter cell) +# ═══════════════════════════════════════════════════════════════════════════ + +class TestScaledInteraction1D: + """Detect the _applyScale coordinate-mismatch bug. + + When ``_applyScale`` applies ``transform:scale(s)`` to ``outerDiv`` (because + the Jupyter output cell is narrower than the figure), every event handler + that computes ``e.clientX - getBoundingClientRect().left`` receives + *visual* coordinates in the range ``[0, pw*s]`` rather than *canvas* + coordinates in ``[0, pw]``. Hit tests then miss by factor ``1/s``. + + Each test below: + * calls ``_inject_scale(page, 0.75)`` to apply the transform directly — + exactly what ``_applyScale`` would do in Jupyter, + * clicks at the **visual** position of the widget handle (what the user + actually sees and clicks on), and + * asserts the widget moved. + + **Expected outcomes:** + * **FAIL** with the current code (hit test misses → position unchanged). + * **PASS** after the event-coordinate fix (``_clientPos`` helper divides + by the scale factor so hit tests receive true canvas coordinates). + """ + + _SCALE = 0.75 + + def test_vline_drag_under_scale(self, interact_page): + """VLine drag at visual position must move the widget under CSS scale.""" + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + plot = ax.plot(np.arange(N, dtype=float)) + vline = plot.add_vline_widget(50.0) + panel_id = plot._id + + page = interact_page(fig) + s = _inject_scale(page, self._SCALE) + r = _plot_rect(FIG_W, FIG_H) + + cx = _frac_to_canvas_x(_data_to_frac(50.0, N), r) + cy = r["y"] + r["h"] // 2 + px_s, py_s = _to_page_scaled(cx, cy, s) + + cx_end = _frac_to_canvas_x(_data_to_frac(20.0, N), r) + px_e, py_e = _to_page_scaled(cx_end, cy, s) + + page.mouse.move(px_s, py_s) + page.mouse.down() + page.mouse.move(px_e, py_e, steps=10) + page.mouse.up() + _rafter(page) + + new_x = _panel_state(page, panel_id)["overlay_widgets"][0]["x"] + assert new_x < 35, ( + f"VLine should have moved left under scale s={s:.2f}; " + f"got x={new_x:.2f} (unchanged=50.0 means hit missed)" + ) + + def test_hline_drag_under_scale(self, interact_page): + """HLine drag at visual position must move the widget under CSS scale.""" + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + t = np.arange(N, dtype=float) + plot = ax.plot(np.sin(2 * np.pi * t / N)) + data_min = plot._state["data_min"] + data_max = plot._state["data_max"] + hline = plot.add_hline_widget(0.0) + panel_id = plot._id + + page = interact_page(fig) + s = _inject_scale(page, self._SCALE) + r = _plot_rect(FIG_W, FIG_H) + + cy = _val_to_canvas_y(0.0, data_min, data_max, r) + cx = r["x"] + r["w"] // 2 + px_s, py_s = _to_page_scaled(cx, cy, s) + px_e, py_e = _to_page_scaled(cx, cy - 40, s) # drag up + + page.mouse.move(px_s, py_s) + page.mouse.down() + page.mouse.move(px_e, py_e, steps=10) + page.mouse.up() + _rafter(page) + + new_y = _panel_state(page, panel_id)["overlay_widgets"][0]["y"] + assert new_y > 0.2, ( + f"HLine should have moved up under scale s={s:.2f}; " + f"got y={new_y:.3f} (unchanged=0.0 means hit missed)" + ) + + def test_range_drag_under_scale(self, interact_page): + """Range edge drag at visual position must move x1 under CSS scale.""" + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + plot = ax.plot(np.zeros(N)) + rw = plot.add_range_widget(20.0, 70.0) + panel_id = plot._id + + page = interact_page(fig) + s = _inject_scale(page, self._SCALE) + r = _plot_rect(FIG_W, FIG_H) + + cx_right = _frac_to_canvas_x(_data_to_frac(70.0, N), r) + cy = r["y"] + r["h"] // 2 + px_s, py_s = _to_page_scaled(cx_right, cy, s) + + cx_new = _frac_to_canvas_x(_data_to_frac(50.0, N), r) + px_e, py_e = _to_page_scaled(cx_new, cy, s) + + page.mouse.move(px_s, py_s) + page.mouse.down() + page.mouse.move(px_e, py_e, steps=10) + page.mouse.up() + _rafter(page) + + ws = _panel_state(page, panel_id)["overlay_widgets"][0] + assert ws["x1"] < 65, ( + f"Range right edge should have moved under scale s={s:.2f}; " + f"got x1={ws['x1']:.2f} (unchanged=70.0 means hit missed)" + ) + + def test_point_drag_under_scale(self, interact_page): + """Point drag at visual position must move the widget under CSS scale. + + This is the exact failure mode the notebook user experiences: the cyan + handle is visible but unresponsive because the hit-test coordinates are + off by factor 1/s. + """ + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + t = np.arange(N, dtype=float) + plot = ax.plot(np.sin(2 * np.pi * t / N)) + data_min = plot._state["data_min"] + data_max = plot._state["data_max"] + pt = plot.add_point_widget(50.0, 0.0) + panel_id = plot._id + + page = interact_page(fig) + s = _inject_scale(page, self._SCALE) + r = _plot_rect(FIG_W, FIG_H) + + cx_s = _frac_to_canvas_x(_data_to_frac(50.0, N), r) + cy_s = _val_to_canvas_y(0.0, data_min, data_max, r) + px_s, py_s = _to_page_scaled(cx_s, cy_s, s) + + cx_e = _frac_to_canvas_x(_data_to_frac(30.0, N), r) + cy_e = _val_to_canvas_y(0.4, data_min, data_max, r) + px_e, py_e = _to_page_scaled(cx_e, cy_e, s) + + page.mouse.move(px_s, py_s) + page.mouse.down() + page.mouse.move(px_e, py_e, steps=12) + page.mouse.up() + _rafter(page) + + ws = _panel_state(page, panel_id)["overlay_widgets"][0] + assert ws["x"] < 45, ( + f"Point x should have moved under scale s={s:.2f}; " + f"got x={ws['x']:.2f} (unchanged=50.0 means hit missed)" + ) + + +# ═══════════════════════════════════════════════════════════════════════════ +# Extra CSS-scale tests — pan, 2D widget drag, bar widget drag +# ═══════════════════════════════════════════════════════════════════════════ + +def _to_page_scaled_2d(ccx: float, ccy: float, scale: float) -> tuple[int, int]: + """Canvas coords for a 2D overlayCanvas → page coords under CSS scale. + + For plain imshow panels (no physical axes) the overlayCanvas starts at + (0, 0) inside the panel wrap, so the only offset is GRID_PAD:: + + page_x = (GRID_PAD + ccx) * scale + page_y = (GRID_PAD + ccy) * scale + """ + return ( + int(round((GRID_PAD + ccx) * scale)), + int(round((GRID_PAD + ccy) * scale)), + ) + + +class TestScaledInteractionExtra: + """Additional scale tests covering pan, 2-D widget drag, and bar chart drag. + + All tests in this class apply ``transform:scale(0.75)`` to ``outerDiv`` + before firing mouse events, mirroring exactly what ``_applyScale()`` does + in a narrow Jupyter cell. + + Expected outcomes: + * **FAIL** with the current code (coordinates off by factor 1/s). + * **PASS** after applying the ``_clientPos`` fix in ``figure_esm.js``. + """ + + _SCALE = 0.75 + + # ── 1D pan under scale ──────────────────────────────────────────────── + + def test_1d_pan_under_scale(self, interact_page): + """Panning by N visual pixels must move the view by N/s canvas pixels. + + The broken code computes ``dx = (e.clientX - panStart.mx) / r.w`` + where the numerator is in visual (CSS-scaled) pixels and the denominator + is in canvas pixels. At s=0.75 this under-pans by factor s. + + Geometry (s=0.75, span=0.4, drag=-165 visual px): + broken nx0 = 0.3 + 165/330 × 0.4 = 0.500 + correct nx0 = 0.3 + 165/247.5 × 0.4 = 0.567 + + The assertion ``view_x0 > 0.52`` cleanly separates the two outcomes. + """ + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + plot = ax.plot(np.zeros(N)) + panel_id = plot._id + + # Pre-zoom to [0.3, 0.7] so there is room to pan rightward. + plot._state["view_x0"] = 0.3 + plot._state["view_x1"] = 0.7 + plot._push() + + page = interact_page(fig) + s = _inject_scale(page, self._SCALE) + r = _plot_rect(FIG_W, FIG_H) + + # Drag left (pan view to the right) starting at the plot mid-point. + cx_start = r["x"] + r["w"] // 2 + cy_mid = r["y"] + r["h"] // 2 + px_s, py_s = _to_page_scaled(cx_start, cy_mid, s) + + # Move 165 visual pixels to the left. + px_e = px_s - 165 + py_e = py_s + + page.mouse.move(px_s, py_s) + page.mouse.down() + page.mouse.move(px_e, py_e, steps=10) + page.mouse.up() + _rafter(page) + + st = _panel_state(page, panel_id) + x0 = st["view_x0"] + assert x0 > 0.52, ( + f"Pan under scale s={s:.2f} under-panned; " + f"got view_x0={x0:.3f} (broken≈0.500, correct≈0.567)" + ) + + # ── 2D crosshair drag under scale ───────────────────────────────────── + + def test_2d_crosshair_drag_under_scale(self, interact_page): + """Crosshair drag at its visual position must move the widget under CSS scale. + + The broken ``_doDrag2d`` and ``_attachEvents2d`` use raw + ``e.clientX - getBoundingClientRect().left``, which is in visual pixels, + while ``_imgToCanvas2d`` works in canvas pixels. At s=0.75 the + initial hit misses by ~47 px (well outside HR+4=13) so the drag is + never started. + """ + rng = np.random.default_rng(42) + img = rng.standard_normal((128, 128)) + + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + plot = ax.imshow(img) + panel_id = plot._id + + # Crosshair centred in the image at (cx=64, cy=64) in image-pixel space. + plot.add_widget("crosshair", cx=64.0, cy=64.0) + + page = interact_page(fig) + s = _inject_scale(page, self._SCALE) + + # At zoom=1, center=(0.5,0.5), image 128×128 in panel 400×240, + # no physical axes so imgW=400, imgH=240: + # fit = min(400/128, 240/128) = 1.875 + # fr = {x:80, y:0, w:240, h:240} + # canvas pos of (64,64): ccx = 80+(64/128)*240 = 200, ccy = 120 + ccx, ccy = 200.0, 120.0 + px_s, py_s = _to_page_scaled_2d(ccx, ccy, s) + + # Drag 40 canvas pixels to the right and down. + px_e, py_e = _to_page_scaled_2d(ccx + 40, ccy + 30, s) + + page.mouse.move(px_s, py_s) + page.mouse.down() + page.mouse.move(px_e, py_e, steps=10) + page.mouse.up() + _rafter(page) + + ws = _panel_state(page, panel_id)["overlay_widgets"][0] + assert ws["cx"] > 64 + 5, ( + f"Crosshair cx should have moved right under scale s={s:.2f}; " + f"got cx={ws['cx']:.2f} (unchanged=64.0 means hit missed)" + ) + + # ── bar-chart vline drag under scale ────────────────────────────────── + + def test_bar_vline_drag_under_scale(self, interact_page): + """Bar-chart VLine drag at visual position must move the widget under CSS scale. + + ``_attachEventsBar`` calls ``_ovHitTest1d(e.clientX-rect.left, …)`` + which has the same coordinate bug as ``_attachEvents1d``. + """ + values = np.array([1.0, 3.0, 2.0, 4.0, 2.5]) + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + bar_plot = ax.bar(values) + panel_id = bar_plot._id + + # x_axis = [-0.5, 4.5]; vline at x=2.0 (middle bar) → frac=0.5 → canvas_x=223 + vline = bar_plot.add_vline_widget(2.0) + + page = interact_page(fig) + s = _inject_scale(page, self._SCALE) + r = _plot_rect(FIG_W, FIG_H) + + # Canvas x for vline at data x=2.0 (frac=0.5): PAD_L + 0.5*r.w = 58+165 = 223 + cx_vline = PAD_L + 0.5 * r["w"] + cy_mid = r["y"] + r["h"] // 2 + + # Visual click position under scale + px_s, py_s = _to_page_scaled(cx_vline, cy_mid, s) + + # Drag left to approximately x=0.5 (frac≈0.1 → canvas_x≈91) + cx_end = PAD_L + 0.1 * r["w"] + px_e, py_e = _to_page_scaled(cx_end, cy_mid, s) + + page.mouse.move(px_s, py_s) + page.mouse.down() + page.mouse.move(px_e, py_e, steps=10) + page.mouse.up() + _rafter(page) + + ws = _panel_state(page, panel_id)["overlay_widgets"][0] + assert ws["x"] < 1.5, ( + f"Bar VLine should have moved under scale s={s:.2f}; " + f"got x={ws['x']:.3f} (unchanged=2.0 means hit missed)" + ) + + # ── bar-chart pointer_down under scale ─────────────────────────────────── + + def test_bar_click_under_scale(self, interact_page): + """Bar pointer_down fires with correct bar_index when clicking at the + visual (scaled) position of a bar. + + The test clicks at a position that is correct in *visual* (scaled) + coordinates but would be wrong in unscaled canvas coordinates. + ``_clientPos`` must undo the CSS transform so the hit-test operates + in canvas space. + + Bar geometry (vertical, 5 bars, default bar_width=0.7, FIG_W=400): + slotPx = 330 / 5 = 66 + bar 2 centre_x = 58 + 2.5 × 66 = 223 (canvas px) + bar 2 y-range = [barTopY, basePx] — computed from data_min/data_max + click y = midpoint of bar 2's y-span (safely inside the bar) + """ + values = np.array([1.0, 3.0, 2.0, 4.0, 2.5]) + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + bar_plot = ax.bar(values) + panel_id = bar_plot._id + + # Read axis bounds computed by PlotBar (includes 7 % padding above max). + data_min = bar_plot._state["data_min"] # == 0.0 + data_max = bar_plot._state["data_max"] # ≈ 4.28 + + page = interact_page(fig) + s = _inject_scale(page, self._SCALE) + r = _plot_rect(FIG_W, FIG_H) + + # Bar geometry + n_bars = len(values) + slot_px = r["w"] / n_bars + cx_bar2 = r["x"] + (2 + 0.5) * slot_px # 58 + 165 = 223 + + # y-coordinate: midpoint between bar 2's top and the baseline (bottom) + bar_top_y = _val_to_canvas_y(values[2], data_min, data_max, r) + baseline_y = _val_to_canvas_y(0.0, data_min, data_max, r) + cy_bar2 = (bar_top_y + baseline_y) / 2 # well inside the bar + + # Scaled visual click position + px_click, py_click = _to_page_scaled(cx_bar2, cy_bar2, s) + + page.mouse.click(px_click, py_click) + _rafter(page) + + ev = _event(page) + assert ev.get("event_type") == "pointer_down", ( + f"Expected pointer_down event under scale s={s:.2f}; " + f"got event_type={ev.get('event_type')!r} " + f"(missing means _clientPos failed to undo the CSS transform)" + ) + assert ev.get("bar_index") == 2, ( + f"Expected bar_index=2 under scale s={s:.2f}; " + f"got bar_index={ev.get('bar_index')!r}" + ) + + +# ═══════════════════════════════════════════════════════════════════════════ +# 2D imshow click vs drag tests +# ═══════════════════════════════════════════════════════════════════════════ + +class TestImshow2DClickVsDrag: + """Verify that a short tap on a 2D imshow panel emits ``pointer_down`` while a + longer drag emits only a pan ``pointer_up`` — and not a ``pointer_down``.""" + + def _make_fig(self): + fig, ax = apl.subplots(1, 1, figsize=(FIG_W, FIG_H)) + data = np.arange(64 * 64, dtype=np.float32).reshape(64, 64) + plot = ax.imshow(data) + return fig, plot + + def _img_center_page(self) -> tuple[int, int]: + """Page coordinates of the centre of the image area.""" + # For a 2D panel the canvas fills the full cell; the image is letterboxed + # inside the PAD region. The centre of the drawable area: + cx = PAD_L + (FIG_W - PAD_L - PAD_R) // 2 + cy = PAD_T + (FIG_H - PAD_T - PAD_B) // 2 + return _to_page(cx, cy) + + def test_short_click_emits_on_click(self, interact_page): + """A short mousedown/up without movement fires a ``pointer_down`` event.""" + fig, plot = self._make_fig() + panel_id = plot._id + + page = interact_page(fig) + px, py = self._img_center_page() + + # Single click (no movement) — Playwright's click() is a + # mousedown + mouseup without intermediate moves, so _dist2 == 0 + # and _dt is well within the 350 ms threshold. + page.mouse.click(px, py) + _rafter(page) + + ev = _event(page) + assert ev.get("event_type") == "pointer_down", ( + f"Expected pointer_down from a short tap; got {ev.get('event_type')!r}" + ) + assert "img_x" in ev and "img_y" in ev, ( + "pointer_down event must include img_x and img_y coordinates" + ) + + def test_drag_does_not_emit_on_click(self, interact_page): + """A visible drag (> 5 px) pans the image and must NOT fire ``pointer_down``.""" + fig, plot = self._make_fig() + panel_id = plot._id + + page = interact_page(fig) + px, py = self._img_center_page() + + # Move mouse to start, press, drag 40 px right, release. + page.mouse.move(px, py) + page.mouse.down() + page.mouse.move(px + 40, py, steps=10) + page.mouse.up() + _rafter(page) + + ev = _event(page) + assert ev.get("event_type") != "pointer_down", ( + f"Expected pan (pointer_up), not pointer_down after a drag; " + f"got {ev.get('event_type')!r}" + ) diff --git a/anyplotlib/tests/test_layouts/test_visual.py b/anyplotlib/tests/test_layouts/test_visual.py new file mode 100644 index 00000000..4487dc55 --- /dev/null +++ b/anyplotlib/tests/test_layouts/test_visual.py @@ -0,0 +1,343 @@ +""" +tests/test_visual.py +==================== + +Pixel-level visual regression tests. + +Each test: + 1. Builds a deterministic Figure using the OO API. + 2. Renders it in headless Chromium via the ``take_screenshot`` fixture + (see conftest.py) — the *exact* JS renderer the user sees in a notebook. + 3. Compares the result against a golden PNG in ``tests/baselines/``. + +Workflow +-------- +Generate / refresh baselines (first run or after intentional visual change):: + + uv run pytest tests/test_visual.py --update-baselines -v + +Normal CI run (fails on regression):: + + uv run pytest tests/test_visual.py -v + +Comparison tolerance +-------------------- +* Per-pixel tolerance: 8 DN (≈3 % of 255) on any channel. +* Maximum bad-pixel fraction: 2 % of all pixels. + +These values absorb sub-pixel anti-aliasing differences between Chromium +versions while still catching genuine rendering regressions. +""" +from __future__ import annotations + +import pathlib + +import numpy as np +import pytest + +import anyplotlib as apl +from anyplotlib.tests._png_utils import decode_png, encode_png, compare_arrays + +BASELINES = pathlib.Path(__file__).parent.parent / "baselines" + + +# --------------------------------------------------------------------------- +# Helper +# --------------------------------------------------------------------------- + +def _check(name: str, arr: np.ndarray, update: bool) -> None: + """Assert *arr* matches the baseline named *name*, or write it if *update*.""" + path = BASELINES / f"{name}.png" + + if update: + BASELINES.mkdir(exist_ok=True) + path.write_bytes(encode_png(arr)) + pytest.skip(f"Baseline updated: {path.name}") + + if not path.exists(): + pytest.skip( + f"No baseline for {name!r} — run with --update-baselines to create it" + ) + + expected = decode_png(path.read_bytes()) + ok, msg = compare_arrays(arr, expected) + assert ok, f"Visual regression [{name}]: {msg}" + + +# --------------------------------------------------------------------------- +# Tests +# --------------------------------------------------------------------------- + +class TestVisual: + """Pixel-accurate rendering checks for each plot kind.""" + + # ── 2-D image ────────────────────────────────────────────────────────── + + def test_imshow_gradient(self, take_screenshot, update_baselines): + """Grayscale linear gradient — exercises the 2-D colormap + LUT path.""" + fig, ax = apl.subplots(1, 1, figsize=(320, 320)) + data = np.linspace(0.0, 1.0, 64 * 64, dtype=np.float32).reshape(64, 64) + ax.imshow(data) + arr = take_screenshot(fig) + _check("imshow_gradient", arr, update_baselines) + + def test_imshow_checkerboard(self, take_screenshot, update_baselines): + """High-contrast checkerboard — exercises sharp edge rendering.""" + fig, ax = apl.subplots(1, 1, figsize=(256, 256)) + board = np.indices((32, 32)).sum(axis=0) % 2 # 0/1 alternating + ax.imshow(board.astype(np.float32)) + arr = take_screenshot(fig) + _check("imshow_checkerboard", arr, update_baselines) + + def test_imshow_viridis(self, take_screenshot, update_baselines): + """2-D image with viridis colormap — exercises non-gray LUT path.""" + fig, ax = apl.subplots(1, 1, figsize=(320, 256)) + rng = np.random.default_rng(0) + data = rng.uniform(0.0, 1.0, (48, 64)).astype(np.float32) + plot = ax.imshow(data) + plot.set_colormap("viridis") + arr = take_screenshot(fig) + _check("imshow_viridis", arr, update_baselines) + + # ── 1-D line ─────────────────────────────────────────────────────────── + + def test_plot1d_sine(self, take_screenshot, update_baselines): + """Single sine wave — exercises the 1-D line renderer.""" + fig, ax = apl.subplots(1, 1, figsize=(400, 240)) + t = np.linspace(0.0, 2.0 * np.pi, 256) + ax.plot(np.sin(t)) + arr = take_screenshot(fig) + _check("plot1d_sine", arr, update_baselines) + + def test_plot1d_multi(self, take_screenshot, update_baselines): + """Multiple overlaid 1-D lines — exercises add_line() layering.""" + fig, ax = apl.subplots(1, 1, figsize=(400, 240)) + t = np.linspace(0.0, 2.0 * np.pi, 256) + plot = ax.plot(np.sin(t), color="#4fc3f7") + plot.add_line(np.cos(t), color="#ff7043") + arr = take_screenshot(fig) + _check("plot1d_multi", arr, update_baselines) + + def test_plot1d_dashed(self, take_screenshot, update_baselines): + """Dashed primary line — exercises linestyle→setLineDash path.""" + fig, ax = apl.subplots(1, 1, figsize=(400, 240)) + t = np.linspace(0.0, 2.0 * np.pi, 256) + ax.plot(np.sin(t), color="#ff7043", linestyle="dashed", linewidth=2) + arr = take_screenshot(fig) + _check("plot1d_dashed", arr, update_baselines) + + def test_plot1d_alpha(self, take_screenshot, update_baselines): + """Semi-transparent overlapping lines — exercises globalAlpha path.""" + fig, ax = apl.subplots(1, 1, figsize=(400, 240)) + t = np.linspace(0.0, 2.0 * np.pi, 256) + plot = ax.plot(np.sin(t), color="#4fc3f7", alpha=0.4) + plot.add_line(np.cos(t), color="#ff7043", alpha=0.4) + arr = take_screenshot(fig) + _check("plot1d_alpha", arr, update_baselines) + + def test_plot1d_markers(self, take_screenshot, update_baselines): + """Circle markers along a sparse line — exercises marker render path.""" + fig, ax = apl.subplots(1, 1, figsize=(400, 240)) + t = np.linspace(0.0, 2.0 * np.pi, 24) + ax.plot(np.sin(t), color="#4fc3f7", marker="o", markersize=4) + arr = take_screenshot(fig) + _check("plot1d_markers", arr, update_baselines) + + def test_plot1d_all_linestyles(self, take_screenshot, update_baselines): + """All four linestyles on one panel — exercises every dash pattern.""" + fig, ax = apl.subplots(1, 1, figsize=(440, 300)) + t = np.linspace(0.0, 2.0 * np.pi, 256) + plot = ax.plot(np.sin(t), color="#4fc3f7", linestyle="solid", label="solid") + plot.add_line(np.sin(t) + 0.6, color="#ff7043", linestyle="dashed", label="dashed") + plot.add_line(np.sin(t) + 1.2, color="#aed581", linestyle="dotted", label="dotted") + plot.add_line(np.sin(t) + 1.8, color="#ce93d8", linestyle="dashdot", label="dashdot") + arr = take_screenshot(fig) + _check("plot1d_all_linestyles", arr, update_baselines) + + def test_plot1d_marker_symbols(self, take_screenshot, update_baselines): + """All seven marker symbols on one panel.""" + fig, ax = apl.subplots(1, 1, figsize=(440, 380)) + t = np.linspace(0.0, 2.0 * np.pi, 20) + symbols = [("o", "#4fc3f7"), ("s", "#ff7043"), ("^", "#aed581"), + ("v", "#ce93d8"), ("D", "#ffcc02"), ("+", "#80cbc4"), + ("x", "#ef9a9a")] + plot = ax.plot(np.sin(t) - 3.0, color=symbols[0][1], + marker=symbols[0][0], markersize=5, label=symbols[0][0]) + for i, (sym, col) in enumerate(symbols[1:], 1): + plot.add_line(np.sin(t) + (i - 3) * 1.0, color=col, + marker=sym, markersize=5, label=sym) + arr = take_screenshot(fig) + _check("plot1d_marker_symbols", arr, update_baselines) + + # ── pcolormesh ───────────────────────────────────────────────────────── + + def test_pcolormesh_uniform(self, take_screenshot, update_baselines): + """Uniform-grid pcolormesh with sine × cosine pattern.""" + x = np.linspace(0.0, 2.0 * np.pi, 33) # 32 cells → 33 edges + y = np.linspace(0.0, 2.0 * np.pi, 33) + Xc = (x[:-1] + x[1:]) / 2 + Yc = (y[:-1] + y[1:]) / 2 + Z = np.sin(Xc[np.newaxis, :]) * np.cos(Yc[:, np.newaxis]) + fig, ax = apl.subplots(1, 1, figsize=(320, 320)) + ax.pcolormesh(Z, x_edges=x, y_edges=y) + arr = take_screenshot(fig) + _check("pcolormesh_uniform", arr, update_baselines) + + # ── 3-D surface ──────────────────────────────────────────────────────── + + def test_plot3d_surface(self, take_screenshot, update_baselines): + """3-D paraboloid surface — exercises the software 3-D renderer.""" + x = np.linspace(-1.5, 1.5, 24) + y = np.linspace(-1.5, 1.5, 24) + X, Y = np.meshgrid(x, y) + Z = X ** 2 + Y ** 2 + fig, ax = apl.subplots(1, 1, figsize=(320, 320)) + ax.plot_surface(X, Y, Z, colormap="viridis") + arr = take_screenshot(fig) + _check("plot3d_surface", arr, update_baselines) + + # ── bar chart ────────────────────────────────────────────────────────── + + def test_bar_basic(self, take_screenshot, update_baselines): + """Basic vertical bar chart — exercises the bar renderer end-to-end.""" + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + ax.bar(["Jan", "Feb", "Mar", "Apr", "May"], + [42, 55, 48, 61, 37], + color="#4fc3f7") + arr = take_screenshot(fig) + _check("bar_basic", arr, update_baselines) + + # ── multi-panel layout ───────────────────────────────────────────────── + + def test_subplots_2x1(self, take_screenshot, update_baselines): + """Two-row figure: image on top, 1-D line below.""" + fig, axs = apl.subplots(2, 1, figsize=(320, 480)) + data = np.linspace(0.0, 1.0, 32 * 32).reshape(32, 32).astype(np.float32) + axs[0].imshow(data) + t = np.linspace(0.0, 2.0 * np.pi, 128) + axs[1].plot(np.sin(t)) + arr = take_screenshot(fig) + _check("subplots_2x1", arr, update_baselines) + + # ── GridSpec layouts ─────────────────────────────────────────────────── + + def test_gridspec_side_by_side_1d(self, take_screenshot, update_baselines): + """Two 1-D spectra side by side — exercises 1×2 GridSpec layout. + + Verifies that side-by-side spectra are not squished and each occupies + exactly half the figure width with a reasonable inner plot area. + """ + gs = apl.GridSpec(1, 2) + fig = apl.Figure(figsize=(640, 240)) + t = np.linspace(0.0, 2.0 * np.pi, 256) + fig.add_subplot(gs[0, 0]).plot(np.sin(t), color="#4fc3f7") + fig.add_subplot(gs[0, 1]).plot(np.cos(t), color="#ff7043") + arr = take_screenshot(fig) + _check("gridspec_side_by_side_1d", arr, update_baselines) + + def test_gridspec_image_two_spectra(self, take_screenshot, update_baselines): + """Image on top (3×height), two 1-D spectra below (1×height) side by side. + + This is the canonical layout that exposed the squishing bug: bare + Figure + GridSpec with height_ratios caused row-1 panels to be floored + to 64px. The image should occupy 3/4 of the height; each spectrum 1/4. + """ + gs = apl.GridSpec(2, 2, height_ratios=[3, 1]) + fig = apl.Figure(figsize=(480, 480)) + data = np.linspace(0.0, 1.0, 32 * 32).reshape(32, 32).astype(np.float32) + fig.add_subplot(gs[0, :]).imshow(data) + t = np.linspace(0.0, 2.0 * np.pi, 128) + fig.add_subplot(gs[1, 0]).plot(np.sin(t), color="#4fc3f7") + fig.add_subplot(gs[1, 1]).plot(np.cos(t), color="#ff7043") + arr = take_screenshot(fig) + _check("gridspec_image_two_spectra", arr, update_baselines) + + def test_gridspec_height_ratio_image_histogram(self, take_screenshot, update_baselines): + """Image (3×) + histogram (1×) with explicit height_ratios via GridSpec.""" + gs = apl.GridSpec(2, 1, height_ratios=[3, 1]) + fig = apl.Figure(figsize=(400, 400)) + rng = np.random.default_rng(42) + data = rng.uniform(0.0, 1.0, (32, 32)).astype(np.float32) + fig.add_subplot(gs[0, 0]).imshow(data, cmap="viridis") + counts = np.histogram(data.ravel(), bins=32)[0].astype(float) + fig.add_subplot(gs[1, 0]).plot(counts, color="#aed581") + arr = take_screenshot(fig) + _check("gridspec_height_ratio_image_histogram", arr, update_baselines) + + def test_gridspec_3col_equal_spectra(self, take_screenshot, update_baselines): + """Three equal-width 1-D spectra in a single row — 1×3 GridSpec.""" + gs = apl.GridSpec(1, 3) + fig = apl.Figure(figsize=(720, 200)) + rng = np.random.default_rng(7) + t = np.linspace(0.0, 2.0 * np.pi, 200) + colors = ["#4fc3f7", "#ff7043", "#aed581"] + for i, color in enumerate(colors): + noise = rng.normal(scale=0.1, size=len(t)) + fig.add_subplot(gs[0, i]).plot(np.sin(t * (i + 1)) + noise, color=color) + arr = take_screenshot(fig) + _check("gridspec_3col_equal_spectra", arr, update_baselines) + + def test_gridspec_asymmetric_width_ratios(self, take_screenshot, update_baselines): + """2:1 width ratio: wide spectrum left, narrow spectrum right.""" + gs = apl.GridSpec(1, 2, width_ratios=[2, 1]) + fig = apl.Figure(figsize=(480, 200)) + t = np.linspace(0.0, 2.0 * np.pi, 256) + fig.add_subplot(gs[0, 0]).plot(np.sin(t), color="#4fc3f7") + fig.add_subplot(gs[0, 1]).plot(np.cos(t), color="#ff7043") + arr = take_screenshot(fig) + _check("gridspec_asymmetric_width_ratios", arr, update_baselines) + + def test_gridspec_spanning_top_two_bottom(self, take_screenshot, update_baselines): + """Full-width spectrum on top (gs[0, :]), two spectra below (gs[1, 0:2]).""" + gs = apl.GridSpec(2, 2, height_ratios=[2, 1]) + fig = apl.Figure(figsize=(480, 360)) + t = np.linspace(0.0, 4.0 * np.pi, 512) + fig.add_subplot(gs[0, :]).plot(np.sin(t), color="#4fc3f7") + fig.add_subplot(gs[1, 0]).plot(np.sin(2 * t), color="#ff7043") + fig.add_subplot(gs[1, 1]).plot(np.cos(2 * t), color="#aed581") + arr = take_screenshot(fig) + _check("gridspec_spanning_top_two_bottom", arr, update_baselines) + + # ── Phase 4 — labels, title, colorbar label, axis visibility ─────────── + + def test_plot1d_title(self, take_screenshot, update_baselines): + """1-D plot with set_title — title text drawn in top PAD area.""" + fig, ax = apl.subplots(1, 1, figsize=(400, 240)) + p = ax.plot(np.sin(np.linspace(0, 2 * np.pi, 256)), color="#4fc3f7") + p.set_title("Sine Wave") + arr = take_screenshot(fig) + _check("plot1d_title", arr, update_baselines) + + def test_plot1d_axis_off(self, take_screenshot, update_baselines): + """1-D plot with set_axis_off — tick labels hidden.""" + fig, ax = apl.subplots(1, 1, figsize=(400, 240)) + p = ax.plot(np.sin(np.linspace(0, 2 * np.pi, 256)), color="#4fc3f7") + p.set_axis_off() + arr = take_screenshot(fig) + _check("plot1d_axis_off", arr, update_baselines) + + def test_imshow_labels(self, take_screenshot, update_baselines): + """2-D image with x_label, y_label, title, and colorbar_label.""" + fig, ax = apl.subplots(1, 1, figsize=(400, 400)) + x = np.linspace(0.0, 10.0, 64) + p = ax.imshow( + np.random.default_rng(0).uniform(size=(64, 64)), + axes=[x, x], units="nm", + ) + p.set_xlabel("x (nm)") + p.set_ylabel("y (nm)") + p.set_title("Test Image") + p.set_colorbar_visible(True) + p.set_colorbar_label("Intensity") + arr = take_screenshot(fig) + _check("imshow_labels", arr, update_baselines) + + def test_imshow_axis_off(self, take_screenshot, update_baselines): + """2-D image with set_axis_off — axis gutters hidden.""" + fig, ax = apl.subplots(1, 1, figsize=(320, 320)) + x = np.linspace(0.0, 5.0, 32) + p = ax.imshow(np.zeros((32, 32)), axes=[x, x], units="nm") + p.set_axis_off() + arr = take_screenshot(fig) + _check("imshow_axis_off", arr, update_baselines) + diff --git a/anyplotlib/tests/test_markers/__init__.py b/anyplotlib/tests/test_markers/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/anyplotlib/tests/test_markers/test_marker_transforms.py b/anyplotlib/tests/test_markers/test_marker_transforms.py new file mode 100644 index 00000000..25aeb63b --- /dev/null +++ b/anyplotlib/tests/test_markers/test_marker_transforms.py @@ -0,0 +1,224 @@ +""" +tests/test_markers/test_marker_transforms.py +============================================= +Tests for the coordinate transform parameter on marker collections. + +Exercises: transform="data" (default), transform="axes", transform="display", +invalid transform, all add_* methods on both Plot1D and Plot2D, and that +set() preserves the transform. +""" +from __future__ import annotations + +import numpy as np +import pytest + +import anyplotlib as apl +from anyplotlib.markers import MarkerGroup + + +def _push_noop(): + pass + + +def _group(mtype, **kwargs): + return MarkerGroup(mtype, "g1", kwargs, _push_noop) + + +def _make_plot2d(): + fig, ax = apl.subplots(1, 1) + return ax.imshow(np.zeros((32, 32))) + + +def _make_plot1d(): + fig, ax = apl.subplots(1, 1) + return ax.plot(np.zeros(32)) + + +# --------------------------------------------------------------------------- +# MarkerGroup — wire-format round-trips +# --------------------------------------------------------------------------- + +class TestTransformWireFormat: + + def test_transform_default_is_data(self): + g = _group("circles", offsets=[[1.0, 2.0]], radius=5) + w = g.to_wire("gid") + assert w["transform"] == "data" + + def test_transform_axes_round_trips(self): + g = _group("texts", offsets=[[0.05, 0.95]], texts=["(3, 7)"], + transform="axes") + w = g.to_wire("gid") + assert w["transform"] == "axes" + + def test_transform_display_round_trips(self): + g = _group("circles", offsets=[[8.0, 8.0]], transform="display") + w = g.to_wire("gid") + assert w["transform"] == "display" + + def test_transform_data_explicit(self): + g = _group("rectangles", offsets=[[0.0, 0.0]], widths=10, heights=10, + transform="data") + w = g.to_wire("gid") + assert w["transform"] == "data" + + def test_all_2d_types_emit_transform(self): + types_and_kwargs = [ + ("circles", dict(offsets=[[1, 2]], radius=5)), + ("arrows", dict(offsets=[[1, 2]], U=1, V=1)), + ("ellipses", dict(offsets=[[1, 2]], widths=4, heights=3)), + ("lines", dict(segments=[[[0, 0], [1, 1]]])), + ("rectangles", dict(offsets=[[1, 2]], widths=4, heights=3)), + ("squares", dict(offsets=[[1, 2]], widths=4)), + ("polygons", dict(vertices_list=[[[0,0],[1,0],[0.5,1]]])), + ("texts", dict(offsets=[[1, 2]], texts=["hi"])), + ] + for mtype, kwargs in types_and_kwargs: + g = _group(mtype, transform="axes", **kwargs) + w = g.to_wire("gid") + assert w["transform"] == "axes", f"Failed for type {mtype!r}" + + def test_1d_types_emit_transform(self): + types_and_kwargs = [ + ("points", dict(offsets=[1.0, 2.0])), + ("vlines", dict(offsets=[1.0, 2.0])), + ("hlines", dict(offsets=[1.0, 2.0])), + ] + for mtype, kwargs in types_and_kwargs: + g = _group(mtype, transform="axes", **kwargs) + w = g.to_wire("gid") + assert w["transform"] == "axes", f"Failed for type {mtype!r}" + + +# --------------------------------------------------------------------------- +# Validation +# --------------------------------------------------------------------------- + +class TestTransformValidation: + + def test_invalid_transform_raises_on_init(self): + with pytest.raises(ValueError, match="transform"): + _group("circles", offsets=[[1, 2]], transform="screen") + + def test_invalid_transform_raises_on_set(self): + g = _group("circles", offsets=[[1, 2]]) + with pytest.raises(ValueError, match="transform"): + g.set(transform="bad") + + def test_valid_transforms_do_not_raise(self): + for tfm in ("data", "axes", "display"): + _group("circles", offsets=[[1, 2]], transform=tfm) # no error + + +# --------------------------------------------------------------------------- +# set() preserves transform +# --------------------------------------------------------------------------- + +class TestTransformPreservedOnSet: + + def test_set_does_not_reset_transform(self): + g = _group("circles", offsets=[[1, 2]], radius=5, transform="axes") + g.set(radius=10) + w = g.to_wire("gid") + assert w["transform"] == "axes" + + def test_set_can_update_transform(self): + g = _group("circles", offsets=[[1, 2]], transform="axes") + g.set(transform="display") + w = g.to_wire("gid") + assert w["transform"] == "display" + + +# --------------------------------------------------------------------------- +# Plot2D add_* methods accept transform kwarg +# --------------------------------------------------------------------------- + +class TestPlot2DTransformKwarg: + + def setup_method(self): + self.plot = _make_plot2d() + + def test_add_circles_transform_axes(self): + g = self.plot.add_circles([[10, 10]], name="c", transform="axes") + wire = self.plot.markers.to_wire_list() + assert wire[0]["transform"] == "axes" + + def test_add_points_transform_axes(self): + g = self.plot.add_points([[10, 10]], name="p", transform="axes") + wire = self.plot.markers.to_wire_list() + assert wire[0]["transform"] == "axes" + + def test_add_texts_transform_axes(self): + g = self.plot.add_texts([[0.05, 0.95]], ["label"], name="t", + transform="axes") + wire = self.plot.markers.to_wire_list() + assert wire[0]["transform"] == "axes" + + def test_add_rectangles_transform_display(self): + g = self.plot.add_rectangles([[5, 5]], widths=10, heights=10, name="r", + transform="display") + wire = self.plot.markers.to_wire_list() + assert wire[0]["transform"] == "display" + + def test_add_arrows_transform_axes(self): + g = self.plot.add_arrows([[5, 5]], U=1, V=1, name="a", transform="axes") + wire = self.plot.markers.to_wire_list() + assert wire[0]["transform"] == "axes" + + def test_add_ellipses_transform_axes(self): + g = self.plot.add_ellipses([[5, 5]], widths=4, heights=3, name="e", + transform="axes") + wire = self.plot.markers.to_wire_list() + assert wire[0]["transform"] == "axes" + + def test_add_lines_transform_axes(self): + g = self.plot.add_lines([[[0, 0], [1, 1]]], name="l", transform="axes") + wire = self.plot.markers.to_wire_list() + assert wire[0]["transform"] == "axes" + + def test_add_squares_transform_axes(self): + g = self.plot.add_squares([[5, 5]], widths=4, name="s", transform="axes") + wire = self.plot.markers.to_wire_list() + assert wire[0]["transform"] == "axes" + + def test_add_polygons_transform_axes(self): + verts = [[[0, 0], [1, 0], [0.5, 1]]] + g = self.plot.add_polygons(verts, name="pg", transform="axes") + wire = self.plot.markers.to_wire_list() + assert wire[0]["transform"] == "axes" + + def test_default_transform_is_data(self): + g = self.plot.add_texts([[5, 5]], ["hi"], name="t2") + wire = self.plot.markers.to_wire_list() + assert wire[0]["transform"] == "data" + + +# --------------------------------------------------------------------------- +# Plot1D add_* methods accept transform kwarg +# --------------------------------------------------------------------------- + +class TestPlot1DTransformKwarg: + + def setup_method(self): + self.plot = _make_plot1d() + + def test_add_vlines_transform_axes(self): + self.plot.add_vlines([0.5], name="v", transform="axes") + wire = self.plot.markers.to_wire_list() + assert wire[0]["transform"] == "axes" + + def test_add_hlines_transform_axes(self): + self.plot.add_hlines([0.5], name="h", transform="axes") + wire = self.plot.markers.to_wire_list() + assert wire[0]["transform"] == "axes" + + def test_add_texts_transform_axes(self): + self.plot.add_texts([[0.05, 0.95]], ["label"], name="t", + transform="axes") + wire = self.plot.markers.to_wire_list() + assert wire[0]["transform"] == "axes" + + def test_default_transform_is_data(self): + self.plot.add_vlines([0.5], name="v2") + wire = self.plot.markers.to_wire_list() + assert wire[0]["transform"] == "data" diff --git a/anyplotlib/tests/test_markers/test_markers.py b/anyplotlib/tests/test_markers/test_markers.py new file mode 100644 index 00000000..1971cb57 --- /dev/null +++ b/anyplotlib/tests/test_markers/test_markers.py @@ -0,0 +1,726 @@ +""" +tests/test_markers.py +===================== + +Tests for the marker system (MarkerGroup, MarkerTypeDict, MarkerRegistry) +and the high-level add_* helpers on Plot2D, Plot1D, and PlotMesh. + +Exercises all marker types from the Examples/Markers gallery: + circles, arrows, ellipses, lines, rectangles, squares, polygons, texts, + points, vlines, hlines. + +Also covers: + * set() — live update + * remove() / clear() + * auto-naming (circles_1, circles_2, …) + * to_wire() output structure for every type + * to_wire() validation errors + * MarkerTypeDict dict-like interface (contains, iter, len, keys/values/items, pop) + * MarkerRegistry allowed-type restriction +""" +from __future__ import annotations + +import numpy as np +import pytest + +import anyplotlib as apl +from anyplotlib.markers import MarkerGroup, MarkerTypeDict, MarkerRegistry + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_plot2d(): + fig, ax = apl.subplots(1, 1) + data = np.random.default_rng(0).standard_normal((64, 64)) + return ax.imshow(data) + + +def _make_plot1d(): + fig, ax = apl.subplots(1, 1) + return ax.plot(np.sin(np.linspace(0, 2 * np.pi, 128))) + + +def _make_mesh(): + fig, ax = apl.subplots(1, 1) + data = np.ones((8, 12)) + x_edges = np.linspace(0, 12, 13) + y_edges = np.linspace(0, 8, 9) + return ax.pcolormesh(data, x_edges=x_edges, y_edges=y_edges) + + +# --------------------------------------------------------------------------- +# MarkerGroup — to_wire() for every type +# --------------------------------------------------------------------------- + +def _push_noop(): + pass + + +class TestMarkerGroupToWire: + + def _group(self, mtype, **kwargs): + return MarkerGroup(mtype, "g1", kwargs, _push_noop) + + # ── 2-D types ─────────────────────────────────────────────────────────── + + def test_circles_basic(self): + g = self._group("circles", offsets=[[10.0, 20.0], [30.0, 40.0]], radius=5) + w = g.to_wire("gid") + assert w["type"] == "circles" + assert len(w["offsets"]) == 2 + assert len(w["sizes"]) == 2 + assert w["sizes"][0] == pytest.approx(5.0) + + def test_circles_with_facecolors(self): + g = self._group("circles", offsets=[[0.0, 0.0]], facecolors="#ff0000", alpha=0.5) + w = g.to_wire("gid") + assert "fill_color" in w + assert w["fill_alpha"] == pytest.approx(0.5) + + def test_arrows_basic(self): + g = self._group("arrows", offsets=[[0.0, 0.0]], U=1.0, V=2.0, linewidths=2.0) + w = g.to_wire("gid") + assert w["type"] == "arrows" + assert len(w["U"]) == 1 + assert len(w["V"]) == 1 + assert w["linewidth"] == pytest.approx(2.0) + + def test_ellipses_basic(self): + g = self._group("ellipses", + offsets=[[32.0, 32.0], [64.0, 96.0]], + widths=30, heights=14, angles=[0.0, 45.0]) + w = g.to_wire("gid") + assert w["type"] == "ellipses" + assert len(w["widths"]) == 2 + assert len(w["heights"]) == 2 + + def test_ellipses_with_fill(self): + g = self._group("ellipses", offsets=[[0.0, 0.0]], widths=10, heights=5, + facecolors="#00ff00", alpha=0.4) + w = g.to_wire("gid") + assert "fill_color" in w + + def test_lines_single_segment(self): + g = self._group("lines", segments=[[0.0, 0.0], [10.0, 10.0]]) + w = g.to_wire("gid") + assert w["type"] == "lines" + assert len(w["segments"]) == 1 + + def test_lines_multi_segment(self): + segs = [[[0.0, 0.0], [5.0, 5.0]], [[5.0, 5.0], [10.0, 0.0]]] + g = self._group("lines", segments=segs) + w = g.to_wire("gid") + assert len(w["segments"]) == 2 + + def test_lines_bad_shape(self): + g = self._group("lines", segments=[[[0.0, 0.0], [1.0, 2.0], [3.0, 4.0]]]) + with pytest.raises(ValueError): + g.to_wire("gid") + + def test_rectangles_basic(self): + g = self._group("rectangles", offsets=[[10.0, 10.0]], widths=20, heights=10) + w = g.to_wire("gid") + assert w["type"] == "rectangles" + + def test_rectangles_with_fill(self): + g = self._group("rectangles", offsets=[[0.0, 0.0]], widths=5, heights=5, + facecolors="#0000ff", alpha=0.2) + w = g.to_wire("gid") + assert "fill_color" in w + + def test_squares_basic(self): + g = self._group("squares", offsets=[[32.0, 32.0]], widths=20, angles=[15.0]) + w = g.to_wire("gid") + assert w["type"] == "squares" + + def test_squares_with_fill(self): + g = self._group("squares", offsets=[[0.0, 0.0]], widths=10, + facecolors="#ff00ff", alpha=0.3) + w = g.to_wire("gid") + assert "fill_color" in w + + def test_polygons_basic(self): + tri = [[0.0, 0.0], [10.0, 0.0], [5.0, 8.0]] + g = self._group("polygons", vertices_list=[tri]) + w = g.to_wire("gid") + assert w["type"] == "polygons" + assert len(w["vertices_list"]) == 1 + + def test_polygons_with_fill(self): + tri = [[0.0, 0.0], [10.0, 0.0], [5.0, 8.0]] + g = self._group("polygons", vertices_list=[tri], facecolors="#aaa", alpha=0.5) + w = g.to_wire("gid") + assert "fill_color" in w + + def test_polygons_bad_vertex(self): + bad = [[0.0, 0.0], [1.0, 1.0]] # only 2 points — must be ≥3 + g = self._group("polygons", vertices_list=[bad]) + with pytest.raises(ValueError): + g.to_wire("gid") + + def test_texts_basic(self): + g = self._group("texts", offsets=[[10.0, 20.0]], texts=["hello"], fontsize=14) + w = g.to_wire("gid") + assert w["type"] == "texts" + assert w["texts"] == ["hello"] + assert w["fontsize"] == 14 + + # ── 1-D types ─────────────────────────────────────────────────────────── + + def test_points_basic(self): + g = self._group("points", offsets=[1.0, 2.0, 3.0], sizes=7, color="#ff0000") + w = g.to_wire("gid") + assert w["type"] == "points" + assert len(w["offsets"]) == 3 + + def test_points_with_fill(self): + g = self._group("points", offsets=[1.0], facecolors="#00ff00", alpha=0.6) + w = g.to_wire("gid") + assert "fill_color" in w + + def test_vlines_basic(self): + g = self._group("vlines", offsets=[1.0, 2.5, 4.0]) + w = g.to_wire("gid") + assert w["type"] == "vlines" + assert len(w["offsets"]) == 3 + assert all(len(r) == 1 for r in w["offsets"]) + + def test_hlines_basic(self): + g = self._group("hlines", offsets=[0.5, 1.0]) + w = g.to_wire("gid") + assert w["type"] == "hlines" + assert len(w["offsets"]) == 2 + + def test_unknown_type_raises(self): + g = self._group("stars", offsets=[[0.0, 0.0]]) + with pytest.raises(ValueError, match="Unknown marker type"): + g.to_wire("gid") + + # ── Optional common fields ─────────────────────────────────────────────── + + def test_label_included(self): + g = self._group("circles", offsets=[[0.0, 0.0]], label="my label") + w = g.to_wire("gid") + assert w["label"] == "my label" + + def test_labels_included(self): + g = self._group("circles", offsets=[[0.0, 0.0], [1.0, 1.0]], + labels=["A", "B"]) + w = g.to_wire("gid") + assert w["labels"] == ["A", "B"] + + def test_hover_edgecolors(self): + g = self._group("circles", offsets=[[0.0, 0.0]], hover_edgecolors="#ff0") + w = g.to_wire("gid") + assert w["hover_color"] == "#ff0" + + def test_hover_facecolors(self): + g = self._group("circles", offsets=[[0.0, 0.0]], hover_facecolors="#0f0") + w = g.to_wire("gid") + assert w["hover_facecolor"] == "#0f0" + + +# --------------------------------------------------------------------------- +# MarkerGroup — set() triggers push +# --------------------------------------------------------------------------- + +class TestMarkerGroupSet: + + def test_set_updates_data(self): + calls = [] + g = MarkerGroup("circles", "g", {"offsets": [[0.0, 0.0]], "radius": 5}, + lambda: calls.append(1)) + g.set(radius=10) + assert g._data["radius"] == 10 + assert len(calls) == 1 + + def test_count_zero_when_no_offsets(self): + g = MarkerGroup("circles", "g", {}, _push_noop) + assert g._count() == 0 + + +# --------------------------------------------------------------------------- +# MarkerTypeDict +# --------------------------------------------------------------------------- + +class TestMarkerTypeDict: + + def _td(self): + calls = [] + td = MarkerTypeDict("circles", lambda: calls.append(1)) + return td, calls + + def test_setitem_triggers_push(self): + td, calls = self._td() + g = MarkerGroup("circles", "g", {"offsets": [[0.0, 0.0]]}, _push_noop) + td["g"] = g + assert len(calls) == 1 + + def test_delitem_triggers_push(self): + td, calls = self._td() + g = MarkerGroup("circles", "g", {"offsets": [[0.0, 0.0]]}, _push_noop) + td._groups["g"] = g + del td["g"] + assert len(calls) == 1 + + def test_contains(self): + td, _ = self._td() + g = MarkerGroup("circles", "g", {}, _push_noop) + td._groups["g"] = g + assert "g" in td + assert "x" not in td + + def test_iter(self): + td, _ = self._td() + g = MarkerGroup("circles", "g", {}, _push_noop) + td._groups["g"] = g + assert list(td) == ["g"] + + def test_len(self): + td, _ = self._td() + assert len(td) == 0 + td._groups["a"] = MarkerGroup("circles", "a", {}, _push_noop) + assert len(td) == 1 + + def test_keys_values_items(self): + td, _ = self._td() + g = MarkerGroup("circles", "g", {}, _push_noop) + td._groups["g"] = g + assert "g" in td.keys() + assert g in td.values() + assert ("g", g) in td.items() + + def test_pop_triggers_push(self): + td, calls = self._td() + g = MarkerGroup("circles", "g", {}, _push_noop) + td._groups["g"] = g + result = td.pop("g") + assert result is g + assert len(calls) == 1 + + def test_pop_default(self): + td, _ = self._td() + result = td.pop("missing", None) + assert result is None + + def test_to_wire_list(self): + td, _ = self._td() + g = MarkerGroup("circles", "g", {"offsets": [[5.0, 5.0]]}, _push_noop) + td._groups["g"] = g + wl = td.to_wire_list() + assert len(wl) == 1 + assert wl[0]["type"] == "circles" + + +# --------------------------------------------------------------------------- +# MarkerRegistry +# --------------------------------------------------------------------------- + +class TestMarkerRegistry: + + def _reg(self, allowed=None): + calls = [] + reg = MarkerRegistry(lambda: calls.append(1), allowed=allowed) + return reg, calls + + def test_auto_creates_type_dict(self): + reg, _ = self._reg() + td = reg["circles"] + assert isinstance(td, MarkerTypeDict) + assert "circles" in reg + + def test_allowed_restriction(self): + reg, _ = self._reg(allowed=frozenset({"circles"})) + with pytest.raises(ValueError, match="not allowed"): + reg["arrows"] + + def test_add_returns_marker_group(self): + reg, calls = self._reg() + g = reg.add("circles", name="g1", offsets=[[0.0, 0.0]], radius=5) + assert isinstance(g, MarkerGroup) + assert len(calls) == 1 + + def test_add_auto_name(self): + reg, _ = self._reg() + g1 = reg.add("circles", offsets=[[0.0, 0.0]]) + g2 = reg.add("circles", offsets=[[1.0, 1.0]]) + assert g1._name == "circles_1" + assert g2._name == "circles_2" + + def test_remove(self): + reg, calls = self._reg() + reg.add("circles", name="g1", offsets=[[0.0, 0.0]]) + n_before = len(calls) + reg.remove("circles", "g1") + assert len(calls) > n_before + + def test_clear(self): + reg, calls = self._reg() + reg.add("circles", name="g1", offsets=[[0.0, 0.0]]) + reg.clear() + assert "circles" not in reg + + def test_iter(self): + reg, _ = self._reg() + reg.add("circles", name="g1", offsets=[[0.0, 0.0]]) + assert "circles" in list(reg) + + def test_to_wire_list(self): + reg, _ = self._reg() + reg.add("circles", name="g1", offsets=[[10.0, 20.0]], radius=4) + wl = reg.to_wire_list() + assert len(wl) == 1 + assert wl[0]["type"] == "circles" + + def test_auto_name_with_custom_names(self): + """Auto-naming should not be confused by custom-named groups.""" + reg, _ = self._reg() + reg.add("circles", name="my_spot", offsets=[[0.0, 0.0]]) + g = reg.add("circles", offsets=[[1.0, 1.0]]) + assert g._name == "circles_1" + + +# --------------------------------------------------------------------------- +# Plot2D high-level add_* helpers (from Examples/Markers) +# --------------------------------------------------------------------------- + +class TestPlot2DMarkerHelpers: + + def test_add_circles(self): + v = _make_plot2d() + centres = np.array([[10.0, 20.0], [30.0, 40.0]]) + v.add_circles(centres, name="spots", radius=10, + edgecolors="#ff1744", facecolors="#ff174433", + labels=["A", "B"]) + assert "spots" in v.markers["circles"] + wl = v.markers.to_wire_list() + assert any(w["type"] == "circles" for w in wl) + + def test_add_circles_set(self): + v = _make_plot2d() + v.add_circles([[5.0, 5.0]], name="c", radius=5) + v.markers["circles"]["c"].set(radius=12, edgecolors="#ffcc00") + assert v.markers["circles"]["c"]._data["radius"] == 12 + + def test_add_arrows(self): + v = _make_plot2d() + tails = np.array([[20.0, 20.0], [60.0, 60.0]]) + U = np.array([5.0, -5.0]) + V = np.array([5.0, -5.0]) + v.add_arrows(tails, U, V, name="flow", edgecolors="#76ff03") + assert "flow" in v.markers["arrows"] + + def test_add_arrows_set(self): + v = _make_plot2d() + v.add_arrows([[5.0, 5.0]], [1.0], [1.0], name="arr") + v.markers["arrows"]["arr"].set(edgecolors="#ff9100", linewidths=2.5) + assert v.markers["arrows"]["arr"]._data["edgecolors"] == "#ff9100" + + def test_add_ellipses(self): + v = _make_plot2d() + centres = np.array([[32.0, 32.0], [64.0, 96.0]]) + v.add_ellipses(centres, widths=30, heights=14, angles=[0.0, 45.0], + name="grains", edgecolors="#ff9100") + assert "grains" in v.markers["ellipses"] + + def test_add_ellipses_set(self): + v = _make_plot2d() + v.add_ellipses([[0.0, 0.0]], widths=10, heights=5, name="e") + v.markers["ellipses"]["e"].set(widths=20) + assert v.markers["ellipses"]["e"]._data["widths"] == 20 + + def test_add_rectangles(self): + v = _make_plot2d() + centres = np.array([[20.0, 20.0], [50.0, 50.0]]) + v.add_rectangles(centres, widths=22, heights=14, name="boxes", + edgecolors="#00e5ff") + assert "boxes" in v.markers["rectangles"] + + def test_add_rectangles_set(self): + v = _make_plot2d() + v.add_rectangles([[5.0, 5.0]], widths=10, heights=5, name="r") + v.markers["rectangles"]["r"].set(widths=20, heights=10) + assert v.markers["rectangles"]["r"]._data["widths"] == 20 + + def test_add_squares(self): + v = _make_plot2d() + centres = np.array([[32.0, 32.0], [64.0, 64.0]]) + v.add_squares(centres, widths=20, angles=[0, 15], name="tiles") + assert "tiles" in v.markers["squares"] + + def test_add_squares_set(self): + v = _make_plot2d() + v.add_squares([[5.0, 5.0]], widths=10, name="s") + v.markers["squares"]["s"].set(widths=20, edgecolors="#e040fb") + assert v.markers["squares"]["s"]._data["widths"] == 20 + + def test_add_polygons(self): + v = _make_plot2d() + tri = [[10.0, 5.0], [20.0, 25.0], [0.0, 25.0]] + v.add_polygons([tri], name="poly", edgecolors="#ff9100") + assert "poly" in v.markers["polygons"] + + def test_add_texts(self): + v = _make_plot2d() + v.add_texts([[10.0, 10.0], [30.0, 30.0]], texts=["A", "B"], + name="labels", color="#ffffff", fontsize=12) + assert "labels" in v.markers["texts"] + + def test_add_lines_2d(self): + v = _make_plot2d() + segs = [[[5.0, 5.0], [20.0, 20.0]], [[20.0, 20.0], [40.0, 5.0]]] + v.add_lines(segs, name="segs") + assert "segs" in v.markers["lines"] + + def test_remove_marker(self): + v = _make_plot2d() + v.add_circles([[0.0, 0.0]], name="c") + v.remove_marker("circles", "c") + assert "c" not in v.markers["circles"] + + def test_clear_markers(self): + v = _make_plot2d() + v.add_circles([[0.0, 0.0]], name="c1") + v.add_circles([[1.0, 1.0]], name="c2") + v.clear_markers() + assert v.markers.to_wire_list() == [] + + def test_list_markers(self): + v = _make_plot2d() + v.add_circles([[0.0, 0.0], [1.0, 1.0]], name="c") + info = v.list_markers() + assert any(d["name"] == "c" and d["n"] == 2 for d in info) + + +# --------------------------------------------------------------------------- +# Plot1D marker helpers +# --------------------------------------------------------------------------- + +class TestPlot1DMarkerHelpers: + + def test_add_points(self): + v = _make_plot1d() + offsets = np.column_stack([[1.0, 2.0, 3.0], [0.5, 0.8, 0.3]]) + v.add_points(offsets, name="peaks", sizes=7, color="#ff1744") + assert "peaks" in v.markers["points"] + + def test_add_vlines(self): + v = _make_plot1d() + v.add_vlines([1.0, 2.0, 3.0], name="marks", color="#00e5ff") + assert "marks" in v.markers["vlines"] + + def test_add_hlines(self): + v = _make_plot1d() + v.add_hlines([0.5, -0.5], name="levels", color="#ff9100") + assert "levels" in v.markers["hlines"] + + def test_remove_marker_1d(self): + v = _make_plot1d() + v.add_vlines([1.0], name="m") + v.remove_marker("vlines", "m") + assert "m" not in v.markers["vlines"] + + def test_clear_markers_1d(self): + v = _make_plot1d() + v.add_vlines([1.0], name="v1") + v.add_hlines([0.5], name="h1") + v.clear_markers() + assert v.markers.to_wire_list() == [] + + +# --------------------------------------------------------------------------- +# PlotMesh marker helpers +# --------------------------------------------------------------------------- + +class TestPlotMeshMarkerHelpers: + + def test_add_circles_mesh(self): + mesh = _make_mesh() + pts = np.array([[2.0, 2.0], [6.0, 4.0]]) + mesh.add_circles(pts, name="peaks", radius=0.5, edgecolors="#ff1744") + assert "peaks" in mesh.markers["circles"] + + def test_add_lines_mesh(self): + mesh = _make_mesh() + segs = [[[1.0, 1.0], [5.0, 5.0]]] + mesh.add_lines(segs, name="path", edgecolors="#00e5ff") + assert "path" in mesh.markers["lines"] + + def test_mesh_disallows_arrows(self): + mesh = _make_mesh() + with pytest.raises(ValueError, match="not allowed"): + mesh.add_arrows([[0.0, 0.0]], [1.0], [1.0]) + + +# --------------------------------------------------------------------------- +# MarkerGroup.remove() +# --------------------------------------------------------------------------- + +class TestMarkerGroupRemove: + + def test_remove_deletes_from_parent(self): + p = _make_plot2d() + g = p.add_circles([[10.0, 20.0]], name="dot", radius=3) + assert "dot" in p.markers["circles"] + g.remove() + assert "dot" not in p.markers["circles"] + + def test_remove_triggers_push(self): + calls = [] + td = MarkerTypeDict("circles", lambda: calls.append(1)) + g = td._add("g", {"offsets": [[0.0, 0.0]], "radius": 2}) + calls.clear() + g.remove() + assert len(calls) == 1 + + def test_remove_no_parent_raises(self): + g = MarkerGroup("circles", "g", {"offsets": [[0.0, 0.0]]}, _push_noop) + with pytest.raises(RuntimeError, match="no parent"): + g.remove() + + def test_remove_1d_group(self): + p = _make_plot1d() + g = p.add_vlines([0.5, 1.5], name="marks") + assert "marks" in p.markers["vlines"] + g.remove() + assert "marks" not in p.markers["vlines"] + + +# =========================================================================== +# _KNOWN_1D completeness — arrows and squares +# =========================================================================== + +class TestKnown1dArrowsSquares: + def test_arrows_in_known_1d(self): + assert "arrows" in MarkerRegistry._KNOWN_1D + + def test_squares_in_known_1d(self): + assert "squares" in MarkerRegistry._KNOWN_1D + + def test_add_arrows_does_not_raise(self): + p = _make_plot1d() + offsets = np.column_stack([np.linspace(0, 1, 5), np.zeros(5)]) + p.add_arrows(offsets, U=0.05, V=0.1) + + def test_add_squares_does_not_raise(self): + p = _make_plot1d() + offsets = np.column_stack([np.linspace(0, 1, 3), np.zeros(3)]) + p.add_squares(offsets, widths=0.05) + + def test_add_arrows_wire_format(self): + p = _make_plot1d() + offsets = np.array([[0.1, 0.2], [0.5, 0.6]]) + p.add_arrows(offsets, U=0.1, V=0.2, name="arr") + wires = [m for m in p._state["markers"] if m["type"] == "arrows"] + assert len(wires) == 1 + w = wires[0] + assert "U" in w and "V" in w + assert len(w["U"]) == 2 + assert len(w["offsets"]) == 2 + + def test_add_squares_wire_format(self): + p = _make_plot1d() + offsets = np.array([[0.1, 0.2], [0.5, 0.6]]) + p.add_squares(offsets, widths=0.1, name="sq") + wires = [m for m in p._state["markers"] if m["type"] == "squares"] + assert len(wires) == 1 + w = wires[0] + assert "widths" in w + assert len(w["widths"]) == 2 + + +# =========================================================================== +# drawMarkers1d new types — wire format correctness +# =========================================================================== + +class TestMarkers1dNewTypes: + """add_rectangles/ellipses/polygons/arrows/squares on Plot1D produce + correct wire-format dicts that the JS drawMarkers1d handler will receive.""" + + def _plot(self): + x = np.linspace(0, 2 * np.pi, 64) + fig, ax = apl.subplots(1, 1) + return ax.plot(np.sin(x), axes=[x]) + + def _wire(self, p, type_): + return [m for m in p._state["markers"] if m["type"] == type_] + + def test_add_rectangles_wire(self): + p = self._plot() + offsets = np.array([[1.0, 0.5], [3.0, -0.5]]) + p.add_rectangles(offsets, widths=0.2, heights=0.1, name="rects") + ws = self._wire(p, "rectangles") + assert len(ws) == 1 + w = ws[0] + assert "widths" in w and "heights" in w + assert len(w["offsets"]) == 2 + + def test_add_squares_wire(self): + p = self._plot() + offsets = np.array([[1.0, 0.5], [3.0, -0.5]]) + p.add_squares(offsets, widths=0.1, name="sq") + ws = self._wire(p, "squares") + assert len(ws) == 1 + assert "widths" in ws[0] + + def test_add_ellipses_wire(self): + p = self._plot() + offsets = np.array([[1.0, 0.5], [4.0, 0.0]]) + p.add_ellipses(offsets, widths=0.3, heights=0.15, name="ellip") + ws = self._wire(p, "ellipses") + assert len(ws) == 1 + w = ws[0] + assert "widths" in w and "heights" in w and "angles" in w + + def test_add_polygons_wire(self): + p = self._plot() + tri = np.array([[0.5, 0.0], [1.0, 0.5], [1.5, 0.0]]) + p.add_polygons([tri], name="poly") + ws = self._wire(p, "polygons") + assert len(ws) == 1 + assert "vertices_list" in ws[0] + assert len(ws[0]["vertices_list"]) == 1 + + def test_add_arrows_wire(self): + p = self._plot() + offsets = np.array([[1.0, 0.0], [3.0, 0.5]]) + p.add_arrows(offsets, U=0.2, V=0.1, name="arrows") + ws = self._wire(p, "arrows") + assert len(ws) == 1 + w = ws[0] + assert "U" in w and "V" in w + assert len(w["U"]) == 2 + + + +# =========================================================================== +# Top-level exports +# =========================================================================== + +class TestTopLevelExports: + def test_line1d_exported(self): + import anyplotlib as apl + assert hasattr(apl, "Line1D") + from anyplotlib import Line1D + assert Line1D is not None + + def test_marker_registry_exported(self): + import anyplotlib as apl + assert hasattr(apl, "MarkerRegistry") + from anyplotlib import MarkerRegistry + assert MarkerRegistry is not None + + def test_marker_group_exported(self): + import anyplotlib as apl + assert hasattr(apl, "MarkerGroup") + from anyplotlib import MarkerGroup + assert MarkerGroup is not None + + def test_line1d_data_length_not_in_wire(self): + """data_length must not appear in to_state_dict() wire output.""" + fig, ax = apl.subplots(1, 1) + p = ax.plot(np.linspace(0, 1, 64)) + wire = p.to_state_dict() + assert "data_length" not in wire diff --git a/anyplotlib/tests/test_plot1d/__init__.py b/anyplotlib/tests/test_plot1d/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/anyplotlib/tests/test_plot1d/test_plot1d.py b/anyplotlib/tests/test_plot1d/test_plot1d.py new file mode 100644 index 00000000..fedd5313 --- /dev/null +++ b/anyplotlib/tests/test_plot1d/test_plot1d.py @@ -0,0 +1,1017 @@ +""" +tests/test_plot1d/test_plot1d.py +================================= + +Unit tests for Plot1D — covering: + + * _norm_linestyle helper + * Default state values + * Construction via Axes.plot() (linestyle, ls shorthand, alpha, marker) + * Setter methods: set_color, set_linewidth, set_linestyle, set_alpha, + set_marker, set_data + * data property (read-only view) + * line property returning Line1D + * add_line() / remove_line() / clear_lines() and Line1D handle + * add_line() field parity (linestyle/alpha/marker in extra_lines dicts) + * State-dict round-trip (to_state_dict) + * Data-range recomputation (data_min / data_max) after overlay changes + * add_span() / remove_span() / clear_spans() + * add_vline_widget() / add_hline_widget() / add_range_widget() + * Widget management: get_widget, remove_widget, list_widgets, clear_widgets + * Marker helpers: add_points, add_vlines, add_hlines, + list_markers, remove_marker, clear_markers +""" +from __future__ import annotations + +import numpy as np +import pytest + +import anyplotlib as apl +from anyplotlib._utils import _norm_linestyle +from anyplotlib.plot1d import Plot1D +from anyplotlib.plot1d._plot1d import Line1D + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _plot(n: int = 128, **kwargs) -> Plot1D: + """Create a Plot1D attached to a one-panel Figure with deterministic data.""" + fig, ax = apl.subplots(1, 1) + data = np.sin(np.linspace(0, 2 * np.pi, n)) + return ax.plot(data, **kwargs) + + +def _plot_lin(n: int = 32, **kwargs) -> Plot1D: + """Create a Plot1D with linspace data (useful for range tests).""" + fig, ax = apl.subplots(1, 1) + return ax.plot(np.linspace(0.0, 1.0, n), **kwargs) + + +t = np.linspace(0, 2 * np.pi, 128) + + +# =========================================================================== +# _norm_linestyle +# =========================================================================== + +class TestNormLinestyle: + + def test_canonical_names_round_trip(self): + for ls in ("solid", "dashed", "dotted", "dashdot"): + assert _norm_linestyle(ls) == ls + + def test_shorthand_solid(self): + assert _norm_linestyle("-") == "solid" + + def test_shorthand_dashed(self): + assert _norm_linestyle("--") == "dashed" + + def test_shorthand_dotted(self): + assert _norm_linestyle(":") == "dotted" + + def test_shorthand_dashdot(self): + assert _norm_linestyle("-.") == "dashdot" + + def test_invalid_raises(self): + with pytest.raises(ValueError, match="Unknown linestyle"): + _norm_linestyle("loose") + + def test_invalid_empty_raises(self): + with pytest.raises(ValueError): + _norm_linestyle("") + + +# =========================================================================== +# Default state values +# =========================================================================== + +class TestPlot1DDefaults: + + def test_linestyle_default(self): + p = _plot_lin() + assert p._state["line_linestyle"] == "solid" + + def test_alpha_default(self): + p = _plot_lin() + assert p._state["line_alpha"] == 1.0 + + def test_marker_default(self): + p = _plot_lin() + assert p._state["line_marker"] == "none" + + def test_markersize_default(self): + p = _plot_lin() + assert p._state["line_markersize"] == 4.0 + + +# =========================================================================== +# Construction via Axes.plot() +# =========================================================================== + +class TestPlot1DConstruction: + + def test_linestyle_dashed(self): + p = _plot(linestyle="dashed") + assert p._state["line_linestyle"] == "dashed" + + def test_linestyle_dotted(self): + p = _plot(linestyle="dotted") + assert p._state["line_linestyle"] == "dotted" + + def test_linestyle_dashdot(self): + p = _plot(linestyle="-.") + assert p._state["line_linestyle"] == "dashdot" + + def test_ls_shorthand(self): + p = _plot(ls="--") + assert p._state["line_linestyle"] == "dashed" + + def test_ls_shorthand_takes_precedence_over_linestyle(self): + p = _plot_lin(linestyle="solid", ls="--") + assert p._state["line_linestyle"] == "dashed" + + def test_ls_only(self): + p = _plot_lin(ls=":") + assert p._state["line_linestyle"] == "dotted" + + def test_alpha_stored(self): + p = _plot(alpha=0.4) + assert p._state["line_alpha"] == pytest.approx(0.4) + + def test_marker_stored(self): + p = _plot(marker="s", markersize=5) + assert p._state["line_marker"] == "s" + assert p._state["line_markersize"] == pytest.approx(5.0) + + def test_markersize_stored(self): + p = _plot_lin(marker="s", markersize=8.0) + assert p._state["line_markersize"] == pytest.approx(8.0) + + def test_marker_none_string(self): + p = _plot_lin(marker="none") + assert p._state["line_marker"] == "none" + + def test_invalid_linestyle_raises(self): + with pytest.raises(ValueError, match="Unknown linestyle"): + _plot_lin(linestyle="zigzag") + + def test_all_known_markers(self): + for sym in ("o", "s", "^", "v", "D", "+", "x", "none"): + p = _plot_lin(marker=sym) + assert p._state["line_marker"] == sym + + +# =========================================================================== +# Setter methods +# =========================================================================== + +class TestPlot1DSetters: + + def test_set_color(self): + p = _plot(color="#4fc3f7") + p.set_color("#ff7043") + assert p._state["line_color"] == "#ff7043" + + def test_set_linewidth(self): + p = _plot() + p.set_linewidth(3.0) + assert p._state["line_linewidth"] == pytest.approx(3.0) + + def test_set_linestyle_canonical(self): + p = _plot_lin() + p.set_linestyle("dotted") + assert p._state["line_linestyle"] == "dotted" + + def test_set_linestyle_word(self): + p = _plot() + p.set_linestyle("dashed") + assert p._state["line_linestyle"] == "dashed" + + def test_set_linestyle_shorthand_dashdot(self): + p = _plot() + p.set_linestyle("-.") + assert p._state["line_linestyle"] == "dashdot" + + def test_set_linestyle_shorthand_colon(self): + p = _plot_lin() + p.set_linestyle(":") + assert p._state["line_linestyle"] == "dotted" + + def test_set_linestyle_invalid_raises(self): + p = _plot_lin() + with pytest.raises(ValueError): + p.set_linestyle("bad") + + def test_set_alpha(self): + p = _plot() + p.set_alpha(0.5) + assert p._state["line_alpha"] == pytest.approx(0.5) + + def test_set_marker_with_size(self): + p = _plot() + p.set_marker("o", markersize=6) + assert p._state["line_marker"] == "o" + assert p._state["line_markersize"] == pytest.approx(6.0) + + def test_set_marker_symbol_only(self): + p = _plot_lin() + p.set_marker("D") + assert p._state["line_marker"] == "D" + + def test_set_marker_no_size_leaves_default(self): + p = _plot_lin() + p.set_marker("^") + assert p._state["line_markersize"] == pytest.approx(4.0) + + def test_set_marker_none_normalised(self): + p = _plot_lin(marker="o") + p.set_marker(None) # type: ignore[arg-type] + assert p._state["line_marker"] == "none" + + def test_setters_chain_without_error(self): + """Multiple setter calls in sequence must not raise.""" + p = _plot_lin() + p.set_color("#aabbcc") + p.set_linewidth(2.5) + p.set_linestyle("--") + p.set_alpha(0.8) + p.set_marker("o", markersize=6) + assert p._state["line_linestyle"] == "dashed" + assert p._state["line_alpha"] == pytest.approx(0.8) + assert p._state["line_marker"] == "o" + + def test_set_data_replaces_primary(self): + p = _plot(n=64) + new_data = np.cos(np.linspace(0, 2 * np.pi, 64)) + p.set_data(new_data) + np.testing.assert_allclose(p._state["data"], new_data) + + def test_set_data_with_new_x_axis(self): + p = _plot(n=32) + y = np.ones(32) + x = np.linspace(10, 42, 32) + p.set_data(y, x_axis=x) + np.testing.assert_allclose(p._state["x_axis"], x) + + def test_set_data_updates_units(self): + p = _plot() + p.set_data(np.zeros(128), units="eV") + assert p._state["units"] == "eV" + + def test_set_data_2d_raises(self): + p = _plot() + with pytest.raises(ValueError): + p.set_data(np.ones((4, 4))) + + def test_data_property_readonly(self): + p = _plot() + arr = p.data + assert not arr.flags.writeable + + def test_line_property_returns_line1d(self): + p = _plot() + assert isinstance(p.line, Line1D) + assert p.line.id is None + + +# =========================================================================== +# Overlay lines (add_line / remove_line / clear_lines / Line1D handle) +# =========================================================================== + +class TestPlot1DOverlayLines: + + def test_add_line_returns_line1d(self): + p = _plot() + line = p.add_line(np.cos(t)) + assert isinstance(line, Line1D) + assert line.id is not None + + def test_add_line_stored_in_extra_lines(self): + p = _plot() + p.add_line(np.cos(t), color="#ff7043", label="cos") + assert len(p._state["extra_lines"]) == 1 + assert p._state["extra_lines"][0]["color"] == "#ff7043" + + def test_add_line_linestyle_alpha_marker(self): + p = _plot() + p.add_line(np.cos(t), linestyle="dashed", alpha=0.75, marker="o", markersize=5) + entry = p._state["extra_lines"][0] + assert entry["linestyle"] == "dashed" + assert entry["alpha"] == pytest.approx(0.75) + assert entry["marker"] == "o" + + def test_add_line_ls_shorthand(self): + p = _plot() + p.add_line(np.cos(t), ls=":") + assert p._state["extra_lines"][0]["linestyle"] == "dotted" + + def test_add_multiple_lines(self): + p = _plot() + p.add_line(np.cos(t)) + p.add_line(np.cos(t) * 0.5) + assert len(p._state["extra_lines"]) == 2 + + def test_remove_line_by_id(self): + p = _plot() + line = p.add_line(np.cos(t)) + p.remove_line(line.id) + assert len(p._state["extra_lines"]) == 0 + + def test_remove_line_by_line1d(self): + p = _plot() + line = p.add_line(np.cos(t)) + p.remove_line(line) + assert len(p._state["extra_lines"]) == 0 + + def test_remove_line_bad_id_raises(self): + p = _plot() + with pytest.raises(KeyError): + p.remove_line("nonexistent") + + def test_clear_lines(self): + p = _plot() + p.add_line(np.cos(t)) + p.add_line(np.cos(2 * t)) + p.clear_lines() + assert p._state["extra_lines"] == [] + + def test_line1d_set_data(self): + p = _plot() + line = p.add_line(np.cos(t)) + new_y = np.zeros(128) + line.set_data(new_y) + entry = next(e for e in p._state["extra_lines"] if e["id"] == line.id) + np.testing.assert_allclose(entry["data"], new_y) + + def test_line1d_set_data_primary_raises(self): + p = _plot() + primary = Line1D(p, None) + with pytest.raises(ValueError, match="primary line"): + primary.set_data(np.zeros(10)) + + def test_line1d_set_data_bad_id_raises(self): + p = _plot() + phantom = Line1D(p, "deadbeef") + with pytest.raises(KeyError): + phantom.set_data(np.zeros(128)) + + def test_line1d_remove(self): + p = _plot() + line = p.add_line(np.cos(t)) + line.remove() + assert len(p._state["extra_lines"]) == 0 + + def test_line1d_remove_primary_raises(self): + p = _plot() + primary = Line1D(p, None) + with pytest.raises(ValueError): + primary.remove() + + def test_line1d_eq_str(self): + p = _plot() + line = p.add_line(np.cos(t)) + assert line == line.id + assert not (line == "other") + + def test_line1d_hash(self): + p = _plot() + line = p.add_line(np.cos(t)) + d = {line: "val"} + assert d[line] == "val" + + def test_line1d_str(self): + p = _plot() + line = p.add_line(np.cos(t)) + assert str(line) == line.id + + +# =========================================================================== +# add_line() field parity +# =========================================================================== + +class TestAddLineParity: + + def _extra(self, **kwargs) -> dict: + p = _plot_lin() + p.add_line(np.ones(32), **kwargs) + return p._state["extra_lines"][0] + + def test_default_linestyle(self): + assert self._extra()["linestyle"] == "solid" + + def test_linestyle_stored(self): + assert self._extra(linestyle="dashed")["linestyle"] == "dashed" + + def test_ls_shorthand(self): + assert self._extra(ls=":")["linestyle"] == "dotted" + + def test_ls_overrides_linestyle(self): + assert self._extra(linestyle="solid", ls="--")["linestyle"] == "dashed" + + def test_default_alpha(self): + assert self._extra()["alpha"] == pytest.approx(1.0) + + def test_alpha_stored(self): + assert self._extra(alpha=0.4)["alpha"] == pytest.approx(0.4) + + def test_default_marker(self): + assert self._extra()["marker"] == "none" + + def test_marker_stored(self): + ex = self._extra(marker="o", markersize=6.0) + assert ex["marker"] == "o" + assert ex["markersize"] == pytest.approx(6.0) + + def test_invalid_linestyle_raises(self): + p = _plot_lin() + with pytest.raises(ValueError): + p.add_line(np.ones(32), linestyle="bad") + + def test_multiple_extra_lines_independent(self): + p = _plot_lin() + p.add_line(np.ones(32), linestyle="dashed", alpha=0.5) + p.add_line(np.ones(32), linestyle="dotted", alpha=0.8) + assert p._state["extra_lines"][0]["linestyle"] == "dashed" + assert p._state["extra_lines"][1]["linestyle"] == "dotted" + assert p._state["extra_lines"][0]["alpha"] == pytest.approx(0.5) + assert p._state["extra_lines"][1]["alpha"] == pytest.approx(0.8) + + +# =========================================================================== +# State-dict round-trip (to_state_dict) +# =========================================================================== + +class TestStateDict: + + def test_primary_keys_present(self): + p = _plot_lin(linestyle="dotted", alpha=0.7, marker="s", markersize=5.0) + sd = p.to_state_dict() + assert sd["line_linestyle"] == "dotted" + assert sd["line_alpha"] == pytest.approx(0.7) + assert sd["line_marker"] == "s" + assert sd["line_markersize"] == pytest.approx(5.0) + + def test_extra_line_keys_present(self): + p = _plot_lin() + p.add_line(np.zeros(32), linestyle="dashdot", alpha=0.6, marker="D") + sd = p.to_state_dict() + ex = sd["extra_lines"][0] + assert ex["linestyle"] == "dashdot" + assert ex["alpha"] == pytest.approx(0.6) + assert ex["marker"] == "D" + + +# =========================================================================== +# Data-range recomputation +# =========================================================================== + +class TestDataRangeRecompute: + """data_min/data_max must always cover all visible lines.""" + + def test_add_line_expands_range_upward(self): + p = _plot_lin() + primary_max = p._state["data_max"] + p.add_line(np.full(32, 5.0)) + assert p._state["data_max"] > primary_max + assert p._state["data_max"] >= 5.0 + + def test_add_line_expands_range_downward(self): + p = _plot_lin() + primary_min = p._state["data_min"] + p.add_line(np.full(32, -5.0)) + assert p._state["data_min"] < primary_min + assert p._state["data_min"] <= -5.0 + + def test_add_line_both_directions(self): + p = _plot_lin() + p.add_line(np.full(32, 10.0)) + p.add_line(np.full(32, -10.0)) + assert p._state["data_max"] >= 10.0 + assert p._state["data_min"] <= -10.0 + + def test_remove_line_shrinks_range(self): + p = _plot_lin() + lid = p.add_line(np.full(32, 100.0)) + assert p._state["data_max"] >= 100.0 + p.remove_line(lid) + assert p._state["data_max"] < 10.0 + + def test_clear_lines_restores_primary_range(self): + p = _plot_lin() + original_min = p._state["data_min"] + original_max = p._state["data_max"] + p.add_line(np.full(32, 50.0)) + p.add_line(np.full(32, -50.0)) + p.clear_lines() + assert p._state["data_min"] == pytest.approx(original_min) + assert p._state["data_max"] == pytest.approx(original_max) + + def test_range_includes_padding(self): + """5 % padding must be applied after recompute.""" + p = _plot_lin() + p.add_line(np.zeros(32) + 3.0) + assert p._state["data_max"] >= 3.0 * 1.05 - 0.01 + + def test_overlay_within_bounds_does_not_change_range(self): + p = _plot_lin() + pre_min = p._state["data_min"] + pre_max = p._state["data_max"] + p.add_line(np.full(32, 0.5)) + assert p._state["data_min"] == pytest.approx(pre_min) + assert p._state["data_max"] == pytest.approx(pre_max) + + def test_sin_overlay_expands_max(self): + p = _plot() + old_max = p._state["data_max"] + p.add_line(np.sin(t) + 5) + assert p._state["data_max"] > old_max + + +# =========================================================================== +# Spans +# =========================================================================== + +class TestPlot1DSpans: + + def test_add_span_returns_id(self): + p = _plot() + sid = p.add_span(1.0, 2.0) + assert isinstance(sid, str) + assert len(p._state["spans"]) == 1 + + def test_add_span_y_axis(self): + p = _plot() + p.add_span(0.5, 0.8, axis="y", color="#ff0000") + assert p._state["spans"][0]["axis"] == "y" + + def test_remove_span(self): + p = _plot() + sid = p.add_span(1.0, 2.0) + p.remove_span(sid) + assert p._state["spans"] == [] + + def test_remove_span_bad_id_raises(self): + p = _plot() + with pytest.raises(KeyError): + p.remove_span("nonexistent") + + def test_clear_spans(self): + p = _plot() + p.add_span(1.0, 2.0) + p.add_span(3.0, 4.0) + p.clear_spans() + assert p._state["spans"] == [] + + +# =========================================================================== +# Widgets +# =========================================================================== + +class TestPlot1DWidgets: + + def test_add_vline_widget(self): + p = _plot() + w = p.add_vline_widget(1.5, color="#ff6e40") + assert w is not None + assert len(p._widgets) == 1 + + def test_add_hline_widget(self): + p = _plot() + p.add_hline_widget(0.5) + assert len(p._widgets) == 1 + + def test_add_range_widget(self): + p = _plot() + p.add_range_widget(1.0, 3.0) + assert len(p._widgets) == 1 + + def test_get_widget_by_id(self): + p = _plot() + w = p.add_vline_widget(1.0) + assert p.get_widget(w.id) is w + + def test_get_widget_by_widget(self): + p = _plot() + w = p.add_vline_widget(1.0) + assert p.get_widget(w) is w + + def test_get_widget_missing_raises(self): + p = _plot() + with pytest.raises(KeyError): + p.get_widget("bad_id") + + def test_remove_widget(self): + p = _plot() + w = p.add_vline_widget(1.0) + p.remove_widget(w) + assert len(p._widgets) == 0 + + def test_remove_widget_missing_raises(self): + p = _plot() + with pytest.raises(KeyError): + p.remove_widget("bad_id") + + def test_list_widgets(self): + p = _plot() + p.add_vline_widget(1.0) + p.add_hline_widget(0.5) + assert len(p.list_widgets()) == 2 + + def test_clear_widgets(self): + p = _plot() + p.add_vline_widget(1.0) + p.add_hline_widget(0.5) + p.clear_widgets() + assert p.list_widgets() == [] + + +# =========================================================================== +# Marker helpers +# =========================================================================== + +class TestPlot1DMarkerHelpers: + + def test_add_points_with_facecolors(self): + p = _plot() + offsets = np.column_stack([[1.0, 2.0], [0.5, 0.8]]) + p.add_points(offsets, name="peaks", sizes=7, + color="#ff1744", facecolors="#ff174433") + wl = p.markers.to_wire_list() + assert any(w["type"] == "points" for w in wl) + + def test_list_markers_count(self): + p = _plot() + offsets = np.column_stack([[1.0, 2.0, 3.0], [0.1, 0.2, 0.3]]) + p.add_points(offsets, name="pts") + info = p.list_markers() + assert any(d["name"] == "pts" and d["n"] == 3 for d in info) + + def test_remove_marker(self): + p = _plot() + p.add_vlines([1.0, 2.0], name="m") + p.remove_marker("vlines", "m") + assert p.markers.to_wire_list() == [] + + def test_clear_markers(self): + p = _plot() + p.add_vlines([1.0], name="v") + p.add_hlines([0.5], name="h") + p.clear_markers() + assert p.markers.to_wire_list() == [] + + +# =========================================================================== +# Phase 2 — Plot1D state methods +# =========================================================================== + +class TestPlot1DProperties: + + def test_color_property(self): + p = _plot(color="#ff0000") + assert p.color == "#ff0000" + + def test_x_property_returns_ndarray(self): + p = _plot_lin(32) + x = p.x + assert isinstance(x, np.ndarray) + assert len(x) == 32 + + def test_y_property_returns_ndarray(self): + data = np.linspace(0.0, 1.0, 64) + fig, ax = apl.subplots(1, 1) + p = ax.plot(data) + y = p.y + assert isinstance(y, np.ndarray) + assert len(y) == 64 + + +class TestPlot1DLabels: + + def test_set_xlabel_updates_units(self): + p = _plot() + p.set_xlabel("Energy (eV)") + assert p._state["units"] == "Energy (eV)" + + def test_set_ylabel_updates_y_units(self): + p = _plot() + p.set_ylabel("Counts") + assert p._state["y_units"] == "Counts" + + def test_set_title(self): + p = _plot() + p.set_title("Spectrum") + assert p._state["title"] == "Spectrum" + + def test_default_title_empty(self): + p = _plot() + assert p._state["title"] == "" + + +class TestPlot1DAxisLimits: + + def test_set_xlim_changes_view(self): + p = _plot_lin(64) + p.set_xlim(10, 50) + assert p._state["view_x0"] != 0.0 or p._state["view_x1"] != 1.0 + + def test_set_ylim_stores_y_range(self): + p = _plot() + p.set_ylim(-2.0, 2.0) + assert p._state["y_range"] == [-2.0, 2.0] + + def test_get_ylim_returns_data_bounds(self): + data = np.array([0.0, 1.0, 2.0, 3.0, 4.0]) + fig, ax = apl.subplots(1, 1) + p = ax.plot(data) + lo, hi = p.get_ylim() + assert lo < hi + assert lo <= 0.0 + assert hi >= 4.0 + + def test_get_xbound_returns_x_range(self): + p = _plot_lin(32) + lo, hi = p.get_xbound() + assert lo == pytest.approx(0.0) + assert hi == pytest.approx(31.0) + + +class TestPlot1DAxisVisibility: + + def test_set_axis_off(self): + p = _plot() + assert p._state["axis_visible"] is True + p.set_axis_off() + assert p._state["axis_visible"] is False + + def test_set_ticks_visible_false(self): + p = _plot() + p.set_ticks_visible(False) + assert p._state["x_ticks_visible"] is False + assert p._state["y_ticks_visible"] is False + + def test_set_ticks_visible_per_axis(self): + p = _plot() + p.set_ticks_visible(False, x=True, y=False) + assert p._state["x_ticks_visible"] is True + assert p._state["y_ticks_visible"] is False + + +# =========================================================================== +# Phase 5 — step-mid linestyle + semilogy / yscale +# =========================================================================== + +class TestNormLinestyleStepMid: + + def test_step_mid_accepted(self): + from anyplotlib._utils import _norm_linestyle + assert _norm_linestyle("step-mid") == "step-mid" + + def test_steps_mid_alias(self): + from anyplotlib._utils import _norm_linestyle + assert _norm_linestyle("steps-mid") == "step-mid" + + def test_step_mid_stored_in_state(self): + fig, ax = apl.subplots(1, 1) + p = ax.plot(np.zeros(16), linestyle="step-mid") + assert p._state["line_linestyle"] == "step-mid" + + def test_step_mid_via_set_linestyle(self): + p = _plot() + p.set_linestyle("step-mid") + assert p._state["line_linestyle"] == "step-mid" + + +class TestSemilogy: + + def test_semilogy_sets_yscale_log(self): + fig, ax = apl.subplots(1, 1) + p = ax.semilogy(np.logspace(0, 3, 64)) + assert p._state["yscale"] == "log" + + def test_yscale_stored_in_state(self): + fig, ax = apl.subplots(1, 1) + p = ax.plot(np.zeros(16), yscale="log") + assert p._state["yscale"] == "log" + + def test_yscale_default_is_linear(self): + p = _plot() + assert p._state["yscale"] == "linear" + + def test_semilogy_passes_kwargs(self): + fig, ax = apl.subplots(1, 1) + p = ax.semilogy(np.ones(16), color="#ff0000") + assert p._state["line_color"] == "#ff0000" + assert p._state["yscale"] == "log" + + +# =========================================================================== +# set_ylim / get_ylim +# =========================================================================== + +class TestSetGetYlim: + def test_get_ylim_default_returns_data_bounds(self): + p = _plot() + lo, hi = p.get_ylim() + assert lo == pytest.approx(p._state["data_min"]) + assert hi == pytest.approx(p._state["data_max"]) + + def test_set_ylim_stored_in_state(self): + p = _plot() + p.set_ylim(-2.0, 5.0) + assert p._state["y_range"] == [-2.0, 5.0] + + def test_get_ylim_after_set_ylim(self): + p = _plot() + p.set_ylim(-1.5, 3.0) + lo, hi = p.get_ylim() + assert lo == pytest.approx(-1.5) + assert hi == pytest.approx(3.0) + + def test_y_range_not_cleared_by_reset_view(self): + p = _plot() + p.set_ylim(-1.0, 1.0) + p.reset_view() + lo, hi = p.get_ylim() + assert lo == pytest.approx(-1.0) + assert hi == pytest.approx(1.0) + + def test_y_range_in_state_dict(self): + p = _plot() + p.set_ylim(0.0, 10.0) + assert p.to_state_dict()["y_range"] == [0.0, 10.0] + + def test_y_range_none_by_default(self): + assert _plot()._state["y_range"] is None + + def test_y_range_propagated_to_state_dict(self): + p = _plot() + p.set_ylim(-5.0, 5.0) + assert p.to_state_dict()["y_range"] == [-5.0, 5.0] + + def test_markers_state_dict_contains_y_range(self): + p = _plot() + p.set_ylim(0.0, 10.0) + assert p.to_state_dict()["y_range"] == [0.0, 10.0] + + +# =========================================================================== +# get_xlim +# =========================================================================== + +class TestGetXlim: + def test_get_xlim_full_view(self): + fig, ax = apl.subplots(1, 1) + x = np.linspace(0.0, 10.0, 64) + p = ax.plot(np.sin(x), axes=[x]) + lo, hi = p.get_xlim() + assert lo == pytest.approx(0.0, abs=0.01) + assert hi == pytest.approx(10.0, abs=0.01) + + def test_get_xlim_after_set_xlim(self): + fig, ax = apl.subplots(1, 1) + x = np.linspace(0.0, 10.0, 64) + p = ax.plot(np.sin(x), axes=[x]) + p.set_xlim(2.0, 8.0) + lo, hi = p.get_xlim() + assert lo == pytest.approx(2.0, abs=0.1) + assert hi == pytest.approx(8.0, abs=0.1) + + def test_get_xlim_default_x_axis(self): + p = _plot_lin(n=100) + lo, hi = p.get_xlim() + assert lo == pytest.approx(0.0, abs=0.01) + assert hi == pytest.approx(99.0, abs=0.01) + + +# =========================================================================== +# _view_from_python flag +# =========================================================================== + +class TestViewFromPython: + def test_initial_view_from_python_false(self): + assert _plot()._state["_view_from_python"] is False + + def test_set_view_clears_flag_after_push(self): + p = _plot() + p.set_view(x0=0.2, x1=0.8) + assert p._state["_view_from_python"] is False + + def test_reset_view_clears_flag_after_push(self): + p = _plot() + p.set_view(x0=0.2, x1=0.8) + p.reset_view() + assert p._state["_view_from_python"] is False + + def test_set_xlim_clears_flag_after_push(self): + fig, ax = apl.subplots(1, 1) + x = np.linspace(0, 10, 64) + p = ax.plot(np.sin(x), axes=[x]) + p.set_xlim(2.0, 8.0) + assert p._state["_view_from_python"] is False + assert p._state["view_x0"] != 0.0 or p._state["view_x1"] != 1.0 + + def test_view_from_python_present_in_state_dict(self): + p = _plot() + p.set_view(x0=0.1, x1=0.9) + sd = p.to_state_dict() + assert "_view_from_python" in sd + assert sd["_view_from_python"] is False + + +# =========================================================================== +# add_line default color +# =========================================================================== + +class TestAddLineDefaultColor: + def test_default_color_is_not_white(self): + import inspect + p = _plot() + default = inspect.signature(p.add_line).parameters["color"].default + assert default != "#ffffff" + assert default == "#4fc3f7" + + def test_add_line_uses_default_color_in_state(self): + p = _plot() + p.add_line(np.linspace(-1, 1, 128)) + assert p._state["extra_lines"][-1]["color"] == "#4fc3f7" + + + +# =========================================================================== +# set_axis_on (Plot1D) +# =========================================================================== + +class TestSetAxisOnPlot1D: + def test_set_axis_on_restores(self): + p = _plot() + p.set_axis_off() + assert p._state["axis_visible"] is False + p.set_axis_on() + assert p._state["axis_visible"] is True + + def test_set_axis_on_default_state(self): + p = _plot() + p.set_axis_on() + assert p._state["axis_visible"] is True + + +# =========================================================================== +# M4: set_yscale on Plot1D +# =========================================================================== + +class TestSetYscale: + def test_set_yscale_log(self): + p = _plot() + p.set_yscale("log") + assert p._state["yscale"] == "log" + + def test_set_yscale_linear(self): + p = _plot() + p.set_yscale("log") + p.set_yscale("linear") + assert p._state["yscale"] == "linear" + + def test_set_yscale_invalid(self): + p = _plot() + with pytest.raises(ValueError): + p.set_yscale("symlog") + + +# =========================================================================== +# m2: configure_pointer_settled public on Plot1D +# =========================================================================== + +class TestPlot1DConfigurePointerSettled: + def test_public_method_exists(self): + p = _plot() + assert hasattr(p, "configure_pointer_settled") + assert callable(p.configure_pointer_settled) + + def test_sets_state(self): + p = _plot() + p.configure_pointer_settled(200, 5) + assert p._state["pointer_settled_ms"] == 200 + assert p._state["pointer_settled_delta"] == 5 + + +# =========================================================================== +# m3: direct tests for set_title/xlabel/ylabel and set_axis_on on Plot1D +# =========================================================================== + +class TestPlot1DDisplayMethods: + def test_set_title(self): + p = _plot() + p.set_title("My Plot") + assert p._state["title"] == "My Plot" + + def test_set_xlabel(self): + p = _plot() + p.set_xlabel("Time (s)") + assert p._state["units"] == "Time (s)" + + def test_set_ylabel(self): + p = _plot() + p.set_ylabel("Amplitude") + assert p._state["y_units"] == "Amplitude" diff --git a/anyplotlib/tests/test_plot1d/test_plotbar.py b/anyplotlib/tests/test_plot1d/test_plotbar.py new file mode 100644 index 00000000..52e66eed --- /dev/null +++ b/anyplotlib/tests/test_plot1d/test_plotbar.py @@ -0,0 +1,995 @@ +""" +tests/test_plot1d/test_plotbar.py +================================== + +Unit tests for PlotBar (bar chart) — covering: + + * Construction: defaults and explicit matplotlib-aligned API + (bar(x, height, width, bottom, ...), string x → category labels) + * State dict contents and data integrity + * Orientation: vertical / horizontal + * Colour options: single colour, per-bar colours, group colours + * Bar-width, baseline/bottom, show_values flags + * x (positions or category labels) and x_labels + * Range / padding calculations + * Grouped bars: 2-D height array, group_labels, group_colors + * Log scale: log_scale flag, clamping, set_log_scale() + * set_data(): value replacement and axis recalculation + * Display-setting mutations: set_color, set_colors, set_show_values, set_log_scale + * _push() contract: state propagated to Figure; layout_json kind == "bar" + * Callback API: on_click (incl. group_index/group_value), on_changed, disconnect + * Widgets: add_vline_widget, add_hline_widget, add_range_widget, add_point_widget, + get_widget, remove_widget, list_widgets, clear_widgets + * Edge cases: single bar, negative values, all-equal values, large N, float values + * Validation errors for bad inputs + * repr() +""" +from __future__ import annotations + +import json + +import numpy as np +import pytest + +import anyplotlib as apl +from anyplotlib.callbacks import CallbackRegistry, Event +from anyplotlib.plot1d import PlotBar + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _make_bar(values=None, **kwargs) -> PlotBar: + """Create a PlotBar attached to a one-panel Figure (values-only call).""" + if values is None: + values = [1, 2, 3, 4, 5] + fig, ax = apl.subplots(1, 1) + return ax.bar(values, **kwargs) + + +def _bar(x, height=None, **kwargs) -> PlotBar: + """Create a PlotBar via the full bar(x, height, ...) API.""" + fig, ax = apl.subplots(1, 1) + if height is not None: + return ax.bar(x, height, **kwargs) + return ax.bar(x, **kwargs) + + +def _state(plot: PlotBar) -> dict: + return plot.to_state_dict() + + +# =========================================================================== +# 1. Construction — defaults +# =========================================================================== + +class TestPlotBarDefaults: + + def test_kind_is_bar(self): + assert _state(_make_bar())["kind"] == "bar" + + def test_values_stored_as_2d(self): + values = [10, 20, 30] + p = _make_bar(values) + assert _state(p)["values"] == pytest.approx(np.array([[10.0], [20.0], [30.0]])) + + def test_numpy_array_accepted(self): + arr = np.array([1.0, 2.0, 3.0]) + p = _make_bar(arr) + assert _state(p)["values"] == pytest.approx(np.array([[1.0], [2.0], [3.0]])) + + def test_default_x_centers(self): + assert _state(_make_bar([5, 6, 7]))["x_centers"] == pytest.approx([0.0, 1.0, 2.0]) + + def test_default_orient_is_vertical(self): + assert _state(_make_bar())["orient"] == "v" + + def test_default_baseline_is_zero(self): + assert _state(_make_bar())["baseline"] == pytest.approx(0.0) + + def test_default_bar_width(self): + assert _state(_make_bar())["bar_width"] == pytest.approx(0.8) + + def test_default_show_values_false(self): + assert _state(_make_bar())["show_values"] is False + + def test_default_color(self): + assert _state(_make_bar())["bar_color"] == "#4fc3f7" + + def test_default_bar_colors_empty(self): + assert _state(_make_bar())["bar_colors"] == [] + + def test_default_x_labels_empty(self): + assert _state(_make_bar())["x_labels"] == [] + + def test_default_units_empty(self): + st = _state(_make_bar()) + assert st["units"] == "" + assert st["y_units"] == "" + + def test_default_groups_is_one(self): + assert _state(_make_bar())["groups"] == 1 + + def test_default_log_scale_false(self): + assert _state(_make_bar())["log_scale"] is False + + +# =========================================================================== +# 2. Construction — explicit / matplotlib-aligned arguments +# =========================================================================== + +class TestPlotBarExplicitArgs: + + def test_x_as_numeric_positions(self): + p = _bar([0, 1, 2], [10, 20, 30]) + st = _state(p) + assert st["x_centers"] == pytest.approx([0.0, 1.0, 2.0]) + assert st["values"] == pytest.approx(np.array([[10.0], [20.0], [30.0]])) + + def test_x_as_string_labels(self): + months = ["Jan", "Feb", "Mar"] + p = _bar(months, [10, 20, 30]) + st = _state(p) + assert st["x_labels"] == months + assert st["x_centers"] == pytest.approx([0.0, 1.0, 2.0]) + + def test_width_parameter(self): + p = _bar([0, 1, 2], [1, 2, 3], width=0.5) + assert _state(p)["bar_width"] == pytest.approx(0.5) + + def test_bottom_parameter(self): + p = _bar([0, 1, 2], [1, 2, 3], bottom=5.0) + assert _state(p)["baseline"] == pytest.approx(5.0) + + def test_orient_h(self): + assert _bar(["A", "B"], [10, 20], orient="h")._state["orient"] == "h" + + def test_orient_v_default(self): + assert _bar([1, 2], [5, 6])._state["orient"] == "v" + + def test_show_values_kwarg(self): + assert _bar([1, 2, 3], [10, 20, 30], show_values=True)._state["show_values"] is True + + def test_custom_color(self): + assert _make_bar(color="#ff0000")._state["bar_color"] == "#ff0000" + + def test_custom_colors_list(self): + palette = ["#ff0000", "#00ff00", "#0000ff"] + p = _bar([1, 2, 3], [10, 20, 30], colors=palette) + assert _state(p)["bar_colors"] == palette + + def test_legacy_x_centers(self): + assert _state(_make_bar([1, 2, 3], x_centers=[10, 20, 30]))["x_centers"] == pytest.approx([10.0, 20.0, 30.0]) + + def test_legacy_x_labels(self): + assert _state(_make_bar([1, 2, 3], x_labels=["A", "B", "C"]))["x_labels"] == ["A", "B", "C"] + + def test_legacy_bar_width(self): + assert _state(_make_bar(bar_width=0.5))["bar_width"] == pytest.approx(0.5) + + def test_legacy_baseline(self): + assert _state(_make_bar(baseline=5.0))["baseline"] == pytest.approx(5.0) + + def test_units_and_y_units(self): + st = _state(_make_bar(units="category", y_units="count")) + assert st["units"] == "category" + assert st["y_units"] == "count" + + def test_axes_bar_returns_plotbar_instance(self): + fig, ax = apl.subplots(1, 1) + assert isinstance(ax.bar([1, 2, 3]), PlotBar) + + def test_orient_invalid_raises(self): + with pytest.raises(ValueError): + _bar([1, 2], [5, 6], orient="diagonal") + + +# =========================================================================== +# 3. Range / padding calculations +# =========================================================================== + +class TestPlotBarRange: + + def test_data_max_exceeds_max_value(self): + assert _state(_make_bar([1, 2, 3, 4, 5]))["data_max"] > 5.0 + + def test_data_min_at_baseline_for_positive_values(self): + assert _state(_make_bar([1, 2, 3, 4, 5], baseline=0.0))["data_min"] <= 0.0 + + def test_negative_values_extend_data_min(self): + assert _state(_make_bar([-3, -1, 0, 2]))["data_min"] < -3.0 + + def test_data_max_gt_data_min(self): + st = _state(_make_bar([1, 2, 3])) + assert st["data_max"] > st["data_min"] + + def test_all_equal_values_padded(self): + st = _state(_make_bar([5, 5, 5])) + assert st["data_max"] > st["data_min"] + + def test_baseline_above_all_values(self): + assert _state(_make_bar([1, 2, 3], baseline=10.0))["data_max"] >= 10.0 + + def test_baseline_below_all_values(self): + assert _state(_make_bar([5, 6, 7], baseline=-5.0))["data_min"] <= -5.0 + + +# =========================================================================== +# 4. Grouped bars +# =========================================================================== + +class TestPlotBarGrouped: + + def test_2d_height_creates_groups(self): + fig, ax = apl.subplots(1, 1) + p = ax.bar(["A", "B", "C"], [[1, 2], [3, 4], [5, 6]]) + st = _state(p) + assert st["groups"] == 2 + assert st["values"] == pytest.approx(np.array([[1, 2], [3, 4], [5, 6]])) + + def test_numpy_2d_height(self): + arr = np.array([[10, 20], [30, 40]]) + fig, ax = apl.subplots(1, 1) + assert _state(ax.bar([0, 1], arr))["groups"] == 2 + + def test_grouped_2d_height_with_group_labels(self): + data = np.array([[1, 2, 3], [4, 5, 6]], dtype=float) + bar = _bar(["A", "B"], data, group_labels=["G1", "G2", "G3"]) + assert bar._state["groups"] == 3 + assert bar._state["group_labels"] == ["G1", "G2", "G3"] + + def test_group_labels_stored(self): + fig, ax = apl.subplots(1, 1) + p = ax.bar(["A", "B"], [[1, 2], [3, 4]], group_labels=["G1", "G2"]) + assert _state(p)["group_labels"] == ["G1", "G2"] + + def test_group_colors_stored(self): + fig, ax = apl.subplots(1, 1) + p = ax.bar(["A", "B"], [[1, 2], [3, 4]], group_colors=["#f00", "#0f0"]) + assert _state(p)["group_colors"] == ["#f00", "#0f0"] + + def test_default_group_colors_assigned_for_multi_group(self): + """Multi-group without explicit group_colors gets a default palette.""" + fig, ax = apl.subplots(1, 1) + gc = _state(ax.bar(["A", "B"], [[1, 2], [3, 4]]))["group_colors"] + assert len(gc) == 2 + assert all(c.startswith("#") for c in gc) + + def test_grouped_default_colors_count(self): + data = np.ones((3, 2)) + assert len(_bar([1, 2, 3], data)._state["group_colors"]) == 2 + + def test_single_group_colors_empty_by_default(self): + assert _state(_make_bar([1, 2, 3]))["group_colors"] == [] + + def test_3d_height_raises(self): + with pytest.raises(ValueError): + _bar([1], np.ones((1, 2, 3))) + + def test_set_data_2d_values(self): + fig, ax = apl.subplots(1, 1) + p = ax.bar(["A", "B"], [[1, 2], [3, 4]]) + p.set_data([[10, 20], [30, 40]]) + assert _state(p)["values"] == pytest.approx(np.array([[10, 20], [30, 40]])) + + def test_set_data_group_count_mismatch_raises(self): + fig, ax = apl.subplots(1, 1) + p = ax.bar(["A", "B"], [[1, 2], [3, 4]]) # groups=2 + with pytest.raises(ValueError, match="Group count"): + p.set_data([[1, 2, 3], [4, 5, 6]]) + + +# =========================================================================== +# 5. Log scale +# =========================================================================== + +class TestPlotBarLogScale: + + def test_log_scale_flag_stored(self): + fig, ax = apl.subplots(1, 1) + p = ax.bar([0, 1, 2], [1, 10, 100], log_scale=True) + assert _state(p)["log_scale"] is True + + def test_log_scale_data_min_positive(self): + """data_min must be > 0 when log_scale=True.""" + fig, ax = apl.subplots(1, 1) + assert _state(ax.bar([0, 1, 2], [1, 10, 100], log_scale=True))["data_min"] > 0.0 + + def test_log_scale_negative_values_clamped(self): + """Negative values are clamped for display, not raised.""" + fig, ax = apl.subplots(1, 1) + st = _state(ax.bar([0, 1, 2], [-5, 10, 100], log_scale=True)) + assert st["log_scale"] is True + assert st["data_min"] > 0.0 + + def test_log_scale_all_negative_clamped(self): + """All-negative values → data_min clamps to 1e-10.""" + fig, ax = apl.subplots(1, 1) + assert _state(ax.bar([0, 1], [-3, -1], log_scale=True))["data_min"] > 0.0 + + def test_set_log_scale_on(self): + p = _make_bar([1, 10, 100]) + p.set_log_scale(True) + st = _state(p) + assert st["log_scale"] is True + assert st["data_min"] > 0.0 + + def test_set_log_scale_off(self): + fig, ax = apl.subplots(1, 1) + p = ax.bar([0, 1, 2], [1, 10, 100], log_scale=True) + p.set_log_scale(False) + assert _state(p)["log_scale"] is False + + def test_set_log_scale_push(self): + fig, ax = apl.subplots(1, 1) + p = ax.bar([0, 1, 2], [1, 10, 100]) + p.set_log_scale(True) + data = json.loads(getattr(fig, f"panel_{p._id}_json")) + assert data["log_scale"] is True + + +# =========================================================================== +# 6. set_data() — value replacement +# =========================================================================== + +class TestPlotBarSetData: + + def test_update_replaces_values(self): + p = _make_bar([1, 2, 3]) + p.set_data([10, 20, 30]) + assert _state(p)["values"] == pytest.approx(np.array([[10.0], [20.0], [30.0]])) + + def test_update_recalculates_data_max(self): + p = _make_bar([1, 2, 3]) + p.set_data([100, 200, 300]) + assert _state(p)["data_max"] > 300.0 + + def test_update_recalculates_data_min(self): + p = _make_bar([1, 2, 3]) + p.set_data([-50, -20, -10]) + assert _state(p)["data_min"] < -50.0 + + def test_update_with_new_x_centers(self): + p = _make_bar([1, 2, 3]) + p.set_data([4, 5, 6], x_centers=[0.5, 1.5, 2.5]) + assert _state(p)["x_centers"] == pytest.approx([0.5, 1.5, 2.5]) + + def test_update_with_new_x(self): + p = _make_bar([1, 2, 3]) + p.set_data([4, 5, 6], x=[0.5, 1.5, 2.5]) + assert _state(p)["x_centers"] == pytest.approx([0.5, 1.5, 2.5]) + + def test_update_with_new_x_labels(self): + p = _make_bar([1, 2, 3], x_labels=["a", "b", "c"]) + p.set_data([4, 5, 6], x_labels=["x", "y", "z"]) + assert _state(p)["x_labels"] == ["x", "y", "z"] + + def test_update_preserves_orient(self): + p = _make_bar([1, 2, 3], orient="h") + p.set_data([4, 5, 6]) + assert _state(p)["orient"] == "h" + + def test_update_preserves_baseline(self): + p = _make_bar([1, 2, 3], baseline=2.0) + p.set_data([10, 20, 30]) + assert _state(p)["baseline"] == pytest.approx(2.0) + + def test_set_data_range_recalculated(self): + bar = _bar([1, 2, 3], [10, 20, 30]) + old_max = bar._state["data_max"] + bar.set_data([100, 200, 300]) + assert bar._state["data_max"] > old_max + + def test_set_data_bad_ndim_raises(self): + p = _make_bar([1, 2, 3]) + with pytest.raises(ValueError, match="1-D or 2-D"): + p.set_data(np.zeros((2, 2, 2))) + + +# =========================================================================== +# 7. Display-setting mutations +# =========================================================================== + +class TestPlotBarDisplayMutations: + + def test_set_color(self): + p = _make_bar() + p.set_color("#abcdef") + assert _state(p)["bar_color"] == "#abcdef" + + def test_set_colors(self): + p = _make_bar([1, 2, 3]) + p.set_colors(["red", "green", "blue"]) + assert _state(p)["bar_colors"] == ["red", "green", "blue"] + + def test_set_show_values_true(self): + p = _make_bar(show_values=False) + p.set_show_values(True) + assert _state(p)["show_values"] is True + + def test_set_show_values_false(self): + p = _make_bar(show_values=True) + p.set_show_values(False) + assert _state(p)["show_values"] is False + + +# =========================================================================== +# 8. _push() / Figure integration +# =========================================================================== + +class TestPlotBarPush: + + def test_panel_trait_exists_after_attach(self): + fig, ax = apl.subplots(1, 1) + p = ax.bar([1, 2, 3]) + assert fig.has_trait(f"panel_{p._id}_json") + + def test_panel_json_contains_kind_bar(self): + fig, ax = apl.subplots(1, 1) + p = ax.bar([1, 2, 3]) + data = json.loads(getattr(fig, f"panel_{p._id}_json")) + assert data["kind"] == "bar" + + def test_panel_json_values_after_update(self): + fig, ax = apl.subplots(1, 1) + p = ax.bar([1, 2, 3]) + p.set_data([7, 8, 9]) + data = json.loads(getattr(fig, f"panel_{p._id}_json")) + assert data["values"] == pytest.approx(np.array([[7.0], [8.0], [9.0]])) + + def test_panel_json_color_after_set_color(self): + fig, ax = apl.subplots(1, 1) + p = ax.bar([1, 2, 3]) + p.set_color("#112233") + data = json.loads(getattr(fig, f"panel_{p._id}_json")) + assert data["bar_color"] == "#112233" + + def test_push_without_figure_is_noop(self): + p = PlotBar([1, 2, 3]) + p._push() # must not raise + + def test_layout_json_kind_bar(self): + fig, ax = apl.subplots(1, 1) + p = ax.bar([10, 20, 30]) + layout = json.loads(fig.layout_json) + panel_spec = next(s for s in layout["panel_specs"] if s["id"] == p._id) + assert panel_spec["kind"] == "bar" + + +# =========================================================================== +# 9. Callback API +# =========================================================================== + +class TestPlotBarCallbacks: + + def test_has_callback_registry(self): + assert isinstance(_make_bar().callbacks, CallbackRegistry) + + def test_on_click_decorator_returns_fn(self): + p = _make_bar() + fn = lambda e: None + result = p.add_event_handler(fn, "pointer_down") + assert result is fn + + def test_on_click_stamps_event_types(self): + p = _make_bar() + + @p.add_event_handler("pointer_down") + def cb(event): pass + + assert hasattr(cb, "_event_types") and "pointer_down" in cb._event_types + + def test_on_click_fires(self): + p = _make_bar() + fired = [] + + @p.add_event_handler("pointer_down") + def cb(event): fired.append(event) + + p.callbacks.fire(Event("pointer_down", p, bar_index=2, value=3.0, + group_index=0)) + assert len(fired) == 1 + + def test_on_click_event_data_with_group(self): + p = _make_bar([10, 20, 30]) + fired = [] + + @p.add_event_handler("pointer_down") + def cb(event): fired.append(event) + + p.callbacks.fire(Event("pointer_down", p, + bar_index=1, value=20.0, + group_index=0, + x_label="B")) + ev = fired[0] + assert ev.bar_index == 1 + assert ev.value == pytest.approx(20.0) + assert ev.group_index == 0 + assert ev.x_label == "B" + + def test_on_click_grouped_event(self): + fig, ax = apl.subplots(1, 1) + p = ax.bar(["A", "B"], [[1, 10], [2, 20]]) + fired = [] + + @p.add_event_handler("pointer_down") + def cb(event): fired.append(event) + + p.callbacks.fire(Event("pointer_down", p, + bar_index=1, group_index=1, + value=20.0, + x_label="B")) + assert fired[0].group_index == 1 + assert fired[0].value == pytest.approx(20.0) + + def test_on_changed_fires(self): + p = _make_bar() + fired = [] + + @p.add_event_handler("pointer_move") + def cb(event): fired.append(event) + + p.callbacks.fire(Event("pointer_move", p)) + assert len(fired) == 1 + + def test_on_click_not_fired_by_on_changed(self): + p = _make_bar() + fired = [] + + @p.add_event_handler("pointer_down") + def cb(event): fired.append(event) + + p.callbacks.fire(Event("pointer_move", p)) + assert fired == [] + + def test_disconnect(self): + p = _make_bar() + fired = [] + + @p.add_event_handler("pointer_down") + def cb(event): fired.append(event) + + p.remove_handler(cb) + p.callbacks.fire(Event("pointer_down", p)) + assert fired == [] + + def test_multiple_on_click_handlers(self): + p = _make_bar() + log = [] + + @p.add_event_handler("pointer_down") + def cb1(event): log.append("a") + + @p.add_event_handler("pointer_down") + def cb2(event): log.append("b") + + p.callbacks.fire(Event("pointer_down", p)) + assert sorted(log) == ["a", "b"] + + +# =========================================================================== +# 10. Widgets +# =========================================================================== + +class TestPlotBarWidgets: + + def test_add_vline_widget(self): + bar = _bar(["A", "B", "C"], [10, 20, 30]) + bar.add_vline_widget(1.5, color="#ff6e40") + assert len(bar._widgets) == 1 + + def test_add_hline_widget(self): + bar = _bar([1, 2, 3], [10, 20, 30]) + bar.add_hline_widget(15.0) + assert len(bar._widgets) == 1 + + def test_add_range_widget(self): + bar = _bar([1, 2, 3], [10, 20, 30]) + bar.add_range_widget(0.5, 2.5) + assert len(bar._widgets) == 1 + + def test_add_point_widget(self): + bar = _bar([1, 2, 3], [10, 20, 30]) + bar.add_point_widget(1.0, 15.0) + assert len(bar._widgets) == 1 + + def test_get_widget_by_id(self): + bar = _bar([1, 2], [10, 20]) + w = bar.add_vline_widget(1.0) + assert bar.get_widget(w.id) is w + + def test_get_widget_missing_raises(self): + bar = _bar([1, 2], [10, 20]) + with pytest.raises(KeyError): + bar.get_widget("nope") + + def test_remove_widget(self): + bar = _bar([1, 2], [10, 20]) + w = bar.add_vline_widget(1.0) + bar.remove_widget(w) + assert len(bar._widgets) == 0 + + def test_remove_widget_missing_raises(self): + bar = _bar([1, 2], [10, 20]) + with pytest.raises(KeyError): + bar.remove_widget("bad") + + def test_list_widgets(self): + bar = _bar([1, 2], [10, 20]) + bar.add_vline_widget(1.0) + bar.add_hline_widget(5.0) + assert len(bar.list_widgets()) == 2 + + def test_clear_widgets(self): + bar = _bar([1, 2], [10, 20]) + bar.add_vline_widget(1.0) + bar.clear_widgets() + assert bar.list_widgets() == [] + + +# =========================================================================== +# 11. Edge cases +# =========================================================================== + +class TestPlotBarEdgeCases: + + def test_single_bar(self): + st = _state(_make_bar([42])) + assert len(st["values"]) == 1 + assert st["data_max"] > st["data_min"] + + def test_large_n(self): + values = list(range(200)) + st = _state(_make_bar(values)) + assert len(st["values"]) == 200 + assert len(st["x_centers"]) == 200 + + def test_all_negative_values(self): + st = _state(_make_bar([-5, -3, -1])) + assert st["data_min"] < -5.0 + assert st["data_max"] >= 0.0 + + def test_mixed_positive_negative(self): + st = _state(_make_bar([-10, 0, 10])) + assert st["data_min"] < -10.0 + assert st["data_max"] > 10.0 + + def test_float_values(self): + assert _state(_make_bar([1.1, 2.2, 3.3]))["values"] == pytest.approx( + np.array([[1.1], [2.2], [3.3]]) + ) + + def test_x_centers_float(self): + assert _state(_make_bar([1, 2, 3], x_centers=[0.5, 1.5, 2.5]))["x_centers"] == pytest.approx( + [0.5, 1.5, 2.5] + ) + + def test_bar_width_zero_boundary(self): + assert _state(_make_bar(bar_width=0.0))["bar_width"] == pytest.approx(0.0) + + def test_bar_width_one_boundary(self): + assert _state(_make_bar(bar_width=1.0))["bar_width"] == pytest.approx(1.0) + + +# =========================================================================== +# 12. Validation errors +# =========================================================================== + +class TestPlotBarValidation: + + def test_3d_values_raises(self): + with pytest.raises(ValueError, match="1-D or 2-D"): + PlotBar(np.zeros((2, 2, 2))) + + def test_invalid_orient_raises(self): + with pytest.raises(ValueError, match="orient"): + PlotBar([1, 2, 3], orient="diagonal") + + def test_x_centers_length_mismatch_raises(self): + with pytest.raises(ValueError, match="length"): + PlotBar([1, 2, 3], x_centers=[0, 1]) + + +# =========================================================================== +# 13. repr +# =========================================================================== + +class TestPlotBarRepr: + + def test_repr_contains_n(self): + assert "n=4" in repr(_make_bar([1, 2, 3, 4])) + + def test_repr_contains_orient_v(self): + assert "orient='v'" in repr(_make_bar([1, 2, 3])) + + def test_repr_contains_orient_h(self): + assert "orient='h'" in repr(_make_bar([1, 2, 3], orient="h")) + + def test_repr_is_string(self): + assert isinstance(repr(_make_bar()), str) + + def test_repr_grouped_shows_groups(self): + fig, ax = apl.subplots(1, 1) + p = ax.bar([0, 1], [[1, 2], [3, 4]]) + assert "groups=2" in repr(p) + assert "n=2" in repr(p) + + def test_repr_contains_plotbar(self): + assert "PlotBar" in repr(_bar([1, 2, 3], [10, 20, 30])) + + +# =========================================================================== +# New state keys added in audit fix +# =========================================================================== + +class TestPlotBarNewStateKeys: + def test_title_default_empty(self): + assert _make_bar()._state["title"] == "" + + def test_x_label_in_state(self): + assert "x_label" in _make_bar()._state + + def test_y_label_in_state(self): + assert "y_label" in _make_bar()._state + + def test_axis_visible_true_by_default(self): + assert _make_bar()._state["axis_visible"] is True + + def test_x_ticks_visible_true_by_default(self): + assert _make_bar()._state["x_ticks_visible"] is True + + def test_y_ticks_visible_true_by_default(self): + assert _make_bar()._state["y_ticks_visible"] is True + + def test_align_stored(self): + assert _make_bar(align="edge")._state["align"] == "edge" + + def test_align_center_by_default(self): + assert _make_bar()._state["align"] == "center" + + def test_y_range_none_by_default(self): + p = _make_bar() + assert "y_range" in p._state + assert p._state["y_range"] is None + + def test_view_from_python_false_by_default(self): + assert _make_bar()._state["_view_from_python"] is False + + +# =========================================================================== +# New display-control methods added in audit fix +# =========================================================================== + +class TestPlotBarDisplayMethods: + def test_set_title(self): + p = _make_bar() + p.set_title("My Chart") + assert p._state["title"] == "My Chart" + + def test_set_xlabel(self): + p = _make_bar() + p.set_xlabel("Category") + assert p._state["x_label"] == "Category" + + def test_set_ylabel(self): + p = _make_bar() + p.set_ylabel("Value") + assert p._state["y_label"] == "Value" + + def test_set_axis_off(self): + p = _make_bar() + p.set_axis_off() + assert p._state["axis_visible"] is False + + def test_set_axis_on_restores(self): + p = _make_bar() + p.set_axis_off() + p.set_axis_on() + assert p._state["axis_visible"] is True + + def test_set_ticks_visible_both_false(self): + p = _make_bar() + p.set_ticks_visible(False) + assert p._state["x_ticks_visible"] is False + assert p._state["y_ticks_visible"] is False + + def test_set_ticks_visible_x_only(self): + p = _make_bar() + p.set_ticks_visible(True, x=True, y=False) + assert p._state["x_ticks_visible"] is True + assert p._state["y_ticks_visible"] is False + + def test_set_ylim(self): + p = _make_bar() + p.set_ylim(0.0, 10.0) + assert p._state["y_range"] == [0.0, 10.0] + + def test_get_ylim_default(self): + p = _make_bar() + lo, hi = p.get_ylim() + assert lo == pytest.approx(p._state["data_min"]) + assert hi == pytest.approx(p._state["data_max"]) + + def test_get_ylim_after_set_ylim(self): + p = _make_bar() + p.set_ylim(-1.0, 20.0) + lo, hi = p.get_ylim() + assert lo == pytest.approx(-1.0) + assert hi == pytest.approx(20.0) + + def test_set_xlim_changes_view(self): + fig, ax = apl.subplots(1, 1) + p = ax.bar(np.arange(10), np.ones(10)) + p.set_xlim(2.0, 7.0) + assert p._state["view_x0"] != 0.0 or p._state["view_x1"] != 1.0 + + def test_reset_view(self): + fig, ax = apl.subplots(1, 1) + p = ax.bar(np.arange(10), np.ones(10)) + p.set_xlim(2.0, 7.0) + p.set_ylim(0.0, 5.0) + p.reset_view() + assert p._state["view_x0"] == pytest.approx(0.0) + assert p._state["view_x1"] == pytest.approx(1.0) + assert p._state["y_range"] is None + + +# =========================================================================== +# _view_from_python flag on PlotBar +# =========================================================================== + +class TestPlotBarViewFromPython: + def test_set_xlim_clears_flag(self): + fig, ax = apl.subplots(1, 1) + p = ax.bar(np.arange(10), np.ones(10)) + p.set_xlim(2.0, 7.0) + assert p._state["_view_from_python"] is False + + def test_reset_view_clears_flag(self): + p = _make_bar() + p.reset_view() + assert p._state["_view_from_python"] is False + + + +# =========================================================================== +# PlotBar: get_xlim and fixed set_ticks_visible signature +# =========================================================================== + +class TestPlotBarGetXlim: + def test_get_xlim_default(self): + p = _make_bar() + x_axis = p._state["x_axis"] + lo, hi = p.get_xlim() + assert lo == pytest.approx(x_axis[0]) + assert hi == pytest.approx(x_axis[-1]) + + def test_get_xlim_after_set_xlim(self): + fig, ax = apl.subplots(1, 1) + p = ax.bar(np.arange(10), np.ones(10)) + p.set_xlim(2.0, 7.0) + lo, hi = p.get_xlim() + assert lo == pytest.approx(2.0, abs=0.5) + assert hi == pytest.approx(7.0, abs=0.5) + + +class TestPlotBarSetTicksVisibleSignature: + def test_positional_visible_both(self): + p = _make_bar() + p.set_ticks_visible(False) + assert p._state["x_ticks_visible"] is False + assert p._state["y_ticks_visible"] is False + + def test_positional_visible_true(self): + p = _make_bar() + p.set_ticks_visible(False) + p.set_ticks_visible(True) + assert p._state["x_ticks_visible"] is True + assert p._state["y_ticks_visible"] is True + + def test_keyword_x_only(self): + p = _make_bar() + p.set_ticks_visible(True, x=False) + assert p._state["x_ticks_visible"] is False + assert p._state["y_ticks_visible"] is True + + def test_keyword_y_only(self): + p = _make_bar() + p.set_ticks_visible(True, y=False) + assert p._state["x_ticks_visible"] is True + assert p._state["y_ticks_visible"] is False + + +# =========================================================================== +# M3: PlotBar constructor-only setters +# =========================================================================== + +class TestPlotBarNewSetters: + def test_set_bar_width(self): + p = _make_bar() + p.set_bar_width(0.5) + assert p._state["bar_width"] == pytest.approx(0.5) + + def test_set_align_center(self): + p = _make_bar() + p.set_align("center") + assert p._state["align"] == "center" + + def test_set_align_edge(self): + p = _make_bar() + p.set_align("edge") + assert p._state["align"] == "edge" + + def test_set_align_invalid(self): + p = _make_bar() + with pytest.raises(ValueError): + p.set_align("left") + + def test_set_orient_h(self): + p = _make_bar() + p.set_orient("h") + assert p._state["orient"] == "h" + + def test_set_orient_v(self): + p = _make_bar() + p.set_orient("v") + assert p._state["orient"] == "v" + + def test_set_orient_invalid(self): + p = _make_bar() + with pytest.raises(ValueError): + p.set_orient("diagonal") + + def test_set_group_labels(self): + p = _make_bar() + p.set_group_labels(["a", "b", "c"]) + assert p._state["group_labels"] == ["a", "b", "c"] + + +# =========================================================================== +# M1/M2: standardized parameter names +# =========================================================================== + +class TestPlotBarParameterNames: + def test_set_title_uses_label_param(self): + import inspect + p = _make_bar() + sig = inspect.signature(p.set_title) + assert "label" in sig.parameters + + def test_set_xlabel_uses_label_param(self): + import inspect + p = _make_bar() + sig = inspect.signature(p.set_xlabel) + assert "label" in sig.parameters + + def test_set_xlim_uses_xmin_xmax(self): + import inspect + p = _make_bar() + sig = inspect.signature(p.set_xlim) + params = list(sig.parameters) + assert params[0] == "xmin" + assert params[1] == "xmax" + + def test_set_title_works(self): + p = _make_bar() + p.set_title(label="My Bar Chart") + assert p._state["title"] == "My Bar Chart" + + +# =========================================================================== +# m2: configure_pointer_settled public on PlotBar +# =========================================================================== + +class TestPlotBarConfigurePointerSettled: + def test_public_method_exists(self): + p = _make_bar() + assert hasattr(p, "configure_pointer_settled") + assert callable(p.configure_pointer_settled) + + def test_sets_state(self): + p = _make_bar() + p.configure_pointer_settled(300, 6) + assert p._state["pointer_settled_ms"] == 300 + assert p._state["pointer_settled_delta"] == 6 diff --git a/anyplotlib/tests/test_plot2d/__init__.py b/anyplotlib/tests/test_plot2d/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/anyplotlib/tests/test_plot2d/test_imshow.py b/anyplotlib/tests/test_plot2d/test_imshow.py new file mode 100644 index 00000000..91c575e3 --- /dev/null +++ b/anyplotlib/tests/test_plot2d/test_imshow.py @@ -0,0 +1,633 @@ +""" +tests/test_plot2d/test_imshow.py +================================= + +Comprehensive tests for Plot2D (imshow). + +Covers: + * Construction: kind, cmap, vmin/vmax, origin, axes, validation + * Colormap: cmap kwarg, LUT building, None default, name property/setter + * vmin/vmax: defaults, overrides, raw_min/raw_max, set_clim post-construction + * Origin: upper/lower storage, y-axis reversal, data flip, set_data re-flip + * Setters: set_colormap, set_clim, set_scale_mode, set_data, data property + * Widgets: add_widget (all kinds), remove_widget, list_widgets, clear_widgets, get_widget + * Markers: add_circles, add_points (uses "circles" wire type on Plot2D) + * View: set_view (x-only, y-only, x+y), reset_view, _view_from_python flag + * Overlay mask: set_overlay_mask, clear, shape/alpha/color validation, origin-lower flip + * Insets: add_inset, minimize, maximize, restore, inset_state + * __repr__ +""" +from __future__ import annotations + +import base64 +import numpy as np +import pytest + +import anyplotlib as apl +from anyplotlib.plot2d import Plot2D + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _img(n=32, **kwargs) -> Plot2D: + """Create a Plot2D attached to a one-panel Figure with deterministic data.""" + fig, ax = apl.subplots(1, 1) + data = np.arange(n * n, dtype=float).reshape(n, n) + return ax.imshow(data, **kwargs) + + +# 4×4 ramp: values 0..15 (row 0 = [0,1,2,3], row 3 = [12,13,14,15]) +DATA = np.arange(16, dtype=float).reshape(4, 4) +X = np.array([1.0, 2.0, 3.0, 4.0]) +Y = np.array([10.0, 20.0, 30.0, 40.0]) + + +def _decoded(v: Plot2D) -> np.ndarray: + """Return the stored uint8 image as a (H, W) array.""" + raw = base64.b64decode(v._state["image_b64"]) + return np.frombuffer(raw, dtype=np.uint8).reshape( + v._state["image_height"], v._state["image_width"] + ) + + +# =========================================================================== +# Construction +# =========================================================================== + +class TestImshowConstruction: + + def test_kind_is_2d(self): + v = _img() + assert v._state["kind"] == "2d" + + def test_3d_data_squeezed(self): + """3-D input with one channel should be accepted (first channel used).""" + data = np.zeros((8, 8, 3)) + fig, ax = apl.subplots(1, 1) + v = ax.imshow(data) + assert v._state["image_width"] == 8 + + def test_with_physical_axes(self): + data = np.zeros((8, 8)) + x = np.linspace(0, 1, 8) + y = np.linspace(0, 1, 8) + fig, ax = apl.subplots(1, 1) + v = ax.imshow(data, axes=[x, y], units="nm") + assert v._state["has_axes"] is True + assert v._state["units"] == "nm" + + def test_bad_data_shape_1d(self): + with pytest.raises(ValueError): + fig, ax = apl.subplots(1, 1) + ax.imshow(np.zeros(16)) + + +# =========================================================================== +# Colormap +# =========================================================================== + +class TestImshowColormap: + + def test_default_cmap_is_gray(self): + fig, ax = apl.subplots() + v = ax.imshow(DATA) + assert v._state["colormap_name"] == "gray" + + def test_cmap_kwarg(self): + fig, ax = apl.subplots() + v = ax.imshow(DATA, cmap="viridis") + assert v._state["colormap_name"] == "viridis" + + def test_cmap_builds_lut(self): + fig, ax = apl.subplots() + v = ax.imshow(DATA, cmap="inferno") + lut = v._state["colormap_data"] + assert len(lut) == 256 + assert len(lut[0]) == 3 # [r, g, b] + + def test_cmap_none_uses_gray(self): + fig, ax = apl.subplots() + v = ax.imshow(DATA, cmap=None) + assert v._state["colormap_name"] == "gray" + + def test_colormap_name_property(self): + v = _img(cmap="viridis") + assert v.colormap_name == "viridis" + + def test_colormap_name_setter(self): + v = _img() + v.colormap_name = "inferno" + assert v._state["colormap_name"] == "inferno" + + +# =========================================================================== +# vmin / vmax +# =========================================================================== + +class TestImshowVminVmax: + + def test_default_uses_data_range(self): + fig, ax = apl.subplots() + v = ax.imshow(DATA) + assert v._state["display_min"] == pytest.approx(0.0) + assert v._state["display_max"] == pytest.approx(15.0) + + def test_vmin_sets_display_min(self): + fig, ax = apl.subplots() + v = ax.imshow(DATA, vmin=3.0) + assert v._state["display_min"] == pytest.approx(3.0) + assert v._state["display_max"] == pytest.approx(15.0) # unchanged + + def test_vmax_sets_display_max(self): + fig, ax = apl.subplots() + v = ax.imshow(DATA, vmax=12.0) + assert v._state["display_min"] == pytest.approx(0.0) # unchanged + assert v._state["display_max"] == pytest.approx(12.0) + + def test_vmin_vmax_together(self): + fig, ax = apl.subplots() + v = ax.imshow(DATA, vmin=3.0, vmax=12.0) + assert v._state["display_min"] == pytest.approx(3.0) + assert v._state["display_max"] == pytest.approx(12.0) + + def test_raw_range_unaffected_by_vmin_vmax(self): + """raw_min/raw_max always reflect the actual data range.""" + fig, ax = apl.subplots() + v = ax.imshow(DATA, vmin=3.0, vmax=12.0) + assert v._state["raw_min"] == pytest.approx(0.0) + assert v._state["raw_max"] == pytest.approx(15.0) + + def test_set_clim_still_works_after_construction(self): + fig, ax = apl.subplots() + v = ax.imshow(DATA, vmin=3.0, vmax=12.0) + v.set_clim(vmin=1.0, vmax=14.0) + assert v._state["display_min"] == pytest.approx(1.0) + assert v._state["display_max"] == pytest.approx(14.0) + + +# =========================================================================== +# Origin +# =========================================================================== + +class TestImshowOrigin: + + def test_upper_is_default(self): + fig, ax = apl.subplots() + v = ax.imshow(DATA) + assert v._origin == "upper" + + def test_upper_keeps_y_axis_order(self): + fig, ax = apl.subplots() + v = ax.imshow(DATA, axes=[X, Y], origin="upper") + assert v._state["y_axis"][0] == pytest.approx(10.0) # top of image + assert v._state["y_axis"][-1] == pytest.approx(40.0) # bottom + + def test_upper_row0_at_top(self): + """With origin='upper', row 0 of data (min values) is stored first.""" + fig, ax = apl.subplots() + v = ax.imshow(DATA, origin="upper") + stored = _decoded(v) + assert stored[0, 0] == 0 # row 0, col 0 → value 0 → uint8 min + + def test_lower_stored(self): + v = _img(origin="lower") + assert v._origin == "lower" + + def test_lower_reverses_y_axis_with_axes(self): + fig, ax = apl.subplots() + v = ax.imshow(DATA, axes=[X, Y], origin="lower") + assert v._state["y_axis"][0] == pytest.approx(40.0) # max at top + assert v._state["y_axis"][-1] == pytest.approx(10.0) # min at bottom + + def test_lower_default_y_axis_reversed(self): + """Without explicit axes, origin='lower' still reverses default y.""" + fig, ax = apl.subplots() + v = ax.imshow(DATA, origin="lower") + assert v._state["y_axis"][0] > v._state["y_axis"][-1] + + def test_lower_flips_data(self): + """With origin='lower', row 0 of original data appears at the bottom.""" + fig, ax = apl.subplots() + v = ax.imshow(DATA, origin="lower") + stored = _decoded(v) + assert stored[0, :].max() == 255 # top row contains the global max + assert stored[-1, :].min() == 0 # bottom row contains the global min + + def test_lower_set_data_reapplies_flip(self): + """set_data() with origin='lower' automatically re-flips new data.""" + fig, ax = apl.subplots() + v = ax.imshow(DATA, origin="lower") + v.set_data(DATA) + stored = _decoded(v) + assert stored[0, :].max() == 255 + assert stored[-1, :].min() == 0 + + def test_lower_set_data_reverses_new_y_axis(self): + fig, ax = apl.subplots() + v = ax.imshow(DATA, origin="lower") + v.set_data(DATA, y_axis=Y) + assert v._state["y_axis"][0] == pytest.approx(40.0) + assert v._state["y_axis"][-1] == pytest.approx(10.0) + + def test_data_property_origin_lower(self): + """data property should undo the internal flipud for origin='lower'.""" + data = np.arange(64, dtype=float).reshape(8, 8) + fig, ax = apl.subplots(1, 1) + v = ax.imshow(data, origin="lower") + np.testing.assert_array_equal(v.data, data) + + def test_invalid_origin_raises(self): + fig, ax = apl.subplots() + with pytest.raises(ValueError, match="origin"): + ax.imshow(DATA, origin="diagonal") + + def test_combined_params(self): + fig, ax = apl.subplots() + v = ax.imshow(DATA, cmap="inferno", vmin=2.0, vmax=13.0, + origin="lower", axes=[X, Y]) + assert v._state["colormap_name"] == "inferno" + assert v._state["display_min"] == pytest.approx(2.0) + assert v._state["display_max"] == pytest.approx(13.0) + assert v._state["y_axis"][0] == pytest.approx(40.0) # reversed + stored = _decoded(v) + assert stored[0, :].max() == 255 # flipped: top row has max value + + +# =========================================================================== +# Setters and data property +# =========================================================================== + +class TestImshowSetters: + + def test_set_colormap(self): + v = _img() + v.set_colormap("plasma") + assert v._state["colormap_name"] == "plasma" + assert isinstance(v._state["colormap_data"], list) + + def test_set_clim_vmin(self): + v = _img() + v.set_clim(vmin=0.1) + assert v._state["display_min"] == pytest.approx(0.1) + + def test_set_clim_vmax(self): + v = _img() + v.set_clim(vmax=0.9) + assert v._state["display_max"] == pytest.approx(0.9) + + def test_set_clim_both(self): + v = _img() + v.set_clim(vmin=0.0, vmax=0.8) + assert v._state["display_min"] == pytest.approx(0.0) + assert v._state["display_max"] == pytest.approx(0.8) + + def test_set_scale_mode_log(self): + v = _img() + v.set_scale_mode("log") + assert v._state["scale_mode"] == "log" + + def test_set_scale_mode_invalid(self): + v = _img() + with pytest.raises(ValueError): + v.set_scale_mode("square_root") + + def test_set_data_replaces(self): + v = _img() + new = np.ones((32, 32)) + v.set_data(new) + assert v._state["image_width"] == 32 + assert v._state["image_height"] == 32 + + def test_set_data_updates_units(self): + v = _img() + v.set_data(np.zeros((32, 32)), units="Å") + assert v._state["units"] == "Å" + + def test_set_data_bad_shape(self): + v = _img() + with pytest.raises(ValueError): + v.set_data(np.zeros(16)) + + def test_data_property_readonly(self): + v = _img() + arr = v.data + assert not arr.flags.writeable + + +# =========================================================================== +# Widgets +# =========================================================================== + +class TestImshowWidgets: + + def test_add_circle_widget(self): + v = _img(n=64) + w = v.add_widget("circle", cx=32, cy=32, r=10) + assert w is not None + assert len(v._widgets) == 1 + + def test_add_rectangle_widget(self): + v = _img(n=64) + v.add_widget("rectangle") + assert len(v._widgets) == 1 + + def test_add_annular_widget(self): + v = _img(n=64) + v.add_widget("annular", r_outer=20, r_inner=10) + assert len(v._widgets) == 1 + + def test_add_polygon_widget(self): + v = _img(n=64) + v.add_widget("polygon") + assert len(v._widgets) == 1 + + def test_add_crosshair_widget(self): + v = _img(n=64) + v.add_widget("crosshair", cx=32, cy=32) + assert len(v._widgets) == 1 + + def test_add_label_widget(self): + v = _img(n=64) + v.add_widget("label", text="hello") + assert len(v._widgets) == 1 + + def test_bad_widget_kind(self): + v = _img(n=64) + with pytest.raises(ValueError): + v.add_widget("star") + + def test_remove_widget(self): + v = _img(n=64) + w = v.add_widget("circle") + v.remove_widget(w) + assert len(v._widgets) == 0 + + def test_list_widgets(self): + v = _img(n=64) + v.add_widget("circle") + v.add_widget("crosshair") + assert len(v.list_widgets()) == 2 + + def test_clear_widgets(self): + v = _img(n=64) + v.add_widget("circle") + v.clear_widgets() + assert v.list_widgets() == [] + + +# =========================================================================== +# Markers (add_circles / add_points on Plot2D) +# =========================================================================== + +class TestImshowMarkers: + + def test_add_circles_does_not_crash(self): + """add_circles on a Plot2D must not raise ValueError.""" + plot = _img() + offsets = np.array([[8.0, 8.0], [16.0, 16.0]]) + mg = plot.add_circles(offsets, name="g1", radius=3) + assert mg is not None + wire = plot.markers.to_wire_list() + assert len(wire) == 1 + assert wire[0]["type"] == "circles" + + def test_add_circles_radius_in_wire(self): + """add_circles must pass radius embedded as 'sizes' in wire format.""" + plot = _img() + offsets = np.array([[4.0, 4.0]]) + plot.add_circles(offsets, name="c1", radius=7) + wire = plot.markers.to_wire_list() + assert wire[0]["type"] == "circles" + sizes = wire[0].get("sizes") + assert sizes is not None and all(s == 7.0 for s in sizes) + + def test_add_points_uses_circles_type(self): + """add_points on a Plot2D must use the 'circles' wire type, not 'points'.""" + plot = _img() + offsets = np.array([[8.0, 8.0]]) + mg = plot.add_points(offsets, name="p1", sizes=5) + assert mg is not None + wire = plot.markers.to_wire_list() + assert wire[0]["type"] == "circles" + + +# =========================================================================== +# View: set_view / reset_view +# =========================================================================== + +class TestImshowView: + + def _make_with_x_axis(self, shape=(32, 32)): + data = np.zeros(shape) + x_axis = np.linspace(0.0, float(shape[1]), shape[1]) + fig, ax = apl.subplots(1, 1) + return ax.imshow(data, axes=[x_axis, None]) + + def test_set_view_x_only(self): + """set_view(x0, x1) must update center_x and zoom, not view_x0/view_x1.""" + plot = self._make_with_x_axis() + plot.set_view(x0=8.0, x1=24.0) + # center_x should be midpoint fraction: (8+24)/2 / 32 = 0.5 + assert abs(plot._state["center_x"] - 0.5) < 1e-6 + # zoom_x = 32 / (24-8) = 2.0 + assert abs(plot._state["zoom"] - 2.0) < 1e-6 + assert "view_x0" not in plot._state + assert "view_x1" not in plot._state + + def test_set_view_y_only(self): + """set_view(y0=..., y1=...) must update center_y and zoom.""" + data = np.zeros((32, 32)) + y_axis = np.linspace(0.0, 32.0, 32) + fig, ax = apl.subplots(1, 1) + plot = ax.imshow(data, axes=[None, y_axis]) + plot.set_view(y0=8.0, y1=24.0) + assert abs(plot._state["center_y"] - 0.5) < 1e-6 + assert abs(plot._state["zoom"] - 2.0) < 1e-6 + + def test_set_view_xy(self): + """set_view(x0, x1, y0, y1) uses minimum zoom when both axes given.""" + data = np.zeros((32, 64)) + x_axis = np.linspace(0.0, 64.0, 64) + y_axis = np.linspace(0.0, 32.0, 32) + fig, ax = apl.subplots(1, 1) + plot = ax.imshow(data, axes=[x_axis, y_axis]) + plot.set_view(x0=0, x1=32, y0=0, y1=16) + zoom_x = 64.0 / 32.0 # = 2.0 + zoom_y = 32.0 / 16.0 # = 2.0 + expected_zoom = min(zoom_x, zoom_y) + assert abs(plot._state["zoom"] - expected_zoom) < 1e-6 + + def test_reset_view(self): + """reset_view must restore zoom=1, center_x=0.5, center_y=0.5.""" + plot = _img() + plot.set_view(x0=4, x1=28) + plot.reset_view() + assert plot._state["zoom"] == 1.0 + assert plot._state["center_x"] == 0.5 + assert plot._state["center_y"] == 0.5 + assert "view_x0" not in plot._state + assert "view_x1" not in plot._state + + def test_view_from_python_flag_set_view(self): + """set_view() sets _view_from_python briefly; it is False after push.""" + plot = self._make_with_x_axis() + plot.set_view(x0=8.0, x1=24.0) + assert plot._state["_view_from_python"] is False + + def test_view_from_python_flag_reset_view(self): + """reset_view() sets _view_from_python briefly; it is False after push.""" + plot = _img() + plot.reset_view() + assert plot._state["_view_from_python"] is False + + +# =========================================================================== +# Overlay mask +# =========================================================================== + +class TestImshowOverlayMask: + + def test_set_overlay_mask_sets_state(self): + plot = _img(n=16) + mask = np.zeros((16, 16), dtype=bool) + mask[4:12, 4:12] = True + plot.set_overlay_mask(mask) + assert plot._state["overlay_mask_b64"] != "" + assert plot._state["overlay_mask_color"] == "#ff4444" + assert plot._state["overlay_mask_alpha"] == 0.4 + + def test_set_overlay_mask_clear(self): + plot = _img(n=16) + mask = np.ones((16, 16), dtype=bool) + plot.set_overlay_mask(mask) + assert plot._state["overlay_mask_b64"] != "" + plot.set_overlay_mask(None) + assert plot._state["overlay_mask_b64"] == "" + + def test_set_overlay_mask_shape_mismatch(self): + fig, ax = apl.subplots(1, 1) + plot = ax.imshow(np.zeros((16, 32))) + bad_mask = np.zeros((8, 8), dtype=bool) + with pytest.raises(ValueError, match="mask shape"): + plot.set_overlay_mask(bad_mask) + + def test_set_overlay_mask_alpha_boundary(self): + plot = _img(n=16) + mask = np.zeros((16, 16), dtype=bool) + plot.set_overlay_mask(mask, alpha=0.0) + assert plot._state["overlay_mask_alpha"] == 0.0 + plot.set_overlay_mask(mask, alpha=1.0) + assert plot._state["overlay_mask_alpha"] == 1.0 + + def test_set_overlay_mask_alpha_out_of_range(self): + plot = _img(n=16) + mask = np.zeros((16, 16), dtype=bool) + with pytest.raises(ValueError, match="alpha"): + plot.set_overlay_mask(mask, alpha=1.5) + with pytest.raises(ValueError, match="alpha"): + plot.set_overlay_mask(mask, alpha=-0.1) + + def test_set_overlay_mask_valid_color(self): + plot = _img(n=16) + mask = np.zeros((16, 16), dtype=bool) + plot.set_overlay_mask(mask, color="#aabbcc") + assert plot._state["overlay_mask_color"] == "#aabbcc" + + def test_set_overlay_mask_invalid_color(self): + plot = _img(n=16) + mask = np.zeros((16, 16), dtype=bool) + with pytest.raises(ValueError, match="color"): + plot.set_overlay_mask(mask, color="red") + with pytest.raises(ValueError, match="color"): + plot.set_overlay_mask(mask, color="#fff") + with pytest.raises(ValueError, match="color"): + plot.set_overlay_mask(mask, color="#GGGGGG") + + def test_set_overlay_mask_origin_lower_flips(self): + """For origin='lower' the mask is flipped to match the internally-flipped image.""" + fig, ax = apl.subplots(1, 1) + data = np.zeros((4, 4)) + plot = ax.imshow(data, origin="lower") + mask = np.zeros((4, 4), dtype=bool) + mask[0, :] = True # only the top row + plot.set_overlay_mask(mask) + raw = base64.b64decode(plot._state["overlay_mask_b64"]) + stored = np.frombuffer(raw, dtype=np.uint8).reshape(4, 4) + # After flipud the True row should be at the last row (index 3), not row 0 + assert stored[3, 0] == 255 + assert stored[0, 0] == 0 + + +# =========================================================================== +# Insets +# =========================================================================== + +class TestImshowInsets: + + def _fig_with_inset(self, **kwargs): + fig, ax = apl.subplots(1, 1, figsize=(500, 500)) + ax.imshow(np.zeros((64, 64))) + inset = fig.add_inset(0.25, 0.25, **kwargs) + return fig, inset + + def test_add_inset_returns_axes(self): + fig, inset = self._fig_with_inset(title="Test") + assert inset is not None + + def test_inset_default_state(self): + fig, inset = self._fig_with_inset() + assert inset.inset_state == "normal" + + def test_inset_minimize(self): + fig, inset = self._fig_with_inset() + inset.minimize() + assert inset.inset_state == "minimized" + + def test_inset_maximize(self): + fig, inset = self._fig_with_inset() + inset.maximize() + assert inset.inset_state == "maximized" + + def test_inset_restore(self): + fig, inset = self._fig_with_inset() + inset.minimize() + inset.restore() + assert inset.inset_state == "normal" + + def test_inset_with_plot(self): + fig, ax = apl.subplots(1, 1, figsize=(500, 500)) + ax.imshow(np.zeros((64, 64))) + inset = fig.add_inset(0.3, 0.3, corner="top-right", title="Profile") + inset.plot(np.sin(np.linspace(0, 2 * np.pi, 64)), color="#4fc3f7") + + def test_inset_with_imshow(self): + fig, ax = apl.subplots(1, 1, figsize=(500, 500)) + ax.imshow(np.zeros((64, 64))) + inset = fig.add_inset(0.3, 0.3, corner="bottom-left") + inset.imshow(np.ones((32, 32)), cmap="hot") + + def test_multiple_insets_same_corner(self): + fig, ax = apl.subplots(1, 1, figsize=(600, 600)) + ax.imshow(np.zeros((64, 64))) + i1 = fig.add_inset(0.25, 0.25, corner="top-right", title="I1") + i2 = fig.add_inset(0.25, 0.25, corner="top-right", title="I2") + assert i1 is not i2 + + +# =========================================================================== +# __repr__ +# =========================================================================== + +class TestImshowRepr: + + def test_repr_contains_dimensions_and_cmap(self): + fig, ax = apl.subplots(1, 1) + plot = ax.imshow(np.zeros((128, 256))) + r = repr(plot) + assert "Plot2D" in r + assert "256" in r # width + assert "128" in r # height + assert "gray" in r # default colormap + diff --git a/anyplotlib/tests/test_plot2d/test_imshow_rgb.py b/anyplotlib/tests/test_plot2d/test_imshow_rgb.py new file mode 100644 index 00000000..55776aa6 --- /dev/null +++ b/anyplotlib/tests/test_plot2d/test_imshow_rgb.py @@ -0,0 +1,117 @@ +""" +Tests for true-colour (H, W, 3|4) imshow support. + +Unit tests cover state encoding and dtype handling; Playwright tests verify +actual rendered pixel colours on the canvas. +""" +from __future__ import annotations + +import base64 + +import numpy as np +import pytest + +import anyplotlib as apl + + +def _rgb_quadrants(n=32): + """Image with pure-red TL, pure-green TR, pure-blue BL, white BR.""" + img = np.zeros((n, n, 3), dtype=np.uint8) + h = n // 2 + img[:h, :h] = [255, 0, 0] + img[:h, h:] = [0, 255, 0] + img[h:, :h] = [0, 0, 255] + img[h:, h:] = [255, 255, 255] + return img + + +class TestRgbState: + def test_uint8_rgb_sets_state(self): + fig, ax = apl.subplots(1, 1) + v = ax.imshow(_rgb_quadrants()) + assert v._state["is_rgb"] is True + raw = base64.b64decode(v._state["image_b64"]) + assert len(raw) == 32 * 32 * 4 # RGBA bytes + assert raw[0:4] == bytes([255, 0, 0, 255]) + + def test_float_01_rgb_scaled(self): + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.full((4, 4, 3), 0.5)) + raw = base64.b64decode(v._state["image_b64"]) + assert raw[0] == 127 or raw[0] == 128 # 0.5 * 255 + + def test_rgba_alpha_preserved(self): + img = np.zeros((4, 4, 4), dtype=np.uint8) + img[..., 3] = 99 + fig, ax = apl.subplots(1, 1) + v = ax.imshow(img) + raw = base64.b64decode(v._state["image_b64"]) + assert raw[3] == 99 + + def test_grayscale_unchanged(self): + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((8, 8))) + assert v._state["is_rgb"] is False + assert len(base64.b64decode(v._state["image_b64"])) == 64 # 1 byte/px + + def test_two_channel_raises(self): + fig, ax = apl.subplots(1, 1) + with pytest.raises(ValueError, match="3 .RGB. or 4"): + ax.imshow(np.zeros((8, 8, 2))) + + def test_set_data_switches_modes(self): + fig, ax = apl.subplots(1, 1) + v = ax.imshow(np.zeros((8, 8))) + v.set_data(_rgb_quadrants(8)) + assert v._state["is_rgb"] is True + v.set_data(np.zeros((8, 8))) + assert v._state["is_rgb"] is False + + def test_origin_lower_flips_rgb(self): + img = np.zeros((2, 2, 3), dtype=np.uint8) + img[0, 0] = [255, 0, 0] # red in row 0 + fig, ax = apl.subplots(1, 1) + v = ax.imshow(img, origin="lower") + raw = base64.b64decode(v._state["image_b64"]) + # flipud → red pixel is now in the LAST row, first column + last_row_first_px = raw[(2 * 1 + 0) * 4: (2 * 1 + 0) * 4 + 4] + assert last_row_first_px == bytes([255, 0, 0, 255]) + + +class TestRgbRendering: + def test_quadrant_colors_on_canvas(self, interact_page): + fig, ax = apl.subplots(1, 1, figsize=(300, 300)) + ax.imshow(_rgb_quadrants()) + page = interact_page(fig) + page.wait_for_timeout(150) + + px = page.evaluate("""() => { + const c = document.querySelector('canvas'); + const ctx = c.getContext('2d'); + const w = c.width, h = c.height; + const grab = (fx, fy) => Array.from( + ctx.getImageData(Math.round(w*fx), Math.round(h*fy), 1, 1).data); + return { tl: grab(0.25, 0.25), tr: grab(0.75, 0.25), + bl: grab(0.25, 0.75), br: grab(0.75, 0.75) }; + }""") + assert px["tl"][:3] == [255, 0, 0], f"top-left not red: {px['tl']}" + assert px["tr"][:3] == [0, 255, 0], f"top-right not green: {px['tr']}" + assert px["bl"][:3] == [0, 0, 255], f"bottom-left not blue: {px['bl']}" + assert px["br"][:3] == [255, 255, 255], f"bottom-right not white: {px['br']}" + + def test_colorbar_suppressed_for_rgb(self, interact_page): + fig, ax = apl.subplots(1, 1, figsize=(300, 300)) + q = np.linspace(0, 1, 32) + v = ax.imshow(_rgb_quadrants(), axes=[q, q]) + v.set_colorbar_visible(True) # must be ignored for RGB + page = interact_page(fig) + page.wait_for_timeout(150) + visible = page.evaluate("""() => { + for (const c of document.querySelectorAll('canvas')) { + const left = parseFloat(c.style.left || '0'); + if (c.width <= 80 && left > 150 && c.style.display !== 'none') + return true; // a visible colorbar-sized canvas + } + return false; + }""") + assert not visible, "colorbar must stay hidden for RGB images" diff --git a/anyplotlib/tests/test_plot2d/test_pcolormesh.py b/anyplotlib/tests/test_plot2d/test_pcolormesh.py new file mode 100644 index 00000000..d62dee83 --- /dev/null +++ b/anyplotlib/tests/test_plot2d/test_pcolormesh.py @@ -0,0 +1,194 @@ +""" +tests/test_plot2d/test_pcolormesh.py +===================================== + +Tests for PlotMesh (pcolormesh) mirroring Examples/plot_pcolormesh.py. + +Covers: + * Basic construction with non-uniform edges + * Edge-count validation (wrong x/y edge count) + * set_colormap() — name and LUT update + * set_data() — replacement, units, wrong ndim, wrong edge count + * Markers: add_circles, add_lines, labels, mutate via .set() + * Marker restrictions: arrows and ellipses disallowed on mesh + * to_wire_list round-trip +""" +from __future__ import annotations + +import numpy as np +import pytest + +import anyplotlib as apl +from anyplotlib.plot2d import PlotMesh + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _mesh(M=8, N=12) -> PlotMesh: + rng = np.random.default_rng(42) + data = rng.standard_normal((M, N)) + x_edges = np.linspace(0, N, N + 1) + y_edges = np.linspace(0, M, M + 1) + fig, ax = apl.subplots(1, 1) + return ax.pcolormesh(data, x_edges=x_edges, y_edges=y_edges) + + +def _log_mesh() -> PlotMesh: + """Mesh with non-uniform (log-spaced) x edges, as in the gallery example.""" + M, N = 32, 48 + rng = np.random.default_rng(1) + data = (np.sin(np.linspace(0, 3 * np.pi, N)) + + np.cos(np.linspace(0, 2 * np.pi, M))[:, None]) + data += rng.normal(scale=0.15, size=(M, N)) + x_edges = np.logspace(-1, 2, N + 1) + y_edges = np.linspace(0, 100, M + 1) + fig, ax = apl.subplots(1, 1) + return ax.pcolormesh(data, x_edges=x_edges, y_edges=y_edges, units="arb.") + + +# =========================================================================== +# Construction +# =========================================================================== + +class TestPlotMeshConstruction: + + def test_kind_is_2d(self): + mesh = _mesh() + assert mesh._state["kind"] == "2d" + + def test_is_mesh_flag(self): + mesh = _mesh() + assert mesh._state["is_mesh"] is True + + def test_x_axis_has_edges(self): + mesh = _mesh(M=8, N=12) + # x_axis stores edges (N+1 values) + assert len(mesh._state["x_axis"]) == 13 + + def test_y_axis_has_edges(self): + mesh = _mesh(M=8, N=12) + assert len(mesh._state["y_axis"]) == 9 + + def test_units_stored(self): + mesh = _log_mesh() + assert mesh._state["units"] == "arb." + + def test_log_x_edges_accepted(self): + """Non-uniform (log-spaced) edges should be accepted without error.""" + mesh = _log_mesh() + assert mesh._state["image_width"] == 48 + + def test_default_colormap_present(self): + mesh = _mesh() + assert "colormap_name" in mesh._state + + def test_wrong_x_edge_count(self): + data = np.ones((8, 12)) + x_edges = np.linspace(0, 10, 10) # should be 13 + y_edges = np.linspace(0, 8, 9) + with pytest.raises(ValueError): + fig, ax = apl.subplots(1, 1) + ax.pcolormesh(data, x_edges=x_edges, y_edges=y_edges) + + def test_wrong_y_edge_count(self): + data = np.ones((8, 12)) + x_edges = np.linspace(0, 12, 13) + y_edges = np.linspace(0, 10, 5) # should be 9 + with pytest.raises(ValueError): + fig, ax = apl.subplots(1, 1) + ax.pcolormesh(data, x_edges=x_edges, y_edges=y_edges) + + +# =========================================================================== +# Mutations +# =========================================================================== + +class TestPlotMeshMutations: + + def test_set_colormap(self): + mesh = _mesh() + mesh.set_colormap("viridis") + assert mesh._state["colormap_name"] == "viridis" + + def test_set_colormap_updates_lut(self): + mesh = _mesh() + mesh.set_colormap("plasma") + lut = mesh._state["colormap_data"] + assert isinstance(lut, list) + assert len(lut) == 256 + + def test_set_data_same_shape(self): + mesh = _mesh(M=8, N=12) + new_data = np.ones((8, 12)) + mesh.set_data(new_data) + assert mesh._state["image_width"] == 12 + + def test_set_data_with_new_units(self): + mesh = _mesh() + mesh.set_data(np.zeros((8, 12)), units="nm") + assert mesh._state["units"] == "nm" + + def test_set_data_wrong_ndim(self): + mesh = _mesh() + with pytest.raises(ValueError): + mesh.set_data(np.zeros(12)) + + def test_set_data_wrong_x_edges(self): + mesh = _mesh(M=8, N=12) + new_data = np.zeros((8, 12)) + bad_x = np.linspace(0, 10, 5) + with pytest.raises(ValueError): + mesh.set_data(new_data, x_edges=bad_x) + + +# =========================================================================== +# Markers +# =========================================================================== + +class TestPlotMeshMarkers: + + def test_add_circles(self): + mesh = _mesh() + pts = np.array([[2.0, 2.0], [6.0, 4.0]]) + mesh.add_circles(pts, name="peaks", radius=0.5, edgecolors="#ff1744") + assert "peaks" in mesh.markers["circles"] + + def test_add_circles_with_labels(self): + mesh = _mesh() + pts = np.array([[1.0, 2.0], [5.0, 4.0], [9.0, 6.0], [11.0, 2.0]]) + mesh.add_circles(pts, name="pks", radius=0.3, + edgecolors="#ff1744", facecolors="#ff174433", + labels=["A", "B", "C", "D"]) + wl = mesh.markers.to_wire_list() + assert any(w.get("labels") == ["A", "B", "C", "D"] for w in wl) + + def test_add_lines(self): + mesh = _mesh() + segs = [[[1.0, 1.0], [5.0, 5.0]], [[5.0, 5.0], [10.0, 2.0]]] + mesh.add_lines(segs, name="path", edgecolors="#00e5ff") + assert "path" in mesh.markers["lines"] + + def test_arrows_disallowed_on_mesh(self): + mesh = _mesh() + with pytest.raises(ValueError, match="not allowed"): + mesh.add_arrows([[0.0, 0.0]], [1.0], [1.0]) + + def test_ellipses_disallowed_on_mesh(self): + mesh = _mesh() + with pytest.raises(ValueError, match="not allowed"): + mesh.add_ellipses([[0.0, 0.0]], widths=5, heights=3) + + def test_circles_mutate_via_set(self): + mesh = _mesh() + mesh.add_circles([[2.0, 2.0]], name="c", radius=1.0) + mesh.markers["circles"]["c"].set(radius=2.0) + assert mesh.markers["circles"]["c"]._data["radius"] == 2.0 + + def test_to_wire_list_contains_circles(self): + mesh = _mesh() + mesh.add_circles([[2.0, 2.0]], name="spot") + wl = mesh.markers.to_wire_list() + assert any(w["type"] == "circles" for w in wl) + diff --git a/anyplotlib/tests/test_plot2d/test_plot2d_api.py b/anyplotlib/tests/test_plot2d/test_plot2d_api.py new file mode 100644 index 00000000..5276cf76 --- /dev/null +++ b/anyplotlib/tests/test_plot2d/test_plot2d_api.py @@ -0,0 +1,506 @@ +""" +tests/test_plot2d/test_plot2d_api.py +===================================== +Cross-cutting API and regression tests for the anyplotlib plot2d module. +Covers: + * __repr__ for Plot1D, Plot2D, Plot3D, PlotBar + * Plot1D.add_circles still uses "points" wire type (regression guard) + * cividis colormap alias resolves to a valid colorcet palette + * Top-level public imports: Plot1D, Plot2D, Axes, CallbackRegistry, Event + * __all__ completeness: all names in anyplotlib.__all__ exist on the module + * No debug print in Figure._on_event +""" +from __future__ import annotations +import numpy as np +import pytest +import anyplotlib as apl +from anyplotlib.plot1d import Plot1D, PlotBar +from anyplotlib.plot2d import Plot2D +from anyplotlib.plot3d import Plot3D +from anyplotlib.callbacks import CallbackRegistry, Event +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +def _make_plot2d(shape=(32, 32)) -> Plot2D: + fig, ax = apl.subplots(1, 1) + return ax.imshow(np.zeros(shape)) +def _make_plot1d(n=64) -> Plot1D: + fig, ax = apl.subplots(1, 1) + return ax.plot(np.zeros(n)) +def _make_plot3d() -> Plot3D: + fig, ax = apl.subplots(1, 1) + x = np.linspace(0, 1, 4) + y = np.linspace(0, 1, 4) + X, Y = np.meshgrid(x, y) + Z = X + Y + return ax.plot_surface(X, Y, Z) +# =========================================================================== +# __repr__ +# =========================================================================== +class TestRepr: + def test_plot2d_repr(self): + plot = _make_plot2d((128, 256)) + r = repr(plot) + assert "Plot2D" in r + assert "256" in r + assert "128" in r + assert "gray" in r + def test_plot1d_repr(self): + plot = _make_plot1d(100) + r = repr(plot) + assert "Plot1D" in r + assert "100" in r + def test_plot3d_repr(self): + plot = _make_plot3d() + r = repr(plot) + assert "Plot3D" in r + assert "surface" in r + def test_plotbar_repr(self): + fig, ax = apl.subplots(1, 1) + plot = ax.bar([1, 2, 3]) + r = repr(plot) + assert "PlotBar" in r + assert "3" in r +# =========================================================================== +# Marker type regression +# =========================================================================== +def test_plot1d_add_circles_still_uses_points(): + """Plot1D.add_circles should continue to use the "points" wire type.""" + plot = _make_plot1d() + offsets = np.array([10.0, 20.0, 30.0]) + plot.add_circles(offsets, name="ev") + wire = plot.markers.to_wire_list() + assert wire[0]["type"] == "points" +# =========================================================================== +# Colormap alias +# =========================================================================== +def test_cividis_alias_resolves(): + from anyplotlib._utils import _build_colormap_lut, _CMAP_ALIASES + alias = _CMAP_ALIASES.get("cividis", "cividis") + assert alias != "dimgray" + import colorcet as cc + assert alias in cc.palette + lut = _build_colormap_lut("cividis") + assert len(lut) == 256 + assert lut[0] != lut[-1] +# =========================================================================== +# Top-level public API +# =========================================================================== +def test_top_level_imports(): + from anyplotlib import Plot1D, Plot2D, Axes, CallbackRegistry, Event # noqa: F401 + assert Plot1D is not None + assert Plot2D is not None + assert Axes is not None + assert CallbackRegistry is not None + assert Event is not None +def test_top_level_all(): + import anyplotlib + for name in anyplotlib.__all__: + assert hasattr(anyplotlib, name), f"anyplotlib.{name} not found" +# =========================================================================== +# No debug print in Figure._on_event +# =========================================================================== +def test_no_debug_print_in_on_event(capsys): + import json + fig, ax = apl.subplots(1, 1) + plot = ax.plot(np.zeros(16)) + payload = { + "source": "js", + "panel_id": plot._id, + "event_type": "on_changed", + "zoom": 1.5, + "center_x": 0.5, + "center_y": 0.5, + } + fig._on_event({"new": json.dumps(payload)}) + captured = capsys.readouterr() + assert captured.out == "", f"Unexpected stdout: {captured.out!r}" + + +# =========================================================================== +# Phase 2 — Plot2D state methods +# =========================================================================== + +class TestPlot2DLabels: + + def test_set_xlabel(self): + p = _make_plot2d() + p.set_xlabel("x (nm)") + assert p._state["x_label"] == "x (nm)" + + def test_set_ylabel(self): + p = _make_plot2d() + p.set_ylabel("y (nm)") + assert p._state["y_label"] == "y (nm)" + + def test_set_title(self): + p = _make_plot2d() + p.set_title("My Image") + assert p._state["title"] == "My Image" + + def test_set_colorbar_label(self): + p = _make_plot2d() + p.set_colorbar_label("Intensity") + assert p._state["colorbar_label"] == "Intensity" + + def test_default_labels_empty(self): + p = _make_plot2d() + assert p._state["x_label"] == "" + assert p._state["y_label"] == "" + assert p._state["title"] == "" + assert p._state["colorbar_label"] == "" + + +class TestPlot2DAxisLimits: + + def test_set_xlim_delegates_to_set_view(self): + p = _make_plot2d((32, 32)) + p.set_xlim(5, 20) + assert p._state["zoom"] != 1.0 or p._state["center_x"] != 0.5 + + def test_set_ylim_delegates_to_set_view(self): + p = _make_plot2d((32, 32)) + p.set_ylim(5, 20) + assert p._state["zoom"] != 1.0 or p._state["center_y"] != 0.5 + + def test_get_ylim_returns_y_axis_bounds(self): + fig, ax = apl.subplots(1, 1) + y_axis = np.linspace(0.0, 5.0, 32) + p = ax.imshow(np.zeros((32, 32)), axes=[np.arange(32), y_axis]) + lo, hi = p.get_ylim() + assert lo == pytest.approx(0.0) + assert hi == pytest.approx(5.0) + + def test_get_xbound_returns_x_axis_bounds(self): + fig, ax = apl.subplots(1, 1) + x_axis = np.linspace(-1.0, 3.0, 32) + p = ax.imshow(np.zeros((32, 32)), axes=[x_axis, np.arange(32)]) + lo, hi = p.get_xbound() + assert lo == pytest.approx(-1.0) + assert hi == pytest.approx(3.0) + + +class TestPlot2DExtent: + + def test_set_extent_updates_axes(self): + p = _make_plot2d((32, 32)) + x_new = np.linspace(0.0, 10.0, 32) + y_new = np.linspace(0.0, 20.0, 32) + p.set_extent(x_new, y_new) + assert p._state["x_axis"][0] == pytest.approx(0.0) + assert p._state["x_axis"][-1] == pytest.approx(10.0) + assert p._state["y_axis"][-1] == pytest.approx(20.0) + + def test_set_extent_updates_scale(self): + p = _make_plot2d((32, 32)) + x_new = np.linspace(0.0, 31.0, 32) + y_new = np.linspace(0.0, 62.0, 32) + p.set_extent(x_new, y_new) + assert p._state["scale_x"] == pytest.approx(1.0) + assert p._state["scale_y"] == pytest.approx(2.0) + + +class TestPlot2DColorbar: + + def test_set_colorbar_visible_true(self): + p = _make_plot2d() + p.set_colorbar_visible(True) + assert p._state["show_colorbar"] is True + + def test_set_colorbar_visible_false(self): + p = _make_plot2d() + p.set_colorbar_visible(True) + p.set_colorbar_visible(False) + assert p._state["show_colorbar"] is False + + +class TestPlot2DAspect: + + def test_set_aspect_float(self): + p = _make_plot2d() + p.set_aspect(2.0) + assert p._state["aspect"] == pytest.approx(2.0) + + def test_set_aspect_equal_string(self): + p = _make_plot2d() + p.set_aspect("equal") + assert p._state["aspect"] == pytest.approx(1.0) + + def test_set_aspect_none(self): + p = _make_plot2d() + p.set_aspect("equal") + p.set_aspect(None) + assert p._state["aspect"] is None + + +class TestPlot2DAxisVisibility: + + def test_set_axis_off(self): + p = _make_plot2d() + assert p._state["axis_visible"] is True + p.set_axis_off() + assert p._state["axis_visible"] is False + + def test_set_ticks_visible_false(self): + p = _make_plot2d() + p.set_ticks_visible(False) + assert p._state["x_ticks_visible"] is False + assert p._state["y_ticks_visible"] is False + + def test_set_ticks_visible_per_axis(self): + p = _make_plot2d() + p.set_ticks_visible(False, x=False, y=True) + assert p._state["x_ticks_visible"] is False + assert p._state["y_ticks_visible"] is True + + +class TestGetColorCycle: + + def test_get_color_cycle_returns_list(self): + import anyplotlib as apl + result = apl.get_color_cycle() + assert isinstance(result, list) + + def test_get_color_cycle_elements_are_strings(self): + import anyplotlib as apl + result = apl.get_color_cycle() + assert all(isinstance(c, str) for c in result) + + def test_get_color_cycle_returns_copy(self): + import anyplotlib as apl + a = apl.get_color_cycle() + b = apl.get_color_cycle() + a.append("extra") + assert len(b) == len(apl.get_color_cycle()) + + def test_get_color_cycle_nonempty(self): + import anyplotlib as apl + assert len(apl.get_color_cycle()) > 0 + + +# =========================================================================== +# Figure resize — Plot2D correctness +# =========================================================================== + +class TestFigureResizePlot2D: + """Figure resize correctly propagates to layout_json and Plot2D panel state. + + The _on_resize observer calls _push_layout() (which recomputes panel pixel + dimensions from the new fig_width/fig_height) then re-pushes every panel's + JSON. For Plot2D panels the panel JSON must still carry the full axis state + so the JS renderer can correctly position tick labels and scale the image. + """ + + def test_resize_updates_layout_fig_size(self): + """layout_json reflects the new fig_width and fig_height after resize.""" + import json + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + ax.imshow(np.zeros((32, 32))) + + fig.fig_width = 800 + fig.fig_height = 600 + + layout = json.loads(fig.layout_json) + assert layout["fig_width"] == 800 + assert layout["fig_height"] == 600 + + def test_resize_updates_single_panel_dimensions(self): + """Panel width/height in layout_json match the new figure size (1×1 grid).""" + import json + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + plot = ax.imshow(np.zeros((32, 32))) + + fig.fig_width = 800 + fig.fig_height = 600 + + layout = json.loads(fig.layout_json) + spec = next(s for s in layout["panel_specs"] if s["id"] == plot._id) + assert spec["panel_width"] == 800 + assert spec["panel_height"] == 600 + + def test_resize_plot2d_with_axes_preserves_axis_state(self): + """Plot2D with physical axes keeps has_axes, x_axis, y_axis, and units after resize.""" + import json + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + x_axis = np.linspace(0.0, 10.0, 32) + y_axis = np.linspace(0.0, 20.0, 32) + plot = ax.imshow(np.zeros((32, 32)), axes=[x_axis, y_axis], units="nm") + + panel_before = json.loads(getattr(fig, f"panel_{plot._id}_json")) + + fig.fig_width = 800 + fig.fig_height = 600 + + panel_after = json.loads(getattr(fig, f"panel_{plot._id}_json")) + assert panel_after["has_axes"] is True + assert panel_after["x_axis"] == panel_before["x_axis"] + assert panel_after["y_axis"] == panel_before["y_axis"] + assert panel_after["units"] == "nm" + + def test_resize_does_not_alter_data_scale(self): + """Resizing the figure must not change Plot2D scale_x/scale_y (data-space quantities).""" + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + x_axis = np.linspace(0.0, 10.0, 32) + y_axis = np.linspace(0.0, 20.0, 32) + plot = ax.imshow(np.zeros((32, 32)), axes=[x_axis, y_axis], units="nm") + + scale_x_before = plot._state["scale_x"] + scale_y_before = plot._state["scale_y"] + + fig.fig_width = 800 + fig.fig_height = 600 + + assert plot._state["scale_x"] == pytest.approx(scale_x_before) + assert plot._state["scale_y"] == pytest.approx(scale_y_before) + + def test_resize_plot2d_with_axes_layout_kind(self): + """layout_json marks a Plot2D with axes as kind='2d' after resize.""" + import json + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + plot = ax.imshow(np.zeros((32, 32)), axes=[np.arange(32), np.arange(32)]) + + fig.fig_width = 640 + fig.fig_height = 480 + + layout = json.loads(fig.layout_json) + spec = next(s for s in layout["panel_specs"] if s["id"] == plot._id) + assert spec["kind"] == "2d" + + def test_resize_two_panel_splits_width_evenly(self): + """Both Plot2D panels in a 1×2 grid each get half the new figure width.""" + import json + fig, axs = apl.subplots(1, 2, figsize=(400, 200)) + plot_l = axs[0].imshow(np.zeros((16, 16))) + plot_r = axs[1].imshow(np.zeros((16, 16))) + + fig.fig_width = 800 + + layout = json.loads(fig.layout_json) + specs = {s["id"]: s for s in layout["panel_specs"]} + assert specs[plot_l._id]["panel_width"] == pytest.approx(400, abs=1) + assert specs[plot_r._id]["panel_width"] == pytest.approx(400, abs=1) + + def test_resize_with_height_ratios_scales_proportionally(self): + """GridSpec height_ratios [3, 1] scale correctly when fig_height changes.""" + import json + gs = apl.GridSpec(2, 1, height_ratios=[3, 1]) + fig = apl.Figure(figsize=(400, 400)) + plot_top = fig.add_subplot(gs[0, 0]).imshow(np.zeros((32, 32))) + plot_bot = fig.add_subplot(gs[1, 0]).imshow(np.zeros((16, 16))) + + fig.fig_height = 800 + + layout = json.loads(fig.layout_json) + specs = {s["id"]: s for s in layout["panel_specs"]} + # top: 3/4 × 800 = 600 px; bottom: 1/4 × 800 = 200 px + assert specs[plot_top._id]["panel_height"] == pytest.approx(600, abs=1) + assert specs[plot_bot._id]["panel_height"] == pytest.approx(200, abs=1) + + +# =========================================================================== +# Plot2D.get_xlim +# =========================================================================== + +class TestPlot2DGetXlim: + def test_get_xlim_exists(self): + p = _make_plot2d() + assert hasattr(p, "get_xlim") + + def test_get_xlim_with_physical_axes(self): + fig, ax = apl.subplots(1, 1) + x = np.linspace(0.0, 10.0, 16) + p = ax.imshow(np.zeros((16, 16)), axes=[x, np.linspace(0, 5, 16)], units="nm") + lo, hi = p.get_xlim() + assert lo == pytest.approx(0.0) + assert hi == pytest.approx(10.0) + + def test_get_xlim_and_get_ylim_match_axes(self): + fig, ax = apl.subplots(1, 1) + x = np.linspace(1.0, 5.0, 16) + y = np.linspace(2.0, 8.0, 16) + p = ax.imshow(np.zeros((16, 16)), axes=[x, y], units="m") + xlo, xhi = p.get_xlim() + ylo, yhi = p.get_ylim() + assert xlo == pytest.approx(1.0) + assert xhi == pytest.approx(5.0) + assert ylo == pytest.approx(2.0) + assert yhi == pytest.approx(8.0) + + +# =========================================================================== +# Plot2D: set_axis_on and no log_scale key +# =========================================================================== + +class TestPlot2DSetAxisOn: + def test_set_axis_on_restores(self): + fig, ax = apl.subplots(1, 1) + p = ax.imshow(np.zeros((8, 8)), units="px") + p.set_axis_off() + assert p._state["axis_visible"] is False + p.set_axis_on() + assert p._state["axis_visible"] is True + + def test_no_log_scale_key(self): + fig, ax = apl.subplots(1, 1) + p = ax.imshow(np.zeros((8, 8)), units="px") + assert "log_scale" not in p._state + + +class TestPlotMeshRepr: + def test_repr_is_plotmesh(self): + from anyplotlib.plot2d import PlotMesh + fig, ax = apl.subplots(1, 1) + p = ax.pcolormesh(np.ones((4, 6))) + r = repr(p) + assert r.startswith("PlotMesh(") + assert "4" in r + assert "6" in r + + def test_repr_not_plot2d(self): + from anyplotlib.plot2d import PlotMesh + fig, ax = apl.subplots(1, 1) + p = ax.pcolormesh(np.ones((3, 5))) + assert not repr(p).startswith("Plot2D(") + + +# =========================================================================== +# m2: configure_pointer_settled public on Plot2D +# =========================================================================== + +class TestPlot2DConfigurePointerSettled: + def test_public_method_exists(self): + fig, ax = apl.subplots(1, 1) + p = ax.imshow(np.zeros((8, 8)), units="px") + assert hasattr(p, "configure_pointer_settled") + assert callable(p.configure_pointer_settled) + + def test_sets_state(self): + fig, ax = apl.subplots(1, 1) + p = ax.imshow(np.zeros((8, 8)), units="px") + p.configure_pointer_settled(150, 3) + assert p._state["pointer_settled_ms"] == 150 + assert p._state["pointer_settled_delta"] == 3 + + +# =========================================================================== +# m3: set_title / set_xlabel / set_ylabel direct tests on Plot2D +# =========================================================================== + +class TestPlot2DDisplayMethods: + def test_set_title(self): + fig, ax = apl.subplots(1, 1) + p = ax.imshow(np.zeros((8, 8)), units="px") + p.set_title("My Image") + assert p._state["title"] == "My Image" + + def test_set_xlabel(self): + fig, ax = apl.subplots(1, 1) + p = ax.imshow(np.zeros((8, 8)), units="px") + p.set_xlabel("x (nm)") + assert p._state["x_label"] == "x (nm)" + + def test_set_ylabel(self): + fig, ax = apl.subplots(1, 1) + p = ax.imshow(np.zeros((8, 8)), units="px") + p.set_ylabel("y (nm)") + assert p._state["y_label"] == "y (nm)" diff --git a/anyplotlib/tests/test_plot3d/__init__.py b/anyplotlib/tests/test_plot3d/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/anyplotlib/tests/test_plot3d/test_colors_highlight.py b/anyplotlib/tests/test_plot3d/test_colors_highlight.py new file mode 100644 index 00000000..574539d8 --- /dev/null +++ b/anyplotlib/tests/test_plot3d/test_colors_highlight.py @@ -0,0 +1,174 @@ +""" +Tests for Plot3D per-point scatter colors, the highlight point, and the +bounds override — the capabilities behind the IPF explorer example. +""" +from __future__ import annotations + +import base64 + +import numpy as np +import pytest + +import anyplotlib as apl + + +def _scatter(**kwargs): + fig, ax = apl.subplots(1, 1, figsize=(300, 300)) + pts = np.array([[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) + return ax.scatter3d(pts[:, 0], pts[:, 1], pts[:, 2], **kwargs) + + +class TestPointColors: + def test_hex_list(self): + v = _scatter(colors=["#ff0000", "#00ff00", "#0000ff"]) + raw = base64.b64decode(v._state["point_colors_b64"]) + assert list(raw) == [255, 0, 0, 0, 255, 0, 0, 0, 255] + + def test_float_array(self): + v = _scatter(colors=np.array([[1, 0, 0], [0, 1, 0], [0, 0, 1.0]])) + raw = base64.b64decode(v._state["point_colors_b64"]) + assert list(raw) == [255, 0, 0, 0, 255, 0, 0, 0, 255] + + def test_wrong_length_raises(self): + with pytest.raises(ValueError, match="2 colors for 3 points"): + _scatter(colors=["#ff0000", "#00ff00"]) + + def test_colors_on_surface_raises(self): + fig, ax = apl.subplots(1, 1) + g = np.linspace(0, 1, 4) + XX, YY = np.meshgrid(g, g) + with pytest.raises(ValueError, match="only supported for scatter"): + apl.Plot3D("surface", XX, YY, XX * YY, colors=["#fff"] * 16) + + def test_set_point_colors_update_and_clear(self): + v = _scatter() + assert v._state["point_colors_b64"] == "" + v.set_point_colors(["#112233"] * 3) + assert v._state["point_colors_b64"] != "" + v.set_point_colors(None) + assert v._state["point_colors_b64"] == "" + + +class TestHighlight: + def test_set_and_clear(self): + v = _scatter() + v.set_highlight(0.1, 0.2, 0.3, color="#ffffff", size=9) + hl = v._state["highlight"] + assert hl == {"x": 0.1, "y": 0.2, "z": 0.3, + "color": "#ffffff", "size": 9.0} + v.clear_highlight() + assert v._state["highlight"] is None + + +class TestSphere: + def test_set_and_clear(self): + v = _scatter(bounds=((-1, 1),) * 3) + v.set_sphere(1.0, color="#777777", alpha=0.2, wireframe=False) + assert v._state["sphere"] == {"radius": 1.0, "color": "#777777", + "alpha": 0.2, "wireframe": False} + v.clear_sphere() + assert v._state["sphere"] is None + + def test_sphere_renders_silhouette(self, interact_page): + """The shaded disk + wireframe must add substantial ink, bounded by + the silhouette circle.""" + def ink(with_sphere): + v = _scatter(bounds=((-1, 1),) * 3, point_size=2) + v.set_axis_off() + if with_sphere: + v.set_sphere(1.0) + page = interact_page(v._fig) + page.wait_for_timeout(200) + return page.evaluate("""() => { + const c = [...document.querySelectorAll('canvas')].find(x => x.style.position === 'relative' && x.style.display !== 'none'); + const d = c.getContext('2d').getImageData(0,0,c.width,c.height).data; + // count pixels that differ from the corner background + const bg = [d[0], d[1], d[2]]; + let n = 0; + for (let i = 0; i < d.length; i += 4) { + if (Math.abs(d[i]-bg[0])+Math.abs(d[i+1]-bg[1]) + +Math.abs(d[i+2]-bg[2]) > 24) n++; + } + return n; + }""") + + without = ink(False) + with_s = ink(True) + assert with_s > without + 2000, ( + f"sphere added too little ink: {without} -> {with_s}") + + +class TestBoundsOverride: + def test_bounds_fix_data_bounds(self): + v = _scatter(bounds=((-1, 1), (-1, 1), (-1, 1))) + assert v._state["data_bounds"] == { + "xmin": -1.0, "xmax": 1.0, "ymin": -1.0, "ymax": 1.0, + "zmin": -1.0, "zmax": 1.0} + + def test_set_data_preserves_bounds(self): + v = _scatter(bounds=((-1, 1),) * 3) + v.set_data([0.5], [0.5], [0.5]) + assert v._state["data_bounds"]["xmin"] == -1.0 + + def test_default_bounds_fit_data(self): + v = _scatter() + assert v._state["data_bounds"]["xmax"] == 1.0 + assert v._state["data_bounds"]["xmin"] == 0.0 + + +class TestRendering: + def test_colored_points_and_highlight_render(self, interact_page): + """Pure-coloured points and a white highlight must appear on canvas.""" + v = _scatter(colors=["#ff0000", "#00ff00", "#0000ff"], + point_size=10, bounds=((-1, 1),) * 3) + v.set_axis_off() + v.set_highlight(-0.6, -0.6, -0.6, color="#ffffff", size=9) + fig = v._fig + page = interact_page(fig) + page.wait_for_timeout(200) + + found = page.evaluate("""() => { + const c = [...document.querySelectorAll('canvas')].find(x => x.style.position === 'relative' && x.style.display !== 'none'); + const d = c.getContext('2d').getImageData(0, 0, c.width, c.height).data; + const seen = { red: false, green: false, blue: false, white: false }; + for (let i = 0; i < d.length; i += 4) { + const r = d[i], g = d[i+1], b = d[i+2]; + if (r > 220 && g < 60 && b < 60) seen.red = true; + if (g > 220 && r < 60 && b < 60) seen.green = true; + if (b > 220 && r < 60 && g < 60) seen.blue = true; + if (r > 240 && g > 240 && b > 240) seen.white = true; + } + return seen; + }""") + assert found["red"] and found["green"] and found["blue"], ( + f"per-point colours missing from canvas: {found}") + assert found["white"], f"highlight dot missing from canvas: {found}" + + def test_highlight_moves_with_set_view(self, interact_page): + """After rotate-to-face, the highlight must sit near panel centre.""" + v = _scatter(bounds=((-1, 1),) * 3, point_size=2) + v.set_axis_off() + d = np.array([0.3, 0.4, 0.866]) + d = d / np.linalg.norm(d) + v.set_highlight(*d, color="#ff00ff", size=8) + # Turntable face-camera: el = asin(vz), az = atan2(vx, -vy) + el = float(np.degrees(np.arcsin(np.clip(d[2], -1, 1)))) + az = float(np.degrees(np.arctan2(d[0], -d[1]))) + v.set_view(azimuth=az, elevation=el) + page = interact_page(v._fig) + page.wait_for_timeout(200) + + pos = page.evaluate("""() => { + const c = [...document.querySelectorAll('canvas')].find(x => x.style.position === 'relative' && x.style.display !== 'none'); + const d = c.getContext('2d').getImageData(0, 0, c.width, c.height).data; + let sx = 0, sy = 0, n = 0; + for (let y = 0; y < c.height; y++) for (let x = 0; x < c.width; x++) { + const i = (y * c.width + x) * 4; + if (d[i] > 220 && d[i+1] < 80 && d[i+2] > 220) { sx += x; sy += y; n++; } + } + return n ? { x: sx / n, y: sy / n, n, w: c.width, h: c.height } : null; + }""") + assert pos is not None, "magenta highlight not found on canvas" + # Facing the camera ⇒ projected at the panel centre (within tolerance) + assert abs(pos["x"] - pos["w"] / 2) < 6, f"highlight off-centre x: {pos}" + assert abs(pos["y"] - pos["h"] / 2) < 6, f"highlight off-centre y: {pos}" diff --git a/anyplotlib/tests/test_plot3d/test_gpu_fallback.py b/anyplotlib/tests/test_plot3d/test_gpu_fallback.py new file mode 100644 index 00000000..6ec99370 --- /dev/null +++ b/anyplotlib/tests/test_plot3d/test_gpu_fallback.py @@ -0,0 +1,240 @@ +""" +Tests for the WebGPU scatter path — focused on the FALLBACK CONTRACT. + +A real GPU adapter is rarely available in CI (headless Chromium exposes +``navigator.gpu`` but ``requestAdapter()`` returns null without Vulkan/ +lavapipe), so these tests assert the thing that must hold *everywhere*: +when the GPU is unavailable, a GPU-requesting scatter renders identically +to the Canvas2D path and ``gpu_active`` reports False. + +The actual GPU render is validated manually on a real-GPU machine; see +WEBGPU_PLAN.md Phase 1 acceptance. +""" +from __future__ import annotations + +import json + +import numpy as np +import pytest + +import anyplotlib as apl + + +def _scatter(n=100, **kwargs): + fig, ax = apl.subplots(1, 1, figsize=(320, 320)) + rng = np.random.default_rng(1) + pts = rng.uniform(-1, 1, size=(n, 3)) + return ax.scatter3d(pts[:, 0], pts[:, 1], pts[:, 2], + bounds=((-1, 1),) * 3, **kwargs) + + +class TestGpuApi: + def test_default_mode_auto(self): + assert _scatter()._state["gpu_mode"] == "auto" + + def test_gpu_true_is_always(self): + assert _scatter(gpu=True)._state["gpu_mode"] == "always" + + def test_gpu_false_is_off(self): + assert _scatter(gpu=False)._state["gpu_mode"] == "off" + + def test_gpu_active_starts_false(self): + assert _scatter()._gpu_active is False + assert _scatter().gpu_active is False + + def test_gpu_status_echo_updates_active(self): + v = _scatter() + fig = v._fig + fig._dispatch_event(json.dumps({ + "panel_id": v._id, "event_type": "gpu_status", "gpu_active": True})) + assert v.gpu_active is True + fig._dispatch_event(json.dumps({ + "panel_id": v._id, "event_type": "gpu_status", "gpu_active": False})) + assert v.gpu_active is False + + def test_gpu_only_for_scatter(self): + # voxels/surface don't carry gpu_mode into the GPU path (Phase 1 = + # points only); the kwarg simply isn't offered there. Sanity: scatter + # has the field, surface does not error. + assert "gpu_mode" in _scatter()._state + + +class TestFallbackRendersOnCanvas: + """gpu='always' with no adapter MUST render via Canvas2D, unchanged.""" + + def _red_ink(self, page): + return page.evaluate("""() => { + const cs = [...document.querySelectorAll('canvas')]; + const c = cs.find(x => !x.style.zIndex || x.style.zIndex === '1'); + const d = c.getContext('2d').getImageData(0,0,c.width,c.height).data; + let red = 0; + for (let i = 0; i < d.length; i += 4) + if (d[i] > 180 && d[i+1] < 140 && d[i+2] < 140) red++; + return red; + }""") + + def test_always_falls_back_to_canvas(self, interact_page): + v = _scatter(n=2000, gpu="always", + colors=np.tile([255, 80, 80], (2000, 1)).astype(np.uint8), + point_size=4) + v.set_axis_off() + page = interact_page(v._fig) + page.wait_for_timeout(400) # allow the async device probe to resolve + # When requestAdapter() is null the gpuCanvas stays hidden … + disp = page.evaluate("""() => { + const g = [...document.querySelectorAll('canvas')] + .find(c => c.style.zIndex === '0'); + return g ? g.style.display : 'none'; + }""") + assert disp == 'none', "gpuCanvas must stay hidden without an adapter" + # … and the points still render on the 2D canvas. + assert self._red_ink(page) > 500, "canvas fallback produced no points" + + def test_auto_small_cloud_uses_canvas(self, interact_page): + """Below the threshold, 'auto' never even probes the GPU.""" + v = _scatter(n=500, gpu="auto", + colors=np.tile([255, 80, 80], (500, 1)).astype(np.uint8), + point_size=4) + v.set_axis_off() + page = interact_page(v._fig) + page.wait_for_timeout(300) + assert self._red_ink(page) > 200 + + def test_gpu_off_renders_canvas(self, interact_page): + v = _scatter(n=1000, gpu=False, + colors=np.tile([255, 80, 80], (1000, 1)).astype(np.uint8), + point_size=4) + v.set_axis_off() + page = interact_page(v._fig) + page.wait_for_timeout(300) + assert self._red_ink(page) > 300 + + def test_no_console_errors_on_fallback(self, interact_page): + v = _scatter(n=2000, gpu="always") + v.set_axis_off() + page = interact_page(v._fig) + errors = [] + page.on("pageerror", lambda e: errors.append(str(e))) + page.wait_for_timeout(400) + assert not errors, f"GPU fallback raised page errors: {errors}" + + +def _voxels(n_side=8, **kwargs): + fig, ax = apl.subplots(1, 1, figsize=(320, 320)) + g = np.arange(0, n_side, dtype=float) + zz, yy, xx = np.meshgrid(g, g, g, indexing="ij") + return ax.voxels(xx.ravel(), yy.ravel(), zz.ravel(), + bounds=((0, n_side - 1),) * 3, **kwargs) + + +class TestVoxelGpuFallback: + """gpu='always' voxels with no adapter MUST render via Canvas2D.""" + + def test_voxel_gpu_mode_state(self): + assert _voxels(gpu=True)._state["gpu_mode"] == "always" + assert _voxels(gpu=False)._state["gpu_mode"] == "off" + assert _voxels()._state["gpu_mode"] == "auto" + + def test_voxel_always_falls_back_to_canvas(self, interact_page): + colors = np.tile([255, 60, 60], (512, 1)).astype(np.uint8) + v = _voxels(colors=colors, alpha=0.4, gpu="always") + v.set_axis_off() + page = interact_page(v._fig) + page.wait_for_timeout(400) + disp = page.evaluate("""() => { + const g = [...document.querySelectorAll('canvas')] + .find(c => c.style.zIndex === '0'); + return g ? g.style.display : 'none'; + }""") + assert disp == 'none', "voxel gpuCanvas must stay hidden without adapter" + red = page.evaluate("""() => { + const c = [...document.querySelectorAll('canvas')] + .find(x => x.style.position === 'relative' && x.style.display !== 'none'); + const d = c.getContext('2d').getImageData(0,0,c.width,c.height).data; + let r = 0; + for (let i = 0; i < d.length; i += 4) + if (d[i] > 120 && d[i+1] < 120 && d[i+2] < 120) r++; + return r; + }""") + assert red > 500, "voxel canvas fallback produced no cubes" + + def test_voxel_gpu_no_console_errors(self, interact_page): + v = _voxels(colors=np.tile([200, 80, 80], (512, 1)).astype(np.uint8), + gpu="always") + v.set_axis_off() + v.add_widget("plane", axis="z", position=4) + page = interact_page(v._fig) + errors = [] + page.on("pageerror", lambda e: errors.append(str(e))) + page.wait_for_timeout(400) + assert not errors, f"GPU voxel fallback raised errors: {errors}" + + def test_gpu_draw_failure_self_heals(self, interact_page, _pw_browser): + """A GPU device that ACTIVATES then throws mid-draw (e.g. Safari's + experimental WebGPU losing the device) must self-heal: the panel + re-renders on the canvas path in the same frame — voxels AND axes — + without the user needing to resize, and the plotCanvas background is + restored to opaque (not left transparent over a dead gpuCanvas). + """ + import pathlib, tempfile + from anyplotlib.tests.conftest import _build_interact_html + + colors = np.tile([255, 60, 60], (512, 1)).astype(np.uint8) + v = _voxels(colors=colors, alpha=0.5, gpu="always") # axes ON + html = _build_interact_html(v._fig) + with tempfile.NamedTemporaryFile( + suffix=".html", mode="w", encoding="utf-8", delete=False) as fh: + fh.write(html) + tmp = pathlib.Path(fh.name) + + # Fake navigator.gpu: adapter+device resolve (GPU ACTIVATES, plotCanvas + # goes transparent), but the first command encoder throws — the exact + # "worked beautifully then broke" Safari signature. + fake_gpu = """ + () => { + const tex = () => ({ createView:()=>({}), destroy:()=>{} }); + const buf = () => ({ destroy:()=>{} }); + const dev = { + lost: new Promise(()=>{}), + createShaderModule:()=>({}), createBuffer:()=>buf(), + createBindGroupLayout:()=>({}), createPipelineLayout:()=>({}), + createBindGroup:()=>({}), createTexture:()=>tex(), + createRenderPipeline:()=>({ getBindGroupLayout:()=>({}) }), + createCommandEncoder:()=>{ throw new Error('SIMULATED mid-draw GPU failure'); }, + queue:{ writeBuffer:()=>{}, submit:()=>{}, readTexture:()=>new Uint8Array(4) }, + }; + navigator.gpu = { + getPreferredCanvasFormat:()=>'bgra8unorm', + requestAdapter: async ()=>({ info:{}, requestDevice: async ()=>dev }), + }; + }""" + page = _pw_browser.new_page() + page.set_viewport_size({"width": 400, "height": 400}) + page.add_init_script(fake_gpu) + errors = [] + page.on("pageerror", lambda e: errors.append(str(e))) + try: + page.goto(tmp.as_uri()) + page.wait_for_function("() => window._aplReady === true", timeout=15000) + page.wait_for_timeout(600) + res = page.evaluate("""() => { + const cs = [...document.querySelectorAll('canvas')]; + const plot = cs.find(x => x.style.zIndex === '1'); + const gpu = cs.find(x => x.style.zIndex === '0'); + const d = plot.getContext('2d').getImageData(0,0,plot.width,plot.height).data; + let red = 0; + for (let i = 0; i < d.length; i += 4) + if (d[i] > 150 && d[i+1] < 130 && d[i+2] < 130) red++; + return { plotBg: plot.style.background, + gpuDisp: gpu ? gpu.style.display : null, red }; + }""") + finally: + page.close() + tmp.unlink(missing_ok=True) + + assert not errors, f"mid-draw GPU failure leaked errors: {errors}" + assert res["gpuDisp"] == "none", "dead gpuCanvas must be hidden" + assert res["plotBg"] and res["plotBg"] != "transparent", \ + f"plotCanvas bg must be restored to opaque, got {res['plotBg']!r}" + assert res["red"] > 500, \ + f"panel did not self-heal onto canvas (no voxels): {res}" diff --git a/anyplotlib/tests/test_plot3d/test_plot3d.py b/anyplotlib/tests/test_plot3d/test_plot3d.py new file mode 100644 index 00000000..5dcb8f4f --- /dev/null +++ b/anyplotlib/tests/test_plot3d/test_plot3d.py @@ -0,0 +1,361 @@ +""" +tests/test_plot3d.py +==================== + +Tests for Plot3D — surface, scatter, and line geometry types. +Mirrors the Examples/plot_3d.py gallery example. + +Covers: + * plot_surface with 2-D meshgrid arrays + * scatter3d + * plot3d (line) + * set_data() — replace geometry + * set_colormap() — change colormap + * set_view() — azimuth and elevation + * set_zoom() + * State dict keys and shape sanity checks + * Validation: bad geom_type, bad surface array shapes +""" +from __future__ import annotations + +import numpy as np +import pytest + +import anyplotlib as apl +from anyplotlib.plot3d import Plot3D + + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- + +def _surface(): + x = np.linspace(-2, 2, 10) + y = np.linspace(-2, 2, 10) + XX, YY = np.meshgrid(x, y) + ZZ = np.sin(np.sqrt(XX ** 2 + YY ** 2)) + fig, ax = apl.subplots(1, 1) + return ax.plot_surface(XX, YY, ZZ, colormap="viridis"), XX, YY, ZZ + + +def _scatter(): + rng = np.random.default_rng(1) + n = 50 + x, y, z = rng.uniform(-1, 1, n), rng.uniform(-1, 1, n), rng.uniform(-1, 1, n) + fig, ax = apl.subplots(1, 1) + return ax.scatter3d(x, y, z, color="#4fc3f7", point_size=3), x, y, z + + +def _line(): + t = np.linspace(0, 4 * np.pi, 50) + x, y, z = np.cos(t), np.sin(t), t / (4 * np.pi) + fig, ax = apl.subplots(1, 1) + return ax.plot3d(x, y, z, color="#ff7043"), x, y, z + + +# --------------------------------------------------------------------------- +# Construction +# --------------------------------------------------------------------------- + +class TestPlot3DConstruction: + + def test_surface_kind(self): + surf, *_ = _surface() + assert surf._state["kind"] == "3d" + assert surf._state["geom_type"] == "surface" + + def test_scatter_kind(self): + sc, *_ = _scatter() + assert sc._state["geom_type"] == "scatter" + + def test_line_kind(self): + ln, *_ = _line() + assert ln._state["geom_type"] == "line" + + def test_surface_has_vertices(self): + surf, *_ = _surface() + assert surf._state["vertices_count"] == 100 # 10×10 grid + + def test_surface_has_faces(self): + surf, *_ = _surface() + assert surf._state["faces_count"] > 0 + + def test_scatter_no_faces(self): + sc, *_ = _scatter() + assert sc._state["faces_count"] == 0 + + def test_colormap_name_stored(self): + surf, *_ = _surface() + assert surf._state["colormap_name"] == "viridis" + + def test_colormap_data_is_list(self): + surf, *_ = _surface() + lut = surf._state["colormap_data"] + assert isinstance(lut, list) + assert len(lut) == 256 + + def test_default_azimuth_elevation(self): + surf, *_ = _surface() + assert surf._state["azimuth"] == pytest.approx(-60.0) + assert surf._state["elevation"] == pytest.approx(30.0) + + def test_labels_stored(self): + x = np.linspace(-1, 1, 5) + y = np.linspace(-1, 1, 5) + XX, YY = np.meshgrid(x, y) + ZZ = XX * YY + fig, ax = apl.subplots(1, 1) + surf = ax.plot_surface(XX, YY, ZZ, x_label="a", y_label="b", z_label="c") + assert surf._state["x_label"] == "a" + assert surf._state["y_label"] == "b" + assert surf._state["z_label"] == "c" + + def test_bad_geom_type(self): + x = np.array([0.0, 1.0]) + with pytest.raises(ValueError): + Plot3D("cube", x, x, x) + + def test_surface_1d_xy_arrays(self): + """plot_surface also accepts 1-D x/y + 2-D z (meshgrid already done).""" + x = np.linspace(-1, 1, 5) + y = np.linspace(-1, 1, 5) + ZZ = np.ones((5, 5)) + fig, ax = apl.subplots(1, 1) + surf = ax.plot_surface(x, y, ZZ) + assert surf._state["vertices_count"] == 25 + + def test_surface_1d_xy_shape_mismatch(self): + x = np.linspace(-1, 1, 4) + y = np.linspace(-1, 1, 5) + ZZ = np.ones((5, 5)) + with pytest.raises(ValueError): + fig, ax = apl.subplots(1, 1) + ax.plot_surface(x, y, ZZ) + + def test_surface_bad_array_shape(self): + x = np.array([1.0, 2.0]) # 1-D but z is also 1-D → invalid + with pytest.raises(ValueError): + Plot3D("surface", x, x, x) + + +# --------------------------------------------------------------------------- +# Mutations +# --------------------------------------------------------------------------- + +class TestPlot3DMutations: + + def test_set_colormap(self): + surf, *_ = _surface() + surf.set_colormap("plasma") + assert surf._state["colormap_name"] == "plasma" + assert isinstance(surf._state["colormap_data"], list) + + def test_set_view_azimuth(self): + surf, *_ = _surface() + surf.set_view(azimuth=45.0) + assert surf._state["azimuth"] == pytest.approx(45.0) + + def test_set_view_elevation(self): + surf, *_ = _surface() + surf.set_view(elevation=60.0) + assert surf._state["elevation"] == pytest.approx(60.0) + + def test_set_view_both(self): + surf, *_ = _surface() + surf.set_view(azimuth=30.0, elevation=40.0) + assert surf._state["azimuth"] == pytest.approx(30.0) + assert surf._state["elevation"] == pytest.approx(40.0) + + def test_set_zoom(self): + surf, *_ = _surface() + surf.set_zoom(2.0) + assert surf._state["zoom"] == pytest.approx(2.0) + + def test_set_data_surface(self): + surf, XX, YY, ZZ = _surface() + ZZ2 = np.cos(np.sqrt(XX ** 2 + YY ** 2)) + surf.set_data(XX, YY, ZZ2) + # vertices_count should stay the same (same grid) + assert surf._state["vertices_count"] == 100 + + def test_set_data_scatter(self): + sc, x, y, z = _scatter() + sc.set_data(x * 2, y * 2, z * 2) + bounds = sc._state["data_bounds"] + assert bounds["xmax"] > bounds["xmin"] + + def test_set_data_line(self): + ln, x, y, z = _line() + ln.set_data(x[::-1], y[::-1], z[::-1]) + assert ln._state["vertices_count"] == len(x) + + def test_set_data_surface_bad_shape(self): + surf, XX, YY, ZZ = _surface() + x = np.array([1.0, 2.0]) + with pytest.raises(ValueError): + surf.set_data(x, x, x) + + def test_set_view_clears_view_from_python(self): + surf, *_ = _surface() + surf.set_view(azimuth=10.0) + assert surf._state["_view_from_python"] is False + + def test_set_zoom_clears_view_from_python(self): + surf, *_ = _surface() + surf.set_zoom(1.5) + assert surf._state["_view_from_python"] is False + + def test_reset_view_restores_defaults(self): + surf, *_ = _surface() + surf.set_view(azimuth=90.0, elevation=10.0) + surf.set_zoom(3.0) + surf.reset_view() + assert surf._state["azimuth"] == pytest.approx(-60.0) + assert surf._state["elevation"] == pytest.approx(30.0) + assert surf._state["zoom"] == pytest.approx(1.0) + assert surf._state["_view_from_python"] is False + + def test_reset_view_uses_constructor_angles(self): + x = np.linspace(-1, 1, 5) + y = np.linspace(-1, 1, 5) + XX, YY = np.meshgrid(x, y) + ZZ = XX * YY + fig, ax = apl.subplots(1, 1) + surf = ax.plot_surface(XX, YY, ZZ, azimuth=15.0, elevation=45.0, zoom=2.0) + surf.set_view(azimuth=0.0, elevation=0.0) + surf.reset_view() + assert surf._state["azimuth"] == pytest.approx(15.0) + assert surf._state["elevation"] == pytest.approx(45.0) + assert surf._state["zoom"] == pytest.approx(2.0) + + def test_set_xlabel(self): + surf, *_ = _surface() + surf.set_xlabel("time") + assert surf._state["x_label"] == "time" + + def test_set_ylabel(self): + surf, *_ = _surface() + surf.set_ylabel("depth") + assert surf._state["y_label"] == "depth" + + def test_set_zlabel(self): + surf, *_ = _surface() + surf.set_zlabel("intensity") + assert surf._state["z_label"] == "intensity" + + def test_set_title(self): + surf, *_ = _surface() + surf.set_title("My Surface") + assert surf._state["title"] == "My Surface" + + +# =========================================================================== +# repr() uses vertices_count, not len(vertices) +# =========================================================================== + +class TestPlot3DRepr: + def test_repr_uses_vertices_count(self): + """repr() must read vertices_count, not len(state['vertices']).""" + + class _FakePlot3D(Plot3D): + def __init__(self): + self._state = {"geom_type": "mesh", "vertices_count": 42} + self._id = "" + self._fig = None + + assert "n_vertices=42" in repr(_FakePlot3D()) + + def test_repr_zero_when_count_zero(self): + class _FakePlot3D(Plot3D): + def __init__(self): + self._state = {"geom_type": "scatter", "vertices_count": 0} + self._id = "" + self._fig = None + + assert "n_vertices=0" in repr(_FakePlot3D()) + + def test_repr_on_real_line(self): + _, x, y, z = _line() + # _line() creates a Plot3D via plot3d(); repr must not raise and must + # show the correct vertex count. + from anyplotlib.plot3d._plot3d import Plot3D as _P3D + # find the plot object returned by _line + ln, *_ = _line() + r = repr(ln) + assert "n_vertices=" in r + # vertex count must equal len(x), not 0 + assert f"n_vertices={len(x)}" in r + + + + +# =========================================================================== +# C1: title initialized in _state +# =========================================================================== + +class TestPlot3DTitle: + def test_title_initialized_empty(self): + surf, *_ = _surface() + assert "title" in surf._state + assert surf._state["title"] == "" + + def test_set_title_label_param(self): + surf, *_ = _surface() + surf.set_title("My Plot") + assert surf._state["title"] == "My Plot" + + def test_set_title_in_wire(self): + surf, *_ = _surface() + surf.set_title("Wire Test") + assert surf.to_state_dict()["title"] == "Wire Test" + + +# =========================================================================== +# C2: axis_on / axis_off on Plot3D +# =========================================================================== + +class TestPlot3DAxisVisibility: + def test_axis_visible_initialized_true(self): + surf, *_ = _surface() + assert surf._state["axis_visible"] is True + + def test_set_axis_off(self): + surf, *_ = _surface() + surf.set_axis_off() + assert surf._state["axis_visible"] is False + + def test_set_axis_on_restores(self): + surf, *_ = _surface() + surf.set_axis_off() + surf.set_axis_on() + assert surf._state["axis_visible"] is True + + +# =========================================================================== +# m1: data-bounds getters on Plot3D +# =========================================================================== + +class TestPlot3DLimGetters: + def test_get_xlim(self): + surf, XX, YY, ZZ = _surface() + lo, hi = surf.get_xlim() + assert lo == pytest.approx(float(XX.min())) + assert hi == pytest.approx(float(XX.max())) + + def test_get_ylim(self): + surf, XX, YY, ZZ = _surface() + lo, hi = surf.get_ylim() + assert lo == pytest.approx(float(YY.min())) + assert hi == pytest.approx(float(YY.max())) + + def test_get_zlim(self): + surf, XX, YY, ZZ = _surface() + lo, hi = surf.get_zlim() + assert lo == pytest.approx(float(ZZ.min())) + assert hi == pytest.approx(float(ZZ.max())) + + def test_get_xlim_scatter(self): + sc, x, y, z = _scatter() + lo, hi = sc.get_xlim() + assert lo == pytest.approx(float(x.min())) + assert hi == pytest.approx(float(x.max())) diff --git a/anyplotlib/tests/test_plot3d/test_voxels_planes.py b/anyplotlib/tests/test_plot3d/test_voxels_planes.py new file mode 100644 index 00000000..48a386a3 --- /dev/null +++ b/anyplotlib/tests/test_plot3d/test_voxels_planes.py @@ -0,0 +1,247 @@ +""" +Tests for the 'voxels' geometry and 3-D PlaneWidget slice selectors. +""" +from __future__ import annotations + +import json + +import numpy as np +import pytest + +import anyplotlib as apl + + +def _voxels(**kwargs): + fig, ax = apl.subplots(1, 1, figsize=(320, 320)) + g = np.arange(0, 8, dtype=float) + zz, yy, xx = np.meshgrid(g, g, g, indexing="ij") + return ax.voxels(xx.ravel(), yy.ravel(), zz.ravel(), + bounds=((0, 7),) * 3, **kwargs) + + +class TestVoxelsState: + def test_geom_and_alpha_state(self): + v = _voxels(size=1.0, alpha=0.2) + assert v._state["geom_type"] == "voxels" + assert v._state["voxel_size"] == 1.0 + assert v._state["voxel_alpha"] == 0.2 + assert v._state["voxel_slice_alpha"] == 0.95 + + def test_per_voxel_colors_allowed(self): + colors = np.zeros((512, 3), dtype=np.uint8) + v = _voxels(colors=colors) + assert v._state["point_colors_b64"] != "" + + def test_set_point_colors_after_construction(self): + """The orthoslice explorer re-cuts slab voxels each drag via + set_data + set_point_colors, so voxels must accept post-hoc + per-voxel colours (not just at construction).""" + v = _voxels() + v.set_point_colors(np.zeros((512, 3), dtype=np.uint8)) + assert v._state["point_colors_b64"] != "" + v.set_point_colors(None) + assert v._state["point_colors_b64"] == "" + + def test_set_voxel_alpha(self): + v = _voxels() + v.set_voxel_alpha(0.1, slice_alpha=0.8) + assert v._state["voxel_alpha"] == 0.1 + assert v._state["voxel_slice_alpha"] == 0.8 + + +class TestPlaneWidget: + def test_add_plane_serialises(self): + v = _voxels() + pw = v.add_widget("plane", axis="z", position=4, color="#40c4ff") + ws = v._state["overlay_widgets"] + assert len(ws) == 1 + assert ws[0]["type"] == "plane" + assert ws[0]["axis"] == "z" + assert ws[0]["position"] == 4.0 + + def test_invalid_axis_raises(self): + v = _voxels() + with pytest.raises(ValueError, match="axis must be"): + v.add_widget("plane", axis="w", position=0) + + def test_only_plane_kind(self): + v = _voxels() + with pytest.raises(ValueError, match="only 'plane'"): + v.add_widget("crosshair") + + def test_set_position_from_python(self): + v = _voxels() + pw = v.add_widget("plane", axis="x", position=2) + pw.set(position=5) + assert pw.position == 5 + + def test_remove_widget(self): + v = _voxels() + pw = v.add_widget("plane", axis="y", position=3) + v.remove_widget(pw) + v._push() + assert v._state["overlay_widgets"] == [] + + def test_js_drag_event_round_trip(self): + """A JS plane-drag message must update position and fire callbacks.""" + v = _voxels() + pw = v.add_widget("plane", axis="z", position=4) + fig = v._fig + got = [] + + @pw.add_event_handler("pointer_move") + def on_drag(event): + got.append(pw.position) + + fig._dispatch_event(json.dumps({ + "panel_id": v._id, "widget_id": pw.id, + "event_type": "pointer_move", "axis": "z", "position": 6.25, + })) + assert got == [6.25] + assert pw.position == 6.25 + + +class TestVoxelRendering: + def test_voxels_render_with_slice_emphasis(self, interact_page): + """Voxels render; an on-plane slice draws more saturated ink.""" + colors = np.full((512, 3), [255, 0, 0], dtype=np.uint8) + v = _voxels(colors=colors, alpha=0.15) + v.set_axis_off() + v.add_widget("plane", axis="z", position=3, alpha=0.0) # invisible plane + page = interact_page(v._fig) + page.wait_for_timeout(250) + + res = page.evaluate("""() => { + const c = [...document.querySelectorAll('canvas')].find(x => x.style.position === 'relative' && x.style.display !== 'none'); + const d = c.getContext('2d').getImageData(0,0,c.width,c.height).data; + let pale = 0, strong = 0; + for (let i = 0; i < d.length; i += 4) { + const r = d[i], g = d[i+1], b = d[i+2]; + if (r > 180 && g < 160 && b < 160) { + if (g > 60) pale++; else strong++; // strong = opaque red + } + } + return { pale, strong }; + }""") + assert res["pale"] > 500, f"translucent voxel ink missing: {res}" + assert res["strong"] > 200, f"opaque slice-plane voxels missing: {res}" + + def test_voxel_gpu_canvas_layering(self, interact_page): + """The 3-D voxel panel stacks a gpuCanvas (z-index 0, WebGPU voxels) + below the plotCanvas (z-index 1, decorations). In canvas mode the + plotCanvas MUST keep an opaque background; the renderer only flips it + to ``transparent`` while the GPU path is active, so the GPU-drawn + voxels beneath aren't hidden by an opaque overlay. + + Regression for: large voxel volumes rendering "empty" (only planes + + highlight visible) in PyCharm's WebGPU-enabled JCEF, because the + opaque plotCanvas painted over the gpuCanvas. The active-GPU swap is + hardware-verified via native wgpu; CI has no adapter, so here we lock + the DOM stacking + the canvas-mode opaque-background invariant. + """ + colors = np.full((512, 3), [255, 0, 0], dtype=np.uint8) + v = _voxels(colors=colors, alpha=0.4) + v.set_axis_off() + page = interact_page(v._fig) + page.wait_for_timeout(200) + + layout = page.evaluate("""() => { + const cs = [...document.querySelectorAll('canvas')]; + const gpu = cs.find(x => x.style.zIndex === '0'); + const plot = cs.find(x => x.style.zIndex === '1'); + return { + hasGpu: !!gpu, + gpuBelow: !!gpu && !!plot, + plotBg: plot ? plot.style.background : null, + gpuDisp: gpu ? gpu.style.display : null, + }; + }""") + assert layout["hasGpu"], "3-D voxel panel must create a gpuCanvas" + assert layout["gpuDisp"] == "none", \ + "gpuCanvas stays hidden in canvas mode (no WebGPU adapter in CI)" + # Canvas mode: plotCanvas keeps an opaque bg (NOT transparent), so the + # canvas-drawn voxels read against a solid panel background. + assert layout["plotBg"] and layout["plotBg"] != "transparent", \ + f"canvas-mode plotCanvas must stay opaque, got {layout['plotBg']!r}" + + def test_plane_drag_in_browser(self, interact_page): + """Dragging a plane widget must change its position in the model.""" + v = _voxels(alpha=0.1) + v.set_axis_off() + pw = v.add_widget("plane", axis="z", position=3, alpha=0.3) + fig = v._fig + page = interact_page(fig) + page.wait_for_timeout(250) + + def js_position(): + return page.evaluate(f"""() => {{ + const st = JSON.parse(window._aplModel.get('panel_{v._id}_json')); + return st.overlay_widgets[0].position; + }}""") + + assert abs(js_position() - 3) < 1e-6 + # Locate the plane via its fully-opaque cyan border pixels, then drag + # from its centroid upward (the z screen-direction at the default view) + centre = page.evaluate("""() => { + const c = [...document.querySelectorAll('canvas')].find(x => x.style.position === 'relative' && x.style.display !== 'none'); + const r = c.getBoundingClientRect(); + const d = c.getContext('2d').getImageData(0,0,c.width,c.height).data; + let sx = 0, sy = 0, n = 0; + for (let y = 0; y < c.height; y++) for (let x = 0; x < c.width; x++) { + const i = (y * c.width + x) * 4; + if (d[i] < 60 && d[i+1] > 200 && d[i+2] > 230) { + sx += x; sy += y; n++; + } + } + return n ? { x: r.left + sx / n, y: r.top + sy / n, n } : null; + }""") + assert centre is not None, "plane border pixels not found on canvas" + page.mouse.move(centre["x"], centre["y"]) + page.mouse.down() + page.mouse.move(centre["x"], centre["y"] - 50, steps=8) + page.mouse.up() + page.wait_for_timeout(250) + moved = js_position() + assert abs(moved - 3) > 0.5, ( + f"plane did not move on drag (position still {moved})") + + +class TestPlaneDragNoSnapBack: + """Regression: a view-only push (set_highlight / set_view) must NOT clobber + a plane widget's live position — the "snap-back" symptom.""" + + def _voxels_with_plane(self): + fig, ax = apl.subplots(1, 1, figsize=(320, 320)) + g = np.arange(0, 8, dtype=float) + zz, yy, xx = np.meshgrid(g, g, g, indexing="ij") + v = ax.voxels(xx.ravel(), yy.ravel(), zz.ravel(), bounds=((0, 7),) * 3) + pw = v.add_widget("plane", axis="z", position=4) + return fig, v, pw + + def test_to_state_dict_reflects_live_widget(self): + fig, v, pw = self._voxels_with_plane() + pw.set(position=2.7) + st = v.to_state_dict() + z = next(w["position"] for w in st["overlay_widgets"] + if w["type"] == "plane" and w["axis"] == "z") + assert z == 2.7, f"to_state_dict serialised a stale plane position: {z}" + + def test_set_highlight_preserves_plane_position(self): + fig, v, pw = self._voxels_with_plane() + pw.set(position=2.7) # simulate a mid-drag float position + v.set_highlight(1, 2, 3) # view-only push on the same panel + import json + st = json.loads(getattr(fig, f"panel_{v._id}_json")) + z = next(w["position"] for w in st["overlay_widgets"] + if w["type"] == "plane" and w["axis"] == "z") + assert z == 2.7, f"set_highlight snapped the plane back to {z} (want 2.7)" + + def test_set_view_preserves_plane_position(self): + fig, v, pw = self._voxels_with_plane() + pw.set(position=5.3) + v.set_view(azimuth=10, elevation=20) + import json + st = json.loads(getattr(fig, f"panel_{v._id}_json")) + z = next(w["position"] for w in st["overlay_widgets"] + if w["type"] == "plane" and w["axis"] == "z") + assert z == 5.3, f"set_view snapped the plane back to {z} (want 5.3)" diff --git a/anyplotlib/tests/test_plotxy/test_plotxy.py b/anyplotlib/tests/test_plotxy/test_plotxy.py new file mode 100644 index 00000000..5b338a8d --- /dev/null +++ b/anyplotlib/tests/test_plotxy/test_plotxy.py @@ -0,0 +1,207 @@ +""" +PlotXY — the blank data-coordinate 2-D axis (matplotlib ``transData`` + +``PathCollection`` model): ``axes2d`` + ``scatter`` / ``plot`` / ``fill`` / +``text`` in data coords, with ``set_xlim`` / ``set_ylim`` / ``set_aspect``. +""" +import numpy as np + +import anyplotlib as apl + + +def test_axes2d_creates_plotxy(): + fig, ax = apl.subplots() + xy = ax.axes2d(xlim=(-1, 1), ylim=(-0.5, 0.9), aspect="equal") + assert isinstance(xy, apl.PlotXY) + assert xy.get_xlim() == (-1.0, 1.0) + assert xy.get_ylim() == (-0.5, 0.9) + assert xy.get_aspect() == "equal" + + +def test_set_lims_and_aspect(): + fig, ax = apl.subplots() + xy = ax.axes2d() + xy.set_xlim(-2, 3) + xy.set_ylim(-1, 5) + assert xy.get_xlim() == (-2.0, 3.0) + assert xy.get_ylim() == (-1.0, 5.0) + xy.set_aspect("equal") + assert xy.get_aspect() == "equal" + xy.set_aspect("auto") + assert xy.get_aspect() is None + + +def test_artists_are_data_coord_collections(): + fig, ax = apl.subplots() + xy = ax.axes2d(xlim=(0, 1), ylim=(0, 1)) + xy.scatter([0.1, 0.9], [0.2, 0.8], c=["#ff0000", "#00ff00"], s=8) + xy.plot([0, 1, 0.5], [0, 0, 1], color="#ffffff") + xy.fill([0, 1, 0.5], [0, 0, 1], facecolor="#eeeeee") + xy.text(0.5, 0.95, r"$[111]$") + + types = {m["type"] for m in xy.list_markers()} + assert {"points", "lines", "polygons", "texts"} <= types + + d = xy.to_state_dict() + # Reuses the 1-D data→canvas transform (matplotlib transLimits→transAxes), + # so every collection is in DATA coords — not image pixels. + assert d["kind"] == "1d" + for grp in d["markers"]: + assert grp.get("transform", "data") == "data" + + +def test_scatter_returns_collection_with_offsets(): + fig, ax = apl.subplots() + xy = ax.axes2d() + xy.scatter(np.array([0.0, 0.5, 1.0]), np.array([0.0, 0.5, 1.0]), s=6) + # PathCollection-style: one collection holding all three offsets. + pts = next(m for m in xy.list_markers() if m["type"] == "points") + assert pts["n"] == 3 + + +def test_double_click_reports_data_coords(interact_page): + """A ``double_click`` on a coordinate (PlotXY) panel reports ``xdata``/``ydata`` + in DATA coords (like the 2-D image path) — needed for a data-coord pick such + as the IPF-refine mask. Clicking panel-centre ⇒ centre of the x/y range.""" + from anyplotlib.tests.test_interactive._event_test_utils import ( + _collect_events, _get_events, _plot_center_page, + ) + fig, ax = apl.subplots(1, 1, figsize=(400, 300)) + ax.axes2d(xlim=(0, 10), ylim=(0, 20), aspect="equal") + page = interact_page(fig) + _collect_events(page) + + px, py = _plot_center_page(400, 300) + page.mouse.dblclick(px, py) + page.wait_for_timeout(100) + + evs = _get_events(page, "double_click") + assert evs, "expected a double_click event" + e = evs[-1] + assert e.get("xdata") is not None and e.get("ydata") is not None + assert abs(e["xdata"] - 5.0) < 1.5 # centre of x range (0, 10) + assert abs(e["ydata"] - 10.0) < 3.0 # centre of y range (0, 20) + + +def test_render_is_chromatic(take_screenshot): + """End-to-end: a filled triangle + coloured scatter + labels in data coords + must actually draw (canvas is chromatic, not blank).""" + fig, ax = apl.subplots(figsize=(360, 320)) + xy = ax.axes2d(xlim=(-0.05, 0.4), ylim=(-0.05, 0.4), aspect="equal") + xy.fill([0.0, 0.36, 0.0, 0.0], [0.0, 0.0, 0.36, 0.0], + facecolor="#223", edgecolor="#ffffff") + xy.scatter([0.05, 0.2, 0.3], [0.05, 0.1, 0.02], + c=["#ff3030", "#30ff60", "#3060ff"], s=10) + xy.text(0.0, 0.37, "[111]", color="#ffffff") + + arr = take_screenshot(fig) # H×W×C uint8 + rgb = arr[..., :3].astype(int) + spread = int((rgb.max(axis=2) - rgb.min(axis=2)).max()) + assert spread > 60 # genuinely coloured, not greyscale + + +def _red_bbox(arr): + """(x0, x1, y0, y1) bounding box of red-ish pixels, or None.""" + rgb = arr[..., :3].astype(int) + mask = (rgb[..., 0] > 150) & (rgb[..., 1] < 90) & (rgb[..., 2] < 90) + ys, xs = np.where(mask) + if xs.size == 0: + return None + return int(xs.min()), int(xs.max()), int(ys.min()), int(ys.max()) + + +def test_aspect_equal_renders_square(take_screenshot): + """Equal x & y spans drawn into a WIDE (2:1) panel: ``aspect="equal"`` must + apply matplotlib's ``apply_aspect`` — shrink + centre the data box to a + square (one data unit equal px on x & y), NOT stretch it to the panel.""" + fig, ax = apl.subplots(figsize=(640, 320)) + xy = ax.axes2d(xlim=(0, 1), ylim=(0, 1), aspect="equal") + xy.fill([0, 1, 0], [0, 0, 1], facecolor="#ff0000", edgecolor="#ff0000", alpha=1.0) + + bb = _red_bbox(take_screenshot(fig)) + assert bb is not None + w, h = bb[1] - bb[0], bb[3] - bb[2] + assert 0.8 < (w / h) < 1.25 # ~square, not stretched 2:1 + + +def test_pcolormesh_builds_polygon_mesh(): + """``pcolormesh`` → one polygons collection, one quad per (N, M) cell; + masked / non-finite cells are dropped (so an orix sector histogram clips + itself to the fundamental sector).""" + fig, ax = apl.subplots() + xy = ax.axes2d() + xe = np.linspace(0, 1, 4) # 3 columns of cells + ye = np.linspace(0, 1, 3) # 2 rows of cells + X, Y = np.meshgrid(xe, ye, indexing="ij") # (4, 3) corners + field = np.arange(3 * 2).reshape(3, 2).astype(float) # (3, 2) cells + xy.pcolormesh(X, Y, field) + poly = next(g for g in xy.to_state_dict()["markers"] if g["type"] == "polygons") + assert len(poly["vertices_list"]) == 6 # 3*2 cells + assert isinstance(poly["fill_color"], list) # per-cell colours + assert len(poly["fill_color"]) == 6 + + masked = np.ma.array(field, mask=[[True, False], [False, False], [False, False]]) + xy2 = ax.axes2d() + xy2.pcolormesh(X, Y, masked) + poly2 = next(g for g in xy2.to_state_dict()["markers"] if g["type"] == "polygons") + assert len(poly2["vertices_list"]) == 5 # one cell masked out + + +def test_pcolormesh_renders_gradient(take_screenshot): + """A scalar field drawn as a data-coord quad mesh (matplotlib ``pcolormesh``) + must render many distinct colormap colours — the primitive an IPF / pole + density heatmap needs.""" + fig, ax = apl.subplots(figsize=(320, 300)) + xy = ax.axes2d(xlim=(0, 1), ylim=(0, 1), aspect="equal") + n = 16 + xe = ye = np.linspace(0, 1, n + 1) + X, Y = np.meshgrid(xe, ye, indexing="ij") + gx, gy = np.meshgrid(np.linspace(0, 1, n), np.linspace(0, 1, n), indexing="ij") + xy.pcolormesh(X, Y, gx + gy, cmap="viridis") # smooth ramp + + arr = take_screenshot(fig) + rgb = arr[..., :3].astype(int) + assert int((rgb.max(2) - rgb.min(2)).max()) > 60 # chromatic + cols = {tuple(c) for c in rgb.reshape(-1, 3)[::29]} + assert len(cols) > 20 # a gradient, not flat + + +def _red_pixel_count(arr): + rgb = arr[..., :3].astype(int) + return int(((rgb[..., 0] > 150) & (rgb[..., 1] < 90) & (rgb[..., 2] < 90)).sum()) + + +def test_pcolormesh_clip_path_clips_mesh(take_screenshot): + """``clip_path`` (matplotlib set_clip_path): a full square mesh clipped to a + lower-left triangle draws only ~half the cells — the primitive that keeps an + IPF density mesh inside the curved fundamental-sector boundary.""" + n = 18 + xe = ye = np.linspace(0, 1, n + 1) + X, Y = np.meshgrid(xe, ye, indexing="ij") + field = np.full((n, n), "#ff0000", dtype=object) # every cell pure red + + fig, ax = apl.subplots(figsize=(300, 300)) + xy = ax.axes2d(xlim=(0, 1), ylim=(0, 1), aspect="equal") + xy.pcolormesh(X, Y, field) # full square + full = _red_pixel_count(take_screenshot(fig)) + + fig2, ax2 = apl.subplots(figsize=(300, 300)) + xy2 = ax2.axes2d(xlim=(0, 1), ylim=(0, 1), aspect="equal") + xy2.pcolormesh(X, Y, field, clip_path=[[0, 0], [1, 0], [0, 1]]) # lower-left ½ + clipped = _red_pixel_count(take_screenshot(fig2)) + + assert clipped > 0 # mesh still drawn + assert clipped < 0.7 * full # ~half clipped away (triangle) + + +def test_aspect_auto_fills_panel(take_screenshot): + """Without ``aspect="equal"`` the same triangle stretches to fill the wide + panel (the data box follows the panel aspect) — the contrast that proves the + equal-aspect step is actually doing something.""" + fig, ax = apl.subplots(figsize=(640, 320)) + xy = ax.axes2d(xlim=(0, 1), ylim=(0, 1)) # aspect None → fill panel + xy.fill([0, 1, 0], [0, 0, 1], facecolor="#ff0000", edgecolor="#ff0000", alpha=1.0) + + bb = _red_bbox(take_screenshot(fig)) + assert bb is not None + w, h = bb[1] - bb[0], bb[3] - bb[2] + assert (w / h) > 1.4 # stretched wide (panel is 2:1) diff --git a/anyplotlib/widgets/__init__.py b/anyplotlib/widgets/__init__.py new file mode 100644 index 00000000..b8b4ff1b --- /dev/null +++ b/anyplotlib/widgets/__init__.py @@ -0,0 +1,18 @@ +"""anyplotlib.widgets — interactive overlay widget classes.""" +from anyplotlib.widgets._base import Widget +from anyplotlib.widgets._widgets2d import ( + RectangleWidget, CircleWidget, AnnularWidget, + CrosshairWidget, PolygonWidget, LabelWidget, +) +from anyplotlib.widgets._widgets1d import ( + VLineWidget, HLineWidget, RangeWidget, PointWidget, +) +from anyplotlib.widgets._widgets3d import PlaneWidget + +__all__ = [ + "Widget", + "RectangleWidget", "CircleWidget", "AnnularWidget", + "CrosshairWidget", "PolygonWidget", "LabelWidget", + "VLineWidget", "HLineWidget", "RangeWidget", "PointWidget", + "PlaneWidget", +] diff --git a/anyplotlib/widgets/_base.py b/anyplotlib/widgets/_base.py new file mode 100644 index 00000000..e73f2301 --- /dev/null +++ b/anyplotlib/widgets/_base.py @@ -0,0 +1,218 @@ +""" +widgets/_base.py +================ +Base Widget class shared by all interactive overlay widgets. +""" + +from __future__ import annotations +import uuid as _uuid +from typing import Any, Callable +from anyplotlib.callbacks import CallbackRegistry, Event, _EventMixin + + +class Widget(_EventMixin): + """Base class for all overlay widgets. + + Provides attribute-based state access, callbacks for interaction events, + and automatic synchronization with the JavaScript renderer. + + Parameters + ---------- + wtype : str + Widget type (e.g., 'rectangle', 'circle', 'crosshair'). + push_fn : Callable + Zero-arg callback to send position updates to the JavaScript renderer. + **kwargs : dict + Initial widget state (position, size, color, etc.). + + Attributes + ---------- + callbacks : CallbackRegistry + Event callback registry. Register handlers via + ``widget.add_event_handler(fn, "pointer_move")`` or as a decorator: + ``@widget.add_event_handler("pointer_move")``. + + Common event types: + + - ``"pointer_move"`` — fires on every drag frame + - ``"pointer_up"`` — fires once when drag settles + - ``"pointer_down"`` — fires on click/press event + """ + + def __init__(self, wtype: str, push_fn: Callable, **kwargs): + self._id: str = str(_uuid.uuid4())[:8] + self._type: str = wtype + self._data: dict = dict(kwargs) + self._data["id"] = self._id + self._data["type"] = wtype + self._push_fn: Callable = push_fn + self.callbacks: CallbackRegistry = CallbackRegistry() + + # ── attribute read ──────────────────────────────────────────────── + + def __getattr__(self, key: str): + """Access widget properties as attributes (read-only).""" + if key.startswith("_"): + raise AttributeError(key) + try: + return self._data[key] + except KeyError: + raise AttributeError( + f"{type(self).__name__} has no attribute {key!r}. " + f"Available: {list(self._data)}" + ) from None + + # ── attribute write — routes public assignments through set() ──── + + def __setattr__(self, key: str, value) -> None: + """Update widget properties via attribute assignment.""" + # Private attrs and 'callbacks' bypass set() + if key.startswith("_") or key == "callbacks": + super().__setattr__(key, value) + return + # During __init__ _data may not exist yet + try: + object.__getattribute__(self, "_data") + except AttributeError: + super().__setattr__(key, value) + return + self.set(**{key: value}) + + # ── set / get ───────────────────────────────────────────────────── + + def set(self, _push: bool = True, **kwargs) -> None: + """Update properties and send targeted update to JavaScript. + + Parameters + ---------- + _push : bool, optional + Whether to push update to renderer. Default True. + Set to False internally to avoid echo loops. + **kwargs : dict + Properties to update (e.g., x=100, y=50, radius=20). + + Notes + ----- + Updates are sent as targeted widget updates, not full panel re-renders. + This is more efficient for frequent updates during dragging. + """ + self._data.update(kwargs) + if _push: + self._push_fn() + self.callbacks.fire(Event("pointer_move", source=self)) + + def get(self, key: str, default=None): + """Get a widget property by name. + + Parameters + ---------- + key : str + Property name. + default : optional + Default value if property not found. + + Returns + ------- + object + The property value. + """ + return self._data.get(key, default) + + def to_dict(self) -> dict: + """Return a dict copy of the widget state. + + Returns + ------- + dict + All widget properties including id and type. + """ + return dict(self._data) + + # ── visibility ──────────────────────────────────────────────────────── + + @property + def visible(self) -> bool: + """``True`` if the widget is rendered; ``False`` if hidden.""" + return self._data.get("visible", True) + + @visible.setter + def visible(self, value: bool) -> None: + self.show() if value else self.hide() + + def show(self) -> None: + """Show the widget. Does not fire ``pointer_move`` callbacks.""" + self._data["visible"] = True + self._push_fn() + + def hide(self) -> None: + """Hide the widget without removing it or its callbacks. + + Call :meth:`show` to make it visible again. + Does not fire ``pointer_move`` callbacks. + """ + self._data["visible"] = False + self._push_fn() + + # ── JS → Python sync ────────────────────────────────────────────── + + def _update_from_js(self, msg: dict, event_type: str = "pointer_move") -> bool: + """Apply incoming JS state without pushing back (avoids echo). + + Updates widget ``_data`` with widget-specific state fields from msg, + then fires widget callbacks with a flat Event. + + Parameters + ---------- + msg : dict + Full raw event message from JS. + event_type : str + One of the pointer event types (``pointer_move``, ``pointer_up``, + ``pointer_down``). + + Returns + ------- + bool + True if any widget state changed. + """ + _envelope = { + "source", "panel_id", "event_type", "widget_id", + "time_stamp", "modifiers", "button", "buttons", + } + changed = False + for k, v in msg.items(): + if k in ("id", "type") or k in _envelope: + continue + if self._data.get(k) != v: + self._data[k] = v + changed = True + + if changed or event_type in ("pointer_up", "pointer_down"): + event = Event( + event_type=event_type, + source=self, + time_stamp=msg.get("time_stamp", 0.0), + modifiers=msg.get("modifiers", []), + x=msg.get("x"), + y=msg.get("y"), + button=msg.get("button"), + buttons=msg.get("buttons", 0), + xdata=msg.get("xdata"), + ydata=msg.get("ydata"), + ) + self.callbacks.fire(event) + return changed + + # ── repr ────────────────────────────────────────────────────────── + + def __repr__(self) -> str: + props = ", ".join( + f"{k}={v:.4g}" if isinstance(v, float) else f"{k}={v!r}" + for k, v in self._data.items() + if k not in ("id", "type", "color") + ) + return f"{type(self).__name__}({props})" + + @property + def id(self) -> str: + """Return the widget's unique identifier.""" + return self._id diff --git a/anyplotlib/widgets/_widgets1d.py b/anyplotlib/widgets/_widgets1d.py new file mode 100644 index 00000000..f25141db --- /dev/null +++ b/anyplotlib/widgets/_widgets1d.py @@ -0,0 +1,109 @@ +""" +widgets/_widgets1d.py +===================== +Interactive overlay widgets for 1-D line panels (Plot1D). +""" + +from __future__ import annotations +from anyplotlib.widgets._base import Widget + + +class VLineWidget(Widget): + """Draggable vertical line overlay widget for 1-D plots. + + Allows interactive selection of a single x-axis value. The line can be + dragged left/right to change the selected position. + + Parameters + ---------- + push_fn : Callable + Update callback. + x : float + Initial x-position in data coordinates. + color : str, optional + CSS colour for the line. Default ``"#00e5ff"``. + """ + def __init__(self, push_fn, *, x, color="#00e5ff"): + super().__init__("vline", push_fn, x=float(x), color=color) + + +class HLineWidget(Widget): + """Draggable horizontal line overlay widget for bar charts. + + Allows interactive selection of a single y-axis value. The line can be + dragged up/down to change the selected value. + + Parameters + ---------- + push_fn : Callable + Update callback. + y : float + Initial y-position in data coordinates. + color : str, optional + CSS colour for the line. Default ``"#00e5ff"``. + """ + def __init__(self, push_fn, *, y, color="#00e5ff"): + super().__init__("hline", push_fn, y=float(y), color=color) + + +class RangeWidget(Widget): + """Draggable range selection widget. + + Two display styles are available: + + ``style='band'`` (default) + Two connected vertical lines with a translucent fill band. Either + line can be dragged independently; the whole band can be dragged by + clicking inside it. + + ``style='fwhm'`` + Two circular handles joined by a dashed horizontal line drawn at + height *y* (the half-maximum level). Only the x-positions of the + handles are draggable. Use this to show/edit a FWHM interval on a + peak. + + Parameters + ---------- + push_fn : Callable + Update callback. + x0, x1 : float + Initial left and right positions in data coordinates. + color : str, optional + CSS colour. Default ``"#00e5ff"``. + style : {'band', 'fwhm'}, optional + Visual style. Default ``"band"``. + y : float, optional + Y-position (data coordinates) for the connecting line when + ``style='fwhm'``. Ignored for ``style='band'``. Default ``0.0``. + """ + def __init__(self, push_fn, *, x0, x1, color="#00e5ff", + style: str = "band", y: float = 0.0): + super().__init__("range", push_fn, + x0=float(x0), x1=float(x1), color=color, + style=str(style), y=float(y)) + + +class PointWidget(Widget): + """Draggable point (control point) overlay widget for 1-D plots. + + A free-moving handle that can be dragged to any position within the + plot area. Reports its data-space ``x`` and ``y`` coordinates back + to Python via the standard callback hooks. + + Parameters + ---------- + push_fn : Callable + Update callback. + x : float + Initial x position in data coordinates. + y : float + Initial y position in data coordinates (value axis). + color : str, optional + CSS colour for the handle. Default ``"#00e5ff"``. + show_crosshair : bool, optional + If ``True`` (default), draw dashed crosshair guide lines through the + handle. Set to ``False`` for a bare draggable dot with no guides. + """ + def __init__(self, push_fn, *, x, y, color="#00e5ff", show_crosshair=True): + super().__init__("point", push_fn, x=float(x), y=float(y), color=color, + show_crosshair=bool(show_crosshair)) diff --git a/anyplotlib/widgets/_widgets2d.py b/anyplotlib/widgets/_widgets2d.py new file mode 100644 index 00000000..5a2456dc --- /dev/null +++ b/anyplotlib/widgets/_widgets2d.py @@ -0,0 +1,141 @@ +""" +widgets/_widgets2d.py +===================== +Interactive overlay widgets for 2-D image panels (Plot2D / InsetAxes). +""" + +from __future__ import annotations +from anyplotlib.widgets._base import Widget + + +class RectangleWidget(Widget): + """Draggable rectangle overlay widget for 2-D plots. + + Parameters + ---------- + push_fn : Callable + Update callback. + x, y : float + Top-left corner position in pixel/data coordinates. + w, h : float + Width and height in pixel/data coordinates. + color : str, optional + CSS colour for the rectangle outline. Default ``"#00e5ff"``. + """ + def __init__(self, push_fn, *, x, y, w, h, color="#00e5ff"): + super().__init__("rectangle", push_fn, + x=float(x), y=float(y), + w=float(w), h=float(h), color=color) + + +class CircleWidget(Widget): + """Draggable circle overlay widget for 2-D plots. + + Parameters + ---------- + push_fn : Callable + Update callback. + cx, cy : float + Center position in pixel/data coordinates. + r : float + Radius in pixel/data coordinates. + color : str, optional + CSS colour for the circle outline. Default ``"#00e5ff"``. + """ + def __init__(self, push_fn, *, cx, cy, r, color="#00e5ff"): + super().__init__("circle", push_fn, + cx=float(cx), cy=float(cy), r=float(r), color=color) + + +class AnnularWidget(Widget): + """Draggable annular (ring) overlay widget for 2-D plots. + + Parameters + ---------- + push_fn : Callable + Update callback. + cx, cy : float + Center position in pixel/data coordinates. + r_outer, r_inner : float + Outer and inner radii in pixel/data coordinates. + Inner radius must be less than outer radius. + color : str, optional + CSS colour for the ring outline. Default ``"#00e5ff"``. + + Raises + ------ + ValueError + If r_inner >= r_outer. + """ + def __init__(self, push_fn, *, cx, cy, r_outer, r_inner, color="#00e5ff"): + if r_inner >= r_outer: + raise ValueError("r_inner must be < r_outer") + super().__init__("annular", push_fn, + cx=float(cx), cy=float(cy), + r_outer=float(r_outer), r_inner=float(r_inner), + color=color) + + +class CrosshairWidget(Widget): + """Draggable crosshair overlay widget for 2-D plots. + + Parameters + ---------- + push_fn : Callable + Update callback. + cx, cy : float + Center position in pixel/data coordinates. + color : str, optional + CSS colour for the crosshair. Default ``"#00e5ff"``. + """ + def __init__(self, push_fn, *, cx, cy, color="#00e5ff"): + super().__init__("crosshair", push_fn, + cx=float(cx), cy=float(cy), color=color) + + +class PolygonWidget(Widget): + """Draggable polygon overlay widget for 2-D plots. + + Parameters + ---------- + push_fn : Callable + Update callback. + vertices : list of tuple + Polygon vertices ``[(x0, y0), (x1, y1), ...]`` in pixel/data coordinates. + Must have at least 3 vertices. + color : str, optional + CSS colour for the polygon outline. Default ``"#00e5ff"``. + + Raises + ------ + ValueError + If fewer than 3 vertices provided. + """ + def __init__(self, push_fn, *, vertices, color="#00e5ff"): + verts = [[float(x), float(y)] for x, y in vertices] + if len(verts) < 3: + raise ValueError("polygon needs >= 3 vertices") + super().__init__("polygon", push_fn, vertices=verts, color=color) + + +class LabelWidget(Widget): + """Text label overlay widget for 2-D plots. + + Parameters + ---------- + push_fn : Callable + Update callback. + x, y : float + Label position in pixel/data coordinates. + text : str, optional + Label text. Default ``"Label"``. + fontsize : int, optional + Font size in points. Default 14. + color : str, optional + CSS colour for the text. Default ``"#00e5ff"``. + """ + def __init__(self, push_fn, *, x, y, text="Label", fontsize=14, + color="#00e5ff"): + super().__init__("label", push_fn, + x=float(x), y=float(y), + text=str(text), fontsize=int(fontsize), color=color) diff --git a/anyplotlib/widgets/_widgets3d.py b/anyplotlib/widgets/_widgets3d.py new file mode 100644 index 00000000..c3660534 --- /dev/null +++ b/anyplotlib/widgets/_widgets3d.py @@ -0,0 +1,50 @@ +""" +widgets/_widgets3d.py +===================== +Interactive overlay widgets for 3-D panels. +""" + +from __future__ import annotations + +from typing import Callable + +from anyplotlib.widgets._base import Widget + + +class PlaneWidget(Widget): + """A draggable axis-aligned plane in a 3-D panel. + + Rendered as a translucent quad spanning the panel's bounds, + perpendicular to *axis* at *position*. Drag it in the browser to slide + it along its normal — ideal as a slice selector for voxel volumes. + Voxels lying on a plane are rendered more opaque (see + :meth:`~anyplotlib.Axes.voxels`). + + Parameters + ---------- + axis : ``"x"`` | ``"y"`` | ``"z"`` + The plane's normal axis. + position : float + Position along *axis* in data coordinates. + color : str, optional + CSS colour of the plane fill and border. + alpha : float, optional + Fill opacity (0–1). Default 0.12. + + Examples + -------- + >>> pw = vol.add_widget("plane", axis="z", position=24) + >>> @pw.add_event_handler("pointer_move") + ... def on_drag(event): + ... print("slice now at", pw.position) + >>> pw.set(position=10) # move it from Python + """ + + def __init__(self, push_fn: Callable, axis: str = "z", + position: float = 0.0, color: str = "#00e5ff", + alpha: float = 0.12): + if axis not in ("x", "y", "z"): + raise ValueError(f"axis must be 'x', 'y', or 'z', got {axis!r}") + super().__init__("plane", push_fn, + axis=axis, position=float(position), + color=color, alpha=float(alpha)) diff --git a/docs/_root/index.html b/docs/_root/index.html new file mode 100644 index 00000000..c6277376 --- /dev/null +++ b/docs/_root/index.html @@ -0,0 +1,16 @@ + + + + + anyplotlib – redirecting… + + + + + +

+ Redirecting to dev documentation… +

+ + + diff --git a/docs/_root/switcher.json b/docs/_root/switcher.json new file mode 100644 index 00000000..9fb40e0d --- /dev/null +++ b/docs/_root/switcher.json @@ -0,0 +1,8 @@ +[ + { + "name": "dev (latest)", + "version": "dev", + "url": "https://cssfrancis.github.io/anyplotlib/dev/" + }, +] + diff --git a/docs/_sg_html_scraper.py b/docs/_sg_html_scraper.py index 8edf07ce..fe9c43ec 100644 --- a/docs/_sg_html_scraper.py +++ b/docs/_sg_html_scraper.py @@ -1,193 +1,21 @@ """ -Custom Sphinx Gallery scraper for anyplotlib Widgets. +_sg_html_scraper.py — compatibility shim +========================================= -Sphinx Gallery requires every scraper to write a PNG file to the path provided -by ``image_path_iterator`` — otherwise it raises ``ExtensionError``. +The canonical implementation has moved to +``anyplotlib.sphinx_anywidget._scraper``. -This scraper: -1. Finds a anyplotlib widget in ``example_globals`` (any object from the ``anyplotlib`` - package that has ``_repr_html_``). -2. Renders a **static thumbnail PNG** via matplotlib for the gallery index. -3. Writes the **full interactive HTML** (iframe + widget JS) alongside the PNG. -4. Returns rST that embeds both: the PNG as a fallback image AND an iframe for - interactive use, using a ``.. raw:: html`` block. -""" - -from __future__ import annotations - -import io -from pathlib import Path - - -# --------------------------------------------------------------------------- -# Helpers -# --------------------------------------------------------------------------- - -def _find_viewer(globals_dict: dict): - """Return the most-recently assigned anyplotlib widget, or None.""" - for val in reversed(list(globals_dict.values())): - module = getattr(type(val), "__module__", "") or "" - if module.startswith("anyplotlib") and callable(getattr(val, "_repr_html_", None)): - return val - return None - - -def _make_thumbnail_png(widget) -> bytes: - """Render a small static thumbnail PNG for the gallery index card.""" - import matplotlib - matplotlib.use("Agg") - import matplotlib.pyplot as plt - import numpy as np - - fig, ax = plt.subplots(figsize=(4, 3), dpi=72) - ax.set_facecolor("#1e1e2e") - fig.patch.set_facecolor("#1e1e2e") - ax.tick_params(colors="#cdd6f4") - for spine in ax.spines.values(): - spine.set_edgecolor("#44475a") - - kind = type(widget).__name__ - - try: - if kind == "Viewer2D": - import json - raw = widget._raw_u8 - cmap = widget.colormap_name or "gray" - ax.imshow(raw, cmap=cmap, aspect="auto", interpolation="nearest") - ax.set_title("Viewer2D", color="#cdd6f4", fontsize=9) - ax.set_xticks([]); ax.set_yticks([]) - - elif kind == "Viewer1D": - import json - data = np.array(json.loads(widget.data_json)) - x_axis = np.array(json.loads(widget.x_axis_json)) - ax.plot(x_axis, data, color="#4fc3f7", linewidth=1) - ax.set_title("Viewer1D", color="#cdd6f4", fontsize=9) - ax.set_facecolor("#181825") +This file is kept so any ``conf.py`` that still does:: - elif kind == "Figure": - from anyplotlib.figure_plots import Plot2D, Plot1D - import json - plots = list(widget._plots_map.values()) - ax.set_title(f"Figure ({widget._nrows}×{widget._ncols})", - color="#cdd6f4", fontsize=9) - if plots: - p = plots[0] - if isinstance(p, Plot2D): - ax.imshow(p._raw_u8, cmap=p._state.get("colormap_name", "gray"), - aspect="auto", interpolation="nearest") - elif isinstance(p, Plot1D): - d = np.asarray(p._state.get("data", [])) - x = np.asarray(p._state.get("x_axis", np.arange(len(d)))) - ax.plot(x, d, color=p._state.get("line_color", "#4fc3f7"), linewidth=1) - ax.set_xticks([]); ax.set_yticks([]) - else: - ax.text(0.5, 0.5, kind, ha="center", va="center", - color="#cdd6f4", transform=ax.transAxes) - ax.axis("off") + from _sg_html_scraper import ViewerScraper - except Exception: - ax.text(0.5, 0.5, kind, ha="center", va="center", - color="#cdd6f4", transform=ax.transAxes) - ax.axis("off") - - plt.tight_layout(pad=0.3) - buf = io.BytesIO() - fig.savefig(buf, format="png", dpi=72, facecolor=fig.get_facecolor()) - plt.close(fig) - buf.seek(0) - return buf.read() - - -# --------------------------------------------------------------------------- -# Scraper -# --------------------------------------------------------------------------- - -class ViewerScraper: - """Sphinx Gallery image scraper that embeds anyplotlib Widgets as live iframes.""" - - def __repr__(self) -> str: - return "ViewerScraper()" - - def __call__(self, block, block_vars, gallery_conf): - globals_dict = block_vars.get("example_globals", {}) - widget = _find_viewer(globals_dict) - if widget is None: - return "" - - # ── 1. Write the thumbnail PNG (Sphinx Gallery requires this) ────── - image_path_iterator = block_vars["image_path_iterator"] - png_path = Path(next(image_path_iterator)) - png_path.parent.mkdir(parents=True, exist_ok=True) - png_path.write_bytes(_make_thumbnail_png(widget)) - - # ── 2. Write the standalone HTML into docs/_static/viewer_widgets/ ─ - # - # WHY NOT srcdoc=: - # The srcdoc= attribute value is thousands of lines. Docutils parses - # the content of a ``.. raw:: html`` block as indented text, so a - # multi-line attribute value confuses the RST parser and the block is - # silently dropped from the output. - # - # WHY NOT src= into auto_examples/images/: - # Sphinx only copies *.png files from that directory to _build/html/. - # Any .html file referenced via src= would be a 404 in the built docs. - # - # SOLUTION: - # Write to docs/_static/viewer_widgets/ which is in html_static_path - # and is copied verbatim by Sphinx. The src= path is a single line, - # which is safe for docutils. - try: - from anyplotlib._repr_utils import build_standalone_html, _widget_px - docs_dir = Path(gallery_conf["src_dir"]) - widgets_dir = docs_dir / "_static" / "viewer_widgets" - widgets_dir.mkdir(parents=True, exist_ok=True) - - html_name = png_path.stem + ".html" # sphx_glr_plot_..._001.html - html_path = widgets_dir / html_name - - inner_html = build_standalone_html(widget, resizable=False) - html_path.write_text(inner_html, encoding="utf-8") - w, h = _widget_px(widget) - interactive = True - except Exception: - interactive = False +continues to work without changes. All public helpers that existed in the +original module are re-exported here so downstream imports keep working. +""" - # ── 3. Return rST ────────────────────────────────────────────────── - if interactive: - # Compute the relative path from the *built* HTML page back up to - # _static/viewer_widgets/. - # - # The PNG (and its sibling HTML) sits at e.g.: - # /auto_examples/Markers/images/sphx_glr_plot_circles_001.png - # The built page for this example is at: - # /auto_examples/Markers/plot_circles.html - # _static/viewer_widgets/ lives at: - # /_static/viewer_widgets/ - # - # We derive depth by counting the parts of the gallery output path - # relative to the Sphinx source dir (which mirrors the build root). - try: - src_dir = Path(gallery_conf["src_dir"]) - # png_path is inside the gallery output images/ subdir. - # The page itself is one directory above images/. - page_dir = png_path.parent.parent # strip /images - rel_parts = page_dir.relative_to(src_dir).parts - depth = len(rel_parts) # e.g. 2 for auto_examples/Markers - except Exception: - depth = 1 - prefix = "../" * depth - src = f"{prefix}_static/viewer_widgets/{html_name}" - return ( - "\n\n.. raw:: html\n\n" - f'
' - f'
\n\n' - ) - else: - rel_png = png_path.name - return ( - f"\n\n.. image:: {rel_png}\n" - f" :width: 100%\n\n" - ) +from anyplotlib.sphinx_anywidget._scraper import ( # noqa: F401 + AnywidgetScraper, + AnywidgetScraper as ViewerScraper, + _make_thumbnail_png, + _iframe_html, +) diff --git a/docs/_static/anyplotlib.svg b/docs/_static/anyplotlib.svg new file mode 100644 index 00000000..8101881a --- /dev/null +++ b/docs/_static/anyplotlib.svg @@ -0,0 +1,48 @@ + + + + + + + + + + diff --git a/docs/_static/custom.css b/docs/_static/custom.css new file mode 100644 index 00000000..0cbae114 --- /dev/null +++ b/docs/_static/custom.css @@ -0,0 +1,23 @@ +/* + * custom.css — anyplotlib docs overrides + * + * Hide empty syntax-highlight blocks that sphinx-gallery emits when an + * entire code cell is wrapped in # sphinx_gallery_start/end_ignore. + * The cell still executes (figures are scraped) but produces a visible + * blank
 element; this rule makes those invisible.
+ */
+.highlight pre:not(:has(:not(:empty))) {
+    display: none;
+}
+.highlight-Python:has(pre > span:only-child:empty) {
+    display: none;
+}
+
+/*
+ * Fallback for browsers that don't support :has() — target the pattern
+ * sphinx produces for an empty highlighted block:
+ *   
\n
+ * We can't do this in pure CSS without :has, so we rely on the rule above + * for modern browsers and accept a small blank gap on older ones. + */ + diff --git a/docs/_static/pyodide_bridge.js b/docs/_static/pyodide_bridge.js new file mode 100644 index 00000000..d0ea5b6c --- /dev/null +++ b/docs/_static/pyodide_bridge.js @@ -0,0 +1,400 @@ +/** + * pyodide_bridge.js + * + * Adds a single floating "⚡" button to any docs page that contains + * anyplotlib figure iframes. Clicking it boots ONE shared Pyodide instance + * for the entire page, runs each example's Python source exactly once, then + * wires Python ↔ JS via postMessage so on_change / on_release callbacks fire + * live in the browser — no server, no Jupyter kernel. + * + * Architecture + * ──────────── + * Parent page (this script) + * ├─ Pyodide WASM runtime (loaded once from CDN on button click) + * ├─ anyplotlib wheel built at docs-build time → _static/wheels/ + * ├─