mirror of
https://github.com/supertone-inc/supertonic.git
synced 2026-07-03 14:08:32 +02:00
init
This commit is contained in:
@@ -0,0 +1,597 @@
|
||||
import ai.onnxruntime.*;
|
||||
import com.fasterxml.jackson.databind.JsonNode;
|
||||
import com.fasterxml.jackson.databind.ObjectMapper;
|
||||
|
||||
import javax.sound.sampled.AudioFileFormat;
|
||||
import javax.sound.sampled.AudioFormat;
|
||||
import javax.sound.sampled.AudioInputStream;
|
||||
import javax.sound.sampled.AudioSystem;
|
||||
import java.io.*;
|
||||
import java.nio.ByteBuffer;
|
||||
import java.nio.ByteOrder;
|
||||
import java.nio.FloatBuffer;
|
||||
import java.nio.LongBuffer;
|
||||
import java.nio.file.Files;
|
||||
import java.nio.file.Paths;
|
||||
import java.text.Normalizer;
|
||||
import java.util.*;
|
||||
|
||||
/**
|
||||
* Configuration classes
|
||||
*/
|
||||
class Config {
|
||||
static class AEConfig {
|
||||
int sampleRate;
|
||||
int baseChunkSize;
|
||||
}
|
||||
|
||||
static class TTLConfig {
|
||||
int chunkCompressFactor;
|
||||
int latentDim;
|
||||
}
|
||||
|
||||
AEConfig ae;
|
||||
TTLConfig ttl;
|
||||
}
|
||||
|
||||
/**
|
||||
* Voice Style Data from JSON
|
||||
*/
|
||||
class VoiceStyleData {
|
||||
static class StyleData {
|
||||
float[][][] data;
|
||||
long[] dims;
|
||||
String type;
|
||||
}
|
||||
|
||||
StyleData styleTtl;
|
||||
StyleData styleDp;
|
||||
}
|
||||
|
||||
/**
|
||||
* Unicode text processor
|
||||
*/
|
||||
class UnicodeProcessor {
|
||||
private long[] indexer;
|
||||
|
||||
public UnicodeProcessor(String unicodeIndexerJsonPath) throws IOException {
|
||||
this.indexer = Helper.loadJsonLongArray(unicodeIndexerJsonPath);
|
||||
}
|
||||
|
||||
public TextProcessResult call(List<String> textList) {
|
||||
List<String> processedTexts = new ArrayList<>();
|
||||
for (String text : textList) {
|
||||
processedTexts.add(preprocessText(text));
|
||||
}
|
||||
|
||||
int[] textIdsLengths = new int[processedTexts.size()];
|
||||
int maxLen = 0;
|
||||
for (int i = 0; i < processedTexts.size(); i++) {
|
||||
textIdsLengths[i] = processedTexts.get(i).length();
|
||||
maxLen = Math.max(maxLen, textIdsLengths[i]);
|
||||
}
|
||||
|
||||
long[][] textIds = new long[processedTexts.size()][maxLen];
|
||||
for (int i = 0; i < processedTexts.size(); i++) {
|
||||
int[] unicodeVals = textToUnicodeValues(processedTexts.get(i));
|
||||
for (int j = 0; j < unicodeVals.length; j++) {
|
||||
textIds[i][j] = indexer[unicodeVals[j]];
|
||||
}
|
||||
}
|
||||
|
||||
float[][][] textMask = getTextMask(textIdsLengths);
|
||||
return new TextProcessResult(textIds, textMask);
|
||||
}
|
||||
|
||||
private String preprocessText(String text) {
|
||||
return Normalizer.normalize(text, Normalizer.Form.NFKD);
|
||||
}
|
||||
|
||||
private int[] textToUnicodeValues(String text) {
|
||||
int[] values = new int[text.length()];
|
||||
for (int i = 0; i < text.length(); i++) {
|
||||
values[i] = text.codePointAt(i);
|
||||
}
|
||||
return values;
|
||||
}
|
||||
|
||||
private float[][][] getTextMask(int[] lengths) {
|
||||
int bsz = lengths.length;
|
||||
int maxLen = 0;
|
||||
for (int len : lengths) {
|
||||
maxLen = Math.max(maxLen, len);
|
||||
}
|
||||
|
||||
float[][][] mask = new float[bsz][1][maxLen];
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
for (int j = 0; j < maxLen; j++) {
|
||||
mask[i][0][j] = j < lengths[i] ? 1.0f : 0.0f;
|
||||
}
|
||||
}
|
||||
return mask;
|
||||
}
|
||||
|
||||
static class TextProcessResult {
|
||||
long[][] textIds;
|
||||
float[][][] textMask;
|
||||
|
||||
TextProcessResult(long[][] textIds, float[][][] textMask) {
|
||||
this.textIds = textIds;
|
||||
this.textMask = textMask;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Text-to-Speech inference class
|
||||
*/
|
||||
class TextToSpeech {
|
||||
private Config config;
|
||||
private UnicodeProcessor textProcessor;
|
||||
private OrtSession dpSession;
|
||||
private OrtSession textEncSession;
|
||||
private OrtSession vectorEstSession;
|
||||
private OrtSession vocoderSession;
|
||||
public int sampleRate;
|
||||
private int baseChunkSize;
|
||||
private int chunkCompress;
|
||||
private int ldim;
|
||||
|
||||
public TextToSpeech(Config config, UnicodeProcessor textProcessor,
|
||||
OrtSession dpSession, OrtSession textEncSession,
|
||||
OrtSession vectorEstSession, OrtSession vocoderSession) {
|
||||
this.config = config;
|
||||
this.textProcessor = textProcessor;
|
||||
this.dpSession = dpSession;
|
||||
this.textEncSession = textEncSession;
|
||||
this.vectorEstSession = vectorEstSession;
|
||||
this.vocoderSession = vocoderSession;
|
||||
this.sampleRate = config.ae.sampleRate;
|
||||
this.baseChunkSize = config.ae.baseChunkSize;
|
||||
this.chunkCompress = config.ttl.chunkCompressFactor;
|
||||
this.ldim = config.ttl.latentDim;
|
||||
}
|
||||
|
||||
public TTSResult call(List<String> textList, Style style, int totalStep, OrtEnvironment env)
|
||||
throws OrtException {
|
||||
int bsz = textList.size();
|
||||
|
||||
// Process text
|
||||
UnicodeProcessor.TextProcessResult textResult = textProcessor.call(textList);
|
||||
long[][] textIds = textResult.textIds;
|
||||
float[][][] textMask = textResult.textMask;
|
||||
|
||||
// Create tensors
|
||||
OnnxTensor textIdsTensor = Helper.createLongTensor(textIds, env);
|
||||
OnnxTensor textMaskTensor = Helper.createFloatTensor(textMask, env);
|
||||
|
||||
// Predict duration
|
||||
Map<String, OnnxTensor> dpInputs = new HashMap<>();
|
||||
dpInputs.put("text_ids", textIdsTensor);
|
||||
dpInputs.put("style_dp", style.dpTensor);
|
||||
dpInputs.put("text_mask", textMaskTensor);
|
||||
|
||||
OrtSession.Result dpResult = dpSession.run(dpInputs);
|
||||
Object dpValue = dpResult.get(0).getValue();
|
||||
float[] duration;
|
||||
if (dpValue instanceof float[][]) {
|
||||
duration = ((float[][]) dpValue)[0];
|
||||
} else {
|
||||
duration = (float[]) dpValue;
|
||||
}
|
||||
|
||||
// Encode text
|
||||
Map<String, OnnxTensor> textEncInputs = new HashMap<>();
|
||||
textEncInputs.put("text_ids", textIdsTensor);
|
||||
textEncInputs.put("style_ttl", style.ttlTensor);
|
||||
textEncInputs.put("text_mask", textMaskTensor);
|
||||
|
||||
OrtSession.Result textEncResult = textEncSession.run(textEncInputs);
|
||||
OnnxTensor textEmbTensor = (OnnxTensor) textEncResult.get(0);
|
||||
|
||||
// Sample noisy latent
|
||||
NoisyLatentResult noisyLatentResult = sampleNoisyLatent(duration);
|
||||
float[][][] xt = noisyLatentResult.noisyLatent;
|
||||
float[][][] latentMask = noisyLatentResult.latentMask;
|
||||
|
||||
// Prepare constant tensors
|
||||
float[] totalStepArray = new float[bsz];
|
||||
Arrays.fill(totalStepArray, (float) totalStep);
|
||||
OnnxTensor totalStepTensor = OnnxTensor.createTensor(env, totalStepArray);
|
||||
|
||||
// Denoising loop
|
||||
for (int step = 0; step < totalStep; step++) {
|
||||
float[] currentStepArray = new float[bsz];
|
||||
Arrays.fill(currentStepArray, (float) step);
|
||||
OnnxTensor currentStepTensor = OnnxTensor.createTensor(env, currentStepArray);
|
||||
OnnxTensor noisyLatentTensor = Helper.createFloatTensor(xt, env);
|
||||
OnnxTensor latentMaskTensor = Helper.createFloatTensor(latentMask, env);
|
||||
OnnxTensor textMaskTensor2 = Helper.createFloatTensor(textMask, env);
|
||||
|
||||
Map<String, OnnxTensor> vectorEstInputs = new HashMap<>();
|
||||
vectorEstInputs.put("noisy_latent", noisyLatentTensor);
|
||||
vectorEstInputs.put("text_emb", textEmbTensor);
|
||||
vectorEstInputs.put("style_ttl", style.ttlTensor);
|
||||
vectorEstInputs.put("latent_mask", latentMaskTensor);
|
||||
vectorEstInputs.put("text_mask", textMaskTensor2);
|
||||
vectorEstInputs.put("current_step", currentStepTensor);
|
||||
vectorEstInputs.put("total_step", totalStepTensor);
|
||||
|
||||
OrtSession.Result vectorEstResult = vectorEstSession.run(vectorEstInputs);
|
||||
float[][][] denoised = (float[][][]) vectorEstResult.get(0).getValue();
|
||||
|
||||
// Update latent
|
||||
xt = denoised;
|
||||
|
||||
// Clean up
|
||||
currentStepTensor.close();
|
||||
noisyLatentTensor.close();
|
||||
latentMaskTensor.close();
|
||||
textMaskTensor2.close();
|
||||
vectorEstResult.close();
|
||||
}
|
||||
|
||||
// Generate waveform
|
||||
OnnxTensor finalLatentTensor = Helper.createFloatTensor(xt, env);
|
||||
Map<String, OnnxTensor> vocoderInputs = new HashMap<>();
|
||||
vocoderInputs.put("latent", finalLatentTensor);
|
||||
|
||||
OrtSession.Result vocoderResult = vocoderSession.run(vocoderInputs);
|
||||
float[][] wavBatch = (float[][]) vocoderResult.get(0).getValue();
|
||||
float[] wav = wavBatch[0];
|
||||
|
||||
// Clean up
|
||||
textIdsTensor.close();
|
||||
textMaskTensor.close();
|
||||
dpResult.close();
|
||||
textEncResult.close();
|
||||
totalStepTensor.close();
|
||||
finalLatentTensor.close();
|
||||
vocoderResult.close();
|
||||
|
||||
return new TTSResult(wav, duration);
|
||||
}
|
||||
|
||||
private NoisyLatentResult sampleNoisyLatent(float[] duration) {
|
||||
int bsz = duration.length;
|
||||
float maxDur = 0;
|
||||
for (float d : duration) {
|
||||
maxDur = Math.max(maxDur, d);
|
||||
}
|
||||
|
||||
long wavLenMax = (long) (maxDur * sampleRate);
|
||||
long[] wavLengths = new long[bsz];
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
wavLengths[i] = (long) (duration[i] * sampleRate);
|
||||
}
|
||||
|
||||
int chunkSize = baseChunkSize * chunkCompress;
|
||||
int latentLen = (int) ((wavLenMax + chunkSize - 1) / chunkSize);
|
||||
int latentDim = ldim * chunkCompress;
|
||||
|
||||
Random rng = new Random();
|
||||
float[][][] noisyLatent = new float[bsz][latentDim][latentLen];
|
||||
for (int b = 0; b < bsz; b++) {
|
||||
for (int d = 0; d < latentDim; d++) {
|
||||
for (int t = 0; t < latentLen; t++) {
|
||||
// Box-Muller transform
|
||||
double u1 = Math.max(1e-10, rng.nextDouble());
|
||||
double u2 = rng.nextDouble();
|
||||
noisyLatent[b][d][t] = (float) (Math.sqrt(-2.0 * Math.log(u1)) * Math.cos(2.0 * Math.PI * u2));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float[][][] latentMask = Helper.getLatentMask(wavLengths, config);
|
||||
|
||||
// Apply mask
|
||||
for (int b = 0; b < bsz; b++) {
|
||||
for (int d = 0; d < latentDim; d++) {
|
||||
for (int t = 0; t < latentLen; t++) {
|
||||
noisyLatent[b][d][t] *= latentMask[b][0][t];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return new NoisyLatentResult(noisyLatent, latentMask);
|
||||
}
|
||||
|
||||
public void close() throws OrtException {
|
||||
if (dpSession != null) dpSession.close();
|
||||
if (textEncSession != null) textEncSession.close();
|
||||
if (vectorEstSession != null) vectorEstSession.close();
|
||||
if (vocoderSession != null) vocoderSession.close();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Style holder class
|
||||
*/
|
||||
class Style {
|
||||
OnnxTensor ttlTensor;
|
||||
OnnxTensor dpTensor;
|
||||
|
||||
Style(OnnxTensor ttlTensor, OnnxTensor dpTensor) {
|
||||
this.ttlTensor = ttlTensor;
|
||||
this.dpTensor = dpTensor;
|
||||
}
|
||||
|
||||
public void close() throws OrtException {
|
||||
if (ttlTensor != null) ttlTensor.close();
|
||||
if (dpTensor != null) dpTensor.close();
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* TTS result holder
|
||||
*/
|
||||
class TTSResult {
|
||||
float[] wav;
|
||||
float[] duration;
|
||||
|
||||
TTSResult(float[] wav, float[] duration) {
|
||||
this.wav = wav;
|
||||
this.duration = duration;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Noisy latent result holder
|
||||
*/
|
||||
class NoisyLatentResult {
|
||||
float[][][] noisyLatent;
|
||||
float[][][] latentMask;
|
||||
|
||||
NoisyLatentResult(float[][][] noisyLatent, float[][][] latentMask) {
|
||||
this.noisyLatent = noisyLatent;
|
||||
this.latentMask = latentMask;
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Helper utility class
|
||||
*/
|
||||
public class Helper {
|
||||
|
||||
/**
|
||||
* Load voice style from JSON files
|
||||
*/
|
||||
public static Style loadVoiceStyle(List<String> voiceStylePaths, boolean verbose, OrtEnvironment env)
|
||||
throws IOException, OrtException {
|
||||
int bsz = voiceStylePaths.size();
|
||||
|
||||
// Read first file to get dimensions
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
JsonNode firstRoot = mapper.readTree(new File(voiceStylePaths.get(0)));
|
||||
|
||||
long[] ttlDims = new long[3];
|
||||
for (int i = 0; i < 3; i++) {
|
||||
ttlDims[i] = firstRoot.get("style_ttl").get("dims").get(i).asLong();
|
||||
}
|
||||
long[] dpDims = new long[3];
|
||||
for (int i = 0; i < 3; i++) {
|
||||
dpDims[i] = firstRoot.get("style_dp").get("dims").get(i).asLong();
|
||||
}
|
||||
|
||||
long ttlDim1 = ttlDims[1];
|
||||
long ttlDim2 = ttlDims[2];
|
||||
long dpDim1 = dpDims[1];
|
||||
long dpDim2 = dpDims[2];
|
||||
|
||||
// Pre-allocate arrays with full batch size
|
||||
int ttlSize = (int) (bsz * ttlDim1 * ttlDim2);
|
||||
int dpSize = (int) (bsz * dpDim1 * dpDim2);
|
||||
float[] ttlFlat = new float[ttlSize];
|
||||
float[] dpFlat = new float[dpSize];
|
||||
|
||||
// Fill in the data
|
||||
for (int i = 0; i < bsz; i++) {
|
||||
JsonNode root = mapper.readTree(new File(voiceStylePaths.get(i)));
|
||||
|
||||
// Flatten TTL data
|
||||
int ttlOffset = (int) (i * ttlDim1 * ttlDim2);
|
||||
int idx = 0;
|
||||
JsonNode ttlData = root.get("style_ttl").get("data");
|
||||
for (JsonNode batch : ttlData) {
|
||||
for (JsonNode row : batch) {
|
||||
for (JsonNode val : row) {
|
||||
ttlFlat[ttlOffset + idx++] = (float) val.asDouble();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Flatten DP data
|
||||
int dpOffset = (int) (i * dpDim1 * dpDim2);
|
||||
idx = 0;
|
||||
JsonNode dpData = root.get("style_dp").get("data");
|
||||
for (JsonNode batch : dpData) {
|
||||
for (JsonNode row : batch) {
|
||||
for (JsonNode val : row) {
|
||||
dpFlat[dpOffset + idx++] = (float) val.asDouble();
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
long[] ttlShape = {bsz, ttlDim1, ttlDim2};
|
||||
long[] dpShape = {bsz, dpDim1, dpDim2};
|
||||
|
||||
OnnxTensor ttlTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(ttlFlat), ttlShape);
|
||||
OnnxTensor dpTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(dpFlat), dpShape);
|
||||
|
||||
if (verbose) {
|
||||
System.out.println("Loaded " + bsz + " voice styles\n");
|
||||
}
|
||||
|
||||
return new Style(ttlTensor, dpTensor);
|
||||
}
|
||||
|
||||
/**
|
||||
* Load TTS components
|
||||
*/
|
||||
public static TextToSpeech loadTextToSpeech(String onnxDir, boolean useGpu, OrtEnvironment env)
|
||||
throws IOException, OrtException {
|
||||
if (useGpu) {
|
||||
throw new RuntimeException("GPU mode is not supported yet");
|
||||
}
|
||||
System.out.println("Using CPU for inference\n");
|
||||
|
||||
// Load config
|
||||
Config config = loadCfgs(onnxDir);
|
||||
|
||||
// Create session options
|
||||
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
|
||||
|
||||
// Load models
|
||||
OrtSession dpSession = env.createSession(onnxDir + "/duration_predictor.onnx", opts);
|
||||
OrtSession textEncSession = env.createSession(onnxDir + "/text_encoder.onnx", opts);
|
||||
OrtSession vectorEstSession = env.createSession(onnxDir + "/vector_estimator.onnx", opts);
|
||||
OrtSession vocoderSession = env.createSession(onnxDir + "/vocoder.onnx", opts);
|
||||
|
||||
// Load text processor
|
||||
UnicodeProcessor textProcessor = new UnicodeProcessor(onnxDir + "/unicode_indexer.json");
|
||||
|
||||
return new TextToSpeech(config, textProcessor, dpSession, textEncSession, vectorEstSession, vocoderSession);
|
||||
}
|
||||
|
||||
/**
|
||||
* Load configuration from JSON
|
||||
*/
|
||||
public static Config loadCfgs(String onnxDir) throws IOException {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
JsonNode root = mapper.readTree(new File(onnxDir + "/tts.json"));
|
||||
|
||||
Config config = new Config();
|
||||
config.ae = new Config.AEConfig();
|
||||
config.ae.sampleRate = root.get("ae").get("sample_rate").asInt();
|
||||
config.ae.baseChunkSize = root.get("ae").get("base_chunk_size").asInt();
|
||||
|
||||
config.ttl = new Config.TTLConfig();
|
||||
config.ttl.chunkCompressFactor = root.get("ttl").get("chunk_compress_factor").asInt();
|
||||
config.ttl.latentDim = root.get("ttl").get("latent_dim").asInt();
|
||||
|
||||
return config;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get latent mask from wav lengths
|
||||
*/
|
||||
public static float[][][] getLatentMask(long[] wavLengths, Config config) {
|
||||
long baseChunkSize = config.ae.baseChunkSize;
|
||||
long chunkCompressFactor = config.ttl.chunkCompressFactor;
|
||||
long latentSize = baseChunkSize * chunkCompressFactor;
|
||||
|
||||
long[] latentLengths = new long[wavLengths.length];
|
||||
long maxLen = 0;
|
||||
for (int i = 0; i < wavLengths.length; i++) {
|
||||
latentLengths[i] = (wavLengths[i] + latentSize - 1) / latentSize;
|
||||
maxLen = Math.max(maxLen, latentLengths[i]);
|
||||
}
|
||||
|
||||
float[][][] mask = new float[wavLengths.length][1][(int) maxLen];
|
||||
for (int i = 0; i < wavLengths.length; i++) {
|
||||
for (int j = 0; j < maxLen; j++) {
|
||||
mask[i][0][j] = j < latentLengths[i] ? 1.0f : 0.0f;
|
||||
}
|
||||
}
|
||||
return mask;
|
||||
}
|
||||
|
||||
/**
|
||||
* Write WAV file
|
||||
*/
|
||||
public static void writeWavFile(String filename, float[] audioData, int sampleRate) throws IOException {
|
||||
// Convert float to byte array
|
||||
byte[] bytes = new byte[audioData.length * 2];
|
||||
ByteBuffer buffer = ByteBuffer.wrap(bytes);
|
||||
buffer.order(ByteOrder.LITTLE_ENDIAN);
|
||||
|
||||
for (float sample : audioData) {
|
||||
short val = (short) Math.max(-32768, Math.min(32767, sample * 32767));
|
||||
buffer.putShort(val);
|
||||
}
|
||||
|
||||
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
|
||||
AudioFormat format = new AudioFormat(sampleRate, 16, 1, true, false);
|
||||
AudioInputStream ais = new AudioInputStream(bais, format, audioData.length);
|
||||
AudioSystem.write(ais, AudioFileFormat.Type.WAVE, new File(filename));
|
||||
}
|
||||
|
||||
/**
|
||||
* Sanitize filename
|
||||
*/
|
||||
public static String sanitizeFilename(String text, int maxLen) {
|
||||
if (text.length() > maxLen) {
|
||||
text = text.substring(0, maxLen);
|
||||
}
|
||||
return text.replaceAll("[^a-zA-Z0-9]", "_");
|
||||
}
|
||||
|
||||
/**
|
||||
* Timer utility
|
||||
*/
|
||||
public static <T> T timer(String name, java.util.function.Supplier<T> fn) {
|
||||
long start = System.currentTimeMillis();
|
||||
System.out.println(name + "...");
|
||||
T result = fn.get();
|
||||
long elapsed = System.currentTimeMillis() - start;
|
||||
System.out.printf(" -> %s completed in %.2f sec\n", name, elapsed / 1000.0);
|
||||
return result;
|
||||
}
|
||||
|
||||
/**
|
||||
* Create float tensor from 3D array
|
||||
*/
|
||||
public static OnnxTensor createFloatTensor(float[][][] array, OrtEnvironment env) throws OrtException {
|
||||
int dim0 = array.length;
|
||||
int dim1 = array[0].length;
|
||||
int dim2 = array[0][0].length;
|
||||
|
||||
float[] flat = new float[dim0 * dim1 * dim2];
|
||||
int idx = 0;
|
||||
for (int i = 0; i < dim0; i++) {
|
||||
for (int j = 0; j < dim1; j++) {
|
||||
for (int k = 0; k < dim2; k++) {
|
||||
flat[idx++] = array[i][j][k];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
long[] shape = {dim0, dim1, dim2};
|
||||
return OnnxTensor.createTensor(env, FloatBuffer.wrap(flat), shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* Create long tensor from 2D array
|
||||
*/
|
||||
public static OnnxTensor createLongTensor(long[][] array, OrtEnvironment env) throws OrtException {
|
||||
int dim0 = array.length;
|
||||
int dim1 = array[0].length;
|
||||
|
||||
long[] flat = new long[dim0 * dim1];
|
||||
int idx = 0;
|
||||
for (int i = 0; i < dim0; i++) {
|
||||
for (int j = 0; j < dim1; j++) {
|
||||
flat[idx++] = array[i][j];
|
||||
}
|
||||
}
|
||||
|
||||
long[] shape = {dim0, dim1};
|
||||
return OnnxTensor.createTensor(env, LongBuffer.wrap(flat), shape);
|
||||
}
|
||||
|
||||
/**
|
||||
* Load JSON long array
|
||||
*/
|
||||
public static long[] loadJsonLongArray(String filePath) throws IOException {
|
||||
ObjectMapper mapper = new ObjectMapper();
|
||||
JsonNode root = mapper.readTree(new File(filePath));
|
||||
|
||||
long[] result = new long[root.size()];
|
||||
for (int i = 0; i < root.size(); i++) {
|
||||
result[i] = root.get(i).asLong();
|
||||
}
|
||||
return result;
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user