diff --git a/app.py b/app.py index fc61c09..9654270 100644 --- a/app.py +++ b/app.py @@ -81,8 +81,8 @@ "faster_rcnn_X_101_32x8d_FPN_3x/139173657/model_final_68b088.pkl" ) -# Gradio example row: (image_rel, tta_lr, tta_iters, det_thresh, kp_thresh, side_view, save_mesh) -ExampleRow = Tuple[str, float, int, float, float, bool, bool] +# Gradio example row: (image_rel, tta_lr, tta_iters, det_thresh, kp_thresh, side_view, render_depth, save_mesh) +ExampleRow = Tuple[str, float, int, float, float, bool, bool, bool] @dataclass(frozen=True) @@ -98,6 +98,7 @@ class DemoProfile: max_tta_iters: int default_save_mesh: bool default_side_view: bool + default_render_depth: bool preload_assets: bool example_rows: Tuple[ExampleRow, ...] description: str @@ -124,13 +125,14 @@ def resolve_detectron_device(self) -> str: max_tta_iters=100, default_save_mesh=True, default_side_view=False, + default_render_depth=False, preload_assets=False, example_rows=( - ("demo_data/000000015956_horse.png", 1e-6, 30, 0.7, 0.1, False, True), - ("demo_data/n02412080_12159.png", 1e-6, 30, 0.7, 0.1, False, True), - ("demo_data/000000315905_zebra.jpg", 1e-6, 30, 0.7, 0.1, False, True), - ("demo_data/beagle.jpg", 1e-6, 30, 0.7, 0.1, False, True), - ("demo_data/shepherd_hati.jpg", 1e-6, 30, 0.7, 0.1, False, True), + ("demo_data/000000015956_horse.png", 1e-6, 30, 0.7, 0.1, False, False, True), + ("demo_data/n02412080_12159.png", 1e-6, 30, 0.7, 0.1, False, False, True), + ("demo_data/000000315905_zebra.jpg", 1e-6, 30, 0.7, 0.1, False, False, True), + ("demo_data/beagle.jpg", 1e-6, 30, 0.7, 0.1, False, False, True), + ("demo_data/shepherd_hati.jpg", 1e-6, 30, 0.7, 0.1, False, False, True), ), description=( "**Local demo** — full pipeline on your machine (GPU when available).\n\n" @@ -153,13 +155,14 @@ def resolve_detectron_device(self) -> str: max_tta_iters=30, default_save_mesh=False, default_side_view=False, + default_render_depth=False, preload_assets=True, example_rows=( - ("demo_data/000000015956_horse.png", 1e-6, 30, 0.7, 0.1, False, False), - ("demo_data/n02412080_12159.png", 1e-6, 30, 0.7, 0.1, False, False), - ("demo_data/000000315905_zebra.jpg", 1e-6, 30, 0.7, 0.1, False, False), - ("demo_data/beagle.jpg", 1e-6, 30, 0.7, 0.1, False, False), - ("demo_data/shepherd_hati.jpg", 1e-6, 30, 0.7, 0.1, False, False), + ("demo_data/000000015956_horse.png", 1e-6, 30, 0.7, 0.1, False, False, False), + ("demo_data/n02412080_12159.png", 1e-6, 30, 0.7, 0.1, False, False, False), + ("demo_data/000000315905_zebra.jpg", 1e-6, 30, 0.7, 0.1, False, False, False), + ("demo_data/beagle.jpg", 1e-6, 30, 0.7, 0.1, False, False, False), + ("demo_data/shepherd_hati.jpg", 1e-6, 30, 0.7, 0.1, False, False, False), ), description=( "**Hugging Face Space (cpu-basic)** — lightweight demo: **CPU-only** PRIMA inference. " @@ -447,6 +450,7 @@ def _collect_animal_results( det_thresh: float, kp_conf_thresh: float, side_view: bool, + render_depth: bool, save_mesh: bool, boxes: Optional[np.ndarray] = None, progress_callback: Optional[Callable[[str], None]] = None, @@ -526,6 +530,7 @@ def report(message: str) -> None: suffix="before_tta", side_view=side_view, save_mesh=save_mesh, + render_depth=render_depth, ) before_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_before_tta.png") @@ -552,6 +557,7 @@ def report(message: str) -> None: suffix="after_tta", side_view=side_view, save_mesh=save_mesh, + render_depth=render_depth, ) after_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.png") @@ -619,6 +625,7 @@ def report(message: str) -> None: suffix="after_tta", side_view=side_view, save_mesh=save_mesh, + render_depth=render_depth, ) after_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.png") @@ -670,6 +677,7 @@ def gradio_inference( det_thresh: float, kp_conf_thresh: float, side_view: bool, + render_depth: bool, save_mesh: bool, ): """Wrapper for Gradio. ``image`` is an RGB numpy array. @@ -756,6 +764,7 @@ def run_collect(progress_callback: Optional[Callable[[str], None]] = None): det_thresh=det_thresh, kp_conf_thresh=kp_conf_thresh, side_view=side_view, + render_depth=render_depth, save_mesh=save_mesh, boxes=boxes, progress_callback=progress_callback, @@ -855,6 +864,7 @@ def report_stage(message: str) -> None: step=0.05, ), gr.Checkbox(label="Render side view", value=profile.default_side_view), + gr.Checkbox(label="Render depth map", value=profile.default_render_depth), gr.Checkbox(label="Save meshes (.obj)", value=profile.default_save_mesh), ], outputs=[ diff --git a/demo.py b/demo.py index c5ac815..bdf16db 100644 --- a/demo.py +++ b/demo.py @@ -59,6 +59,8 @@ def main(): parser.add_argument('--out_folder', type=str, default='demo_out', help='Output folder to save rendered results') parser.add_argument('--side_view', dest='side_view', action='store_true', default=False, help='If set, render side view also') + parser.add_argument('--render_depth', dest='render_depth', action='store_true', default=False, + help='If set, render depth map also') parser.add_argument('--save_mesh', dest='save_mesh', action='store_true', default=False, help='If set, save meshes to disk also') parser.add_argument('--batch_size', type=int, default=1, help='Batch size for inference/fitting') @@ -162,6 +164,35 @@ def main(): scene_bg_color=(1, 1, 1), side_view=True) final_img = np.concatenate([final_img, side_img], axis=1) + + if args.render_depth: + depth_img = renderer(out['pred_vertices'][n].detach().cpu().numpy(), + out['pred_cam_t'][n].detach().cpu().numpy(), + white_img, + mesh_base_color=GREEN, + scene_bg_color=(1, 1, 1), + depth=True) + + valid_mask = depth_img > 0 + if np.sum(valid_mask) == 0: + # no valid depth + depth_norm = np.zeros_like(depth_img) + else: + min_val = np.min(depth_img[valid_mask]) + max_val = np.max(depth_img[valid_mask]) + if min_val == max_val: + depth_norm = np.zeros_like(depth_img) + else: + depth_norm = (depth_img - min_val) / (max_val - min_val + 1e-8) + depth_norm[~valid_mask] = 0 + + depth_vis = (depth_norm * 255).astype(np.uint8) + depth_vis = cv2.applyColorMap(depth_vis, cv2.COLORMAP_VIRIDIS) + depth_vis = cv2.cvtColor(depth_vis, cv2.COLOR_BGR2RGB) + depth_vis = depth_vis.astype(np.float32) / 255.0 + depth_vis[~valid_mask] = 0 + final_img = np.concatenate([final_img, depth_vis], axis=1) + cv2.imwrite(os.path.join(args.out_folder, f'{img_fn}_{animal_id}.png'), cv2.cvtColor((255 * final_img).astype(np.uint8), cv2.COLOR_RGB2BGR)) diff --git a/demo_tta.py b/demo_tta.py index c8257c2..7c713bf 100644 --- a/demo_tta.py +++ b/demo_tta.py @@ -92,6 +92,27 @@ def save_keypoint_vis(patch_rgb: np.ndarray, kpts_xyc: np.ndarray, save_path: st cv2.imwrite(save_path, vis) +def depth_to_viridis_rgb(depth_img: np.ndarray) -> np.ndarray: + valid_mask = depth_img > 0 + if np.sum(valid_mask) == 0: + depth_norm = np.zeros_like(depth_img) + else: + min_val = np.min(depth_img[valid_mask]) + max_val = np.max(depth_img[valid_mask]) + if min_val == max_val: + depth_norm = np.zeros_like(depth_img) + else: + depth_norm = (depth_img - min_val) / (max_val - min_val + 1e-8) + depth_norm[~valid_mask] = 0 + + depth_vis = (depth_norm * 255).astype(np.uint8) + depth_vis = cv2.applyColorMap(depth_vis, cv2.COLORMAP_VIRIDIS) + depth_vis = cv2.cvtColor(depth_vis, cv2.COLOR_BGR2RGB) + depth_vis = depth_vis.astype(np.float32) / 255.0 + depth_vis[~valid_mask] = 0 + return depth_vis + + def resolve_sa_weights_path(local_path: str) -> str: """Return a local path to the fine-tuned SuperAnimal .pt snapshot. @@ -162,7 +183,19 @@ def run_superanimal_on_patch(patch_rgb: np.ndarray, args, tmp_dir: str): return bodyparts[best_idx].astype(np.float32) -def render_and_save(renderer, cam_crop_to_full_fn, out, batch, img_fn, animal_id, out_folder, suffix, side_view, save_mesh): +def render_and_save( + renderer, + cam_crop_to_full_fn, + out, + batch, + img_fn, + animal_id, + out_folder, + suffix, + side_view, + save_mesh, + render_depth=False, +): pred_cam = out['pred_cam'] box_center = batch['box_center'].float() box_size = batch['box_size'].float() @@ -195,6 +228,17 @@ def render_and_save(renderer, cam_crop_to_full_fn, out, batch, img_fn, animal_id ) final_img = np.concatenate([final_img, side_img], axis=1) + if render_depth: + depth_img = renderer( + out['pred_vertices'][0].detach().cpu().numpy(), + out['pred_cam_t'][0].detach().cpu().numpy(), + white_img, + mesh_base_color=GREEN, + scene_bg_color=(1, 1, 1), + depth=True, + ) + final_img = np.concatenate([final_img, depth_to_viridis_rgb(depth_img)], axis=1) + cv2.imwrite( os.path.join(out_folder, f'{img_fn}_{animal_id}_{suffix}.png'), cv2.cvtColor((255 * final_img).astype(np.uint8), cv2.COLOR_RGB2BGR), @@ -252,6 +296,7 @@ def main(): parser.add_argument('--img_folder', type=str, default='demo_data/', help='Folder with input images') parser.add_argument('--out_folder', type=str, default='demo_out_tta', help='Output folder') parser.add_argument('--side_view', dest='side_view', action='store_true', default=False, help='Render side view') + parser.add_argument('--render_depth', dest='render_depth', action='store_true', default=False, help='Render depth map') parser.add_argument('--save_mesh', dest='save_mesh', action='store_true', default=False, help='Save meshes') parser.add_argument('--file_type', nargs='+', default=['*.jpg', '*.png', '*.jpeg', '*.JPEG'], help='Image globs') parser.add_argument('--det_thresh', type=float, default=0.7, help='Detectron2 score threshold for animals') @@ -348,6 +393,7 @@ def main(): suffix='before_tta', side_view=args.side_view, save_mesh=args.save_mesh, + render_depth=args.render_depth, ) patch_rgb = denorm_patch_to_rgb(batch['img'][0]) @@ -392,6 +438,7 @@ def main(): suffix='after_tta', side_view=args.side_view, save_mesh=args.save_mesh, + render_depth=args.render_depth, ) diff --git a/demo_tta.sh b/demo_tta.sh index 58eef2d..6249a2d 100644 --- a/demo_tta.sh +++ b/demo_tta.sh @@ -12,4 +12,5 @@ python3 demo_tta.py \ --img_folder demo_data/ \ --out_folder demo_out_tta/ \ --tta_lr 1e-6 \ - --tta_num_iters 30 + --tta_num_iters 30 \ + --render_depth