001/**
002 * Copyright (C) 2006-2018 Talend Inc. - www.talend.com
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package org.talend.sdk.component.junit.http.internal.impl;
017
018import static io.netty.handler.codec.http.HttpUtil.setContentLength;
019import static io.netty.handler.codec.http.HttpUtil.setKeepAlive;
020import static java.util.Optional.ofNullable;
021import static java.util.stream.Collectors.toMap;
022import static org.talend.sdk.component.junit.http.internal.impl.Handlers.BASE;
023import static org.talend.sdk.component.junit.http.internal.impl.Handlers.closeOnFlush;
024import static org.talend.sdk.component.junit.http.internal.impl.Handlers.sendError;
025
026import java.io.ByteArrayOutputStream;
027import java.io.IOException;
028import java.io.InputStream;
029import java.net.HttpURLConnection;
030import java.net.Proxy;
031import java.net.URL;
032import java.util.HashMap;
033import java.util.List;
034import java.util.Map;
035import java.util.Objects;
036import java.util.TreeMap;
037import java.util.function.Predicate;
038import java.util.stream.Collectors;
039import java.util.stream.Stream;
040
041import javax.net.ssl.HttpsURLConnection;
042import javax.net.ssl.SSLEngine;
043
044import org.talend.sdk.component.junit.http.api.HttpApiHandler;
045import org.talend.sdk.component.junit.http.api.Response;
046
047import io.netty.buffer.ByteBuf;
048import io.netty.buffer.Unpooled;
049import io.netty.channel.ChannelHandlerContext;
050import io.netty.channel.SimpleChannelInboundHandler;
051import io.netty.handler.codec.http.DefaultFullHttpResponse;
052import io.netty.handler.codec.http.FullHttpRequest;
053import io.netty.handler.codec.http.FullHttpResponse;
054import io.netty.handler.codec.http.HttpMethod;
055import io.netty.handler.codec.http.HttpResponse;
056import io.netty.handler.codec.http.HttpResponseStatus;
057import io.netty.handler.codec.http.HttpUtil;
058import io.netty.handler.codec.http.HttpVersion;
059import io.netty.handler.ssl.SslHandler;
060import io.netty.util.Attribute;
061
062import lombok.AllArgsConstructor;
063import lombok.extern.slf4j.Slf4j;
064
065@Slf4j
066@AllArgsConstructor
067public class PassthroughHandler extends SimpleChannelInboundHandler<FullHttpRequest> {
068
069    protected final HttpApiHandler api;
070
071    @Override
072    protected void channelRead0(final ChannelHandlerContext ctx, final FullHttpRequest request) {
073        if (HttpMethod.CONNECT.name().equalsIgnoreCase(request.method().name())) {
074            final FullHttpResponse response =
075                    new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, Unpooled.EMPTY_BUFFER);
076            setKeepAlive(response, true);
077            setContentLength(response, 0);
078            if (api.getSslContext() != null) {
079                final SSLEngine sslEngine = api.getSslContext().createSSLEngine();
080                sslEngine.setUseClientMode(false);
081                ctx.channel().pipeline().addFirst("ssl", new SslHandler(sslEngine, true));
082
083                final String uri = request.uri();
084                final String[] parts = uri.split(":");
085                ctx.channel().attr(BASE).set(
086                        "https://" + parts[0] + (parts.length > 1 && !"443".equals(parts[1]) ? ":" + parts[1] : ""));
087            }
088            ctx.writeAndFlush(response);
089            return;
090        }
091        final FullHttpRequest req = request.copy(); // copy to use in a separated thread
092        api.getExecutor().execute(() -> doHttpRequest(req, ctx));
093    }
094
095    private void doHttpRequest(final FullHttpRequest request, final ChannelHandlerContext ctx) {
096        try {
097            final Attribute<String> baseAttr = ctx.channel().attr(Handlers.BASE);
098            final String requestUri =
099                    (baseAttr == null || baseAttr.get() == null ? "" : baseAttr.get()) + request.uri();
100
101            // do the remote request with all the incoming data and save it
102            // note: this request must be synchronous for now
103            final Response resp;
104            final Map<String, String> otherHeaders = new HashMap<>();
105            try {
106                final URL url = new URL(requestUri);
107                final HttpURLConnection connection = HttpURLConnection.class.cast(url.openConnection(Proxy.NO_PROXY));
108                connection.setConnectTimeout(30000);
109                connection.setReadTimeout(20000);
110                if (HttpsURLConnection.class.isInstance(connection) && api.getSslContext() != null) {
111                    final HttpsURLConnection httpsURLConnection = HttpsURLConnection.class.cast(connection);
112                    httpsURLConnection.setHostnameVerifier((h, s) -> true);
113                    httpsURLConnection.setSSLSocketFactory(api.getSslContext().getSocketFactory());
114                }
115                request.headers().forEach(e -> connection.setRequestProperty(e.getKey(), e.getValue()));
116                if (request.method() != null) {
117                    final String requestMethod = request.method().name();
118                    connection.setRequestMethod(requestMethod);
119
120                    if (!"HEAD".equalsIgnoreCase(requestMethod) && request.content().readableBytes() > 0) {
121                        connection.setDoOutput(true);
122                        request.content().readBytes(connection.getOutputStream(), request.content().readableBytes());
123                    }
124                }
125
126                final int responseCode = connection.getResponseCode();
127                final int defaultLength =
128                        ofNullable(connection.getHeaderField("content-length")).map(Integer::parseInt).orElse(8192);
129                resp = new ResponseImpl(collectHeaders(connection, api.getHeaderFilter().negate()), responseCode,
130                        responseCode <= 399 ? slurp(connection.getInputStream(), defaultLength)
131                                : slurp(connection.getErrorStream(), defaultLength));
132
133                otherHeaders.putAll(collectHeaders(connection, api.getHeaderFilter()));
134
135                beforeResponse(requestUri, request, resp,
136                        new TreeMap<String, List<String>>(String.CASE_INSENSITIVE_ORDER) {
137
138                            {
139                                connection
140                                        .getHeaderFields()
141                                        .entrySet()
142                                        .stream()
143                                        .filter(it -> it.getKey() != null && it.getValue() != null)
144                                        .forEach(e -> put(e.getKey(), e.getValue()));
145                            }
146                        });
147            } catch (final Exception e) {
148                log.error(e.getMessage(), e);
149                sendError(ctx, HttpResponseStatus.BAD_REQUEST);
150                return;
151            }
152
153            final ByteBuf bytes = ofNullable(resp.payload()).map(Unpooled::copiedBuffer).orElse(Unpooled.EMPTY_BUFFER);
154            final HttpResponse response =
155                    new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.valueOf(resp.status()), bytes);
156            HttpUtil.setContentLength(response, bytes.array().length);
157
158            Stream.of(resp.headers(), otherHeaders).filter(Objects::nonNull).forEach(
159                    h -> h.forEach((k, v) -> response.headers().set(k, v)));
160            ctx.writeAndFlush(response);
161
162        } finally {
163            request.release();
164        }
165    }
166
167    private TreeMap<String, String> collectHeaders(final HttpURLConnection connection, final Predicate<String> filter) {
168        return connection
169                .getHeaderFields()
170                .entrySet()
171                .stream()
172                .filter(e -> e.getKey() != null)
173                .filter(h -> filter.test(h.getKey()))
174                .collect(toMap(Map.Entry::getKey, e -> e.getValue().stream().collect(Collectors.joining(",")),
175                        (a, b) -> a, () -> new TreeMap<>(String.CASE_INSENSITIVE_ORDER)));
176    }
177
178    protected void beforeResponse(final String requestUri, final FullHttpRequest request, final Response resp,
179            final Map<String, List<String>> headerFields) {
180        // no-op
181    }
182
183    private byte[] slurp(final InputStream inputStream, final int defaultLen) throws IOException {
184        if (inputStream == null) {
185            return null;
186        }
187
188        final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(defaultLen);
189        final byte[] bytes = new byte[defaultLen];
190        int read;
191        while ((read = inputStream.read(bytes)) >= 0) {
192            byteArrayOutputStream.write(bytes, 0, read);
193        }
194        return byteArrayOutputStream.toByteArray();
195    }
196
197    @Override
198    public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) {
199        log.error(cause.getMessage(), cause);
200        closeOnFlush(ctx.channel());
201    }
202}