001/** 002 * Copyright (C) 2006-2018 Talend Inc. - www.talend.com 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016package org.talend.sdk.component.junit.http.internal.impl; 017 018import static io.netty.handler.codec.http.HttpUtil.setContentLength; 019import static io.netty.handler.codec.http.HttpUtil.setKeepAlive; 020import static java.util.Optional.ofNullable; 021import static java.util.stream.Collectors.toMap; 022import static org.talend.sdk.component.junit.http.internal.impl.Handlers.BASE; 023import static org.talend.sdk.component.junit.http.internal.impl.Handlers.closeOnFlush; 024import static org.talend.sdk.component.junit.http.internal.impl.Handlers.sendError; 025 026import java.io.ByteArrayOutputStream; 027import java.io.IOException; 028import java.io.InputStream; 029import java.net.HttpURLConnection; 030import java.net.Proxy; 031import java.net.URL; 032import java.util.HashMap; 033import java.util.List; 034import java.util.Map; 035import java.util.Objects; 036import java.util.TreeMap; 037import java.util.function.Predicate; 038import java.util.stream.Collectors; 039import java.util.stream.Stream; 040 041import javax.net.ssl.HttpsURLConnection; 042import javax.net.ssl.SSLEngine; 043 044import org.talend.sdk.component.junit.http.api.HttpApiHandler; 045import org.talend.sdk.component.junit.http.api.Response; 046 047import io.netty.buffer.ByteBuf; 048import io.netty.buffer.Unpooled; 049import io.netty.channel.ChannelHandlerContext; 050import io.netty.channel.SimpleChannelInboundHandler; 051import io.netty.handler.codec.http.DefaultFullHttpResponse; 052import io.netty.handler.codec.http.FullHttpRequest; 053import io.netty.handler.codec.http.FullHttpResponse; 054import io.netty.handler.codec.http.HttpMethod; 055import io.netty.handler.codec.http.HttpResponse; 056import io.netty.handler.codec.http.HttpResponseStatus; 057import io.netty.handler.codec.http.HttpUtil; 058import io.netty.handler.codec.http.HttpVersion; 059import io.netty.handler.ssl.SslHandler; 060import io.netty.util.Attribute; 061 062import lombok.AllArgsConstructor; 063import lombok.extern.slf4j.Slf4j; 064 065@Slf4j 066@AllArgsConstructor 067public class PassthroughHandler extends SimpleChannelInboundHandler<FullHttpRequest> { 068 069 protected final HttpApiHandler api; 070 071 @Override 072 protected void channelRead0(final ChannelHandlerContext ctx, final FullHttpRequest request) { 073 if (HttpMethod.CONNECT.name().equalsIgnoreCase(request.method().name())) { 074 final FullHttpResponse response = 075 new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, Unpooled.EMPTY_BUFFER); 076 setKeepAlive(response, true); 077 setContentLength(response, 0); 078 if (api.getSslContext() != null) { 079 final SSLEngine sslEngine = api.getSslContext().createSSLEngine(); 080 sslEngine.setUseClientMode(false); 081 ctx.channel().pipeline().addFirst("ssl", new SslHandler(sslEngine, true)); 082 083 final String uri = request.uri(); 084 final String[] parts = uri.split(":"); 085 ctx.channel().attr(BASE).set( 086 "https://" + parts[0] + (parts.length > 1 && !"443".equals(parts[1]) ? ":" + parts[1] : "")); 087 } 088 ctx.writeAndFlush(response); 089 return; 090 } 091 final FullHttpRequest req = request.copy(); // copy to use in a separated thread 092 api.getExecutor().execute(() -> doHttpRequest(req, ctx)); 093 } 094 095 private void doHttpRequest(final FullHttpRequest request, final ChannelHandlerContext ctx) { 096 try { 097 final Attribute<String> baseAttr = ctx.channel().attr(Handlers.BASE); 098 final String requestUri = 099 (baseAttr == null || baseAttr.get() == null ? "" : baseAttr.get()) + request.uri(); 100 101 // do the remote request with all the incoming data and save it 102 // note: this request must be synchronous for now 103 final Response resp; 104 final Map<String, String> otherHeaders = new HashMap<>(); 105 try { 106 final URL url = new URL(requestUri); 107 final HttpURLConnection connection = HttpURLConnection.class.cast(url.openConnection(Proxy.NO_PROXY)); 108 connection.setConnectTimeout(30000); 109 connection.setReadTimeout(20000); 110 if (HttpsURLConnection.class.isInstance(connection) && api.getSslContext() != null) { 111 final HttpsURLConnection httpsURLConnection = HttpsURLConnection.class.cast(connection); 112 httpsURLConnection.setHostnameVerifier((h, s) -> true); 113 httpsURLConnection.setSSLSocketFactory(api.getSslContext().getSocketFactory()); 114 } 115 request.headers().forEach(e -> connection.setRequestProperty(e.getKey(), e.getValue())); 116 if (request.method() != null) { 117 final String requestMethod = request.method().name(); 118 connection.setRequestMethod(requestMethod); 119 120 if (!"HEAD".equalsIgnoreCase(requestMethod) && request.content().readableBytes() > 0) { 121 connection.setDoOutput(true); 122 request.content().readBytes(connection.getOutputStream(), request.content().readableBytes()); 123 } 124 } 125 126 final int responseCode = connection.getResponseCode(); 127 final int defaultLength = 128 ofNullable(connection.getHeaderField("content-length")).map(Integer::parseInt).orElse(8192); 129 resp = new ResponseImpl(collectHeaders(connection, api.getHeaderFilter().negate()), responseCode, 130 responseCode <= 399 ? slurp(connection.getInputStream(), defaultLength) 131 : slurp(connection.getErrorStream(), defaultLength)); 132 133 otherHeaders.putAll(collectHeaders(connection, api.getHeaderFilter())); 134 135 beforeResponse(requestUri, request, resp, 136 new TreeMap<String, List<String>>(String.CASE_INSENSITIVE_ORDER) { 137 138 { 139 connection 140 .getHeaderFields() 141 .entrySet() 142 .stream() 143 .filter(it -> it.getKey() != null && it.getValue() != null) 144 .forEach(e -> put(e.getKey(), e.getValue())); 145 } 146 }); 147 } catch (final Exception e) { 148 log.error(e.getMessage(), e); 149 sendError(ctx, HttpResponseStatus.BAD_REQUEST); 150 return; 151 } 152 153 final ByteBuf bytes = ofNullable(resp.payload()).map(Unpooled::copiedBuffer).orElse(Unpooled.EMPTY_BUFFER); 154 final HttpResponse response = 155 new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.valueOf(resp.status()), bytes); 156 HttpUtil.setContentLength(response, bytes.array().length); 157 158 Stream.of(resp.headers(), otherHeaders).filter(Objects::nonNull).forEach( 159 h -> h.forEach((k, v) -> response.headers().set(k, v))); 160 ctx.writeAndFlush(response); 161 162 } finally { 163 request.release(); 164 } 165 } 166 167 private TreeMap<String, String> collectHeaders(final HttpURLConnection connection, final Predicate<String> filter) { 168 return connection 169 .getHeaderFields() 170 .entrySet() 171 .stream() 172 .filter(e -> e.getKey() != null) 173 .filter(h -> filter.test(h.getKey())) 174 .collect(toMap(Map.Entry::getKey, e -> e.getValue().stream().collect(Collectors.joining(",")), 175 (a, b) -> a, () -> new TreeMap<>(String.CASE_INSENSITIVE_ORDER))); 176 } 177 178 protected void beforeResponse(final String requestUri, final FullHttpRequest request, final Response resp, 179 final Map<String, List<String>> headerFields) { 180 // no-op 181 } 182 183 private byte[] slurp(final InputStream inputStream, final int defaultLen) throws IOException { 184 if (inputStream == null) { 185 return null; 186 } 187 188 final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(defaultLen); 189 final byte[] bytes = new byte[defaultLen]; 190 int read; 191 while ((read = inputStream.read(bytes)) >= 0) { 192 byteArrayOutputStream.write(bytes, 0, read); 193 } 194 return byteArrayOutputStream.toByteArray(); 195 } 196 197 @Override 198 public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) { 199 log.error(cause.getMessage(), cause); 200 closeOnFlush(ctx.channel()); 201 } 202}