Java源码示例:org.deeplearning4j.nn.api.Model

示例1
@Override
public void onGradientCalculation(Model model) {
    int iterCount = getModelInfo(model).iterCount;
    if (calcFromGradients() && updateConfig.reportingFrequency() > 0
            && (iterCount == 0 || iterCount % updateConfig.reportingFrequency() == 0)) {
        Gradient g = model.gradient();
        if (updateConfig.collectHistograms(StatsType.Gradients)) {
            gradientHistograms = getHistograms(g.gradientForVariable(), updateConfig.numHistogramBins(StatsType.Gradients));
        }

        if (updateConfig.collectMean(StatsType.Gradients)) {
            meanGradients = calculateSummaryStats(g.gradientForVariable(), StatType.Mean);
        }
        if (updateConfig.collectStdev(StatsType.Gradients)) {
            stdevGradient = calculateSummaryStats(g.gradientForVariable(), StatType.Stdev);
        }
        if (updateConfig.collectMeanMagnitudes(StatsType.Gradients)) {
            meanMagGradients = calculateSummaryStats(g.gradientForVariable(), StatType.MeanMagnitude);
        }
    }
}
 
示例2
@Override
public void onEpochEnd(Model model) {
  currentEpoch++;

  // Skip if this is not an evaluation epoch
  if (currentEpoch % n != 0) {
    return;
  }

  String s = "Epoch [" + currentEpoch + "/" + numEpochs + "]\n";

  if (isIntermediateEvaluationsEnabled) {
    s += "Train Set:      \n" + evaluateDataSetIterator(model, trainIterator, true);
    if (validationIterator != null) {
      s += "Validation Set: \n" + evaluateDataSetIterator(model, validationIterator, false);
    }
  }

  log(s);
}
 
示例3
/**
 * This method does forward pass and returns output provided by OutputAdapter
 *
 * @param adapter
 * @param input
 * @param inputMasks
 * @param <T>
 * @return
 */
public <T> T output(@NonNull ModelAdapter<T> adapter, INDArray[] input, INDArray[] inputMasks, INDArray[] labelsMasks) {
    val holder = selector.getModelForThisThread();
    Model model = null;
    boolean acquired = false;
    try {
        model = holder.acquireModel();
        acquired = true;
        return adapter.apply(model, input, inputMasks, labelsMasks);
    } catch (InterruptedException e) {
        throw new RuntimeException(e);
    } finally {
        if (model != null && acquired)
            holder.releaseModel(model);
    }
}
 
示例4
protected void triggerEpochListeners(boolean epochStart, Model model, int epochNum){
    Collection<TrainingListener> listeners;
    if(model instanceof MultiLayerNetwork){
        MultiLayerNetwork n = ((MultiLayerNetwork) model);
        listeners = n.getListeners();
        n.setEpochCount(epochNum);
    } else if(model instanceof ComputationGraph){
        ComputationGraph cg = ((ComputationGraph) model);
        listeners = cg.getListeners();
        cg.getConfiguration().setEpochCount(epochNum);
    } else {
        return;
    }

    if(listeners != null && !listeners.isEmpty()){
        for (TrainingListener l : listeners) {
            if (epochStart) {
                l.onEpochStart(model);
            } else {
                l.onEpochEnd(model);
            }
        }
    }
}
 
示例5
private static void doEval(Model m, IEvaluation[] e, Iterator<DataSet> ds, Iterator<MultiDataSet> mds, int evalBatchSize){
    if(m instanceof MultiLayerNetwork){
        MultiLayerNetwork mln = (MultiLayerNetwork)m;
        if(ds != null){
            mln.doEvaluation(new IteratorDataSetIterator(ds, evalBatchSize), e);
        } else {
            mln.doEvaluation(new IteratorMultiDataSetIterator(mds, evalBatchSize), e);
        }
    } else {
        ComputationGraph cg = (ComputationGraph)m;
        if(ds != null){
            cg.doEvaluation(new IteratorDataSetIterator(ds, evalBatchSize), e);
        } else {
            cg.doEvaluation(new IteratorMultiDataSetIterator(mds, evalBatchSize), e);
        }
    }
}
 
示例6
@Test
public void test() throws Exception {
  int testsCount = 0;
  for (int numInputs = 1; numInputs <= 5; ++numInputs) {
    for (int numOutputs = 1; numOutputs <= 5; ++numOutputs) {

      for (Model model : new Model[]{
          buildMultiLayerNetworkModel(numInputs, numOutputs),
          buildComputationGraphModel(numInputs, numOutputs)
        }) {

        doTest(model, numInputs, numOutputs);
        ++testsCount;

      }
    }
  }
  assertEquals(50, testsCount);
}
 
示例7
@Test
public void testNormalizerInPlace() throws Exception {
    MultiLayerNetwork net = getNetwork();

    File tempFile = testDir.newFile("testNormalizerInPlace.bin");

    NormalizerMinMaxScaler normalizer = new NormalizerMinMaxScaler(0, 1);
    normalizer.fit(new DataSet(Nd4j.rand(new int[] {2, 2}), Nd4j.rand(new int[] {2, 2})));
    ModelSerializer.writeModel(net, tempFile, true,normalizer);

    Model model = ModelGuesser.loadModelGuess(tempFile.getAbsolutePath());
    Normalizer<?> normalizer1 = ModelGuesser.loadNormalizer(tempFile.getAbsolutePath());
    assertEquals(model, net);
    assertEquals(normalizer, normalizer1);

}
 
示例8
private static Map<String,INDArray> getFrozenLayerParamCopies(Model m){
    Map<String,INDArray> out = new LinkedHashMap<>();
    org.deeplearning4j.nn.api.Layer[] layers;
    if (m instanceof MultiLayerNetwork) {
        layers = ((MultiLayerNetwork) m).getLayers();
    } else {
        layers = ((ComputationGraph) m).getLayers();
    }

    for(org.deeplearning4j.nn.api.Layer l : layers){
        if(l instanceof FrozenLayer){
            String paramPrefix;
            if(m instanceof MultiLayerNetwork){
                paramPrefix = l.getIndex() + "_";
            } else {
                paramPrefix = l.conf().getLayer().getLayerName() + "_";
            }
            Map<String,INDArray> paramTable = l.paramTable();
            for(Map.Entry<String,INDArray> e : paramTable.entrySet()){
                out.put(paramPrefix + e.getKey(), e.getValue().dup());
            }
        }
    }

    return out;
}
 
示例9
/**
 * Loads a dl4j zip file (either computation graph or multi layer network)
 *
 * @param path the path to the file to load
 * @return a loaded dl4j model
 * @throws Exception if loading a dl4j model fails
 */
public static Model loadDl4jGuess(String path) throws Exception {
    if (isZipFile(new File(path))) {
        log.debug("Loading file " + path);
        boolean compGraph = false;
        try (ZipFile zipFile = new ZipFile(path)) {
            List<String> collect = zipFile.stream().map(ZipEntry::getName)
                    .collect(Collectors.toList());
            log.debug("Entries " + collect);
            if (collect.contains(ModelSerializer.COEFFICIENTS_BIN) && collect.contains(ModelSerializer.CONFIGURATION_JSON)) {
                ZipEntry entry = zipFile.getEntry(ModelSerializer.CONFIGURATION_JSON);
                log.debug("Loaded configuration");
                try (InputStream is = zipFile.getInputStream(entry)) {
                    String configJson = IOUtils.toString(is, StandardCharsets.UTF_8);
                    JSONObject jsonObject = new JSONObject(configJson);
                    if (jsonObject.has("vertexInputs")) {
                        log.debug("Loading computation graph.");
                        compGraph = true;
                    } else {
                        log.debug("Loading multi layer network.");
                    }

                }
            }
        }

        if (compGraph) {
            return ModelSerializer.restoreComputationGraph(new File(path));
        } else {
            return ModelSerializer.restoreMultiLayerNetwork(new File(path));
        }
    }

    return null;
}
 
示例10
public static String saveModel(String name, Model model, int index, int accuracy) throws Exception {
  	System.err.println("Saving model, don't shutdown...");
      try {
      	String fn = name + "_idx_" + index + "_" + accuracy + ".zip";
	File locationToSave = new File(System.getProperty("user.dir") + "/model/" + fn);
	boolean saveUpdater = true;                                             //Updater: i.e., the state for Momentum, RMSProp, Adagrad etc. Save this if you want to train your network more in the future
	ModelSerializer.writeModel(model, locationToSave, saveUpdater);
	System.err.println("Model saved");
	return fn;
} catch (IOException e) {
	System.err.println("Save model failed");
	e.printStackTrace();
	throw e;
}
  }
 
示例11
@Override
public void onBackwardPass(Model model) {
    if(!printOnBackwardPass || printFileTarget == null)
        return;

    writeFileWithMessage("backward pass");
}
 
示例12
private static Model buildModel() throws Exception {

    final int numInputs = 3;
    final int numOutputs = 2;

    final MultiLayerConfiguration conf = new NeuralNetConfiguration.Builder()
        .list(
            new OutputLayer.Builder()
                           .nIn(numInputs)
                           .nOut(numOutputs)
                           .activation(Activation.IDENTITY)
                           .lossFunction(LossFunctions.LossFunction.MSE)
                           .build()
            )
        .build();

    final MultiLayerNetwork model = new MultiLayerNetwork(conf);
    model.init();

    final float[] floats = new float[]{ +1, +1, +1, -1, -1, -1, 0, 0 };
    // positive weight for first output, negative weight for second output, no biases
    assertEquals((numInputs+1)*numOutputs, floats.length);

    final INDArray params = Nd4j.create(floats);
    model.setParams(params);

    return model;
  }
 
示例13
@Override
public void iterationDone(Model model, int iteration, int epoch) {
    sleep(lastIteration.get(), timerIteration);

    if (lastIteration.get() == null)
        lastIteration.set(new AtomicLong(System.currentTimeMillis()));
    else
        lastIteration.get().set(System.currentTimeMillis());
}
 
示例14
@Override
public void iterationDone(Model model, int iteration, int epoch) {
    if (statusListeners == null) {
        return;
    }

    for (StatusListener sl : statusListeners) {
        sl.onCandidateIteration(candidateInfo, model, iteration);
    }
}
 
示例15
@Override
public void onForwardPass(Model model, List<INDArray> activations) {
    if(!printOnBackwardPass || printFileTarget == null)
        return;

    writeFileWithMessage("forward pass");

}
 
示例16
private static void validateLayerIterCounts(Model m, int expEpoch, int expIter){
    //Check that the iteration and epoch counts - on the layers - are synced
    org.deeplearning4j.nn.api.Layer[] layers;
    if (m instanceof MultiLayerNetwork) {
        layers = ((MultiLayerNetwork) m).getLayers();
    } else {
        layers = ((ComputationGraph) m).getLayers();
    }

    for(org.deeplearning4j.nn.api.Layer l : layers){
        assertEquals("Epoch count", expEpoch, l.getEpochCount());
        assertEquals("Iteration count", expIter, l.getIterationCount());
    }
}
 
示例17
public static org.deeplearning4j.nn.api.Updater getUpdater(Model layer) {
    if (layer instanceof MultiLayerNetwork) {
        return new MultiLayerUpdater((MultiLayerNetwork) layer);
    } else if (layer instanceof ComputationGraph) {
        return new ComputationGraphUpdater((ComputationGraph) layer);
    } else {
        return new LayerUpdater((Layer) layer);
    }
}
 
示例18
/**
 *
 * @param conf
 * @param stepFunction
 * @param trainingListeners
 * @param model
 */
public BaseOptimizer(NeuralNetConfiguration conf, StepFunction stepFunction,
                Collection<TrainingListener> trainingListeners, Model model) {
    this.conf = conf;
    this.stepFunction = (stepFunction != null ? stepFunction : getDefaultStepFunctionForOptimizer(this.getClass()));
    this.trainingListeners = trainingListeners != null ? trainingListeners : new ArrayList<TrainingListener>();
    this.model = model;
    lineMaximizer = new BackTrackLineSearch(model, this.stepFunction, this);
    lineMaximizer.setStepMax(stepMax);
    lineMaximizer.setMaxIterations(conf.getMaxNumLineSearchIterations());
}
 
示例19
@Override
public void onForwardPass(Model model, Map<String, INDArray> activations) {
    if(!printOnForwardPass)
        return;

    SystemInfo systemInfo = new SystemInfo();
    log.info(SYSTEM_INFO);
    log.info(systemInfo.toPrettyJSON());
}
 
示例20
public static int getEpochCount(Model model){
    if (model instanceof MultiLayerNetwork) {
        return ((MultiLayerNetwork) model).getLayerWiseConfigurations().getEpochCount();
    } else if (model instanceof ComputationGraph) {
        return ((ComputationGraph) model).getConfiguration().getEpochCount();
    } else {
        return model.conf().getEpochCount();
    }
}
 
示例21
@Override
public void onBackwardPass(Model model) {
    sleep(lastBP.get(), timerBP);

    if (lastBP.get() == null)
        lastBP.set(new AtomicLong(System.currentTimeMillis()));
    else
        lastBP.get().set(System.currentTimeMillis());
}
 
示例22
protected static String getModelType(Model model){
    if(model.getClass() == MultiLayerNetwork.class){
        return "MultiLayerNetwork";
    } else if(model.getClass() == ComputationGraph.class){
        return "ComputationGraph";
    } else {
        return "Model";
    }
}
 
示例23
/**
 * Uses the {@link ModelGuesser#loadModelGuess(InputStream)} method.
 */
protected Model restoreModel(InputStream inputStream) throws IOException {
  final File instanceDir = solrResourceLoader.getInstancePath().toFile();
  try {
    return ModelGuesser.loadModelGuess(inputStream, instanceDir);
  } catch (Exception e) {
    throw new IOException("Failed to restore model from given file (" + serializedModelFileName + ")", e);
  }
}
 
示例24
@Override
public void onForwardPass(Model model, Map<String, INDArray> activations) {
    if(!printOnForwardPass || printFileTarget == null)
        return;

    writeFileWithMessage("forward pass");

}
 
示例25
protected static int getEpoch(Model model) {
    if (model instanceof MultiLayerNetwork) {
        return ((MultiLayerNetwork) model).getLayerWiseConfigurations().getEpochCount();
    } else if (model instanceof ComputationGraph) {
        return ((ComputationGraph) model).getConfiguration().getEpochCount();
    } else {
        return model.conf().getEpochCount();
    }
}
 
示例26
@Override
public void onGradientCalculation(Model model) {
    if(!printOnGradientCalculation)
        return;

    SystemInfo systemInfo = new SystemInfo();
    log.info(SYSTEM_INFO);
    log.info(systemInfo.toPrettyJSON());
}
 
示例27
@Override
public void onEpochEnd(Model model) {
    int epochsDone = getEpoch(model) + 1;
    if(saveEveryNEpochs != null && epochsDone > 0 && epochsDone % saveEveryNEpochs == 0){
        //Save:
        saveCheckpoint(model);
    }
    //General saving conditions: don't need to check here - will check in iterationDone
}
 
示例28
@Override
public boolean triggerFailure(CallType callType, int iteration, int epoch, Model model) {
    boolean b = false;
    for(FailureTrigger ft : triggers)
        b |= ft.triggerFailure(callType, iteration, epoch, model);
    return b;
}
 
示例29
@Override
protected synchronized Model[] getCurrentModelsFromWorkers() {
    val models = new Model[holders.size()];
    int cnt = 0;
    for (val h:holders) {
        models[cnt++] = h.sourceModel;
    }

    return models;
}
 
示例30
@Override
public Class<? extends Model> modelType() {
    return ComputationGraph.class;
}