Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 45 additions & 38 deletions training/input_processing.py
Original file line number Diff line number Diff line change
Expand Up @@ -394,44 +394,51 @@ def cutout(shot,

@tf.function
def pil_equalize(shot):
# Implements Equalize function from PIL using TF ops.
# https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
def scale_channel(im, c):
im = tf.cast(im[:, :, c], tf.int32)
# Compute the histogram of the image channel.
histo = tf.histogram_fixed_width(im, [0, 255], nbins=256)

# For the purposes of computing the step, filter out the nonzeros.
nonzero = tf.where(tf.not_equal(histo, 0))
nonzero_histo = tf.reshape(tf.gather(histo, nonzero), [-1])
step = (tf.reduce_sum(nonzero_histo) - nonzero_histo[-1]) // 255

def build_lut(histo, step):
# Compute the cumulative sum, shifting by step // 2 and then normalization by step.
lut = (tf.cumsum(histo) + (step // 2)) // step
# Shift lut, prepending with 0.
lut = tf.concat([[0], lut[:-1]], 0)
# Clip the counts to be in range. This is done in the C code for image.point.
return tf.clip_by_value(lut, 0, 255)

# If step is zero, return the original image.
# Otherwise, build lut from the full histogram and step and then index from it.
result = tf.cond(tf.equal(step, 0),
lambda: im,
lambda: tf.gather(build_lut(histo, step), im))

return tf.cast(result, tf.uint8)

# Assumes RGB for now. Scales each channel independently and then stacks the result.
l, h, w, c = tf.shape(shot)[0], tf.shape(shot)[1], tf.shape(shot)[2], tf.shape(shot)[3]

shot = tf.reshape(shot, [l * h, w, c])
s1 = scale_channel(shot, 0)
s2 = scale_channel(shot, 1)
s3 = scale_channel(shot, 2)
shot = tf.stack([s1, s2, s3], 2)
shot = tf.reshape(shot, [l, h, w, c])
return shot
with tf.name_scope("pil_equalize"):
# Implements Equalize function from PIL using TF ops.
# https://github.com/tensorflow/tpu/blob/master/models/official/efficientnet/autoaugment.py
def scale_channel(im, c):
with tf.name_scope("scale_channel"):
im = tf.cast(im[:, :, c], tf.int32)
# Compute the histogram of the image channel.
histo = tf.histogram_fixed_width(im, [0, 255], nbins=256)

# For the purposes of computing the step, filter out the nonzeros.
nonzero = tf.where(tf.not_equal(histo, 0))
nonzero_histo = tf.reshape(tf.gather(histo, nonzero), [-1])
if len(nonzero_histo) <= 1:
return tf.cast(im, tf.uint8)
step = (tf.reduce_sum(nonzero_histo) - nonzero_histo[-1]) // 255


def build_lut(histo, step):
with tf.name_scope("build_lut"):
# Compute the cumulative sum, shifting by step // 2 and then normalization by step.
lut = (tf.cumsum(histo) + (step // 2)) // step
# Shift lut, prepending with 0.
lut = tf.concat([[0], lut[:-1]], 0)
# Clip the counts to be in range. This is done in the C code for image.point.
return tf.clip_by_value(lut, 0, 255)

# If step is zero, return the original image.
# Otherwise, build lut from the full histogram and step and then index from it.
result = tf.cond(tf.equal(step, 0),
lambda: im,
lambda: tf.gather(build_lut(histo, step), im))

return tf.cast(result, tf.uint8)

# Assumes RGB for now. Scales each channel independently and then stacks the result.
l, h, w, c = tf.shape(shot)[0], tf.shape(shot)[1], tf.shape(shot)[2], tf.shape(shot)[3]

shot = tf.reshape(shot, [l * h, w, c])
s1 = scale_channel(shot, 0)
s2 = scale_channel(shot, 1)
s3 = scale_channel(shot, 2)
shot = tf.stack([s1, s2, s3], 2)
shot = tf.reshape(shot, [l, h, w, c])
return shot



def pil_posterize(image, bits):
Expand Down