Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
{
"comment": "Modify this file in a trivial way to cause this test suite to run!",
"modification": 2,
"modification": 1,
}
Original file line number Diff line number Diff line change
Expand Up @@ -130,114 +130,114 @@ public Result invokeProcessElement(
final Map<String, PCollectionView<?>> sideInputMapping) {
final ProcessContext processContext = new ProcessContext(element, tracker, watermarkEstimator);

DoFn.ProcessContinuation cont =
invoker.invokeProcessElement(
new DoFnInvoker.BaseArgumentProvider<InputT, OutputT>() {
@Override
public String getErrorContext() {
return OutputAndTimeBoundedSplittableProcessElementInvoker.class.getSimpleName();
}

@Override
public DoFn<InputT, OutputT>.ProcessContext processContext(
DoFn<InputT, OutputT> doFn) {
return processContext;
}

@Override
public Object sideInput(String tagId) {
PCollectionView<?> view = sideInputMapping.get(tagId);
if (view == null) {
throw new IllegalArgumentException("calling getSideInput() with unknown view");
}
return processContext.sideInput(view);
}

@Override
public Object restriction() {
return tracker.currentRestriction();
}

@Override
public InputT element(DoFn<InputT, OutputT> doFn) {
return processContext.element();
}

@Override
public Instant timestamp(DoFn<InputT, OutputT> doFn) {
return processContext.timestamp();
}

@Override
public String timerId(DoFn<InputT, OutputT> doFn) {
throw new UnsupportedOperationException(
"Cannot access timerId as parameter outside of @OnTimer method.");
}

@Override
public TimeDomain timeDomain(DoFn<InputT, OutputT> doFn) {
throw new UnsupportedOperationException(
"Access to time domain not supported in ProcessElement");
}

@Override
public OutputReceiver<OutputT> outputReceiver(DoFn<InputT, OutputT> doFn) {
return DoFnOutputReceivers.windowedReceiver(
processContext, OutputBuilderSuppliers.supplierForElement(element), null);
}

@Override
public OutputReceiver<Row> outputRowReceiver(DoFn<InputT, OutputT> doFn) {
throw new UnsupportedOperationException("Not supported in SplittableDoFn");
}

@Override
public MultiOutputReceiver taggedOutputReceiver(DoFn<InputT, OutputT> doFn) {
return DoFnOutputReceivers.windowedMultiReceiver(
processContext, OutputBuilderSuppliers.supplierForElement(element));
}

@Override
public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
return processContext.causedByDrain();
}

@Override
public RestrictionTracker<?, ?> restrictionTracker() {
return processContext.tracker;
}

@Override
public WatermarkEstimator<?> watermarkEstimator() {
return processContext.watermarkEstimator;
}

@Override
public PipelineOptions pipelineOptions() {
return pipelineOptions;
}

@Override
public BundleFinalizer bundleFinalizer() {
return bundleFinalizer.get();
}

// Unsupported methods below.

@Override
public StartBundleContext startBundleContext(DoFn<InputT, OutputT> doFn) {
throw new IllegalStateException(
"Should not access startBundleContext() from @"
+ DoFn.ProcessElement.class.getSimpleName());
}

@Override
public FinishBundleContext finishBundleContext(DoFn<InputT, OutputT> doFn) {
throw new IllegalStateException(
"Should not access finishBundleContext() from @"
+ DoFn.ProcessElement.class.getSimpleName());
}
});
DoFnInvoker.BaseArgumentProvider<InputT, OutputT> invokerArgumentProvider =
new DoFnInvoker.BaseArgumentProvider<InputT, OutputT>() {
@Override
public String getErrorContext() {
return OutputAndTimeBoundedSplittableProcessElementInvoker.class.getSimpleName();
}

@Override
public DoFn<InputT, OutputT>.ProcessContext processContext(DoFn<InputT, OutputT> doFn) {
return processContext;
}

@Override
public Object sideInput(String tagId) {
PCollectionView<?> view = sideInputMapping.get(tagId);
if (view == null) {
throw new IllegalArgumentException("calling getSideInput() with unknown view");
}
return processContext.sideInput(view);
}

@Override
public Object restriction() {
return tracker.currentRestriction();
}

@Override
public InputT element(DoFn<InputT, OutputT> doFn) {
return processContext.element();
}

@Override
public Instant timestamp(DoFn<InputT, OutputT> doFn) {
return processContext.timestamp();
}

@Override
public String timerId(DoFn<InputT, OutputT> doFn) {
throw new UnsupportedOperationException(
"Cannot access timerId as parameter outside of @OnTimer method.");
}

@Override
public TimeDomain timeDomain(DoFn<InputT, OutputT> doFn) {
throw new UnsupportedOperationException(
"Access to time domain not supported in ProcessElement");
}

@Override
public OutputReceiver<OutputT> outputReceiver(DoFn<InputT, OutputT> doFn) {
return DoFnOutputReceivers.windowedReceiver(
processContext, OutputBuilderSuppliers.supplierForElement(element), null);
}

@Override
public OutputReceiver<Row> outputRowReceiver(DoFn<InputT, OutputT> doFn) {
throw new UnsupportedOperationException("Not supported in SplittableDoFn");
}

@Override
public MultiOutputReceiver taggedOutputReceiver(DoFn<InputT, OutputT> doFn) {
return DoFnOutputReceivers.windowedMultiReceiver(
processContext, OutputBuilderSuppliers.supplierForElement(element));
}

@Override
public CausedByDrain causedByDrain(DoFn<InputT, OutputT> doFn) {
return processContext.causedByDrain();
}

@Override
public RestrictionTracker<?, ?> restrictionTracker() {
return processContext.tracker;
}

@Override
public WatermarkEstimator<?> watermarkEstimator() {
return processContext.watermarkEstimator;
}

@Override
public PipelineOptions pipelineOptions() {
return pipelineOptions;
}

@Override
public BundleFinalizer bundleFinalizer() {
return bundleFinalizer.get();
}

// Unsupported methods below.

@Override
public StartBundleContext startBundleContext(DoFn<InputT, OutputT> doFn) {
throw new IllegalStateException(
"Should not access startBundleContext() from @"
+ DoFn.ProcessElement.class.getSimpleName());
}

@Override
public FinishBundleContext finishBundleContext(DoFn<InputT, OutputT> doFn) {
throw new IllegalStateException(
"Should not access finishBundleContext() from @"
+ DoFn.ProcessElement.class.getSimpleName());
}
};

DoFn.ProcessContinuation cont = invoker.invokeProcessElement(invokerArgumentProvider);
processContext.cancelScheduledCheckpoint();
@Nullable
KV<RestrictionT, KV<Instant, WatermarkEstimatorStateT>> residual =
Expand Down Expand Up @@ -278,8 +278,37 @@ public FinishBundleContext finishBundleContext(DoFn<InputT, OutputT> doFn) {
if (residual == null) {
return new Result(null, cont, null, null);
}
final KV<RestrictionT, KV<Instant, WatermarkEstimatorStateT>> residualForGetSize = residual;
// For a list of all DoFnInvoker arguments, see DoFn.java.
double backlogBytes =
invoker.invokeGetSize(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How about creating a util to get information from the residual instead of creating an inline class here. Probably also refactor other similar places if any.

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I replaced this with a call to DelegatingArgumentProvider instead, which allows us to reuse most of the original argument provider used for ProcessElement.

new DoFnInvoker.DelegatingArgumentProvider<InputT, OutputT>(
invokerArgumentProvider, invokerArgumentProvider.getErrorContext() + "/GetSize") {
@Override
public Object restriction() {
return residualForGetSize.getKey();
}

@Override
public RestrictionTracker<?, ?> restrictionTracker() {
return invoker.invokeNewTracker(
new DoFnInvoker.DelegatingArgumentProvider<InputT, OutputT>(
invokerArgumentProvider,
invokerArgumentProvider.getErrorContext() + "/NewTracker") {

@Override
public Object restriction() {
return residualForGetSize.getKey();
}
});
}
});
return new Result(
residual.getKey(), cont, residual.getValue().getKey(), residual.getValue().getValue());
residual.getKey(),
cont,
residual.getValue().getKey(),
residual.getValue().getValue(),
backlogBytes);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We expect to get this information only during finishBundle ?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not during finishBundle, but rather in processElement after we've finished processing the restriction. The idea is that we want to get the backlog (work remaining) of the residual restriction after we've finished processing the bundle.

I don't think you can call tryClaim or otherwise change the restriction in finishBundle (since it requires keyed state and whatnot).

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we do any form of validation of the value returned here before sending it to the runner ? For example, ignore if negative or zero(wrong implementation but we probably don't want to pass that to the runner).

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided to put the validation in the Dataflow execution context that way other runners could decide how they wanted to handle these edge cases. Thoughts?

}

private class ProcessContext extends DoFn<InputT, OutputT>.ProcessContext
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import com.google.auto.service.AutoService;
import java.util.List;
import java.util.Map;
import java.util.function.Consumer;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.sdk.coders.ByteArrayCoder;
import org.apache.beam.sdk.coders.Coder;
Expand Down Expand Up @@ -281,6 +282,7 @@ public static class ProcessFn<InputT, OutputT, RestrictionT, PositionT, Watermar
processElementInvoker;

private transient @Nullable DoFnInvoker<InputT, OutputT> invoker;
private transient @Nullable Consumer<Double> backlogBytesCallback;

public ProcessFn(
DoFn<InputT, OutputT> fn,
Expand Down Expand Up @@ -323,6 +325,10 @@ public void setProcessElementInvoker(
this.processElementInvoker = invoker;
}

public void setBacklogBytesCallback(Consumer<Double> backlogBytesCallback) {
this.backlogBytesCallback = backlogBytesCallback;
}

public DoFn<InputT, OutputT> getFn() {
return fn;
}
Expand Down Expand Up @@ -622,6 +628,9 @@ public String getErrorContext() {
} else {
holdState.clear();
}
if (backlogBytesCallback != null && result.getBacklogBytes() >= 0) {
backlogBytesCallback.accept(result.getBacklogBytes());
}
}

private DoFnInvoker.ArgumentProvider<InputT, OutputT> wrapOptionsAsSetup(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,10 @@ public class Result {
private final DoFn.ProcessContinuation continuation;
private final @Nullable Instant futureOutputWatermark;
private final @Nullable WatermarkEstimatorStateT futureWatermarkEstimatorState;
private final double backlogBytes;

/* Constant representing an unknown amount of backlog. */
public static final double BACKLOG_UNKNOWN = -1.0;

@SuppressFBWarnings(
value = "NP_PARAMETER_MUST_BE_NONNULL_BUT_MARKED_AS_NULLABLE",
Expand All @@ -50,12 +54,27 @@ public Result(
@Nullable RestrictionT residualRestriction,
DoFn.ProcessContinuation continuation,
@Nullable Instant futureOutputWatermark,
@Nullable WatermarkEstimatorStateT futureWatermarkEstimatorState) {
@Nullable WatermarkEstimatorStateT futureWatermarkEstimatorState,
double backlogBytes) {
checkArgument(continuation != null, "continuation must not be null");
this.continuation = continuation;
this.residualRestriction = residualRestriction;
this.futureOutputWatermark = futureOutputWatermark;
this.futureWatermarkEstimatorState = futureWatermarkEstimatorState;
this.backlogBytes = backlogBytes;
}

public Result(
@Nullable RestrictionT residualRestriction,
DoFn.ProcessContinuation continuation,
@Nullable Instant futureOutputWatermark,
@Nullable WatermarkEstimatorStateT futureWatermarkEstimatorState) {
this(
residualRestriction,
continuation,
futureOutputWatermark,
futureWatermarkEstimatorState,
BACKLOG_UNKNOWN);
}

/**
Expand All @@ -76,6 +95,10 @@ public DoFn.ProcessContinuation getContinuation() {
public @Nullable WatermarkEstimatorStateT getFutureWatermarkEstimatorState() {
return futureWatermarkEstimatorState;
}

public double getBacklogBytes() {
return backlogBytes;
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,4 +36,10 @@ public interface StepContext {
default BundleFinalizer bundleFinalizer() {
throw new UnsupportedOperationException("BundleFinalizer is unsupported.");
}

/**
* Set the current backlog bytes for this step. This is mainly used by splittable DoFn to report
* the size of the residual restriction.
*/
default void setBacklogBytes(double backlogBytes) {}
}
Loading
Loading