001/** 002 * Copyright (C) 2006-2019 Talend Inc. - www.talend.com 003 * 004 * Licensed under the Apache License, Version 2.0 (the "License"); 005 * you may not use this file except in compliance with the License. 006 * You may obtain a copy of the License at 007 * 008 * http://www.apache.org/licenses/LICENSE-2.0 009 * 010 * Unless required by applicable law or agreed to in writing, software 011 * distributed under the License is distributed on an "AS IS" BASIS, 012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 013 * See the License for the specific language governing permissions and 014 * limitations under the License. 015 */ 016package org.talend.sdk.component.junit.http.internal.impl; 017 018import static io.netty.handler.codec.http.HttpUtil.setContentLength; 019import static io.netty.handler.codec.http.HttpUtil.setKeepAlive; 020import static java.util.Optional.ofNullable; 021import static java.util.stream.Collectors.toMap; 022import static org.talend.sdk.component.junit.http.internal.impl.Handlers.BASE; 023import static org.talend.sdk.component.junit.http.internal.impl.Handlers.closeOnFlush; 024import static org.talend.sdk.component.junit.http.internal.impl.Handlers.sendError; 025 026import java.io.ByteArrayOutputStream; 027import java.io.IOException; 028import java.io.InputStream; 029import java.net.HttpURLConnection; 030import java.net.Proxy; 031import java.net.URL; 032import java.util.HashMap; 033import java.util.List; 034import java.util.Map; 035import java.util.Objects; 036import java.util.TreeMap; 037import java.util.function.Predicate; 038import java.util.stream.Collectors; 039import java.util.stream.Stream; 040 041import javax.net.ssl.HttpsURLConnection; 042import javax.net.ssl.SSLEngine; 043 044import org.talend.sdk.component.junit.http.api.HttpApiHandler; 045import org.talend.sdk.component.junit.http.api.Response; 046 047import io.netty.buffer.ByteBuf; 048import io.netty.buffer.Unpooled; 049import io.netty.channel.ChannelHandlerContext; 050import io.netty.channel.SimpleChannelInboundHandler; 051import io.netty.handler.codec.http.DefaultFullHttpResponse; 052import io.netty.handler.codec.http.FullHttpRequest; 053import io.netty.handler.codec.http.FullHttpResponse; 054import io.netty.handler.codec.http.HttpMethod; 055import io.netty.handler.codec.http.HttpResponse; 056import io.netty.handler.codec.http.HttpResponseStatus; 057import io.netty.handler.codec.http.HttpUtil; 058import io.netty.handler.codec.http.HttpVersion; 059import io.netty.handler.ssl.SslHandler; 060import io.netty.util.Attribute; 061 062import lombok.AllArgsConstructor; 063import lombok.extern.slf4j.Slf4j; 064 065@Slf4j 066@AllArgsConstructor 067public class PassthroughHandler extends SimpleChannelInboundHandler<FullHttpRequest> { 068 069 protected final HttpApiHandler api; 070 071 @Override 072 protected void channelRead0(final ChannelHandlerContext ctx, final FullHttpRequest request) { 073 if (HttpMethod.CONNECT.name().equalsIgnoreCase(request.method().name())) { 074 final FullHttpResponse response = 075 new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.OK, Unpooled.EMPTY_BUFFER); 076 setKeepAlive(response, true); 077 setContentLength(response, 0); 078 if (api.getSslContext() != null) { 079 final SSLEngine sslEngine = api.getSslContext().createSSLEngine(); 080 sslEngine.setUseClientMode(false); 081 ctx.channel().pipeline().addFirst("ssl", new SslHandler(sslEngine, true)); 082 083 final String uri = request.uri(); 084 final String[] parts = uri.split(":"); 085 ctx 086 .channel() 087 .attr(BASE) 088 .set("https://" + parts[0] 089 + (parts.length > 1 && !"443".equals(parts[1]) ? ":" + parts[1] : "")); 090 } 091 ctx.writeAndFlush(response); 092 return; 093 } 094 final FullHttpRequest req = request.copy(); // copy to use in a separated thread 095 api.getExecutor().execute(() -> doHttpRequest(req, ctx)); 096 } 097 098 private void doHttpRequest(final FullHttpRequest request, final ChannelHandlerContext ctx) { 099 try { 100 final Attribute<String> baseAttr = ctx.channel().attr(Handlers.BASE); 101 final String requestUri = 102 (baseAttr == null || baseAttr.get() == null ? "" : baseAttr.get()) + request.uri(); 103 104 // do the remote request with all the incoming data and save it 105 // note: this request must be synchronous for now 106 final Response resp; 107 final Map<String, String> otherHeaders = new HashMap<>(); 108 try { 109 final URL url = new URL(requestUri); 110 final HttpURLConnection connection = HttpURLConnection.class.cast(url.openConnection(Proxy.NO_PROXY)); 111 connection.setConnectTimeout(30000); 112 connection.setReadTimeout(20000); 113 if (HttpsURLConnection.class.isInstance(connection) && api.getSslContext() != null) { 114 final HttpsURLConnection httpsURLConnection = HttpsURLConnection.class.cast(connection); 115 httpsURLConnection.setHostnameVerifier((h, s) -> true); 116 httpsURLConnection.setSSLSocketFactory(api.getSslContext().getSocketFactory()); 117 } 118 request.headers().forEach(e -> connection.setRequestProperty(e.getKey(), e.getValue())); 119 if (request.method() != null) { 120 final String requestMethod = request.method().name(); 121 connection.setRequestMethod(requestMethod); 122 123 if (!"HEAD".equalsIgnoreCase(requestMethod) && request.content().readableBytes() > 0) { 124 connection.setDoOutput(true); 125 request.content().readBytes(connection.getOutputStream(), request.content().readableBytes()); 126 } 127 } 128 129 final int responseCode = connection.getResponseCode(); 130 final int defaultLength = 131 ofNullable(connection.getHeaderField("content-length")).map(Integer::parseInt).orElse(8192); 132 resp = new ResponseImpl(collectHeaders(connection, api.getHeaderFilter().negate()), responseCode, 133 responseCode <= 399 ? slurp(connection.getInputStream(), defaultLength) 134 : slurp(connection.getErrorStream(), defaultLength)); 135 136 otherHeaders.putAll(collectHeaders(connection, api.getHeaderFilter())); 137 138 beforeResponse(requestUri, request, resp, 139 new TreeMap<String, List<String>>(String.CASE_INSENSITIVE_ORDER) { 140 141 { 142 connection 143 .getHeaderFields() 144 .entrySet() 145 .stream() 146 .filter(it -> it.getKey() != null && it.getValue() != null) 147 .forEach(e -> put(e.getKey(), e.getValue())); 148 } 149 }); 150 } catch (final Exception e) { 151 log.error(e.getMessage(), e); 152 sendError(ctx, HttpResponseStatus.BAD_REQUEST); 153 return; 154 } 155 156 final ByteBuf bytes = ofNullable(resp.payload()).map(Unpooled::copiedBuffer).orElse(Unpooled.EMPTY_BUFFER); 157 final HttpResponse response = 158 new DefaultFullHttpResponse(HttpVersion.HTTP_1_1, HttpResponseStatus.valueOf(resp.status()), bytes); 159 HttpUtil.setContentLength(response, bytes.array().length); 160 161 Stream 162 .of(resp.headers(), otherHeaders) 163 .filter(Objects::nonNull) 164 .forEach(h -> h.forEach((k, v) -> response.headers().set(k, v))); 165 ctx.writeAndFlush(response); 166 167 } finally { 168 request.release(); 169 } 170 } 171 172 private TreeMap<String, String> collectHeaders(final HttpURLConnection connection, final Predicate<String> filter) { 173 return connection 174 .getHeaderFields() 175 .entrySet() 176 .stream() 177 .filter(e -> e.getKey() != null) 178 .filter(h -> filter.test(h.getKey())) 179 .collect(toMap(Map.Entry::getKey, e -> e.getValue().stream().collect(Collectors.joining(",")), 180 (a, b) -> a, () -> new TreeMap<>(String.CASE_INSENSITIVE_ORDER))); 181 } 182 183 protected void beforeResponse(final String requestUri, final FullHttpRequest request, final Response resp, 184 final Map<String, List<String>> headerFields) { 185 // no-op 186 } 187 188 private byte[] slurp(final InputStream inputStream, final int defaultLen) throws IOException { 189 if (inputStream == null) { 190 return null; 191 } 192 193 final ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream(defaultLen); 194 final byte[] bytes = new byte[defaultLen]; 195 int read; 196 while ((read = inputStream.read(bytes)) >= 0) { 197 byteArrayOutputStream.write(bytes, 0, read); 198 } 199 return byteArrayOutputStream.toByteArray(); 200 } 201 202 @Override 203 public void exceptionCaught(final ChannelHandlerContext ctx, final Throwable cause) { 204 log.error(cause.getMessage(), cause); 205 closeOnFlush(ctx.channel()); 206 } 207}