@@ -98,24 +98,9 @@ public static void createCocoDirectoryStructure(String baseDir, String type) thr
98
98
Files .createDirectories (Paths .get (baseDir + "annotations" ));
99
99
Files .createDirectories (Paths .get (baseDir + "images" ));
100
100
101
- Files .createDirectories (Paths .get (baseDir + "train" ));
102
- Files .createDirectories (Paths .get (baseDir + "val" ));
103
- Files .createDirectories (Paths .get (baseDir + "test" ));
104
-
105
- Files .createDirectories (Paths .get (baseDir , "annotations" , "train" ));
106
- Files .createDirectories (Paths .get (baseDir , "annotations" , "val" ));
107
- Files .createDirectories (Paths .get (baseDir , "annotations" , "test" ));
108
-
109
- Files .createDirectories (Paths .get (baseDir , "images" , "train" ));
110
- Files .createDirectories (Paths .get (baseDir , "images" , "val" ));
111
- Files .createDirectories (Paths .get (baseDir , "images" , "test" ));
112
-
113
101
// 根据类型创建特定目录 detection, classification, segmentation, keypoints, face_keypoints 使用标准结构
114
102
if (TaskType .OCR .getType ().equals (type ) || TaskType .ROTATED_DETECTION .getType ().equals (type )) {
115
- Files .createDirectories (Paths .get (baseDir , "labels" ));
116
- Files .createDirectories (Paths .get (baseDir , "labels" , "train" ));
117
- Files .createDirectories (Paths .get (baseDir , "labels" , "val" ));
118
- Files .createDirectories (Paths .get (baseDir , "labels" , "test" ));
103
+ Files .createDirectories (Paths .get (baseDir + "labels" ));
119
104
}
120
105
}
121
106
@@ -346,18 +331,52 @@ public static void writeToFile(CocoDataset dataset, String outputPath) throws IO
346
331
347
332
348
333
/**
349
- * 复制图片文件到指定目录,支持URL和base64两种格式
334
+ * 复制图片文件到指定目录,按类别分目录存放, 支持URL和base64两种格式
350
335
* @param images 图片信息列表
351
- * @param imageDir 目标图片目录
336
+ * @param imageDir 目标图片根目录
337
+ * @param categories 类别列表
338
+ * @param annotations 标注列表
352
339
* @throws IOException
353
340
*/
354
- public static void copyImagesToDirectory (List <ImageInfo > images , String imageDir ) throws IOException {
355
- // 创建图片目录
341
+ public static void copyImagesToDirectory (
342
+ List <ImageInfo > images , String imageDir , List <Category > categories , List <Annotation > annotations
343
+ ) throws IOException {
344
+ // 创建图片根目录
356
345
Path imageDirPath = Paths .get (imageDir );
357
346
if (!Files .exists (imageDirPath )) {
358
347
Files .createDirectories (imageDirPath );
359
348
}
360
349
350
+ // 创建类别ID到名称的映射
351
+ Map <Integer , String > categoryIdToName = new HashMap <>();
352
+ if (categories != null ) {
353
+ for (Category category : categories ) {
354
+ categoryIdToName .put (category .getId (), category .getName ());
355
+ }
356
+ }
357
+
358
+ // 创建类别目录
359
+ for (String categoryName : categoryIdToName .values ()) {
360
+ Path categoryDir = Paths .get (imageDir , categoryName );
361
+ if (!Files .exists (categoryDir )) {
362
+ Files .createDirectories (categoryDir );
363
+ }
364
+ }
365
+
366
+ // 创建图片ID到标注类别的映射
367
+ Map <Integer , Set <String >> imageIdToCategories = new HashMap <>();
368
+ if (annotations != null ) {
369
+ for (Annotation annotation : annotations ) {
370
+ int imageId = annotation .getImage_id ();
371
+ int categoryId = annotation .getCategory_id ();
372
+ String categoryName = categoryIdToName .get (categoryId );
373
+
374
+ if (categoryName != null ) {
375
+ imageIdToCategories .computeIfAbsent (imageId , k -> new HashSet <>()).add (categoryName );
376
+ }
377
+ }
378
+ }
379
+
361
380
for (ImageInfo image : images ) {
362
381
String imgSource = image .getImg ();
363
382
String fileName = image .getFile_name ();
@@ -367,27 +386,54 @@ public static void copyImagesToDirectory(List<ImageInfo> images, String imageDir
367
386
continue ;
368
387
}
369
388
370
- Path targetPath = Paths .get (imageDir , fileName );
371
-
372
- try {
373
- if (imgSource .startsWith ("data:image/" )) {
374
- // 处理base64格式
375
- copyBase64Image (imgSource , targetPath );
376
- } else if (imgSource .startsWith ("http://" ) || imgSource .startsWith ("https://" )) {
377
- // 处理URL格式
378
- copyUrlImage (imgSource , targetPath );
379
- } else {
380
- // 处理本地文件路径
381
- copyLocalImage (imgSource , targetPath );
382
- }
389
+ // 获取这张图片的所有相关类别
390
+ Set <String > imageCategories = imageIdToCategories .get (image .getId ());
391
+ if (imageCategories == null || imageCategories .isEmpty ()) {
392
+ // 如果没有标注信息,放到一个默认目录
393
+ Path targetPath = Paths .get (imageDir , fileName );
394
+ Files .createDirectories (targetPath .getParent ());
383
395
384
- System .out .println ("Successfully copied image: " + fileName );
385
- } catch (Exception e ) {
386
- System .err .println ("Failed to copy image " + fileName + ": " + e .getMessage ());
396
+ try {
397
+ copyImageToPath (imgSource , targetPath );
398
+ System .out .println ("Successfully copied unlabeled image: " + fileName );
399
+ } catch (Exception e ) {
400
+ System .err .println ("Failed to copy image " + fileName + ": " + e .getMessage ());
401
+ }
402
+ } else {
403
+ // 将图片复制到所有相关类别的目录
404
+ for (String categoryName : imageCategories ) {
405
+ Path targetPath = Paths .get (imageDir , categoryName , fileName );
406
+
407
+ try {
408
+ copyImageToPath (imgSource , targetPath );
409
+ System .out .println ("Successfully copied image " + fileName + " to category: " + categoryName );
410
+ } catch (Exception e ) {
411
+ System .err .println ("Failed to copy image " + fileName + " to category " + categoryName + ": " + e .getMessage ());
412
+ }
413
+ }
387
414
}
388
415
}
389
416
}
390
417
418
+ /**
419
+ * 复制图片到指定路径的辅助方法
420
+ */
421
+ private static void copyImageToPath (String imgSource , Path targetPath ) throws IOException {
422
+ // 确保目标目录存在
423
+ Files .createDirectories (targetPath .getParent ());
424
+
425
+ if (imgSource .startsWith ("data:image/" )) {
426
+ // 处理base64格式
427
+ copyBase64Image (imgSource , targetPath );
428
+ } else if (imgSource .startsWith ("http://" ) || imgSource .startsWith ("https://" )) {
429
+ // 处理URL格式
430
+ copyUrlImage (imgSource , targetPath );
431
+ } else {
432
+ // 处理本地文件路径
433
+ copyLocalImage (imgSource , targetPath );
434
+ }
435
+ }
436
+
391
437
/**
392
438
* 从base64数据复制图片
393
439
*/
@@ -520,7 +566,8 @@ public static void generate(String outputDir, Set<TaskType> tasks) throws IOExce
520
566
writeToFile (cocoDataset , outputJsonPath );
521
567
522
568
// 复制图片文件到指定目录 outputDir/images/
523
- copyImagesToDirectory (cocoDataset .getImages (), outputDir + "/images/" );
569
+ copyImagesToDirectory (cocoDataset .getImages (), outputDir + "/images/" , null , null );
570
+ copyImagesToDirectory (cocoDataset .getImages (), outputDir + "/images/" , cocoDataset .getCategories (), cocoDataset .getAnnotations ());
524
571
525
572
System .out .println ("Successfully generated dataset at: " + outputDir );
526
573
}
@@ -647,7 +694,10 @@ public static void generate(List<JSONObject> data, Set<TaskType> tasks, String o
647
694
writeToFile (cocoDataset , Paths .get (outputDir , taskName + ".json" ).toString ());
648
695
649
696
// 复制图片文件到指定目录 outputDir/images/ train/val
650
- copyImagesToDirectory (cocoDataset .getImages (), outputDir + taskName + "/" );
697
+ copyImagesToDirectory (cocoDataset .getImages (), outputDir + taskName + File .separator , cocoDataset .getCategories (), cocoDataset .getAnnotations ());
698
+ //if (Log.DEBUG || tasks.contains(TaskType.CLASSIFICATION)) {
699
+ copyImagesToDirectory (cocoDataset .getImages (), outputDir + "images" + File .separator + taskName + File .separator , null , null );
700
+ //}
651
701
652
702
System .out .println ("Successfully generated dataset from JSONObject data at: " + outputDir );
653
703
}
0 commit comments