001/**
002 * Copyright (C) 2006-2019 Talend Inc. - www.talend.com
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package org.talend.sdk.component.junit;
017
018import static java.lang.Math.abs;
019import static java.util.Collections.emptyIterator;
020import static java.util.Collections.emptyMap;
021import static java.util.Locale.ROOT;
022import static java.util.concurrent.TimeUnit.MINUTES;
023import static java.util.concurrent.TimeUnit.SECONDS;
024import static java.util.stream.Collectors.joining;
025import static java.util.stream.Collectors.toList;
026import static org.apache.ziplock.JarLocation.jarLocation;
027import static org.junit.Assert.fail;
028import static org.talend.sdk.component.junit.SimpleFactory.configurationByExample;
029
030import java.util.ArrayList;
031import java.util.Collection;
032import java.util.HashMap;
033import java.util.HashSet;
034import java.util.Iterator;
035import java.util.List;
036import java.util.Map;
037import java.util.Objects;
038import java.util.Optional;
039import java.util.Queue;
040import java.util.Set;
041import java.util.Spliterator;
042import java.util.Spliterators;
043import java.util.concurrent.ConcurrentLinkedQueue;
044import java.util.concurrent.CountDownLatch;
045import java.util.concurrent.ExecutionException;
046import java.util.concurrent.ExecutorService;
047import java.util.concurrent.Executors;
048import java.util.concurrent.Future;
049import java.util.concurrent.Semaphore;
050import java.util.concurrent.TimeoutException;
051import java.util.concurrent.atomic.AtomicInteger;
052import java.util.concurrent.atomic.AtomicReference;
053import java.util.stream.Stream;
054import java.util.stream.StreamSupport;
055
056import javax.json.JsonBuilderFactory;
057import javax.json.JsonObject;
058import javax.json.bind.Jsonb;
059import javax.json.bind.JsonbConfig;
060import javax.json.spi.JsonProvider;
061
062import org.apache.xbean.finder.filter.Filter;
063import org.talend.sdk.component.api.record.Record;
064import org.talend.sdk.component.api.service.injector.Injector;
065import org.talend.sdk.component.api.service.record.RecordBuilderFactory;
066import org.talend.sdk.component.junit.lang.StreamDecorator;
067import org.talend.sdk.component.runtime.base.Lifecycle;
068import org.talend.sdk.component.runtime.input.Input;
069import org.talend.sdk.component.runtime.input.Mapper;
070import org.talend.sdk.component.runtime.manager.ComponentFamilyMeta;
071import org.talend.sdk.component.runtime.manager.ComponentManager;
072import org.talend.sdk.component.runtime.manager.ContainerComponentRegistry;
073import org.talend.sdk.component.runtime.manager.chain.AutoChunkProcessor;
074import org.talend.sdk.component.runtime.manager.chain.Job;
075import org.talend.sdk.component.runtime.manager.json.PreComputedJsonpProvider;
076import org.talend.sdk.component.runtime.output.OutputFactory;
077import org.talend.sdk.component.runtime.output.Processor;
078import org.talend.sdk.component.runtime.record.RecordConverters;
079
080import lombok.AllArgsConstructor;
081import lombok.extern.slf4j.Slf4j;
082
083@Slf4j
084public class BaseComponentsHandler implements ComponentsHandler {
085
086    protected static final Local<State> STATE = loadStateHolder();
087
088    private static Local<State> loadStateHolder() {
089        switch (System.getProperty("talend.component.junit.handler.state", "thread").toLowerCase(ROOT)) {
090        case "static":
091            return new Local.StaticImpl<>();
092        default:
093            return new Local.ThreadLocalImpl<>();
094        }
095    }
096
097    private final ThreadLocal<PreState> initState = ThreadLocal.withInitial(PreState::new);
098
099    protected String packageName;
100
101    protected Collection<String> isolatedPackages;
102
103    public <T> T injectServices(final T instance) {
104        if (instance == null) {
105            return null;
106        }
107        final String plugin = getSinglePlugin();
108        final Map<Class<?>, Object> services = asManager()
109                .findPlugin(plugin)
110                .orElseThrow(() -> new IllegalArgumentException("cant find plugin '" + plugin + "'"))
111                .get(ComponentManager.AllServices.class)
112                .getServices();
113        Injector.class.cast(services.get(Injector.class)).inject(instance);
114        return instance;
115    }
116
117    public BaseComponentsHandler withIsolatedPackage(final String packageName, final String... packages) {
118        isolatedPackages =
119                Stream.concat(Stream.of(packageName), Stream.of(packages)).filter(Objects::nonNull).collect(toList());
120        if (isolatedPackages.isEmpty()) {
121            isolatedPackages = null;
122        }
123        return this;
124    }
125
126    public EmbeddedComponentManager start() {
127        final EmbeddedComponentManager embeddedComponentManager = new EmbeddedComponentManager(packageName) {
128
129            @Override
130            protected boolean isContainerClass(final Filter filter, final String name) {
131                if (name == null) {
132                    return super.isContainerClass(filter, null);
133                }
134                return (isolatedPackages == null || isolatedPackages.stream().noneMatch(name::startsWith))
135                        && super.isContainerClass(filter, name);
136            }
137
138            @Override
139            public void close() {
140                try {
141                    final State state = STATE.get();
142                    if (state.jsonb != null) {
143                        try {
144                            state.jsonb.close();
145                        } catch (final Exception e) {
146                            // no-op: not important
147                        }
148                    }
149                    STATE.remove();
150                    initState.remove();
151                } finally {
152                    super.close();
153                }
154            }
155        };
156
157        STATE
158                .set(new State(embeddedComponentManager, new ArrayList<>(), initState.get().emitter, null, null, null,
159                        null));
160        return embeddedComponentManager;
161    }
162
163    @Override
164    public Outputs collect(final Processor processor, final ControllableInputFactory inputs) {
165        return collect(processor, inputs, 10);
166    }
167
168    /**
169     * Collects all outputs of a processor.
170     *
171     * @param processor the processor to run while there are inputs.
172     * @param inputs the input factory, when an input will return null it will stop the
173     * processing.
174     * @param bundleSize the bundle size to use.
175     * @return a map where the key is the output name and the value a stream of the
176     * output values.
177     */
178    @Override
179    public Outputs collect(final Processor processor, final ControllableInputFactory inputs, final int bundleSize) {
180        final AutoChunkProcessor autoChunkProcessor = new AutoChunkProcessor(bundleSize, processor);
181        autoChunkProcessor.start();
182        final Outputs outputs = new Outputs();
183        final OutputFactory outputFactory = name -> value -> {
184            final List aggregator = outputs.data.computeIfAbsent(name, n -> new ArrayList<>());
185            aggregator.add(value);
186        };
187        try {
188            while (inputs.hasMoreData()) {
189                autoChunkProcessor.onElement(inputs, outputFactory);
190            }
191            autoChunkProcessor.flush(outputFactory);
192        } finally {
193            autoChunkProcessor.stop();
194        }
195        return outputs;
196    }
197
198    @Override
199    public <T> Stream<T> collect(final Class<T> recordType, final Mapper mapper, final int maxRecords) {
200        return collect(recordType, mapper, maxRecords, Runtime.getRuntime().availableProcessors());
201    }
202
203    /**
204     * Collects data emitted from this mapper. If the split creates more than one
205     * mapper, it will create as much threads as mappers otherwise it will use the
206     * caller thread.
207     *
208     * IMPORTANT: don't forget to consume all the stream to ensure the underlying
209     * { @see org.talend.sdk.component.runtime.input.Input} is closed.
210     *
211     * @param recordType the record type to use to type the returned type.
212     * @param mapper the mapper to go through.
213     * @param maxRecords maximum number of records, allows to stop the source when
214     * infinite.
215     * @param concurrency requested (1 can be used instead if &lt;= 0) concurrency for the reader execution.
216     * @param <T> the returned type of the records of the mapper.
217     * @return all the records emitted by the mapper.
218     */
219    @Override
220    public <T> Stream<T> collect(final Class<T> recordType, final Mapper mapper, final int maxRecords,
221            final int concurrency) {
222        mapper.start();
223
224        final State state = STATE.get();
225        final long assess = mapper.assess();
226        final int proc = Math.max(1, concurrency);
227        final List<Mapper> mappers = mapper.split(Math.max(assess / proc, 1));
228        switch (mappers.size()) {
229        case 0:
230            return Stream.empty();
231        case 1:
232            return StreamDecorator
233                    .decorate(asStream(asIterator(mappers.iterator().next().create(), new AtomicInteger(maxRecords))),
234                            collect -> {
235                                try {
236                                    collect.run();
237                                } finally {
238                                    mapper.stop();
239                                }
240                            });
241        default: // N producers-1 consumer pattern
242            final AtomicInteger threadCounter = new AtomicInteger(0);
243            final ExecutorService es = Executors.newFixedThreadPool(mappers.size(), r -> new Thread(r) {
244
245                {
246                    setName(BaseComponentsHandler.this.getClass().getSimpleName() + "-pool-" + abs(mapper.hashCode())
247                            + "-" + threadCounter.incrementAndGet());
248                }
249            });
250            final AtomicInteger recordCounter = new AtomicInteger(maxRecords);
251            final Semaphore permissions = new Semaphore(0);
252            final Queue<T> records = new ConcurrentLinkedQueue<>();
253            final CountDownLatch latch = new CountDownLatch(mappers.size());
254            final List<? extends Future<?>> tasks = mappers
255                    .stream()
256                    .map(Mapper::create)
257                    .map(input -> (Iterator<T>) asIterator(input, recordCounter))
258                    .map(it -> es.submit(() -> {
259                        try {
260                            while (it.hasNext()) {
261                                final T next = it.next();
262                                records.add(next);
263                                permissions.release();
264                            }
265                        } finally {
266                            latch.countDown();
267                        }
268                    }))
269                    .collect(toList());
270            es.shutdown();
271
272            final int timeout = Integer.getInteger("talend.component.junit.timeout", 5);
273            new Thread() {
274
275                {
276                    setName(BaseComponentsHandler.class.getSimpleName() + "-monitor_" + abs(mapper.hashCode()));
277                }
278
279                @Override
280                public void run() {
281                    try {
282                        latch.await(timeout, MINUTES);
283                    } catch (final InterruptedException e) {
284                        Thread.currentThread().interrupt();
285                    } finally {
286                        permissions.release();
287                    }
288                }
289            }.start();
290            return StreamDecorator.decorate(asStream(new Iterator<T>() {
291
292                @Override
293                public boolean hasNext() {
294                    try {
295                        permissions.acquire();
296                    } catch (final InterruptedException e) {
297                        Thread.currentThread().interrupt();
298                        fail(e.getMessage());
299                    }
300                    return !records.isEmpty();
301                }
302
303                @Override
304                public T next() {
305                    T poll = records.poll();
306                    if (poll != null) {
307                        return mapRecord(state, recordType, poll);
308                    }
309                    return null;
310                }
311            }), task -> {
312                try {
313                    task.run();
314                } finally {
315                    tasks.forEach(f -> {
316                        try {
317                            f.get(5, SECONDS);
318                        } catch (final InterruptedException e) {
319                            Thread.currentThread().interrupt();
320                        } catch (final ExecutionException | TimeoutException e) {
321                            // no-op
322                        } finally {
323                            if (!f.isDone() && !f.isCancelled()) {
324                                f.cancel(true);
325                            }
326                        }
327                    });
328                }
329            });
330        }
331    }
332
333    private <T> Stream<T> asStream(final Iterator<T> iterator) {
334        return StreamSupport.stream(Spliterators.spliteratorUnknownSize(iterator, Spliterator.IMMUTABLE), false);
335    }
336
337    private <T> Iterator<T> asIterator(final Input input, final AtomicInteger counter) {
338        input.start();
339        return new Iterator<T>() {
340
341            private boolean closed;
342
343            private Object next;
344
345            @Override
346            public boolean hasNext() {
347                final int remaining = counter.get();
348                if (remaining <= 0) {
349                    return false;
350                }
351
352                final boolean hasNext = (next = input.next()) != null;
353                if (!hasNext && !closed) {
354                    closed = true;
355                    input.stop();
356                }
357                if (hasNext) {
358                    counter.decrementAndGet();
359                }
360                return hasNext;
361            }
362
363            @Override
364            public T next() {
365                return (T) next;
366            }
367        };
368    }
369
370    @Override
371    public <T> List<T> collectAsList(final Class<T> recordType, final Mapper mapper) {
372        return collectAsList(recordType, mapper, 1000);
373    }
374
375    @Override
376    public <T> List<T> collectAsList(final Class<T> recordType, final Mapper mapper, final int maxRecords) {
377        return collect(recordType, mapper, maxRecords).collect(toList());
378    }
379
380    @Override
381    public Mapper createMapper(final Class<?> componentType, final Object configuration) {
382        return create(Mapper.class, componentType, configuration);
383    }
384
385    @Override
386    public Processor createProcessor(final Class<?> componentType, final Object configuration) {
387        return create(Processor.class, componentType, configuration);
388    }
389
390    private <C, T, A> A create(final Class<A> api, final Class<T> componentType, final C configuration) {
391        final ComponentFamilyMeta.BaseMeta<? extends Lifecycle> meta = findMeta(componentType);
392        return api
393                .cast(meta
394                        .getInstantiator()
395                        .apply(configuration == null || meta.getParameterMetas().isEmpty() ? emptyMap()
396                                : configurationByExample(configuration, meta
397                                        .getParameterMetas()
398                                        .stream()
399                                        .filter(p -> p.getName().equals(p.getPath()))
400                                        .findFirst()
401                                        .map(p -> p.getName() + '.')
402                                        .orElseThrow(() -> new IllegalArgumentException(
403                                                "Didn't find any option and therefore "
404                                                        + "can't convert the configuration instance to a configuration")))));
405    }
406
407    private <T> ComponentFamilyMeta.BaseMeta<? extends Lifecycle> findMeta(final Class<T> componentType) {
408        return asManager()
409                .find(c -> c.get(ContainerComponentRegistry.class).getComponents().values().stream())
410                .flatMap(f -> Stream
411                        .concat(f.getProcessors().values().stream(), f.getPartitionMappers().values().stream()))
412                .filter(m -> m.getType().getName().equals(componentType.getName()))
413                .findFirst()
414                .orElseThrow(() -> new IllegalArgumentException("No component " + componentType));
415    }
416
417    @Override
418    public <T> List<T> collect(final Class<T> recordType, final String family, final String component,
419            final int version, final Map<String, String> configuration) {
420        Job
421                .components()
422                .component("in",
423                        family + "://" + component + "?__version=" + version
424                                + configuration
425                                        .entrySet()
426                                        .stream()
427                                        .map(entry -> entry.getKey() + "=" + entry.getValue())
428                                        .collect(joining("&", "&", "")))
429                .component("collector", "test://collector")
430                .connections()
431                .from("in")
432                .to("collector")
433                .build()
434                .run();
435
436        return getCollectedData(recordType);
437    }
438
439    @Override
440    public <T> void process(final Iterable<T> inputs, final String family, final String component, final int version,
441            final Map<String, String> configuration) {
442        setInputData(inputs);
443
444        Job
445                .components()
446                .component("emitter", "test://emitter")
447                .component("out",
448                        family + "://" + component + "?__version=" + version
449                                + configuration
450                                        .entrySet()
451                                        .stream()
452                                        .map(entry -> entry.getKey() + "=" + entry.getValue())
453                                        .collect(joining("&", "&", "")))
454                .connections()
455                .from("emitter")
456                .to("out")
457                .build()
458                .run();
459
460    }
461
462    @Override
463    public ComponentManager asManager() {
464        return STATE.get().manager;
465    }
466
467    @Override
468    public <T> T findService(final String plugin, final Class<T> serviceClass) {
469        return serviceClass
470                .cast(asManager()
471                        .findPlugin(plugin)
472                        .orElseThrow(() -> new IllegalArgumentException("cant find plugin '" + plugin + "'"))
473                        .get(ComponentManager.AllServices.class)
474                        .getServices()
475                        .get(serviceClass));
476    }
477
478    @Override
479    public <T> T findService(final Class<T> serviceClass) {
480        return findService(getSinglePlugin(), serviceClass);
481    }
482
483    public Set<String> getTestPlugins() {
484        return new HashSet<>(EmbeddedComponentManager.class.cast(asManager()).testPlugins);
485    }
486
487    @Override
488    public <T> void setInputData(final Iterable<T> data) {
489        final State state = STATE.get();
490        if (state == null) {
491            initState.get().emitter = data.iterator();
492        } else {
493            state.emitter = data.iterator();
494        }
495    }
496
497    @Override
498    public <T> List<T> getCollectedData(final Class<T> recordType) {
499        final State state = STATE.get();
500        return state.collector
501                .stream()
502                .filter(r -> recordType.isInstance(r) || JsonObject.class.isInstance(r) || Record.class.isInstance(r))
503                .map(r -> mapRecord(state, recordType, r))
504                .collect(toList());
505    }
506
507    public void resetState() {
508        final State state = STATE.get();
509        if (state == null) {
510            STATE.remove();
511        } else {
512            state.collector.clear();
513            state.emitter = emptyIterator();
514        }
515    }
516
517    private String getSinglePlugin() {
518        return Optional
519                .of(EmbeddedComponentManager.class.cast(asManager()).testPlugins/* sorted */)
520                .filter(c -> !c.isEmpty())
521                .map(c -> c.iterator().next())
522                .orElseThrow(() -> new IllegalStateException("No component plugin found"));
523    }
524
525    private <T> T mapRecord(final State state, final Class<T> recordType, final Object r) {
526        if (recordType.isInstance(r)) {
527            return recordType.cast(r);
528        }
529        if (Record.class == recordType) {
530            return recordType
531                    .cast(new RecordConverters()
532                            .toRecord(state.registry, r, state::jsonb, state::recordBuilderFactory));
533        }
534        return recordType
535                .cast(new RecordConverters()
536                        .toType(state.registry, r, recordType, state::jsonBuilderFactory, state::jsonProvider,
537                                state::jsonb, state::recordBuilderFactory));
538    }
539
540    static class PreState {
541
542        Iterator<?> emitter;
543    }
544
545    @AllArgsConstructor
546    protected static class State {
547
548        final ComponentManager manager;
549
550        final Collection<Object> collector;
551
552        final RecordConverters.MappingMetaRegistry registry = new RecordConverters.MappingMetaRegistry();
553
554        Iterator<?> emitter;
555
556        volatile Jsonb jsonb;
557
558        volatile JsonProvider jsonProvider;
559
560        volatile JsonBuilderFactory jsonBuilderFactory;
561
562        volatile RecordBuilderFactory recordBuilderFactory;
563
564        synchronized Jsonb jsonb() {
565            if (jsonb == null) {
566                jsonb = manager
567                        .getJsonbProvider()
568                        .create()
569                        .withProvider(new PreComputedJsonpProvider("test", manager.getJsonpProvider(),
570                                manager.getJsonpParserFactory(), manager.getJsonpWriterFactory(),
571                                manager.getJsonpBuilderFactory(), manager.getJsonpGeneratorFactory(),
572                                manager.getJsonpReaderFactory())) // reuses the same memory buffers
573                        .withConfig(new JsonbConfig().setProperty("johnzon.cdi.activated", false))
574                        .build();
575            }
576            return jsonb;
577        }
578
579        synchronized JsonProvider jsonProvider() {
580            if (jsonProvider == null) {
581                jsonProvider = manager.getJsonpProvider();
582            }
583            return jsonProvider;
584        }
585
586        synchronized JsonBuilderFactory jsonBuilderFactory() {
587            if (jsonBuilderFactory == null) {
588                jsonBuilderFactory = manager.getJsonpBuilderFactory();
589            }
590            return jsonBuilderFactory;
591        }
592
593        synchronized RecordBuilderFactory recordBuilderFactory() {
594            if (recordBuilderFactory == null) {
595                recordBuilderFactory = manager.getRecordBuilderFactoryProvider().apply("test");
596            }
597            return recordBuilderFactory;
598        }
599    }
600
601    public static class EmbeddedComponentManager extends ComponentManager {
602
603        private final ComponentManager oldInstance;
604
605        private final List<String> testPlugins;
606
607        private EmbeddedComponentManager(final String componentPackage) {
608            super(findM2(), "TALEND-INF/dependencies.txt", "org.talend.sdk.component:type=component,value=%s");
609            testPlugins = addJarContaining(Thread.currentThread().getContextClassLoader(),
610                    componentPackage.replace('.', '/'));
611            container
612                    .builder("component-runtime-junit.jar", jarLocation(SimpleCollector.class).getAbsolutePath())
613                    .create();
614            oldInstance = CONTEXTUAL_INSTANCE.get();
615            CONTEXTUAL_INSTANCE.set(this);
616        }
617
618        @Override
619        public void close() {
620            try {
621                super.close();
622            } finally {
623                CONTEXTUAL_INSTANCE.compareAndSet(this, oldInstance);
624            }
625        }
626
627        @Override
628        protected boolean isContainerClass(final Filter filter, final String name) {
629            // embedded mode (no plugin structure) so just run with all classes in parent classloader
630            return true;
631        }
632    }
633
634    public static class Outputs {
635
636        private final Map<String, List<?>> data = new HashMap<>();
637
638        public int size() {
639            return data.size();
640        }
641
642        public Set<String> keys() {
643            return data.keySet();
644        }
645
646        public <T> List<T> get(final Class<T> type, final String name) {
647            return (List<T>) data.get(name);
648        }
649    }
650
651    interface Local<T> {
652
653        void set(T value);
654
655        T get();
656
657        void remove();
658
659        class StaticImpl<T> implements Local<T> {
660
661            private final AtomicReference<T> state = new AtomicReference<>();
662
663            @Override
664            public void set(final T value) {
665                state.set(value);
666            }
667
668            @Override
669            public T get() {
670                return state.get();
671            }
672
673            @Override
674            public void remove() {
675                state.set(null);
676            }
677        }
678
679        class ThreadLocalImpl<T> implements Local<T> {
680
681            private final ThreadLocal<T> threadLocal = new ThreadLocal<>();
682
683            @Override
684            public void set(final T value) {
685                threadLocal.set(value);
686            }
687
688            @Override
689            public T get() {
690                return threadLocal.get();
691            }
692
693            @Override
694            public void remove() {
695                threadLocal.remove();
696            }
697        }
698    }
699}