001/**
002 * Copyright (C) 2006-2022 Talend Inc. - www.talend.com
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package org.talend.sdk.component.junit.http.internal.impl;
017
018import static java.util.Optional.ofNullable;
019import static java.util.stream.Collectors.toMap;
020import static org.talend.sdk.component.junit.http.internal.impl.Handlers.BASE;
021import static org.talend.sdk.component.junit.http.internal.impl.Handlers.closeOnFlush;
022import static org.talend.sdk.component.junit.http.internal.impl.Handlers.sendError;
023
024import java.io.ByteArrayOutputStream;
025import java.io.IOException;
026import java.io.InputStream;
027import java.net.HttpURLConnection;
028import java.net.Proxy;
029import java.net.URL;
030import java.util.HashMap;
031import java.util.List;
032import java.util.Map;
033import java.util.Objects;
034import java.util.TreeMap;
035import java.util.function.Predicate;
036import java.util.stream.Collectors;
037import java.util.stream.Stream;
038
039import javax.net.ssl.HttpsURLConnection;
040import javax.net.ssl.SSLEngine;
041
042import org.talend.sdk.component.junit.http.api.HttpApiHandler;
043import org.talend.sdk.component.junit.http.api.Response;
044
045import io.netty.buffer.ByteBuf;
046import io.netty.buffer.Unpooled;
047import io.netty.channel.ChannelHandlerContext;
048import io.netty.channel.SimpleChannelInboundHandler;
049import io.netty.handler.codec.http.DefaultFullHttpResponse;
050import io.netty.handler.codec.http.FullHttpRequest;
051import io.netty.handler.codec.http.FullHttpResponse;
052import io.netty.handler.codec.http.HttpMethod;
053import io.netty.handler.codec.http.HttpResponse;
054import io.netty.handler.codec.http.HttpResponseStatus;
055import io.netty.handler.codec.http.HttpUtil;
056import io.netty.handler.codec.http.HttpVersion;
057import io.netty.handler.ssl.SslHandler;
058import io.netty.util.Attribute;
059
060import lombok.AllArgsConstructor;
061import lombok.extern.slf4j.Slf4j;
062
063@Slf4j
064@AllArgsConstructor
065public class PassthroughHandler extends SimpleChannelInboundHandler<FullHttpRequest> {
066
067    protected final HttpApiHandler api;
068
069    @Override
070    protected void channelRead0(final ChannelHandlerContext ctx, final FullHttpRequest request) {
071        if (HttpMethod.CONNECT.name().equalsIgnoreCase(request.method().name())) {
072            final FullHttpResponse response =
073                    new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, Unpooled.EMPTY_BUFFER);
074            HttpUtil.setKeepAlive(response, true);
075            HttpUtil.setContentLength(response, 0);
076            if (api.getSslContext() != null) {
077                final SSLEngine sslEngine = api.getSslContext().createSSLEngine();
078                sslEngine.setUseClientMode(false);
079                ctx.channel().pipeline().addFirst("ssl", new SslHandler(sslEngine, true));
080
081                final String uri = request.uri();
082                final String[] parts = uri.split(":");
083                ctx
084                        .channel()
085                        .attr(BASE)
086                        .set("https://" + parts[0]
087                                + (parts.length > 1 && !"443".equals(parts[1]) ? ":" + parts[1] : ""));
088            }
089            ctx.writeAndFlush(response);
090            return;
091        }
092        final FullHttpRequest req = request.copy(); // copy to use in a separated thread
093        api.getExecutor().execute(() -> doHttpRequest(req, ctx));
094    }
095
096    private void doHttpRequest(final FullHttpRequest request, final ChannelHandlerContext ctx) {
097        try {
098            final Attribute<String> baseAttr = ctx.channel().attr(Handlers.BASE);
099            final String requestUri =
100                    (baseAttr == null || baseAttr.get() == null ? "" : baseAttr.get()) + request.uri();
101
102            // do the remote request with all the incoming data and save it
103            // note: this request must be synchronous for now
104            final Response resp;
105            final Map<String, String> otherHeaders = new HashMap<>();
106            try {
107                final URL url = new URL(requestUri);
108                final HttpURLConnection connection = HttpURLConnection.class.cast(url.openConnection(Proxy.NO_PROXY));
109                connection.setConnectTimeout(30000);
110                connection.setReadTimeout(20000);
111                if (HttpsURLConnection.class.isInstance(connection) && api.getSslContext() != null) {
112                    final HttpsURLConnection httpsURLConnection = HttpsURLConnection.class.cast(connection);
113                    httpsURLConnection.setHostnameVerifier((h, s) -> true);
114                    httpsURLConnection.setSSLSocketFactory(api.getSslContext().getSocketFactory());
115                }
116                request.headers().entries().forEach(e -> connection.setRequestProperty(e.getKey(), e.getValue()));
117                if (request.method() != null) {
118                    final String requestMethod = request.method().name().toString();
119                    connection.setRequestMethod(requestMethod);
120
121                    if (!"HEAD".equalsIgnoreCase(requestMethod) && request.content().readableBytes() > 0) {
122                        connection.setDoOutput(true);
123                        request.content().readBytes(connection.getOutputStream(), request.content().readableBytes());
124                    }
125                }
126
127                final int responseCode = connection.getResponseCode();
128                final int defaultLength =
129                        ofNullable(connection.getHeaderField("content-length")).map(Integer::parseInt).orElse(8192);
130                resp = new ResponseImpl(collectHeaders(connection, api.getHeaderFilter().negate()), responseCode,
131                        responseCode <= 399 ? slurp(connection.getInputStream(), defaultLength)
132                                : slurp(connection.getErrorStream(), defaultLength));
133
134                otherHeaders.putAll(collectHeaders(connection, api.getHeaderFilter()));
135
136                beforeResponse(requestUri, request, resp,
137                        new TreeMap<String, List<String>>(String.CASE_INSENSITIVE_ORDER) {
138
139                            {
140                                connection
141                                        .getHeaderFields()
142                                        .entrySet()
143                                        .stream()
144                                        .filter(it -> it.getKey() != null && it.getValue() != null)
145                                        .forEach(e -> put(e.getKey(), e.getValue()));
146                            }
147                        });
148            } catch (final Exception e) {
149                log.error(e.getMessage(), e);
150                sendError(ctx, HttpResponseStatus.BAD_REQUEST);
151                return;
152            }
153
154            final ByteBuf bytes = ofNullable(resp.payload()).map(Unpooled::copiedBuffer).orElse(Unpooled.EMPTY_BUFFER);
155            final HttpResponse response =
156                    new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.valueOf(resp.status()), bytes);
157            HttpUtil.setContentLength(response, bytes.array().length);
158
159            Stream
160                    .of(resp.headers(), otherHeaders)
161                    .filter(Objects::nonNull)
162                    .forEach(h -> h.forEach((k, v) -> response.headers().set(k, v)));
163            ctx.writeAndFlush(response);
164
165        } finally {
166            request.release();
167        }
168    }
169
170    private TreeMap<String, String> collectHeaders(final HttpURLConnection connection, final Predicate<String> filter) {
171        return connection
172                .getHeaderFields()
173                .entrySet()
174                .stream()
175                .filter(e -> e.getKey() != null)
176                .filter(h -> filter.test(h.getKey()))
177                .collect(toMap(Map.Entry::getKey, e -> e.getValue().stream().collect(Collectors.joining(",")),
178                        (a, b) -> a, () -> new TreeMap<>(String.CASE_INSENSITIVE_ORDER)));
179    }
180
181    protected void beforeResponse(final String requestUri, final FullHttpRequest request, final Response resp,
182            final Map<String, List<String>> headerFields) {
183        // no-op
184    }
185
186    private byte[] slurp(final InputStream inputStream, final int defaultLen) throws IOException {
187        if (inputStream == null) {
188            return null;
189        }
190
191        final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(defaultLen);
192        final byte[] bytes = new byte[defaultLen];
193        int read;
194        while ((read = inputStream.read(bytes)) >= 0) {
195            byteArrayOutputStream.write(bytes, 0, read);
196        }
197        return byteArrayOutputStream.toByteArray();
198    }
199
200    @Override
201    public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) {
202        log.error(cause.getMessage(), cause);
203        closeOnFlush(ctx.channel());
204    }
205}