Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for multiple non-matching sets #28

Merged
merged 31 commits into from
Jan 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
d04bdf9
WIP: js renames
clbarnes Mar 31, 2023
67644c4
Remove commented out dep
clbarnes Mar 31, 2023
174fc0b
Utility struct for sampling sets of pairs
clbarnes Mar 31, 2023
5171a74
lin, log, and quant binlookup creation
clbarnes Apr 5, 2023
c731b5c
Refactor smat creation logic into separate structs
clbarnes Apr 5, 2023
438be5b
BinLookup::new_n_quantiles
clbarnes Apr 5, 2023
135e61a
Major smat refactor
clbarnes Apr 5, 2023
0ac5abe
Make better use of thiserror
clbarnes Apr 5, 2023
b5d851e
Various tidying up, docs etc
clbarnes Apr 5, 2023
37ae8df
unused import in benchmarks
clbarnes Apr 5, 2023
d404175
Update python wrapper
clbarnes Apr 5, 2023
fd6a1cb
remove conditional compilation in py
clbarnes Apr 5, 2023
673a9a4
fmt
clbarnes Apr 5, 2023
c4b9dd4
Update py and js wrappers
clbarnes Apr 5, 2023
a201da8
Remove vestigial use_alpha references
clbarnes Apr 5, 2023
0109fa2
score matrix builder for python
clbarnes Apr 5, 2023
edc15b5
Format and lint
clbarnes Apr 5, 2023
19cd71d
python test refactors
clbarnes Apr 5, 2023
523fcd7
Update fib250 benchmark
clbarnes Apr 5, 2023
b411e61
remove unused import
clbarnes Apr 12, 2023
e72de76
update maturin config
clbarnes Apr 12, 2023
8c4d7d3
bump pyo3
clbarnes Apr 6, 2023
2540db6
query_target_pairs method, some refactors
clbarnes Apr 6, 2023
bbf0e9e
use numpy interface for nblastarena
clbarnes Apr 6, 2023
63d1005
refactor binlookup creation
clbarnes Apr 14, 2023
0bdb41c
Fix python interface and tests
clbarnes Jan 18, 2024
b25163f
fix more tests
clbarnes Jan 18, 2024
560c2d0
clippy fix
clbarnes Jan 18, 2024
7bb2855
more clippy fixes
clbarnes Jan 18, 2024
61aa086
fmt
clbarnes Jan 18, 2024
320850a
fix dep installation on CI
clbarnes Jan 18, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 7 additions & 5 deletions .github/workflows/nblast-py.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,22 +7,24 @@ defaults:
working-directory: nblast-py

jobs:

lint:
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v3
with:
python-version: '3.x'
- run: pip install $(grep -E '^(black|flake8|mypy|isort)' requirements.txt)
python-version: "3.x"
- run: pip install $(grep -E '^(ruff|mypy)' requirements.txt)
- run: make lint

test:
strategy:
fail-fast: false
matrix:
python-version: ['3.9', '3.10', '3.11']
python-version:
- "3.9"
- "3.10"
- "3.11"
runs-on: ubuntu-22.04
steps:
- uses: actions/checkout@v3
Expand Down Expand Up @@ -71,7 +73,7 @@ jobs:
toolchain: stable
- uses: actions/setup-python@v3
with:
python-version: '3.10'
python-version: "3.10"
- uses: PyO3/maturin-action@v1
with:
manylinux: auto
Expand Down
15 changes: 8 additions & 7 deletions examples/bench_fib250.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from tqdm import tqdm
import pandas as pd
import numpy as np
import pooch

from pynblast import NblastArena, ScoreMatrix

Expand All @@ -19,7 +20,7 @@


def get_threads():
val = os.environ.get("NBLAST_THREADS", 0) or 0
val = os.environ.get("NBLAST_THREADS", "0")
try:
return int(val)
except ValueError:
Expand All @@ -32,11 +33,13 @@ def get_threads():
logger.warning("Running with THREADS = %s", THREADS)

URL_PREFIX = "https://github.com/clbarnes/nblast-rs/files"
SCORES_NAME = "fib250.aba.csv.zip"

SCORES_URL = f"{URL_PREFIX}/4567582/{SCORES_NAME}"

SCORES_FPATH = here / SCORES_NAME
def get_scores_path() -> Path:
SCORES_NAME = "fib250.aba.csv.zip"

SCORES_URL = f"{URL_PREFIX}/4567582/{SCORES_NAME}"
return Path(pooch.retrieve(SCORES_URL, None))


def df_to_pt_tan_a(df):
Expand All @@ -49,10 +52,8 @@ def df_to_pt_tan_a(df):
def ingest_dotprops():
DOTPROPS_NAME = "fib250.csv.zip"
DOTPROPS_URL = f"{URL_PREFIX}/4567531/{DOTPROPS_NAME}"
DOTPROPS_FPATH = here / DOTPROPS_NAME

if not DOTPROPS_FPATH.is_file():
raise ValueError(f"Download necessary data from\n\t{DOTPROPS_URL}")
DOTPROPS_FPATH = Path(pooch.retrieve(url=DOTPROPS_URL, known_hash=None))

with zf.ZipFile(DOTPROPS_FPATH) as z:
with z.open(DOTPROPS_FPATH.name[:-4]) as f:
Expand Down
2 changes: 1 addition & 1 deletion examples/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ jupyter
pynblast
numpy
tqdm

pooch
60 changes: 59 additions & 1 deletion nblast-js/examples/nblast-app/index.js
Original file line number Diff line number Diff line change
@@ -1,7 +1,65 @@
import init, { NblastArena } from "./node_modules/nblast-js/nblast_js.js";
import init, { NblastArena, makeFlatTangentsAlphas } from "./node_modules/nblast-js/nblast_js.js";

const CACHE = {};

/**
* Calculate tangents and alpha values for given points
* (as an array of 3-length arrays of numbers),
* and returned in flat typed array form as can be passed into the NblastArena.
*
* Tangents and alphas can be given, in which case they will be returned as
* flat typed arrays,
* but it's better to do this elsewhere to save copies between the main and worker threads.
*
* @param {number[][]} points
* @param {number[][]} [tangents]
* @param {number[]} [alphas]
* @returns {Object.<string, Float64Array>} { points, tangents, alphas }, in flattened form which can be passed straight into wasm.
*/
function makeFlatPointsTangentsAlphas(points, tangents, alphas) {
const pointsFlat = flatArray64(points);
let tangentsFlat;
let alphasFlat;

if (tangents != null) {
tangentsFlat = flatArray64(tangents);
alphasFlat = flatArray64(alphas, points.length, 1);
} else {
const tangentsAlphas = makeFlatTangentsAlphas(pointsFlat);
tangentsFlat = tangentsAlphas.slice(0, pointsFlat.length);
alphasFlat = tangentsAlphas.slice(pointsFlat.length);
}

return {
points: pointsFlat,
tangents: tangentsFlat,
alphas: alphasFlat
}
}

/**
* Return a Float64Array, which may contain the contents of (flattened) arr,
* or an array with a particular length and fill value.
*
* @param {(number[]|number[][]|Float64Array)} [arr]
* @param {number} [lengthIfNull]
* @param {number} [fillIfNull]
* @returns {Float64Array}
*/
function flatArray64(arr, lengthIfNull, fillIfNull) {
if (arr == null) {
return new Float64Array(lengthIfNull).fill(fillIfNull);
}
if (arr instanceof Float64Array) {
return arr;
}
if (Array.isArray(arr[0])) {
return new Float64Array(arr.flat());
} else {
return new Float64Array(arr);
}
}

/**
* Class containing NBLASTable neurons and a score matrix.
*
Expand Down
33 changes: 11 additions & 22 deletions nblast-js/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@ fn convert_multi_output(
) -> HashMap<NeuronIdx, HashMap<NeuronIdx, Precision>> {
let mut out: HashMap<NeuronIdx, HashMap<NeuronIdx, f64>> = HashMap::default();
for ((q, t), v) in result.drain() {
out.entry(q).or_insert_with(HashMap::default).insert(t, v);
out.entry(q).or_default().insert(t, v);
}
out
}
Expand All @@ -69,6 +69,7 @@ impl NblastArena {
dot_thresholds: &[f64],
cells: &[f64],
k: usize,
use_alpha: bool,
) -> JsResult<NblastArena> {
let rtable = RangeTable::new_from_bins(
vec![dist_thresholds.to_vec(), dot_thresholds.to_vec()],
Expand All @@ -77,19 +78,19 @@ impl NblastArena {
.map_err(to_js_err)?;
let score_calc = ScoreCalc::Table(rtable);
Ok(Self {
arena: nblast::NblastArena::new(score_calc),
arena: nblast::NblastArena::new(score_calc, use_alpha),
k,
})
}

#[wasm_bindgen(js_name="addPoints")]
#[wasm_bindgen(js_name = "addPoints")]
pub fn add_points(&mut self, flat_points: &[f64]) -> JsResult<usize> {
let points = flat_to_array3(flat_points);
let neuron = RStarTangentsAlphas::new(points, self.k).map_err(JsError::new)?;
Ok(self.arena.add_neuron(neuron))
}

#[wasm_bindgen(js_name="addPointsTangentsAlphas")]
#[wasm_bindgen(js_name = "addPointsTangentsAlphas")]
pub fn add_points_tangents_alphas(
&mut self,
flat_points: &[f64],
Expand All @@ -110,29 +111,27 @@ impl NblastArena {
Ok(self.arena.add_neuron(neuron))
}

#[wasm_bindgen(js_name="queryTarget")]
#[wasm_bindgen(js_name = "queryTarget")]
pub fn query_target(
&self,
query_idx: NeuronIdx,
target_idx: NeuronIdx,
normalize: bool,
symmetry: Option<JsString>,
use_alpha: bool,
) -> JsResult<Option<f64>> {
let sym = parse_symmetry(symmetry)?;
Ok(self
.arena
.query_target(query_idx, target_idx, normalize, &sym, use_alpha))
.query_target(query_idx, target_idx, normalize, &sym))
}

#[wasm_bindgen(js_name="queriesTargets")]
#[wasm_bindgen(js_name = "queriesTargets")]
pub fn queries_targets(
&self,
query_idxs: &[NeuronIdx],
target_idxs: &[NeuronIdx],
normalize: bool,
symmetry: Option<JsString>,
use_alpha: bool,
max_centroid_dist: Option<Precision>,
) -> JsResult<JsValue> {
let sym = parse_symmetry(symmetry)?;
Expand All @@ -141,29 +140,20 @@ impl NblastArena {
target_idxs,
normalize,
&sym,
use_alpha,
None,
max_centroid_dist,
));
Ok(serde_wasm_bindgen::to_value(&out)?)
}

#[wasm_bindgen(js_name="allVAll")]
#[wasm_bindgen(js_name = "allVAll")]
pub fn all_v_all(
&self,
normalize: bool,
symmetry: Option<JsString>,
use_alpha: bool,
max_centroid_dist: Option<Precision>,
) -> JsResult<JsValue> {
let sym = parse_symmetry(symmetry)?;
let out = convert_multi_output(self.arena.all_v_all(
normalize,
&sym,
use_alpha,
None,
max_centroid_dist,
));
let out = convert_multi_output(self.arena.all_v_all(normalize, &sym, max_centroid_dist));
Ok(serde_wasm_bindgen::to_value(&out)?)
}
}
Expand All @@ -181,8 +171,7 @@ pub fn make_flat_tangents_alphas(flat_points: &[f64], k: usize) -> JsResult<Floa
for (idx, val) in neuron
.tangents()
.into_iter()
.map(|n| [n[0], n[1], n[2]])
.flatten()
.flat_map(|n| [n[0], n[1], n[2]])
.chain(neuron.alphas().into_iter())
.enumerate()
{
Expand Down
3 changes: 2 additions & 1 deletion nblast-py/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,10 @@ readme = "README.rst"
edition = "2018"

[dependencies]
pyo3 = { version = "0.17.2", features = ["extension-module"] }
pyo3 = { version = "0.18.2", features = ["extension-module"] }
neurarbor = "0.2.0"
nblast = { path = "../nblast-rs", version = "^0.5.0", features = ["parallel"] }
numpy = "0.18"

[lib]
name = "pynblast"
Expand Down
6 changes: 3 additions & 3 deletions nblast-py/Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -52,12 +52,12 @@ clean-test: ## remove test and coverage artifacts
rm -fr .pytest_cache

fmt:
black $(PY_PATHS)
ruff format $(PY_PATHS)
cargo fmt

lint: ## check style with flake8
flake8 $(PY_PATHS)
black --check $(PY_PATHS)
ruff check $(PY_PATHS)
ruff format --check $(PY_PATHS)

test: ## run tests quickly with the default Python
maturin develop
Expand Down
27 changes: 24 additions & 3 deletions nblast-py/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
[project]
name = "pynblast"
dependencies = [
"numpy"
description = "NBLAST neuron morphology comparison in python (over rust)"
readme = "README.rst"
requires-python = ">=3.9"
authors = [
{name = "Chris L. Barnes", email = "[email protected]"}
]
license = { text = "MIT" }

Expand All @@ -12,11 +15,29 @@ classifiers = [
"Natural Language :: English",
"Programming Language :: Python :: 3.9",
"Programming Language :: Python :: 3.10",
"Programming Language :: Python :: 3.11",
]
dependencies = [
"numpy>=1.21.0"
]

[project.urls]
homepage = "https://pypi.org/project/pynblast/"
documentation = "https://pynblast.readthedocs.io/"
repository = "https://github.com/clbarnes/nblast-rs/nblast-py"

[build-system]
requires = ["maturin>=0.13,<0.14"]
requires = ["maturin==0.14", "numpy>=1.21"]
build-backend = "maturin"

[tool.maturin]
python-source = "python"

[tool.mypy]
ignore_missing_imports = true
plugins = ["numpy.typing.mypy_plugin"]
python_version = "3.9"

[tool.ruff]
extend-exclude = ["docs"]
target-version = "py39"
Loading
Loading