/*
 * Decompiled with CFR 0.152.
 */
package org.nd4j.linalg.api.ops.impl.shape.tensorops;

import java.util.Map;
import org.nd4j.autodiff.samediff.SDVariable;
import org.nd4j.autodiff.samediff.SameDiff;
import org.nd4j.imports.graphmapper.tf.TFGraphMapper;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.Op;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.BaseTensorOp;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayConcatV3;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayGatherV3;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayReadV3;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayScatterV3;
import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayWriteV3;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.util.ArrayUtil;
import org.nd4j.list.compat.TensorList;
import org.tensorflow.framework.AttrValue;
import org.tensorflow.framework.GraphDef;
import org.tensorflow.framework.NodeDef;

public class TensorArrayV3
extends BaseTensorOp {
    TensorList list;

    @Override
    public String tensorflowName() {
        return "TensorArrayV3";
    }

    public TensorArrayV3(String name, SameDiff sameDiff) {
        super(name, sameDiff, new SDVariable[0]);
        this.list = new TensorList(this.getOwnName());
    }

    public TensorArrayV3(SameDiff sameDiff) {
        super(sameDiff, new SDVariable[0]);
        this.list = new TensorList(this.getOwnName());
    }

    public TensorArrayV3(TensorArrayV3 ta) {
        super(ta.sameDiff, new SDVariable[0]);
        this.list = ta.list;
    }

    public TensorArrayV3(TensorArrayV3 ta, SDVariable[] inputs) {
        super(ta.sameDiff, inputs);
        this.list = ta.list;
    }

    @Override
    public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map<String, AttrValue> attributesForNode, GraphDef graph) {
        String idd = nodeDef.getInput(nodeDef.getInputCount() - 1);
        NodeDef iddNode = null;
        for (int i = 0; i < graph.getNodeCount(); ++i) {
            if (!graph.getNode(i).getName().equals(idd)) continue;
            iddNode = graph.getNode(i);
        }
        INDArray arr = TFGraphMapper.getInstance().getNDArrayFromTensor("value", iddNode, graph);
        if (arr != null) {
            int idx = arr.getInt(0);
            this.addIArgument(idx);
        }
        this.list = new TensorList(this.getOwnName());
    }

    public TensorArrayV3() {
        this.list = new TensorList(this.getOwnName());
    }

    @Override
    public TensorList execute(SameDiff sameDiff) {
        return this.list;
    }

    @Override
    public String toString() {
        return this.opName();
    }

    @Override
    public String opName() {
        return "tensorarrayv3";
    }

    @Override
    public Op.Type opType() {
        return Op.Type.CUSTOM;
    }

    private SDVariable getVar() {
        this.getSameDiff().putListByName(this.list.getName(), this.list);
        String name = this.list.getName();
        if (this.getSameDiff().variableMap().containsKey(name)) {
            return this.getSameDiff().variableMap().get(name);
        }
        return this.getSameDiff().var(this.list.getName(), 1L);
    }

    @Override
    public SameDiff getSameDiff() {
        SameDiff sd = this.sameDiff;
        if (sd.getChild() != null) {
            return sd.getChild();
        }
        return sd;
    }

    private SDVariable intToVar(int ... index) {
        return this.sameDiff.var(Nd4j.create(ArrayUtil.toDouble((int[])index)));
    }

    public SDVariable read(int index) {
        return new TensorArrayReadV3(this.getSameDiff(), new SDVariable[]{this.getVar(), this.intToVar(index)}).outputVariable();
    }

    public SDVariable read(SDVariable index) {
        return new TensorArrayReadV3(this.getSameDiff(), new SDVariable[]{this.getVar(), index}).outputVariable();
    }

    public SDVariable gather(int ... indices) {
        return new TensorArrayGatherV3(this.getSameDiff(), new SDVariable[]{this.getVar(), this.intToVar(indices)}).outputVariable();
    }

    public SDVariable gather(SDVariable indices) {
        return new TensorArrayGatherV3(this.getSameDiff(), new SDVariable[]{this.getVar(), indices}).outputVariable();
    }

    public SDVariable stack() {
        return new TensorArrayGatherV3(this.getSameDiff(), new SDVariable[]{this.getVar(), this.intToVar(-1)}).outputVariable();
    }

    public SDVariable concat() {
        return new TensorArrayConcatV3(this.getSameDiff(), new SDVariable[]{this.getVar()}).outputVariable();
    }

    public void write(int index, SDVariable value) {
        new TensorArrayWriteV3(this.getSameDiff(), new SDVariable[]{this.getVar(), this.intToVar(index), value}).outputVariables();
    }

    public void write(SDVariable index, SDVariable value) {
        System.out.println("TA write  - " + this.sameDiff);
        new TensorArrayWriteV3(this.getSameDiff(), new SDVariable[]{this.getVar(), index, value}).outputVariables();
    }

    public void scatter(SDVariable value, int ... indices) {
        new TensorArrayScatterV3(this.getSameDiff(), new SDVariable[]{this.getVar(), this.intToVar(indices), value}).outputVariables();
    }

    public void scatter(SDVariable value, SDVariable indices) {
        new TensorArrayScatterV3(this.getSameDiff(), new SDVariable[]{this.getVar(), indices, value}).outputVariables();
    }

    public void unstack(SDVariable value) {
        new TensorArrayScatterV3(this.getSameDiff(), new SDVariable[]{this.getVar(), this.intToVar(-1), value}).outputVariables();
    }
}

