如何以简洁的方式比较ndarrays的形状?

发布于 2025-02-03 00:22:49 字数 850 浏览 4 评论 0 原文

我是新来的生锈。

假设矩阵 a 具有Shape (N1,N2) B 具有(M1,M2)和<代码> C 具有(K1,K2)。我想检查 a b 可以乘以(作为矩阵),并且 a * b 的形状等于> C 。换句话说,(n2 == m1)&amp;&amp; (n1 == k1)&amp;&amp; (m2 == k2)

use ndarray::Array2;

// a : Array2<i64>
// b : Array2<i64>
// c : Array2<i64>

a>返回阵列的形状作为切片。 什么是简洁的方法?

是否从 .shape()保证具有长度2的返回数组,还是应该检查它?如果保证它,是否可以跳过检查?

let n1 = a.shape().get(0);  // this is Optional<i64>

I'm new to Rust.

Suppose a matrix a has shape (n1, n2), b has (m1, m2), and c has (k1, k2). I would like to check that a and b can be multiplied (as matrices) and the shape of a * b is equal to c. In other words, (n2 == m1) && (n1 == k1) && (m2 == k2).

use ndarray::Array2;

// a : Array2<i64>
// b : Array2<i64>
// c : Array2<i64>

.shape method returns the shape of the array as a slice.
What is the concise way to do it?

Is the returned array from .shape() guaranteed to have length 2, or should I check it? If it guaranteed, is there a way to skip the None checking?

let n1 = a.shape().get(0);  // this is Optional<i64>

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

何处潇湘 2025-02-10 00:22:49

对于array2,具体是。如果您仅使用2D阵列,那么这可能是最佳选择。他们返回使用情况,因此不需要无需检查。

use ndarray::prelude::*;

fn is_valid_matmul(a: &Array2<i64>, b: &Array2<i64>, c: &Array2<i64>) -> bool {
    //nrows() and ncols() are only valid for Array2, 
    //[arr.nrows(), arr.ncols()] = [arr.shape()[0], arr.shape()[1]]
    return a.ncols() == b.nrows() && b.ncols() == c.ncols() && a.nrows() == c.nrows();
}
fn main() {
    let a = Array2::<i64>::zeros((3, 5));
    let b = Array2::<i64>::zeros((5, 6));
    let c_valid = Array2::<i64>::zeros((3, 6));
    let c_invalid = Array2::<i64>::zeros((8, 6));

    println!("is_valid_matmul(&a, &b, &c_valid) = {}", is_valid_matmul(&a, &b, &c_valid));
    println!("is_valid_matmul(&a, &b, &c_invalid) = {}", is_valid_matmul(&a, &b, &c_invalid));
}
/*
output:
is_valid_matmul(&a, &b, &c_valid) = true
is_valid_matmul(&a, &b, &c_invalid) = false
*/

For Array2 specifically there are .ncols() and .nrows() methods. If you are only working with 2d arrays then this is probably the best choice. They return usize, so no None checking is required.

use ndarray::prelude::*;

fn is_valid_matmul(a: &Array2<i64>, b: &Array2<i64>, c: &Array2<i64>) -> bool {
    //nrows() and ncols() are only valid for Array2, 
    //[arr.nrows(), arr.ncols()] = [arr.shape()[0], arr.shape()[1]]
    return a.ncols() == b.nrows() && b.ncols() == c.ncols() && a.nrows() == c.nrows();
}
fn main() {
    let a = Array2::<i64>::zeros((3, 5));
    let b = Array2::<i64>::zeros((5, 6));
    let c_valid = Array2::<i64>::zeros((3, 6));
    let c_invalid = Array2::<i64>::zeros((8, 6));

    println!("is_valid_matmul(&a, &b, &c_valid) = {}", is_valid_matmul(&a, &b, &c_valid));
    println!("is_valid_matmul(&a, &b, &c_invalid) = {}", is_valid_matmul(&a, &b, &c_invalid));
}
/*
output:
is_valid_matmul(&a, &b, &c_valid) = true
is_valid_matmul(&a, &b, &c_invalid) = false
*/
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文