001/**
002 * Copyright (C) 2006-2022 Talend Inc. - www.talend.com
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package org.talend.sdk.component.junit;
017
018import static java.lang.Math.abs;
019import static java.util.Collections.emptyIterator;
020import static java.util.Collections.emptyMap;
021import static java.util.Locale.ROOT;
022import static java.util.concurrent.TimeUnit.MINUTES;
023import static java.util.concurrent.TimeUnit.SECONDS;
024import static java.util.stream.Collectors.joining;
025import static java.util.stream.Collectors.toList;
026import static org.apache.ziplock.JarLocation.jarLocation;
027import static org.junit.Assert.fail;
028import static org.talend.sdk.component.junit.SimpleFactory.configurationByExample;
029
030import java.util.ArrayList;
031import java.util.Collection;
032import java.util.HashMap;
033import java.util.HashSet;
034import java.util.Iterator;
035import java.util.List;
036import java.util.Map;
037import java.util.Objects;
038import java.util.Optional;
039import java.util.Queue;
040import java.util.Set;
041import java.util.Spliterator;
042import java.util.Spliterators;
043import java.util.concurrent.ConcurrentLinkedQueue;
044import java.util.concurrent.CopyOnWriteArrayList;
045import java.util.concurrent.CountDownLatch;
046import java.util.concurrent.ExecutionException;
047import java.util.concurrent.ExecutorService;
048import java.util.concurrent.Executors;
049import java.util.concurrent.Future;
050import java.util.concurrent.Semaphore;
051import java.util.concurrent.TimeoutException;
052import java.util.concurrent.atomic.AtomicInteger;
053import java.util.concurrent.atomic.AtomicReference;
054import java.util.stream.Stream;
055import java.util.stream.StreamSupport;
056
057import javax.json.JsonBuilderFactory;
058import javax.json.JsonObject;
059import javax.json.bind.Jsonb;
060import javax.json.bind.JsonbConfig;
061import javax.json.spi.JsonProvider;
062
063import org.apache.xbean.finder.filter.Filter;
064import org.talend.sdk.component.api.record.Record;
065import org.talend.sdk.component.api.service.injector.Injector;
066import org.talend.sdk.component.api.service.record.RecordBuilderFactory;
067import org.talend.sdk.component.junit.lang.StreamDecorator;
068import org.talend.sdk.component.runtime.base.Lifecycle;
069import org.talend.sdk.component.runtime.input.Input;
070import org.talend.sdk.component.runtime.input.Mapper;
071import org.talend.sdk.component.runtime.manager.ComponentFamilyMeta;
072import org.talend.sdk.component.runtime.manager.ComponentManager;
073import org.talend.sdk.component.runtime.manager.ContainerComponentRegistry;
074import org.talend.sdk.component.runtime.manager.chain.AutoChunkProcessor;
075import org.talend.sdk.component.runtime.manager.chain.Job;
076import org.talend.sdk.component.runtime.manager.json.PreComputedJsonpProvider;
077import org.talend.sdk.component.runtime.output.OutputFactory;
078import org.talend.sdk.component.runtime.output.Processor;
079import org.talend.sdk.component.runtime.record.RecordConverters;
080
081import lombok.AllArgsConstructor;
082import lombok.extern.slf4j.Slf4j;
083
084@Slf4j
085public class BaseComponentsHandler implements ComponentsHandler {
086
087    protected static final Local<State> STATE = loadStateHolder();
088
089    private static Local<State> loadStateHolder() {
090        switch (System.getProperty("talend.component.junit.handler.state", "thread").toLowerCase(ROOT)) {
091        case "static":
092            return new Local.StaticImpl<>();
093        default:
094            return new Local.ThreadLocalImpl<>();
095        }
096    }
097
098    private final ThreadLocal<PreState> initState = ThreadLocal.withInitial(PreState::new);
099
100    protected String packageName;
101
102    protected Collection<String> isolatedPackages;
103
104    public <T> T injectServices(final T instance) {
105        if (instance == null) {
106            return null;
107        }
108        final String plugin = getSinglePlugin();
109        final Map<Class<?>, Object> services = asManager()
110                .findPlugin(plugin)
111                .orElseThrow(() -> new IllegalArgumentException("cant find plugin '" + plugin + "'"))
112                .get(ComponentManager.AllServices.class)
113                .getServices();
114        Injector.class.cast(services.get(Injector.class)).inject(instance);
115        return instance;
116    }
117
118    public BaseComponentsHandler withIsolatedPackage(final String packageName, final String... packages) {
119        isolatedPackages =
120                Stream.concat(Stream.of(packageName), Stream.of(packages)).filter(Objects::nonNull).collect(toList());
121        if (isolatedPackages.isEmpty()) {
122            isolatedPackages = null;
123        }
124        return this;
125    }
126
127    public EmbeddedComponentManager start() {
128        final EmbeddedComponentManager embeddedComponentManager = new EmbeddedComponentManager(packageName) {
129
130            @Override
131            protected boolean isContainerClass(final Filter filter, final String name) {
132                if (name == null) {
133                    return super.isContainerClass(filter, null);
134                }
135                return (isolatedPackages == null || isolatedPackages.stream().noneMatch(name::startsWith))
136                        && super.isContainerClass(filter, name);
137            }
138
139            @Override
140            public void close() {
141                try {
142                    final State state = STATE.get();
143                    if (state.jsonb != null) {
144                        try {
145                            state.jsonb.close();
146                        } catch (final Exception e) {
147                            // no-op: not important
148                        }
149                    }
150                    STATE.remove();
151                    initState.remove();
152                } finally {
153                    super.close();
154                }
155            }
156        };
157
158        STATE
159                .set(new State(embeddedComponentManager, new CopyOnWriteArrayList<>(), initState.get().emitter, null,
160                        null, null, null));
161        return embeddedComponentManager;
162    }
163
164    @Override
165    public Outputs collect(final Processor processor, final ControllableInputFactory inputs) {
166        return collect(processor, inputs, 10);
167    }
168
169    /**
170     * Collects all outputs of a processor.
171     *
172     * @param processor the processor to run while there are inputs.
173     * @param inputs the input factory, when an input will return null it will stop the
174     * processing.
175     * @param bundleSize the bundle size to use.
176     * @return a map where the key is the output name and the value a stream of the
177     * output values.
178     */
179    @Override
180    public Outputs collect(final Processor processor, final ControllableInputFactory inputs, final int bundleSize) {
181        final AutoChunkProcessor autoChunkProcessor = new AutoChunkProcessor(bundleSize, processor);
182        autoChunkProcessor.start();
183        final Outputs outputs = new Outputs();
184        final OutputFactory outputFactory = name -> value -> {
185            final List aggregator = outputs.data.computeIfAbsent(name, n -> new ArrayList<>());
186            aggregator.add(value);
187        };
188        try {
189            while (inputs.hasMoreData()) {
190                autoChunkProcessor.onElement(inputs, outputFactory);
191            }
192            autoChunkProcessor.flush(outputFactory);
193        } finally {
194            autoChunkProcessor.stop();
195        }
196        return outputs;
197    }
198
199    @Override
200    public <T> Stream<T> collect(final Class<T> recordType, final Mapper mapper, final int maxRecords) {
201        return collect(recordType, mapper, maxRecords, Runtime.getRuntime().availableProcessors());
202    }
203
204    /**
205     * Collects data emitted from this mapper. If the split creates more than one
206     * mapper, it will create as much threads as mappers otherwise it will use the
207     * caller thread.
208     *
209     * IMPORTANT: don't forget to consume all the stream to ensure the underlying
210     * { @see org.talend.sdk.component.runtime.input.Input} is closed.
211     *
212     * @param recordType the record type to use to type the returned type.
213     * @param mapper the mapper to go through.
214     * @param maxRecords maximum number of records, allows to stop the source when
215     * infinite.
216     * @param concurrency requested (1 can be used instead if &lt;= 0) concurrency for the reader execution.
217     * @param <T> the returned type of the records of the mapper.
218     * @return all the records emitted by the mapper.
219     */
220    @Override
221    public <T> Stream<T> collect(final Class<T> recordType, final Mapper mapper, final int maxRecords,
222            final int concurrency) {
223        mapper.start();
224
225        final State state = STATE.get();
226        final long assess = mapper.assess();
227        final int proc = Math.max(1, concurrency);
228        final List<Mapper> mappers = mapper.split(Math.max(assess / proc, 1));
229        switch (mappers.size()) {
230        case 0:
231            return Stream.empty();
232        case 1:
233            return StreamDecorator
234                    .decorate(asStream(asIterator(mappers.iterator().next().create(), new AtomicInteger(maxRecords))),
235                            collect -> {
236                                try {
237                                    collect.run();
238                                } finally {
239                                    mapper.stop();
240                                }
241                            });
242        default: // N producers-1 consumer pattern
243            final AtomicInteger threadCounter = new AtomicInteger(0);
244            final ExecutorService es = Executors.newFixedThreadPool(mappers.size(), r -> new Thread(r) {
245
246                {
247                    setName(BaseComponentsHandler.this.getClass().getSimpleName() + "-pool-" + abs(mapper.hashCode())
248                            + "-" + threadCounter.incrementAndGet());
249                }
250            });
251            final AtomicInteger recordCounter = new AtomicInteger(maxRecords);
252            final Semaphore permissions = new Semaphore(0);
253            final Queue<T> records = new ConcurrentLinkedQueue<>();
254            final CountDownLatch latch = new CountDownLatch(mappers.size());
255            final List<? extends Future<?>> tasks = mappers
256                    .stream()
257                    .map(Mapper::create)
258                    .map(input -> (Iterator<T>) asIterator(input, recordCounter))
259                    .map(it -> es.submit(() -> {
260                        try {
261                            while (it.hasNext()) {
262                                final T next = it.next();
263                                records.add(next);
264                                permissions.release();
265                            }
266                        } finally {
267                            latch.countDown();
268                        }
269                    }))
270                    .collect(toList());
271            es.shutdown();
272
273            final int timeout = Integer.getInteger("talend.component.junit.timeout", 5);
274            new Thread() {
275
276                {
277                    setName(BaseComponentsHandler.class.getSimpleName() + "-monitor_" + abs(mapper.hashCode()));
278                }
279
280                @Override
281                public void run() {
282                    try {
283                        latch.await(timeout, MINUTES);
284                    } catch (final InterruptedException e) {
285                        Thread.currentThread().interrupt();
286                    } finally {
287                        permissions.release();
288                    }
289                }
290            }.start();
291            return StreamDecorator.decorate(asStream(new Iterator<T>() {
292
293                @Override
294                public boolean hasNext() {
295                    try {
296                        permissions.acquire();
297                    } catch (final InterruptedException e) {
298                        Thread.currentThread().interrupt();
299                        fail(e.getMessage());
300                    }
301                    return !records.isEmpty();
302                }
303
304                @Override
305                public T next() {
306                    final T poll = records.poll();
307                    if (poll != null) {
308                        return mapRecord(state, recordType, poll);
309                    }
310                    return null;
311                }
312            }), task -> {
313                try {
314                    task.run();
315                } finally {
316                    tasks.forEach(f -> {
317                        try {
318                            f.get(5, SECONDS);
319                        } catch (final InterruptedException e) {
320                            Thread.currentThread().interrupt();
321                        } catch (final ExecutionException | TimeoutException e) {
322                            // no-op
323                        } finally {
324                            if (!f.isDone() && !f.isCancelled()) {
325                                f.cancel(true);
326                            }
327                        }
328                    });
329                }
330            });
331        }
332    }
333
334    private <T> Stream<T> asStream(final Iterator<T> iterator) {
335        return StreamSupport.stream(Spliterators.spliteratorUnknownSize(iterator, Spliterator.IMMUTABLE), false);
336    }
337
338    private <T> Iterator<T> asIterator(final Input input, final AtomicInteger counter) {
339        input.start();
340        return new Iterator<T>() {
341
342            private boolean closed;
343
344            private Object next;
345
346            @Override
347            public boolean hasNext() {
348                final int remaining = counter.get();
349                if (remaining <= 0) {
350                    return false;
351                }
352
353                final boolean hasNext = (next = input.next()) != null;
354                if (!hasNext && !closed) {
355                    closed = true;
356                    input.stop();
357                }
358                if (hasNext) {
359                    counter.decrementAndGet();
360                }
361                return hasNext;
362            }
363
364            @Override
365            public T next() {
366                return (T) next;
367            }
368        };
369    }
370
371    @Override
372    public <T> List<T> collectAsList(final Class<T> recordType, final Mapper mapper) {
373        return collectAsList(recordType, mapper, 1000);
374    }
375
376    @Override
377    public <T> List<T> collectAsList(final Class<T> recordType, final Mapper mapper, final int maxRecords) {
378        return collect(recordType, mapper, maxRecords).collect(toList());
379    }
380
381    @Override
382    public Mapper createMapper(final Class<?> componentType, final Object configuration) {
383        return create(Mapper.class, componentType, configuration);
384    }
385
386    @Override
387    public Processor createProcessor(final Class<?> componentType, final Object configuration) {
388        return create(Processor.class, componentType, configuration);
389    }
390
391    private <C, T, A> A create(final Class<A> api, final Class<T> componentType, final C configuration) {
392        final ComponentFamilyMeta.BaseMeta<? extends Lifecycle> meta = findMeta(componentType);
393        return api
394                .cast(meta
395                        .getInstantiator()
396                        .apply(configuration == null || meta.getParameterMetas().get().isEmpty() ? emptyMap()
397                                : configurationByExample(configuration, meta
398                                        .getParameterMetas()
399                                        .get()
400                                        .stream()
401                                        .filter(p -> p.getName().equals(p.getPath()))
402                                        .findFirst()
403                                        .map(p -> p.getName() + '.')
404                                        .orElseThrow(() -> new IllegalArgumentException(
405                                                "Didn't find any option and therefore "
406                                                        + "can't convert the configuration instance to a configuration")))));
407    }
408
409    private <T> ComponentFamilyMeta.BaseMeta<? extends Lifecycle> findMeta(final Class<T> componentType) {
410        return asManager()
411                .find(c -> c.get(ContainerComponentRegistry.class).getComponents().values().stream())
412                .flatMap(f -> Stream
413                        .of(f.getProcessors().values().stream(), f.getPartitionMappers().values().stream(),
414                                f.getDriverRunners().values().stream())
415                        .flatMap(t -> t))
416                .filter(m -> m.getType().getName().equals(componentType.getName()))
417                .findFirst()
418                .orElseThrow(() -> new IllegalArgumentException("No component " + componentType));
419    }
420
421    @Override
422    public <T> List<T> collect(final Class<T> recordType, final String family, final String component,
423            final int version, final Map<String, String> configuration) {
424        Job
425                .components()
426                .component("in",
427                        family + "://" + component + "?__version=" + version
428                                + configuration
429                                        .entrySet()
430                                        .stream()
431                                        .map(entry -> entry.getKey() + "=" + entry.getValue())
432                                        .collect(joining("&", "&", "")))
433                .component("collector", "test://collector")
434                .connections()
435                .from("in")
436                .to("collector")
437                .build()
438                .run();
439
440        return getCollectedData(recordType);
441    }
442
443    @Override
444    public <T> void process(final Iterable<T> inputs, final String family, final String component, final int version,
445            final Map<String, String> configuration) {
446        setInputData(inputs);
447
448        Job
449                .components()
450                .component("emitter", "test://emitter")
451                .component("out",
452                        family + "://" + component + "?__version=" + version
453                                + configuration
454                                        .entrySet()
455                                        .stream()
456                                        .map(entry -> entry.getKey() + "=" + entry.getValue())
457                                        .collect(joining("&", "&", "")))
458                .connections()
459                .from("emitter")
460                .to("out")
461                .build()
462                .run();
463
464    }
465
466    @Override
467    public ComponentManager asManager() {
468        return STATE.get().manager;
469    }
470
471    @Override
472    public <T> T findService(final String plugin, final Class<T> serviceClass) {
473        return serviceClass
474                .cast(asManager()
475                        .findPlugin(plugin)
476                        .orElseThrow(() -> new IllegalArgumentException("cant find plugin '" + plugin + "'"))
477                        .get(ComponentManager.AllServices.class)
478                        .getServices()
479                        .get(serviceClass));
480    }
481
482    @Override
483    public <T> T findService(final Class<T> serviceClass) {
484        return findService(getSinglePlugin(), serviceClass);
485    }
486
487    public Set<String> getTestPlugins() {
488        return new HashSet<>(EmbeddedComponentManager.class.cast(asManager()).testPlugins);
489    }
490
491    @Override
492    public <T> void setInputData(final Iterable<T> data) {
493        final State state = STATE.get();
494        if (state == null) {
495            initState.get().emitter = data.iterator();
496        } else {
497            state.emitter = data.iterator();
498        }
499    }
500
501    @Override
502    public <T> List<T> getCollectedData(final Class<T> recordType) {
503        final State state = STATE.get();
504        return state.collector
505                .stream()
506                .filter(r -> recordType.isInstance(r) || JsonObject.class.isInstance(r) || Record.class.isInstance(r))
507                .map(r -> mapRecord(state, recordType, r))
508                .collect(toList());
509    }
510
511    public void resetState() {
512        final State state = STATE.get();
513        if (state == null) {
514            STATE.remove();
515        } else {
516            state.collector.clear();
517            state.emitter = emptyIterator();
518        }
519    }
520
521    private String getSinglePlugin() {
522        return Optional
523                .of(EmbeddedComponentManager.class.cast(asManager()).testPlugins/* sorted */)
524                .filter(c -> !c.isEmpty())
525                .map(c -> c.iterator().next())
526                .orElseThrow(() -> new IllegalStateException("No component plugin found"));
527    }
528
529    private <T> T mapRecord(final State state, final Class<T> recordType, final Object r) {
530        if (recordType.isInstance(r)) {
531            return recordType.cast(r);
532        }
533        if (Record.class == recordType) {
534            return recordType
535                    .cast(new RecordConverters()
536                            .toRecord(state.registry, r, state::jsonb, state::recordBuilderFactory));
537        }
538        return recordType
539                .cast(new RecordConverters()
540                        .toType(state.registry, r, recordType, state::jsonBuilderFactory, state::jsonProvider,
541                                state::jsonb, state::recordBuilderFactory));
542    }
543
544    static class PreState {
545
546        Iterator<?> emitter;
547    }
548
549    @AllArgsConstructor
550    protected static class State {
551
552        final ComponentManager manager;
553
554        final Collection<Object> collector;
555
556        final RecordConverters.MappingMetaRegistry registry = new RecordConverters.MappingMetaRegistry();
557
558        Iterator<?> emitter;
559
560        volatile Jsonb jsonb;
561
562        volatile JsonProvider jsonProvider;
563
564        volatile JsonBuilderFactory jsonBuilderFactory;
565
566        volatile RecordBuilderFactory recordBuilderFactory;
567
568        synchronized Jsonb jsonb() {
569            if (jsonb == null) {
570                jsonb = manager
571                        .getJsonbProvider()
572                        .create()
573                        .withProvider(new PreComputedJsonpProvider("test", manager.getJsonpProvider(),
574                                manager.getJsonpParserFactory(), manager.getJsonpWriterFactory(),
575                                manager.getJsonpBuilderFactory(), manager.getJsonpGeneratorFactory(),
576                                manager.getJsonpReaderFactory())) // reuses the same memory buffers
577                        .withConfig(new JsonbConfig().setProperty("johnzon.cdi.activated", false))
578                        .build();
579            }
580            return jsonb;
581        }
582
583        synchronized JsonProvider jsonProvider() {
584            if (jsonProvider == null) {
585                jsonProvider = manager.getJsonpProvider();
586            }
587            return jsonProvider;
588        }
589
590        synchronized JsonBuilderFactory jsonBuilderFactory() {
591            if (jsonBuilderFactory == null) {
592                jsonBuilderFactory = manager.getJsonpBuilderFactory();
593            }
594            return jsonBuilderFactory;
595        }
596
597        synchronized RecordBuilderFactory recordBuilderFactory() {
598            if (recordBuilderFactory == null) {
599                recordBuilderFactory = manager.getRecordBuilderFactoryProvider().apply("test");
600            }
601            return recordBuilderFactory;
602        }
603    }
604
605    public static class EmbeddedComponentManager extends ComponentManager {
606
607        private final ComponentManager oldInstance;
608
609        private final List<String> testPlugins;
610
611        private EmbeddedComponentManager(final String componentPackage) {
612            super(findM2(), "TALEND-INF/dependencies.txt", "org.talend.sdk.component:type=component,value=%s");
613            testPlugins = addJarContaining(Thread.currentThread().getContextClassLoader(),
614                    componentPackage.replace('.', '/'));
615            container
616                    .builder("component-runtime-junit.jar", jarLocation(SimpleCollector.class).getAbsolutePath())
617                    .create();
618            oldInstance = ComponentManager.contextualInstance().get();
619            ComponentManager.contextualInstance().set(this);
620        }
621
622        @Override
623        public void close() {
624            try {
625                super.close();
626            } finally {
627                ComponentManager.contextualInstance().compareAndSet(this, oldInstance);
628            }
629        }
630
631        @Override
632        protected boolean isContainerClass(final Filter filter, final String name) {
633            // embedded mode (no plugin structure) so just run with all classes in parent classloader
634            return true;
635        }
636    }
637
638    public static class Outputs {
639
640        private final Map<String, List<?>> data = new HashMap<>();
641
642        public int size() {
643            return data.size();
644        }
645
646        public Set<String> keys() {
647            return data.keySet();
648        }
649
650        public <T> List<T> get(final Class<T> type, final String name) {
651            return (List<T>) data.get(name);
652        }
653    }
654
655    interface Local<T> {
656
657        void set(T value);
658
659        T get();
660
661        void remove();
662
663        class StaticImpl<T> implements Local<T> {
664
665            private final AtomicReference<T> state = new AtomicReference<>();
666
667            @Override
668            public void set(final T value) {
669                state.set(value);
670            }
671
672            @Override
673            public T get() {
674                return state.get();
675            }
676
677            @Override
678            public void remove() {
679                state.set(null);
680            }
681        }
682
683        class ThreadLocalImpl<T> implements Local<T> {
684
685            private final ThreadLocal<T> threadLocal = new ThreadLocal<>();
686
687            @Override
688            public void set(final T value) {
689                threadLocal.set(value);
690            }
691
692            @Override
693            public T get() {
694                return threadLocal.get();
695            }
696
697            @Override
698            public void remove() {
699                threadLocal.remove();
700            }
701        }
702    }
703}