Skip to content

Commit

Permalink
read packets on netty thread, execute receiver on main
Browse files Browse the repository at this point in the history
closes #10
  • Loading branch information
deirn committed Oct 20, 2023
1 parent af21dd5 commit 5a56030
Show file tree
Hide file tree
Showing 10 changed files with 123 additions and 44 deletions.
12 changes: 12 additions & 0 deletions src/main/java/lol/bai/badpackets/api/config/ConfigPackets.java
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,13 @@
import lol.bai.badpackets.impl.payload.UntypedPayload;
import lol.bai.badpackets.impl.registry.CallbackRegistry;
import lol.bai.badpackets.impl.registry.ChannelRegistry;
import net.minecraft.client.Minecraft;
import net.minecraft.client.multiplayer.ClientConfigurationPacketListenerImpl;
import net.minecraft.network.FriendlyByteBuf;
import net.minecraft.network.chat.Component;
import net.minecraft.network.protocol.common.custom.CustomPacketPayload;
import net.minecraft.resources.ResourceLocation;
import net.minecraft.server.MinecraftServer;

/**
* Utility for working with configuration packets.
Expand All @@ -32,6 +34,9 @@ public static void registerTask(ResourceLocation id, ConfigTaskExecutor executor

/**
* Register a client-to-server packet receiver.
* <p>
* Raw packet receiver is run on Netty event-loop. Read the buffer on it and run
* the operation on {@linkplain MinecraftServer#execute(Runnable) server thread}.
*
* @param id the packet id
* @param receiver the receiver
Expand All @@ -45,6 +50,8 @@ public static void registerServerReceiver(ResourceLocation id, ServerConfigPacke

/**
* Register a client-to-server packet receiver.
* <p>
* Typed packet receiver is run on the main server thread.
*
* @param id the {@linkplain CustomPacketPayload#id() packet id}
* @param reader the payload reader
Expand All @@ -71,6 +78,9 @@ public static void registerServerReadyCallback(ServerConfigPacketReadyCallback c

/**
* Register a server-to-client packet receiver.
* <p>
* Raw packet receiver is run on Netty event-loop. Read the buffer on it and run
* the operation on {@linkplain Minecraft#execute(Runnable) client thread}.
*
* @param id the packet id
* @param receiver the receiver
Expand All @@ -86,6 +96,8 @@ public static void registerClientReceiver(ResourceLocation id, ClientConfigPacke

/**
* Register a server-to-client packet receiver.
* <p>
* Typed packet receiver is run on the main client thread.
*
* @param id the {@linkplain CustomPacketPayload#id() packet id}
* @param reader the payload reader
Expand Down
12 changes: 12 additions & 0 deletions src/main/java/lol/bai/badpackets/api/play/PlayPackets.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@
import lol.bai.badpackets.impl.payload.UntypedPayload;
import lol.bai.badpackets.impl.registry.CallbackRegistry;
import lol.bai.badpackets.impl.registry.ChannelRegistry;
import net.minecraft.client.Minecraft;
import net.minecraft.network.FriendlyByteBuf;
import net.minecraft.network.protocol.common.custom.CustomPacketPayload;
import net.minecraft.resources.ResourceLocation;
import net.minecraft.server.MinecraftServer;

/**
* Utility for working with play packets.
Expand All @@ -16,6 +18,9 @@ public final class PlayPackets {

/**
* Register a client-to-server packet receiver.
* <p>
* Raw packet receiver is run on Netty event-loop. Read the buffer on it and run
* the operation on {@linkplain MinecraftServer#execute(Runnable) server thread}.
*
* @param id the packet id
* @param receiver the receiver
Expand All @@ -27,6 +32,8 @@ public static void registerServerReceiver(ResourceLocation id, ServerPlayPacketR

/**
* Register a client-to-server packet receiver.
* <p>
* Typed packet receiver is run on the main server thread.
*
* @param id the {@linkplain CustomPacketPayload#id() packet id}
* @param reader the payload reader
Expand All @@ -50,6 +57,9 @@ public static void registerServerReadyCallback(ServerPlayPacketReadyCallback cal

/**
* Register a server-to-client packet receiver.
* <p>
* Raw packet receiver is run on Netty event-loop. Read the buffer on it and run
* the operation on {@linkplain Minecraft#execute(Runnable) client thread}.
*
* @param id the packet id
* @param receiver the receiver
Expand All @@ -62,6 +72,8 @@ public static void registerClientReceiver(ResourceLocation id, ClientPlayPacketR

/**
* Register a server-to-client packet receiver.
* <p>
* Typed packet receiver is run on the main client thread.
*
* @param id the {@linkplain CustomPacketPayload#id() packet id}
* @param reader the payload reader
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package lol.bai.badpackets.impl.handler;

import java.nio.charset.StandardCharsets;
import java.util.Collections;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
Expand All @@ -19,6 +20,7 @@
import net.minecraft.network.protocol.Packet;
import net.minecraft.network.protocol.common.custom.CustomPacketPayload;
import net.minecraft.resources.ResourceLocation;
import net.minecraft.util.thread.BlockableEventLoop;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.jetbrains.annotations.Nullable;
Expand All @@ -29,16 +31,18 @@ public abstract class AbstractPacketHandler<T> implements PacketSender {
protected final Logger logger;

private final Function<CustomPacketPayload, Packet<?>> packetFactory;
private final Set<ResourceLocation> sendableChannels = new HashSet<>();
private final Set<ResourceLocation> sendableChannels = Collections.synchronizedSet(new HashSet<>());

private final BlockableEventLoop<?> eventLoop;
private final Connection connection;

private boolean initialized = false;

protected AbstractPacketHandler(String desc, ChannelRegistry<T> registry, Function<CustomPacketPayload, Packet<?>> packetFactory, Connection connection) {
protected AbstractPacketHandler(String desc, ChannelRegistry<T> registry, Function<CustomPacketPayload, Packet<?>> packetFactory, BlockableEventLoop<?> eventLoop, Connection connection) {
this.logger = LogManager.getLogger(desc);
this.registry = registry;
this.packetFactory = packetFactory;
this.eventLoop = eventLoop;
this.connection = connection;

registry.addHandler(this);
Expand All @@ -58,7 +62,8 @@ private void receiveChannelSyncPacket(FriendlyByteBuf buf) {
sendableChannels.add(new ResourceLocation(namespace, path));
}
}
onInitialChannelSyncPacketReceived();

eventLoop.execute(this::onInitialChannelSyncPacketReceived);
}
}
}
Expand All @@ -76,9 +81,15 @@ public boolean receive(CustomPacketPayload payload) {
return true;
}

if (registry.channels.containsKey(id)) {
if (registry.has(id)) {
try {
receive(registry.channels.get(id), payload);
T receiver = registry.get(id);

if (payload instanceof UntypedPayload || eventLoop.isSameThread()) {
receiveUnsafe(receiver, payload);
} else eventLoop.execute(() -> {
if (connection.isConnected()) receiveUnsafe(receiver, payload);
});
} catch (Throwable t) {
logger.error("Error when receiving packet {}", id, t);
throw t;
Expand All @@ -91,17 +102,18 @@ public boolean receive(CustomPacketPayload payload) {

protected abstract void onInitialChannelSyncPacketReceived();

protected abstract void receive(T receiver, CustomPacketPayload payload);
protected abstract void receiveUnsafe(T receiver, CustomPacketPayload payload);

public void sendInitialChannelSyncPacket() {
if (!initialized) {
initialized = true;
sendVanillaChannelRegisterPacket(Set.of(Constants.CHANNEL_SYNC));

Set<ResourceLocation> channels = registry.getChannels();
FriendlyByteBuf buf = new FriendlyByteBuf(Unpooled.buffer());
buf.writeByte(Constants.CHANNEL_SYNC_INITIAL);

Map<String, List<ResourceLocation>> group = registry.channels.keySet().stream().collect(Collectors.groupingBy(ResourceLocation::getNamespace));
Map<String, List<ResourceLocation>> group = channels.stream().collect(Collectors.groupingBy(ResourceLocation::getNamespace));
buf.writeVarInt(group.size());

for (Map.Entry<String, List<ResourceLocation>> entry : group.entrySet()) {
Expand All @@ -114,7 +126,7 @@ public void sendInitialChannelSyncPacket() {
}

send(Constants.CHANNEL_SYNC, buf);
sendVanillaChannelRegisterPacket(registry.channels.keySet());
sendVanillaChannelRegisterPacket(channels);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ public class ClientConfigPacketHandler extends AbstractPacketHandler<ClientConfi
private final ClientConfigurationPacketListenerImpl listener;

public ClientConfigPacketHandler(Minecraft client, ClientConfigurationPacketListenerImpl listener, Connection connection) {
super("ClientConfigPacketHandler", ChannelRegistry.CONFIG_S2C, ServerboundCustomPayloadPacket::new, connection);
super("ClientConfigPacketHandler", ChannelRegistry.CONFIG_S2C, ServerboundCustomPayloadPacket::new, client, connection);

this.client = client;
this.listener = listener;
Expand All @@ -40,7 +40,7 @@ protected void onInitialChannelSyncPacketReceived() {
}

@Override
protected void receive(ClientConfigPacketReceiver<CustomPacketPayload> receiver, CustomPacketPayload payload) {
protected void receiveUnsafe(ClientConfigPacketReceiver<CustomPacketPayload> receiver, CustomPacketPayload payload) {
receiver.receive(client, listener, payload, this);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ public class ClientPlayPacketHandler extends AbstractPacketHandler<ClientPlayPac
private final ClientPacketListener listener;

public ClientPlayPacketHandler(Minecraft client, ClientPacketListener listener) {
super("ClientPlayPacketHandler", ChannelRegistry.PLAY_S2C, ServerboundCustomPayloadPacket::new, listener.getConnection());
super("ClientPlayPacketHandler", ChannelRegistry.PLAY_S2C, ServerboundCustomPayloadPacket::new, client, listener.getConnection());

this.client = client;
this.listener = listener;
Expand Down Expand Up @@ -47,7 +47,7 @@ protected void onInitialChannelSyncPacketReceived() {
}

@Override
protected void receive(ClientPlayPacketReceiver<CustomPacketPayload> receiver, CustomPacketPayload payload) {
protected void receiveUnsafe(ClientPlayPacketReceiver<CustomPacketPayload> receiver, CustomPacketPayload payload) {
receiver.receive(client, listener, payload, this);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public class ServerConfigPacketHandler extends AbstractPacketHandler<ServerConfi
private final ServerConfigurationPacketListenerImpl listener;

public ServerConfigPacketHandler(MinecraftServer server, ServerConfigurationPacketListenerImpl listener, Connection connection) {
super("ServerConfigPacketHandler for " + listener.getOwner().getName(), ChannelRegistry.CONFIG_C2S, ClientboundCustomPayloadPacket::new, connection);
super("ServerConfigPacketHandler for " + listener.getOwner().getName(), ChannelRegistry.CONFIG_C2S, ClientboundCustomPayloadPacket::new, server, connection);

this.server = server;
this.listener = listener;
Expand Down Expand Up @@ -60,7 +60,7 @@ protected void onInitialChannelSyncPacketReceived() {
}

@Override
protected void receive(ServerConfigPacketReceiver<CustomPacketPayload> receiver, CustomPacketPayload payload) {
protected void receiveUnsafe(ServerConfigPacketReceiver<CustomPacketPayload> receiver, CustomPacketPayload payload) {
receiver.receive(server, listener, payload, this, this);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ public class ServerPlayPacketHandler extends AbstractPacketHandler<ServerPlayPac
private final ServerGamePacketListenerImpl handler;

public ServerPlayPacketHandler(MinecraftServer server, ServerGamePacketListenerImpl handler, Connection connection) {
super("ServerPlayPacketHandler for " + handler.getPlayer().getScoreboardName(), ChannelRegistry.PLAY_C2S, ClientboundCustomPayloadPacket::new, connection);
super("ServerPlayPacketHandler for " + handler.getPlayer().getScoreboardName(), ChannelRegistry.PLAY_C2S, ClientboundCustomPayloadPacket::new, server, connection);
this.server = server;
this.handler = handler;
}
Expand All @@ -42,7 +42,7 @@ protected void onInitialChannelSyncPacketReceived() {
}

@Override
protected void receive(ServerPlayPacketReceiver<CustomPacketPayload> receiver, CustomPacketPayload payload) {
protected void receiveUnsafe(ServerPlayPacketReceiver<CustomPacketPayload> receiver, CustomPacketPayload payload) {
receiver.receive(server, handler.getPlayer(), handler, payload, this);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,6 @@
import lol.bai.badpackets.impl.Constants;
import lol.bai.badpackets.impl.handler.ServerConfigPacketHandler;
import net.minecraft.network.Connection;
import net.minecraft.network.protocol.PacketUtils;
import net.minecraft.network.protocol.common.ServerCommonPacketListener;
import net.minecraft.network.protocol.common.ServerboundCustomPayloadPacket;
import net.minecraft.network.protocol.configuration.ServerboundFinishConfigurationPacket;
import net.minecraft.server.MinecraftServer;
Expand Down Expand Up @@ -38,15 +36,11 @@ public abstract class MixinServerConfigurationPacketListenerImpl extends MixinSe
@Nullable
private ConfigurationTask currentTask;

@Unique
private MinecraftServer badpackets_server;

@Unique
private ServerConfigPacketHandler badpackets_packetHandler;

@Inject(method = "<init>", at = @At("TAIL"))
private void badpackets_createPacketHandler(MinecraftServer server, Connection connection, CommonListenerCookie cookie, CallbackInfo ci) {
badpackets_server = server;
badpackets_packetHandler = new ServerConfigPacketHandler(server, (ServerConfigurationPacketListenerImpl) (Object) this, connection);
}

Expand Down Expand Up @@ -82,7 +76,6 @@ protected void badpackets_removePacketHandler() {

@Override
protected boolean badpackets_handleCustomPayload(ServerboundCustomPayloadPacket packet) {
PacketUtils.ensureRunningOnSameThread(packet, (ServerCommonPacketListener) this, badpackets_server);
return badpackets_packetHandler.receive(packet.payload());
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,16 +3,13 @@
import lol.bai.badpackets.impl.handler.ServerPlayPacketHandler;
import net.minecraft.network.Connection;
import net.minecraft.network.chat.Component;
import net.minecraft.network.protocol.PacketUtils;
import net.minecraft.network.protocol.common.ServerCommonPacketListener;
import net.minecraft.network.protocol.common.ServerboundCustomPayloadPacket;
import net.minecraft.network.protocol.game.ServerboundConfigurationAcknowledgedPacket;
import net.minecraft.server.MinecraftServer;
import net.minecraft.server.level.ServerPlayer;
import net.minecraft.server.network.CommonListenerCookie;
import net.minecraft.server.network.ServerGamePacketListenerImpl;
import org.spongepowered.asm.mixin.Mixin;
import org.spongepowered.asm.mixin.Shadow;
import org.spongepowered.asm.mixin.Unique;
import org.spongepowered.asm.mixin.injection.At;
import org.spongepowered.asm.mixin.injection.Inject;
Expand All @@ -21,9 +18,6 @@
@Mixin(ServerGamePacketListenerImpl.class)
public class MixinServerGamePacketListenerImpl extends MixinServerCommonPacketListenerImpl implements ServerPlayPacketHandler.Holder {

@Shadow
public ServerPlayer player;

@Unique
private ServerPlayPacketHandler badpacket_packetHandler;

Expand All @@ -44,7 +38,6 @@ private void badpacekts_removePacketHandler(ServerboundConfigurationAcknowledged

@Override
protected boolean badpackets_handleCustomPayload(ServerboundCustomPayloadPacket packet) {
PacketUtils.ensureRunningOnSameThread(packet, (ServerCommonPacketListener) this, player.server);
return badpacket_packetHandler.receive(packet.payload());
}

Expand Down
Loading

0 comments on commit 5a56030

Please sign in to comment.