/*
 * This file is licensed under the MIT License, part of Roughly Enough Items.
 * Copyright (c) 2018, 2019, 2020, 2021, 2022, 2023 shedaniel
 *
 * Permission is hereby granted, free of charge, to any person obtaining a copy
 * of this software and associated documentation files (the "Software"), to deal
 * in the Software without restriction, including without limitation the rights
 * to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
 * copies of the Software, and to permit persons to whom the Software is
 * furnished to do so, subject to the following conditions:
 *
 * The above copyright notice and this permission notice shall be included in all
 * copies or substantial portions of the Software.
 *
 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
 * SOFTWARE.
 */

package me.shedaniel.rei.plugin.common.displays.tag;

import com.mojang.serialization.DataResult;
import dev.architectury.event.events.client.ClientLifecycleEvent;
import dev.architectury.impl.NetworkAggregator;
import dev.architectury.networking.NetworkManager;
import dev.architectury.networking.transformers.SplitPacketTransformer;
import dev.architectury.platform.Platform;
import dev.architectury.utils.Env;
import dev.architectury.utils.EnvExecutor;
import io.netty.buffer.Unpooled;
import it.unimi.dsi.fastutil.ints.IntArrayList;
import it.unimi.dsi.fastutil.ints.IntList;
import me.shedaniel.rei.api.common.display.basic.BasicDisplay;
import net.fabricmc.api.EnvType;
import net.fabricmc.api.Environment;
import net.minecraft.class_2378;
import net.minecraft.class_2540;
import net.minecraft.class_2960;
import net.minecraft.class_310;
import net.minecraft.class_3222;
import net.minecraft.class_5321;
import net.minecraft.class_6862;
import net.minecraft.class_6880;
import net.minecraft.class_6885;
import net.minecraft.class_7923;
import net.minecraft.class_9129;
import org.jetbrains.annotations.ApiStatus;

import java.util.*;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.Consumer;

@ApiStatus.Internal
public class TagNodes {
    public static final class_2960 REQUEST_TAGS_PACKET_C2S = class_2960.method_60655("roughlyenoughitems", "request_tags_c2s");
    public static final class_2960 REQUEST_TAGS_PACKET_S2C = class_2960.method_60655("roughlyenoughitems", "request_tags_s2c");
    
    public static final Map<String, class_5321<? extends class_2378<?>>> TAG_DIR_MAP = new HashMap<>();
    public static final ThreadLocal<String> CURRENT_TAG_DIR = new ThreadLocal<>();
    public static final Map<String, Map<CollectionWrapper<?>, RawTagData>> RAW_TAG_DATA_MAP = new ConcurrentHashMap<>();
    public static final Map<class_5321<? extends class_2378<?>>, Map<class_2960, TagData>> TAG_DATA_MAP = new HashMap<>();
    public static Map<class_5321<? extends class_2378<?>>, Consumer<Consumer<DataResult<Map<class_2960, TagData>>>>> requestedTags = new HashMap<>();
    
    public static class CollectionWrapper<T> {
        private final Collection<T> collection;
        
        public CollectionWrapper(Collection<T> collection) {
            this.collection = collection;
        }
        
        @Override
        public boolean equals(Object obj) {
            return obj instanceof CollectionWrapper && ((CollectionWrapper) obj).collection == collection;
        }
        
        @Override
        public int hashCode() {
            return System.identityHashCode(collection);
        }
    }
    
    public record RawTagData(List<class_2960> otherElements, List<class_2960> otherTags) {
    }
    
    public record TagData(IntList otherElements, List<class_2960> otherTags) {
        private static TagData fromNetwork(class_2540 buf) {
            int count = buf.method_10816();
            IntList otherElements = new IntArrayList(count + 1);
            for (int i = 0; i < count; i++) {
                otherElements.add(buf.method_10816());
            }
            count = buf.method_10816();
            List<class_2960> otherTags = new ArrayList<>(count + 1);
            for (int i = 0; i < count; i++) {
                otherTags.add(buf.method_10810());
            }
            return new TagData(otherElements, otherTags);
        }
        
        private void toNetwork(class_2540 buf) {
            buf.method_10804(otherElements.size());
            for (int integer : otherElements) {
                buf.method_10804(integer);
            }
            buf.method_10804(otherTags.size());
            for (class_2960 tag : otherTags) {
                writeResourceLocation(buf, tag);
            }
        }
    }
    
    private static void writeResourceLocation(class_2540 buf, class_2960 location) {
        if (location.method_12836().equals("minecraft")) {
            buf.method_10814(location.method_12832());
        } else {
            buf.method_10814(location.toString());
        }
    }
    
    public static void init() {
        EnvExecutor.runInEnv(Env.CLIENT, () -> Client::init);

        // Fix for TagNodes not being loaded on the server
        // A bit hacky as it uses Architectury's internal API, but this class needs rewriting to use codecs due to the deprecation of the old serialization system anyway.
        if(Platform.getEnvironment() != Env.CLIENT) {
            NetworkAggregator.registerS2CType(REQUEST_TAGS_PACKET_S2C, Collections.singletonList(new SplitPacketTransformer()));
        }
        
        NetworkManager.registerReceiver(NetworkManager.c2s(), REQUEST_TAGS_PACKET_C2S, Collections.singletonList(new SplitPacketTransformer()), (buf, context) -> {
            UUID uuid = buf.method_10790();
            class_5321<? extends class_2378<?>> resourceKey = class_5321.method_29180(buf.method_10810());
            class_9129 newBuf = new class_9129(Unpooled.buffer(), context.registryAccess());
            newBuf.method_10797(uuid);
            Map<class_2960, TagData> dataMap = TAG_DATA_MAP.getOrDefault(resourceKey, Collections.emptyMap());
            newBuf.method_53002(dataMap.size());
            for (Map.Entry<class_2960, TagData> entry : dataMap.entrySet()) {
                writeResourceLocation(newBuf, entry.getKey());
                entry.getValue().toNetwork(newBuf);
            }
            NetworkManager.sendToPlayer((class_3222) context.getPlayer(), REQUEST_TAGS_PACKET_S2C, newBuf);
        });
    }
    
    @Environment(EnvType.CLIENT)
    public static void requestTagData(class_5321<? extends class_2378<?>> resourceKey, Consumer<DataResult<Map<class_2960, TagData>>> callback) {
        if (class_310.method_1551().method_1576() != null) {
            callback.accept(DataResult.success(TAG_DATA_MAP.get(resourceKey)));
        } else if (!NetworkManager.canServerReceive(REQUEST_TAGS_PACKET_C2S)) {
            callback.accept(DataResult.error(() -> "Cannot request tags from server"));
        } else if (requestedTags.containsKey(resourceKey)) {
            requestedTags.get(resourceKey).accept(callback);
            callback.accept(DataResult.success(TAG_DATA_MAP.getOrDefault(resourceKey, Collections.emptyMap())));
        } else {
            class_9129 buf = new class_9129(Unpooled.buffer(), BasicDisplay.registryAccess());
            UUID uuid = UUID.randomUUID();
            buf.method_10797(uuid);
            buf.method_10812(resourceKey.method_29177());
            Client.nextUUID = uuid;
            Client.nextResourceKey = resourceKey;
            List<Consumer<DataResult<Map<class_2960, TagData>>>> callbacks = new CopyOnWriteArrayList<>();
            callbacks.add(callback);
            Client.nextCallback = mapDataResult -> {
                requestedTags.put(resourceKey, c -> c.accept(mapDataResult));
                for (Consumer<DataResult<Map<class_2960, TagData>>> consumer : callbacks) {
                    consumer.accept(mapDataResult);
                }
            };
            requestedTags.put(resourceKey, callbacks::add);
            NetworkManager.sendToServer(REQUEST_TAGS_PACKET_C2S, buf);
        }
    }
    
    private static class Client {
        public static UUID nextUUID;
        public static class_5321<? extends class_2378<?>> nextResourceKey;
        public static Consumer<DataResult<Map<class_2960, TagData>>> nextCallback;
        
        private static void init() {
            ClientLifecycleEvent.CLIENT_LEVEL_LOAD.register(world -> {
                requestedTags.clear();
            });
            NetworkManager.registerReceiver(NetworkManager.s2c(), REQUEST_TAGS_PACKET_S2C, (buf, context) -> {
                UUID uuid = buf.method_10790();
                if (nextUUID.equals(uuid)) {
                    Map<class_2960, TagData> map = new HashMap<>();
                    int count = buf.readInt();
                    for (int i = 0; i < count; i++) {
                        map.put(buf.method_10810(), TagData.fromNetwork(buf));
                    }
                    
                    TAG_DATA_MAP.put(nextResourceKey, map);
                    nextCallback.accept(DataResult.success(map));
                    
                    nextUUID = null;
                    nextResourceKey = null;
                    nextCallback = null;
                }
            });
        }
    }
    
    public static <T> void create(class_6862<T> tagKey, Consumer<DataResult<TagNode<T>>> callback) {
        class_2378<T> registry = ((class_2378<class_2378<T>>) class_7923.field_41167).method_31140((class_5321<class_2378<T>>) tagKey.comp_326());
        requestTagData(tagKey.comp_326(), result -> {
            callback.accept(result.flatMap(dataMap -> dataMap != null ? resolveTag(tagKey, registry, dataMap).orElse(DataResult.error(() -> "No tag data")) : DataResult.error(() -> "No tag data")));
        });
    }
    
    private static <T> Optional<DataResult<TagNode<T>>> resolveTag(class_6862<T> tagKey, class_2378<T> registry, Map<class_2960, TagData> tagDataMap) {
        TagData tagData = tagDataMap.get(tagKey.comp_327());
        if (tagData == null) return Optional.empty();
        
        TagNode<T> self = TagNode.ofReference(tagKey);
        List<class_6880<T>> holders = new ArrayList<>();
        for (int element : tagData.otherElements()) {
            Optional<class_6880.class_6883<T>> holder = registry.method_40265(element);
            if (holder.isPresent()) {
                holders.add(holder.get());
            }
        }
        if (!holders.isEmpty()) {
            self.addValuesChild(class_6885.method_40242(holders));
        }
        for (class_2960 childTagId : tagData.otherTags()) {
            class_6862<T> childTagKey = class_6862.method_40092(tagKey.comp_326(), childTagId);
            if (registry.method_46733(childTagKey).isPresent()) {
                Optional<DataResult<TagNode<T>>> resultOptional = resolveTag(childTagKey, registry, tagDataMap);
                if (resultOptional.isPresent()) {
                    DataResult<TagNode<T>> result = resultOptional.get();
                    if (result.error().isPresent()) return Optional.of(DataResult.error(() -> result.error().get().message()));
                    self.addChild(result.result().get());
                }
            }
        }
        return Optional.of(DataResult.success(self));
    }
}
