001/**
002 * Copyright (C) 2006-2018 Talend Inc. - www.talend.com
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package org.talend.sdk.component.junit.http.internal.impl;
017
018import static java.util.Optional.of;
019import static java.util.Optional.ofNullable;
020import static java.util.stream.Collectors.toMap;
021import static org.talend.sdk.component.junit.http.internal.impl.Handlers.closeOnFlush;
022import static org.talend.sdk.component.junit.http.internal.impl.Handlers.sendError;
023
024import java.net.HttpURLConnection;
025import java.nio.charset.StandardCharsets;
026import java.util.HashMap;
027import java.util.Map;
028import java.util.Optional;
029import java.util.Spliterator;
030import java.util.Spliterators;
031import java.util.stream.StreamSupport;
032
033import javax.net.ssl.SSLEngine;
034
035import org.talend.sdk.component.junit.http.api.HttpApiHandler;
036import org.talend.sdk.component.junit.http.api.Response;
037
038import io.netty.buffer.ByteBuf;
039import io.netty.buffer.Unpooled;
040import io.netty.channel.ChannelHandler;
041import io.netty.channel.ChannelHandlerContext;
042import io.netty.channel.SimpleChannelInboundHandler;
043import io.netty.handler.codec.http.DefaultFullHttpResponse;
044import io.netty.handler.codec.http.FullHttpRequest;
045import io.netty.handler.codec.http.HttpHeaderNames;
046import io.netty.handler.codec.http.HttpHeaderValues;
047import io.netty.handler.codec.http.HttpMethod;
048import io.netty.handler.codec.http.HttpResponse;
049import io.netty.handler.codec.http.HttpResponseStatus;
050import io.netty.handler.codec.http.HttpUtil;
051import io.netty.handler.codec.http.HttpVersion;
052import io.netty.handler.ssl.SslHandler;
053import io.netty.util.Attribute;
054
055import lombok.AllArgsConstructor;
056import lombok.extern.slf4j.Slf4j;
057
058@Slf4j
059@AllArgsConstructor
060@ChannelHandler.Sharable
061public class ServingProxyHandler extends SimpleChannelInboundHandler<FullHttpRequest> {
062
063    private final HttpApiHandler api;
064
065    @Override
066    protected void channelRead0(final ChannelHandlerContext ctx, final FullHttpRequest request) {
067        if (!request.decoderResult().isSuccess()) {
068            sendError(ctx, HttpResponseStatus.BAD_REQUEST);
069            return;
070        }
071
072        final String payload = request.content().toString(StandardCharsets.UTF_8);
073
074        api.getExecutor().execute(() -> {
075            final Map<String, String> headers = StreamSupport
076                    .stream(Spliterators.spliteratorUnknownSize(request.headers().iteratorAsString(),
077                            Spliterator.IMMUTABLE), false)
078                    .collect(toMap(Map.Entry::getKey, Map.Entry::getValue));
079            final Attribute<String> baseAttr = ctx.channel().attr(Handlers.BASE);
080            Optional<Response> matching = api.getResponseLocator().findMatching(
081                    new RequestImpl((baseAttr == null || baseAttr.get() == null ? "" : baseAttr.get()) + request.uri(),
082                            request.method().name(), payload, headers),
083                    api.getHeaderFilter());
084            if (!matching.isPresent()) {
085                if (HttpMethod.CONNECT.name().equalsIgnoreCase(request.method().name())) {
086                    final Map<String, String> responseHeaders = new HashMap<>();
087                    responseHeaders.put(HttpHeaderNames.CONNECTION.toString(), HttpHeaderValues.KEEP_ALIVE.toString());
088                    responseHeaders.put(HttpHeaderNames.CONTENT_LENGTH.toString(), "0");
089                    matching = of(new ResponseImpl(responseHeaders, HttpResponseStatus.OK.code(),
090                            Unpooled.EMPTY_BUFFER.array()));
091                    if (api.getSslContext() != null) {
092                        final SSLEngine sslEngine = api.getSslContext().createSSLEngine();
093                        sslEngine.setUseClientMode(false);
094                        ctx.channel().pipeline().addFirst("ssl", new SslHandler(sslEngine, true));
095
096                        final String uri = request.uri();
097                        final String[] parts = uri.split(":");
098                        ctx.channel().attr(Handlers.BASE).set("https://" + parts[0]
099                                + (parts.length > 1 && !"443".equals(parts[1]) ? ":" + parts[1] : ""));
100                    }
101                } else {
102                    sendError(ctx, new HttpResponseStatus(HttpURLConnection.HTTP_BAD_REQUEST,
103                            "You are in proxy mode. No response was found for the simulated request. Please ensure to capture it for next executions. "
104                                    + request.method().name() + " " + request.uri()));
105                    return;
106                }
107            }
108
109            final Response resp = matching.get();
110            final ByteBuf bytes = ofNullable(resp.payload()).map(Unpooled::copiedBuffer).orElse(Unpooled.EMPTY_BUFFER);
111            final HttpResponse response =
112                    new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.valueOf(resp.status()), bytes);
113            HttpUtil.setContentLength(response, bytes.array().length);
114
115            if (!api.isSkipProxyHeaders()) {
116                response.headers().set("X-Talend-Proxy-JUnit", "true");
117            }
118
119            ofNullable(resp.headers()).ifPresent(h -> h.forEach((k, v) -> response.headers().set(k, v)));
120            ctx.writeAndFlush(response);
121        });
122    }
123
124    @Override
125    public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) {
126        log.error(cause.getMessage(), cause);
127        closeOnFlush(ctx.channel());
128    }
129}