Commit 6ddda77a authored by Bruno López Trigo's avatar Bruno López Trigo

O sistema permite subir un dataset de test e xerar a matriz de confusión do mesmo

parent 7d65731c
......@@ -204,13 +204,11 @@ public class ClassifierManagerImpl implements ClassifierManager {
m = this.matrixBuilder.readMatrix();
this.mapper.writeJSON(m, new File(matrixLocation));
return m;
} else if(type.equals("train")){
} else {
this.matrixBuilder = new MatrixBuilder(input);
m = this.matrixBuilder.buildMatrixInstances(token, dataset, algorithm);
m = this.matrixBuilder.buildMatrixInstances(token, dataset, algorithm, type);
this.mapper.writeJSON(m, new File(matrixLocation));
return m;
} else {
return m;
}
}
......
......@@ -75,7 +75,7 @@ public class MatrixBuilder {
}
public Matrix buildMatrixInstances(String token, String name, String algorithm) throws NotFoundEx, IOException, FormatEx{
public Matrix buildMatrixInstances(String token, String name, String algorithm, String type) throws NotFoundEx, IOException, FormatEx{
File configFile;
MapperJSON mapper = new MapperJSON();
......@@ -87,7 +87,13 @@ public class MatrixBuilder {
}
DatasetConfig config = mapper.readConfigJSON(configFile);
Dataset dataset = this.datasetManager.getDataset(token, name);
Dataset dataset;
if(type.equals("train"))
dataset = this.datasetManager.getDataset(token, name);
else
dataset = this.datasetManager.getTestDataset(token, name);
Instance instance;
Classification classification;
Matrix m = new Matrix(config.getConsequents().size());
......
......@@ -108,8 +108,15 @@ public class TreeInterpreter {
}
} else if (path.get(path.size() - 1) instanceof CategoricNode) {
att = ((CategoricNode) path.get(path.size() - 1)).getAttribute();
path.add(((CategoricNode) path.get(path.size() - 1)).getChild(instance.getCategoricValue(att.getId())));
classification.addAntecedent(new Antecedent(att.getId() + " is " + instance.getCategoricValue(att.getId())));
Node nodeAux = ((CategoricNode) path.get(path.size() - 1)).getChild(instance.getCategoricValue(att.getId()));
if (nodeAux != null) {
classification.addAntecedent(new Antecedent(att.getId() + " is " + instance.getCategoricValue(att.getId())));
path.add(nodeAux);
} else {
nodeAux = ((CategoricNode) path.get(path.size() - 1)).getNotChild(att.getId());
classification.addAntecedent(new Antecedent(att.getId() + " not " + ((CategoricNode) path.get(path.size() - 1)).getNotChildValue(att.getId())));
path.add(nodeAux);
}
}
if (path.get(path.size() - 1) instanceof ConsequentNode) {
if (instance.getSolution() != null) {
......@@ -262,8 +269,15 @@ public class TreeInterpreter {
}
} else if (path.get(path.size() - 1) instanceof CategoricNode) {
att = ((CategoricNode) path.get(path.size() - 1)).getAttribute();
path.add(((CategoricNode) path.get(path.size() - 1)).getChild(alternatives.get(0).getCategoricValue(att.getId())));
classification.addAntecedent(new Antecedent(att.getId() + " is " + alternatives.get(0).getCategoricValue(att.getId())));
Node nodeAux = ((CategoricNode) path.get(path.size() - 1)).getChild(instance.getCategoricValue(att.getId()));
if (nodeAux != null) {
classification.addAntecedent(new Antecedent(att.getId() + " is " + instance.getCategoricValue(att.getId())));
path.add(nodeAux);
} else {
nodeAux = ((CategoricNode) path.get(path.size() - 1)).getNotChild(att.getId());
classification.addAntecedent(new Antecedent(att.getId() + " not " + ((CategoricNode) path.get(path.size() - 1)).getNotChildValue(att.getId())));
path.add(nodeAux);
}
}
if (path.get(path.size() - 1) instanceof ConsequentNode && !solutions.contains(((ConsequentNode) path.get(path.size() - 1)).getConsequent().getId())) {
if (alternatives.get(0).getSolution() != null) {
......
......@@ -15,7 +15,9 @@ public interface DatasetManager {
public ArrayList<Dataset> listDatasets(String token) throws NotFoundEx;
public ArrayList<Dataset> deleteDatasets(String token) throws NotFoundEx;
public Dataset getDataset(String token, String name) throws NotFoundEx, IOException;
public Dataset getTestDataset(String token, String name) throws NotFoundEx, IOException;
public Dataset deleteDataset(String token, String name) throws NotFoundEx;
public ModifiedDataset uploadDataset(String token, String name, InputStream file) throws FormatEx, ConflictEx, IOException;
public ModifiedDataset uploadTestDataset(String token, String name, InputStream file) throws FormatEx, IOException;
public Line getDatasetLine(String token, String name, int number) throws FormatEx, NotFoundEx, IOException;
}
......@@ -85,6 +85,20 @@ public class DatasetManagerImpl implements DatasetManager {
return dataset;
}
@Override
public Dataset getTestDataset(String token, String name) throws NotFoundEx, IOException {
File testFile;
Dataset dataset = null;
testFile = this.fmanager.getTest(token, name);
dataset = buildDataset(name, testFile);
return dataset;
}
private Dataset buildDataset(String name, File datasetFile) throws IOException {
Dataset dataset = new Dataset(name);
......@@ -191,6 +205,25 @@ public class DatasetManagerImpl implements DatasetManager {
return dataset;
}
@Override
public ModifiedDataset uploadTestDataset(String token, String name, InputStream file) throws FormatEx, IOException {
File testDataset = new File(this.fmanager.getTestLocation(token, name));
String tmpLocation = this.fmanager.getBASELocation(token) + "/" + name + "/" + name + "-test.arff.tmp";
ModifiedDataset dataset = null;
new File(this.fmanager.getBASE(token) + "/" + name).mkdirs();
writeToFile(file, tmpLocation);
ArrayList<String> modified = FormatChecker.checkArffFormat(new File(tmpLocation));
dataset = new ModifiedDataset(buildDataset(name, testDataset), modified);
return dataset;
}
private void writeToFile(InputStream uploadedInputStream,
String uploadedFileLocation) throws IOException {
......
......@@ -2,11 +2,13 @@ package brunolopez.expliclas.explainer;
import brunolopez.expliclas.classifiers.MatrixBuilder;
import brunolopez.expliclas.models.Matrix;
import brunolopez.expliclas.models.Position;
import de.normalisiert.utils.graphs.ElementaryCyclesSearch;
import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Scanner;
public class ConfusionAnalyzer {
......@@ -19,6 +21,10 @@ public class ConfusionAnalyzer {
this.sc = new Scanner(input);
}
public ConfusionAnalyzer(Matrix matrix) throws FileNotFoundException {
this.matrix = matrix;
}
public double getGlobalPercentage(){
String line = this.sc.nextLine();
......@@ -33,6 +39,12 @@ public class ConfusionAnalyzer {
}
public ArrayList<Integer> getWrongInstances(int i, int j){
return this.matrix.getConfused().get(new Position(i, j));
}
public List getLongestCycle() {
Integer[] nodes;
......
......@@ -12,6 +12,7 @@ public interface ExplainerManager {
Explanation getGlobalExplanation(String token, String dataset, String algorithm, String language) throws IOException, NotFoundEx, FormatEx;
Explanation getConfusion(String token, String dataset, String algorithm, ArrayList<Consequent> consequents, String language) throws FormatEx, IOException, NotFoundEx;
Explanation getConfusionInstances(String token, String dataset, String algorithm, ArrayList<Consequent> consequents, String language, String type) throws FormatEx, IOException, NotFoundEx;
Explanation getLocalExplanation(String token, String dataset, String algorithm, ArrayList<Classification> classifications, String language) throws FormatEx, IOException, NotFoundEx;
}
package brunolopez.expliclas.explainer;
import brunolopez.expliclas.classifiers.ClassifierManager;
import brunolopez.expliclas.classifiers.ClassifierManagerImpl;
import brunolopez.expliclas.exceptions.FormatEx;
import brunolopez.expliclas.exceptions.NotFoundEx;
import brunolopez.expliclas.models.Classification;
......@@ -7,6 +9,7 @@ import brunolopez.expliclas.models.Consequent;
import brunolopez.expliclas.utils.MapperJSON;
import brunolopez.expliclas.models.Explanation;
import brunolopez.expliclas.models.GlobalConfig;
import brunolopez.expliclas.models.Matrix;
import brunolopez.expliclas.utils.FileManager;
import java.io.File;
import java.io.IOException;
......@@ -26,8 +29,10 @@ public class ExplainerManagerImpl implements ExplainerManager {
private InfoExtractor extractor;
private GlobalConfig globalConfig;
private final FileManager fmanager;
private final ClassifierManager cmanager;
public ExplainerManagerImpl() {
this.cmanager = new ClassifierManagerImpl();
this.fmanager = new FileManager();
this.mapper = new MapperJSON();
}
......@@ -434,7 +439,7 @@ public class ExplainerManagerImpl implements ExplainerManager {
this.generator = new ClauseGeneratorEn();
if (confusion != 0) {
phrase = this.generator.generateClause(
consequents.get(0).getName(), "be", false, "confused with " + consequents.get(1).getName() + " " + df.format(confusion) + "%");
consequents.get(0).getName(), "be", false, "confused with " + consequents.get(1).getName() + " by " + df.format(confusion) + "%");
} else {
phrase = this.generator.generateClause(
consequents.get(0).getName(), "be", false, "never confused with " + consequents.get(1).getName());
......@@ -467,6 +472,61 @@ public class ExplainerManagerImpl implements ExplainerManager {
return explanation;
}
@Override
public Explanation getConfusionInstances(String token, String dataset, String algorithm, ArrayList<Consequent> consequents, String language, String type) throws FormatEx, NotFoundEx, IOException {
if (consequents == null || consequents.size() != 2) {
throw new FormatEx("You must specify two consequents");
}
try {
File config = this.fmanager.getConfig(dataset, language);
File global = this.fmanager.getGlobalConfig(language);
this.globalConfig = this.mapper.readGlobalConfigJSON(global);
Matrix m = this.cmanager.getMatrix(token, dataset, algorithm, type);
this.extractor = new InfoExtractor(this.mapper.readConfigJSON(config), m, this.globalConfig);
} catch (NotFoundEx ex) {
File config = this.fmanager.getConfig(token, dataset, language);
File global = this.fmanager.getGlobalConfig(language);
this.globalConfig = this.mapper.readGlobalConfigJSON(global);
Matrix m = this.cmanager.getMatrix(token, dataset, algorithm, type);
this.extractor = new InfoExtractor(this.mapper.readConfigJSON(config), m, this.globalConfig);
}
Explanation explanation = new Explanation();
ArrayList<String> instances = this.extractor.getWrongInstances(consequents);
SPhraseSpec phrase;
switch (language) {
case "en":
this.generator = new ClauseGeneratorEn();
if (!instances.isEmpty()) {
explanation.addClause(((ClauseGeneratorEn) this.generator).getRealisation(this.generator.generateNoumsCoordinate(instances)));
} else {
explanation.addClause("Not confused");
}
break;
case "es":
this.generator = new ClauseGeneratorEs();
if (!instances.isEmpty()) {
explanation.addClause(((ClauseGeneratorEs) this.generator).getRealisation(this.generator.generateNoumsCoordinate(instances)));
} else {
explanation.addClause("No confundidos");
}
break;
case "gl":
this.generator = new ClauseGeneratorGl();
if (!instances.isEmpty()) {
explanation.addClause(((ClauseGeneratorGl) this.generator).getRealisation(this.generator.generateNoumsCoordinate(instances)));
} else {
explanation.addClause("Non confundidos");
}
break;
}
return explanation;
}
@Override
public Explanation getLocalExplanation(String token, String dataset, String algorithm, ArrayList<Classification> classifications, String language) throws FormatEx, IOException, NotFoundEx {
......
......@@ -7,6 +7,7 @@ import brunolopez.expliclas.models.Classification;
import brunolopez.expliclas.models.Consequent;
import brunolopez.expliclas.models.DatasetConfig;
import brunolopez.expliclas.models.GlobalConfig;
import brunolopez.expliclas.models.Matrix;
import java.io.File;
import java.io.FileNotFoundException;
import java.util.ArrayList;
......@@ -33,6 +34,12 @@ public class InfoExtractor {
this.classifications = classifications;
}
public InfoExtractor(DatasetConfig config, Matrix matrix, GlobalConfig global) throws FileNotFoundException {
this.config = config;
this.analyzer = new ConfusionAnalyzer(matrix);
this.global = global;
}
public String getDatasetName() {
return this.config.getDataset();
}
......@@ -122,6 +129,20 @@ public class InfoExtractor {
return this.analyzer.getConfusionBetween(consequents.get(0).getMatrixPosition() - 1, consequents.get(1).getMatrixPosition() - 1);
}
public ArrayList<String> getWrongInstances(ArrayList<Consequent> consequents) {
ArrayList<String> instances = new ArrayList();
ArrayList<Integer> wrongInstances = this.analyzer.getWrongInstances(consequents.get(0).getMatrixPosition() - 1, consequents.get(1).getMatrixPosition() - 1);
if (wrongInstances != null) {
for (Integer i : wrongInstances) {
instances.add("Instance " + i);
}
}
return instances;
}
public int getNumAlternatives() {
return this.classifications.size();
}
......@@ -137,15 +158,27 @@ public class InfoExtractor {
antecedent = this.classifications.get(alternative).getAntecedenById(a.getId());
if (antecedent != null) {
if (antecedent.isCategoric()) {
categoricId = antecedent.getDecision().split(" is ")[1];
attributes = list.get(((CategoricProperty) this.config.getAttributeById(a.getId()).getPropertyByName(categoricId)).getValue());
if (attributes == null) {
attributes = new ArrayList();
attributes.add(a.getName());
if (antecedent.getDecision().contains(" is ")) {
categoricId = antecedent.getDecision().split(" is ")[1];
attributes = list.get(((CategoricProperty) this.config.getAttributeById(a.getId()).getPropertyByName(categoricId)).getValue());
if (attributes == null) {
attributes = new ArrayList();
attributes.add(a.getName());
} else {
attributes.add(a.getName());
}
list.put(((CategoricProperty) this.config.getAttributeById(a.getId()).getPropertyByName(categoricId)).getValue(), attributes);
} else {
attributes.add(a.getName());
categoricId = antecedent.getDecision().split(" not ")[1];
attributes = list.get("not " + ((CategoricProperty) this.config.getAttributeById(a.getId()).getPropertyByName(categoricId)).getValue());
if (attributes == null) {
attributes = new ArrayList();
attributes.add(a.getName());
} else {
attributes.add(a.getName());
}
list.put("not " + ((CategoricProperty) this.config.getAttributeById(a.getId()).getPropertyByName(categoricId)).getValue(), attributes);
}
list.put(((CategoricProperty) this.config.getAttributeById(a.getId()).getPropertyByName(categoricId)).getValue(), attributes);
} else {
attributes = list.get(this.config.getAttributeById(a.getId()).getPropertyById(antecedent.getProperty()).getName());
if (attributes == null) {
......@@ -211,9 +244,9 @@ public class InfoExtractor {
consequents.add(this.config.getConsequentById(this.classifications.get(i).getConsequent().getId()).getName());
globalProbs.put(label, consequents);
}
return globalProbs;
}
public HashMap<String, Double> getSplitValues(int alternative) {
......@@ -223,7 +256,7 @@ public class InfoExtractor {
for (Antecedent a : this.classifications.get(alternative).getAntecedents()) {
if (!a.isCategoric()) {
Double split = this.classifications.get(alternative).getSplitValue(a.getDecision().split(" > | <= ")[0]);
if (split != null) {
values.put(this.config.getAttributeById(a.getDecision().split(" > | <= ")[0]).getName(), split);
}
......
......@@ -47,6 +47,6 @@ public class Antecedent {
@JsonIgnore
public boolean isCategoric(){
return this.decision.contains("is");
return this.decision.contains(" is ") || this.decision.contains(" not ");
}
}
......@@ -34,7 +34,33 @@ public class CategoricNode extends Node {
}
public Node getChild(String key) {
return children.get(key);
return children.get("=:" + key);
}
public Node getNotChild(String key){
Node n;
for(Map.Entry<String, Node> entry: children.entrySet()){
if(entry.getKey().startsWith("!=") && !entry.getKey().split(":")[1].equals(key)){
n = entry.getValue();
return n;
}
}
return null;
}
public String getNotChildValue(String key){
for(Map.Entry<String, Node> entry: children.entrySet()){
if(entry.getKey().startsWith("!=") && !entry.getKey().split(":")[1].equals(key)){
return entry.getKey().replaceFirst("!=:", "");
}
}
return null;
}
public void addChild(String value, Node node){
......
......@@ -46,11 +46,10 @@ public class Position {
@Override
public int hashCode() {
int hash = 5;
hash = 29 * hash + this.row;
hash = 29 * hash + this.column;
int hash = 7;
hash = 79 * hash + this.row;
hash = 79 * hash + this.column;
return hash;
}
}
......@@ -173,7 +173,59 @@ public class DatasetService {
Dataset dataset = this.manager.getDataset(token, name);
return Response.status(Response.Status.OK).entity(dataset).build();
} catch (NotFoundEx ex) {
return Response.status(Response.Status.NOT_FOUND).entity(ex.getMessage()).build();
return Response.status(Response.Status.NOT_FOUND).entity(new SimpleMessage(ex.getMessage())).build();
} catch (IOException ex) {
return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(new SimpleMessage("Could not retrieve dataset")).build();
}
}
@Operation(
summary = "Get a test dataset",
description = "Get a specific test dataset",
responses = {
@ApiResponse(
responseCode = "200",
description = "Dataset successfuly obtained",
content = @Content(mediaType = "application/json",
schema = @Schema(implementation = Dataset.class))
),
@ApiResponse(
responseCode = "404",
description = "Dataset not found",
content = @Content(mediaType = "application/json",
schema = @Schema(implementation = SimpleMessage.class))
),
@ApiResponse(
responseCode = "500",
description = "Error retrieving dataset",
content = @Content(mediaType = "application/json",
schema = @Schema(implementation = SimpleMessage.class))
)
},
tags = "dataset"
)
@SecurityRequirement(name = "token")
@GET
@Path("/{name}/test")
@Produces(MediaType.APPLICATION_JSON)
public Response getTestDataset(@Context HttpHeaders httpheaders,
@Parameter(description = "Dataset name")
@PathParam("name") String name) {
String header = httpheaders.getHeaderString(HttpHeaders.AUTHORIZATION);
String token = "";
if (header != null) {
token = header.substring("Bearer".length()).trim();
}
try {
Dataset dataset = this.manager.getTestDataset(token, name);
return Response.status(Response.Status.OK).entity(dataset).build();
} catch (NotFoundEx ex) {
return Response.status(Response.Status.NOT_FOUND).entity(new SimpleMessage(ex.getMessage())).build();
} catch (IOException ex) {
return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(new SimpleMessage("Could not retrieve dataset")).build();
}
......@@ -248,6 +300,67 @@ public class DatasetService {
}
}
@Operation(
summary = "Upload a test dataset",
description = "Upload a new test dataset",
responses = {
@ApiResponse(
responseCode = "201",
description = "Dataset successfuly uploaded",
content = @Content(mediaType = "application/json",
schema = @Schema(implementation = Dataset.class))
),
@ApiResponse(
responseCode = "400",
description = "Dataset format is wrong",
content = @Content(mediaType = "application/json",
schema = @Schema(implementation = SimpleMessage.class))
),
@ApiResponse(
responseCode = "500",
description = "Error uploading dataset",
content = @Content(mediaType = "application/json",
schema = @Schema(implementation = SimpleMessage.class))
),
@ApiResponse(
responseCode = "401",
description = "Unauthorized",
content = @Content(mediaType = "application/json",
schema = @Schema(implementation = SimpleMessage.class))
)
},
tags = "dataset"
)
@SecurityRequirement(name = "token")
@POST
@Path("/{name}/test")
@Consumes(MediaType.MULTIPART_FORM_DATA)
@Produces(MediaType.APPLICATION_JSON)
@TokenNeeded
public Response uploadTestDataset(@Context HttpHeaders httpheaders,
@Parameter(description = "Dataset name")
@PathParam("name") String name,
@Parameter(schema = @Schema(type = "string", format = "binary", description = "Dataset file in arff format"))
@FormDataParam("file") InputStream uploadedInputStream) {
String header = httpheaders.getHeaderString(HttpHeaders.AUTHORIZATION);
String token = header.substring("Bearer".length()).trim();
try {
ModifiedDataset dataset = this.manager.uploadTestDataset(token, name, uploadedInputStream);
return Response.status(Response.Status.CREATED).entity(dataset).build();
} catch (FormatEx ex) {
return Response.status(Response.Status.BAD_REQUEST).entity(new SimpleMessage(ex.getMessage())).build();
} catch (IOException ex) {
return Response.status(Response.Status.INTERNAL_SERVER_ERROR).entity(new SimpleMessage("Error uploading file")).build();
}
}
@Operation(
......
......@@ -169,7 +169,9 @@ public class ExplainerService {
@Parameter(description = "Consequents")
ArrayList<Consequent> consequents,
@Parameter(description = "Explanation language {en: english, es: spanish, gl: galician}")
@DefaultValue("en") @QueryParam("lang") String lang) {
@DefaultValue("en") @QueryParam("lang") String lang,
@Parameter(description = "Matrix {cv: cross-validation matrix, train: train matrix, test: test matrix}")
@DefaultValue("cv") @QueryParam("type") String type) {
String header = httpheaders.getHeaderString(HttpHeaders.AUTHORIZATION);
String token = "";
......@@ -180,7 +182,12 @@ public class ExplainerService {
Explanation explanation;
try{
explanation = this.manager.getConfusion(token, dataset, algorithm, consequents, lang);
if(type.equals("cv"))
explanation = this.manager.getConfusion(token, dataset, algorithm, consequents, lang);
else if(type.equals("train") || type.equals("test"))
explanation = this.manager.getConfusionInstances(token, dataset, algorithm, consequents