/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.neuralsearch.search.query.util;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Objects;
import lombok.Generated;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.apache.lucene.search.FieldDoc;
import org.apache.lucene.search.IndexSearcher;
import org.apache.lucene.search.ScoreDoc;
import org.apache.lucene.search.SortField;
import org.apache.lucene.search.TopDocs;
import org.apache.lucene.search.TopFieldDocs;
import org.apache.lucene.search.TotalHits;
import org.apache.lucene.search.grouping.CollapseTopFieldDocs;
import org.apache.lucene.util.BytesRef;
import org.opensearch.common.lucene.search.TopDocsAndMaxScore;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.neuralsearch.search.collector.HybridSearchCollector;
import org.opensearch.neuralsearch.search.query.HybridCollectorResultsUtilParams;
import org.opensearch.neuralsearch.search.query.exception.HybridSearchRescoreQueryException;
import org.opensearch.neuralsearch.search.util.HybridSearchResultFormatUtil;
import org.opensearch.search.query.QuerySearchResult;
import org.opensearch.search.rescore.RescoreContext;

public class HybridSearchCollectorResultUtil {
    @Generated
    private static final Logger log = LogManager.getLogger(HybridSearchCollectorResultUtil.class);
    private final HybridCollectorResultsUtilParams hybridSearchCollectorResultsDTO;
    private final HybridSearchCollector hybridSearchCollector;

    public void reduceCollectorResults(QuerySearchResult result, TopDocsAndMaxScore topDocsAndMaxScore) {
        if (result.hasConsumedTopDocs()) {
            result.topDocs(topDocsAndMaxScore, this.hybridSearchCollectorResultsDTO.getDocValueFormats());
            return;
        }
        if (topDocsAndMaxScore.topDocs.totalHits.value() == 0L) {
            return;
        }
        TopDocsAndMaxScore originalTotalDocsAndHits = result.topDocs();
        TopDocsAndMaxScore mergeTopDocsAndMaxScores = this.hybridSearchCollectorResultsDTO.getTopDocsMerger().merge(originalTotalDocsAndHits, topDocsAndMaxScore);
        result.topDocs(mergeTopDocsAndMaxScores, this.hybridSearchCollectorResultsDTO.getDocValueFormats());
    }

    public TopDocsAndMaxScore getTopDocsAndMaxScore() throws IOException {
        List<TopDocs> topDocs = this.hybridSearchCollector.topDocs();
        if (this.hybridSearchCollectorResultsDTO.isSortEnabled() || this.hybridSearchCollectorResultsDTO.isCollapseEnabled()) {
            return this.getSortedTopDocsAndMaxScore(topDocs);
        }
        return this.getTopDocsAndMaxScore(topDocs);
    }

    private TopDocsAndMaxScore getSortedTopDocsAndMaxScore(List<TopFieldDocs> topDocs) {
        SortField[] sortFields = this.hybridSearchCollectorResultsDTO.getSortFields();
        TopDocs sortedTopDocs = this.hybridSearchCollectorResultsDTO.isCollapseEnabled() ? this.getCollapseTopFieldDocs(this.getTotalHits(topDocs), topDocs, sortFields) : this.getNewTopFieldDocs(this.getTotalHits(topDocs), topDocs, sortFields);
        return new TopDocsAndMaxScore(sortedTopDocs, this.hybridSearchCollector.getMaxScore());
    }

    private TopDocs getCollapseTopFieldDocs(TotalHits totalHits, List<TopFieldDocs> collapseTopFieldDocs, SortField[] sortFields) {
        String collapseField = this.hybridSearchCollectorResultsDTO.getSearchContext().collapse().getFieldName();
        if (Objects.isNull(collapseTopFieldDocs)) {
            return new CollapseTopFieldDocs(collapseField, totalHits, (ScoreDoc[])new FieldDoc[0], sortFields, new Object[0]);
        }
        int delimiterDocId = this.findDelimiterDocId(collapseTopFieldDocs);
        if (delimiterDocId == -1) {
            return new CollapseTopFieldDocs(collapseField, totalHits, (ScoreDoc[])new FieldDoc[0], sortFields, new Object[0]);
        }
        ArrayList<Object> collapseValues = new ArrayList<Object>();
        ArrayList fieldDocs = new ArrayList();
        ArrayList<Object> result = new ArrayList<Object>();
        Object[] fields = HybridSearchResultFormatUtil.createSortFieldsForDelimiterResults(collapseTopFieldDocs.getFirst().fields);
        result.add(HybridSearchResultFormatUtil.createFieldDocStartStopElementForHybridSearchResults(delimiterDocId, fields));
        collapseValues.add(new BytesRef(HybridSearchResultFormatUtil.createCollapseValueStartStopElementForHybridSearchResults()));
        for (TopDocs topDocs : collapseTopFieldDocs) {
            CollapseTopFieldDocs collapseTopFieldDoc = (CollapseTopFieldDocs)topDocs;
            collapseField = collapseTopFieldDoc.field;
            if (Objects.isNull(topDocs) || Objects.isNull(topDocs.scoreDocs)) {
                result.add(HybridSearchResultFormatUtil.createFieldDocDelimiterElementForHybridSearchResults(delimiterDocId, fields));
                continue;
            }
            ArrayList<FieldDoc> fieldDocsPerQuery = new ArrayList<FieldDoc>();
            for (ScoreDoc scoreDoc : collapseTopFieldDoc.scoreDocs) {
                fieldDocsPerQuery.add((FieldDoc)scoreDoc);
            }
            result.add(HybridSearchResultFormatUtil.createFieldDocDelimiterElementForHybridSearchResults(delimiterDocId, fields));
            result.addAll(fieldDocsPerQuery);
            collapseValues.add(new BytesRef(HybridSearchResultFormatUtil.createCollapseValueDelimiterElementForHybridSearchResults()));
            collapseValues.addAll(Arrays.asList(collapseTopFieldDoc.collapseValues));
        }
        result.add(HybridSearchResultFormatUtil.createFieldDocStartStopElementForHybridSearchResults(delimiterDocId, fields));
        collapseValues.add(new BytesRef(HybridSearchResultFormatUtil.createCollapseValueStartStopElementForHybridSearchResults()));
        fieldDocs.addAll(result);
        return new CollapseTopFieldDocs(collapseField, totalHits, (ScoreDoc[])fieldDocs.toArray(new FieldDoc[0]), collapseTopFieldDocs.getFirst().fields, collapseValues.toArray(new Object[0]));
    }

    private TopDocs getNewTopFieldDocs(TotalHits totalHits, List<TopFieldDocs> topFieldDocs, SortField[] sortFields) {
        if (Objects.isNull(topFieldDocs)) {
            return new TopFieldDocs(totalHits, (ScoreDoc[])new FieldDoc[0], sortFields);
        }
        int delimiterDocId = this.findDelimiterDocId(topFieldDocs);
        if (delimiterDocId == -1) {
            return new TopFieldDocs(totalHits, (ScoreDoc[])new FieldDoc[0], sortFields);
        }
        Object[] sortFieldsForDelimiterResults = HybridSearchResultFormatUtil.createSortFieldsForDelimiterResults(sortFields);
        ArrayList<Object> result = new ArrayList<Object>();
        result.add(HybridSearchResultFormatUtil.createFieldDocStartStopElementForHybridSearchResults(delimiterDocId, sortFieldsForDelimiterResults));
        for (TopFieldDocs topFieldDoc : topFieldDocs) {
            if (Objects.isNull(topFieldDoc) || Objects.isNull(topFieldDoc.scoreDocs)) {
                result.add(HybridSearchResultFormatUtil.createFieldDocDelimiterElementForHybridSearchResults(delimiterDocId, sortFieldsForDelimiterResults));
                continue;
            }
            ArrayList<FieldDoc> fieldDocsPerQuery = new ArrayList<FieldDoc>();
            for (ScoreDoc scoreDoc : topFieldDoc.scoreDocs) {
                fieldDocsPerQuery.add((FieldDoc)scoreDoc);
            }
            result.add(HybridSearchResultFormatUtil.createFieldDocDelimiterElementForHybridSearchResults(delimiterDocId, sortFieldsForDelimiterResults));
            result.addAll(fieldDocsPerQuery);
        }
        result.add(HybridSearchResultFormatUtil.createFieldDocStartStopElementForHybridSearchResults(delimiterDocId, sortFieldsForDelimiterResults));
        FieldDoc[] fieldDocs = result.toArray(new FieldDoc[0]);
        return new TopFieldDocs(totalHits, (ScoreDoc[])fieldDocs, sortFields);
    }

    private TotalHits getTotalHits(List<?> topDocs) {
        TotalHits.Relation relation;
        TotalHits.Relation relation2 = relation = this.hybridSearchCollectorResultsDTO.getTrackTotalHitsUpTo() == -1 ? TotalHits.Relation.GREATER_THAN_OR_EQUAL_TO : TotalHits.Relation.EQUAL_TO;
        if (topDocs == null || topDocs.isEmpty()) {
            return new TotalHits(0L, relation);
        }
        return new TotalHits((long)this.hybridSearchCollector.getTotalHits(), relation);
    }

    private TopDocsAndMaxScore getTopDocsAndMaxScore(List<TopDocs> topDocs) {
        if (this.shouldRescore()) {
            topDocs = this.rescore(topDocs);
        }
        float maxScore = this.calculateMaxScore(topDocs, this.hybridSearchCollector.getMaxScore());
        TopDocs finalTopDocs = this.getNewTopDocs(this.getTotalHits(topDocs), topDocs);
        return new TopDocsAndMaxScore(finalTopDocs, maxScore);
    }

    private TopDocs getNewTopDocs(TotalHits totalHits, List<TopDocs> topDocs) {
        ScoreDoc[] scoreDocs = new ScoreDoc[]{};
        if (Objects.isNull(topDocs)) {
            return new TopDocs(totalHits, scoreDocs);
        }
        int delimiterDocId = this.findDelimiterDocId(topDocs);
        if (delimiterDocId == -1) {
            return new TopDocs(totalHits, scoreDocs);
        }
        ArrayList<ScoreDoc> result = new ArrayList<ScoreDoc>();
        result.add(HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults(delimiterDocId));
        for (TopDocs topDoc : topDocs) {
            if (Objects.isNull(topDoc) || Objects.isNull(topDoc.scoreDocs)) {
                result.add(HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults(delimiterDocId));
                continue;
            }
            result.add(HybridSearchResultFormatUtil.createDelimiterElementForHybridSearchResults(delimiterDocId));
            result.addAll(Arrays.asList(topDoc.scoreDocs));
        }
        result.add(HybridSearchResultFormatUtil.createStartStopElementForHybridSearchResults(delimiterDocId));
        scoreDocs = (ScoreDoc[])result.stream().map(doc -> new ScoreDoc(doc.doc, doc.score, doc.shardIndex)).toArray(ScoreDoc[]::new);
        return new TopDocs(totalHits, scoreDocs);
    }

    private int findDelimiterDocId(List<? extends TopDocs> topDocs) {
        return topDocs.stream().filter(Objects::nonNull).filter(topDoc -> Objects.nonNull(topDoc.scoreDocs)).map(topDoc -> topDoc.scoreDocs).filter(scoreDoc -> ((ScoreDoc[])scoreDoc).length > 0).map(scoreDoc -> scoreDoc[0].doc).findFirst().orElse(-1);
    }

    private float calculateMaxScore(List<TopDocs> topDocsList, float initialMaxScore) {
        List<RescoreContext> rescoreContexts = this.hybridSearchCollectorResultsDTO.getRescoreContexts();
        if (Objects.nonNull(rescoreContexts) && !rescoreContexts.isEmpty()) {
            for (TopDocs topDocs : topDocsList) {
                if (!Objects.nonNull(topDocs.scoreDocs) || topDocs.scoreDocs.length <= 0) continue;
                initialMaxScore = Math.max(initialMaxScore, topDocs.scoreDocs[0].score);
            }
        }
        return initialMaxScore;
    }

    private boolean shouldRescore() {
        List<RescoreContext> rescoreContexts = this.hybridSearchCollectorResultsDTO.getRescoreContexts();
        return !CollectionUtils.isEmpty(rescoreContexts);
    }

    private List<TopDocs> rescore(List<TopDocs> topDocs) {
        List<TopDocs> rescoredTopDocs = topDocs;
        for (RescoreContext ctx : this.hybridSearchCollectorResultsDTO.getRescoreContexts()) {
            rescoredTopDocs = this.rescoredTopDocs(ctx, rescoredTopDocs);
        }
        return rescoredTopDocs;
    }

    private List<TopDocs> rescoredTopDocs(RescoreContext ctx, List<TopDocs> topDocs) {
        ArrayList<TopDocs> result = new ArrayList<TopDocs>(topDocs.size());
        for (TopDocs topDoc : topDocs) {
            try {
                result.add(ctx.rescorer().rescore(topDoc, (IndexSearcher)this.hybridSearchCollectorResultsDTO.getSearchContext().searcher(), ctx));
            }
            catch (IOException exception) {
                log.error("rescore failed for hybrid query in collector_manager.reduce call", (Throwable)exception);
                throw new HybridSearchRescoreQueryException(exception);
            }
        }
        return result;
    }

    @Generated
    public HybridSearchCollectorResultUtil(HybridCollectorResultsUtilParams hybridSearchCollectorResultsDTO, HybridSearchCollector hybridSearchCollector) {
        this.hybridSearchCollectorResultsDTO = hybridSearchCollectorResultsDTO;
        this.hybridSearchCollector = hybridSearchCollector;
    }
}

