CapabilityMultiplexer.java
/*
* Copyright ConsenSys AG.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.hyperledger.besu.ethereum.p2p.rlpx.wire;
import static java.util.Comparator.comparing;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Set;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableRangeMap;
import com.google.common.collect.Range;
import org.apache.tuweni.bytes.Bytes;
public class CapabilityMultiplexer {
static final int WIRE_PROTOCOL_MESSAGE_SPACE = 16;
private static final Comparator<Capability> CAPABILITY_COMPARATOR =
comparing(Capability::getName).thenComparing(comparing(Capability::getVersion).reversed());
private final ImmutableRangeMap<Integer, Capability> agreedCaps;
private final ImmutableMap<Capability, Integer> capabilityOffsets;
private final Map<String, SubProtocol> subProtocols = new HashMap<>();
public CapabilityMultiplexer(
final List<SubProtocol> subProtocols, final List<Capability> a, final List<Capability> b) {
for (final SubProtocol subProtocol : subProtocols) {
this.subProtocols.put(subProtocol.getName(), subProtocol);
}
agreedCaps = calculateAgreedCapabilities(a, b);
capabilityOffsets = calculateCapabilityOffsets(agreedCaps);
}
public Set<Capability> getAgreedCapabilities() {
return capabilityOffsets.keySet();
}
public SubProtocol subProtocol(final Capability cap) {
return subProtocols.get(cap.getName());
}
/**
* Prepares a message to send by offsetting its code based on the agreed capabilities.
*
* @param cap The capability (protocol) associated with this message.
* @param messageToSend The message to send.
* @return Returns message with the correctly offset code.
*/
public MessageData multiplex(final Capability cap, final MessageData messageToSend) {
final int offset = null == cap ? 0 : capabilityOffsets.get(cap);
return offsetMessageCode(messageToSend, offset);
}
/**
* Given a message from a peer, determine which capability the message corresponds to and maps the
* message code to the appropriate value.
*
* @param receivedMessage The message received from a peer.
* @return The interpreted message.
*/
public ProtocolMessage demultiplex(final MessageData receivedMessage) {
final Map.Entry<Range<Integer>, Capability> agreedCap =
agreedCaps.getEntry(receivedMessage.getCode());
if (agreedCap == null) {
return new ProtocolMessage(null, receivedMessage);
}
final int offset = -agreedCap.getKey().lowerEndpoint();
final Capability cap = agreedCap.getValue();
final MessageData demultiplexedMessage = offsetMessageCode(receivedMessage, offset);
return new ProtocolMessage(cap, demultiplexedMessage);
}
private MessageData offsetMessageCode(final MessageData originalMessage, final int offset) {
// Return wrapped message with modified offset
return new MessageData() {
@Override
public int getSize() {
return originalMessage.getSize();
}
@Override
public int getCode() {
return originalMessage.getCode() + offset;
}
@Override
public Bytes getData() {
return originalMessage.getData();
}
@Override
public String toString() {
return "Message{ code=" + getCode() + ", size=" + getSize() + "}";
}
};
}
private ImmutableRangeMap<Integer, Capability> calculateAgreedCapabilities(
final List<Capability> a, final List<Capability> b) {
final List<Capability> caps = new ArrayList<>(a);
caps.sort(CAPABILITY_COMPARATOR);
caps.retainAll(b);
final ImmutableRangeMap.Builder<Integer, Capability> builder = ImmutableRangeMap.builder();
// Reserve some messages for WireProtocol
int offset = WIRE_PROTOCOL_MESSAGE_SPACE;
String prevProtocol = null;
for (final Iterator<Capability> itr = caps.iterator(); itr.hasNext(); ) {
final Capability cap = itr.next();
final String curProtocol = cap.getName();
if (curProtocol.equalsIgnoreCase(prevProtocol)) {
// A later version of this protocol is already being used, so ignore this version
continue;
}
prevProtocol = curProtocol;
final SubProtocol subProtocol = subProtocols.get(cap.getName());
final int messageSpace = subProtocol == null ? 0 : subProtocol.messageSpace(cap.getVersion());
if (messageSpace > 0) {
builder.put(Range.closedOpen(offset, offset + messageSpace), cap);
}
offset += messageSpace;
}
return builder.build();
}
private static ImmutableMap<Capability, Integer> calculateCapabilityOffsets(
final ImmutableRangeMap<Integer, Capability> agreedCaps) {
final ImmutableMap.Builder<Capability, Integer> capToOffset = ImmutableMap.builder();
for (final Map.Entry<Range<Integer>, Capability> entry :
agreedCaps.asMapOfRanges().entrySet()) {
capToOffset.put(entry.getValue(), entry.getKey().lowerEndpoint());
}
return capToOffset.build();
}
public static class ProtocolMessage {
private final Capability capability;
private final MessageData message;
ProtocolMessage(final Capability capability, final MessageData message) {
this.capability = capability;
this.message = message;
}
public Capability getCapability() {
return capability;
}
public MessageData getMessage() {
return message;
}
@Override
public String toString() {
return "ProtocolMessage{"
+ "capability="
+ capability
+ ", messageCode="
+ message.getCode()
+ '}';
}
}
}