/*
 * Decompiled with CFR 0.152.
 */
package org.jpmml.rexp;

import java.io.ByteArrayInputStream;
import java.io.File;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.util.LinkedHashMap;
import java.util.List;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Label;
import org.jpmml.converter.PMMLEncoder;
import org.jpmml.converter.Schema;
import org.jpmml.converter.ValueUtil;
import org.jpmml.rexp.DecorationUtil;
import org.jpmml.rexp.ModelConverter;
import org.jpmml.rexp.RExpEncoder;
import org.jpmml.rexp.RGenericVector;
import org.jpmml.rexp.RIntegerVector;
import org.jpmml.rexp.RNumberVector;
import org.jpmml.rexp.RRaw;
import org.jpmml.rexp.RStringVector;
import org.jpmml.rexp.RVector;
import org.jpmml.xgboost.FeatureMap;
import org.jpmml.xgboost.Learner;
import org.jpmml.xgboost.ObjFunction;
import org.jpmml.xgboost.XGBoostUtil;

public class XGBoostConverter
extends ModelConverter<RGenericVector> {
    private Learner learner = null;
    private boolean compact = this.getOption("compact", Boolean.TRUE);

    public XGBoostConverter(RGenericVector booster) {
        super(booster);
    }

    @Override
    public void encodeSchema(RExpEncoder encoder) {
        RVector<?> missing;
        FeatureMap featureMap;
        RGenericVector booster = (RGenericVector)this.getObject();
        RVector<?> fmap = DecorationUtil.getVectorElement(booster, "fmap");
        RGenericVector schema = booster.getGenericElement("schema", false);
        try {
            featureMap = XGBoostConverter.loadFeatureMap(fmap);
        }
        catch (IOException ioe) {
            throw new IllegalArgumentException(ioe);
        }
        if (schema != null && (missing = schema.getVectorElement("missing", false)) != null) {
            featureMap.addMissingValue(org.jpmml.model.ValueUtil.toString(missing.asScalar()));
        }
        Learner learner = this.ensureLearner();
        ObjFunction obj = learner.obj();
        FieldName targetField = FieldName.create((String)"_target");
        List<String> targetCategories = null;
        if (schema != null) {
            RStringVector responseName = schema.getStringElement("response_name", false);
            RStringVector responseLevels = schema.getStringElement("response_levels", false);
            if (responseName != null) {
                targetField = FieldName.create((String)((String)responseName.asScalar()));
            }
            if (responseLevels != null) {
                targetCategories = responseLevels.getValues();
            }
        }
        Label label = obj.encodeLabel(targetField, targetCategories, (PMMLEncoder)encoder);
        encoder.setLabel(label);
        List features = featureMap.encodeFeatures((PMMLEncoder)encoder);
        for (Feature feature : features) {
            encoder.addFeature(feature);
        }
    }

    public MiningModel encodeModel(Schema schema) {
        RGenericVector booster = (RGenericVector)this.getObject();
        RNumberVector<?> ntreeLimit = booster.getNumericElement("ntreelimit", false);
        Learner learner = this.ensureLearner();
        LinkedHashMap<String, Comparable<Boolean>> options = new LinkedHashMap<String, Comparable<Boolean>>();
        options.put("compact", Boolean.valueOf(this.compact));
        options.put("ntree_limit", ntreeLimit != null ? ValueUtil.asInteger((Number)((Number)ntreeLimit.asScalar())) : null);
        Schema xgbSchema = learner.toXGBoostSchema(schema);
        MiningModel miningModel = learner.encodeMiningModel(options, xgbSchema);
        return miningModel;
    }

    private Learner ensureLearner() {
        if (this.learner == null) {
            this.learner = this.loadLearner();
        }
        return this.learner;
    }

    private Learner loadLearner() {
        RGenericVector booster = (RGenericVector)this.getObject();
        RRaw raw = (RRaw)booster.getElement("raw");
        try {
            return XGBoostConverter.loadLearner(raw);
        }
        catch (IOException ioe) {
            throw new IllegalArgumentException(ioe);
        }
    }

    private static FeatureMap loadFeatureMap(RVector<?> fmap) throws IOException {
        if (fmap instanceof RStringVector) {
            return XGBoostConverter.loadFeatureMap((RStringVector)fmap);
        }
        if (fmap instanceof RGenericVector) {
            return XGBoostConverter.loadFeatureMap((RGenericVector)fmap);
        }
        throw new IllegalArgumentException();
    }

    private static FeatureMap loadFeatureMap(RStringVector fmap) throws IOException {
        File file = new File((String)fmap.asScalar());
        try (FileInputStream is = new FileInputStream(file);){
            FeatureMap featureMap = XGBoostUtil.loadFeatureMap((InputStream)is);
            return featureMap;
        }
    }

    private static FeatureMap loadFeatureMap(RGenericVector fmap) {
        RIntegerVector id = (RIntegerVector)fmap.getValue(0);
        RIntegerVector name = (RIntegerVector)fmap.getValue(1);
        RIntegerVector type = (RIntegerVector)fmap.getValue(2);
        if (!name.isFactor() || !type.isFactor()) {
            throw new IllegalArgumentException();
        }
        FeatureMap featureMap = new FeatureMap();
        for (int i = 0; i < id.size(); ++i) {
            if (i != id.getValue(i)) {
                throw new IllegalArgumentException();
            }
            featureMap.addEntry(name.getFactorValue(i), type.getFactorValue(i));
        }
        return featureMap;
    }

    private static Learner loadLearner(RRaw raw) throws IOException {
        byte[] value = raw.getValue();
        try (ByteArrayInputStream is = new ByteArrayInputStream(value);){
            Learner learner = XGBoostUtil.loadLearner((InputStream)is);
            return learner;
        }
    }
}

