@@ -42,9 +42,14 @@ static void DeconvolutionComputeExCPU(const nnvm::NodeAttrs& attrs,
4242 const std::vector<NDArray>& outputs) {
4343 const DeconvolutionParam& params = nnvm::get<DeconvolutionParam>(attrs.parsed );
4444 if (SupportMKLDNNDeconv (params, inputs[0 ])) {
45- MKLDNN_OPCHECK_INIT (false , outputs.size (), inputs, outputs);
46- MKLDNNRun (MKLDNNDeconvolutionForward, attrs, ctx, inputs, req, outputs);
47- MKLDNN_OPCHECK_RUN (DeconvolutionCompute<cpu>, attrs, ctx, inputs, req, outputs);
45+ if (params.kernel .ndim () == 3 ) {
46+ // we cannot check the output, as 3D deconvolution is not natively supported yet
47+ MKLDNNRun (MKLDNNDeconvolutionForward, attrs, ctx, inputs, req, outputs);
48+ } else {
49+ MKLDNN_OPCHECK_INIT (false , outputs.size (), inputs, outputs);
50+ MKLDNNRun (MKLDNNDeconvolutionForward, attrs, ctx, inputs, req, outputs);
51+ MKLDNN_OPCHECK_RUN (DeconvolutionCompute<cpu>, attrs, ctx, inputs, req, outputs);
52+ }
4853 return ;
4954 }
5055 FallBackCompute (DeconvolutionCompute<cpu>, attrs, ctx, inputs, req, outputs);
@@ -57,9 +62,14 @@ static void DeconvolutionGradComputeExCPU(const nnvm::NodeAttrs& attrs,
5762 const std::vector<NDArray>& outputs) {
5863 const DeconvolutionParam& params = nnvm::get<DeconvolutionParam>(attrs.parsed );
5964 if (SupportMKLDNNDeconv (params, inputs[0 ])) {
60- MKLDNN_OPCHECK_INIT (true , outputs.size (), inputs, outputs);
61- MKLDNNRun (MKLDNNDeconvolutionBackward, attrs, ctx, inputs, req, outputs);
62- MKLDNN_OPCHECK_RUN (DeconvolutionGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
65+ if (params.kernel .ndim () == 3 ) {
66+ // we cannot check the output, as 3D deconvolution is not natively supported yet
67+ MKLDNNRun (MKLDNNDeconvolutionBackward, attrs, ctx, inputs, req, outputs);
68+ } else {
69+ MKLDNN_OPCHECK_INIT (true , outputs.size (), inputs, outputs);
70+ MKLDNNRun (MKLDNNDeconvolutionBackward, attrs, ctx, inputs, req, outputs);
71+ MKLDNN_OPCHECK_RUN (DeconvolutionGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
72+ }
6373 return ;
6474 }
6575 FallBackCompute (DeconvolutionGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
@@ -99,12 +109,12 @@ static bool DeconvolutionShape(const nnvm::NodeAttrs& attrs,
99109 mxnet::ShapeVector *in_shape,
100110 mxnet::ShapeVector *out_shape) {
101111 const DeconvolutionParam& param_ = nnvm::get<DeconvolutionParam>(attrs.parsed );
102- #if MXNET_USE_CUDNN == 0
112+ #if MXNET_USE_CUDNN == 0 && MXNET_USE_MKLDNN == 0
103113 if (param_.kernel .ndim () > 2 ) {
104- LOG (FATAL) << " If not using CUDNN, only 1D or 2D Deconvolution is supported" ;
114+ LOG (FATAL) << " If not using CUDNN or MKLDNN , only 1D or 2D Deconvolution is supported" ;
105115 return false ;
106116 }
107- #endif // CUDNN
117+ #endif
108118
109119 using namespace mshadow ;
110120 if (!param_.no_bias ) {
0 commit comments