/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ad.ratelimit;

import java.time.Clock;
import java.time.Duration;
import java.time.Instant;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.Random;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.logging.log4j.message.Message;
import org.apache.logging.log4j.message.ParameterizedMessage;
import org.opensearch.action.bulk.BulkItemResponse;
import org.opensearch.action.bulk.BulkRequest;
import org.opensearch.action.bulk.BulkResponse;
import org.opensearch.action.update.UpdateRequest;
import org.opensearch.ad.NodeStateManager;
import org.opensearch.ad.breaker.ADCircuitBreakerService;
import org.opensearch.ad.ml.CheckpointDao;
import org.opensearch.ad.ml.EntityModel;
import org.opensearch.ad.ml.ModelState;
import org.opensearch.ad.model.AnomalyDetector;
import org.opensearch.ad.ratelimit.BatchWorker;
import org.opensearch.ad.ratelimit.CheckpointWriteRequest;
import org.opensearch.ad.ratelimit.RequestPriority;
import org.opensearch.ad.settings.AnomalyDetectorSettings;
import org.opensearch.ad.util.ExceptionUtil;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.threadpool.ThreadPool;

public class CheckpointWriteWorker
extends BatchWorker<CheckpointWriteRequest, BulkRequest, BulkResponse> {
    private static final Logger LOG = LogManager.getLogger(CheckpointWriteWorker.class);
    public static final String WORKER_NAME = "checkpoint-write";
    private final CheckpointDao checkpoint;
    private final String indexName;
    private final Duration checkpointInterval;

    public CheckpointWriteWorker(long heapSizeInBytes, int singleRequestSizeInBytes, Setting<Float> maxHeapPercentForQueueSetting, ClusterService clusterService, Random random, ADCircuitBreakerService adCircuitBreakerService, ThreadPool threadPool, Settings settings, float maxQueuedTaskRatio, Clock clock, float mediumSegmentPruneRatio, float lowSegmentPruneRatio, int maintenanceFreqConstant, Duration executionTtl, CheckpointDao checkpoint, String indexName, Duration checkpointInterval, NodeStateManager stateManager, Duration stateTtl) {
        super(WORKER_NAME, heapSizeInBytes, singleRequestSizeInBytes, maxHeapPercentForQueueSetting, clusterService, random, adCircuitBreakerService, threadPool, settings, maxQueuedTaskRatio, clock, mediumSegmentPruneRatio, lowSegmentPruneRatio, maintenanceFreqConstant, AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_CONCURRENCY, executionTtl, AnomalyDetectorSettings.CHECKPOINT_WRITE_QUEUE_BATCH_SIZE, stateTtl, stateManager);
        this.checkpoint = checkpoint;
        this.indexName = indexName;
        this.checkpointInterval = checkpointInterval;
    }

    @Override
    protected void executeBatchRequest(BulkRequest request, ActionListener<BulkResponse> listener) {
        this.checkpoint.batchWrite(request, listener);
    }

    @Override
    protected BulkRequest toBatchRequest(List<CheckpointWriteRequest> toProcess) {
        BulkRequest bulkRequest = new BulkRequest();
        for (CheckpointWriteRequest request : toProcess) {
            bulkRequest.add(request.getUpdateRequest());
        }
        return bulkRequest;
    }

    @Override
    protected ActionListener<BulkResponse> getResponseListener(List<CheckpointWriteRequest> toProcess, BulkRequest batchRequest) {
        return ActionListener.wrap(response -> {
            for (BulkItemResponse r : response.getItems()) {
                if (r.getFailureMessage() == null) continue;
                LOG.error(r.getFailureMessage());
            }
        }, exception -> {
            if (ExceptionUtil.isOverloaded(exception)) {
                LOG.error("too many get AD model checkpoint requests or shard not avialble");
                this.setCoolDownStart();
            }
            for (CheckpointWriteRequest request : toProcess) {
                this.nodeStateManager.setException(request.getDetectorId(), (Exception)exception);
            }
            LOG.error("Fail to save models", (Throwable)exception);
        });
    }

    public void write(ModelState<EntityModel> modelState, boolean forceWrite, RequestPriority priority) {
        Instant instant = modelState.getLastCheckpointTime();
        if (!this.checkpoint.shouldSave(instant, forceWrite, this.checkpointInterval, this.clock)) {
            return;
        }
        if (modelState.getModel() != null) {
            String detectorId = modelState.getDetectorId();
            String modelId = modelState.getModelId();
            if (modelId == null || detectorId == null) {
                return;
            }
            this.nodeStateManager.getAnomalyDetector(detectorId, this.onGetDetector(detectorId, modelId, modelState, priority));
        }
    }

    private ActionListener<Optional<AnomalyDetector>> onGetDetector(String detectorId, String modelId, ModelState<EntityModel> modelState, RequestPriority priority) {
        return ActionListener.wrap(detectorOptional -> {
            if (!detectorOptional.isPresent()) {
                LOG.warn((Message)new ParameterizedMessage("AnomalyDetector [{}] is not available.", (Object)detectorId));
                return;
            }
            AnomalyDetector detector = (AnomalyDetector)detectorOptional.get();
            try {
                Map<String, Object> source = this.checkpoint.toIndexSource(modelState);
                if (source == null || source.isEmpty()) {
                    return;
                }
                modelState.setLastCheckpointTime(this.clock.instant());
                CheckpointWriteRequest request = new CheckpointWriteRequest(System.currentTimeMillis() + detector.getDetectorIntervalInMilliseconds(), detectorId, priority, new UpdateRequest(this.indexName, modelId).docAsUpsert(true).doc(source));
                this.put(request);
            }
            catch (Exception e) {
                LOG.error((Message)new ParameterizedMessage("Exception while serializing models for [{}]", (Object)modelId), (Throwable)e);
            }
        }, exception -> LOG.error((Message)new ParameterizedMessage("fail to get detector [{}]", (Object)detectorId), (Throwable)exception));
    }

    public void writeAll(List<ModelState<EntityModel>> modelStates, String detectorId, boolean forceWrite, RequestPriority priority) {
        ActionListener onGetForAll = ActionListener.wrap(detectorOptional -> {
            if (!detectorOptional.isPresent()) {
                LOG.warn((Message)new ParameterizedMessage("AnomalyDetector [{}] is not available.", (Object)detectorId));
                return;
            }
            AnomalyDetector detector = (AnomalyDetector)detectorOptional.get();
            try {
                ArrayList<CheckpointWriteRequest> allRequests = new ArrayList<CheckpointWriteRequest>();
                for (ModelState state : modelStates) {
                    Instant instant = state.getLastCheckpointTime();
                    if (!this.checkpoint.shouldSave(instant, forceWrite, this.checkpointInterval, this.clock)) continue;
                    Map<String, Object> source = this.checkpoint.toIndexSource(state);
                    String modelId = state.getModelId();
                    if (source == null || source.isEmpty() || Strings.isEmpty((CharSequence)modelId)) continue;
                    state.setLastCheckpointTime(this.clock.instant());
                    allRequests.add(new CheckpointWriteRequest(System.currentTimeMillis() + detector.getDetectorIntervalInMilliseconds(), detectorId, priority, new UpdateRequest(this.indexName, modelId).docAsUpsert(true).doc(source)));
                }
                this.putAll(allRequests);
            }
            catch (Exception e) {
                LOG.info((Message)new ParameterizedMessage("Exception while serializing models for [{}]", (Object)detectorId), (Throwable)e);
            }
        }, exception -> LOG.error((Message)new ParameterizedMessage("fail to get detector [{}]", (Object)detectorId), (Throwable)exception));
        this.nodeStateManager.getAnomalyDetector(detectorId, (ActionListener<Optional<AnomalyDetector>>)onGetForAll);
    }
}

