/*
 * Decompiled with CFR 0.152.
 */
package org.deeplearning4j.nn.modelimport.keras;

import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import org.deeplearning4j.nn.conf.InputPreProcessor;
import org.deeplearning4j.nn.conf.graph.GraphVertex;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.layers.FeedForwardLayer;
import org.deeplearning4j.nn.conf.layers.Layer;
import org.deeplearning4j.nn.conf.layers.samediff.SameDiffLambdaLayer;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfiguration;
import org.deeplearning4j.nn.modelimport.keras.config.KerasLayerConfigurationFactory;
import org.deeplearning4j.nn.modelimport.keras.exceptions.InvalidKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.exceptions.UnsupportedKerasConfigurationException;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasLayerUtils;
import org.deeplearning4j.nn.modelimport.keras.utils.KerasRegularizerUtils;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.util.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

public class KerasLayer {
    private static final Logger log = LoggerFactory.getLogger(KerasLayer.class);
    private static final String LAYER_FIELD_KERAS_VERSION = "keras_version";
    static final Map<String, Class<? extends KerasLayer>> customLayers = new HashMap<String, Class<? extends KerasLayer>>();
    static final Map<String, SameDiffLambdaLayer> lambdaLayers = new HashMap<String, SameDiffLambdaLayer>();
    protected String className;
    protected String layerName;
    protected int[] inputShape;
    protected DimOrder dimOrder;
    protected List<String> inboundLayerNames;
    protected Layer layer;
    protected GraphVertex vertex;
    protected Map<String, INDArray> weights;
    protected double weightL1Regularization = 0.0;
    protected double weightL2Regularization = 0.0;
    protected double dropout = 1.0;
    protected Integer kerasMajorVersion = 2;
    protected KerasLayerConfiguration conf;

    protected KerasLayer(Integer kerasVersion) throws UnsupportedKerasConfigurationException {
        this.className = null;
        this.layerName = null;
        this.inputShape = null;
        this.dimOrder = DimOrder.NONE;
        this.inboundLayerNames = new ArrayList<String>();
        this.layer = null;
        this.vertex = null;
        this.weights = null;
        this.kerasMajorVersion = kerasVersion;
        this.conf = KerasLayerConfigurationFactory.get(this.kerasMajorVersion);
    }

    protected KerasLayer() throws UnsupportedKerasConfigurationException {
        this.className = null;
        this.layerName = null;
        this.inputShape = null;
        this.dimOrder = DimOrder.NONE;
        this.inboundLayerNames = new ArrayList<String>();
        this.layer = null;
        this.vertex = null;
        this.weights = null;
        this.conf = KerasLayerConfigurationFactory.get(this.kerasMajorVersion);
    }

    protected KerasLayer(Map<String, Object> layerConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this(layerConfig, true);
    }

    protected KerasLayer(Map<String, Object> layerConfig, boolean enforceTrainingConfig) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        this.kerasMajorVersion = (Integer)layerConfig.get(LAYER_FIELD_KERAS_VERSION);
        this.conf = KerasLayerConfigurationFactory.get(this.kerasMajorVersion);
        this.className = KerasLayerUtils.getClassNameFromConfig(layerConfig, this.conf);
        if (this.className == null) {
            throw new InvalidKerasConfigurationException("Keras layer class name is missing");
        }
        this.layerName = KerasLayerUtils.getLayerNameFromConfig(layerConfig, this.conf);
        if (this.layerName == null) {
            throw new InvalidKerasConfigurationException("Keras layer class name is missing");
        }
        this.inputShape = KerasLayerUtils.getInputShapeFromConfig(layerConfig, this.conf);
        this.dimOrder = KerasLayerUtils.getDimOrderFromConfig(layerConfig, this.conf);
        this.inboundLayerNames = KerasLayerUtils.getInboundLayerNamesFromConfig(layerConfig, this.conf);
        this.layer = null;
        this.vertex = null;
        this.weights = null;
        this.weightL1Regularization = KerasRegularizerUtils.getWeightRegularizerFromConfig(layerConfig, this.conf, this.conf.getLAYER_FIELD_W_REGULARIZER(), this.conf.getREGULARIZATION_TYPE_L1());
        this.weightL2Regularization = KerasRegularizerUtils.getWeightRegularizerFromConfig(layerConfig, this.conf, this.conf.getLAYER_FIELD_W_REGULARIZER(), this.conf.getREGULARIZATION_TYPE_L2());
        this.dropout = KerasLayerUtils.getDropoutFromConfig(layerConfig, this.conf);
        KerasLayerUtils.checkForUnsupportedConfigurations(layerConfig, enforceTrainingConfig, this.conf);
    }

    public static void registerLambdaLayer(String lambdaLayerName, SameDiffLambdaLayer sameDiffLambdaLayer) {
        lambdaLayers.put(lambdaLayerName, sameDiffLambdaLayer);
    }

    public static void clearLambdaLayers() {
        lambdaLayers.clear();
    }

    public static void registerCustomLayer(String layerName, Class<? extends KerasLayer> configClass) {
        customLayers.put(layerName, configClass);
    }

    public static void clearCustomLayers() {
        customLayers.clear();
    }

    public Integer getKerasMajorVersion() {
        return this.kerasMajorVersion;
    }

    public String getClassName() {
        return this.className;
    }

    public String getLayerName() {
        return this.layerName;
    }

    public int[] getInputShape() {
        if (this.inputShape == null) {
            return null;
        }
        return (int[])this.inputShape.clone();
    }

    protected DimOrder getDimOrder() {
        return this.dimOrder;
    }

    void setDimOrder(DimOrder dimOrder) {
        this.dimOrder = dimOrder;
    }

    public List<String> getInboundLayerNames() {
        if (this.inboundLayerNames == null) {
            this.inboundLayerNames = new ArrayList<String>();
        }
        return this.inboundLayerNames;
    }

    public void setInboundLayerNames(List<String> inboundLayerNames) {
        this.inboundLayerNames = new ArrayList<String>(inboundLayerNames);
    }

    public int getNumParams() {
        return 0;
    }

    public boolean usesRegularization() {
        return this.weightL1Regularization > 0.0 || this.weightL2Regularization > 0.0 || this.dropout < 1.0;
    }

    public void setWeights(Map<String, INDArray> weights) throws InvalidKerasConfigurationException {
    }

    public Map<String, INDArray> getWeights() {
        return this.weights;
    }

    public void copyWeightsToLayer(org.deeplearning4j.nn.api.Layer layer) throws InvalidKerasConfigurationException {
        if (this.getNumParams() > 0) {
            String dl4jLayerName = layer.conf().getLayer().getLayerName();
            String kerasLayerName = this.getLayerName();
            String msg = "Error when attempting to copy weights from Keras layer " + kerasLayerName + " to DL4J layer " + dl4jLayerName;
            if (this.getWeights() == null) {
                throw new InvalidKerasConfigurationException(msg + "(weights is null)");
            }
            HashSet paramsInLayer = new HashSet(layer.paramTable().keySet());
            HashSet<String> paramsInKerasLayer = new HashSet<String>(this.weights.keySet());
            paramsInLayer.removeAll(paramsInKerasLayer);
            if (!paramsInLayer.isEmpty()) {
                String joinedParamsInLayer = StringUtils.join((CharSequence)", ", paramsInLayer);
                throw new InvalidKerasConfigurationException(msg + "(no stored weights for parameters: " + joinedParamsInLayer + ")");
            }
            paramsInKerasLayer.removeAll(layer.paramTable().keySet());
            if (!paramsInKerasLayer.isEmpty()) {
                String joinedParamsInKerasLayer = StringUtils.join((CharSequence)", ", paramsInKerasLayer);
                throw new InvalidKerasConfigurationException(msg + "(found no parameters named: " + joinedParamsInKerasLayer + ")");
            }
            for (String paramName : layer.paramTable().keySet()) {
                try {
                    layer.setParam(paramName, this.weights.get(paramName));
                }
                catch (Exception e) {
                    log.error(e.getMessage());
                    throw new InvalidKerasConfigurationException(e.getMessage() + "\nTried to set weights for layer with name " + this.getLayerName() + ", of " + layer.conf().getLayer().getClass() + ".\nFailed to set weights for parameter " + paramName + "\nExpected shape for this parameter: " + layer.getParam(paramName).shapeInfoToString() + ", \ngot: " + this.weights.get(paramName).shapeInfoToString());
                }
            }
        }
    }

    public boolean isLayer() {
        return this.layer != null;
    }

    public Layer getLayer() {
        return this.layer;
    }

    public boolean isVertex() {
        return this.vertex != null;
    }

    public GraphVertex getVertex() {
        return this.vertex;
    }

    public boolean isInputPreProcessor() {
        return false;
    }

    protected long getNInFromConfig(Map<String, ? extends KerasLayer> previousLayers) throws UnsupportedKerasConfigurationException {
        int size = previousLayers.size();
        int count = 0;
        String inboundLayerName = this.inboundLayerNames.get(0);
        while (count <= size) {
            if (!previousLayers.containsKey(inboundLayerName)) continue;
            KerasLayer inbound = previousLayers.get(inboundLayerName);
            try {
                FeedForwardLayer ffLayer = (FeedForwardLayer)inbound.getLayer();
                long nIn = ffLayer.getNOut();
                if (nIn > 0L) {
                    return nIn;
                }
                ++count;
                inboundLayerName = inbound.getInboundLayerNames().get(0);
            }
            catch (Exception e) {
                inboundLayerName = inbound.getInboundLayerNames().get(0);
            }
        }
        throw new UnsupportedKerasConfigurationException("Could not determine number of input channels fordepthwise convolution.");
    }

    public InputPreProcessor getInputPreprocessor(InputType ... inputType) throws InvalidKerasConfigurationException {
        InputPreProcessor preprocessor = null;
        if (this.layer != null) {
            if (inputType.length > 1) {
                throw new InvalidKerasConfigurationException("Keras layer of type \"" + this.className + "\" accepts only one input");
            }
            preprocessor = this.layer.getPreProcessorForInputType(inputType[0]);
        }
        return preprocessor;
    }

    public InputType getOutputType(InputType ... inputType) throws InvalidKerasConfigurationException, UnsupportedKerasConfigurationException {
        throw new UnsupportedOperationException("Cannot determine output type for Keras layer of type " + this.className);
    }

    public boolean isValidInboundLayer() throws InvalidKerasConfigurationException {
        return this.getLayer() != null || this.getVertex() != null || this.getInputPreprocessor(new InputType[0]) != null || this.className.equals(this.conf.getLAYER_CLASS_NAME_INPUT());
    }

    public static enum DimOrder {
        NONE,
        THEANO,
        TENSORFLOW;

    }
}

