From 5a5603043e44ada44c68136c7ec166677e21b344 Mon Sep 17 00:00:00 2001 From: deirn Date: Fri, 20 Oct 2023 18:38:15 +0700 Subject: [PATCH] read packets on netty thread, execute receiver on main closes #10 --- .../badpackets/api/config/ConfigPackets.java | 12 +++ .../bai/badpackets/api/play/PlayPackets.java | 12 +++ .../impl/handler/AbstractPacketHandler.java | 28 ++++-- .../handler/ClientConfigPacketHandler.java | 4 +- .../impl/handler/ClientPlayPacketHandler.java | 4 +- .../handler/ServerConfigPacketHandler.java | 4 +- .../impl/handler/ServerPlayPacketHandler.java | 4 +- ...ServerConfigurationPacketListenerImpl.java | 7 -- .../MixinServerGamePacketListenerImpl.java | 7 -- .../impl/registry/ChannelRegistry.java | 85 ++++++++++++++++--- 10 files changed, 123 insertions(+), 44 deletions(-) diff --git a/src/main/java/lol/bai/badpackets/api/config/ConfigPackets.java b/src/main/java/lol/bai/badpackets/api/config/ConfigPackets.java index d302469..466d6c0 100644 --- a/src/main/java/lol/bai/badpackets/api/config/ConfigPackets.java +++ b/src/main/java/lol/bai/badpackets/api/config/ConfigPackets.java @@ -7,11 +7,13 @@ import lol.bai.badpackets.impl.payload.UntypedPayload; import lol.bai.badpackets.impl.registry.CallbackRegistry; import lol.bai.badpackets.impl.registry.ChannelRegistry; +import net.minecraft.client.Minecraft; import net.minecraft.client.multiplayer.ClientConfigurationPacketListenerImpl; import net.minecraft.network.FriendlyByteBuf; import net.minecraft.network.chat.Component; import net.minecraft.network.protocol.common.custom.CustomPacketPayload; import net.minecraft.resources.ResourceLocation; +import net.minecraft.server.MinecraftServer; /** * Utility for working with configuration packets. @@ -32,6 +34,9 @@ public static void registerTask(ResourceLocation id, ConfigTaskExecutor executor /** * Register a client-to-server packet receiver. + *

+ * Raw packet receiver is run on Netty event-loop. Read the buffer on it and run + * the operation on {@linkplain MinecraftServer#execute(Runnable) server thread}. * * @param id the packet id * @param receiver the receiver @@ -45,6 +50,8 @@ public static void registerServerReceiver(ResourceLocation id, ServerConfigPacke /** * Register a client-to-server packet receiver. + *

+ * Typed packet receiver is run on the main server thread. * * @param id the {@linkplain CustomPacketPayload#id() packet id} * @param reader the payload reader @@ -71,6 +78,9 @@ public static void registerServerReadyCallback(ServerConfigPacketReadyCallback c /** * Register a server-to-client packet receiver. + *

+ * Raw packet receiver is run on Netty event-loop. Read the buffer on it and run + * the operation on {@linkplain Minecraft#execute(Runnable) client thread}. * * @param id the packet id * @param receiver the receiver @@ -86,6 +96,8 @@ public static void registerClientReceiver(ResourceLocation id, ClientConfigPacke /** * Register a server-to-client packet receiver. + *

+ * Typed packet receiver is run on the main client thread. * * @param id the {@linkplain CustomPacketPayload#id() packet id} * @param reader the payload reader diff --git a/src/main/java/lol/bai/badpackets/api/play/PlayPackets.java b/src/main/java/lol/bai/badpackets/api/play/PlayPackets.java index 41955c3..be26e83 100644 --- a/src/main/java/lol/bai/badpackets/api/play/PlayPackets.java +++ b/src/main/java/lol/bai/badpackets/api/play/PlayPackets.java @@ -5,9 +5,11 @@ import lol.bai.badpackets.impl.payload.UntypedPayload; import lol.bai.badpackets.impl.registry.CallbackRegistry; import lol.bai.badpackets.impl.registry.ChannelRegistry; +import net.minecraft.client.Minecraft; import net.minecraft.network.FriendlyByteBuf; import net.minecraft.network.protocol.common.custom.CustomPacketPayload; import net.minecraft.resources.ResourceLocation; +import net.minecraft.server.MinecraftServer; /** * Utility for working with play packets. @@ -16,6 +18,9 @@ public final class PlayPackets { /** * Register a client-to-server packet receiver. + *

+ * Raw packet receiver is run on Netty event-loop. Read the buffer on it and run + * the operation on {@linkplain MinecraftServer#execute(Runnable) server thread}. * * @param id the packet id * @param receiver the receiver @@ -27,6 +32,8 @@ public static void registerServerReceiver(ResourceLocation id, ServerPlayPacketR /** * Register a client-to-server packet receiver. + *

+ * Typed packet receiver is run on the main server thread. * * @param id the {@linkplain CustomPacketPayload#id() packet id} * @param reader the payload reader @@ -50,6 +57,9 @@ public static void registerServerReadyCallback(ServerPlayPacketReadyCallback cal /** * Register a server-to-client packet receiver. + *

+ * Raw packet receiver is run on Netty event-loop. Read the buffer on it and run + * the operation on {@linkplain Minecraft#execute(Runnable) client thread}. * * @param id the packet id * @param receiver the receiver @@ -62,6 +72,8 @@ public static void registerClientReceiver(ResourceLocation id, ClientPlayPacketR /** * Register a server-to-client packet receiver. + *

+ * Typed packet receiver is run on the main client thread. * * @param id the {@linkplain CustomPacketPayload#id() packet id} * @param reader the payload reader diff --git a/src/main/java/lol/bai/badpackets/impl/handler/AbstractPacketHandler.java b/src/main/java/lol/bai/badpackets/impl/handler/AbstractPacketHandler.java index c779fc0..7dadc9c 100644 --- a/src/main/java/lol/bai/badpackets/impl/handler/AbstractPacketHandler.java +++ b/src/main/java/lol/bai/badpackets/impl/handler/AbstractPacketHandler.java @@ -1,6 +1,7 @@ package lol.bai.badpackets.impl.handler; import java.nio.charset.StandardCharsets; +import java.util.Collections; import java.util.HashSet; import java.util.List; import java.util.Map; @@ -19,6 +20,7 @@ import net.minecraft.network.protocol.Packet; import net.minecraft.network.protocol.common.custom.CustomPacketPayload; import net.minecraft.resources.ResourceLocation; +import net.minecraft.util.thread.BlockableEventLoop; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.jetbrains.annotations.Nullable; @@ -29,16 +31,18 @@ public abstract class AbstractPacketHandler implements PacketSender { protected final Logger logger; private final Function> packetFactory; - private final Set sendableChannels = new HashSet<>(); + private final Set sendableChannels = Collections.synchronizedSet(new HashSet<>()); + private final BlockableEventLoop eventLoop; private final Connection connection; private boolean initialized = false; - protected AbstractPacketHandler(String desc, ChannelRegistry registry, Function> packetFactory, Connection connection) { + protected AbstractPacketHandler(String desc, ChannelRegistry registry, Function> packetFactory, BlockableEventLoop eventLoop, Connection connection) { this.logger = LogManager.getLogger(desc); this.registry = registry; this.packetFactory = packetFactory; + this.eventLoop = eventLoop; this.connection = connection; registry.addHandler(this); @@ -58,7 +62,8 @@ private void receiveChannelSyncPacket(FriendlyByteBuf buf) { sendableChannels.add(new ResourceLocation(namespace, path)); } } - onInitialChannelSyncPacketReceived(); + + eventLoop.execute(this::onInitialChannelSyncPacketReceived); } } } @@ -76,9 +81,15 @@ public boolean receive(CustomPacketPayload payload) { return true; } - if (registry.channels.containsKey(id)) { + if (registry.has(id)) { try { - receive(registry.channels.get(id), payload); + T receiver = registry.get(id); + + if (payload instanceof UntypedPayload || eventLoop.isSameThread()) { + receiveUnsafe(receiver, payload); + } else eventLoop.execute(() -> { + if (connection.isConnected()) receiveUnsafe(receiver, payload); + }); } catch (Throwable t) { logger.error("Error when receiving packet {}", id, t); throw t; @@ -91,17 +102,18 @@ public boolean receive(CustomPacketPayload payload) { protected abstract void onInitialChannelSyncPacketReceived(); - protected abstract void receive(T receiver, CustomPacketPayload payload); + protected abstract void receiveUnsafe(T receiver, CustomPacketPayload payload); public void sendInitialChannelSyncPacket() { if (!initialized) { initialized = true; sendVanillaChannelRegisterPacket(Set.of(Constants.CHANNEL_SYNC)); + Set channels = registry.getChannels(); FriendlyByteBuf buf = new FriendlyByteBuf(Unpooled.buffer()); buf.writeByte(Constants.CHANNEL_SYNC_INITIAL); - Map> group = registry.channels.keySet().stream().collect(Collectors.groupingBy(ResourceLocation::getNamespace)); + Map> group = channels.stream().collect(Collectors.groupingBy(ResourceLocation::getNamespace)); buf.writeVarInt(group.size()); for (Map.Entry> entry : group.entrySet()) { @@ -114,7 +126,7 @@ public void sendInitialChannelSyncPacket() { } send(Constants.CHANNEL_SYNC, buf); - sendVanillaChannelRegisterPacket(registry.channels.keySet()); + sendVanillaChannelRegisterPacket(channels); } } diff --git a/src/main/java/lol/bai/badpackets/impl/handler/ClientConfigPacketHandler.java b/src/main/java/lol/bai/badpackets/impl/handler/ClientConfigPacketHandler.java index c7e4b60..165aedb 100644 --- a/src/main/java/lol/bai/badpackets/impl/handler/ClientConfigPacketHandler.java +++ b/src/main/java/lol/bai/badpackets/impl/handler/ClientConfigPacketHandler.java @@ -19,7 +19,7 @@ public class ClientConfigPacketHandler extends AbstractPacketHandler receiver, CustomPacketPayload payload) { + protected void receiveUnsafe(ClientConfigPacketReceiver receiver, CustomPacketPayload payload) { receiver.receive(client, listener, payload, this); } diff --git a/src/main/java/lol/bai/badpackets/impl/handler/ClientPlayPacketHandler.java b/src/main/java/lol/bai/badpackets/impl/handler/ClientPlayPacketHandler.java index 80d2020..5899f8e 100644 --- a/src/main/java/lol/bai/badpackets/impl/handler/ClientPlayPacketHandler.java +++ b/src/main/java/lol/bai/badpackets/impl/handler/ClientPlayPacketHandler.java @@ -18,7 +18,7 @@ public class ClientPlayPacketHandler extends AbstractPacketHandler receiver, CustomPacketPayload payload) { + protected void receiveUnsafe(ClientPlayPacketReceiver receiver, CustomPacketPayload payload) { receiver.receive(client, listener, payload, this); } diff --git a/src/main/java/lol/bai/badpackets/impl/handler/ServerConfigPacketHandler.java b/src/main/java/lol/bai/badpackets/impl/handler/ServerConfigPacketHandler.java index 44329c9..e25af21 100644 --- a/src/main/java/lol/bai/badpackets/impl/handler/ServerConfigPacketHandler.java +++ b/src/main/java/lol/bai/badpackets/impl/handler/ServerConfigPacketHandler.java @@ -31,7 +31,7 @@ public class ServerConfigPacketHandler extends AbstractPacketHandler receiver, CustomPacketPayload payload) { + protected void receiveUnsafe(ServerConfigPacketReceiver receiver, CustomPacketPayload payload) { receiver.receive(server, listener, payload, this, this); } diff --git a/src/main/java/lol/bai/badpackets/impl/handler/ServerPlayPacketHandler.java b/src/main/java/lol/bai/badpackets/impl/handler/ServerPlayPacketHandler.java index 5053c00..d03c700 100644 --- a/src/main/java/lol/bai/badpackets/impl/handler/ServerPlayPacketHandler.java +++ b/src/main/java/lol/bai/badpackets/impl/handler/ServerPlayPacketHandler.java @@ -20,7 +20,7 @@ public class ServerPlayPacketHandler extends AbstractPacketHandler receiver, CustomPacketPayload payload) { + protected void receiveUnsafe(ServerPlayPacketReceiver receiver, CustomPacketPayload payload) { receiver.receive(server, handler.getPlayer(), handler, payload, this); } diff --git a/src/main/java/lol/bai/badpackets/impl/mixin/MixinServerConfigurationPacketListenerImpl.java b/src/main/java/lol/bai/badpackets/impl/mixin/MixinServerConfigurationPacketListenerImpl.java index b4aadd6..dd73144 100644 --- a/src/main/java/lol/bai/badpackets/impl/mixin/MixinServerConfigurationPacketListenerImpl.java +++ b/src/main/java/lol/bai/badpackets/impl/mixin/MixinServerConfigurationPacketListenerImpl.java @@ -5,8 +5,6 @@ import lol.bai.badpackets.impl.Constants; import lol.bai.badpackets.impl.handler.ServerConfigPacketHandler; import net.minecraft.network.Connection; -import net.minecraft.network.protocol.PacketUtils; -import net.minecraft.network.protocol.common.ServerCommonPacketListener; import net.minecraft.network.protocol.common.ServerboundCustomPayloadPacket; import net.minecraft.network.protocol.configuration.ServerboundFinishConfigurationPacket; import net.minecraft.server.MinecraftServer; @@ -38,15 +36,11 @@ public abstract class MixinServerConfigurationPacketListenerImpl extends MixinSe @Nullable private ConfigurationTask currentTask; - @Unique - private MinecraftServer badpackets_server; - @Unique private ServerConfigPacketHandler badpackets_packetHandler; @Inject(method = "", at = @At("TAIL")) private void badpackets_createPacketHandler(MinecraftServer server, Connection connection, CommonListenerCookie cookie, CallbackInfo ci) { - badpackets_server = server; badpackets_packetHandler = new ServerConfigPacketHandler(server, (ServerConfigurationPacketListenerImpl) (Object) this, connection); } @@ -82,7 +76,6 @@ protected void badpackets_removePacketHandler() { @Override protected boolean badpackets_handleCustomPayload(ServerboundCustomPayloadPacket packet) { - PacketUtils.ensureRunningOnSameThread(packet, (ServerCommonPacketListener) this, badpackets_server); return badpackets_packetHandler.receive(packet.payload()); } diff --git a/src/main/java/lol/bai/badpackets/impl/mixin/MixinServerGamePacketListenerImpl.java b/src/main/java/lol/bai/badpackets/impl/mixin/MixinServerGamePacketListenerImpl.java index 55e9cad..0953257 100644 --- a/src/main/java/lol/bai/badpackets/impl/mixin/MixinServerGamePacketListenerImpl.java +++ b/src/main/java/lol/bai/badpackets/impl/mixin/MixinServerGamePacketListenerImpl.java @@ -3,8 +3,6 @@ import lol.bai.badpackets.impl.handler.ServerPlayPacketHandler; import net.minecraft.network.Connection; import net.minecraft.network.chat.Component; -import net.minecraft.network.protocol.PacketUtils; -import net.minecraft.network.protocol.common.ServerCommonPacketListener; import net.minecraft.network.protocol.common.ServerboundCustomPayloadPacket; import net.minecraft.network.protocol.game.ServerboundConfigurationAcknowledgedPacket; import net.minecraft.server.MinecraftServer; @@ -12,7 +10,6 @@ import net.minecraft.server.network.CommonListenerCookie; import net.minecraft.server.network.ServerGamePacketListenerImpl; import org.spongepowered.asm.mixin.Mixin; -import org.spongepowered.asm.mixin.Shadow; import org.spongepowered.asm.mixin.Unique; import org.spongepowered.asm.mixin.injection.At; import org.spongepowered.asm.mixin.injection.Inject; @@ -21,9 +18,6 @@ @Mixin(ServerGamePacketListenerImpl.class) public class MixinServerGamePacketListenerImpl extends MixinServerCommonPacketListenerImpl implements ServerPlayPacketHandler.Holder { - @Shadow - public ServerPlayer player; - @Unique private ServerPlayPacketHandler badpacket_packetHandler; @@ -44,7 +38,6 @@ private void badpacekts_removePacketHandler(ServerboundConfigurationAcknowledged @Override protected boolean badpackets_handleCustomPayload(ServerboundCustomPayloadPacket packet) { - PacketUtils.ensureRunningOnSameThread(packet, (ServerCommonPacketListener) this, player.server); return badpacket_packetHandler.receive(packet.payload()); } diff --git a/src/main/java/lol/bai/badpackets/impl/registry/ChannelRegistry.java b/src/main/java/lol/bai/badpackets/impl/registry/ChannelRegistry.java index 47d5e0d..f161dce 100644 --- a/src/main/java/lol/bai/badpackets/impl/registry/ChannelRegistry.java +++ b/src/main/java/lol/bai/badpackets/impl/registry/ChannelRegistry.java @@ -4,6 +4,8 @@ import java.util.HashSet; import java.util.Map; import java.util.Set; +import java.util.concurrent.locks.Lock; +import java.util.concurrent.locks.ReentrantReadWriteLock; import java.util.function.Consumer; import java.util.function.Supplier; @@ -35,41 +37,96 @@ public class ChannelRegistry { public static final ChannelRegistry> PLAY_S2C = new ChannelRegistry<>(RESERVED_CHANNELS, S2C_READERS); public static final ChannelRegistry> PLAY_C2S = new ChannelRegistry<>(RESERVED_CHANNELS, C2S_READERS); - public final Map channels = new HashMap<>(); - + private final Map channels = new HashMap<>(); private final Set reservedChannels; private final Set> handlers = new HashSet<>(); private final ReaderMapHolder readersHolder; + private final ReentrantReadWriteLock locks = new ReentrantReadWriteLock(); + private ChannelRegistry(Set reservedChannels, ReaderMapHolder readersHolder) { this.reservedChannels = reservedChannels; this.readersHolder = readersHolder; } public void register(ResourceLocation id, FriendlyByteBuf.Reader reader, T receiver) { - if (reservedChannels.contains(id)) { - throw new IllegalArgumentException("Reserved channel id " + id); + Lock lock = locks.writeLock(); + lock.lock(); + + try { + if (reservedChannels.contains(id)) { + throw new IllegalArgumentException("Reserved channel id " + id); + } + + Map> readers = readersHolder.getter.get(); + if (!(readers instanceof HashMap)) { + readers = new HashMap<>(readers); + readersHolder.setter.accept(readers); + } + + readers.put(id, reader); + channels.put(id, receiver); + for (AbstractPacketHandler handler : handlers) { + handler.onRegister(id); + } + } finally { + lock.unlock(); } + } + + public boolean has(ResourceLocation id) { + Lock lock = locks.readLock(); + lock.lock(); + + try { + return channels.containsKey(id); + } finally { + lock.unlock(); + } + } - Map> readers = readersHolder.getter.get(); - if (!(readers instanceof HashMap)) { - readers = new HashMap<>(readers); - readersHolder.setter.accept(readers); + public T get(ResourceLocation id) { + Lock lock = locks.readLock(); + lock.lock(); + + try { + return channels.get(id); + } finally { + lock.unlock(); } + } - readers.put(id, reader); - channels.put(id, receiver); - for (AbstractPacketHandler handler : handlers) { - handler.onRegister(id); + public Set getChannels() { + Lock lock = locks.readLock(); + lock.lock(); + + try { + return new HashSet<>(channels.keySet()); + } finally { + lock.unlock(); } } public void addHandler(AbstractPacketHandler handler) { - handlers.add(handler); + Lock lock = locks.writeLock(); + lock.lock(); + + try { + handlers.add(handler); + } finally { + lock.unlock(); + } } public void removeHandler(AbstractPacketHandler handler) { - handlers.remove(handler); + Lock lock = locks.writeLock(); + lock.lock(); + + try { + handlers.remove(handler); + } finally { + lock.unlock(); + } } private record ReaderMapHolder(