/*
 * Decompiled with CFR 0.152.
 */
package org.apache.nifi.stateless.bootstrap;

import java.lang.reflect.Method;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class AllowListClassLoader
extends ClassLoader {
    private static final Logger logger = LoggerFactory.getLogger(AllowListClassLoader.class);
    private final Set<String> allowedClassNames;
    private final List<String> allowedModulePrefixes = Arrays.asList("java.", "jdk.");

    public AllowListClassLoader(ClassLoader parent, Set<String> allowed) {
        super(parent);
        this.allowedClassNames = allowed;
    }

    public Set<String> getClassesAllowed() {
        return Collections.unmodifiableSet(this.allowedClassNames);
    }

    @Override
    protected Class<?> loadClass(String name, boolean resolve) throws ClassNotFoundException {
        if (this.allowedClassNames.contains(name)) {
            return super.loadClass(name, resolve);
        }
        try {
            Class<?> found = super.loadClass(name, false);
            boolean allowed = this.isClassAllowed(name, found);
            if (allowed) {
                if (resolve) {
                    super.resolveClass(found);
                }
                return found;
            }
        }
        catch (NoClassDefFoundError noClassDefFoundError) {
            // empty catch block
        }
        throw new ClassNotFoundException(name + " was blocked by AllowListClassLoader");
    }

    @Override
    protected Class<?> findClass(String name) throws ClassNotFoundException {
        Class<?> found = super.findClass(name);
        if (this.isClassAllowed(name, found)) {
            return found;
        }
        throw new ClassNotFoundException(name + " was blocked by AllowListClassLoader");
    }

    private boolean isClassAllowed(String name, Class<?> clazz) {
        if (this.allowedClassNames.contains(name)) {
            return true;
        }
        try {
            Method getModule = Class.class.getMethod("getModule", new Class[0]);
            Object module = getModule.invoke(clazz, new Object[0]);
            if (module == null) {
                return false;
            }
            Method getName = module.getClass().getMethod("getName", new Class[0]);
            String moduleName = (String)getName.invoke(module, new Object[0]);
            if (this.isModuleAllowed(moduleName)) {
                logger.debug("Allowing Class {} because its module is {}", (Object)name, (Object)moduleName);
                return true;
            }
            return false;
        }
        catch (Exception e) {
            logger.debug("Failed to determine if class {} is part of the implicitly allowed modules", (Object)name, (Object)e);
            return false;
        }
    }

    private boolean isModuleAllowed(String moduleName) {
        for (String prefix : this.allowedModulePrefixes) {
            if (!moduleName.startsWith(prefix)) continue;
            return true;
        }
        return false;
    }
}

