diff --git a/src/main/java/com/velocitypowered/proxy/protocol/StateRegistry.java b/src/main/java/com/velocitypowered/proxy/protocol/StateRegistry.java index f25e578f..5c7043e9 100644 --- a/src/main/java/com/velocitypowered/proxy/protocol/StateRegistry.java +++ b/src/main/java/com/velocitypowered/proxy/protocol/StateRegistry.java @@ -1,83 +1,176 @@ package com.velocitypowered.proxy.protocol; +import com.google.common.base.Preconditions; +import com.google.common.collect.ImmutableList; import com.velocitypowered.proxy.protocol.packets.*; import io.netty.util.collection.IntObjectHashMap; import io.netty.util.collection.IntObjectMap; import java.util.HashMap; +import java.util.List; import java.util.Map; +import java.util.Objects; import java.util.function.Supplier; +import static com.velocitypowered.proxy.protocol.ProtocolConstants.MINECRAFT_1_12; + public enum StateRegistry { HANDSHAKE { { - TO_SERVER.register(0x00, Handshake.class, Handshake::new); + TO_SERVER.register(Handshake.class, Handshake::new, + generic(0x00)); } }, STATUS { { - TO_SERVER.register(0x00, StatusRequest.class, StatusRequest::new); - TO_SERVER.register(0x01, Ping.class, Ping::new); + TO_SERVER.register(StatusRequest.class, StatusRequest::new, + generic(0x00)); + TO_SERVER.register(Ping.class, Ping::new, + generic(0x01)); - TO_CLIENT.register(0x00, StatusResponse.class, StatusResponse::new); - TO_CLIENT.register(0x01, Ping.class, Ping::new); + TO_CLIENT.register(StatusResponse.class, StatusResponse::new, + generic(0x00)); + TO_CLIENT.register(Ping.class, Ping::new, + generic(0x01)); } }, PLAY { { - TO_SERVER.register(0x02, Chat.class, Chat::new); - TO_SERVER.register(0x0b, Ping.class, Ping::new); + TO_SERVER.register(Chat.class, Chat::new, + map(0x02, MINECRAFT_1_12)); + TO_SERVER.register(Ping.class, Ping::new, + map(0x0b, MINECRAFT_1_12)); - TO_CLIENT.register(0x0F, Chat.class, Chat::new); - TO_CLIENT.register(0x1A, Disconnect.class, Disconnect::new); - TO_CLIENT.register(0x1F, Ping.class, Ping::new); + TO_CLIENT.register(Chat.class, Chat::new, + map(0x0F, MINECRAFT_1_12)); + TO_CLIENT.register(Disconnect.class, Disconnect::new, + map(0x1A, MINECRAFT_1_12)); + TO_CLIENT.register(Ping.class, Ping::new, + map(0x1F, MINECRAFT_1_12)); } }, LOGIN { { - TO_SERVER.register(0x00, ServerLogin.class, ServerLogin::new); - TO_SERVER.register(0x01, EncryptionResponse.class, EncryptionResponse::new); + TO_SERVER.register(ServerLogin.class, ServerLogin::new, + generic(0x00)); + TO_SERVER.register(EncryptionResponse.class, EncryptionResponse::new, + generic(0x01)); - TO_CLIENT.register(0x00, Disconnect.class, Disconnect::new); - TO_CLIENT.register(0x01, EncryptionRequest.class, EncryptionRequest::new); - TO_CLIENT.register(0x02, ServerLoginSuccess.class, ServerLoginSuccess::new); - TO_CLIENT.register(0x03, SetCompression.class, SetCompression::new); + TO_CLIENT.register(Disconnect.class, Disconnect::new, + generic(0x00)); + TO_CLIENT.register(EncryptionRequest.class, EncryptionRequest::new, + generic(0x01)); + TO_CLIENT.register(ServerLoginSuccess.class, ServerLoginSuccess::new, + generic(0x02)); + TO_CLIENT.register(SetCompression.class, SetCompression::new, + generic(0x03)); } }; - public final ProtocolMappings TO_CLIENT = new ProtocolMappings(ProtocolConstants.Direction.TO_CLIENT, this); - public final ProtocolMappings TO_SERVER = new ProtocolMappings(ProtocolConstants.Direction.TO_SERVER, this); + public final PacketRegistry TO_CLIENT = new PacketRegistry(ProtocolConstants.Direction.TO_CLIENT, this); + public final PacketRegistry TO_SERVER = new PacketRegistry(ProtocolConstants.Direction.TO_SERVER, this); - public static class ProtocolMappings { + public static class PacketRegistry { private final ProtocolConstants.Direction direction; private final StateRegistry state; - private final IntObjectMap> idsToSuppliers = new IntObjectHashMap<>(); - private final Map, Integer> packetClassesToIds = new HashMap<>(); + private final IntObjectMap>> byProtocolVersionToProtocolIds = new IntObjectHashMap<>(); + private final Map, List> idMappers = new HashMap<>(); - public ProtocolMappings(ProtocolConstants.Direction direction, StateRegistry state) { + public PacketRegistry(ProtocolConstants.Direction direction, StateRegistry state) { this.direction = direction; this.state = state; } - public

void register(int id, Class

clazz, Supplier

packetSupplier) { - idsToSuppliers.put(id, packetSupplier); - packetClassesToIds.put(clazz, id); + public

void register(Class

clazz, Supplier

packetSupplier, PacketMapping... mappings) { + if (mappings.length == 0) { + throw new IllegalArgumentException("At least one mapping must be provided."); + } + for (PacketMapping mapping : mappings) { + IntObjectMap> ids = byProtocolVersionToProtocolIds.get(mapping.protocolVersion); + if (ids == null) { + byProtocolVersionToProtocolIds.put(mapping.protocolVersion, ids = new IntObjectHashMap<>()); + } + ids.put(mapping.id, packetSupplier); + } + idMappers.put(clazz, ImmutableList.copyOf(mappings)); } - public MinecraftPacket createPacket(int id) { - Supplier supplier = idsToSuppliers.get(id); + public MinecraftPacket createPacket(int id, int protocolVersion) { + IntObjectMap> bestLookup = null; + for (IntObjectMap.PrimitiveEntry>> entry : byProtocolVersionToProtocolIds.entries()) { + if (entry.key() <= protocolVersion) { + bestLookup = entry.value(); + } + } + if (bestLookup == null) { + return null; + } + Supplier supplier = bestLookup.get(id); if (supplier == null) { return null; } return supplier.get(); } - public int getId(MinecraftPacket packet) { - Integer id = packetClassesToIds.get(packet.getClass()); - if (id == null) { - throw new IllegalArgumentException("Supplied packet " + packet.getClass().getName() + " doesn't have a mapping. Direction " + direction + " State " + state); + public int getId(MinecraftPacket packet, int protocolVersion) { + Preconditions.checkNotNull(packet, "packet"); + + List mappings = idMappers.get(packet.getClass()); + if (mappings == null || mappings.isEmpty()) { + throw new IllegalArgumentException("Supplied packet " + packet.getClass().getName() + + " doesn't have any mappings. Direction " + direction + " State " + state); } - return id; + int useId = -1; + for (PacketMapping mapping : mappings) { + if (mapping.protocolVersion <= protocolVersion) { + useId = mapping.id; + } + } + if (useId == -1) { + throw new IllegalArgumentException("Unable to find a mapping for " + packet.getClass().getName() + + " Version " + protocolVersion + " Direction " + direction + " State " + state); + } + return useId; } } + + public static class PacketMapping { + private final int id; + private final int protocolVersion; + + public PacketMapping(int id, int protocolVersion) { + this.id = id; + this.protocolVersion = protocolVersion; + } + + @Override + public String toString() { + return "PacketMapping{" + + "id=" + id + + ", protocolVersion=" + protocolVersion + + '}'; + } + + @Override + public boolean equals(Object o) { + if (this == o) return true; + if (o == null || getClass() != o.getClass()) return false; + PacketMapping that = (PacketMapping) o; + return id == that.id && + protocolVersion == that.protocolVersion; + } + + @Override + public int hashCode() { + return Objects.hash(id, protocolVersion); + } + } + + private static PacketMapping map(int id, int version) { + return new PacketMapping(id, version); + } + + private static PacketMapping generic(int id) { + return new PacketMapping(id, 0); + } } diff --git a/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftDecoder.java b/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftDecoder.java index 137d85bd..8f0f1dfa 100644 --- a/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftDecoder.java +++ b/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftDecoder.java @@ -28,8 +28,8 @@ public class MinecraftDecoder extends MessageToMessageDecoder { ByteBuf slice = msg.slice().retain(); int packetId = ProtocolUtils.readVarInt(msg); - StateRegistry.ProtocolMappings mappings = direction == ProtocolConstants.Direction.TO_CLIENT ? state.TO_CLIENT : state.TO_SERVER; - MinecraftPacket packet = mappings.createPacket(packetId); + StateRegistry.PacketRegistry mappings = direction == ProtocolConstants.Direction.TO_CLIENT ? state.TO_CLIENT : state.TO_SERVER; + MinecraftPacket packet = mappings.createPacket(packetId, protocolVersion); if (packet == null) { msg.skipBytes(msg.readableBytes()); out.add(new PacketWrapper(null, slice)); diff --git a/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftEncoder.java b/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftEncoder.java index 35130e38..f90662a1 100644 --- a/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftEncoder.java +++ b/src/main/java/com/velocitypowered/proxy/protocol/netty/MinecraftEncoder.java @@ -21,8 +21,8 @@ public class MinecraftEncoder extends MessageToByteEncoder { @Override protected void encode(ChannelHandlerContext ctx, MinecraftPacket msg, ByteBuf out) throws Exception { - StateRegistry.ProtocolMappings mappings = direction == ProtocolConstants.Direction.TO_CLIENT ? state.TO_CLIENT : state.TO_SERVER; - int packetId = mappings.getId(msg); + StateRegistry.PacketRegistry mappings = direction == ProtocolConstants.Direction.TO_CLIENT ? state.TO_CLIENT : state.TO_SERVER; + int packetId = mappings.getId(msg, protocolVersion); ProtocolUtils.writeVarInt(out, packetId); msg.encode(out, direction, protocolVersion); } diff --git a/src/test/java/com/velocitypowered/proxy/protocol/PacketRegistryTest.java b/src/test/java/com/velocitypowered/proxy/protocol/PacketRegistryTest.java new file mode 100644 index 00000000..a5c3ff2b --- /dev/null +++ b/src/test/java/com/velocitypowered/proxy/protocol/PacketRegistryTest.java @@ -0,0 +1,48 @@ +package com.velocitypowered.proxy.protocol; + +import com.velocitypowered.proxy.protocol.packets.Handshake; +import org.junit.jupiter.api.Test; + +import static org.junit.jupiter.api.Assertions.*; + +class PacketRegistryTest { + private StateRegistry.PacketRegistry setupRegistry() { + StateRegistry.PacketRegistry registry = new StateRegistry.PacketRegistry(ProtocolConstants.Direction.TO_CLIENT, StateRegistry.HANDSHAKE); + registry.register(Handshake.class, Handshake::new, new StateRegistry.PacketMapping(0x00, 1)); + return registry; + } + + @Test + void packetRegistryWorks() { + StateRegistry.PacketRegistry registry = setupRegistry(); + MinecraftPacket packet = registry.createPacket(0, 1); + assertNotNull(packet, "Packet was not found in registry"); + assertEquals(Handshake.class, packet.getClass(), "Registry returned wrong class"); + + assertEquals(0, registry.getId(packet, 1), "Registry did not return the correct packet ID"); + } + + @Test + void packetRegistryRevertsToBestOldVersion() { + StateRegistry.PacketRegistry registry = setupRegistry(); + MinecraftPacket packet = registry.createPacket(0, 2); + assertNotNull(packet, "Packet was not found in registry"); + assertEquals(Handshake.class, packet.getClass(), "Registry returned wrong class"); + + assertEquals(0, registry.getId(packet, 2), "Registry did not return the correct packet ID"); + } + + @Test + void packetRegistryDoesntProvideNewPacketsForOld() { + StateRegistry.PacketRegistry registry = setupRegistry(); + assertNull(registry.createPacket(0, 0), "Packet was found in registry despite being too new"); + + assertThrows(IllegalArgumentException.class, () -> registry.getId(new Handshake(), 0), "Registry provided new packets for an old protocol version"); + } + + @Test + void failOnNoMappings() { + StateRegistry.PacketRegistry registry = new StateRegistry.PacketRegistry(ProtocolConstants.Direction.TO_CLIENT, StateRegistry.HANDSHAKE); + assertThrows(IllegalArgumentException.class, () -> registry.register(Handshake.class, Handshake::new)); + } +} \ No newline at end of file