Don't search through protocol versions all the time

This commit is contained in:
kashike
2018-07-27 20:13:23 -07:00
parent 39def54e96
commit b95f076562
8 changed files with 100 additions and 102 deletions

View File

@@ -96,8 +96,8 @@ public final class ConnectionManager {
.addLast(FRAME_DECODER, new MinecraftVarintFrameDecoder())
.addLast(LEGACY_PING_ENCODER, LegacyPingEncoder.INSTANCE)
.addLast(FRAME_ENCODER, MinecraftVarintLengthEncoder.INSTANCE)
.addLast(MINECRAFT_DECODER, new MinecraftDecoder(ProtocolConstants.Direction.TO_SERVER))
.addLast(MINECRAFT_ENCODER, new MinecraftEncoder(ProtocolConstants.Direction.TO_CLIENT));
.addLast(MINECRAFT_DECODER, new MinecraftDecoder(ProtocolConstants.Direction.SERVERBOUND))
.addLast(MINECRAFT_ENCODER, new MinecraftEncoder(ProtocolConstants.Direction.CLIENTBOUND));
final MinecraftConnection connection = new MinecraftConnection(ch);
connection.setState(StateRegistry.HANDSHAKE);

View File

@@ -48,8 +48,8 @@ public class ServerConnection implements MinecraftConnectionAssociation {
.addLast(READ_TIMEOUT, new ReadTimeoutHandler(SERVER_READ_TIMEOUT_SECONDS, TimeUnit.SECONDS))
.addLast(FRAME_DECODER, new MinecraftVarintFrameDecoder())
.addLast(FRAME_ENCODER, MinecraftVarintLengthEncoder.INSTANCE)
.addLast(MINECRAFT_DECODER, new MinecraftDecoder(ProtocolConstants.Direction.TO_CLIENT))
.addLast(MINECRAFT_ENCODER, new MinecraftEncoder(ProtocolConstants.Direction.TO_SERVER));
.addLast(MINECRAFT_DECODER, new MinecraftDecoder(ProtocolConstants.Direction.CLIENTBOUND))
.addLast(MINECRAFT_ENCODER, new MinecraftEncoder(ProtocolConstants.Direction.SERVERBOUND));
MinecraftConnection connection = new MinecraftConnection(ch);
connection.setState(StateRegistry.HANDSHAKE);

View File

@@ -14,7 +14,7 @@ public enum ProtocolConstants { ;
}
public enum Direction {
TO_SERVER,
TO_CLIENT
SERVERBOUND,
CLIENTBOUND
}
}

View File

@@ -1,13 +1,10 @@
package com.velocitypowered.proxy.protocol;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableList;
import com.velocitypowered.proxy.protocol.packets.*;
import io.netty.util.collection.IntObjectHashMap;
import io.netty.util.collection.IntObjectMap;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Supplier;
@@ -17,124 +14,127 @@ import static com.velocitypowered.proxy.protocol.ProtocolConstants.MINECRAFT_1_1
public enum StateRegistry {
HANDSHAKE {
{
TO_SERVER.register(Handshake.class, Handshake::new,
SERVERBOUND.register(Handshake.class, Handshake::new,
generic(0x00));
}
},
STATUS {
{
TO_SERVER.register(StatusRequest.class, StatusRequest::new,
SERVERBOUND.register(StatusRequest.class, StatusRequest::new,
generic(0x00));
TO_SERVER.register(Ping.class, Ping::new,
SERVERBOUND.register(Ping.class, Ping::new,
generic(0x01));
TO_CLIENT.register(StatusResponse.class, StatusResponse::new,
CLIENTBOUND.register(StatusResponse.class, StatusResponse::new,
generic(0x00));
TO_CLIENT.register(Ping.class, Ping::new,
CLIENTBOUND.register(Ping.class, Ping::new,
generic(0x01));
}
},
PLAY {
{
TO_SERVER.register(Chat.class, Chat::new,
SERVERBOUND.register(Chat.class, Chat::new,
map(0x02, MINECRAFT_1_12));
TO_SERVER.register(Ping.class, Ping::new,
SERVERBOUND.register(Ping.class, Ping::new,
map(0x0b, MINECRAFT_1_12));
TO_CLIENT.register(Chat.class, Chat::new,
CLIENTBOUND.register(Chat.class, Chat::new,
map(0x0F, MINECRAFT_1_12));
TO_CLIENT.register(Disconnect.class, Disconnect::new,
CLIENTBOUND.register(Disconnect.class, Disconnect::new,
map(0x1A, MINECRAFT_1_12));
TO_CLIENT.register(Ping.class, Ping::new,
CLIENTBOUND.register(Ping.class, Ping::new,
map(0x1F, MINECRAFT_1_12));
TO_CLIENT.register(JoinGame.class, JoinGame::new,
CLIENTBOUND.register(JoinGame.class, JoinGame::new,
map(0x23, MINECRAFT_1_12));
TO_CLIENT.register(Respawn.class, Respawn::new,
CLIENTBOUND.register(Respawn.class, Respawn::new,
map(0x35, MINECRAFT_1_12));
}
},
LOGIN {
{
TO_SERVER.register(ServerLogin.class, ServerLogin::new,
SERVERBOUND.register(ServerLogin.class, ServerLogin::new,
generic(0x00));
TO_SERVER.register(EncryptionResponse.class, EncryptionResponse::new,
SERVERBOUND.register(EncryptionResponse.class, EncryptionResponse::new,
generic(0x01));
TO_CLIENT.register(Disconnect.class, Disconnect::new,
CLIENTBOUND.register(Disconnect.class, Disconnect::new,
generic(0x00));
TO_CLIENT.register(EncryptionRequest.class, EncryptionRequest::new,
CLIENTBOUND.register(EncryptionRequest.class, EncryptionRequest::new,
generic(0x01));
TO_CLIENT.register(ServerLoginSuccess.class, ServerLoginSuccess::new,
CLIENTBOUND.register(ServerLoginSuccess.class, ServerLoginSuccess::new,
generic(0x02));
TO_CLIENT.register(SetCompression.class, SetCompression::new,
CLIENTBOUND.register(SetCompression.class, SetCompression::new,
generic(0x03));
}
};
public final PacketRegistry TO_CLIENT = new PacketRegistry(ProtocolConstants.Direction.TO_CLIENT, this);
public final PacketRegistry TO_SERVER = new PacketRegistry(ProtocolConstants.Direction.TO_SERVER, this);
public final PacketRegistry CLIENTBOUND = new PacketRegistry(ProtocolConstants.Direction.CLIENTBOUND);
public final PacketRegistry SERVERBOUND = new PacketRegistry(ProtocolConstants.Direction.SERVERBOUND);
public static class PacketRegistry {
private final ProtocolConstants.Direction direction;
private final StateRegistry state;
private final IntObjectMap<IntObjectMap<Supplier<? extends MinecraftPacket>>> byProtocolVersionToProtocolIds = new IntObjectHashMap<>();
private final Map<Class<? extends MinecraftPacket>, List<PacketMapping>> idMappers = new HashMap<>();
private final IntObjectMap<ProtocolVersion> versions = new IntObjectHashMap<>();
public PacketRegistry(ProtocolConstants.Direction direction, StateRegistry state) {
public PacketRegistry(ProtocolConstants.Direction direction) {
this.direction = direction;
this.state = state;
}
public ProtocolVersion getVersion(final int version) {
ProtocolVersion result = null;
for (final IntObjectMap.PrimitiveEntry<ProtocolVersion> entry : this.versions.entries()) {
if (entry.key() <= version) {
result = entry.value();
}
}
if (result == null) {
throw new IllegalArgumentException("Could not find data for protocol version " + version);
}
return result;
}
public <P extends MinecraftPacket> void register(Class<P> clazz, Supplier<P> packetSupplier, PacketMapping... mappings) {
if (mappings.length == 0) {
throw new IllegalArgumentException("At least one mapping must be provided.");
}
for (PacketMapping mapping : mappings) {
IntObjectMap<Supplier<? extends MinecraftPacket>> ids = byProtocolVersionToProtocolIds.get(mapping.protocolVersion);
if (ids == null) {
byProtocolVersionToProtocolIds.put(mapping.protocolVersion, ids = new IntObjectHashMap<>());
for (final PacketMapping mapping : mappings) {
ProtocolVersion version = this.versions.get(mapping.protocolVersion);
if (version == null) {
version = new ProtocolVersion(mapping.protocolVersion);
this.versions.put(mapping.protocolVersion, version);
}
ids.put(mapping.id, packetSupplier);
version.packetIdToSupplier.put(mapping.id, packetSupplier);
version.packetClassToId.put(clazz, mapping.id);
}
idMappers.put(clazz, ImmutableList.copyOf(mappings));
}
public MinecraftPacket createPacket(int id, int protocolVersion) {
IntObjectMap<Supplier<? extends MinecraftPacket>> bestLookup = null;
for (IntObjectMap.PrimitiveEntry<IntObjectMap<Supplier<? extends MinecraftPacket>>> entry : byProtocolVersionToProtocolIds.entries()) {
if (entry.key() <= protocolVersion) {
bestLookup = entry.value();
}
}
if (bestLookup == null) {
return null;
}
Supplier<? extends MinecraftPacket> supplier = bestLookup.get(id);
if (supplier == null) {
return null;
}
return supplier.get();
}
public class ProtocolVersion {
public final int id;
final IntObjectMap<Supplier<? extends MinecraftPacket>> packetIdToSupplier = new IntObjectHashMap<>();
final Map<Class<? extends MinecraftPacket>, Integer> packetClassToId = new HashMap<>();
public int getId(MinecraftPacket packet, int protocolVersion) {
Preconditions.checkNotNull(packet, "packet");
ProtocolVersion(final int id) {
this.id = id;
}
List<PacketMapping> mappings = idMappers.get(packet.getClass());
if (mappings == null || mappings.isEmpty()) {
throw new IllegalArgumentException("Supplied packet " + packet.getClass().getName() +
" doesn't have any mappings. Direction " + direction + " State " + state);
}
int useId = -1;
for (PacketMapping mapping : mappings) {
if (mapping.protocolVersion <= protocolVersion) {
useId = mapping.id;
public MinecraftPacket createPacket(final int id) {
final Supplier<? extends MinecraftPacket> supplier = this.packetIdToSupplier.get(id);
if (supplier == null) {
return null;
}
return supplier.get();
}
if (useId == -1) {
throw new IllegalArgumentException("Unable to find a mapping for " + packet.getClass().getName()
+ " Version " + protocolVersion + " Direction " + direction + " State " + state);
public int getPacketId(final MinecraftPacket packet) {
final Integer id = this.packetClassToId.get(packet.getClass());
if (id == null) {
throw new IllegalArgumentException(String.format(
"Unable to find id for packet of type %s in %s protocol %s",
packet.getClass().getName(), PacketRegistry.this.direction, this.id
));
}
return id;
}
return useId;
}
}

View File

@@ -12,7 +12,7 @@ import java.util.List;
public class MinecraftDecoder extends MessageToMessageDecoder<ByteBuf> {
private StateRegistry state;
private final ProtocolConstants.Direction direction;
private int protocolVersion;
private StateRegistry.PacketRegistry.ProtocolVersion protocolVersion;
public MinecraftDecoder(ProtocolConstants.Direction direction) {
this.state = StateRegistry.HANDSHAKE;
@@ -28,14 +28,13 @@ public class MinecraftDecoder extends MessageToMessageDecoder<ByteBuf> {
ByteBuf slice = msg.slice().retain();
int packetId = ProtocolUtils.readVarInt(msg);
StateRegistry.PacketRegistry mappings = direction == ProtocolConstants.Direction.TO_CLIENT ? state.TO_CLIENT : state.TO_SERVER;
MinecraftPacket packet = mappings.createPacket(packetId, protocolVersion);
MinecraftPacket packet = this.protocolVersion.createPacket(packetId);
if (packet == null) {
msg.skipBytes(msg.readableBytes());
out.add(new PacketWrapper(null, slice));
} else {
try {
packet.decode(msg, direction, protocolVersion);
packet.decode(msg, direction, protocolVersion.id);
} catch (Exception e) {
throw new CorruptedFrameException("Error decoding " + packet.getClass() + " Direction " + direction
+ " Protocol " + protocolVersion + " State " + state + " ID " + Integer.toHexString(packetId), e);
@@ -44,12 +43,12 @@ public class MinecraftDecoder extends MessageToMessageDecoder<ByteBuf> {
}
}
public int getProtocolVersion() {
public StateRegistry.PacketRegistry.ProtocolVersion getProtocolVersion() {
return protocolVersion;
}
public void setProtocolVersion(int protocolVersion) {
this.protocolVersion = protocolVersion;
this.protocolVersion = (this.direction == ProtocolConstants.Direction.CLIENTBOUND ? this.state.CLIENTBOUND : this.state.SERVERBOUND).getVersion(protocolVersion);
}
public StateRegistry getState() {

View File

@@ -12,7 +12,7 @@ import io.netty.handler.codec.MessageToByteEncoder;
public class MinecraftEncoder extends MessageToByteEncoder<MinecraftPacket> {
private StateRegistry state;
private final ProtocolConstants.Direction direction;
private int protocolVersion;
private StateRegistry.PacketRegistry.ProtocolVersion protocolVersion;
public MinecraftEncoder(ProtocolConstants.Direction direction) {
this.state = StateRegistry.HANDSHAKE;
@@ -20,19 +20,18 @@ public class MinecraftEncoder extends MessageToByteEncoder<MinecraftPacket> {
}
@Override
protected void encode(ChannelHandlerContext ctx, MinecraftPacket msg, ByteBuf out) throws Exception {
StateRegistry.PacketRegistry mappings = direction == ProtocolConstants.Direction.TO_CLIENT ? state.TO_CLIENT : state.TO_SERVER;
int packetId = mappings.getId(msg, protocolVersion);
protected void encode(ChannelHandlerContext ctx, MinecraftPacket msg, ByteBuf out) {
int packetId = this.protocolVersion.getPacketId(msg);
ProtocolUtils.writeVarInt(out, packetId);
msg.encode(out, direction, protocolVersion);
msg.encode(out, direction, protocolVersion.id);
}
public int getProtocolVersion() {
public StateRegistry.PacketRegistry.ProtocolVersion getProtocolVersion() {
return protocolVersion;
}
public void setProtocolVersion(int protocolVersion) {
this.protocolVersion = protocolVersion;
public void setProtocolVersion(final int protocolVersion) {
this.protocolVersion = (this.direction == ProtocolConstants.Direction.CLIENTBOUND ? this.state.CLIENTBOUND : this.state.SERVERBOUND).getVersion(protocolVersion);
}
public StateRegistry getState() {

View File

@@ -47,7 +47,7 @@ public class Chat implements MinecraftPacket {
@Override
public void decode(ByteBuf buf, ProtocolConstants.Direction direction, int protocolVersion) {
message = ProtocolUtils.readString(buf);
if (direction == ProtocolConstants.Direction.TO_CLIENT) {
if (direction == ProtocolConstants.Direction.CLIENTBOUND) {
position = buf.readByte();
}
}
@@ -55,7 +55,7 @@ public class Chat implements MinecraftPacket {
@Override
public void encode(ByteBuf buf, ProtocolConstants.Direction direction, int protocolVersion) {
ProtocolUtils.writeString(buf, message);
if (direction == ProtocolConstants.Direction.TO_CLIENT) {
if (direction == ProtocolConstants.Direction.CLIENTBOUND) {
buf.writeByte(position);
}
}