WebSocketService.java
/*
* Copyright ConsenSys AG.
*
* Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on
* an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the
* specific language governing permissions and limitations under the License.
*
* SPDX-License-Identifier: Apache-2.0
*/
package org.hyperledger.besu.ethereum.api.jsonrpc.websocket;
import static org.hyperledger.besu.ethereum.api.jsonrpc.authentication.AuthenticationUtils.truncToken;
import org.hyperledger.besu.ethereum.api.jsonrpc.authentication.AuthenticationService;
import org.hyperledger.besu.ethereum.api.jsonrpc.authentication.AuthenticationUtils;
import org.hyperledger.besu.ethereum.api.jsonrpc.authentication.DefaultAuthenticationService;
import org.hyperledger.besu.ethereum.api.jsonrpc.internal.exception.Logging403ErrorHandler;
import org.hyperledger.besu.ethereum.api.jsonrpc.websocket.subscription.SubscriptionManager;
import org.hyperledger.besu.metrics.BesuMetricCategory;
import org.hyperledger.besu.plugin.services.MetricsSystem;
import java.net.InetSocketAddress;
import java.util.Optional;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicInteger;
import com.google.common.annotations.VisibleForTesting;
import io.vertx.core.AsyncResult;
import io.vertx.core.Handler;
import io.vertx.core.Vertx;
import io.vertx.core.buffer.Buffer;
import io.vertx.core.http.HttpConnection;
import io.vertx.core.http.HttpServer;
import io.vertx.core.http.HttpServerOptions;
import io.vertx.core.http.HttpServerRequest;
import io.vertx.core.http.HttpServerResponse;
import io.vertx.core.http.ServerWebSocket;
import io.vertx.core.net.HostAndPort;
import io.vertx.core.net.SocketAddress;
import io.vertx.ext.web.Router;
import io.vertx.ext.web.RoutingContext;
import io.vertx.ext.web.handler.BodyHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
public class WebSocketService {
private static final Logger LOG = LoggerFactory.getLogger(WebSocketService.class);
private static final InetSocketAddress EMPTY_SOCKET_ADDRESS = new InetSocketAddress("0.0.0.0", 0);
private static final String APPLICATION_JSON = "application/json";
private final int maxActiveConnections;
private final AtomicInteger activeConnectionsCount = new AtomicInteger();
private final Vertx vertx;
private final WebSocketConfiguration configuration;
private final WebSocketMessageHandler websocketMessageHandler;
private HttpServer httpServer;
@VisibleForTesting public final Optional<AuthenticationService> authenticationService;
public WebSocketService(
final Vertx vertx,
final WebSocketConfiguration configuration,
final WebSocketMessageHandler websocketMessageHandler,
final MetricsSystem metricsSystem) {
this(
vertx,
configuration,
websocketMessageHandler,
DefaultAuthenticationService.create(vertx, configuration),
metricsSystem);
}
public WebSocketService(
final Vertx vertx,
final WebSocketConfiguration configuration,
final WebSocketMessageHandler websocketMessageHandler,
final Optional<AuthenticationService> authenticationService,
final MetricsSystem metricsSystem) {
this.vertx = vertx;
this.configuration = configuration;
this.websocketMessageHandler = websocketMessageHandler;
this.authenticationService = authenticationService;
this.maxActiveConnections = configuration.getMaxActiveConnections();
metricsSystem.createIntegerGauge(
BesuMetricCategory.RPC,
"active_ws_connection_count",
"Total no of active rpc ws connections",
activeConnectionsCount::intValue);
}
public CompletableFuture<?> start() {
LOG.info(
"Starting Websocket service on {}:{}", configuration.getHost(), configuration.getPort());
final CompletableFuture<?> resultFuture = new CompletableFuture<>();
httpServer =
vertx
.createHttpServer(
new HttpServerOptions()
.setHost(configuration.getHost())
.setPort(configuration.getPort())
.setHandle100ContinueAutomatically(true)
.setCompressionSupported(true)
.addWebSocketSubProtocol("undefined")
.setMaxWebSocketFrameSize(configuration.getMaxFrameSize())
.setMaxWebSocketMessageSize(configuration.getMaxFrameSize() * 4)
.setRegisterWebSocketWriteHandlers(true))
.webSocketHandler(websocketHandler())
.connectionHandler(connectionHandler())
.requestHandler(httpHandler())
.listen(startHandler(resultFuture));
return resultFuture;
}
private Handler<ServerWebSocket> websocketHandler() {
return websocket -> {
final SocketAddress socketAddress = websocket.remoteAddress();
final String connectionId = websocket.textHandlerID();
final String token = getAuthToken(websocket);
if (token != null) {
LOG.atTrace()
.setMessage("Websocket authentication token {}")
.addArgument(() -> truncToken(token))
.log();
}
if (!checkHostInAllowlist(
Optional.ofNullable(websocket.authority()).map(HostAndPort::host))) {
websocket.reject(403);
}
LOG.debug("Websocket Connected ({})", socketAddressAsString(socketAddress));
final Handler<Buffer> socketHandler =
buffer -> {
LOG.debug(
"Received Websocket request (binary frame) {} ({})",
buffer.toString(),
socketAddressAsString(socketAddress));
if (authenticationService.isPresent()) {
authenticationService
.get()
.authenticate(
token, user -> websocketMessageHandler.handle(websocket, buffer, user));
} else {
websocketMessageHandler.handle(websocket, buffer, Optional.empty());
}
};
websocket.textMessageHandler(text -> socketHandler.handle(Buffer.buffer(text)));
websocket.binaryMessageHandler(socketHandler);
websocket.closeHandler(
v -> {
LOG.debug("Websocket Disconnected ({})", socketAddressAsString(socketAddress));
vertx
.eventBus()
.publish(SubscriptionManager.EVENTBUS_REMOVE_SUBSCRIPTIONS_ADDRESS, connectionId);
});
websocket.exceptionHandler(
t -> {
LOG.debug(
"Unrecoverable error on Websocket: {} ({})",
t.getMessage(),
socketAddressAsString(socketAddress));
websocket.close();
});
};
}
private Handler<HttpConnection> connectionHandler() {
return connection -> {
if (activeConnectionsCount.get() >= maxActiveConnections) {
// disallow new connections to prevent DoS
LOG.warn(
"Rejecting new connection from {}. {}/{} max active connections limit reached.",
connection.remoteAddress(),
activeConnectionsCount.getAndIncrement(),
maxActiveConnections);
connection.close();
} else {
LOG.debug(
"Opened connection from {}. Total of active connections: {}/{}",
connection.remoteAddress(),
activeConnectionsCount.incrementAndGet(),
maxActiveConnections);
}
connection.closeHandler(
c ->
LOG.debug(
"Connection closed from {}. Total of active connections: {}/{}",
connection.remoteAddress(),
activeConnectionsCount.decrementAndGet(),
maxActiveConnections));
};
}
private Handler<HttpServerRequest> httpHandler() {
final Router router = Router.router(vertx);
// Verify Host header to avoid rebind attack.
router.route().handler(checkAllowlistHostHeader());
if (authenticationService.isPresent()) {
router.route("/login").handler(BodyHandler.create());
router
.post("/login")
.produces(APPLICATION_JSON)
.handler(authenticationService.get()::handleLogin);
} else {
router
.post("/login")
.produces(APPLICATION_JSON)
.handler(DefaultAuthenticationService::handleDisabledLogin);
}
router.errorHandler(403, new Logging403ErrorHandler());
router.route().handler(WebSocketService::handleHttpNotSupported);
return router;
}
private static void handleHttpNotSupported(final RoutingContext http) {
final HttpServerResponse response = http.response();
if (!response.closed()) {
response.setStatusCode(400).end("Websocket endpoint can't handle HTTP requests");
}
}
private Handler<AsyncResult<HttpServer>> startHandler(final CompletableFuture<?> resultFuture) {
return res -> {
if (res.succeeded()) {
final int actualPort = res.result().actualPort();
LOG.info(
"Websocket service started and listening on {}:{}",
configuration.getHost(),
actualPort);
configuration.setPort(actualPort);
resultFuture.complete(null);
} else {
resultFuture.completeExceptionally(res.cause());
}
};
}
public CompletableFuture<?> stop() {
if (httpServer == null) {
return CompletableFuture.completedFuture(null);
}
final CompletableFuture<?> resultFuture = new CompletableFuture<>();
httpServer.close(
res -> {
if (res.succeeded()) {
httpServer = null;
resultFuture.complete(null);
} else {
resultFuture.completeExceptionally(res.cause());
}
});
return resultFuture;
}
public InetSocketAddress socketAddress() {
if (httpServer == null) {
return EMPTY_SOCKET_ADDRESS;
}
return new InetSocketAddress(configuration.getHost(), httpServer.actualPort());
}
private String socketAddressAsString(final SocketAddress socketAddress) {
return String.format("host=%s, port=%d", socketAddress.host(), socketAddress.port());
}
private String getAuthToken(final ServerWebSocket websocket) {
return AuthenticationUtils.getJwtTokenFromAuthorizationHeaderValue(
websocket.headers().get("Authorization"));
}
private Handler<RoutingContext> checkAllowlistHostHeader() {
return event -> {
if (checkHostInAllowlist(
Optional.ofNullable(event.request().authority()).map(HostAndPort::host))) {
event.next();
} else {
final HttpServerResponse response = event.response();
if (!response.closed()) {
response
.setStatusCode(403)
.putHeader("Content-Type", "application/json; charset=utf-8")
.end("{\"message\":\"Host not authorized.\"}");
}
}
};
}
@VisibleForTesting
boolean checkHostInAllowlist(final Optional<String> host) {
return configuration.getHostsAllowlist().contains("*")
|| host.map(
header ->
configuration.getHostsAllowlist().stream()
.anyMatch(allowListEntry -> allowListEntry.equalsIgnoreCase(header)))
.orElse(false);
}
}