Skip to content

Commit 2f85990

Browse files
committed
Java: MultiDataSource 导出 CVAuto 数据集在 train, val 里按 category 分目录存放图片,在 images 里不分
1 parent 59522c8 commit 2f85990

File tree

1 file changed

+88
-38
lines changed
  • APIJSON-Java-Server/APIJSONBoot-MultiDataSource/src/main/java/apijson

1 file changed

+88
-38
lines changed

APIJSON-Java-Server/APIJSONBoot-MultiDataSource/src/main/java/apijson/DatasetUtil.java

Lines changed: 88 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -98,24 +98,9 @@ public static void createCocoDirectoryStructure(String baseDir, String type) thr
9898
Files.createDirectories(Paths.get(baseDir + "annotations"));
9999
Files.createDirectories(Paths.get(baseDir + "images"));
100100

101-
Files.createDirectories(Paths.get(baseDir + "train"));
102-
Files.createDirectories(Paths.get(baseDir + "val"));
103-
Files.createDirectories(Paths.get(baseDir + "test"));
104-
105-
Files.createDirectories(Paths.get(baseDir, "annotations", "train"));
106-
Files.createDirectories(Paths.get(baseDir, "annotations", "val"));
107-
Files.createDirectories(Paths.get(baseDir, "annotations", "test"));
108-
109-
Files.createDirectories(Paths.get(baseDir, "images", "train"));
110-
Files.createDirectories(Paths.get(baseDir, "images", "val"));
111-
Files.createDirectories(Paths.get(baseDir, "images", "test"));
112-
113101
// 根据类型创建特定目录 detection, classification, segmentation, keypoints, face_keypoints 使用标准结构
114102
if (TaskType.OCR.getType().equals(type) || TaskType.ROTATED_DETECTION.getType().equals(type)) {
115-
Files.createDirectories(Paths.get(baseDir, "labels"));
116-
Files.createDirectories(Paths.get(baseDir, "labels", "train"));
117-
Files.createDirectories(Paths.get(baseDir, "labels", "val"));
118-
Files.createDirectories(Paths.get(baseDir, "labels", "test"));
103+
Files.createDirectories(Paths.get(baseDir + "labels"));
119104
}
120105
}
121106

@@ -346,18 +331,52 @@ public static void writeToFile(CocoDataset dataset, String outputPath) throws IO
346331

347332

348333
/**
349-
* 复制图片文件到指定目录,支持URL和base64两种格式
334+
* 复制图片文件到指定目录,按类别分目录存放,支持URL和base64两种格式
350335
* @param images 图片信息列表
351-
* @param imageDir 目标图片目录
336+
* @param imageDir 目标图片根目录
337+
* @param categories 类别列表
338+
* @param annotations 标注列表
352339
* @throws IOException
353340
*/
354-
public static void copyImagesToDirectory(List<ImageInfo> images, String imageDir) throws IOException {
355-
// 创建图片目录
341+
public static void copyImagesToDirectory(
342+
List<ImageInfo> images, String imageDir, List<Category> categories, List<Annotation> annotations
343+
) throws IOException {
344+
// 创建图片根目录
356345
Path imageDirPath = Paths.get(imageDir);
357346
if (!Files.exists(imageDirPath)) {
358347
Files.createDirectories(imageDirPath);
359348
}
360349

350+
// 创建类别ID到名称的映射
351+
Map<Integer, String> categoryIdToName = new HashMap<>();
352+
if (categories != null) {
353+
for (Category category : categories) {
354+
categoryIdToName.put(category.getId(), category.getName());
355+
}
356+
}
357+
358+
// 创建类别目录
359+
for (String categoryName : categoryIdToName.values()) {
360+
Path categoryDir = Paths.get(imageDir, categoryName);
361+
if (!Files.exists(categoryDir)) {
362+
Files.createDirectories(categoryDir);
363+
}
364+
}
365+
366+
// 创建图片ID到标注类别的映射
367+
Map<Integer, Set<String>> imageIdToCategories = new HashMap<>();
368+
if (annotations != null) {
369+
for (Annotation annotation : annotations) {
370+
int imageId = annotation.getImage_id();
371+
int categoryId = annotation.getCategory_id();
372+
String categoryName = categoryIdToName.get(categoryId);
373+
374+
if (categoryName != null) {
375+
imageIdToCategories.computeIfAbsent(imageId, k -> new HashSet<>()).add(categoryName);
376+
}
377+
}
378+
}
379+
361380
for (ImageInfo image : images) {
362381
String imgSource = image.getImg();
363382
String fileName = image.getFile_name();
@@ -367,27 +386,54 @@ public static void copyImagesToDirectory(List<ImageInfo> images, String imageDir
367386
continue;
368387
}
369388

370-
Path targetPath = Paths.get(imageDir, fileName);
371-
372-
try {
373-
if (imgSource.startsWith("data:image/")) {
374-
// 处理base64格式
375-
copyBase64Image(imgSource, targetPath);
376-
} else if (imgSource.startsWith("http://") || imgSource.startsWith("https://")) {
377-
// 处理URL格式
378-
copyUrlImage(imgSource, targetPath);
379-
} else {
380-
// 处理本地文件路径
381-
copyLocalImage(imgSource, targetPath);
382-
}
389+
// 获取这张图片的所有相关类别
390+
Set<String> imageCategories = imageIdToCategories.get(image.getId());
391+
if (imageCategories == null || imageCategories.isEmpty()) {
392+
// 如果没有标注信息,放到一个默认目录
393+
Path targetPath = Paths.get(imageDir, fileName);
394+
Files.createDirectories(targetPath.getParent());
383395

384-
System.out.println("Successfully copied image: " + fileName);
385-
} catch (Exception e) {
386-
System.err.println("Failed to copy image " + fileName + ": " + e.getMessage());
396+
try {
397+
copyImageToPath(imgSource, targetPath);
398+
System.out.println("Successfully copied unlabeled image: " + fileName);
399+
} catch (Exception e) {
400+
System.err.println("Failed to copy image " + fileName + ": " + e.getMessage());
401+
}
402+
} else {
403+
// 将图片复制到所有相关类别的目录
404+
for (String categoryName : imageCategories) {
405+
Path targetPath = Paths.get(imageDir, categoryName, fileName);
406+
407+
try {
408+
copyImageToPath(imgSource, targetPath);
409+
System.out.println("Successfully copied image " + fileName + " to category: " + categoryName);
410+
} catch (Exception e) {
411+
System.err.println("Failed to copy image " + fileName + " to category " + categoryName + ": " + e.getMessage());
412+
}
413+
}
387414
}
388415
}
389416
}
390417

418+
/**
419+
* 复制图片到指定路径的辅助方法
420+
*/
421+
private static void copyImageToPath(String imgSource, Path targetPath) throws IOException {
422+
// 确保目标目录存在
423+
Files.createDirectories(targetPath.getParent());
424+
425+
if (imgSource.startsWith("data:image/")) {
426+
// 处理base64格式
427+
copyBase64Image(imgSource, targetPath);
428+
} else if (imgSource.startsWith("http://") || imgSource.startsWith("https://")) {
429+
// 处理URL格式
430+
copyUrlImage(imgSource, targetPath);
431+
} else {
432+
// 处理本地文件路径
433+
copyLocalImage(imgSource, targetPath);
434+
}
435+
}
436+
391437
/**
392438
* 从base64数据复制图片
393439
*/
@@ -520,7 +566,8 @@ public static void generate(String outputDir, Set<TaskType> tasks) throws IOExce
520566
writeToFile(cocoDataset, outputJsonPath);
521567

522568
// 复制图片文件到指定目录 outputDir/images/
523-
copyImagesToDirectory(cocoDataset.getImages(), outputDir + "/images/");
569+
copyImagesToDirectory(cocoDataset.getImages(), outputDir + "/images/", null, null);
570+
copyImagesToDirectory(cocoDataset.getImages(), outputDir + "/images/", cocoDataset.getCategories(), cocoDataset.getAnnotations());
524571

525572
System.out.println("Successfully generated dataset at: " + outputDir);
526573
}
@@ -647,7 +694,10 @@ public static void generate(List<JSONObject> data, Set<TaskType> tasks, String o
647694
writeToFile(cocoDataset, Paths.get(outputDir, taskName + ".json").toString());
648695

649696
// 复制图片文件到指定目录 outputDir/images/ train/val
650-
copyImagesToDirectory(cocoDataset.getImages(), outputDir + taskName + "/");
697+
copyImagesToDirectory(cocoDataset.getImages(), outputDir + taskName + File.separator, cocoDataset.getCategories(), cocoDataset.getAnnotations());
698+
//if (Log.DEBUG || tasks.contains(TaskType.CLASSIFICATION)) {
699+
copyImagesToDirectory(cocoDataset.getImages(), outputDir + "images" + File.separator + taskName + File.separator, null, null);
700+
//}
651701

652702
System.out.println("Successfully generated dataset from JSONObject data at: " + outputDir);
653703
}

0 commit comments

Comments
 (0)