Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/parallel/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
//! - [`ArrayViewMut`] `.into_par_iter()`
//! - [`AxisIter`], [`AxisIterMut`] `.into_par_iter()`
//! - [`AxisChunksIter`], [`AxisChunksIterMut`] `.into_par_iter()`
//! - [`ExactChunks`], [`ExactChunksMut`] `.into_par_iter()`
//! - [`Zip`] `.into_par_iter()`
//!
//! The following other parallelized methods exist:
Expand Down Expand Up @@ -94,6 +95,23 @@
//! assert_eq!(shapes, [vec![3, 16], vec![1, 16]]);
//! ```
//!
//! ## Exact chunks
//!
//! Use parallel `.exact_chunks()` to process only complete chunks of an array.
//!
//! ```
//! use ndarray::Array;
//! use ndarray::parallel::prelude::*;
//!
//! let a = Array::linspace(0.0..=63.0, 64).into_shape_with_order((8, 8)).unwrap();
//! let sum: f64 = a.exact_chunks((2, 4))
//! .into_par_iter()
//! .map(|chunk| chunk.sum())
//! .sum();
//!
//! assert_eq!(sum, a.sum());
//! ```
//!
//! ## Zip
//!
//! Use zip for lock step function application across several arrays
Expand All @@ -118,7 +136,9 @@
//! ```

#[allow(unused_imports)] // used by rustdoc links
use crate::iter::{AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut};
use crate::iter::{
AxisChunksIter, AxisChunksIterMut, AxisIter, AxisIterMut, ExactChunks, ExactChunksMut,
};
#[allow(unused_imports)] // used by rustdoc links
use crate::{ArcArray, Array, ArrayBase, ArrayView, ArrayViewMut, Zip};

Expand Down
115 changes: 114 additions & 1 deletion src/parallel/par.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@ use crate::iter::AxisChunksIter;
use crate::iter::AxisChunksIterMut;
use crate::iter::AxisIter;
use crate::iter::AxisIterMut;
use crate::iter::ExactChunks;
use crate::iter::ExactChunksMut;
use crate::split_at::SplitPreference;
use crate::Dimension;
use crate::{ArrayView, ArrayViewMut};
use crate::{ArrayView, ArrayViewMut, Axis};

/// Parallel iterator wrapper.
#[derive(Copy, Clone, Debug)]
Expand Down Expand Up @@ -225,6 +227,117 @@ par_iter_view_wrapper!(ArrayViewMut, [Sync + Send]);

use crate::{FoldWhile, NdProducer, Zip};

macro_rules! par_ndproducer_wrapper {
// thread_bounds are either Sync or Send + Sync
($producer_name:ident, [$($thread_bounds:tt)*]) => {
/// Requires crate feature `rayon`.
impl<'a, A, D> IntoParallelIterator for $producer_name<'a, A, D>
where D: Dimension,
A: $($thread_bounds)*,
{
type Item = <Self as NdProducer>::Item;
type Iter = Parallel<Self>;
fn into_par_iter(self) -> Self::Iter {
Parallel {
iter: self,
min_len: DEFAULT_MIN_LEN,
}
}
}

impl<'a, A, D> ParallelIterator for Parallel<$producer_name<'a, A, D>>
where D: Dimension,
A: $($thread_bounds)*,
{
type Item = <$producer_name<'a, A, D> as NdProducer>::Item;
fn drive_unindexed<C>(self, consumer: C) -> C::Result
where C: UnindexedConsumer<Self::Item>
{
bridge_unindexed(ParallelProducer(self.iter, self.min_len), consumer)
}

fn opt_len(&self) -> Option<usize> {
Some(self.iter.raw_dim().size())
}
}

impl<'a, A, D> Parallel<$producer_name<'a, A, D>>
where D: Dimension,
A: $($thread_bounds)*,
{
/// Sets the minimum number of chunks desired to process in each job. This will not be
/// split any smaller than this length, but of course a producer could already be smaller
/// to begin with.
///
/// ***Panics*** if `min_len` is zero.
pub fn with_min_len(self, min_len: usize) -> Self {
assert_ne!(min_len, 0, "Minimum number of elements must at least be one to avoid splitting off empty tasks.");

Self {
min_len,
..self
}
}
}

impl<'a, A, D> UnindexedProducer for ParallelProducer<$producer_name<'a, A, D>>
where D: Dimension,
A: $($thread_bounds)*,
{
type Item = <$producer_name<'a, A, D> as NdProducer>::Item;
fn split(self) -> (Self, Option<Self>) {
let dim = self.0.raw_dim();
if dim.size() <= self.1 {
return (self, None)
}

let Some((axis, &len)) = dim
.slice()
.iter()
.enumerate()
.max_by_key(|&(_, len)| len)
else {
return (self, None)
};
if len <= 1 {
return (self, None)
}

let (a, b) = self.0.split_at(Axis(axis), len / 2);
(ParallelProducer(a, self.1), Some(ParallelProducer(b, self.1)))
}

fn fold_with<F>(self, folder: F) -> F
where F: Folder<Self::Item>,
{
Zip::from(self.0).fold_while(folder, |mut folder, elt| {
folder = folder.consume(elt);
if folder.full() {
FoldWhile::Done(folder)
} else {
FoldWhile::Continue(folder)
}
}).into_inner()
}
}

impl<'a, A, D> IntoIterator for ParallelProducer<$producer_name<'a, A, D>>
where D: Dimension,
A: $($thread_bounds)*,
{
type Item = <$producer_name<'a, A, D> as IntoIterator>::Item;
type IntoIter = <$producer_name<'a, A, D> as IntoIterator>::IntoIter;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}

};
}

par_ndproducer_wrapper!(ExactChunks, [Sync]);
par_ndproducer_wrapper!(ExactChunksMut, [Send + Sync]);

macro_rules! zip_impl {
($([$($p:ident)*],)+) => {
$(
Expand Down
34 changes: 34 additions & 0 deletions tests/par_rayon.rs
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,40 @@ fn test_axis_chunks_iter()
assert_eq!(s, a.sum());
}

#[test]
fn test_exact_chunks()
{
let a = Array::from_iter(0..100)
.into_shape_with_order((10, 10))
.unwrap();
let s: i32 = a
.exact_chunks((2, 5))
.into_par_iter()
.map(|chunk| chunk.sum())
.sum();
assert_eq!(s, a.sum());
}

#[test]
fn test_exact_chunks_mut()
{
let mut a = Array2::<usize>::zeros((7, 8));
a.exact_chunks_mut((2, 3))
.into_par_iter()
.for_each(|mut chunk| chunk.fill(1));

let ans = array![
[1, 1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0],
[1, 1, 1, 1, 1, 1, 0, 0],
[0, 0, 0, 0, 0, 0, 0, 0],
];
assert_eq!(a, ans);
}

#[test]
#[cfg(feature = "approx")]
fn test_axis_chunks_iter_mut()
Expand Down