-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathim2col.cpp
More file actions
67 lines (61 loc) · 2.1 KB
/
im2col.cpp
File metadata and controls
67 lines (61 loc) · 2.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
//========================================================================
// Im2col
//========================================================================
// @brief: pre-processing image data
#include "gemm.h"
#include "im2col.h"
// image to column : filters%batch == 0
void im2col(float *data_im,int channels, int height, int width, int ksize, int stride, int pad, float* data_col)
{
int c,h,w;
int height_col = height + 2*pad;
int width_col = width + 2*pad;
float temp;
int step;
for(c = 0; c < channels; c++)
{
for(h = 0; h < height_col; h++)
{
for(w = 0; w < width_col; w++)
{
//for those width mod SystolicKernelSize != 0 and comes to last few points in each column
if( ( w - ( w - 2 ) % SystolicKernelSize + SystolicKernelSize ) > width_col )
step = ( width_col - 2 ) % SystolicKernelSize;
else
step = SystolicKernelSize;
//pad
if((w == 0) || (h == 0) || (w == width_col-1) || (h == height_col-1))
temp = 0;
//read data
else
temp = data_im[c*width*height+(h-1)*width+(w-1)];
//first two channel are directly read into buffer, thus it is transferred directly
if( w == 0 || w == 1 )
{
data_col[h * width_col * channels + w * channels + c] = temp;
}
//deal with data other than first two channels: pls refer to report
else
data_col[h * width_col * channels + ( w - ( w - 2 ) % SystolicKernelSize ) * channels + step * c + ( w - 2 ) % SystolicKernelSize ] = temp;
}
}
}
}
// image to column : filters%batch != 0
void im2col_extra(float *data_im,int channels, int height, int width, int ksize, int stride, int pad, float* data_col)
{
int c,h,w;
float temp;
for(w = 0; w < width; w++)
{
for(h = 0; h < height; h++)
{
for(c = 0; c < channels; c++)
{
int index_col = (w+h*width)*channels+c;
int index_im = c*width*height+h*width+w;
data_col[index_col] = data_im[index_im];
}
}
}
}