/*
 * Decompiled with CFR 0.152.
 */
package org.opensearch.ml.engine.function_calling;

import com.jayway.jsonpath.JsonPath;
import com.jayway.jsonpath.Predicate;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import org.opensearch.core.common.util.CollectionUtils;
import org.opensearch.ml.common.output.model.ModelTensor;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.output.model.ModelTensors;
import org.opensearch.ml.common.utils.StringUtils;
import org.opensearch.ml.engine.algorithms.agent.AgentUtils;
import org.opensearch.ml.engine.function_calling.BedrockMessage;
import org.opensearch.ml.engine.function_calling.FunctionCalling;
import org.opensearch.ml.engine.function_calling.LLMMessage;

public class BedrockConverseDeepseekR1FunctionCalling
implements FunctionCalling {
    public static final String FINISH_REASON_PATH = "stop_reason";
    public static final String FINISH_REASON = "tool_use";
    public static final String CALL_PATH = "tool_calls";
    public static final String NAME = "tool_name";
    public static final String INPUT = "input";
    public static final String ID_PATH = "id";
    public static final String TOOL_ERROR = "tool_error";
    public static final String BEDROCK_DEEPSEEK_R1_TOOL_TEMPLATE = "{\"toolSpec\":{\"name\":\"${tool.name}\",\"description\":\"${tool.description}\",\"inputSchema\": {\"json\": ${tool.attributes.input_schema} } }}";

    @Override
    public void configure(Map<String, String> params) {
        if (!params.containsKey("no_escape_params")) {
            params.put("no_escape_params", "_chat_history,_interactions");
        }
        params.put("llm_response_filter", "$.output.message.content[0].text");
        params.put("llm_final_response_post_filter", "$.message.content[0].text");
        params.put("tool_template", BEDROCK_DEEPSEEK_R1_TOOL_TEMPLATE);
        params.put("tool_calls_path", "_llm_response.tool_calls");
        params.put("tool_calls.tool_name", NAME);
        params.put("tool_calls.tool_input", INPUT);
        params.put("tool_calls.id_path", ID_PATH);
        params.put("interaction_template.assistant_tool_calls_path", "$.output.message");
        params.put("interaction_template.assistant_tool_calls_exclude_path", "[ \"$.output.message.content[?(@.reasoningContent)]\" ]");
        params.put("interaction_template.tool_response", "{\"role\":\"user\",\"content\":[ {\"text\":\"{\\\"tool_call_id\\\":\\\"${_interactions.tool_call_id}\\\",\\\"tool_result\\\": \\\"${_interactions.tool_response}\\\"\"} ]}");
        params.put("chat_history_template.user_question", "{\"role\":\"user\",\"content\":[{\"text\":\"${_chat_history.message.question}\"}]}");
        params.put("chat_history_template.ai_response", "{\"role\":\"assistant\",\"content\":[{\"text\":\"${_chat_history.message.response}\"}]}");
        params.put("llm_finish_reason_path", "_llm_response.stop_reason");
        params.put("llm_finish_reason_tool_use", FINISH_REASON);
    }

    @Override
    public List<Map<String, String>> handle(ModelTensorOutput tmpModelTensorOutput, Map<String, String> parameters) {
        Object response;
        Map llmResponse;
        String llmFinishReason;
        ArrayList<Map<String, String>> output = new ArrayList<Map<String, String>>();
        Map<String, ?> dataAsMap = ((ModelTensor)((ModelTensors)tmpModelTensorOutput.getMlModelOutputs().get(0)).getMlModelTensors().get(0)).getDataAsMap();
        String llmResponseExcludePath = parameters.get("llm_response_exclude_path");
        if (llmResponseExcludePath != null) {
            dataAsMap = AgentUtils.removeJsonPath(dataAsMap, llmResponseExcludePath, true);
        }
        if (!(llmFinishReason = (String)JsonPath.read((Object)(llmResponse = StringUtils.fromJson((String)StringUtils.toJson((Object)(response = JsonPath.read(dataAsMap, (String)parameters.get("llm_response_filter"), (Predicate[])new Predicate[0]))), (String)"response")), (String)FINISH_REASON_PATH, (Predicate[])new Predicate[0])).contentEquals(FINISH_REASON)) {
            return output;
        }
        List toolCalls = (List)JsonPath.read((Object)llmResponse, (String)CALL_PATH, (Predicate[])new Predicate[0]);
        if (CollectionUtils.isEmpty((Collection)toolCalls)) {
            return output;
        }
        for (Object call : toolCalls) {
            String toolName = (String)JsonPath.read(call, (String)NAME, (Predicate[])new Predicate[0]);
            String toolInput = StringUtils.toJson((Object)JsonPath.read(call, (String)INPUT, (Predicate[])new Predicate[0]));
            String toolCallId = (String)JsonPath.read(call, (String)ID_PATH, (Predicate[])new Predicate[0]);
            output.add(Map.of(NAME, toolName, "tool_input", toolInput, "tool_call_id", toolCallId));
        }
        return output;
    }

    @Override
    public List<LLMMessage> supply(List<Map<String, Object>> toolResults) {
        BedrockMessage toolMessage = new BedrockMessage();
        for (Map<String, Object> toolResult : toolResults) {
            String toolUseId = (String)toolResult.get("tool_call_id");
            if (toolUseId == null) continue;
            String textJson = StringUtils.toJson(Map.of("tool_call_id", toolUseId, "tool_result", toolResult.get("tool_result")));
            toolMessage.getContent().add(Map.of("text", textJson));
        }
        return List.of(toolMessage);
    }
}

