/*
 * Decompiled with CFR 0.152.
 */
package org.biojava.bio.dp;

import java.io.PrintStream;
import java.util.Collections;
import java.util.HashMap;
import java.util.Iterator;
import java.util.NoSuchElementException;
import org.biojava.bio.Annotation;
import org.biojava.bio.BioError;
import org.biojava.bio.BioException;
import org.biojava.bio.dist.Distribution;
import org.biojava.bio.dist.DistributionFactory;
import org.biojava.bio.dist.OrderNDistributionFactory;
import org.biojava.bio.dp.DP;
import org.biojava.bio.dp.EmissionState;
import org.biojava.bio.dp.MagicalState;
import org.biojava.bio.dp.MarkovModel;
import org.biojava.bio.dp.SimpleEmissionState;
import org.biojava.bio.dp.SimpleMarkovModel;
import org.biojava.bio.dp.SimpleWeightMatrix;
import org.biojava.bio.dp.State;
import org.biojava.bio.dp.WMAsMM;
import org.biojava.bio.dp.WeightMatrix;
import org.biojava.bio.seq.io.SymbolTokenization;
import org.biojava.bio.symbol.Alphabet;
import org.biojava.bio.symbol.AlphabetManager;
import org.biojava.bio.symbol.FiniteAlphabet;
import org.biojava.bio.symbol.IllegalAlphabetException;
import org.biojava.bio.symbol.IllegalSymbolException;
import org.biojava.bio.symbol.Symbol;
import org.biojava.utils.ChangeVetoException;
import org.w3c.dom.Element;
import org.w3c.dom.NodeList;

public class XmlMarkovModel {
    public static WeightMatrix readMatrix(Element root) throws IllegalSymbolException, IllegalAlphabetException, BioException {
        Element alphaE = (Element)root.getElementsByTagName("alphabet").item(0);
        Alphabet sa = AlphabetManager.alphabetForName(alphaE.getAttribute("name"));
        if (!(sa instanceof FiniteAlphabet)) {
            throw new IllegalAlphabetException("Can't read WeightMatrix over infinite alphabet " + sa.getName() + " of type " + sa.getClass());
        }
        FiniteAlphabet seqAlpha = (FiniteAlphabet)sa;
        SymbolTokenization symParser = seqAlpha.getTokenization("token");
        SymbolTokenization nameParser = seqAlpha.getTokenization("name");
        int columns = 0;
        NodeList colL = root.getElementsByTagName("col");
        int i = 0;
        while (i < colL.getLength()) {
            int indx = Integer.parseInt(((Element)colL.item(i)).getAttribute("indx"));
            columns = Math.max(columns, indx);
            ++i;
        }
        SimpleWeightMatrix wm = new SimpleWeightMatrix(seqAlpha, columns, DistributionFactory.DEFAULT);
        colL = root.getElementsByTagName("col");
        int i2 = 0;
        while (i2 < colL.getLength()) {
            Element colE = (Element)colL.item(i2);
            int indx = Integer.parseInt(colE.getAttribute("indx")) - 1;
            NodeList weights = colE.getElementsByTagName("weight");
            int j = 0;
            while (j < weights.getLength()) {
                Element weightE = (Element)weights.item(j);
                String symName = weightE.getAttribute("res");
                Symbol sym = symName.length() > 1 ? nameParser.parseToken(symName) : symParser.parseToken(symName);
                try {
                    wm.getColumn(indx).setWeight(sym, Double.parseDouble(weightE.getAttribute("prob")));
                }
                catch (ChangeVetoException cve) {
                    throw new BioError("Assertion failure: Should be able to set the weights");
                }
                ++j;
            }
            ++i2;
        }
        return wm;
    }

    public static MarkovModel readModel(Element root) throws BioException, IllegalSymbolException, IllegalAlphabetException {
        if (root.getTagName().equals("WeightMatrix")) {
            return new WMAsMM(XmlMarkovModel.readMatrix(root));
        }
        int heads = Integer.parseInt(root.getAttribute("heads"));
        Element alphaE = (Element)root.getElementsByTagName("alphabet").item(0);
        Alphabet seqAlpha = AlphabetManager.alphabetForName(alphaE.getAttribute("name"));
        SimpleMarkovModel model = new SimpleMarkovModel(heads, seqAlpha);
        int[] advance = new int[heads];
        int i = 0;
        while (i < heads) {
            advance[i] = 1;
            ++i;
        }
        SymbolTokenization nameParser = null;
        SymbolTokenization symbolParser = null;
        try {
            nameParser = seqAlpha.getTokenization("name");
        }
        catch (NoSuchElementException nsee) {
            // empty catch block
        }
        try {
            symbolParser = seqAlpha.getTokenization("token");
        }
        catch (NoSuchElementException nsee) {
            // empty catch block
        }
        if (nameParser == null && symbolParser == null) {
            throw new BioException("Couldn't find a parser for alphabet " + seqAlpha.getName());
        }
        HashMap<String, SimpleEmissionState> nameToState = new HashMap<String, SimpleEmissionState>();
        nameToState.put("_start_", model.magicalState());
        nameToState.put("_end_", model.magicalState());
        nameToState.put("_START_", model.magicalState());
        nameToState.put("_END_", model.magicalState());
        NodeList states = root.getElementsByTagName("state");
        DistributionFactory dFact = seqAlpha.getAlphabets().size() > 1 && seqAlpha.getAlphabets().equals(Collections.nCopies(seqAlpha.getAlphabets().size(), seqAlpha.getAlphabets().get(0))) ? OrderNDistributionFactory.DEFAULT : DistributionFactory.DEFAULT;
        int i2 = 0;
        while (i2 < states.getLength()) {
            Element stateE = (Element)states.item(i2);
            String name = stateE.getAttribute("name");
            Distribution dis = dFact.createDistribution(seqAlpha);
            SimpleEmissionState state = new SimpleEmissionState(name, Annotation.EMPTY_ANNOTATION, advance, dis);
            nameToState.put(name, state);
            NodeList weights = stateE.getElementsByTagName("weight");
            int j = 0;
            while (j < weights.getLength()) {
                Symbol sym;
                Element weightE = (Element)weights.item(j);
                String symName = weightE.getAttribute("res");
                if (symName == null || "".equals(symName)) {
                    symName = weightE.getAttribute("sym");
                }
                if (symName.length() == 1) {
                    sym = symbolParser != null ? symbolParser.parseToken(symName) : nameParser.parseToken(symName);
                } else {
                    try {
                        sym = nameParser != null ? nameParser.parseToken(symName) : symbolParser.parseToken(symName);
                    }
                    catch (IllegalSymbolException ise) {
                        throw new BioException(ise, "Can't extract symbol from " + weightE + " in " + stateE);
                    }
                }
                try {
                    dis.setWeight(sym, Double.parseDouble(weightE.getAttribute("prob")));
                }
                catch (ChangeVetoException cve) {
                    throw new BioError(cve, "Assertion failure: Should be able to edit distribution");
                }
                ++j;
            }
            try {
                model.addState(state);
            }
            catch (ChangeVetoException cve) {
                throw new BioError(cve, "Assertion failure: Should be able to add states to model");
            }
            ++i2;
        }
        NodeList transitions = root.getElementsByTagName("transition");
        int i3 = 0;
        while (i3 < transitions.getLength()) {
            Element transitionE = (Element)transitions.item(i3);
            State from = (State)nameToState.get(transitionE.getAttribute("from"));
            State to = (State)nameToState.get(transitionE.getAttribute("to"));
            double prob = Double.parseDouble(transitionE.getAttribute("prob"));
            try {
                model.createTransition(from, to);
            }
            catch (IllegalSymbolException ite) {
                throw new BioError(ite, "We should have unlimited write-access to this model. Something is very wrong.");
            }
            catch (ChangeVetoException cve) {
                throw new BioError(cve, "We should have unlimited write-access to this model. Something is very wrong.");
            }
            ++i3;
        }
        int i4 = 0;
        while (i4 < transitions.getLength()) {
            Element transitionE = (Element)transitions.item(i4);
            State from = (State)nameToState.get(transitionE.getAttribute("from"));
            State to = (State)nameToState.get(transitionE.getAttribute("to"));
            double prob = Double.parseDouble(transitionE.getAttribute("prob"));
            try {
                model.getWeights(from).setWeight(to, prob);
            }
            catch (IllegalSymbolException ite) {
                throw new BioError(ite, "We should have unlimited write-access to this model. Something is very wrong.");
            }
            catch (ChangeVetoException cve) {
                throw new BioError(cve, "We should have unlimited write-access to this model. Something is very wrong.");
            }
            ++i4;
        }
        return model;
    }

    public static void writeMatrix(WeightMatrix matrix, PrintStream out) throws Exception {
        FiniteAlphabet symA = (FiniteAlphabet)matrix.getAlphabet();
        out.println("<MarkovModel>\n  <alphabet name=\"" + symA.getName() + "\"/>");
        int i = 0;
        while (i < matrix.columns()) {
            out.println("  <col indx=\"" + (i + 1) + "\">");
            Iterator si = symA.iterator();
            while (si.hasNext()) {
                Symbol s = (Symbol)si.next();
                out.println("    <weight sym=\"" + s.getName() + "\" prob=\"" + matrix.getColumn(i).getWeight(s) + "\"/>");
            }
            out.println("  </col>");
            ++i;
        }
        out.println("</MarkovModel>");
    }

    public static void writeModel(MarkovModel model, PrintStream out) throws Exception {
        model = DP.flatView(model);
        FiniteAlphabet stateA = model.stateAlphabet();
        FiniteAlphabet symA = (FiniteAlphabet)model.emissionAlphabet();
        out.println("<MarkovModel heads=\"" + model.heads() + "\">");
        out.println("<alphabet name=\"" + symA.getName() + "\"/>");
        Iterator stateI = stateA.iterator();
        while (stateI.hasNext()) {
            State s = (State)stateI.next();
            if (s instanceof MagicalState) continue;
            out.println("  <state name=\"" + s.getName() + "\">");
            if (s instanceof EmissionState) {
                EmissionState es = (EmissionState)s;
                Distribution dis = es.getDistribution();
                Iterator symI = symA.iterator();
                while (symI.hasNext()) {
                    Symbol sym = (Symbol)symI.next();
                    out.println("    <weight sym=\"" + sym.getName() + "\" prob=\"" + dis.getWeight(sym) + "\"/>");
                }
            }
            out.println("  </state>");
        }
        Iterator i = stateA.iterator();
        while (i.hasNext()) {
            State from = (State)i.next();
            XmlMarkovModel.printTransitions(model, from, out);
        }
        out.println("</MarkovModel>");
    }

    private static void printTransitions(MarkovModel model, State from, PrintStream out) throws IllegalSymbolException {
        Iterator i = model.transitionsFrom(from).iterator();
        while (i.hasNext()) {
            State to = (State)i.next();
            try {
                out.println("  <transition from=\"" + (from instanceof MagicalState ? "_start_" : from.getName()) + "\" to=\"" + (to instanceof MagicalState ? "_end_" : to.getName()) + "\" prob=\"" + model.getWeights(from).getWeight(to) + "\"/>");
            }
            catch (IllegalSymbolException ite) {
                throw new BioError(ite, "Transition listed in transitionsFrom(" + from.getName() + ") has dissapeared");
            }
        }
    }
}

