返回介绍

Caffe 源码阅读 - DataLayer&Data Transformer

发布于 2025-02-25 23:04:58 字数 4811 浏览 0 评论 0 收藏 0

又一次回到了 Caffe 的源码阅读的环节,这一次瞄准的目标是网络的输入,现在的 CNN 网络百花齐放,各种各样的网络结构搭配各种各样的输入让人眼花缭乱,所以我们也必要研究一下输入的代码结构。

Caffe 的 DataLayer 基础版的主要目标是读入两种 DB 的训练数据作为输入,而两种 DB 内存储的格式默认是一种叫 Datum 的数据结构。

message Datum {
  optional int32 channels = 1;
  optional int32 height = 2;
  optional int32 width = 3;
  // the actual image data, in bytes
  optional bytes data = 4;
  optional int32 label = 5;
  // Optionally, the datum could also hold float data.
  repeated float float_data = 6;
  // If true data contains an encoded image that need to be decoded
  optional bool encoded = 7 [default = false];
}

可以看出,这种 Datum 的结构主要的服务对象是经典的图像分类任务。我们同时输入两部分信息:图像信息 data 和类别信息 label。对于其他的信息来说,使用这个结构进行存储就显得有些困难了。比方说 Object Detection 的任务,其中还涉及到许多 BoundingBox 的信息,存储的结构要比这个更复杂。比方说一个知名的图像物体检测的网络结构 SSD 的作者开源实现就用到了一种自定义的训练数据存储方式:

// An extension of Datum which contains "rich" annotations.
message AnnotatedDatum {
  enum AnnotationType {
    BBOX = 0;
  }
  optional Datum datum = 1;
  // If there are "rich" annotations, specify the type of annotation.
  // Currently it only supports bounding box.
  // If there are no "rich" annotations, use label in datum instead.
  optional AnnotationType type = 2;
  // Each group contains annotation for a particular class.
  repeated AnnotationGroup annotation_group = 3;
}

当然,创建一个新的 Datum 类型只是开始,我们还需要围绕着这个新的 Datum 创造相关的读取数据的 C++类。当然,在创建这些类之前,我们当然需要了解一下 Caffe 自身的数据层的机制。

为了更好地理解这部分有点绕弯的关系,我们先来上一张几个相关类的关系图:

这其中涉及到了两个线程和一个先向计算的过程,我们一一仔细看下。

DataReader Thread

DataReader 是 Caffe 封装的读取两种 DB 的数据的类,这一步仅仅是把数据从 DB 中读取出来,也是上面图中右下角那个红框所展示的内容。这部分会给每一个读入的数据源创建一个独立的线程,专门负责这个数据源的读入工作。如果我们有多个 Solver,比方说工作在多 GPU 下,而读入的数据源只有一份(比方说 Train 的 DB 只有一个),那么这一个读取数据的线程将会给这些 Solver 一并服务,这其中的原理可以详细看看 DataReader 这一部分。

最终每一个 Solver 里面的 Net 对象的 DataLayer 都会有一个自己的 DataReader 对象,其中会有一对变量:free 和 full。DataReader 线程作为生产者将读入的数据放入 full 中,而下游的 BasePrefetchingDataLayer 的线程(后面会提到)将作为消费者将 full 中的内容取走。Caffe 中继续使用 BlockingQueue 作为生产者和消费者之间同步的结构,并且设置两个队列的容量:

  1. 每一次 DataReader 将 free 中已经被消费过的对象取出,填上新的数据,然后将其塞入 full 中;
  2. 每一次 BasePrefetchingDataLayer 将 full 中的对象取出并消费,然后将其塞入 free 中。

这样就保证了两边通信没有问题。

BasePrefetchingDataLayer Thread

BasePrefetchingDataLayer 从名字上来看就是一个具有预先取出数据功能的数据层。每次前向计算时,我们并不需要在取数据这一步等待,我们完成可以把数据事先取好,等用的时候直接拿出来。这就是它们以线程的形式独立启动的原因。实际上 DataReader 的主要工作是把原始的数据从 DB 中取出,而 BasePrefetchingDataLayer 类要做的就是数据的加工了。这一部分主要完成两件事:

  1. 确定数据层最终的输出(可以不输出 label 的)
  2. 完成数据层预处理(通常要做一些白化数据的简单工作,比如减均值,乘系数)

前面已经提过,这一部分将消费 DataReader 的输入,同时这一部分将产生可以供上层计算网络直接使用的数据,这样在它的前向计算中,我们直接将 BasePrefetchingDataLayer 的输出拷贝到 top_data 就可以了,这样就节省了一定的时间。一般来说由 GPU 完成计算,由 CPU 完成数据准备,两者之间也不会出现严重的资源竞争。

写个新结构?

从上面的介绍中,我们看出:DataReader 基本上不用变,我们只要根据不用的 Datum 类型创建不同的泛型类就好了,这部分的代码逻辑是固定的;而 BasePrefetchingDataLayer 的部分是有可能发生变化的。这也是一个新的 Datum 要面对的主要部分。而 BasePrefetchingDataLayer 也采用的 Template 的设计模式把不变的流程代码准备好了,一般来说只要实现两个函数即可:

  • DataLayerSetUp
  • load_batch

一个是把类内的一些变量的维度进行初始化,一个是实现如何把一个 DataReader 返回的 raw data 转化为上层网络要的数据。

知道了这些,我们就可以看看 SSD 中的 annotated_data_layer 的实现了,它的 label 中需要存入 8 个信息:

[item_id, group_label, instance_id, xmin, ymin, xmax, ymax, diff]

所以相对应的 load_batch 部分也要做许多计算和准备。具体的计算内容我们可以以后再看,总之通过这两个部分的修改 - prototxt 中的数据结构定义和 DataLayer 部分相关位置的修改,我们就可以使得网络输入多种多样的数据,我们的 Caffe 也就可以完成更多有挑战的事情了。

Data Transformer

这一段新加入的,本来希望能单独写一篇,后来发现字数不够多,就把这部分和这篇合并起来了。我们来看一看 Data Transformer 的内容。这一部分的实现在 C++和 python 上有所不同。一般来说,我们用 C++的代码做训练,用 python 的代码做预测(当然现在用 python 做训练的也越来越多)。在 C++中这是 DataLayer 中的一个小部分,而在 python 中这是一个独立的部分。

C++中的 DataTransformer

Crop :这是在训练过程中经常用到的一种增强数据的方式。在 train 的过程中 Caffe 会进行随机 crop,在 test 的过程中只会保留中间的部分。

Mirror :做一个 x 轴的翻转

Mean : 给每个像素减去一个均值

Scale :给每个像素值乘以一个系数

Python 中的 DataTransformer

python 的 Transformer 就有些复杂了:

Resize :将输入数据缩放到指定的长宽比例

Transpose :转换输入数据的维度。因为经过 skimage 读入后数据的维度是(Height * Width * Channel),需要将数据的维度转换到(Channel*Height*Width)

Channelswap :这个主要针对彩色图的输入,不同的图像处理库对 channel 的处理顺序有所不同。像 opencv,C++的主要合作伙伴,它的默认装载顺序是 BGR - Blue,Green,Red。而 skimage 读入的是 RGB,所以为了保证和 C++训练的模型一致,所以这一步也是很必要的。

Raw scale - Mean - Input scale :这里是将每个像素乘以一个 Raw scale,减去一个 Mean,再乘以一个 Input scale。

python 的版本中还包含一个 deprocess ,用于做图像的反向处理。

python 版本的 crop 和 mirror 功能在 python/caffe/io.py 中的 oversample 函数,不过他的逻辑和 C++的逻辑不太一样了。实际使用中还可以针对自己的使用情况进行修改。

如果你对这篇内容有疑问,欢迎到本站社区发帖提问 参与讨论,获取更多帮助,或者扫码二维码加入 Web 技术交流群。

扫码二维码加入Web技术交流群

发布评论

需要 登录 才能够评论, 你可以免费 注册 一个本站的账号。
列表为空,暂无数据
    我们使用 Cookies 和其他技术来定制您的体验包括您的登录状态等。通过阅读我们的 隐私政策 了解更多相关信息。 单击 接受 或继续使用网站,即表示您同意使用 Cookies 和您的相关数据。
    原文