/*
 * Decompiled with CFR 0.152.
 */
package org.apache.beam.sdk.util.construction.graph;

import com.google.auto.value.AutoValue;
import java.util.ArrayDeque;
import java.util.Collection;
import java.util.Comparator;
import java.util.HashMap;
import java.util.HashSet;
import java.util.LinkedHashSet;
import java.util.Map;
import java.util.NavigableSet;
import java.util.Set;
import java.util.SortedSet;
import java.util.TreeSet;
import java.util.stream.Collectors;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.sdk.util.construction.graph.AutoValue_GreedyPipelineFuser_CollectionConsumer;
import org.apache.beam.sdk.util.construction.graph.AutoValue_GreedyPipelineFuser_DescendantConsumers;
import org.apache.beam.sdk.util.construction.graph.AutoValue_GreedyPipelineFuser_SiblingKey;
import org.apache.beam.sdk.util.construction.graph.ExecutableStage;
import org.apache.beam.sdk.util.construction.graph.FusedPipeline;
import org.apache.beam.sdk.util.construction.graph.GreedyPCollectionFusers;
import org.apache.beam.sdk.util.construction.graph.GreedyStageFuser;
import org.apache.beam.sdk.util.construction.graph.ImmutableExecutableStage;
import org.apache.beam.sdk.util.construction.graph.OutputDeduplicator;
import org.apache.beam.sdk.util.construction.graph.PipelineNode;
import org.apache.beam.sdk.util.construction.graph.PipelineValidator;
import org.apache.beam.sdk.util.construction.graph.QueryablePipeline;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.Preconditions;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ComparisonChain;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.HashMultimap;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableList;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.ImmutableSet;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Sets;
import org.checkerframework.checker.initialization.qual.Initialized;
import org.checkerframework.checker.nullness.qual.NonNull;
import org.checkerframework.checker.nullness.qual.UnknownKeyFor;
import org.checkerframework.dataflow.qual.Pure;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class GreedyPipelineFuser {
    private static final @UnknownKeyFor @NonNull @Initialized Logger LOG = LoggerFactory.getLogger(GreedyPipelineFuser.class);
    private final @UnknownKeyFor @NonNull @Initialized QueryablePipeline pipeline;
    private final @UnknownKeyFor @NonNull @Initialized FusedPipeline fusedPipeline;

    private GreedyPipelineFuser( @UnknownKeyFor @NonNull @Initialized RunnerApi.Pipeline p) {
        PipelineValidator.validate(p);
        this.pipeline = QueryablePipeline.forPrimitivesIn(p.getComponents());
        LinkedHashSet<PipelineNode.PTransformNode> unfusedRootNodes = new LinkedHashSet<PipelineNode.PTransformNode>();
        TreeSet<CollectionConsumer> rootConsumers = new TreeSet<CollectionConsumer>();
        for (PipelineNode.PTransformNode pTransformNode : this.pipeline.getRootTransforms()) {
            DescendantConsumers descendants = this.getRootConsumers(pTransformNode);
            unfusedRootNodes.addAll(descendants.getUnfusedNodes());
            rootConsumers.addAll(descendants.getFusibleConsumers());
        }
        this.fusedPipeline = this.fusePipeline(unfusedRootNodes, this.groupSiblings(rootConsumers), (Set<String>)ImmutableSet.copyOf((Collection)p.getRequirementsList()));
    }

    public static @UnknownKeyFor @NonNull @Initialized FusedPipeline fuse( @UnknownKeyFor @NonNull @Initialized RunnerApi.Pipeline p) {
        return new GreedyPipelineFuser((RunnerApi.Pipeline)p).fusedPipeline;
    }

    private @UnknownKeyFor @NonNull @Initialized FusedPipeline fusePipeline(@UnknownKeyFor @NonNull @Initialized Collection<@UnknownKeyFor @NonNull @Initialized PipelineNode.PTransformNode> initialUnfusedTransforms, @UnknownKeyFor @NonNull @Initialized NavigableSet<@UnknownKeyFor @NonNull @Initialized NavigableSet<@UnknownKeyFor @NonNull @Initialized CollectionConsumer>> initialConsumers, @UnknownKeyFor @NonNull @Initialized Set<@UnknownKeyFor @NonNull @Initialized String> requirements) {
        HashMap<CollectionConsumer, ExecutableStage> consumedCollectionsAndTransforms = new HashMap<CollectionConsumer, ExecutableStage>();
        LinkedHashSet<ExecutableStage> stages = new LinkedHashSet<ExecutableStage>();
        LinkedHashSet<PipelineNode.PTransformNode> unfusedTransforms = new LinkedHashSet<PipelineNode.PTransformNode>(initialUnfusedTransforms);
        ArrayDeque<NavigableSet<CollectionConsumer>> pendingSiblingSets = new ArrayDeque<NavigableSet<CollectionConsumer>>(initialConsumers);
        while (!pendingSiblingSets.isEmpty()) {
            Set candidateSiblings = (Set)pendingSiblingSets.poll();
            Sets.SetView siblingSet = Sets.difference((Set)candidateSiblings, consumedCollectionsAndTransforms.keySet());
            Preconditions.checkState((siblingSet.equals(candidateSiblings) || siblingSet.isEmpty() ? 1 : 0) != 0, (String)"Inconsistent collection of siblings reported for a %s. Initial attempt missed %s", (Object)PipelineNode.PCollectionNode.class.getSimpleName(), (Object)siblingSet);
            if (siblingSet.isEmpty()) {
                LOG.debug("Filtered out duplicate stage root {}", (Object)candidateSiblings);
                continue;
            }
            ExecutableStage stage2 = this.fuseSiblings((Set<CollectionConsumer>)siblingSet);
            for (CollectionConsumer sibling : siblingSet) {
                consumedCollectionsAndTransforms.put(sibling, stage2);
            }
            stages.add(stage2);
            for (PipelineNode.PCollectionNode materializedOutput : stage2.getOutputPCollections()) {
                DescendantConsumers descendantConsumers = this.getDescendantConsumers(materializedOutput);
                unfusedTransforms.addAll(descendantConsumers.getUnfusedNodes());
                NavigableSet<NavigableSet<CollectionConsumer>> siblings = this.groupSiblings(descendantConsumers.getFusibleConsumers());
                pendingSiblingSets.addAll(siblings);
            }
        }
        OutputDeduplicator.DeduplicationResult deduplicated = OutputDeduplicator.ensureSingleProducer(this.pipeline, stages, unfusedTransforms);
        return FusedPipeline.of(deduplicated.getDeduplicatedComponents(), stages.stream().map(stage -> deduplicated.getDeduplicatedStages().getOrDefault(stage, (ExecutableStage)stage)).map(GreedyPipelineFuser::sanitizeDanglingPTransformInputs).collect(Collectors.toSet()), (Set<PipelineNode.PTransformNode>)Sets.union(deduplicated.getIntroducedTransforms(), unfusedTransforms.stream().map(transform -> deduplicated.getDeduplicatedTransforms().getOrDefault(transform.getId(), (PipelineNode.PTransformNode)transform)).collect(Collectors.toSet())), requirements);
    }

    private @UnknownKeyFor @NonNull @Initialized DescendantConsumers getRootConsumers(@UnknownKeyFor @NonNull @Initialized PipelineNode.PTransformNode rootNode) {
        Preconditions.checkArgument((rootNode.getTransform().getInputsCount() == 0 ? 1 : 0) != 0, (String)"Transform %s is not at the root of the graph (consumes %s)", (Object)rootNode.getId(), rootNode.getTransform().getInputsMap());
        Preconditions.checkArgument((!this.pipeline.getEnvironment(rootNode).isPresent() ? 1 : 0) != 0, (String)"%s requires all root nodes to be runner-implemented %s or %s primitives, but transform %s executes in environment %s", (Object[])new Object[]{GreedyPipelineFuser.class.getSimpleName(), "beam:transform:impulse:v1", "beam:transform:read:v1", rootNode.getId(), this.pipeline.getEnvironment(rootNode)});
        HashSet<PipelineNode.PTransformNode> unfused = new HashSet<PipelineNode.PTransformNode>();
        unfused.add(rootNode);
        TreeSet<CollectionConsumer> environmentNodes = new TreeSet<CollectionConsumer>();
        for (PipelineNode.PCollectionNode output : this.pipeline.getOutputPCollections(rootNode)) {
            DescendantConsumers descendants = this.getDescendantConsumers(output);
            unfused.addAll(descendants.getUnfusedNodes());
            environmentNodes.addAll(descendants.getFusibleConsumers());
        }
        return DescendantConsumers.of(unfused, environmentNodes);
    }

    private @UnknownKeyFor @NonNull @Initialized DescendantConsumers getDescendantConsumers(@UnknownKeyFor @NonNull @Initialized PipelineNode.PCollectionNode inputPCollection) {
        HashSet<PipelineNode.PTransformNode> unfused = new HashSet<PipelineNode.PTransformNode>();
        TreeSet<CollectionConsumer> downstreamConsumers = new TreeSet<CollectionConsumer>();
        for (PipelineNode.PTransformNode consumer : this.pipeline.getPerElementConsumers(inputPCollection)) {
            if (this.pipeline.getEnvironment(consumer).isPresent()) {
                downstreamConsumers.add(CollectionConsumer.of(inputPCollection, consumer));
                continue;
            }
            LOG.debug("Adding {} {} to the set of runner-executed transforms", (Object)PipelineNode.PTransformNode.class.getSimpleName(), (Object)consumer.getId());
            unfused.add(consumer);
            for (PipelineNode.PCollectionNode output : this.pipeline.getOutputPCollections(consumer)) {
                DescendantConsumers descendants = this.getDescendantConsumers(output);
                unfused.addAll(descendants.getUnfusedNodes());
                downstreamConsumers.addAll(descendants.getFusibleConsumers());
            }
        }
        return DescendantConsumers.of(unfused, downstreamConsumers);
    }

    private @UnknownKeyFor @NonNull @Initialized NavigableSet<@UnknownKeyFor @NonNull @Initialized NavigableSet<@UnknownKeyFor @NonNull @Initialized CollectionConsumer>> groupSiblings(@UnknownKeyFor @NonNull @Initialized NavigableSet<@UnknownKeyFor @NonNull @Initialized CollectionConsumer> newConsumers) {
        HashMultimap compatibleConsumers = HashMultimap.create();
        for (CollectionConsumer newConsumer : newConsumers) {
            AutoValue_GreedyPipelineFuser_SiblingKey key = new AutoValue_GreedyPipelineFuser_SiblingKey(newConsumer.consumedCollection(), this.pipeline.getEnvironment(newConsumer.consumingTransform()).get());
            boolean foundSiblings = false;
            for (Set existingConsumers : compatibleConsumers.get((Object)key)) {
                if (!existingConsumers.stream().allMatch(collectionConsumer -> GreedyPCollectionFusers.isCompatible(collectionConsumer.consumingTransform(), newConsumer.consumingTransform(), this.pipeline))) continue;
                existingConsumers.add(newConsumer);
                foundSiblings = true;
                break;
            }
            if (foundSiblings) continue;
            TreeSet<CollectionConsumer> newConsumerSet = new TreeSet<CollectionConsumer>();
            newConsumerSet.add(newConsumer);
            compatibleConsumers.put((Object)key, newConsumerSet);
        }
        TreeSet<NavigableSet<CollectionConsumer>> orderedSiblings = new TreeSet<NavigableSet<CollectionConsumer>>(Comparator.comparing(SortedSet::first));
        orderedSiblings.addAll(compatibleConsumers.values());
        return orderedSiblings;
    }

    private @UnknownKeyFor @NonNull @Initialized ExecutableStage fuseSiblings(@UnknownKeyFor @NonNull @Initialized Set<@UnknownKeyFor @NonNull @Initialized CollectionConsumer> mutuallyCompatible) {
        PipelineNode.PCollectionNode rootCollection = mutuallyCompatible.iterator().next().consumedCollection();
        return GreedyStageFuser.forGrpcPortRead(this.pipeline, rootCollection, mutuallyCompatible.stream().map(CollectionConsumer::consumingTransform).collect(Collectors.toSet()));
    }

    private static @UnknownKeyFor @NonNull @Initialized ExecutableStage sanitizeDanglingPTransformInputs(@UnknownKeyFor @NonNull @Initialized ExecutableStage stage) {
        HashSet<String> possibleInputs = new HashSet<String>();
        possibleInputs.add(stage.getInputPCollection().getId());
        possibleInputs.addAll(stage.getOutputPCollections().stream().map(PipelineNode.PCollectionNode::getId).collect(Collectors.toSet()));
        possibleInputs.addAll(stage.getSideInputs().stream().map(s -> s.collection().getId()).collect(Collectors.toSet()));
        possibleInputs.addAll(stage.getTransforms().stream().flatMap(t -> t.getTransform().getOutputsMap().values().stream()).collect(Collectors.toSet()));
        Set danglingInputs = stage.getTransforms().stream().flatMap(t -> t.getTransform().getInputsMap().values().stream()).filter(in -> !possibleInputs.contains(in)).collect(Collectors.toSet());
        ImmutableList.Builder pTransformNodesBuilder = ImmutableList.builder();
        for (PipelineNode.PTransformNode transformNode : stage.getTransforms()) {
            RunnerApi.PTransform transform = transformNode.getTransform();
            Map<String, String> validInputs = transform.getInputsMap().entrySet().stream().filter(e -> !danglingInputs.contains(e.getValue())).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
            if (!validInputs.equals(transform.getInputsMap())) {
                transformNode = PipelineNode.pTransform(transformNode.getId(), transform.toBuilder().clearInputs().putAllInputs(validInputs).build());
            }
            pTransformNodesBuilder.add((Object)transformNode);
        }
        ImmutableList pTransformNodes = pTransformNodesBuilder.build();
        RunnerApi.Components.Builder componentBuilder = stage.getComponents().toBuilder();
        componentBuilder.clearTransforms().putAllTransforms(pTransformNodes.stream().collect(Collectors.toMap(PipelineNode.PTransformNode::getId, PipelineNode.PTransformNode::getTransform)));
        Map<String, RunnerApi.PCollection> validPCollectionMap = stage.getComponents().getPcollectionsMap().entrySet().stream().filter(e -> !danglingInputs.contains(e.getKey())).collect(Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue));
        componentBuilder.clearPcollections().putAllPcollections(validPCollectionMap);
        return ImmutableExecutableStage.of(componentBuilder.build(), stage.getEnvironment(), stage.getInputPCollection(), stage.getSideInputs(), stage.getUserStates(), stage.getTimers(), (Collection<PipelineNode.PTransformNode>)pTransformNodes, stage.getOutputPCollections(), stage.getWireCoderSettings());
    }

    @AutoValue
    static abstract class CollectionConsumer
    implements Comparable<CollectionConsumer> {
        CollectionConsumer() {
        }

        static @UnknownKeyFor @NonNull @Initialized CollectionConsumer of(@UnknownKeyFor @NonNull @Initialized PipelineNode.PCollectionNode collection, @UnknownKeyFor @NonNull @Initialized PipelineNode.PTransformNode consumer) {
            return new AutoValue_GreedyPipelineFuser_CollectionConsumer(collection, consumer);
        }

        abstract @UnknownKeyFor @NonNull @Initialized PipelineNode.PCollectionNode consumedCollection();

        abstract @UnknownKeyFor @NonNull @Initialized PipelineNode.PTransformNode consumingTransform();

        @Override
        @Pure
        public @UnknownKeyFor @NonNull @Initialized int compareTo(@UnknownKeyFor @NonNull @Initialized CollectionConsumer that) {
            return ComparisonChain.start().compare((Comparable)((Object)this.consumedCollection().getId()), (Comparable)((Object)that.consumedCollection().getId())).compare((Comparable)((Object)this.consumingTransform().getId()), (Comparable)((Object)that.consumingTransform().getId())).result();
        }
    }

    @AutoValue
    static abstract class SiblingKey {
        SiblingKey() {
        }

        abstract @UnknownKeyFor @NonNull @Initialized PipelineNode.PCollectionNode getInputCollection();

        abstract  @UnknownKeyFor @NonNull @Initialized RunnerApi.Environment getEnv();
    }

    @AutoValue
    static abstract class DescendantConsumers {
        DescendantConsumers() {
        }

        static @UnknownKeyFor @NonNull @Initialized DescendantConsumers of(@UnknownKeyFor @NonNull @Initialized Set<@UnknownKeyFor @NonNull @Initialized PipelineNode.PTransformNode> unfusible, @UnknownKeyFor @NonNull @Initialized NavigableSet<@UnknownKeyFor @NonNull @Initialized CollectionConsumer> fusible) {
            return new AutoValue_GreedyPipelineFuser_DescendantConsumers(unfusible, fusible);
        }

        abstract @UnknownKeyFor @NonNull @Initialized Set<@UnknownKeyFor @NonNull @Initialized PipelineNode.PTransformNode> getUnfusedNodes();

        abstract @UnknownKeyFor @NonNull @Initialized NavigableSet<@UnknownKeyFor @NonNull @Initialized CollectionConsumer> getFusibleConsumers();
    }
}

