From 4c35e061d4ba6ee159599f3df1657ef7993844fd Mon Sep 17 00:00:00 2001 From: Yu Feng Date: Sat, 25 Dec 2021 23:41:23 +0800 Subject: [PATCH] modify pil_equalize: fix black image error --- training/input_processing.py | 83 +++++++++++++++++++----------------- 1 file changed, 45 insertions(+), 38 deletions(-) diff --git a/training/input_processing.py b/training/input_processing.py index 277a8cb..3511bbd 100644 --- a/training/input_processing.py +++ b/training/input_processing.py @@ -394,44 +394,51 @@ def cutout(shot, @tf.function def pil_equalize(shot): - # Implements Equalize function from PIL using TF ops. - # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py - def scale_channel(im, c): - im = tf.cast(im[:, :, c], tf.int32) - # Compute the histogram of the image channel. - histo = tf.histogram_fixed_width(im, [0, 255], nbins=256) - - # For the purposes of computing the step, filter out the nonzeros. - nonzero = tf.where(tf.not_equal(histo, 0)) - nonzero_histo = tf.reshape(tf.gather(histo, nonzero), [-1]) - step = (tf.reduce_sum(nonzero_histo) - nonzero_histo[-1]) // 255 - - def build_lut(histo, step): - # Compute the cumulative sum, shifting by step // 2 and then normalization by step. - lut = (tf.cumsum(histo) + (step // 2)) // step - # Shift lut, prepending with 0. - lut = tf.concat([[0], lut[:-1]], 0) - # Clip the counts to be in range. This is done in the C code for image.point. - return tf.clip_by_value(lut, 0, 255) - - # If step is zero, return the original image. - # Otherwise, build lut from the full histogram and step and then index from it. - result = tf.cond(tf.equal(step, 0), - lambda: im, - lambda: tf.gather(build_lut(histo, step), im)) - - return tf.cast(result, tf.uint8) - - # Assumes RGB for now. Scales each channel independently and then stacks the result. - l, h, w, c = tf.shape(shot)[0], tf.shape(shot)[1], tf.shape(shot)[2], tf.shape(shot)[3] - - shot = tf.reshape(shot, [l * h, w, c]) - s1 = scale_channel(shot, 0) - s2 = scale_channel(shot, 1) - s3 = scale_channel(shot, 2) - shot = tf.stack([s1, s2, s3], 2) - shot = tf.reshape(shot, [l, h, w, c]) - return shot + with tf.name_scope("pil_equalize"): + # Implements Equalize function from PIL using TF ops. + # https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py + def scale_channel(im, c): + with tf.name_scope("scale_channel"): + im = tf.cast(im[:, :, c], tf.int32) + # Compute the histogram of the image channel. + histo = tf.histogram_fixed_width(im, [0, 255], nbins=256) + + # For the purposes of computing the step, filter out the nonzeros. + nonzero = tf.where(tf.not_equal(histo, 0)) + nonzero_histo = tf.reshape(tf.gather(histo, nonzero), [-1]) + if len(nonzero_histo) <= 1: + return tf.cast(im, tf.uint8) + step = (tf.reduce_sum(nonzero_histo) - nonzero_histo[-1]) // 255 + + + def build_lut(histo, step): + with tf.name_scope("build_lut"): + # Compute the cumulative sum, shifting by step // 2 and then normalization by step. + lut = (tf.cumsum(histo) + (step // 2)) // step + # Shift lut, prepending with 0. + lut = tf.concat([[0], lut[:-1]], 0) + # Clip the counts to be in range. This is done in the C code for image.point. + return tf.clip_by_value(lut, 0, 255) + + # If step is zero, return the original image. + # Otherwise, build lut from the full histogram and step and then index from it. + result = tf.cond(tf.equal(step, 0), + lambda: im, + lambda: tf.gather(build_lut(histo, step), im)) + + return tf.cast(result, tf.uint8) + + # Assumes RGB for now. Scales each channel independently and then stacks the result. + l, h, w, c = tf.shape(shot)[0], tf.shape(shot)[1], tf.shape(shot)[2], tf.shape(shot)[3] + + shot = tf.reshape(shot, [l * h, w, c]) + s1 = scale_channel(shot, 0) + s2 = scale_channel(shot, 1) + s3 = scale_channel(shot, 2) + shot = tf.stack([s1, s2, s3], 2) + shot = tf.reshape(shot, [l, h, w, c]) + return shot + def pil_posterize(image, bits):