TCP Proxy


TCP Proxy in Java

I’ve been using this for general purpose spying when I need to do formatting and it’s too cumbersome with Wireshark/Ethereal

package com.gent00.proxy;

import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.net.Socket;
import java.nio.charset.StandardCharsets;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;

public class TcpProxy implements AutoCloseable {
private static final DateTimeFormatter TS = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss.SSS");

    private final String targetHost;
    private final int targetPort;
    private final int listenPort;

    private final ExecutorService executor = Executors.newCachedThreadPool();
    private volatile boolean running;
    private ServerSocket serverSocket;

    public TcpProxy(int listenPort, String targetHost, int targetPort) {
        this.listenPort = listenPort;
        this.targetHost = targetHost;
        this.targetPort = targetPort;
    }

    public void start() throws IOException {
        if (running) {
            return;
        }

        serverSocket = new ServerSocket();
        serverSocket.bind(new InetSocketAddress(listenPort));
        running = true;
        log("Listening on port %d, forwarding to %s:%d", listenPort, targetHost, targetPort);

        while (running) {
            try {
                Socket client = serverSocket.accept();
                log("Accepted client %s", client.getRemoteSocketAddress());
                executor.submit(() -> handleClient(client));
            } catch (IOException e) {
                if (running) {
                    throw e;
                }
            }
        }
    }

    private void handleClient(Socket client) {
        String clientAddress = String.valueOf(client.getRemoteSocketAddress());
        try (client;
             Socket target = new Socket(targetHost, targetPort)) {

            log("Connected %s -> %s:%d", clientAddress, targetHost, targetPort);

            executor.submit(() -> pipe(client, target, "client->target", clientAddress));
            pipe(target, client, "target->client", clientAddress);

        } catch (IOException e) {
            log("Connection failed for %s: %s", clientAddress, e.getMessage());
        } finally {
            log("Closed connection for %s", clientAddress);
        }
    }

    private void pipe(Socket from, Socket to, String direction, String clientAddress) {
        try (InputStream in = from.getInputStream();
             OutputStream out = to.getOutputStream()) {

            byte[] buffer = new byte[8192];
            int read;
            while ((read = in.read(buffer)) != -1 && running) {
                out.write(buffer, 0, read);
                out.flush();
                log("%s %s bytes for %s", direction, read, clientAddress);
                String payload = new String(buffer, 0, read, StandardCharsets.UTF_8);
                log("%s payload: %s", direction, payload);
            }
        } catch (IOException e) {
            log("%s ended for %s: %s", direction, clientAddress, e.getMessage());
        } finally {
            try {
                to.shutdownOutput();
            } catch (IOException ignored) {
                // Ignore shutdown issues.
            }
        }
    }

    private void log(String format, Object... args) {
        System.out.printf("[%s] %s%n", LocalDateTime.now().format(TS), String.format(format, args));
    }

    public void stop() {
        running = false;
        try {
            if (serverSocket != null && !serverSocket.isClosed()) {
                serverSocket.close();
            }
        } catch (IOException ignored) {
            // Ignore shutdown issues.
        }
        executor.shutdownNow();
        log("Proxy stopped");
    }

    @Override
    public void close() {
        stop();
    }

    public static void main(String[] args) throws Exception {
        if (args.length != 3) {
            System.out.println("Usage: java TcpProxy <listenPort> <targetHost> <targetPort>");
            return;
        }

        int listenPort = Integer.parseInt(args[0]);
        String targetHost = args[1];
        int targetPort = Integer.parseInt(args[2]);

        try (TcpProxy proxy = new TcpProxy(listenPort, targetHost, targetPort)) {
            proxy.start();
        }
    }
}