Skip to content

Commit

Permalink
Leos comments (#44)
Browse files Browse the repository at this point in the history
* Address comments in #42
  • Loading branch information
alanocallaghan authored Aug 6, 2024
1 parent 3b30495 commit ac62573
Show file tree
Hide file tree
Showing 12 changed files with 181 additions and 146 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ public class DetectionMeasurer {
private final Collection<ObjectMeasurements.ShapeFeatures> shapeFeatures;
private final double pixelSize;


private DetectionMeasurer(Collection<ObjectMeasurements.Compartments> compartments,
Collection<ObjectMeasurements.Measurements> measurements,
Collection<ObjectMeasurements.ShapeFeatures> shapeFeatures,
Expand Down
1 change: 0 additions & 1 deletion src/main/java/qupath/ext/instanseg/core/InstanSeg.java
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,6 @@ public Builder currentImageData() {
return this;
}


/**
* Set the channels to be used in inference
* @param channels A collection of channels to be used in inference
Expand Down
131 changes: 70 additions & 61 deletions src/main/java/qupath/ext/instanseg/core/InstanSegModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ private InstanSegModel(BioimageIoSpec.BioimageIoModel bioimageIoModel) {
this.name = model.getName();
}


private InstanSegModel(URL modelURL, String name) {
this.modelURL = modelURL;
this.name = name;
Expand All @@ -75,26 +74,80 @@ public static InstanSegModel fromName(String name) {
throw new UnsupportedOperationException("Fetching models by name is not yet implemented!");
}

public BioimageIoSpec.BioimageIoModel getModel() {
if (model == null) {
try {
fetchModel();
} catch (IOException e) {
// todo: exception handling here, or...?
throw new RuntimeException(e);
}
}
return model;
}

/**
* Get the pixel size in the X dimension.
* @return the pixel size in the X dimension.
*/
public Double getPixelSizeX() {
return getPixelSize().get("x");
}

/**
* Get the pixel size in the Y dimension.
* @return the pixel size in the Y dimension.
*/
public Double getPixelSizeY() {
return getPixelSize().get("y");
}

/**
* Get the path where the model is stored on disk.
* @return A path on disk, or an exception if it can't be found.
*/
public Path getPath() {
if (path == null) {
fetchModel();
}
return path;
}

@Override
public String toString() {
return getName();
}

/**
* Check if a path is (likely) a valid InstanSeg model.
* @param path The path to a folder.
* @return True if the folder contains an instanseg.pt file and an accompanying rdf.yaml.
* Does not currently validate the contents of either, but may in future check
* the yaml contents and the checksum of the pt file.
*/
public static boolean isValidModel(Path path) {
// return path.toString().endsWith(".pt"); // if just looking at pt files
if (Files.isDirectory(path)) {
return Files.exists(path.resolve("instanseg.pt")) && Files.exists(path.resolve("rdf.yaml"));
}
return false;
}

/**
* The number of tiles that failed during processing.
* @return The count of the number of failed tiles.
*/
public int nFailed() {
return nFailed;
}

/**
* Get the model name
* @return A string
*/
String getName() {
return name;
}

/**
* Retrieve the BioImage model spec.
* @return The BioImageIO model spec for this InstanSeg model.
*/
BioimageIoSpec.BioimageIoModel getModel() {
if (model == null) {
fetchModel();
}
return model;
}

private Map<String, Double> getPixelSize() {
// todo: this code is horrendous
var map = new HashMap<String, Double>();
Expand All @@ -105,45 +158,23 @@ private Map<String, Double> getPixelSize() {
return map;
}

private void fetchModel() throws IOException {
private void fetchModel() {
if (modelURL == null) {
throw new NullPointerException("Model URL should not be null for a local model!");
}
downloadAndUnzip(modelURL, getUserDir().resolve("instanseg"));
}

private static void downloadAndUnzip(URL url, Path localDirectory) throws IOException {
private static void downloadAndUnzip(URL url, Path localDirectory) {
// todo: implement
}


private static Path getUserDir() {
Path userPath = UserDirectoryManager.getInstance().getUserPath();
Path cachePath = Paths.get(System.getProperty("user.dir"), ".cache", "QuPath");
return userPath == null || userPath.toString().isEmpty() ? cachePath : userPath;
}

public String getName() {
return name;
}

public Path getPath() {
if (path == null) {
try {
fetchModel();
} catch (IOException e) {
// todo: handle here, or...?
throw new RuntimeException(e);
}
}
return path;
}

@Override
public String toString() {
return getName();
}

void runInstanSeg(
ImageData<BufferedImage> imageData,
Collection<PathObject> pathObjects,
Expand All @@ -158,7 +189,8 @@ void runInstanSeg(
TaskRunner taskRunner) {

nFailed = 0;
Path modelPath = getPath().resolve("instanseg.pt");
Path modelPath;
modelPath = getPath().resolve("instanseg.pt");
int nPredictors = 1; // todo: change me?


Expand Down Expand Up @@ -227,27 +259,4 @@ private static void printResourceCount(String title, BaseNDManager manager) {
manager.debugDump(2);
}

/**
* Check if a path is (likely) a valid InstanSeg model.
* @param path The path to a folder.
* @return True if the folder contains an instanseg.pt file and an accompanying rdf.yaml.
* Does not currently validate the contents of either, but may in future check
* the yaml contents and the checksum of the pt file.
*/
public static boolean isValidModel(Path path) {
// return path.toString().endsWith(".pt"); // if just looking at pt files
if (Files.isDirectory(path)) {
return Files.exists(path.resolve("instanseg.pt")) && Files.exists(path.resolve("rdf.yaml"));
}
return false;
}

/**
* The number of tiles that failed during processing.
* @return The count of the number of failed tiles.
*/
public int nFailed() {
return nFailed;
}

}
17 changes: 11 additions & 6 deletions src/main/java/qupath/ext/instanseg/core/MatTranslator.java
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,18 @@ class MatTranslator implements Translator<Mat, Mat> {

private final String inputLayoutNd;
private final String outputLayoutNd;
private final boolean nucleiOnly;
private final boolean firstChannelOnly;

public MatTranslator(String inputLayoutNd, String outputLayoutNd, boolean nucleiOnly) {
/**
* Create a translator from InstanSeg input to output.
* @param inputLayoutNd N-dimensional output specification
* @param outputLayoutNd N-dimensional output specification
* @param firstChannelOnly Should the model only be concerned with the first output channel?
*/
MatTranslator(String inputLayoutNd, String outputLayoutNd, boolean firstChannelOnly) {
this.inputLayoutNd = inputLayoutNd;
this.outputLayoutNd = outputLayoutNd;
this.nucleiOnly = nucleiOnly;
this.firstChannelOnly = firstChannelOnly;
}

/**
Expand All @@ -31,9 +37,8 @@ public NDList processInput(TranslatorContext ctx, Mat input) {
var manager = ctx.getNDManager();
var ndarray = DjlTools.matToNDArray(manager, input, inputLayoutNd);
var out = new NDList(ndarray);
if (nucleiOnly) {
var inds = new int[]{1, 1};
inds[1] = 0;
if (firstChannelOnly) {
var inds = new int[]{1, 0};
var array = manager.create(inds, new Shape(2));
var arrayCPU = array.toDevice(Device.cpu(), false);
out.add(arrayCPU);
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
package qupath.ext.instanseg.ui;
package qupath.ext.instanseg.core;

import ai.djl.engine.Engine;
import ai.djl.engine.EngineException;
Expand All @@ -15,15 +15,28 @@
/**
* Helper class to manage access to PyTorch via Deep Java Library.
*/
class PytorchManager {
public class PytorchManager {

private static final Logger logger = LoggerFactory.getLogger(PytorchManager.class);

/**
* Get the PyTorch engine, downloading if necessary.
* @return the engine if available, or null if this failed
*/
public static Engine getEngineOnline() {
try {
return callOnline(() -> Engine.getEngine("PyTorch"));
} catch (Exception e) {
logger.error(e.getMessage(), e);
return null;
}
}

/**
* Get the available devices for PyTorch, including MPS if Apple Silicon.
* @return Only "cpu" if no local engine is found.
*/
static Collection<String> getAvailableDevices() {
public static Collection<String> getAvailableDevices() {
try {
Set<String> availableDevices = new LinkedHashSet<>();

Expand Down Expand Up @@ -54,7 +67,7 @@ static Collection<String> getAvailableDevices() {
* Query if the PyTorch engine is already available, without a need to download.
* @return
*/
static boolean hasPyTorchEngine() {
public static boolean hasPyTorchEngine() {
return getEngineOffline() != null;
}

Expand All @@ -71,19 +84,6 @@ static Engine getEngineOffline() {
}
}

/**
* Get the PyTorch engine, downloading if necessary.
* @return the engine if available, or null if this failed
*/
static Engine getEngineOnline() {
try {
return callOnline(() -> Engine.getEngine("PyTorch"));
} catch (Exception e) {
logger.error(e.getMessage(), e);
return null;
}
}

/**
* Call a function with the "offline" property set to true (to block automatic downloads).
* @param callable
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,14 @@ class TilePredictionProcessor implements Processor<Mat, Mat, Mat> {
this.doPadding = doPadding;
}

/**
* The number of tiles that failed during processing.
* @return The count of the number of failed tiles.
*/
public int nFailed() {
return nFailed;
}

@Override
public Mat process(Parameters<Mat, Mat> params) throws IOException {

Expand Down Expand Up @@ -169,11 +177,5 @@ private static ImageOp getNormalization(ImageData<BufferedImage> imageData, Path
return defaults;
}

/**
* The number of tiles that failed during processing.
* @return The count of the number of failed tiles.
*/
public int nFailed() {
return nFailed;
}

}
Loading

0 comments on commit ac62573

Please sign in to comment.