/*
 * Decompiled with CFR 0.152.
 */
package dev.langchain4j.rag;

import dev.langchain4j.data.message.ChatMessage;
import dev.langchain4j.data.message.UserMessage;
import dev.langchain4j.internal.Utils;
import dev.langchain4j.internal.ValidationUtils;
import dev.langchain4j.rag.AugmentationRequest;
import dev.langchain4j.rag.AugmentationResult;
import dev.langchain4j.rag.RetrievalAugmentor;
import dev.langchain4j.rag.content.Content;
import dev.langchain4j.rag.content.aggregator.ContentAggregator;
import dev.langchain4j.rag.content.aggregator.DefaultContentAggregator;
import dev.langchain4j.rag.content.injector.ContentInjector;
import dev.langchain4j.rag.content.injector.DefaultContentInjector;
import dev.langchain4j.rag.content.retriever.ContentRetriever;
import dev.langchain4j.rag.query.Metadata;
import dev.langchain4j.rag.query.Query;
import dev.langchain4j.rag.query.router.DefaultQueryRouter;
import dev.langchain4j.rag.query.router.QueryRouter;
import dev.langchain4j.rag.query.transformer.DefaultQueryTransformer;
import dev.langchain4j.rag.query.transformer.QueryTransformer;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionStage;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.Executor;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.SynchronousQueue;
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class DefaultRetrievalAugmentor
implements RetrievalAugmentor {
    private static final Logger log = LoggerFactory.getLogger(DefaultRetrievalAugmentor.class);
    private final QueryTransformer queryTransformer;
    private final QueryRouter queryRouter;
    private final ContentAggregator contentAggregator;
    private final ContentInjector contentInjector;
    private final Executor executor;

    public DefaultRetrievalAugmentor(QueryTransformer queryTransformer, QueryRouter queryRouter, ContentAggregator contentAggregator, ContentInjector contentInjector, Executor executor) {
        this.queryTransformer = Utils.getOrDefault(queryTransformer, DefaultQueryTransformer::new);
        this.queryRouter = ValidationUtils.ensureNotNull(queryRouter, "queryRouter");
        this.contentAggregator = Utils.getOrDefault(contentAggregator, DefaultContentAggregator::new);
        this.contentInjector = Utils.getOrDefault(contentInjector, DefaultContentInjector::new);
        this.executor = Utils.getOrDefault(executor, DefaultRetrievalAugmentor::createDefaultExecutor);
    }

    private static ExecutorService createDefaultExecutor() {
        return new ThreadPoolExecutor(0, Integer.MAX_VALUE, 1L, TimeUnit.SECONDS, new SynchronousQueue<Runnable>());
    }

    @Override
    @Deprecated
    public UserMessage augment(UserMessage userMessage, Metadata metadata) {
        AugmentationRequest augmentationRequest = new AugmentationRequest(userMessage, metadata);
        return (UserMessage)this.augment(augmentationRequest).chatMessage();
    }

    @Override
    public AugmentationResult augment(AugmentationRequest augmentationRequest) {
        ChatMessage chatMessage = augmentationRequest.chatMessage();
        Metadata metadata = augmentationRequest.metadata();
        Query originalQuery = Query.from(chatMessage.text(), metadata);
        Collection<Query> queries = this.queryTransformer.transform(originalQuery);
        DefaultRetrievalAugmentor.logQueries(originalQuery, queries);
        Map<Query, Collection<List<Content>>> queryToContents = this.process(queries);
        List<Content> contents = this.contentAggregator.aggregate(queryToContents);
        DefaultRetrievalAugmentor.log(queryToContents, contents);
        ChatMessage augmentedChatMessage = this.contentInjector.inject(contents, chatMessage);
        DefaultRetrievalAugmentor.log(augmentedChatMessage);
        return AugmentationResult.builder().chatMessage(augmentedChatMessage).contents(contents).build();
    }

    private Map<Query, Collection<List<Content>>> process(Collection<Query> queries) {
        if (queries.size() == 1) {
            Query query2 = queries.iterator().next();
            Collection<ContentRetriever> retrievers = this.queryRouter.route(query2);
            if (retrievers.size() == 1) {
                ContentRetriever contentRetriever = retrievers.iterator().next();
                List<Content> contents = contentRetriever.retrieve(query2);
                return Collections.singletonMap(query2, Collections.singletonList(contents));
            }
            if (retrievers.size() > 1) {
                Collection<List<Content>> contents = this.retrieveFromAll(retrievers, query2).join();
                return Collections.singletonMap(query2, contents);
            }
            return Collections.emptyMap();
        }
        if (queries.size() > 1) {
            ConcurrentHashMap<Query, CompletableFuture<Collection<List<Content>>>> queryToFutureContents = new ConcurrentHashMap<Query, CompletableFuture<Collection<List<Content>>>>();
            queries.forEach(query -> {
                CompletionStage futureContents = CompletableFuture.supplyAsync(() -> {
                    Collection<ContentRetriever> retrievers = this.queryRouter.route((Query)query);
                    DefaultRetrievalAugmentor.log(query, retrievers);
                    return retrievers;
                }, this.executor).thenCompose(retrievers -> this.retrieveFromAll((Collection<ContentRetriever>)retrievers, (Query)query));
                queryToFutureContents.put((Query)query, (CompletableFuture<Collection<List<Content>>>)futureContents);
            });
            return DefaultRetrievalAugmentor.join(queryToFutureContents);
        }
        return Collections.emptyMap();
    }

    private CompletableFuture<Collection<List<Content>>> retrieveFromAll(Collection<ContentRetriever> retrievers, Query query) {
        List<CompletableFuture> futureContents = retrievers.stream().map(retriever -> CompletableFuture.supplyAsync(() -> DefaultRetrievalAugmentor.retrieve(retriever, query), this.executor)).toList();
        return CompletableFuture.allOf(futureContents.toArray(new CompletableFuture[0])).thenApply(ignored -> futureContents.stream().map(CompletableFuture::join).toList());
    }

    private static List<Content> retrieve(ContentRetriever retriever, Query query) {
        List<Content> contents = retriever.retrieve(query);
        DefaultRetrievalAugmentor.log(query, retriever, contents);
        return contents;
    }

    private static Map<Query, Collection<List<Content>>> join(Map<Query, CompletableFuture<Collection<List<Content>>>> queryToFutureContents) {
        return (Map)((CompletableFuture)CompletableFuture.allOf(queryToFutureContents.values().toArray(new CompletableFuture[0])).thenApply(ignored -> queryToFutureContents.entrySet().stream().collect(Collectors.toMap(Map.Entry::getKey, entry -> (Collection)((CompletableFuture)entry.getValue()).join())))).join();
    }

    private static void logQueries(Query originalQuery, Collection<Query> queries) {
        if (queries.size() == 1) {
            Query transformedQuery = queries.iterator().next();
            if (!transformedQuery.equals(originalQuery)) {
                log.debug("Transformed original query '{}' into '{}'", (Object)originalQuery.text(), (Object)transformedQuery.text());
            }
        } else if (log.isDebugEnabled()) {
            log.debug("Transformed original query '{}' into the following queries:\n{}", (Object)originalQuery.text(), (Object)queries.stream().map(Query::text).map(query -> "- '" + query + "'").collect(Collectors.joining("\n")));
        }
    }

    private static void log(Query query, Collection<ContentRetriever> retrievers) {
        if (retrievers.size() == 1) {
            log.debug("Routing query '{}' to the following retriever: {}", (Object)query.text(), (Object)retrievers.iterator().next());
        } else if (log.isDebugEnabled()) {
            log.debug("Routing query '{}' to the following retrievers:\n{}", (Object)query.text(), (Object)retrievers.stream().map(retriever -> "- " + retriever.toString()).collect(Collectors.joining("\n")));
        }
    }

    private static void log(Query query, ContentRetriever retriever, List<Content> contents) {
        log.debug("Retrieved {} contents using query '{}' and retriever '{}'", new Object[]{contents.size(), query.text(), retriever});
        if (!log.isTraceEnabled()) {
            return;
        }
        if (!contents.isEmpty()) {
            String contentsSting = contents.stream().map(Content::textSegment).map(segment -> "- " + DefaultRetrievalAugmentor.escapeNewlines(segment.text())).collect(Collectors.joining("\n"));
            log.trace("Retrieved {} contents using query '{}' and retriever '{}':\n{}", new Object[]{contents.size(), query.text(), retriever.getClass().getName(), contentsSting});
        } else {
            log.trace("Retrieved 0 contents using query '{}' and retriever '{}'", (Object)query.text(), (Object)retriever.getClass().getName());
        }
    }

    private static void log(Map<Query, Collection<List<Content>>> queryToContents, List<Content> contents) {
        int contentCount = 0;
        for (Map.Entry<Query, Collection<List<Content>>> entry : queryToContents.entrySet()) {
            for (List<Content> contentList : entry.getValue()) {
                contentCount += contentList.size();
            }
        }
        if (contentCount == contents.size()) {
            return;
        }
        log.debug("Aggregated {} content(s) into {}", (Object)contentCount, (Object)contents.size());
        if (log.isTraceEnabled()) {
            log.trace("Aggregated {} content(s) into:\n{}", (Object)contentCount, (Object)contents.stream().map(Content::textSegment).map(segment -> "- " + DefaultRetrievalAugmentor.escapeNewlines(segment.text())).collect(Collectors.joining("\n")));
        }
    }

    private static void log(ChatMessage augmentedChatMessage) {
        if (log.isTraceEnabled()) {
            log.trace("Augmented chat message: {}", (Object)DefaultRetrievalAugmentor.escapeNewlines(augmentedChatMessage.text()));
        }
    }

    private static String escapeNewlines(String text) {
        return text.replace("\n", "\\n");
    }

    public static DefaultRetrievalAugmentorBuilder builder() {
        return new DefaultRetrievalAugmentorBuilder();
    }

    public static class DefaultRetrievalAugmentorBuilder {
        private QueryTransformer queryTransformer;
        private QueryRouter queryRouter;
        private ContentAggregator contentAggregator;
        private ContentInjector contentInjector;
        private Executor executor;

        DefaultRetrievalAugmentorBuilder() {
        }

        public DefaultRetrievalAugmentorBuilder contentRetriever(ContentRetriever contentRetriever) {
            this.queryRouter = new DefaultQueryRouter(ValidationUtils.ensureNotNull(contentRetriever, "contentRetriever"));
            return this;
        }

        public DefaultRetrievalAugmentorBuilder queryTransformer(QueryTransformer queryTransformer) {
            this.queryTransformer = queryTransformer;
            return this;
        }

        public DefaultRetrievalAugmentorBuilder queryRouter(QueryRouter queryRouter) {
            this.queryRouter = queryRouter;
            return this;
        }

        public DefaultRetrievalAugmentorBuilder contentAggregator(ContentAggregator contentAggregator) {
            this.contentAggregator = contentAggregator;
            return this;
        }

        public DefaultRetrievalAugmentorBuilder contentInjector(ContentInjector contentInjector) {
            this.contentInjector = contentInjector;
            return this;
        }

        public DefaultRetrievalAugmentorBuilder executor(Executor executor) {
            this.executor = executor;
            return this;
        }

        public DefaultRetrievalAugmentor build() {
            return new DefaultRetrievalAugmentor(this.queryTransformer, this.queryRouter, this.contentAggregator, this.contentInjector, this.executor);
        }

        public String toString() {
            return "DefaultRetrievalAugmentor.DefaultRetrievalAugmentorBuilder(queryTransformer=" + String.valueOf(this.queryTransformer) + ", queryRouter=" + String.valueOf(this.queryRouter) + ", contentAggregator=" + String.valueOf(this.contentAggregator) + ", contentInjector=" + String.valueOf(this.contentInjector) + ", executor=" + String.valueOf(this.executor) + ")";
        }
    }
}

