001/**
002 * Copyright (C) 2006-2025 Talend Inc. - www.talend.com
003 *
004 * Licensed under the Apache License, Version 2.0 (the "License");
005 * you may not use this file except in compliance with the License.
006 * You may obtain a copy of the License at
007 *
008 * http://www.apache.org/licenses/LICENSE-2.0
009 *
010 * Unless required by applicable law or agreed to in writing, software
011 * distributed under the License is distributed on an "AS IS" BASIS,
012 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
013 * See the License for the specific language governing permissions and
014 * limitations under the License.
015 */
016package org.talend.sdk.component.junit;
017
018import static java.lang.Math.abs;
019import static java.util.Collections.emptyIterator;
020import static java.util.Collections.emptyMap;
021import static java.util.Locale.ROOT;
022import static java.util.concurrent.TimeUnit.MINUTES;
023import static java.util.concurrent.TimeUnit.SECONDS;
024import static java.util.stream.Collectors.joining;
025import static java.util.stream.Collectors.toList;
026import static org.apache.ziplock.JarLocation.jarLocation;
027import static org.junit.Assert.fail;
028import static org.talend.sdk.component.junit.SimpleFactory.configurationByExample;
029
030import java.util.ArrayList;
031import java.util.Collection;
032import java.util.HashMap;
033import java.util.HashSet;
034import java.util.Iterator;
035import java.util.List;
036import java.util.Map;
037import java.util.Objects;
038import java.util.Optional;
039import java.util.Queue;
040import java.util.Set;
041import java.util.Spliterator;
042import java.util.Spliterators;
043import java.util.concurrent.ConcurrentLinkedQueue;
044import java.util.concurrent.CopyOnWriteArrayList;
045import java.util.concurrent.CountDownLatch;
046import java.util.concurrent.ExecutionException;
047import java.util.concurrent.ExecutorService;
048import java.util.concurrent.Executors;
049import java.util.concurrent.Future;
050import java.util.concurrent.Semaphore;
051import java.util.concurrent.TimeoutException;
052import java.util.concurrent.atomic.AtomicInteger;
053import java.util.concurrent.atomic.AtomicReference;
054import java.util.stream.Stream;
055import java.util.stream.StreamSupport;
056
057import javax.json.JsonBuilderFactory;
058import javax.json.JsonObject;
059import javax.json.bind.Jsonb;
060import javax.json.bind.JsonbConfig;
061import javax.json.spi.JsonProvider;
062
063import org.apache.xbean.finder.filter.Filter;
064import org.talend.sdk.component.api.record.Record;
065import org.talend.sdk.component.api.service.injector.Injector;
066import org.talend.sdk.component.api.service.record.RecordBuilderFactory;
067import org.talend.sdk.component.junit.lang.StreamDecorator;
068import org.talend.sdk.component.runtime.base.Lifecycle;
069import org.talend.sdk.component.runtime.input.Input;
070import org.talend.sdk.component.runtime.input.Mapper;
071import org.talend.sdk.component.runtime.manager.ComponentFamilyMeta;
072import org.talend.sdk.component.runtime.manager.ComponentManager;
073import org.talend.sdk.component.runtime.manager.ContainerComponentRegistry;
074import org.talend.sdk.component.runtime.manager.chain.AutoChunkProcessor;
075import org.talend.sdk.component.runtime.manager.chain.Job;
076import org.talend.sdk.component.runtime.manager.json.PreComputedJsonpProvider;
077import org.talend.sdk.component.runtime.output.OutputFactory;
078import org.talend.sdk.component.runtime.output.Processor;
079import org.talend.sdk.component.runtime.record.RecordConverters;
080
081import lombok.AllArgsConstructor;
082import lombok.extern.slf4j.Slf4j;
083
084@Slf4j
085public class BaseComponentsHandler implements ComponentsHandler {
086
087    protected static final Local<State> STATE = loadStateHolder();
088
089    private static Local<State> loadStateHolder() {
090        switch (System.getProperty("talend.component.junit.handler.state", "thread").toLowerCase(ROOT)) {
091            case "static":
092                return new Local.StaticImpl<>();
093            default:
094                return new Local.ThreadLocalImpl<>();
095        }
096    }
097
098    private final ThreadLocal<PreState> initState = ThreadLocal.withInitial(PreState::new);
099
100    protected String packageName;
101
102    protected Collection<String> isolatedPackages;
103
104    public <T> T injectServices(final T instance) {
105        if (instance == null) {
106            return null;
107        }
108        final String plugin = getSinglePlugin();
109        final Map<Class<?>, Object> services = asManager()
110                .findPlugin(plugin)
111                .orElseThrow(() -> new IllegalArgumentException("cant find plugin '" + plugin + "'"))
112                .get(ComponentManager.AllServices.class)
113                .getServices();
114        Injector.class.cast(services.get(Injector.class)).inject(instance);
115        return instance;
116    }
117
118    public BaseComponentsHandler withIsolatedPackage(final String packageName, final String... packages) {
119        isolatedPackages =
120                Stream.concat(Stream.of(packageName), Stream.of(packages)).filter(Objects::nonNull).collect(toList());
121        if (isolatedPackages.isEmpty()) {
122            isolatedPackages = null;
123        }
124        return this;
125    }
126
127    public EmbeddedComponentManager start() {
128        final EmbeddedComponentManager embeddedComponentManager = new EmbeddedComponentManager(packageName) {
129
130            @Override
131            protected boolean isContainerClass(final Filter filter, final String name) {
132                if (name == null) {
133                    return super.isContainerClass(filter, null);
134                }
135                return (isolatedPackages == null || isolatedPackages.stream().noneMatch(name::startsWith))
136                        && super.isContainerClass(filter, name);
137            }
138
139            @Override
140            public void close() {
141                try {
142                    final State state = STATE.get();
143                    if (state.jsonb != null) {
144                        try {
145                            state.jsonb.close();
146                        } catch (final Exception e) {
147                            // no-op: not important
148                        }
149                    }
150                    STATE.remove();
151                    initState.remove();
152                } finally {
153                    super.close();
154                }
155            }
156        };
157
158        STATE
159                .set(new State(embeddedComponentManager, new CopyOnWriteArrayList<>(), initState.get().emitter, null,
160                        null, null, null));
161        return embeddedComponentManager;
162    }
163
164    @Override
165    public Outputs collect(final Processor processor, final ControllableInputFactory inputs) {
166        return collect(processor, inputs, 10);
167    }
168
169    /**
170     * Collects all outputs of a processor.
171     *
172     * @param processor the processor to run while there are inputs.
173     * @param inputs the input factory, when an input will return null it will stop the
174     * processing.
175     * @param bundleSize the bundle size to use.
176     * @return a map where the key is the output name and the value a stream of the
177     * output values.
178     */
179    @Override
180    public Outputs collect(final Processor processor, final ControllableInputFactory inputs, final int bundleSize) {
181        final AutoChunkProcessor autoChunkProcessor = new AutoChunkProcessor(bundleSize, processor);
182        autoChunkProcessor.start();
183        final Outputs outputs = new Outputs();
184        final OutputFactory outputFactory = name -> value -> {
185            final List aggregator = outputs.data.computeIfAbsent(name, n -> new ArrayList<>());
186            aggregator.add(value);
187        };
188        try {
189            while (inputs.hasMoreData()) {
190                autoChunkProcessor.onElement(inputs, outputFactory);
191            }
192            autoChunkProcessor.flush(outputFactory);
193        } finally {
194            autoChunkProcessor.stop();
195        }
196        return outputs;
197    }
198
199    @Override
200    public <T> Stream<T> collect(final Class<T> recordType, final Mapper mapper, final int maxRecords) {
201        return collect(recordType, mapper, maxRecords, Runtime.getRuntime().availableProcessors());
202    }
203
204    /**
205     * Collects data emitted from this mapper. If the split creates more than one
206     * mapper, it will create as much threads as mappers otherwise it will use the
207     * caller thread.
208     *
209     * IMPORTANT: don't forget to consume all the stream to ensure the underlying
210     * { @see org.talend.sdk.component.runtime.input.Input} is closed.
211     *
212     * @param recordType the record type to use to type the returned type.
213     * @param mapper the mapper to go through.
214     * @param maxRecords maximum number of records, allows to stop the source when
215     * infinite.
216     * @param concurrency requested (1 can be used instead if &lt;= 0) concurrency for the reader execution.
217     * @param <T> the returned type of the records of the mapper.
218     * @return all the records emitted by the mapper.
219     */
220    @Override
221    public <T> Stream<T> collect(final Class<T> recordType, final Mapper mapper, final int maxRecords,
222            final int concurrency) {
223        mapper.start();
224
225        final State state = STATE.get();
226        final long assess = mapper.assess();
227        final int proc = Math.max(1, concurrency);
228        final List<Mapper> mappers = mapper.split(Math.max(assess / proc, 1));
229        switch (mappers.size()) {
230            case 0:
231                return Stream.empty();
232            case 1:
233                return StreamDecorator.decorate(
234                        asStream(asIterator(mappers.iterator().next().create(), new AtomicInteger(maxRecords))),
235                        collect -> {
236                            try {
237                                collect.run();
238                            } finally {
239                                mapper.stop();
240                            }
241                        });
242            default: // N producers-1 consumer pattern
243                final AtomicInteger threadCounter = new AtomicInteger(0);
244                final ExecutorService es = Executors.newFixedThreadPool(mappers.size(), r -> new Thread(r) {
245
246                    {
247                        setName(BaseComponentsHandler.this.getClass().getSimpleName() + "-pool-"
248                                + abs(mapper.hashCode())
249                                + "-" + threadCounter.incrementAndGet());
250                    }
251                });
252                final AtomicInteger recordCounter = new AtomicInteger(maxRecords);
253                final Semaphore permissions = new Semaphore(0);
254                final Queue<T> records = new ConcurrentLinkedQueue<>();
255                final CountDownLatch latch = new CountDownLatch(mappers.size());
256                final List<? extends Future<?>> tasks = mappers
257                        .stream()
258                        .map(Mapper::create)
259                        .map(input -> (Iterator<T>) asIterator(input, recordCounter))
260                        .map(it -> es.submit(() -> {
261                            try {
262                                while (it.hasNext()) {
263                                    final T next = it.next();
264                                    records.add(next);
265                                    permissions.release();
266                                }
267                            } finally {
268                                latch.countDown();
269                            }
270                        }))
271                        .collect(toList());
272                es.shutdown();
273
274                final int timeout = Integer.getInteger("talend.component.junit.timeout", 5);
275                new Thread() {
276
277                    {
278                        setName(BaseComponentsHandler.class.getSimpleName() + "-monitor_" + abs(mapper.hashCode()));
279                    }
280
281                    @Override
282                    public void run() {
283                        try {
284                            latch.await(timeout, MINUTES);
285                        } catch (final InterruptedException e) {
286                            Thread.currentThread().interrupt();
287                        } finally {
288                            permissions.release();
289                        }
290                    }
291                }.start();
292                return StreamDecorator.decorate(asStream(new Iterator<T>() {
293
294                    @Override
295                    public boolean hasNext() {
296                        try {
297                            permissions.acquire();
298                        } catch (final InterruptedException e) {
299                            Thread.currentThread().interrupt();
300                            fail(e.getMessage());
301                        }
302                        return !records.isEmpty();
303                    }
304
305                    @Override
306                    public T next() {
307                        final T poll = records.poll();
308                        if (poll != null) {
309                            return mapRecord(state, recordType, poll);
310                        }
311                        return null;
312                    }
313                }), task -> {
314                    try {
315                        task.run();
316                    } finally {
317                        tasks.forEach(f -> {
318                            try {
319                                f.get(5, SECONDS);
320                            } catch (final InterruptedException e) {
321                                Thread.currentThread().interrupt();
322                            } catch (final ExecutionException | TimeoutException e) {
323                                // no-op
324                            } finally {
325                                if (!f.isDone() && !f.isCancelled()) {
326                                    f.cancel(true);
327                                }
328                            }
329                        });
330                    }
331                });
332        }
333    }
334
335    private <T> Stream<T> asStream(final Iterator<T> iterator) {
336        return StreamSupport.stream(Spliterators.spliteratorUnknownSize(iterator, Spliterator.IMMUTABLE), false);
337    }
338
339    private <T> Iterator<T> asIterator(final Input input, final AtomicInteger counter) {
340        input.start();
341        return new Iterator<T>() {
342
343            private boolean closed;
344
345            private Object next;
346
347            @Override
348            public boolean hasNext() {
349                final int remaining = counter.get();
350                if (remaining <= 0) {
351                    return false;
352                }
353
354                final boolean hasNext = (next = input.next()) != null;
355                if (!hasNext && !closed) {
356                    closed = true;
357                    input.stop();
358                }
359                if (hasNext) {
360                    counter.decrementAndGet();
361                }
362                return hasNext;
363            }
364
365            @Override
366            public T next() {
367                return (T) next;
368            }
369        };
370    }
371
372    @Override
373    public <T> List<T> collectAsList(final Class<T> recordType, final Mapper mapper) {
374        return collectAsList(recordType, mapper, 1000);
375    }
376
377    @Override
378    public <T> List<T> collectAsList(final Class<T> recordType, final Mapper mapper, final int maxRecords) {
379        return collect(recordType, mapper, maxRecords).collect(toList());
380    }
381
382    @Override
383    public Mapper createMapper(final Class<?> componentType, final Object configuration) {
384        return create(Mapper.class, componentType, configuration);
385    }
386
387    @Override
388    public Processor createProcessor(final Class<?> componentType, final Object configuration) {
389        return create(Processor.class, componentType, configuration);
390    }
391
392    private <C, T, A> A create(final Class<A> api, final Class<T> componentType, final C configuration) {
393        final ComponentFamilyMeta.BaseMeta<? extends Lifecycle> meta = findMeta(componentType);
394        return api
395                .cast(meta
396                        .getInstantiator()
397                        .apply(configuration == null || meta.getParameterMetas().get().isEmpty() ? emptyMap()
398                                : configurationByExample(configuration, meta
399                                        .getParameterMetas()
400                                        .get()
401                                        .stream()
402                                        .filter(p -> p.getName().equals(p.getPath()))
403                                        .findFirst()
404                                        .map(p -> p.getName() + '.')
405                                        .orElseThrow(() -> new IllegalArgumentException(
406                                                "Didn't find any option and therefore "
407                                                        + "can't convert the configuration instance to a configuration")))));
408    }
409
410    private <T> ComponentFamilyMeta.BaseMeta<? extends Lifecycle> findMeta(final Class<T> componentType) {
411        return asManager()
412                .find(c -> c.get(ContainerComponentRegistry.class).getComponents().values().stream())
413                .flatMap(f -> Stream
414                        .of(f.getProcessors().values().stream(), f.getPartitionMappers().values().stream(),
415                                f.getDriverRunners().values().stream())
416                        .flatMap(t -> t))
417                .filter(m -> m.getType().getName().equals(componentType.getName()))
418                .findFirst()
419                .orElseThrow(() -> new IllegalArgumentException("No component " + componentType));
420    }
421
422    @Override
423    public <T> List<T> collect(final Class<T> recordType, final String family, final String component,
424            final int version, final Map<String, String> configuration) {
425        Job
426                .components()
427                .component("in",
428                        family + "://" + component + "?__version=" + version
429                                + configuration
430                                        .entrySet()
431                                        .stream()
432                                        .map(entry -> entry.getKey() + "=" + entry.getValue())
433                                        .collect(joining("&", "&", "")))
434                .component("collector", "test://collector")
435                .connections()
436                .from("in")
437                .to("collector")
438                .build()
439                .run();
440
441        return getCollectedData(recordType);
442    }
443
444    @Override
445    public <T> void process(final Iterable<T> inputs, final String family, final String component, final int version,
446            final Map<String, String> configuration) {
447        setInputData(inputs);
448
449        Job
450                .components()
451                .component("emitter", "test://emitter")
452                .component("out",
453                        family + "://" + component + "?__version=" + version
454                                + configuration
455                                        .entrySet()
456                                        .stream()
457                                        .map(entry -> entry.getKey() + "=" + entry.getValue())
458                                        .collect(joining("&", "&", "")))
459                .connections()
460                .from("emitter")
461                .to("out")
462                .build()
463                .run();
464
465    }
466
467    @Override
468    public ComponentManager asManager() {
469        return STATE.get().manager;
470    }
471
472    @Override
473    public <T> T findService(final String plugin, final Class<T> serviceClass) {
474        return serviceClass
475                .cast(asManager()
476                        .findPlugin(plugin)
477                        .orElseThrow(() -> new IllegalArgumentException("cant find plugin '" + plugin + "'"))
478                        .get(ComponentManager.AllServices.class)
479                        .getServices()
480                        .get(serviceClass));
481    }
482
483    @Override
484    public <T> T findService(final Class<T> serviceClass) {
485        return findService(getSinglePlugin(), serviceClass);
486    }
487
488    public Set<String> getTestPlugins() {
489        return new HashSet<>(EmbeddedComponentManager.class.cast(asManager()).testPlugins);
490    }
491
492    @Override
493    public <T> void setInputData(final Iterable<T> data) {
494        final State state = STATE.get();
495        if (state == null) {
496            initState.get().emitter = data.iterator();
497        } else {
498            state.emitter = data.iterator();
499        }
500    }
501
502    @Override
503    public <T> List<T> getCollectedData(final Class<T> recordType) {
504        final State state = STATE.get();
505        return state.collector
506                .stream()
507                .filter(r -> recordType.isInstance(r) || JsonObject.class.isInstance(r) || Record.class.isInstance(r))
508                .map(r -> mapRecord(state, recordType, r))
509                .collect(toList());
510    }
511
512    public void resetState() {
513        final State state = STATE.get();
514        if (state == null) {
515            STATE.remove();
516        } else {
517            state.collector.clear();
518            state.emitter = emptyIterator();
519        }
520    }
521
522    private String getSinglePlugin() {
523        return Optional
524                .of(EmbeddedComponentManager.class.cast(asManager()).testPlugins/* sorted */)
525                .filter(c -> !c.isEmpty())
526                .map(c -> c.iterator().next())
527                .orElseThrow(() -> new IllegalStateException("No component plugin found"));
528    }
529
530    private <T> T mapRecord(final State state, final Class<T> recordType, final Object r) {
531        if (recordType.isInstance(r)) {
532            return recordType.cast(r);
533        }
534        if (Record.class == recordType) {
535            return recordType
536                    .cast(new RecordConverters()
537                            .toRecord(state.registry, r, state::jsonb, state::recordBuilderFactory));
538        }
539        return recordType
540                .cast(new RecordConverters()
541                        .toType(state.registry, r, recordType, state::jsonBuilderFactory, state::jsonProvider,
542                                state::jsonb, state::recordBuilderFactory));
543    }
544
545    static class PreState {
546
547        Iterator<?> emitter;
548    }
549
550    @AllArgsConstructor
551    protected static class State {
552
553        final ComponentManager manager;
554
555        final Collection<Object> collector;
556
557        final RecordConverters.MappingMetaRegistry registry = new RecordConverters.MappingMetaRegistry();
558
559        Iterator<?> emitter;
560
561        volatile Jsonb jsonb;
562
563        volatile JsonProvider jsonProvider;
564
565        volatile JsonBuilderFactory jsonBuilderFactory;
566
567        volatile RecordBuilderFactory recordBuilderFactory;
568
569        synchronized Jsonb jsonb() {
570            if (jsonb == null) {
571                jsonb = manager
572                        .getJsonbProvider()
573                        .create()
574                        .withProvider(new PreComputedJsonpProvider("test", manager.getJsonpProvider(),
575                                manager.getJsonpParserFactory(), manager.getJsonpWriterFactory(),
576                                manager.getJsonpBuilderFactory(), manager.getJsonpGeneratorFactory(),
577                                manager.getJsonpReaderFactory())) // reuses the same memory buffers
578                        .withConfig(new JsonbConfig().setProperty("johnzon.cdi.activated", false))
579                        .build();
580            }
581            return jsonb;
582        }
583
584        synchronized JsonProvider jsonProvider() {
585            if (jsonProvider == null) {
586                jsonProvider = manager.getJsonpProvider();
587            }
588            return jsonProvider;
589        }
590
591        synchronized JsonBuilderFactory jsonBuilderFactory() {
592            if (jsonBuilderFactory == null) {
593                jsonBuilderFactory = manager.getJsonpBuilderFactory();
594            }
595            return jsonBuilderFactory;
596        }
597
598        synchronized RecordBuilderFactory recordBuilderFactory() {
599            if (recordBuilderFactory == null) {
600                recordBuilderFactory = manager.getRecordBuilderFactoryProvider().apply("test");
601            }
602            return recordBuilderFactory;
603        }
604    }
605
606    public static class EmbeddedComponentManager extends ComponentManager {
607
608        private final ComponentManager oldInstance;
609
610        private final List<String> testPlugins;
611
612        private EmbeddedComponentManager(final String componentPackage) {
613            super(findM2(), "TALEND-INF/dependencies.txt", "org.talend.sdk.component:type=component,value=%s");
614            testPlugins = addJarContaining(Thread.currentThread().getContextClassLoader(),
615                    componentPackage.replace('.', '/'));
616            container
617                    .builder("component-runtime-junit.jar", jarLocation(SimpleCollector.class).getAbsolutePath())
618                    .create();
619            oldInstance = ComponentManager.contextualInstance().get();
620            ComponentManager.contextualInstance().set(this);
621        }
622
623        @Override
624        public void close() {
625            try {
626                super.close();
627            } finally {
628                ComponentManager.contextualInstance().compareAndSet(this, oldInstance);
629            }
630        }
631
632        @Override
633        protected boolean isContainerClass(final Filter filter, final String name) {
634            // embedded mode (no plugin structure) so just run with all classes in parent classloader
635            return true;
636        }
637    }
638
639    public static class Outputs {
640
641        private final Map<String, List<?>> data = new HashMap<>();
642
643        public int size() {
644            return data.size();
645        }
646
647        public Set<String> keys() {
648            return data.keySet();
649        }
650
651        public <T> List<T> get(final Class<T> type, final String name) {
652            return (List<T>) data.get(name);
653        }
654    }
655
656    interface Local<T> {
657
658        void set(T value);
659
660        T get();
661
662        void remove();
663
664        class StaticImpl<T> implements Local<T> {
665
666            private final AtomicReference<T> state = new AtomicReference<>();
667
668            @Override
669            public void set(final T value) {
670                state.set(value);
671            }
672
673            @Override
674            public T get() {
675                return state.get();
676            }
677
678            @Override
679            public void remove() {
680                state.set(null);
681            }
682        }
683
684        class ThreadLocalImpl<T> implements Local<T> {
685
686            private final ThreadLocal<T> threadLocal = new ThreadLocal<>();
687
688            @Override
689            public void set(final T value) {
690                threadLocal.set(value);
691            }
692
693            @Override
694            public T get() {
695                return threadLocal.get();
696            }
697
698            @Override
699            public void remove() {
700                threadLocal.remove();
701            }
702        }
703    }
704}