numpy.dot-形状错误 - 神经网络

发布于 2025-02-03 00:20:58 字数 451 浏览 1 评论 0 原文

我正在尝试将这些 a1 w2 矩阵( z2 = w2.dot(a1))倍增:

A1 : [[0.42940542]
 [0.55013895]]
W2 : [[-0.4734037  -0.39642393 -0.05440914 -0.24011293 -0.03670913 -0.37523234]
 [-0.45501004  0.23881832  0.21831658  0.32237388  0.25674681  0.27956714]]

但是我得到了此错误形状(2,6)和(2,1)未对齐:6(DIM 1)!= 2(DIM 0),为什么?用(2,6)矩阵将(2,1)乘不正常?

因为我有一个带有2个节点的隐藏图层,并带有 6节点的输出图层

I am trying to multiply these A1 and W2 matrices (Z2 = W2.dot(A1)):

A1 : [[0.42940542]
 [0.55013895]]
W2 : [[-0.4734037  -0.39642393 -0.05440914 -0.24011293 -0.03670913 -0.37523234]
 [-0.45501004  0.23881832  0.21831658  0.32237388  0.25674681  0.27956714]]

But I am getting this error shapes (2,6) and (2,1) not aligned: 6 (dim 1) != 2 (dim 0), why? Isn't it normal to multiply a (2,1) with a (2,6) matrix?

Because I have a hidden layer with 2 nodes and output layer with 6 nodes

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

嘴硬脾气大 2025-02-10 00:20:58

从数学上讲,这是不可能的,因为您将A(2,6)矩阵乘以(2,1)。您需要做的就是转置W2。

ps:请注意,在线性代数np.dot(W2.T,a1)中与np.dot(a1.t,w2)不同,

import numpy as np

A1 = np.asarray([[0.42940542], [0.55013895]])
W2 = np.asarray([[
    -0.4734037, -0.39642393, -0.05440914, -0.24011293, -0.03670913, -0.37523234
], [-0.45501004, 0.23881832, 0.21831658, 0.32237388, 0.25674681, 0.27956714]])
print(W2.shape, A1.shape)  # (2, 6), (2, 1)
Z2 = W2.T @ A1
print(Z2)

结果是:[[-0.45360086] [[-0.45360086] [[ -0.03884332] [0.09674087] [0.07424463] [0.12548332] [-0.00732603]]

Mathematically this is impossible because your multiplying a (2, 6) matrix by (2, 1). All you need to do is to transpose W2.

P.S: Note that in linear algebra np.dot(W2.T, A1) is not the same as np.dot(A1.T, W2)

import numpy as np

A1 = np.asarray([[0.42940542], [0.55013895]])
W2 = np.asarray([[
    -0.4734037, -0.39642393, -0.05440914, -0.24011293, -0.03670913, -0.37523234
], [-0.45501004, 0.23881832, 0.21831658, 0.32237388, 0.25674681, 0.27956714]])
print(W2.shape, A1.shape)  # (2, 6), (2, 1)
Z2 = W2.T @ A1
print(Z2)

The result would be: [[-0.45360086] [-0.03884332] [ 0.09674087] [ 0.07424463] [ 0.12548332] [-0.00732603]]

~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文