001/**
002 * Copyright (C) 2006-2019 Talend Inc. - www.talend.com
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package org.talend.sdk.component.junit;
017
018import static java.lang.Math.abs;
019import static java.util.Collections.emptyIterator;
020import static java.util.Collections.emptyMap;
021import static java.util.Locale.ROOT;
022import static java.util.concurrent.TimeUnit.MINUTES;
023import static java.util.concurrent.TimeUnit.SECONDS;
024import static java.util.stream.Collectors.joining;
025import static java.util.stream.Collectors.toList;
026import static org.apache.ziplock.JarLocation.jarLocation;
027import static org.junit.Assert.fail;
028import static org.talend.sdk.component.junit.SimpleFactory.configurationByExample;
029
030import java.util.ArrayList;
031import java.util.Collection;
032import java.util.HashMap;
033import java.util.HashSet;
034import java.util.Iterator;
035import java.util.List;
036import java.util.Map;
037import java.util.Objects;
038import java.util.Optional;
039import java.util.Queue;
040import java.util.Set;
041import java.util.Spliterator;
042import java.util.Spliterators;
043import java.util.concurrent.ConcurrentLinkedQueue;
044import java.util.concurrent.CountDownLatch;
045import java.util.concurrent.ExecutionException;
046import java.util.concurrent.ExecutorService;
047import java.util.concurrent.Executors;
048import java.util.concurrent.Future;
049import java.util.concurrent.Semaphore;
050import java.util.concurrent.TimeoutException;
051import java.util.concurrent.atomic.AtomicInteger;
052import java.util.concurrent.atomic.AtomicReference;
053import java.util.stream.Stream;
054import java.util.stream.StreamSupport;
055
056import javax.json.JsonBuilderFactory;
057import javax.json.JsonObject;
058import javax.json.bind.Jsonb;
059import javax.json.bind.JsonbConfig;
060import javax.json.spi.JsonProvider;
061
062import org.apache.xbean.finder.filter.Filter;
063import org.talend.sdk.component.api.record.Record;
064import org.talend.sdk.component.api.service.injector.Injector;
065import org.talend.sdk.component.api.service.record.RecordBuilderFactory;
066import org.talend.sdk.component.junit.lang.StreamDecorator;
067import org.talend.sdk.component.runtime.base.Lifecycle;
068import org.talend.sdk.component.runtime.input.Input;
069import org.talend.sdk.component.runtime.input.Mapper;
070import org.talend.sdk.component.runtime.manager.ComponentFamilyMeta;
071import org.talend.sdk.component.runtime.manager.ComponentManager;
072import org.talend.sdk.component.runtime.manager.ContainerComponentRegistry;
073import org.talend.sdk.component.runtime.manager.chain.AutoChunkProcessor;
074import org.talend.sdk.component.runtime.manager.chain.Job;
075import org.talend.sdk.component.runtime.manager.json.PreComputedJsonpProvider;
076import org.talend.sdk.component.runtime.output.OutputFactory;
077import org.talend.sdk.component.runtime.output.Processor;
078import org.talend.sdk.component.runtime.record.RecordConverters;
079
080import lombok.AllArgsConstructor;
081import lombok.extern.slf4j.Slf4j;
082
083@Slf4j
084public class BaseComponentsHandler implements ComponentsHandler {
085
086    protected static final Local<State> STATE = loadStateHolder();
087
088    private static Local<State> loadStateHolder() {
089        switch (System.getProperty("talend.component.junit.handler.state", "thread").toLowerCase(ROOT)) {
090        case "static":
091            return new Local.StaticImpl<>();
092        default:
093            return new Local.ThreadLocalImpl<>();
094        }
095    }
096
097    private final ThreadLocal<PreState> initState = ThreadLocal.withInitial(PreState::new);
098
099    protected String packageName;
100
101    protected Collection<String> isolatedPackages;
102
103    public <T> T injectServices(final T instance) {
104        if (instance == null) {
105            return null;
106        }
107        final String plugin = getSinglePlugin();
108        final Map<Class<?>, Object> services = asManager()
109                .findPlugin(plugin)
110                .orElseThrow(() -> new IllegalArgumentException("cant find plugin '" + plugin + "'"))
111                .get(ComponentManager.AllServices.class)
112                .getServices();
113        Injector.class.cast(services.get(Injector.class)).inject(instance);
114        return instance;
115    }
116
117    public BaseComponentsHandler withIsolatedPackage(final String packageName, final String... packages) {
118        isolatedPackages =
119                Stream.concat(Stream.of(packageName), Stream.of(packages)).filter(Objects::nonNull).collect(toList());
120        if (isolatedPackages.isEmpty()) {
121            isolatedPackages = null;
122        }
123        return this;
124    }
125
126    public EmbeddedComponentManager start() {
127        final EmbeddedComponentManager embeddedComponentManager = new EmbeddedComponentManager(packageName) {
128
129            @Override
130            protected boolean isContainerClass(final Filter filter, final String name) {
131                if (name == null) {
132                    return super.isContainerClass(filter, null);
133                }
134                return (isolatedPackages == null || isolatedPackages.stream().noneMatch(name::startsWith))
135                        && super.isContainerClass(filter, name);
136            }
137
138            @Override
139            public void close() {
140                try {
141                    final State state = STATE.get();
142                    if (state.jsonb != null) {
143                        try {
144                            state.jsonb.close();
145                        } catch (final Exception e) {
146                            // no-op: not important
147                        }
148                    }
149                    STATE.remove();
150                    initState.remove();
151                } finally {
152                    super.close();
153                }
154            }
155        };
156
157        STATE
158                .set(new State(embeddedComponentManager, new ArrayList<>(), initState.get().emitter, null, null, null,
159                        null));
160        return embeddedComponentManager;
161    }
162
163    @Override
164    public Outputs collect(final Processor processor, final ControllableInputFactory inputs) {
165        return collect(processor, inputs, 10);
166    }
167
168    /**
169     * Collects all outputs of a processor.
170     *
171     * @param processor the processor to run while there are inputs.
172     * @param inputs the input factory, when an input will return null it will stop the
173     * processing.
174     * @param bundleSize the bundle size to use.
175     * @return a map where the key is the output name and the value a stream of the
176     * output values.
177     */
178    @Override
179    public Outputs collect(final Processor processor, final ControllableInputFactory inputs, final int bundleSize) {
180        final AutoChunkProcessor autoChunkProcessor = new AutoChunkProcessor(bundleSize, processor);
181        autoChunkProcessor.start();
182        final Outputs outputs = new Outputs();
183        final OutputFactory outputFactory = name -> value -> {
184            final List aggregator = outputs.data.computeIfAbsent(name, n -> new ArrayList<>());
185            aggregator.add(value);
186        };
187        try {
188            while (inputs.hasMoreData()) {
189                autoChunkProcessor.onElement(inputs, outputFactory);
190            }
191            autoChunkProcessor.flush(outputFactory);
192        } finally {
193            autoChunkProcessor.stop();
194        }
195        return outputs;
196    }
197
198    @Override
199    public <T> Stream<T> collect(final Class<T> recordType, final Mapper mapper, final int maxRecords) {
200        return collect(recordType, mapper, maxRecords, Runtime.getRuntime().availableProcessors());
201    }
202
203    /**
204     * Collects data emitted from this mapper. If the split creates more than one
205     * mapper, it will create as much threads as mappers otherwise it will use the
206     * caller thread.
207     *
208     * IMPORTANT: don't forget to consume all the stream to ensure the underlying
209     * { @see org.talend.sdk.component.runtime.input.Input} is closed.
210     *
211     * @param recordType the record type to use to type the returned type.
212     * @param mapper the mapper to go through.
213     * @param maxRecords maximum number of records, allows to stop the source when
214     * infinite.
215     * @param concurrency requested (1 can be used instead if &lt;= 0) concurrency for the reader execution.
216     * @param <T> the returned type of the records of the mapper.
217     * @return all the records emitted by the mapper.
218     */
219    @Override
220    public <T> Stream<T> collect(final Class<T> recordType, final Mapper mapper, final int maxRecords,
221            final int concurrency) {
222        mapper.start();
223
224        final State state = STATE.get();
225        final long assess = mapper.assess();
226        final int proc = Math.max(1, concurrency);
227        final List<Mapper> mappers = mapper.split(Math.max(assess / proc, 1));
228        switch (mappers.size()) {
229        case 0:
230            return Stream.empty();
231        case 1:
232            return StreamDecorator
233                    .decorate(asStream(asIterator(mappers.iterator().next().create(), new AtomicInteger(maxRecords))),
234                            collect -> {
235                                try {
236                                    collect.run();
237                                } finally {
238                                    mapper.stop();
239                                }
240                            });
241        default: // N producers-1 consumer pattern
242            final AtomicInteger threadCounter = new AtomicInteger(0);
243            final ExecutorService es = Executors.newFixedThreadPool(mappers.size(), r -> new Thread(r) {
244
245                {
246                    setName(BaseComponentsHandler.this.getClass().getSimpleName() + "-pool-" + abs(mapper.hashCode())
247                            + "-" + threadCounter.incrementAndGet());
248                }
249            });
250            final AtomicInteger recordCounter = new AtomicInteger(maxRecords);
251            final Semaphore permissions = new Semaphore(0);
252            final Queue<T> records = new ConcurrentLinkedQueue<>();
253            final CountDownLatch latch = new CountDownLatch(mappers.size());
254            final List<? extends Future<?>> tasks = mappers
255                    .stream()
256                    .map(Mapper::create)
257                    .map(input -> (Iterator<T>) asIterator(input, recordCounter))
258                    .map(it -> es.submit(() -> {
259                        try {
260                            while (it.hasNext()) {
261                                final T next = it.next();
262                                records.add(next);
263                                permissions.release();
264                            }
265                        } finally {
266                            latch.countDown();
267                        }
268                    }))
269                    .collect(toList());
270            es.shutdown();
271
272            final int timeout = Integer.getInteger("talend.component.junit.timeout", 5);
273            new Thread() {
274
275                {
276                    setName(BaseComponentsHandler.class.getSimpleName() + "-monitor_" + abs(mapper.hashCode()));
277                }
278
279                @Override
280                public void run() {
281                    try {
282                        latch.await(timeout, MINUTES);
283                    } catch (final InterruptedException e) {
284                        Thread.currentThread().interrupt();
285                    } finally {
286                        permissions.release();
287                    }
288                }
289            }.start();
290            return StreamDecorator.decorate(asStream(new Iterator<T>() {
291
292                @Override
293                public boolean hasNext() {
294                    try {
295                        permissions.acquire();
296                    } catch (final InterruptedException e) {
297                        Thread.currentThread().interrupt();
298                        fail(e.getMessage());
299                    }
300                    return !records.isEmpty();
301                }
302
303                @Override
304                public T next() {
305                    T poll = records.poll();
306                    if (poll != null) {
307                        return mapRecord(state, recordType, poll);
308                    }
309                    return null;
310                }
311            }), task -> {
312                try {
313                    task.run();
314                } finally {
315                    tasks.forEach(f -> {
316                        try {
317                            f.get(5, SECONDS);
318                        } catch (final InterruptedException e) {
319                            Thread.currentThread().interrupt();
320                        } catch (final ExecutionException | TimeoutException e) {
321                            // no-op
322                        } finally {
323                            if (!f.isDone() && !f.isCancelled()) {
324                                f.cancel(true);
325                            }
326                        }
327                    });
328                }
329            });
330        }
331    }
332
333    private <T> Stream<T> asStream(final Iterator<T> iterator) {
334        return StreamSupport.stream(Spliterators.spliteratorUnknownSize(iterator, Spliterator.IMMUTABLE), false);
335    }
336
337    private <T> Iterator<T> asIterator(final Input input, final AtomicInteger counter) {
338        input.start();
339        return new Iterator<T>() {
340
341            private boolean closed;
342
343            private Object next;
344
345            @Override
346            public boolean hasNext() {
347                final int remaining = counter.get();
348                if (remaining <= 0) {
349                    return false;
350                }
351
352                final boolean hasNext = (next = input.next()) != null;
353                if (!hasNext && !closed) {
354                    closed = true;
355                    input.stop();
356                }
357                if (hasNext) {
358                    counter.decrementAndGet();
359                }
360                return hasNext;
361            }
362
363            @Override
364            public T next() {
365                return (T) next;
366            }
367        };
368    }
369
370    @Override
371    public <T> List<T> collectAsList(final Class<T> recordType, final Mapper mapper) {
372        return collectAsList(recordType, mapper, 1000);
373    }
374
375    @Override
376    public <T> List<T> collectAsList(final Class<T> recordType, final Mapper mapper, final int maxRecords) {
377        return collect(recordType, mapper, maxRecords).collect(toList());
378    }
379
380    @Override
381    public Mapper createMapper(final Class<?> componentType, final Object configuration) {
382        return create(Mapper.class, componentType, configuration);
383    }
384
385    @Override
386    public Processor createProcessor(final Class<?> componentType, final Object configuration) {
387        return create(Processor.class, componentType, configuration);
388    }
389
390    private <C, T, A> A create(final Class<A> api, final Class<T> componentType, final C configuration) {
391        final ComponentFamilyMeta.BaseMeta<? extends Lifecycle> meta = findMeta(componentType);
392        return api
393                .cast(meta
394                        .getInstantiator()
395                        .apply(configuration == null || meta.getParameterMetas().get().isEmpty() ? emptyMap()
396                                : configurationByExample(configuration, meta
397                                        .getParameterMetas()
398                                        .get()
399                                        .stream()
400                                        .filter(p -> p.getName().equals(p.getPath()))
401                                        .findFirst()
402                                        .map(p -> p.getName() + '.')
403                                        .orElseThrow(() -> new IllegalArgumentException(
404                                                "Didn't find any option and therefore "
405                                                        + "can't convert the configuration instance to a configuration")))));
406    }
407
408    private <T> ComponentFamilyMeta.BaseMeta<? extends Lifecycle> findMeta(final Class<T> componentType) {
409        return asManager()
410                .find(c -> c.get(ContainerComponentRegistry.class).getComponents().values().stream())
411                .flatMap(f -> Stream
412                        .concat(f.getProcessors().values().stream(), f.getPartitionMappers().values().stream()))
413                .filter(m -> m.getType().getName().equals(componentType.getName()))
414                .findFirst()
415                .orElseThrow(() -> new IllegalArgumentException("No component " + componentType));
416    }
417
418    @Override
419    public <T> List<T> collect(final Class<T> recordType, final String family, final String component,
420            final int version, final Map<String, String> configuration) {
421        Job
422                .components()
423                .component("in",
424                        family + "://" + component + "?__version=" + version
425                                + configuration
426                                        .entrySet()
427                                        .stream()
428                                        .map(entry -> entry.getKey() + "=" + entry.getValue())
429                                        .collect(joining("&", "&", "")))
430                .component("collector", "test://collector")
431                .connections()
432                .from("in")
433                .to("collector")
434                .build()
435                .run();
436
437        return getCollectedData(recordType);
438    }
439
440    @Override
441    public <T> void process(final Iterable<T> inputs, final String family, final String component, final int version,
442            final Map<String, String> configuration) {
443        setInputData(inputs);
444
445        Job
446                .components()
447                .component("emitter", "test://emitter")
448                .component("out",
449                        family + "://" + component + "?__version=" + version
450                                + configuration
451                                        .entrySet()
452                                        .stream()
453                                        .map(entry -> entry.getKey() + "=" + entry.getValue())
454                                        .collect(joining("&", "&", "")))
455                .connections()
456                .from("emitter")
457                .to("out")
458                .build()
459                .run();
460
461    }
462
463    @Override
464    public ComponentManager asManager() {
465        return STATE.get().manager;
466    }
467
468    @Override
469    public <T> T findService(final String plugin, final Class<T> serviceClass) {
470        return serviceClass
471                .cast(asManager()
472                        .findPlugin(plugin)
473                        .orElseThrow(() -> new IllegalArgumentException("cant find plugin '" + plugin + "'"))
474                        .get(ComponentManager.AllServices.class)
475                        .getServices()
476                        .get(serviceClass));
477    }
478
479    @Override
480    public <T> T findService(final Class<T> serviceClass) {
481        return findService(getSinglePlugin(), serviceClass);
482    }
483
484    public Set<String> getTestPlugins() {
485        return new HashSet<>(EmbeddedComponentManager.class.cast(asManager()).testPlugins);
486    }
487
488    @Override
489    public <T> void setInputData(final Iterable<T> data) {
490        final State state = STATE.get();
491        if (state == null) {
492            initState.get().emitter = data.iterator();
493        } else {
494            state.emitter = data.iterator();
495        }
496    }
497
498    @Override
499    public <T> List<T> getCollectedData(final Class<T> recordType) {
500        final State state = STATE.get();
501        return state.collector
502                .stream()
503                .filter(r -> recordType.isInstance(r) || JsonObject.class.isInstance(r) || Record.class.isInstance(r))
504                .map(r -> mapRecord(state, recordType, r))
505                .collect(toList());
506    }
507
508    public void resetState() {
509        final State state = STATE.get();
510        if (state == null) {
511            STATE.remove();
512        } else {
513            state.collector.clear();
514            state.emitter = emptyIterator();
515        }
516    }
517
518    private String getSinglePlugin() {
519        return Optional
520                .of(EmbeddedComponentManager.class.cast(asManager()).testPlugins/* sorted */)
521                .filter(c -> !c.isEmpty())
522                .map(c -> c.iterator().next())
523                .orElseThrow(() -> new IllegalStateException("No component plugin found"));
524    }
525
526    private <T> T mapRecord(final State state, final Class<T> recordType, final Object r) {
527        if (recordType.isInstance(r)) {
528            return recordType.cast(r);
529        }
530        if (Record.class == recordType) {
531            return recordType
532                    .cast(new RecordConverters()
533                            .toRecord(state.registry, r, state::jsonb, state::recordBuilderFactory));
534        }
535        return recordType
536                .cast(new RecordConverters()
537                        .toType(state.registry, r, recordType, state::jsonBuilderFactory, state::jsonProvider,
538                                state::jsonb, state::recordBuilderFactory));
539    }
540
541    static class PreState {
542
543        Iterator<?> emitter;
544    }
545
546    @AllArgsConstructor
547    protected static class State {
548
549        final ComponentManager manager;
550
551        final Collection<Object> collector;
552
553        final RecordConverters.MappingMetaRegistry registry = new RecordConverters.MappingMetaRegistry();
554
555        Iterator<?> emitter;
556
557        volatile Jsonb jsonb;
558
559        volatile JsonProvider jsonProvider;
560
561        volatile JsonBuilderFactory jsonBuilderFactory;
562
563        volatile RecordBuilderFactory recordBuilderFactory;
564
565        synchronized Jsonb jsonb() {
566            if (jsonb == null) {
567                jsonb = manager
568                        .getJsonbProvider()
569                        .create()
570                        .withProvider(new PreComputedJsonpProvider("test", manager.getJsonpProvider(),
571                                manager.getJsonpParserFactory(), manager.getJsonpWriterFactory(),
572                                manager.getJsonpBuilderFactory(), manager.getJsonpGeneratorFactory(),
573                                manager.getJsonpReaderFactory())) // reuses the same memory buffers
574                        .withConfig(new JsonbConfig().setProperty("johnzon.cdi.activated", false))
575                        .build();
576            }
577            return jsonb;
578        }
579
580        synchronized JsonProvider jsonProvider() {
581            if (jsonProvider == null) {
582                jsonProvider = manager.getJsonpProvider();
583            }
584            return jsonProvider;
585        }
586
587        synchronized JsonBuilderFactory jsonBuilderFactory() {
588            if (jsonBuilderFactory == null) {
589                jsonBuilderFactory = manager.getJsonpBuilderFactory();
590            }
591            return jsonBuilderFactory;
592        }
593
594        synchronized RecordBuilderFactory recordBuilderFactory() {
595            if (recordBuilderFactory == null) {
596                recordBuilderFactory = manager.getRecordBuilderFactoryProvider().apply("test");
597            }
598            return recordBuilderFactory;
599        }
600    }
601
602    public static class EmbeddedComponentManager extends ComponentManager {
603
604        private final ComponentManager oldInstance;
605
606        private final List<String> testPlugins;
607
608        private EmbeddedComponentManager(final String componentPackage) {
609            super(findM2(), "TALEND-INF/dependencies.txt", "org.talend.sdk.component:type=component,value=%s");
610            testPlugins = addJarContaining(Thread.currentThread().getContextClassLoader(),
611                    componentPackage.replace('.', '/'));
612            container
613                    .builder("component-runtime-junit.jar", jarLocation(SimpleCollector.class).getAbsolutePath())
614                    .create();
615            oldInstance = CONTEXTUAL_INSTANCE.get();
616            CONTEXTUAL_INSTANCE.set(this);
617        }
618
619        @Override
620        public void close() {
621            try {
622                super.close();
623            } finally {
624                CONTEXTUAL_INSTANCE.compareAndSet(this, oldInstance);
625            }
626        }
627
628        @Override
629        protected boolean isContainerClass(final Filter filter, final String name) {
630            // embedded mode (no plugin structure) so just run with all classes in parent classloader
631            return true;
632        }
633    }
634
635    public static class Outputs {
636
637        private final Map<String, List<?>> data = new HashMap<>();
638
639        public int size() {
640            return data.size();
641        }
642
643        public Set<String> keys() {
644            return data.keySet();
645        }
646
647        public <T> List<T> get(final Class<T> type, final String name) {
648            return (List<T>) data.get(name);
649        }
650    }
651
652    interface Local<T> {
653
654        void set(T value);
655
656        T get();
657
658        void remove();
659
660        class StaticImpl<T> implements Local<T> {
661
662            private final AtomicReference<T> state = new AtomicReference<>();
663
664            @Override
665            public void set(final T value) {
666                state.set(value);
667            }
668
669            @Override
670            public T get() {
671                return state.get();
672            }
673
674            @Override
675            public void remove() {
676                state.set(null);
677            }
678        }
679
680        class ThreadLocalImpl<T> implements Local<T> {
681
682            private final ThreadLocal<T> threadLocal = new ThreadLocal<>();
683
684            @Override
685            public void set(final T value) {
686                threadLocal.set(value);
687            }
688
689            @Override
690            public T get() {
691                return threadLocal.get();
692            }
693
694            @Override
695            public void remove() {
696                threadLocal.remove();
697            }
698        }
699    }
700}