001/**
002 * Copyright (C) 2006-2019 Talend Inc. - www.talend.com
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package org.talend.sdk.component.junit.http.internal.impl;
017
018import static io.netty.handler.codec.http.HttpUtil.setContentLength;
019import static io.netty.handler.codec.http.HttpUtil.setKeepAlive;
020import static java.util.Optional.ofNullable;
021import static java.util.stream.Collectors.toMap;
022import static org.talend.sdk.component.junit.http.internal.impl.Handlers.BASE;
023import static org.talend.sdk.component.junit.http.internal.impl.Handlers.closeOnFlush;
024import static org.talend.sdk.component.junit.http.internal.impl.Handlers.sendError;
025
026import java.io.ByteArrayOutputStream;
027import java.io.IOException;
028import java.io.InputStream;
029import java.net.HttpURLConnection;
030import java.net.Proxy;
031import java.net.URL;
032import java.util.HashMap;
033import java.util.List;
034import java.util.Map;
035import java.util.Objects;
036import java.util.TreeMap;
037import java.util.function.Predicate;
038import java.util.stream.Collectors;
039import java.util.stream.Stream;
040
041import javax.net.ssl.HttpsURLConnection;
042import javax.net.ssl.SSLEngine;
043
044import org.talend.sdk.component.junit.http.api.HttpApiHandler;
045import org.talend.sdk.component.junit.http.api.Response;
046
047import io.netty.buffer.ByteBuf;
048import io.netty.buffer.Unpooled;
049import io.netty.channel.ChannelHandlerContext;
050import io.netty.channel.SimpleChannelInboundHandler;
051import io.netty.handler.codec.http.DefaultFullHttpResponse;
052import io.netty.handler.codec.http.FullHttpRequest;
053import io.netty.handler.codec.http.FullHttpResponse;
054import io.netty.handler.codec.http.HttpMethod;
055import io.netty.handler.codec.http.HttpResponse;
056import io.netty.handler.codec.http.HttpResponseStatus;
057import io.netty.handler.codec.http.HttpUtil;
058import io.netty.handler.codec.http.HttpVersion;
059import io.netty.handler.ssl.SslHandler;
060import io.netty.util.Attribute;
061
062import lombok.AllArgsConstructor;
063import lombok.extern.slf4j.Slf4j;
064
065@Slf4j
066@AllArgsConstructor
067public class PassthroughHandler extends SimpleChannelInboundHandler<FullHttpRequest> {
068
069    protected final HttpApiHandler api;
070
071    @Override
072    protected void channelRead0(final ChannelHandlerContext ctx, final FullHttpRequest request) {
073        if (HttpMethod.CONNECT.name().equalsIgnoreCase(request.method().name())) {
074            final FullHttpResponse response =
075                    new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, Unpooled.EMPTY_BUFFER);
076            setKeepAlive(response, true);
077            setContentLength(response, 0);
078            if (api.getSslContext() != null) {
079                final SSLEngine sslEngine = api.getSslContext().createSSLEngine();
080                sslEngine.setUseClientMode(false);
081                ctx.channel().pipeline().addFirst("ssl", new SslHandler(sslEngine, true));
082
083                final String uri = request.uri();
084                final String[] parts = uri.split(":");
085                ctx
086                        .channel()
087                        .attr(BASE)
088                        .set("https://" + parts[0]
089                                + (parts.length > 1 && !"443".equals(parts[1]) ? ":" + parts[1] : ""));
090            }
091            ctx.writeAndFlush(response);
092            return;
093        }
094        final FullHttpRequest req = request.copy(); // copy to use in a separated thread
095        api.getExecutor().execute(() -> doHttpRequest(req, ctx));
096    }
097
098    private void doHttpRequest(final FullHttpRequest request, final ChannelHandlerContext ctx) {
099        try {
100            final Attribute<String> baseAttr = ctx.channel().attr(Handlers.BASE);
101            final String requestUri =
102                    (baseAttr == null || baseAttr.get() == null ? "" : baseAttr.get()) + request.uri();
103
104            // do the remote request with all the incoming data and save it
105            // note: this request must be synchronous for now
106            final Response resp;
107            final Map<String, String> otherHeaders = new HashMap<>();
108            try {
109                final URL url = new URL(requestUri);
110                final HttpURLConnection connection = HttpURLConnection.class.cast(url.openConnection(Proxy.NO_PROXY));
111                connection.setConnectTimeout(30000);
112                connection.setReadTimeout(20000);
113                if (HttpsURLConnection.class.isInstance(connection) && api.getSslContext() != null) {
114                    final HttpsURLConnection httpsURLConnection = HttpsURLConnection.class.cast(connection);
115                    httpsURLConnection.setHostnameVerifier((h, s) -> true);
116                    httpsURLConnection.setSSLSocketFactory(api.getSslContext().getSocketFactory());
117                }
118                request.headers().forEach(e -> connection.setRequestProperty(e.getKey(), e.getValue()));
119                if (request.method() != null) {
120                    final String requestMethod = request.method().name();
121                    connection.setRequestMethod(requestMethod);
122
123                    if (!"HEAD".equalsIgnoreCase(requestMethod) && request.content().readableBytes() > 0) {
124                        connection.setDoOutput(true);
125                        request.content().readBytes(connection.getOutputStream(), request.content().readableBytes());
126                    }
127                }
128
129                final int responseCode = connection.getResponseCode();
130                final int defaultLength =
131                        ofNullable(connection.getHeaderField("content-length")).map(Integer::parseInt).orElse(8192);
132                resp = new ResponseImpl(collectHeaders(connection, api.getHeaderFilter().negate()), responseCode,
133                        responseCode <= 399 ? slurp(connection.getInputStream(), defaultLength)
134                                : slurp(connection.getErrorStream(), defaultLength));
135
136                otherHeaders.putAll(collectHeaders(connection, api.getHeaderFilter()));
137
138                beforeResponse(requestUri, request, resp,
139                        new TreeMap<String, List<String>>(String.CASE_INSENSITIVE_ORDER) {
140
141                            {
142                                connection
143                                        .getHeaderFields()
144                                        .entrySet()
145                                        .stream()
146                                        .filter(it -> it.getKey() != null && it.getValue() != null)
147                                        .forEach(e -> put(e.getKey(), e.getValue()));
148                            }
149                        });
150            } catch (final Exception e) {
151                log.error(e.getMessage(), e);
152                sendError(ctx, HttpResponseStatus.BAD_REQUEST);
153                return;
154            }
155
156            final ByteBuf bytes = ofNullable(resp.payload()).map(Unpooled::copiedBuffer).orElse(Unpooled.EMPTY_BUFFER);
157            final HttpResponse response =
158                    new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.valueOf(resp.status()), bytes);
159            HttpUtil.setContentLength(response, bytes.array().length);
160
161            Stream
162                    .of(resp.headers(), otherHeaders)
163                    .filter(Objects::nonNull)
164                    .forEach(h -> h.forEach((k, v) -> response.headers().set(k, v)));
165            ctx.writeAndFlush(response);
166
167        } finally {
168            request.release();
169        }
170    }
171
172    private TreeMap<String, String> collectHeaders(final HttpURLConnection connection, final Predicate<String> filter) {
173        return connection
174                .getHeaderFields()
175                .entrySet()
176                .stream()
177                .filter(e -> e.getKey() != null)
178                .filter(h -> filter.test(h.getKey()))
179                .collect(toMap(Map.Entry::getKey, e -> e.getValue().stream().collect(Collectors.joining(",")),
180                        (a, b) -> a, () -> new TreeMap<>(String.CASE_INSENSITIVE_ORDER)));
181    }
182
183    protected void beforeResponse(final String requestUri, final FullHttpRequest request, final Response resp,
184            final Map<String, List<String>> headerFields) {
185        // no-op
186    }
187
188    private byte[] slurp(final InputStream inputStream, final int defaultLen) throws IOException {
189        if (inputStream == null) {
190            return null;
191        }
192
193        final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(defaultLen);
194        final byte[] bytes = new byte[defaultLen];
195        int read;
196        while ((read = inputStream.read(bytes)) >= 0) {
197            byteArrayOutputStream.write(bytes, 0, read);
198        }
199        return byteArrayOutputStream.toByteArray();
200    }
201
202    @Override
203    public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) {
204        log.error(cause.getMessage(), cause);
205        closeOnFlush(ctx.channel());
206    }
207}