java/LLMBenchmarkTester.java

731 lines
31 KiB
Java
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import java.io.File;
import java.io.IOException;
import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;
import java.lang.reflect.Field;
import java.net.URI;
import java.net.http.HttpClient;
import java.net.http.HttpRequest;
import java.net.http.HttpResponse;
import java.nio.charset.StandardCharsets;
import java.nio.file.Files;
import java.nio.file.Path;
import java.nio.file.StandardOpenOption;
import java.time.Duration;
import java.time.LocalDate;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.*;
import java.util.concurrent.*;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
import java.util.stream.Collectors;
/**
* <p>
* JDK版本必须大于或等于21, 直接运行将生成一份bat脚本或shell脚本, 下载JDK可以在浏览器打开链接按需下载:
* https://www.azul.com/downloads/?version=java-21-lts&package=jdk#zulu
* </p>
*/
public class LLMBenchmarkTester {
public static final String SEP = "=============================================================================================";
public static final Field[] PARAM_FIELD = ScriptParameter.class.getDeclaredFields();
public static final Pattern CONTENT_PATTERN = Pattern.compile("\"content\"\\s*:\\s*\"([^\"]*)\"");
public static void main(String[] args) throws Exception {
if (args == null || args.length == 0) {
createRunScript();
} else {
ScriptParameter param = readScriptParameter(args);
printScriptParam(param);
List<ExecuteContext> executeContexts = new ArrayList<>();
if (param.isTestChatModel()) {
TextQuestion textQuestion = new TextQuestion(param);
List<ExecuteContext> tasks = submit(
param.modelName.split(","),
param.threadSize.split(","),
param,
textQuestion,
null
);
if (!tasks.isEmpty()) {
executeContexts.addAll(tasks);
}
}
if (param.isTestVlModel()) {
ImageQuestion imageQuestion = new ImageQuestion(param);
List<ExecuteContext> tasks = submit(
param.vlModelName.split(","),
param.threadSize.split(","),
param,
null,
imageQuestion
);
if (!tasks.isEmpty()) {
executeContexts.addAll(tasks);
}
}
if (!executeContexts.isEmpty()) {
String today = LocalDate.now().format(DateTimeFormatter.ofPattern("yyyyMMdd"));
String logName = String.format("llm_bench_%s.log", today);
Path logPath = Path.of(System.getProperty("user.dir"), logName);
for (ExecuteContext executeContext : executeContexts) {
executeContext.latch.await();
writeLog(logPath, executeContext);
executeContext.executor.shutdownNow();
executeContext.sessionMap.clear();
}
}
}
}
private static void writeLog(Path logPath, ExecuteContext executeContext) throws IOException {
Collection<HttpContext> values = executeContext.sessionMap.values();
double avgResponse = values.stream().mapToLong(HttpContext::toEndMillis).filter(d -> d > 0L).average().orElse(0D);
long totalTime = values.stream().mapToLong(HttpContext::toFinishMillis).sum();
long successNum = values.stream().filter(d -> d.success).count();
long maxResponse = values.stream().mapToLong(HttpContext::toEndMillis).max().orElse(0L);
int outTextLength = values.stream().mapToInt(d -> d.outTexts.stream().mapToInt(s -> s != null ? s.length() : 0).sum()).sum();
int outTextCount = values.stream().mapToInt(d -> d.outTexts != null ? d.outTexts.size() : 0).sum();
String format = """
------------------------------------------
模型: %s
并发量: %d
问题数量: %d
成功: %d
首次响应最长耗时: %d毫秒
首次响应平均耗时: %f毫秒
一共输出: %d字, 共输出%d次, 共计耗时:%d毫秒
------------------------------------------
%s
""";
String msg = String.format(
format,
executeContext.model,
executeContext.threadSize,
executeContext.sessionMap.size(),
successNum,
maxResponse,
avgResponse,
outTextLength,
outTextCount,
totalTime,
System.lineSeparator()
);
Files.writeString(logPath, msg, StandardCharsets.UTF_8, StandardOpenOption.CREATE, StandardOpenOption.APPEND);
}
private static List<ExecuteContext> submit(String[] models, String[] threadSizeStr, ScriptParameter param, TextQuestion textQuestion, ImageQuestion imageQuestion) {
List<Integer> threadSizeList = Arrays.stream(threadSizeStr).map(s -> Integer.parseInt(s.strip())).toList();
List<ExecuteContext> executeContexts = new ArrayList<>();
for (Integer threadSize : threadSizeList) {
for (String model : models) {
if (textQuestion != null) {
executeContexts.add(execute(threadSize, model, param, textQuestion.getRequestParams(model)));
} else if (imageQuestion != null) {
executeContexts.add(execute(threadSize, model, param, imageQuestion.getRequestParams(model)));
}
}
}
return executeContexts;
}
private static ExecuteContext execute(int threadSize, String model, ScriptParameter param, List<String> requestParams) {
CountDownLatch latch = new CountDownLatch(requestParams.size());
ConcurrentHashMap<Long, HttpContext> sessionMap = new ConcurrentHashMap<>(requestParams.size());
ExecutorService executorService = Executors.newFixedThreadPool(threadSize);
URI uri = URI.create(param.openAiApiHost);
executorService.execute(() -> {
for (String requestBody : requestParams) {
startHttp(uri, param.apiKey, requestBody, latch, sessionMap);
}
});
return new ExecuteContext(model, threadSize, sessionMap, latch, executorService);
}
record ExecuteContext(String model,
int threadSize,
ConcurrentHashMap<Long, HttpContext> sessionMap,
CountDownLatch latch,
ExecutorService executor) {
}
private static void startHttp(URI uri, String apiKey, String requestBody, CountDownLatch latch, Map<Long, HttpContext> sessionMap) {
HttpRequest httpRequest = HttpRequest.newBuilder()
.uri(uri)
.header("Content-Type", "application/json")
.header("Authorization", "Bearer " + apiKey)
.timeout(Duration.ofSeconds(15))
.POST(HttpRequest.BodyPublishers.ofString(requestBody))
.build();
try (HttpClient client = HttpClient.newHttpClient()) {
HttpContext context = new HttpContext();
context.start = LocalDateTime.now();
context.outTexts = new ArrayList<>();
Flow.Subscriber<String> subscriber = createResponseFluxHandler(context);
CompletableFuture<HttpResponse<Void>> future =
client.sendAsync(httpRequest, HttpResponse.BodyHandlers.fromLineSubscriber(subscriber));
handleHttpResponseFuture(future, latch, sessionMap, context);
}
}
private static class HttpContext {
long sessionId = System.nanoTime();
LocalDateTime start;
LocalDateTime end;
LocalDateTime completed;
boolean success;
List<String> outTexts;
public long toEndMillis() {
return this.end != null ? Duration.between(this.start, this.end).toMillis() : 0L;
}
public long toFinishMillis() {
return this.completed != null ? Duration.between(this.start, this.completed).toMillis() : 0L;
}
}
private static void handleHttpResponseFuture(CompletableFuture<HttpResponse<Void>> future,
CountDownLatch latch,
Map<Long, HttpContext> sessionMap,
HttpContext context) {
future.whenComplete((response, exception) -> context.success = false)
.thenAccept(response -> {
context.success = response.statusCode() == 200;
sessionMap.putIfAbsent(context.sessionId, context);
latch.countDown();
}).exceptionally(err -> {
context.success = false;
sessionMap.putIfAbsent(context.sessionId, context);
latch.countDown();
return null;
});
}
private static Flow.Subscriber<String> createResponseFluxHandler(HttpContext context) {
return new Flow.Subscriber<>() {
@Override
public void onSubscribe(Flow.Subscription subscription) {
context.end = LocalDateTime.now();
subscription.request(Long.MAX_VALUE);
}
@Override
public void onNext(String item) {
if (item != null && !item.isEmpty()) {
Matcher matcher = CONTENT_PATTERN.matcher(item);
String group;
if (matcher.find() && null != (group = matcher.group(1)) && !group.isEmpty()) {
context.outTexts.add(group);
}
}
}
@Override
public void onError(Throwable throwable) {
context.success = false;
}
@Override
public void onComplete() {
context.completed = LocalDateTime.now();
// System.out.println(context.outTexts);
}
};
}
private static class ImageQuestion {
private static final Map<String, List<String>> cache = new ConcurrentHashMap<>();
private static final String template = """
{
"model": "${model}",
"messages": [
{
"role": "user",
"content": [
{
"type": "text",
"text": "这张图片里有什么?"
},
{
"type": "image_url",
"image_url": {
"url": "${imageBase64}"
}
}
]
}
],
"stream": true
}
""".strip();
List<String> list;
public String getImgHead(File file) {
if (file.getName().endsWith("png")) {
return "image/png";
}
if (file.getName().endsWith("jpg") || file.getName().endsWith("jpeg")) {
return "image/jpeg";
}
return null;
}
public String tryEncodeBase64(File file, Path path) {
String imgHead = getImgHead(file);
if (imgHead == null || imgHead.isBlank()) {
return null;
}
try {
return "data:" + imgHead + ";base64," + Base64.getEncoder().encodeToString(Files.readAllBytes(path));
} catch (IOException e) {
throw new RuntimeException(e);
}
}
public List<String> base64Image(ScriptParameter parameter) throws IOException {
try (var files = Files.list(Path.of(parameter.vlImgFolder))) {
return files.map(path -> {
File file = path.toFile();
return tryEncodeBase64(file, path);
}).filter(Objects::nonNull).toList();
}
}
public ImageQuestion(ScriptParameter parameter) throws IOException {
List<String> base64List = base64Image(parameter);
int imgNum = Integer.parseInt(parameter.imgSize);
this.list = new ArrayList<>(imgNum);
do {
for (String item : base64List) {
this.list.add(item);
if (this.list.size() == imgNum) {
break;
}
}
} while (this.list.size() < imgNum);
}
public List<String> getRequestParams(String model) {
List<String> cacheParams = cache.get(model);
if (cacheParams != null && !cacheParams.isEmpty()) {
return cacheParams;
}
List<String> params = this.list.stream().map(s -> this.toJsonParam(model, s)).toList();
cache.put(model, params);
return params;
}
public String toJsonParam(String model, String imageBase64) {
return template.replace("${model}", model).replace("${imageBase64}", imageBase64);
}
}
private static class TextQuestion {
private static final Map<String, List<String>> cache = new ConcurrentHashMap<>();
private static final String template = """
{
"model": "${model}",
"messages": [
{
"role": "user",
"content": "${prompt}"
}
],
"stream": true
}
""".strip();
// 解析文件得到的问题列表
List<String> list;
public TextQuestion(ScriptParameter parameter) throws IOException {
this.list = Files.readAllLines(Path.of(parameter.chatDatasetsPath));
}
public List<String> getRequestParams(String model) {
List<String> cacheParams = cache.get(model);
if (cacheParams != null && cacheParams.isEmpty()) {
return cacheParams;
}
List<String> params = this.list.stream().map(s -> this.toJsonParam(model, s)).toList();
cache.put(model, params);
return params;
}
public String toJsonParam(String model, String prompt) {
return template.replace("${model}", model).replace("${prompt}", prompt);
}
}
private static void printScriptParam(ScriptParameter param) throws IllegalAccessException {
System.out.println("本次执行脚本的参数如下:");
for (Field field : PARAM_FIELD) {
if (field.isAnnotationPresent(EnvName.class)) {
System.out.println(SEP);
EnvName annotation = field.getAnnotation(EnvName.class);
field.setAccessible(true);
Object value = field.get(param);
System.out.printf("参数: %s 数值: %s%n", annotation.value(), value);
}
}
System.out.println(SEP);
}
private static void createRunScript() throws IOException {
String osName = System.getProperty("os.name").toLowerCase();
if (osName.contains("win")) {
generateWindowsBat();
} else {
generateShellScript();
}
}
private static File createScripeFile(String extName) {
String date = LocalDateTime.now().format(DateTimeFormatter.ofPattern("yyyyMMdd"));
File file = new File(String.format("llm_benchmark_tester_%s.%s", date, extName));
if (file.exists()) {
throw new RuntimeException(String.format("您可以通过 %s 脚本直接运行", file.getAbsolutePath()));
}
boolean createBat;
try {
createBat = file.createNewFile();
} catch (Exception e) {
throw new RuntimeException(String.format("创建 %s 脚本文件异常: %s", file.getAbsolutePath(), e.getMessage()), e);
}
if (!createBat) {
throw new RuntimeException(String.format("创建 %s 脚本文件失败", file.getAbsolutePath()));
}
System.out.println("已为你生成一份脚本, 请修改脚本中的环境变量, 使用脚本运行");
System.out.printf("脚本的存储路径 %s%n", file.getAbsolutePath());
System.out.println("运行脚本之前, 请确保脚本文件的换行符与系统相匹配, 否则会无法运行");
return file;
}
private static void writeScriptFile(File file, String template, BiConsumer<EnvName, List<String>> eachFunc) throws IOException {
List<String> envLines = new ArrayList<>();
for (Field field : PARAM_FIELD) {
if (field.isAnnotationPresent(EnvName.class)) {
EnvName annotation = field.getAnnotation(EnvName.class);
eachFunc.accept(annotation, envLines);
}
}
if (!envLines.isEmpty()) {
String envList = envLines.stream().collect(Collectors.joining(System.lineSeparator()));
String script = template.replace("${ENV_LINES}", envList);
Files.writeString(file.toPath(), script.strip(), StandardCharsets.UTF_8);
}
}
private static void generateWindowsBat() throws IOException {
File file = createScripeFile("bat");
String batTemplate = """
@echo off
:: java可执行文件路径, 不是JAVA_HOME, 是完整的java可执行文件路径, 例如: D:\\jdk-2108\\bin\\java
set JAVA_BIN=
:: 脚本存放路径, 例如: E:\\JExample\\src\\LLMBenchmarkTester.java
set SCRIPT_PATH=
${ENV_LINES}
:: 基于环境变量的方式执行, 交互式命令行执行用这个命令: %JAVA_BIN% %SCRIPT_PATH% -p input
%JAVA_BIN% %SCRIPT_PATH% -p env
pause
""";
writeScriptFile(file, batTemplate, (envName, envLines) -> {
envLines.add(":: " + envName.desc());
envLines.add("set " + envName.value() + "=");
});
}
private static void generateShellScript() throws IOException {
File file = createScripeFile("sh");
String bashTemplate = """
#!/bin/bash
# java可执行文件路径, 不是JAVA_HOME, 是完整的java可执行文件路径, 例如: /opt/jdk-2108/bin/java
JAVA_BIN=""
# 脚本存放路径, 例如: /home/user/JExample/src/LLMBenchmarkTester.java
SCRIPT_PATH=""
${ENV_LINES}
# 基于环境变量的方式执行, 交互式命令行执行用这个命令: $JAVA_BIN $SCRIPT_PATH -p input
"$JAVA_BIN" "$SCRIPT_PATH" -p env
""";
writeScriptFile(file, bashTemplate, (envName, envLines) -> {
envLines.add("# " + envName.desc());
envLines.add(envName.value() + "=");
});
}
private static ScriptParameter readScriptParameter(String[] args) throws IllegalAccessException {
if (args != null && args.length > 0) {
boolean p = Arrays.stream(args).anyMatch(s -> s.equalsIgnoreCase("-p"));
if (p && Arrays.stream(args).anyMatch(s -> s.equalsIgnoreCase("env"))) {
return initScriptParamFromEnv();
}
if (p && Arrays.stream(args).anyMatch(s -> s.equalsIgnoreCase("input"))) {
return initScriptParamFromAsk();
}
}
throw new RuntimeException("命令错误, 请检查参数是否正确");
}
private static ScriptParameter initScriptParamFromEnv() throws IllegalAccessException {
ScriptParameter param = new ScriptParameter();
param.channel = 1;
for (Field field : PARAM_FIELD) {
if (field.isAnnotationPresent(EnvName.class)) {
EnvName envName = field.getAnnotation(EnvName.class);
String fieldValue = System.getenv(envName.value());
String formatValue = formatValue(fieldValue, field);
if (field.isAnnotationPresent(NotBlank.class) && (formatValue == null || formatValue.isBlank())) {
throw new RuntimeException(String.format("环境变量[%s]不能为空或空白字符", envName.value()));
}
if (!isValidValue(formatValue, field)) {
throw new RuntimeException(String.format("环境变量[%s]数值不合法, 当前值:[%s]", envName.value(), formatValue));
}
field.setAccessible(true);
field.set(param, fieldValue);
}
}
return param;
}
private static ScriptParameter initScriptParamFromAsk() throws IllegalAccessException {
ScriptParameter param = new ScriptParameter();
param.channel = 2;
Scanner scanner = new Scanner(System.in);
for (Field field : PARAM_FIELD) {
if (field.isAnnotationPresent(AskUser.class)) {
AskUser askUser = field.getAnnotation(AskUser.class);
System.out.println(askUser.value() + ":");
boolean isNotBlank = field.isAnnotationPresent(NotBlank.class);
for (; ; ) {
String userInput = scanner.nextLine().trim();
String formatValue = formatValue(userInput, field);
// 允许为空并且输入值为空
if (!isNotBlank && (formatValue == null || formatValue.isBlank())) {
break;
}
// 非空并且输入值合法
if (isNotBlank && formatValue != null && !formatValue.isBlank() && isValidValue(formatValue, field)) {
field.setAccessible(true);
field.set(param, formatValue);
break;
}
System.out.print("请重新输入:");
}
}
}
return param;
}
// 顺序校验
private static Boolean isValidValue(String formatValue, Field field) {
if (field.isAnnotationPresent(Validator.class)) {
Validator anno = field.getAnnotation(Validator.class);
return Arrays.stream(anno.value())
.map(validator -> Constants.TEXT_VALIDATOR.get(validator.name()).apply(formatValue))
.allMatch(Boolean.TRUE::equals);
}
return Boolean.TRUE;
}
// 顺序格式化
private static String formatValue(String fieldValue, Field field) {
if (field.isAnnotationPresent(Formatter.class)) {
Formatter anno = field.getAnnotation(Formatter.class);
for (TextFormater fmt : anno.value()) {
fieldValue = Constants.TEXT_FORMATER.get(fmt.name()).apply(fieldValue);
}
}
return fieldValue;
}
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD})
public @interface AskUser {
String value();
}
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD})
public @interface EnvName {
String value() default "";
String desc();
}
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD})
public @interface NotBlank {
}
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD})
public @interface Validator {
TextValidator[] value();
}
@Retention(RetentionPolicy.RUNTIME)
@Target({ElementType.FIELD})
public @interface Formatter {
TextFormater[] value();
}
public enum TextValidator {
MUST_URL, MUST_FOLDER, MUST_TXT, MUST_NUM;
}
public enum TextFormater {
STRIP, COMMA_CN_2_EN;
}
private static class Constants {
// 去除字符串两端空白字符和制表符
public static final Function<String, String> STRIP_FORMATTER =
str -> Optional.ofNullable(str).map(java.lang.String::strip).orElse("");
// 中文逗号替换成英文逗号
public static final Function<String, String> COMMA_CN_2_EN =
str -> Optional.ofNullable(str).map(d -> d.replaceAll("", ",")).orElse("");
// 字符串必须是一个http链接
public static final Function<String, Boolean> URL_VALIDATOR =
str -> str != null && (str.startsWith("http://") || str.startsWith("https://"));
// 字符串必须是一个合法的文件路径且已存在的文件夹
public static final Function<String, Boolean> FOLDER_VALIDATOR = str -> {
if (str != null && !str.isBlank()) {
try {
File file = new File(str);
return file.exists() && file.isDirectory();
} catch (Exception e) {
return false;
}
}
return true;
};
// 字符串必须是一个合法的文件路径且已存在的txt文件
public static final Function<String, Boolean> TXT_FILE_VALIDATOR = str -> {
if (str != null && !str.isBlank()) {
try {
File file = new File(str);
return file.exists() && file.isFile() && file.getName().endsWith(".txt");
} catch (Exception e) {
return false;
}
}
return true;
};
// 字符串必须是一个整数
public static final Function<String, Boolean> NUMBER_VALIDATOR = str -> {
if (str != null && !str.isBlank()) {
try {
Integer.parseInt(str);
} catch (Exception e) {
return false;
}
}
return true;
};
// 文本格式化工具注册表
public static final Map<String, Function<String, String>> TEXT_FORMATER =
Map.of(
TextFormater.STRIP.name(), Constants.STRIP_FORMATTER,
TextFormater.COMMA_CN_2_EN.name(), Constants.COMMA_CN_2_EN
);
// 文本验证工具注册表
public static final Map<String, Function<String, Boolean>> TEXT_VALIDATOR =
Map.of(
TextValidator.MUST_URL.name(), URL_VALIDATOR,
TextValidator.MUST_FOLDER.name(), FOLDER_VALIDATOR,
TextValidator.MUST_TXT.name(), TXT_FILE_VALIDATOR,
TextValidator.MUST_NUM.name(), NUMBER_VALIDATOR
);
}
public static class ScriptParameter {
// 1=环境变量, 2=交互式命令行
int channel;
@NotBlank
@EnvName(value = "BENCH_LLM_API_HOST", desc = "OpenAI API 的访问地址, 例如: http://localhost:8080/v1/chat/completions")
@Validator(value = TextValidator.MUST_URL)
@Formatter(value = TextFormater.STRIP)
@AskUser(value = "请输入 OpenAI API 的访问地址 (例如: http://localhost:8080/v1/chat/completions)")
String openAiApiHost;
@NotBlank
@EnvName(value = "BENCH_LLM_API_KEY", desc = "ApiKey或者叫API令牌")
@Formatter(value = TextFormater.STRIP)
@AskUser(value = "请输入ApiKey或者叫API令牌")
String apiKey;
@NotBlank
@EnvName(value = "BENCH_THREAD_SIZE_ARRAY", desc = "请输入线程池配置, 示例值: 10,50,100")
@Formatter(value = {TextFormater.STRIP, TextFormater.COMMA_CN_2_EN})
@AskUser(value = "请输入线程池配置 (示例值: 10,50,100)")
String threadSize;
@EnvName(value = "BENCH_LLM_MODEL_NAME", desc = "文本模型名称, 多个使用英文逗号隔开, 如果不测试文生文模型可以不设置, 示例值: qwen2.5,qwen3")
@Formatter(value = {TextFormater.STRIP, TextFormater.COMMA_CN_2_EN})
@AskUser(value = "请输入文本模型名称, 多个使用英文逗号隔开, 如果不测试文生文模型可以直接回车 (示例值: qwen2.5,qwen3)")
String modelName;
@EnvName(value = "BENCH_LLM_VL_MODEL_NAME", desc = "VL模型名称, 多个用英文逗号隔开, 如果不测试VL模型可以不设置")
@Formatter(value = {TextFormater.STRIP, TextFormater.COMMA_CN_2_EN})
@AskUser(value = "请输入VL模型名称, 多个用英文逗号隔开, 如果不测试VL模型可以直接回车")
String vlModelName;
@EnvName(value = "BENCH_LLM_VL_IMG_FOLDER", desc = "调用VL模型的图片存储目录, 如果不测试VL模型可以不设置, 示例值: /home/image")
@Validator(value = TextValidator.MUST_FOLDER)
@Formatter(value = {TextFormater.STRIP})
@AskUser(value = "请输入调用VL模型的图片存储目录, 如果不测试VL模型可以直接回车 (示例值: /home/image)")
String vlImgFolder;
@EnvName(value = "BENCH_LLM_CHAT_MODEL_DATASETS", desc = "文生文测试数据集的文件路径, 如果不测试文生文模型可以不设置, 必须是一个.txt文件 (示例值: /home/datasets.txt)")
@Validator(value = TextValidator.MUST_TXT)
@AskUser(value = "请输入文生文测试数据集的文件路径, 必须是一个.txt文件, 如果不测试文生文模型可以直接回车 (示例值: /home/datasets.txt)")
String chatDatasetsPath;
@EnvName(value = "BENCH_LLM_VL_IMG_SIZE", desc = "调用VL模型的测试图片数量, 如果文件夹下的图片数量不够, 会复制直到到足够数量, 如果不测试VL模型可以不设置 (示例值: 300)")
@Validator(value = TextValidator.MUST_NUM)
@Formatter(value = TextFormater.STRIP)
@AskUser("请输入调用VL模型的测试图片数量, 如果文件夹下的图片数量不够, 会复制直到到足够数量, 如果不测试VL模型可以直接回车 (示例值: 300)")
String imgSize;
public boolean isTestChatModel() {
return this.modelName != null && !this.modelName.isBlank()
&& this.chatDatasetsPath != null && !this.chatDatasetsPath.isBlank();
}
public boolean isTestVlModel() {
return this.vlModelName != null && !this.vlModelName.isBlank()
&& this.vlImgFolder != null && !this.vlImgFolder.isBlank()
&& this.imgSize != null && !this.imgSize.isBlank();
}
}
}