/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.indices.replication;

import java.io.IOException;
import java.util.Map;
import java.util.concurrent.ConcurrentMap;
import java.util.concurrent.ConcurrentNavigableMap;
import java.util.concurrent.ConcurrentSkipListMap;
import java.util.concurrent.TimeUnit;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.apache.lucene.index.CorruptIndexException;
import org.opensearch.OpenSearchCorruptionException;
import org.opensearch.common.SetOnce;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.AbstractRunnable;
import org.opensearch.common.util.concurrent.ConcurrentCollections;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.index.shard.ShardId;
import org.opensearch.index.ReplicationStats;
import org.opensearch.index.shard.IndexShard;
import org.opensearch.index.store.Store;
import org.opensearch.index.store.StoreFileMetadata;
import org.opensearch.indices.replication.SegmentReplicationSource;
import org.opensearch.indices.replication.SegmentReplicationSourceFactory;
import org.opensearch.indices.replication.SegmentReplicationState;
import org.opensearch.indices.replication.SegmentReplicationTarget;
import org.opensearch.indices.replication.SegmentReplicationTargetService;
import org.opensearch.indices.replication.checkpoint.ReplicationCheckpoint;
import org.opensearch.indices.replication.common.ReplicationCollection;
import org.opensearch.indices.replication.common.ReplicationFailedException;
import org.opensearch.indices.replication.common.ReplicationListener;
import org.opensearch.threadpool.ThreadPool;

public class SegmentReplicator {
    private static final Logger logger = LogManager.getLogger(SegmentReplicator.class);
    private final ReplicationCollection<SegmentReplicationTarget> onGoingReplications;
    private final Map<ShardId, SegmentReplicationState> completedReplications = ConcurrentCollections.newConcurrentMap();
    private final ConcurrentMap<ShardId, ConcurrentNavigableMap<Long, ReplicationCheckpointStats>> replicationCheckpointStats = ConcurrentCollections.newConcurrentMap();
    private final ConcurrentMap<ShardId, ReplicationCheckpoint> primaryCheckpoint = ConcurrentCollections.newConcurrentMap();
    private final ThreadPool threadPool;
    private final SetOnce<SegmentReplicationSourceFactory> sourceFactory;

    public SegmentReplicator(ThreadPool threadPool) {
        this.onGoingReplications = new ReplicationCollection(logger, threadPool);
        this.threadPool = threadPool;
        this.sourceFactory = new SetOnce();
    }

    public void startReplication(final IndexShard shard) {
        if (this.sourceFactory.get() == null) {
            return;
        }
        this.startReplication(shard, shard.getLatestReplicationCheckpoint(), ((SegmentReplicationSourceFactory)this.sourceFactory.get()).get(shard), new SegmentReplicationTargetService.SegmentReplicationListener(){

            @Override
            public void onReplicationDone(SegmentReplicationState state) {
                logger.trace("Completed replication for {}", (Object)shard.shardId());
            }

            @Override
            public void onReplicationFailure(SegmentReplicationState state, ReplicationFailedException e, boolean sendShardFailure) {
                logger.error(() -> new ParameterizedMessage("Failed segment replication for {}", (Object)shard.shardId()), (Throwable)((Object)e));
                if (sendShardFailure) {
                    shard.failShard("unrecoverable replication failure", (Exception)((Object)e));
                }
            }
        });
    }

    void setSourceFactory(SegmentReplicationSourceFactory sourceFactory) {
        this.sourceFactory.set((Object)sourceFactory);
    }

    SegmentReplicationTarget startReplication(IndexShard indexShard, ReplicationCheckpoint checkpoint, SegmentReplicationSource source, SegmentReplicationTargetService.SegmentReplicationListener listener) {
        SegmentReplicationTarget target = new SegmentReplicationTarget(indexShard, checkpoint, source, (ReplicationListener)listener);
        this.startReplication(target, indexShard.getRecoverySettings().activityTimeout());
        return target;
    }

    public ReplicationStats getSegmentReplicationStats(ShardId shardId) {
        ConcurrentNavigableMap existingCheckpointStats = (ConcurrentNavigableMap)this.replicationCheckpointStats.get(shardId);
        if (existingCheckpointStats == null || existingCheckpointStats.isEmpty()) {
            return ReplicationStats.empty();
        }
        Map.Entry lowestEntry = existingCheckpointStats.firstEntry();
        Map.Entry highestEntry = existingCheckpointStats.lastEntry();
        long bytesBehind = ((ReplicationCheckpointStats)highestEntry.getValue()).getBytesBehind();
        long replicationLag = bytesBehind > 0L ? TimeUnit.NANOSECONDS.toMillis(System.nanoTime() - ((ReplicationCheckpointStats)lowestEntry.getValue()).getTimestamp()) : 0L;
        return new ReplicationStats(bytesBehind, bytesBehind, replicationLag);
    }

    public void updateReplicationCheckpointStats(ReplicationCheckpoint latestReceivedCheckPoint, IndexShard indexShard) {
        ReplicationCheckpoint primaryCheckPoint = (ReplicationCheckpoint)this.primaryCheckpoint.get(indexShard.shardId());
        if (primaryCheckPoint == null || latestReceivedCheckPoint.isAheadOf(primaryCheckPoint)) {
            this.primaryCheckpoint.put(indexShard.shardId(), latestReceivedCheckPoint);
            this.calculateReplicationCheckpointStats(latestReceivedCheckPoint, indexShard);
        }
    }

    protected void pruneCheckpointsUpToLastSync(IndexShard indexShard) {
        ReplicationCheckpoint latestCheckpoint = (ReplicationCheckpoint)this.primaryCheckpoint.get(indexShard.shardId());
        if (latestCheckpoint != null) {
            ReplicationCheckpoint indexReplicationCheckPoint = indexShard.getLatestReplicationCheckpoint();
            long segmentInfoVersion = indexReplicationCheckPoint.getSegmentInfosVersion();
            ConcurrentNavigableMap existingCheckpointStats = (ConcurrentNavigableMap)this.replicationCheckpointStats.get(indexShard.shardId());
            if (existingCheckpointStats != null && !existingCheckpointStats.isEmpty()) {
                existingCheckpointStats.keySet().removeIf(key -> key < segmentInfoVersion);
                Map.Entry lastEntry = existingCheckpointStats.lastEntry();
                if (lastEntry != null) {
                    ((ReplicationCheckpointStats)lastEntry.getValue()).setBytesBehind(this.calculateBytesBehind(latestCheckpoint, indexReplicationCheckPoint));
                }
            }
        }
    }

    private void calculateReplicationCheckpointStats(ReplicationCheckpoint latestReceivedCheckPoint, IndexShard indexShard) {
        ReplicationCheckpoint indexShardReplicationCheckpoint = indexShard.getLatestReplicationCheckpoint();
        if (indexShardReplicationCheckpoint != null) {
            ConcurrentNavigableMap existingCheckpointStats;
            long segmentInfosVersion = latestReceivedCheckPoint.getSegmentInfosVersion();
            long bytesBehind = this.calculateBytesBehind(latestReceivedCheckPoint, indexShardReplicationCheckpoint);
            if (bytesBehind > 0L && (existingCheckpointStats = (ConcurrentNavigableMap)this.replicationCheckpointStats.get(indexShard.shardId())) != null) {
                existingCheckpointStats.computeIfAbsent(segmentInfosVersion, k -> new ReplicationCheckpointStats(bytesBehind, latestReceivedCheckPoint.getCreatedTimeStamp()));
            }
        }
    }

    private long calculateBytesBehind(ReplicationCheckpoint latestCheckPoint, ReplicationCheckpoint replicationCheckpoint) {
        Store.RecoveryDiff diff = Store.segmentReplicationDiff(latestCheckPoint.getMetadataMap(), replicationCheckpoint.getMetadataMap());
        return diff.missing.stream().mapToLong(StoreFileMetadata::length).sum();
    }

    public void initializeStats(ShardId shardId) {
        this.replicationCheckpointStats.computeIfAbsent(shardId, k -> new ConcurrentSkipListMap());
    }

    private void start(final long replicationId) {
        SegmentReplicationTarget target;
        try (ReplicationCollection.ReplicationRef<SegmentReplicationTarget> replicationRef = this.onGoingReplications.get(replicationId);){
            if (replicationRef == null) {
                return;
            }
            target = (SegmentReplicationTarget)((Object)replicationRef.get());
        }
        target.startReplication(new ActionListener<Void>(){

            public void onResponse(Void o) {
                logger.debug(() -> new ParameterizedMessage("Finished replicating {} marking as done.", (Object)target.description()));
                SegmentReplicator.this.pruneCheckpointsUpToLastSync(target.indexShard());
                SegmentReplicator.this.onGoingReplications.markAsDone(replicationId);
                if (target.state().getIndex().recoveredFileCount() != 0 && target.state().getIndex().recoveredBytes() != 0L) {
                    SegmentReplicator.this.completedReplications.put(target.shardId(), target.state());
                }
            }

            public void onFailure(Exception e) {
                logger.debug("Replication failed {}", (Object)target.description());
                if (SegmentReplicator.this.isStoreCorrupt(target) || e instanceof CorruptIndexException || e instanceof OpenSearchCorruptionException) {
                    SegmentReplicator.this.onGoingReplications.fail(replicationId, new ReplicationFailedException("Store corruption during replication", (Throwable)e), true);
                    return;
                }
                SegmentReplicator.this.onGoingReplications.fail(replicationId, new ReplicationFailedException("Segment Replication failed", (Throwable)e), false);
            }
        }, this::updateReplicationCheckpointStats);
    }

    void startReplication(SegmentReplicationTarget target, TimeValue timeout) {
        long replicationId;
        try {
            replicationId = this.onGoingReplications.startSafe(target, timeout);
        }
        catch (ReplicationFailedException e) {
            target.fail(e, false);
            return;
        }
        logger.trace(() -> new ParameterizedMessage("Added new replication to collection {}", (Object)target.description()));
        this.threadPool.generic().execute(new ReplicationRunner(replicationId));
    }

    /*
     * WARNING - Removed try catching itself - possible behaviour change.
     */
    private boolean isStoreCorrupt(SegmentReplicationTarget target) {
        Store store;
        if (target.refCount() > 0 && (store = target.store()).tryIncRef()) {
            try {
                boolean bl = store.isMarkedCorrupted();
                return bl;
            }
            catch (IOException ex) {
                logger.warn("Unable to determine if store is corrupt", (Throwable)ex);
                boolean bl = false;
                return bl;
            }
            finally {
                store.decRef();
            }
        }
        return false;
    }

    int size() {
        return this.onGoingReplications.size();
    }

    void cancel(ShardId shardId, String reason) {
        this.onGoingReplications.cancelForShard(shardId, reason);
        this.replicationCheckpointStats.remove(shardId);
        this.primaryCheckpoint.remove(shardId);
    }

    SegmentReplicationTarget get(ShardId shardId) {
        return this.onGoingReplications.getOngoingReplicationTarget(shardId);
    }

    ReplicationCheckpoint getPrimaryCheckpoint(ShardId shardId) {
        return (ReplicationCheckpoint)this.primaryCheckpoint.get(shardId);
    }

    ReplicationCollection.ReplicationRef<SegmentReplicationTarget> get(long id) {
        return this.onGoingReplications.get(id);
    }

    SegmentReplicationState getCompleted(ShardId shardId) {
        return this.completedReplications.get(shardId);
    }

    ReplicationCollection.ReplicationRef<SegmentReplicationTarget> get(long id, ShardId shardId) {
        return this.onGoingReplications.getSafe(id, shardId);
    }

    private static class ReplicationCheckpointStats {
        private long bytesBehind;
        private final long timestamp;

        public ReplicationCheckpointStats(long bytesBehind, long timestamp) {
            this.bytesBehind = bytesBehind;
            this.timestamp = timestamp;
        }

        public long getBytesBehind() {
            return this.bytesBehind;
        }

        public void setBytesBehind(long bytesBehind) {
            this.bytesBehind = bytesBehind;
        }

        public long getTimestamp() {
            return this.timestamp;
        }
    }

    private class ReplicationRunner
    extends AbstractRunnable {
        final long replicationId;

        public ReplicationRunner(long replicationId) {
            this.replicationId = replicationId;
        }

        @Override
        public void onFailure(Exception e) {
            SegmentReplicator.this.onGoingReplications.fail(this.replicationId, new ReplicationFailedException("Unexpected Error during replication", (Throwable)e), false);
        }

        @Override
        public void doRun() {
            SegmentReplicator.this.start(this.replicationId);
        }
    }
}

