sklearn train_test_split() 分层不适用于 2D 标签

发布于 2025-01-12 09:23:02 字数 2678 浏览 0 评论 0原文

我训练一个具有稀疏二维标签的 seq-to-seq 模型。为每个时间步长单独定义了类索引。这是多类单标签任务(softmax)。

这里的分层对于通过原始数据集中的标签分割来平衡新生成的数据集中的标签很有用。

# load dataset
f = np.load('./new_dataset.npz')
signals = f['signals']
labels = f['labels']

# downsample to 50 Hz (6 sec windows)
if (signals.shape[0] % 2) != 0:
    signals = signals[:-1]
    labels = labels[:-1]

signals = np.reshape(signals, (-1, 600, signals.shape[-1]))
labels = np.reshape(labels, (-1, 600))

signals = signals[:, ::2]
labels = labels[:, ::2]

print(f"signals: {signals.shape}")
print(f"labels: {labels.shape}")

# split to train-test
X_train, X_test, y_train, y_test = train_test_split(
    signals, labels, test_size=0.15, random_state=9, stratify=labels
)
X_train, X_val, y_train, y_val = train_test_split(
    X_train, y_train, test_size=0.15, random_state=9, stratify=y_train
)
print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)
print(X_val.shape, y_val.shape)

结果

signals: (41564, 300, 6)
labels: (41564, 300)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/var/folders/7v/fqqcktvs23qc8fwgftjpz_gh0000gn/T/ipykernel_15879/1199612105.py in <module>
     42 
     43 # split to train-test
---> 44 X_train, X_test, y_train, y_test = train_test_split(
     45     signals, labels, test_size=0.15, random_state=9, stratify=labels
     46 )

~/miniforge3/lib/python3.9/site-packages/sklearn/model_selection/_split.py in train_test_split(test_size, train_size, random_state, shuffle, stratify, *arrays)
   2439         cv = CVClass(test_size=n_test, train_size=n_train, random_state=random_state)
   2440 
-> 2441         train, test = next(cv.split(X=arrays[0], y=stratify))
   2442 
   2443     return list(

~/miniforge3/lib/python3.9/site-packages/sklearn/model_selection/_split.py in split(self, X, y, groups)
   1598         """
   1599         X, y, groups = indexable(X, y, groups)
-> 1600         for train, test in self._iter_indices(X, y, groups):
   1601             yield train, test
   1602 

~/miniforge3/lib/python3.9/site-packages/sklearn/model_selection/_split.py in _iter_indices(self, X, y, groups)
   1938         class_counts = np.bincount(y_indices)
   1939         if np.min(class_counts) < 2:
-> 1940             raise ValueError(
   1941                 "The least populated class in y has only 1"
   1942                 " member, which is too few. The minimum"

ValueError: The least populated class in y has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2.

I train a seq-to-seq model with sparse 2D labels. There are indexes of classes defined for each timestep individualy. It's the multi-class single-label task (softmax).

Here is stratify useful for balance labels in newly generated datasets by the labels split in the original dataset.

# load dataset
f = np.load('./new_dataset.npz')
signals = f['signals']
labels = f['labels']

# downsample to 50 Hz (6 sec windows)
if (signals.shape[0] % 2) != 0:
    signals = signals[:-1]
    labels = labels[:-1]

signals = np.reshape(signals, (-1, 600, signals.shape[-1]))
labels = np.reshape(labels, (-1, 600))

signals = signals[:, ::2]
labels = labels[:, ::2]

print(f"signals: {signals.shape}")
print(f"labels: {labels.shape}")

# split to train-test
X_train, X_test, y_train, y_test = train_test_split(
    signals, labels, test_size=0.15, random_state=9, stratify=labels
)
X_train, X_val, y_train, y_val = train_test_split(
    X_train, y_train, test_size=0.15, random_state=9, stratify=y_train
)
print(X_train.shape, y_train.shape)
print(X_test.shape, y_test.shape)
print(X_val.shape, y_val.shape)

Result

signals: (41564, 300, 6)
labels: (41564, 300)
---------------------------------------------------------------------------
ValueError                                Traceback (most recent call last)
/var/folders/7v/fqqcktvs23qc8fwgftjpz_gh0000gn/T/ipykernel_15879/1199612105.py in <module>
     42 
     43 # split to train-test
---> 44 X_train, X_test, y_train, y_test = train_test_split(
     45     signals, labels, test_size=0.15, random_state=9, stratify=labels
     46 )

~/miniforge3/lib/python3.9/site-packages/sklearn/model_selection/_split.py in train_test_split(test_size, train_size, random_state, shuffle, stratify, *arrays)
   2439         cv = CVClass(test_size=n_test, train_size=n_train, random_state=random_state)
   2440 
-> 2441         train, test = next(cv.split(X=arrays[0], y=stratify))
   2442 
   2443     return list(

~/miniforge3/lib/python3.9/site-packages/sklearn/model_selection/_split.py in split(self, X, y, groups)
   1598         """
   1599         X, y, groups = indexable(X, y, groups)
-> 1600         for train, test in self._iter_indices(X, y, groups):
   1601             yield train, test
   1602 

~/miniforge3/lib/python3.9/site-packages/sklearn/model_selection/_split.py in _iter_indices(self, X, y, groups)
   1938         class_counts = np.bincount(y_indices)
   1939         if np.min(class_counts) < 2:
-> 1940             raise ValueError(
   1941                 "The least populated class in y has only 1"
   1942                 " member, which is too few. The minimum"

ValueError: The least populated class in y has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2.

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。

评论(1

帅气称霸 2025-01-19 09:23:02

由于错误表明您的班级之一只有 1 名成员。最小值为 2。
考虑删除或扩展该类。

ValueError: The least populated class in y has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2.

在更改数组的形状之前,使用它来查看每个类的计数:

import collections
print(collections.Counter(np.argmax(labels, axis=1)))

As the error is saying one of your classes has only 1 member. The minimum is 2.
Consider removing or extending that class.

ValueError: The least populated class in y has only 1 member, which is too few. The minimum number of groups for any class cannot be less than 2.

Use this to see the per class count before changing the shape of your array:

import collections
print(collections.Counter(np.argmax(labels, axis=1)))
~没有更多了~
我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
原文