import ai.onnxruntime.*; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.ObjectMapper; import javax.sound.sampled.AudioFileFormat; import javax.sound.sampled.AudioFormat; import javax.sound.sampled.AudioInputStream; import javax.sound.sampled.AudioSystem; import java.io.*; import java.nio.ByteBuffer; import java.nio.ByteOrder; import java.nio.FloatBuffer; import java.nio.LongBuffer; import java.nio.file.Files; import java.nio.file.Paths; import java.text.Normalizer; import java.util.*; /** * Configuration classes */ class Config { static class AEConfig { int sampleRate; int baseChunkSize; } static class TTLConfig { int chunkCompressFactor; int latentDim; } AEConfig ae; TTLConfig ttl; } /** * Voice Style Data from JSON */ class VoiceStyleData { static class StyleData { float[][][] data; long[] dims; String type; } StyleData styleTtl; StyleData styleDp; } /** * Unicode text processor */ class UnicodeProcessor { private long[] indexer; public UnicodeProcessor(String unicodeIndexerJsonPath) throws IOException { this.indexer = Helper.loadJsonLongArray(unicodeIndexerJsonPath); } public TextProcessResult call(List textList) { List processedTexts = new ArrayList<>(); for (String text : textList) { processedTexts.add(preprocessText(text)); } int[] textIdsLengths = new int[processedTexts.size()]; int maxLen = 0; for (int i = 0; i < processedTexts.size(); i++) { textIdsLengths[i] = processedTexts.get(i).length(); maxLen = Math.max(maxLen, textIdsLengths[i]); } long[][] textIds = new long[processedTexts.size()][maxLen]; for (int i = 0; i < processedTexts.size(); i++) { int[] unicodeVals = textToUnicodeValues(processedTexts.get(i)); for (int j = 0; j < unicodeVals.length; j++) { textIds[i][j] = indexer[unicodeVals[j]]; } } float[][][] textMask = getTextMask(textIdsLengths); return new TextProcessResult(textIds, textMask); } private String preprocessText(String text) { return Normalizer.normalize(text, Normalizer.Form.NFKD); } private int[] textToUnicodeValues(String text) { int[] values = new int[text.length()]; for (int i = 0; i < text.length(); i++) { values[i] = text.codePointAt(i); } return values; } private float[][][] getTextMask(int[] lengths) { int bsz = lengths.length; int maxLen = 0; for (int len : lengths) { maxLen = Math.max(maxLen, len); } float[][][] mask = new float[bsz][1][maxLen]; for (int i = 0; i < bsz; i++) { for (int j = 0; j < maxLen; j++) { mask[i][0][j] = j < lengths[i] ? 1.0f : 0.0f; } } return mask; } static class TextProcessResult { long[][] textIds; float[][][] textMask; TextProcessResult(long[][] textIds, float[][][] textMask) { this.textIds = textIds; this.textMask = textMask; } } } /** * Text-to-Speech inference class */ class TextToSpeech { private Config config; private UnicodeProcessor textProcessor; private OrtSession dpSession; private OrtSession textEncSession; private OrtSession vectorEstSession; private OrtSession vocoderSession; public int sampleRate; private int baseChunkSize; private int chunkCompress; private int ldim; public TextToSpeech(Config config, UnicodeProcessor textProcessor, OrtSession dpSession, OrtSession textEncSession, OrtSession vectorEstSession, OrtSession vocoderSession) { this.config = config; this.textProcessor = textProcessor; this.dpSession = dpSession; this.textEncSession = textEncSession; this.vectorEstSession = vectorEstSession; this.vocoderSession = vocoderSession; this.sampleRate = config.ae.sampleRate; this.baseChunkSize = config.ae.baseChunkSize; this.chunkCompress = config.ttl.chunkCompressFactor; this.ldim = config.ttl.latentDim; } public TTSResult call(List textList, Style style, int totalStep, OrtEnvironment env) throws OrtException { int bsz = textList.size(); // Process text UnicodeProcessor.TextProcessResult textResult = textProcessor.call(textList); long[][] textIds = textResult.textIds; float[][][] textMask = textResult.textMask; // Create tensors OnnxTensor textIdsTensor = Helper.createLongTensor(textIds, env); OnnxTensor textMaskTensor = Helper.createFloatTensor(textMask, env); // Predict duration Map dpInputs = new HashMap<>(); dpInputs.put("text_ids", textIdsTensor); dpInputs.put("style_dp", style.dpTensor); dpInputs.put("text_mask", textMaskTensor); OrtSession.Result dpResult = dpSession.run(dpInputs); Object dpValue = dpResult.get(0).getValue(); float[] duration; if (dpValue instanceof float[][]) { duration = ((float[][]) dpValue)[0]; } else { duration = (float[]) dpValue; } // Encode text Map textEncInputs = new HashMap<>(); textEncInputs.put("text_ids", textIdsTensor); textEncInputs.put("style_ttl", style.ttlTensor); textEncInputs.put("text_mask", textMaskTensor); OrtSession.Result textEncResult = textEncSession.run(textEncInputs); OnnxTensor textEmbTensor = (OnnxTensor) textEncResult.get(0); // Sample noisy latent NoisyLatentResult noisyLatentResult = sampleNoisyLatent(duration); float[][][] xt = noisyLatentResult.noisyLatent; float[][][] latentMask = noisyLatentResult.latentMask; // Prepare constant tensors float[] totalStepArray = new float[bsz]; Arrays.fill(totalStepArray, (float) totalStep); OnnxTensor totalStepTensor = OnnxTensor.createTensor(env, totalStepArray); // Denoising loop for (int step = 0; step < totalStep; step++) { float[] currentStepArray = new float[bsz]; Arrays.fill(currentStepArray, (float) step); OnnxTensor currentStepTensor = OnnxTensor.createTensor(env, currentStepArray); OnnxTensor noisyLatentTensor = Helper.createFloatTensor(xt, env); OnnxTensor latentMaskTensor = Helper.createFloatTensor(latentMask, env); OnnxTensor textMaskTensor2 = Helper.createFloatTensor(textMask, env); Map vectorEstInputs = new HashMap<>(); vectorEstInputs.put("noisy_latent", noisyLatentTensor); vectorEstInputs.put("text_emb", textEmbTensor); vectorEstInputs.put("style_ttl", style.ttlTensor); vectorEstInputs.put("latent_mask", latentMaskTensor); vectorEstInputs.put("text_mask", textMaskTensor2); vectorEstInputs.put("current_step", currentStepTensor); vectorEstInputs.put("total_step", totalStepTensor); OrtSession.Result vectorEstResult = vectorEstSession.run(vectorEstInputs); float[][][] denoised = (float[][][]) vectorEstResult.get(0).getValue(); // Update latent xt = denoised; // Clean up currentStepTensor.close(); noisyLatentTensor.close(); latentMaskTensor.close(); textMaskTensor2.close(); vectorEstResult.close(); } // Generate waveform OnnxTensor finalLatentTensor = Helper.createFloatTensor(xt, env); Map vocoderInputs = new HashMap<>(); vocoderInputs.put("latent", finalLatentTensor); OrtSession.Result vocoderResult = vocoderSession.run(vocoderInputs); float[][] wavBatch = (float[][]) vocoderResult.get(0).getValue(); float[] wav = wavBatch[0]; // Clean up textIdsTensor.close(); textMaskTensor.close(); dpResult.close(); textEncResult.close(); totalStepTensor.close(); finalLatentTensor.close(); vocoderResult.close(); return new TTSResult(wav, duration); } private NoisyLatentResult sampleNoisyLatent(float[] duration) { int bsz = duration.length; float maxDur = 0; for (float d : duration) { maxDur = Math.max(maxDur, d); } long wavLenMax = (long) (maxDur * sampleRate); long[] wavLengths = new long[bsz]; for (int i = 0; i < bsz; i++) { wavLengths[i] = (long) (duration[i] * sampleRate); } int chunkSize = baseChunkSize * chunkCompress; int latentLen = (int) ((wavLenMax + chunkSize - 1) / chunkSize); int latentDim = ldim * chunkCompress; Random rng = new Random(); float[][][] noisyLatent = new float[bsz][latentDim][latentLen]; for (int b = 0; b < bsz; b++) { for (int d = 0; d < latentDim; d++) { for (int t = 0; t < latentLen; t++) { // Box-Muller transform double u1 = Math.max(1e-10, rng.nextDouble()); double u2 = rng.nextDouble(); noisyLatent[b][d][t] = (float) (Math.sqrt(-2.0 * Math.log(u1)) * Math.cos(2.0 * Math.PI * u2)); } } } float[][][] latentMask = Helper.getLatentMask(wavLengths, config); // Apply mask for (int b = 0; b < bsz; b++) { for (int d = 0; d < latentDim; d++) { for (int t = 0; t < latentLen; t++) { noisyLatent[b][d][t] *= latentMask[b][0][t]; } } } return new NoisyLatentResult(noisyLatent, latentMask); } public void close() throws OrtException { if (dpSession != null) dpSession.close(); if (textEncSession != null) textEncSession.close(); if (vectorEstSession != null) vectorEstSession.close(); if (vocoderSession != null) vocoderSession.close(); } } /** * Style holder class */ class Style { OnnxTensor ttlTensor; OnnxTensor dpTensor; Style(OnnxTensor ttlTensor, OnnxTensor dpTensor) { this.ttlTensor = ttlTensor; this.dpTensor = dpTensor; } public void close() throws OrtException { if (ttlTensor != null) ttlTensor.close(); if (dpTensor != null) dpTensor.close(); } } /** * TTS result holder */ class TTSResult { float[] wav; float[] duration; TTSResult(float[] wav, float[] duration) { this.wav = wav; this.duration = duration; } } /** * Noisy latent result holder */ class NoisyLatentResult { float[][][] noisyLatent; float[][][] latentMask; NoisyLatentResult(float[][][] noisyLatent, float[][][] latentMask) { this.noisyLatent = noisyLatent; this.latentMask = latentMask; } } /** * Helper utility class */ public class Helper { /** * Load voice style from JSON files */ public static Style loadVoiceStyle(List voiceStylePaths, boolean verbose, OrtEnvironment env) throws IOException, OrtException { int bsz = voiceStylePaths.size(); // Read first file to get dimensions ObjectMapper mapper = new ObjectMapper(); JsonNode firstRoot = mapper.readTree(new File(voiceStylePaths.get(0))); long[] ttlDims = new long[3]; for (int i = 0; i < 3; i++) { ttlDims[i] = firstRoot.get("style_ttl").get("dims").get(i).asLong(); } long[] dpDims = new long[3]; for (int i = 0; i < 3; i++) { dpDims[i] = firstRoot.get("style_dp").get("dims").get(i).asLong(); } long ttlDim1 = ttlDims[1]; long ttlDim2 = ttlDims[2]; long dpDim1 = dpDims[1]; long dpDim2 = dpDims[2]; // Pre-allocate arrays with full batch size int ttlSize = (int) (bsz * ttlDim1 * ttlDim2); int dpSize = (int) (bsz * dpDim1 * dpDim2); float[] ttlFlat = new float[ttlSize]; float[] dpFlat = new float[dpSize]; // Fill in the data for (int i = 0; i < bsz; i++) { JsonNode root = mapper.readTree(new File(voiceStylePaths.get(i))); // Flatten TTL data int ttlOffset = (int) (i * ttlDim1 * ttlDim2); int idx = 0; JsonNode ttlData = root.get("style_ttl").get("data"); for (JsonNode batch : ttlData) { for (JsonNode row : batch) { for (JsonNode val : row) { ttlFlat[ttlOffset + idx++] = (float) val.asDouble(); } } } // Flatten DP data int dpOffset = (int) (i * dpDim1 * dpDim2); idx = 0; JsonNode dpData = root.get("style_dp").get("data"); for (JsonNode batch : dpData) { for (JsonNode row : batch) { for (JsonNode val : row) { dpFlat[dpOffset + idx++] = (float) val.asDouble(); } } } } long[] ttlShape = {bsz, ttlDim1, ttlDim2}; long[] dpShape = {bsz, dpDim1, dpDim2}; OnnxTensor ttlTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(ttlFlat), ttlShape); OnnxTensor dpTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(dpFlat), dpShape); if (verbose) { System.out.println("Loaded " + bsz + " voice styles\n"); } return new Style(ttlTensor, dpTensor); } /** * Load TTS components */ public static TextToSpeech loadTextToSpeech(String onnxDir, boolean useGpu, OrtEnvironment env) throws IOException, OrtException { if (useGpu) { throw new RuntimeException("GPU mode is not supported yet"); } System.out.println("Using CPU for inference\n"); // Load config Config config = loadCfgs(onnxDir); // Create session options OrtSession.SessionOptions opts = new OrtSession.SessionOptions(); // Load models OrtSession dpSession = env.createSession(onnxDir + "/duration_predictor.onnx", opts); OrtSession textEncSession = env.createSession(onnxDir + "/text_encoder.onnx", opts); OrtSession vectorEstSession = env.createSession(onnxDir + "/vector_estimator.onnx", opts); OrtSession vocoderSession = env.createSession(onnxDir + "/vocoder.onnx", opts); // Load text processor UnicodeProcessor textProcessor = new UnicodeProcessor(onnxDir + "/unicode_indexer.json"); return new TextToSpeech(config, textProcessor, dpSession, textEncSession, vectorEstSession, vocoderSession); } /** * Load configuration from JSON */ public static Config loadCfgs(String onnxDir) throws IOException { ObjectMapper mapper = new ObjectMapper(); JsonNode root = mapper.readTree(new File(onnxDir + "/tts.json")); Config config = new Config(); config.ae = new Config.AEConfig(); config.ae.sampleRate = root.get("ae").get("sample_rate").asInt(); config.ae.baseChunkSize = root.get("ae").get("base_chunk_size").asInt(); config.ttl = new Config.TTLConfig(); config.ttl.chunkCompressFactor = root.get("ttl").get("chunk_compress_factor").asInt(); config.ttl.latentDim = root.get("ttl").get("latent_dim").asInt(); return config; } /** * Get latent mask from wav lengths */ public static float[][][] getLatentMask(long[] wavLengths, Config config) { long baseChunkSize = config.ae.baseChunkSize; long chunkCompressFactor = config.ttl.chunkCompressFactor; long latentSize = baseChunkSize * chunkCompressFactor; long[] latentLengths = new long[wavLengths.length]; long maxLen = 0; for (int i = 0; i < wavLengths.length; i++) { latentLengths[i] = (wavLengths[i] + latentSize - 1) / latentSize; maxLen = Math.max(maxLen, latentLengths[i]); } float[][][] mask = new float[wavLengths.length][1][(int) maxLen]; for (int i = 0; i < wavLengths.length; i++) { for (int j = 0; j < maxLen; j++) { mask[i][0][j] = j < latentLengths[i] ? 1.0f : 0.0f; } } return mask; } /** * Write WAV file */ public static void writeWavFile(String filename, float[] audioData, int sampleRate) throws IOException { // Convert float to byte array byte[] bytes = new byte[audioData.length * 2]; ByteBuffer buffer = ByteBuffer.wrap(bytes); buffer.order(ByteOrder.LITTLE_ENDIAN); for (float sample : audioData) { short val = (short) Math.max(-32768, Math.min(32767, sample * 32767)); buffer.putShort(val); } ByteArrayInputStream bais = new ByteArrayInputStream(bytes); AudioFormat format = new AudioFormat(sampleRate, 16, 1, true, false); AudioInputStream ais = new AudioInputStream(bais, format, audioData.length); AudioSystem.write(ais, AudioFileFormat.Type.WAVE, new File(filename)); } /** * Sanitize filename */ public static String sanitizeFilename(String text, int maxLen) { if (text.length() > maxLen) { text = text.substring(0, maxLen); } return text.replaceAll("[^a-zA-Z0-9]", "_"); } /** * Timer utility */ public static T timer(String name, java.util.function.Supplier fn) { long start = System.currentTimeMillis(); System.out.println(name + "..."); T result = fn.get(); long elapsed = System.currentTimeMillis() - start; System.out.printf(" -> %s completed in %.2f sec\n", name, elapsed / 1000.0); return result; } /** * Create float tensor from 3D array */ public static OnnxTensor createFloatTensor(float[][][] array, OrtEnvironment env) throws OrtException { int dim0 = array.length; int dim1 = array[0].length; int dim2 = array[0][0].length; float[] flat = new float[dim0 * dim1 * dim2]; int idx = 0; for (int i = 0; i < dim0; i++) { for (int j = 0; j < dim1; j++) { for (int k = 0; k < dim2; k++) { flat[idx++] = array[i][j][k]; } } } long[] shape = {dim0, dim1, dim2}; return OnnxTensor.createTensor(env, FloatBuffer.wrap(flat), shape); } /** * Create long tensor from 2D array */ public static OnnxTensor createLongTensor(long[][] array, OrtEnvironment env) throws OrtException { int dim0 = array.length; int dim1 = array[0].length; long[] flat = new long[dim0 * dim1]; int idx = 0; for (int i = 0; i < dim0; i++) { for (int j = 0; j < dim1; j++) { flat[idx++] = array[i][j]; } } long[] shape = {dim0, dim1}; return OnnxTensor.createTensor(env, LongBuffer.wrap(flat), shape); } /** * Load JSON long array */ public static long[] loadJsonLongArray(String filePath) throws IOException { ObjectMapper mapper = new ObjectMapper(); JsonNode root = mapper.readTree(new File(filePath)); long[] result = new long[root.size()]; for (int i = 0; i < root.size(); i++) { result[i] = root.get(i).asLong(); } return result; } }