This commit is contained in:
ANLGBOY
2025-11-19 01:18:16 +09:00
commit d31536d9fc
74 changed files with 10681 additions and 0 deletions
+597
View File
@@ -0,0 +1,597 @@
import ai.onnxruntime.*;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.ObjectMapper;
import javax.sound.sampled.AudioFileFormat;
import javax.sound.sampled.AudioFormat;
import javax.sound.sampled.AudioInputStream;
import javax.sound.sampled.AudioSystem;
import java.io.*;
import java.nio.ByteBuffer;
import java.nio.ByteOrder;
import java.nio.FloatBuffer;
import java.nio.LongBuffer;
import java.nio.file.Files;
import java.nio.file.Paths;
import java.text.Normalizer;
import java.util.*;
/**
* Configuration classes
*/
class Config {
static class AEConfig {
int sampleRate;
int baseChunkSize;
}
static class TTLConfig {
int chunkCompressFactor;
int latentDim;
}
AEConfig ae;
TTLConfig ttl;
}
/**
* Voice Style Data from JSON
*/
class VoiceStyleData {
static class StyleData {
float[][][] data;
long[] dims;
String type;
}
StyleData styleTtl;
StyleData styleDp;
}
/**
* Unicode text processor
*/
class UnicodeProcessor {
private long[] indexer;
public UnicodeProcessor(String unicodeIndexerJsonPath) throws IOException {
this.indexer = Helper.loadJsonLongArray(unicodeIndexerJsonPath);
}
public TextProcessResult call(List<String> textList) {
List<String> processedTexts = new ArrayList<>();
for (String text : textList) {
processedTexts.add(preprocessText(text));
}
int[] textIdsLengths = new int[processedTexts.size()];
int maxLen = 0;
for (int i = 0; i < processedTexts.size(); i++) {
textIdsLengths[i] = processedTexts.get(i).length();
maxLen = Math.max(maxLen, textIdsLengths[i]);
}
long[][] textIds = new long[processedTexts.size()][maxLen];
for (int i = 0; i < processedTexts.size(); i++) {
int[] unicodeVals = textToUnicodeValues(processedTexts.get(i));
for (int j = 0; j < unicodeVals.length; j++) {
textIds[i][j] = indexer[unicodeVals[j]];
}
}
float[][][] textMask = getTextMask(textIdsLengths);
return new TextProcessResult(textIds, textMask);
}
private String preprocessText(String text) {
return Normalizer.normalize(text, Normalizer.Form.NFKD);
}
private int[] textToUnicodeValues(String text) {
int[] values = new int[text.length()];
for (int i = 0; i < text.length(); i++) {
values[i] = text.codePointAt(i);
}
return values;
}
private float[][][] getTextMask(int[] lengths) {
int bsz = lengths.length;
int maxLen = 0;
for (int len : lengths) {
maxLen = Math.max(maxLen, len);
}
float[][][] mask = new float[bsz][1][maxLen];
for (int i = 0; i < bsz; i++) {
for (int j = 0; j < maxLen; j++) {
mask[i][0][j] = j < lengths[i] ? 1.0f : 0.0f;
}
}
return mask;
}
static class TextProcessResult {
long[][] textIds;
float[][][] textMask;
TextProcessResult(long[][] textIds, float[][][] textMask) {
this.textIds = textIds;
this.textMask = textMask;
}
}
}
/**
* Text-to-Speech inference class
*/
class TextToSpeech {
private Config config;
private UnicodeProcessor textProcessor;
private OrtSession dpSession;
private OrtSession textEncSession;
private OrtSession vectorEstSession;
private OrtSession vocoderSession;
public int sampleRate;
private int baseChunkSize;
private int chunkCompress;
private int ldim;
public TextToSpeech(Config config, UnicodeProcessor textProcessor,
OrtSession dpSession, OrtSession textEncSession,
OrtSession vectorEstSession, OrtSession vocoderSession) {
this.config = config;
this.textProcessor = textProcessor;
this.dpSession = dpSession;
this.textEncSession = textEncSession;
this.vectorEstSession = vectorEstSession;
this.vocoderSession = vocoderSession;
this.sampleRate = config.ae.sampleRate;
this.baseChunkSize = config.ae.baseChunkSize;
this.chunkCompress = config.ttl.chunkCompressFactor;
this.ldim = config.ttl.latentDim;
}
public TTSResult call(List<String> textList, Style style, int totalStep, OrtEnvironment env)
throws OrtException {
int bsz = textList.size();
// Process text
UnicodeProcessor.TextProcessResult textResult = textProcessor.call(textList);
long[][] textIds = textResult.textIds;
float[][][] textMask = textResult.textMask;
// Create tensors
OnnxTensor textIdsTensor = Helper.createLongTensor(textIds, env);
OnnxTensor textMaskTensor = Helper.createFloatTensor(textMask, env);
// Predict duration
Map<String, OnnxTensor> dpInputs = new HashMap<>();
dpInputs.put("text_ids", textIdsTensor);
dpInputs.put("style_dp", style.dpTensor);
dpInputs.put("text_mask", textMaskTensor);
OrtSession.Result dpResult = dpSession.run(dpInputs);
Object dpValue = dpResult.get(0).getValue();
float[] duration;
if (dpValue instanceof float[][]) {
duration = ((float[][]) dpValue)[0];
} else {
duration = (float[]) dpValue;
}
// Encode text
Map<String, OnnxTensor> textEncInputs = new HashMap<>();
textEncInputs.put("text_ids", textIdsTensor);
textEncInputs.put("style_ttl", style.ttlTensor);
textEncInputs.put("text_mask", textMaskTensor);
OrtSession.Result textEncResult = textEncSession.run(textEncInputs);
OnnxTensor textEmbTensor = (OnnxTensor) textEncResult.get(0);
// Sample noisy latent
NoisyLatentResult noisyLatentResult = sampleNoisyLatent(duration);
float[][][] xt = noisyLatentResult.noisyLatent;
float[][][] latentMask = noisyLatentResult.latentMask;
// Prepare constant tensors
float[] totalStepArray = new float[bsz];
Arrays.fill(totalStepArray, (float) totalStep);
OnnxTensor totalStepTensor = OnnxTensor.createTensor(env, totalStepArray);
// Denoising loop
for (int step = 0; step < totalStep; step++) {
float[] currentStepArray = new float[bsz];
Arrays.fill(currentStepArray, (float) step);
OnnxTensor currentStepTensor = OnnxTensor.createTensor(env, currentStepArray);
OnnxTensor noisyLatentTensor = Helper.createFloatTensor(xt, env);
OnnxTensor latentMaskTensor = Helper.createFloatTensor(latentMask, env);
OnnxTensor textMaskTensor2 = Helper.createFloatTensor(textMask, env);
Map<String, OnnxTensor> vectorEstInputs = new HashMap<>();
vectorEstInputs.put("noisy_latent", noisyLatentTensor);
vectorEstInputs.put("text_emb", textEmbTensor);
vectorEstInputs.put("style_ttl", style.ttlTensor);
vectorEstInputs.put("latent_mask", latentMaskTensor);
vectorEstInputs.put("text_mask", textMaskTensor2);
vectorEstInputs.put("current_step", currentStepTensor);
vectorEstInputs.put("total_step", totalStepTensor);
OrtSession.Result vectorEstResult = vectorEstSession.run(vectorEstInputs);
float[][][] denoised = (float[][][]) vectorEstResult.get(0).getValue();
// Update latent
xt = denoised;
// Clean up
currentStepTensor.close();
noisyLatentTensor.close();
latentMaskTensor.close();
textMaskTensor2.close();
vectorEstResult.close();
}
// Generate waveform
OnnxTensor finalLatentTensor = Helper.createFloatTensor(xt, env);
Map<String, OnnxTensor> vocoderInputs = new HashMap<>();
vocoderInputs.put("latent", finalLatentTensor);
OrtSession.Result vocoderResult = vocoderSession.run(vocoderInputs);
float[][] wavBatch = (float[][]) vocoderResult.get(0).getValue();
float[] wav = wavBatch[0];
// Clean up
textIdsTensor.close();
textMaskTensor.close();
dpResult.close();
textEncResult.close();
totalStepTensor.close();
finalLatentTensor.close();
vocoderResult.close();
return new TTSResult(wav, duration);
}
private NoisyLatentResult sampleNoisyLatent(float[] duration) {
int bsz = duration.length;
float maxDur = 0;
for (float d : duration) {
maxDur = Math.max(maxDur, d);
}
long wavLenMax = (long) (maxDur * sampleRate);
long[] wavLengths = new long[bsz];
for (int i = 0; i < bsz; i++) {
wavLengths[i] = (long) (duration[i] * sampleRate);
}
int chunkSize = baseChunkSize * chunkCompress;
int latentLen = (int) ((wavLenMax + chunkSize - 1) / chunkSize);
int latentDim = ldim * chunkCompress;
Random rng = new Random();
float[][][] noisyLatent = new float[bsz][latentDim][latentLen];
for (int b = 0; b < bsz; b++) {
for (int d = 0; d < latentDim; d++) {
for (int t = 0; t < latentLen; t++) {
// Box-Muller transform
double u1 = Math.max(1e-10, rng.nextDouble());
double u2 = rng.nextDouble();
noisyLatent[b][d][t] = (float) (Math.sqrt(-2.0 * Math.log(u1)) * Math.cos(2.0 * Math.PI * u2));
}
}
}
float[][][] latentMask = Helper.getLatentMask(wavLengths, config);
// Apply mask
for (int b = 0; b < bsz; b++) {
for (int d = 0; d < latentDim; d++) {
for (int t = 0; t < latentLen; t++) {
noisyLatent[b][d][t] *= latentMask[b][0][t];
}
}
}
return new NoisyLatentResult(noisyLatent, latentMask);
}
public void close() throws OrtException {
if (dpSession != null) dpSession.close();
if (textEncSession != null) textEncSession.close();
if (vectorEstSession != null) vectorEstSession.close();
if (vocoderSession != null) vocoderSession.close();
}
}
/**
* Style holder class
*/
class Style {
OnnxTensor ttlTensor;
OnnxTensor dpTensor;
Style(OnnxTensor ttlTensor, OnnxTensor dpTensor) {
this.ttlTensor = ttlTensor;
this.dpTensor = dpTensor;
}
public void close() throws OrtException {
if (ttlTensor != null) ttlTensor.close();
if (dpTensor != null) dpTensor.close();
}
}
/**
* TTS result holder
*/
class TTSResult {
float[] wav;
float[] duration;
TTSResult(float[] wav, float[] duration) {
this.wav = wav;
this.duration = duration;
}
}
/**
* Noisy latent result holder
*/
class NoisyLatentResult {
float[][][] noisyLatent;
float[][][] latentMask;
NoisyLatentResult(float[][][] noisyLatent, float[][][] latentMask) {
this.noisyLatent = noisyLatent;
this.latentMask = latentMask;
}
}
/**
* Helper utility class
*/
public class Helper {
/**
* Load voice style from JSON files
*/
public static Style loadVoiceStyle(List<String> voiceStylePaths, boolean verbose, OrtEnvironment env)
throws IOException, OrtException {
int bsz = voiceStylePaths.size();
// Read first file to get dimensions
ObjectMapper mapper = new ObjectMapper();
JsonNode firstRoot = mapper.readTree(new File(voiceStylePaths.get(0)));
long[] ttlDims = new long[3];
for (int i = 0; i < 3; i++) {
ttlDims[i] = firstRoot.get("style_ttl").get("dims").get(i).asLong();
}
long[] dpDims = new long[3];
for (int i = 0; i < 3; i++) {
dpDims[i] = firstRoot.get("style_dp").get("dims").get(i).asLong();
}
long ttlDim1 = ttlDims[1];
long ttlDim2 = ttlDims[2];
long dpDim1 = dpDims[1];
long dpDim2 = dpDims[2];
// Pre-allocate arrays with full batch size
int ttlSize = (int) (bsz * ttlDim1 * ttlDim2);
int dpSize = (int) (bsz * dpDim1 * dpDim2);
float[] ttlFlat = new float[ttlSize];
float[] dpFlat = new float[dpSize];
// Fill in the data
for (int i = 0; i < bsz; i++) {
JsonNode root = mapper.readTree(new File(voiceStylePaths.get(i)));
// Flatten TTL data
int ttlOffset = (int) (i * ttlDim1 * ttlDim2);
int idx = 0;
JsonNode ttlData = root.get("style_ttl").get("data");
for (JsonNode batch : ttlData) {
for (JsonNode row : batch) {
for (JsonNode val : row) {
ttlFlat[ttlOffset + idx++] = (float) val.asDouble();
}
}
}
// Flatten DP data
int dpOffset = (int) (i * dpDim1 * dpDim2);
idx = 0;
JsonNode dpData = root.get("style_dp").get("data");
for (JsonNode batch : dpData) {
for (JsonNode row : batch) {
for (JsonNode val : row) {
dpFlat[dpOffset + idx++] = (float) val.asDouble();
}
}
}
}
long[] ttlShape = {bsz, ttlDim1, ttlDim2};
long[] dpShape = {bsz, dpDim1, dpDim2};
OnnxTensor ttlTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(ttlFlat), ttlShape);
OnnxTensor dpTensor = OnnxTensor.createTensor(env, FloatBuffer.wrap(dpFlat), dpShape);
if (verbose) {
System.out.println("Loaded " + bsz + " voice styles\n");
}
return new Style(ttlTensor, dpTensor);
}
/**
* Load TTS components
*/
public static TextToSpeech loadTextToSpeech(String onnxDir, boolean useGpu, OrtEnvironment env)
throws IOException, OrtException {
if (useGpu) {
throw new RuntimeException("GPU mode is not supported yet");
}
System.out.println("Using CPU for inference\n");
// Load config
Config config = loadCfgs(onnxDir);
// Create session options
OrtSession.SessionOptions opts = new OrtSession.SessionOptions();
// Load models
OrtSession dpSession = env.createSession(onnxDir + "/duration_predictor.onnx", opts);
OrtSession textEncSession = env.createSession(onnxDir + "/text_encoder.onnx", opts);
OrtSession vectorEstSession = env.createSession(onnxDir + "/vector_estimator.onnx", opts);
OrtSession vocoderSession = env.createSession(onnxDir + "/vocoder.onnx", opts);
// Load text processor
UnicodeProcessor textProcessor = new UnicodeProcessor(onnxDir + "/unicode_indexer.json");
return new TextToSpeech(config, textProcessor, dpSession, textEncSession, vectorEstSession, vocoderSession);
}
/**
* Load configuration from JSON
*/
public static Config loadCfgs(String onnxDir) throws IOException {
ObjectMapper mapper = new ObjectMapper();
JsonNode root = mapper.readTree(new File(onnxDir + "/tts.json"));
Config config = new Config();
config.ae = new Config.AEConfig();
config.ae.sampleRate = root.get("ae").get("sample_rate").asInt();
config.ae.baseChunkSize = root.get("ae").get("base_chunk_size").asInt();
config.ttl = new Config.TTLConfig();
config.ttl.chunkCompressFactor = root.get("ttl").get("chunk_compress_factor").asInt();
config.ttl.latentDim = root.get("ttl").get("latent_dim").asInt();
return config;
}
/**
* Get latent mask from wav lengths
*/
public static float[][][] getLatentMask(long[] wavLengths, Config config) {
long baseChunkSize = config.ae.baseChunkSize;
long chunkCompressFactor = config.ttl.chunkCompressFactor;
long latentSize = baseChunkSize * chunkCompressFactor;
long[] latentLengths = new long[wavLengths.length];
long maxLen = 0;
for (int i = 0; i < wavLengths.length; i++) {
latentLengths[i] = (wavLengths[i] + latentSize - 1) / latentSize;
maxLen = Math.max(maxLen, latentLengths[i]);
}
float[][][] mask = new float[wavLengths.length][1][(int) maxLen];
for (int i = 0; i < wavLengths.length; i++) {
for (int j = 0; j < maxLen; j++) {
mask[i][0][j] = j < latentLengths[i] ? 1.0f : 0.0f;
}
}
return mask;
}
/**
* Write WAV file
*/
public static void writeWavFile(String filename, float[] audioData, int sampleRate) throws IOException {
// Convert float to byte array
byte[] bytes = new byte[audioData.length * 2];
ByteBuffer buffer = ByteBuffer.wrap(bytes);
buffer.order(ByteOrder.LITTLE_ENDIAN);
for (float sample : audioData) {
short val = (short) Math.max(-32768, Math.min(32767, sample * 32767));
buffer.putShort(val);
}
ByteArrayInputStream bais = new ByteArrayInputStream(bytes);
AudioFormat format = new AudioFormat(sampleRate, 16, 1, true, false);
AudioInputStream ais = new AudioInputStream(bais, format, audioData.length);
AudioSystem.write(ais, AudioFileFormat.Type.WAVE, new File(filename));
}
/**
* Sanitize filename
*/
public static String sanitizeFilename(String text, int maxLen) {
if (text.length() > maxLen) {
text = text.substring(0, maxLen);
}
return text.replaceAll("[^a-zA-Z0-9]", "_");
}
/**
* Timer utility
*/
public static <T> T timer(String name, java.util.function.Supplier<T> fn) {
long start = System.currentTimeMillis();
System.out.println(name + "...");
T result = fn.get();
long elapsed = System.currentTimeMillis() - start;
System.out.printf(" -> %s completed in %.2f sec\n", name, elapsed / 1000.0);
return result;
}
/**
* Create float tensor from 3D array
*/
public static OnnxTensor createFloatTensor(float[][][] array, OrtEnvironment env) throws OrtException {
int dim0 = array.length;
int dim1 = array[0].length;
int dim2 = array[0][0].length;
float[] flat = new float[dim0 * dim1 * dim2];
int idx = 0;
for (int i = 0; i < dim0; i++) {
for (int j = 0; j < dim1; j++) {
for (int k = 0; k < dim2; k++) {
flat[idx++] = array[i][j][k];
}
}
}
long[] shape = {dim0, dim1, dim2};
return OnnxTensor.createTensor(env, FloatBuffer.wrap(flat), shape);
}
/**
* Create long tensor from 2D array
*/
public static OnnxTensor createLongTensor(long[][] array, OrtEnvironment env) throws OrtException {
int dim0 = array.length;
int dim1 = array[0].length;
long[] flat = new long[dim0 * dim1];
int idx = 0;
for (int i = 0; i < dim0; i++) {
for (int j = 0; j < dim1; j++) {
flat[idx++] = array[i][j];
}
}
long[] shape = {dim0, dim1};
return OnnxTensor.createTensor(env, LongBuffer.wrap(flat), shape);
}
/**
* Load JSON long array
*/
public static long[] loadJsonLongArray(String filePath) throws IOException {
ObjectMapper mapper = new ObjectMapper();
JsonNode root = mapper.readTree(new File(filePath));
long[] result = new long[root.size()];
for (int i = 0; i < root.size(); i++) {
result[i] = root.get(i).asLong();
}
return result;
}
}