Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 22 additions & 12 deletions app.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@
"faster_rcnn_X_101_32x8d_FPN_3x/139173657/model_final_68b088.pkl"
)

# Gradio example row: (image_rel, tta_lr, tta_iters, det_thresh, kp_thresh, side_view, save_mesh)
ExampleRow = Tuple[str, float, int, float, float, bool, bool]
# Gradio example row: (image_rel, tta_lr, tta_iters, det_thresh, kp_thresh, side_view, render_depth, save_mesh)
ExampleRow = Tuple[str, float, int, float, float, bool, bool, bool]


@dataclass(frozen=True)
Expand All @@ -98,6 +98,7 @@ class DemoProfile:
max_tta_iters: int
default_save_mesh: bool
default_side_view: bool
default_render_depth: bool
preload_assets: bool
example_rows: Tuple[ExampleRow, ...]
description: str
Expand All @@ -124,13 +125,14 @@ def resolve_detectron_device(self) -> str:
max_tta_iters=100,
default_save_mesh=True,
default_side_view=False,
default_render_depth=False,
preload_assets=False,
example_rows=(
("demo_data/000000015956_horse.png", 1e-6, 30, 0.7, 0.1, False, True),
("demo_data/n02412080_12159.png", 1e-6, 30, 0.7, 0.1, False, True),
("demo_data/000000315905_zebra.jpg", 1e-6, 30, 0.7, 0.1, False, True),
("demo_data/beagle.jpg", 1e-6, 30, 0.7, 0.1, False, True),
("demo_data/shepherd_hati.jpg", 1e-6, 30, 0.7, 0.1, False, True),
("demo_data/000000015956_horse.png", 1e-6, 30, 0.7, 0.1, False, False, True),
("demo_data/n02412080_12159.png", 1e-6, 30, 0.7, 0.1, False, False, True),
("demo_data/000000315905_zebra.jpg", 1e-6, 30, 0.7, 0.1, False, False, True),
("demo_data/beagle.jpg", 1e-6, 30, 0.7, 0.1, False, False, True),
("demo_data/shepherd_hati.jpg", 1e-6, 30, 0.7, 0.1, False, False, True),
),
description=(
"**Local demo** — full pipeline on your machine (GPU when available).\n\n"
Expand All @@ -153,13 +155,14 @@ def resolve_detectron_device(self) -> str:
max_tta_iters=30,
default_save_mesh=False,
default_side_view=False,
default_render_depth=False,
preload_assets=True,
example_rows=(
("demo_data/000000015956_horse.png", 1e-6, 30, 0.7, 0.1, False, False),
("demo_data/n02412080_12159.png", 1e-6, 30, 0.7, 0.1, False, False),
("demo_data/000000315905_zebra.jpg", 1e-6, 30, 0.7, 0.1, False, False),
("demo_data/beagle.jpg", 1e-6, 30, 0.7, 0.1, False, False),
("demo_data/shepherd_hati.jpg", 1e-6, 30, 0.7, 0.1, False, False),
("demo_data/000000015956_horse.png", 1e-6, 30, 0.7, 0.1, False, False, False),
("demo_data/n02412080_12159.png", 1e-6, 30, 0.7, 0.1, False, False, False),
("demo_data/000000315905_zebra.jpg", 1e-6, 30, 0.7, 0.1, False, False, False),
("demo_data/beagle.jpg", 1e-6, 30, 0.7, 0.1, False, False, False),
("demo_data/shepherd_hati.jpg", 1e-6, 30, 0.7, 0.1, False, False, False),
),
description=(
"**Hugging Face Space (cpu-basic)** — lightweight demo: **CPU-only** PRIMA inference. "
Expand Down Expand Up @@ -447,6 +450,7 @@ def _collect_animal_results(
det_thresh: float,
kp_conf_thresh: float,
side_view: bool,
render_depth: bool,
save_mesh: bool,
boxes: Optional[np.ndarray] = None,
progress_callback: Optional[Callable[[str], None]] = None,
Expand Down Expand Up @@ -526,6 +530,7 @@ def report(message: str) -> None:
suffix="before_tta",
side_view=side_view,
save_mesh=save_mesh,
render_depth=render_depth,
)

before_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_before_tta.png")
Expand All @@ -552,6 +557,7 @@ def report(message: str) -> None:
suffix="after_tta",
side_view=side_view,
save_mesh=save_mesh,
render_depth=render_depth,
)

after_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.png")
Expand Down Expand Up @@ -619,6 +625,7 @@ def report(message: str) -> None:
suffix="after_tta",
side_view=side_view,
save_mesh=save_mesh,
render_depth=render_depth,
)

after_png_path = os.path.join(out_folder, f"{img_fn}_{animal_id}_after_tta.png")
Expand Down Expand Up @@ -670,6 +677,7 @@ def gradio_inference(
det_thresh: float,
kp_conf_thresh: float,
side_view: bool,
render_depth: bool,
save_mesh: bool,
):
"""Wrapper for Gradio. ``image`` is an RGB numpy array.
Expand Down Expand Up @@ -756,6 +764,7 @@ def run_collect(progress_callback: Optional[Callable[[str], None]] = None):
det_thresh=det_thresh,
kp_conf_thresh=kp_conf_thresh,
side_view=side_view,
render_depth=render_depth,
save_mesh=save_mesh,
boxes=boxes,
progress_callback=progress_callback,
Expand Down Expand Up @@ -855,6 +864,7 @@ def report_stage(message: str) -> None:
step=0.05,
),
gr.Checkbox(label="Render side view", value=profile.default_side_view),
gr.Checkbox(label="Render depth map", value=profile.default_render_depth),
gr.Checkbox(label="Save meshes (.obj)", value=profile.default_save_mesh),
],
outputs=[
Expand Down
31 changes: 31 additions & 0 deletions demo.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ def main():
parser.add_argument('--out_folder', type=str, default='demo_out', help='Output folder to save rendered results')
parser.add_argument('--side_view', dest='side_view', action='store_true', default=False,
help='If set, render side view also')
parser.add_argument('--render_depth', dest='render_depth', action='store_true', default=False,
help='If set, render depth map also')
parser.add_argument('--save_mesh', dest='save_mesh', action='store_true', default=False,
help='If set, save meshes to disk also')
parser.add_argument('--batch_size', type=int, default=1, help='Batch size for inference/fitting')
Expand Down Expand Up @@ -162,6 +164,35 @@ def main():
scene_bg_color=(1, 1, 1),
side_view=True)
final_img = np.concatenate([final_img, side_img], axis=1)

if args.render_depth:
depth_img = renderer(out['pred_vertices'][n].detach().cpu().numpy(),
out['pred_cam_t'][n].detach().cpu().numpy(),
white_img,
mesh_base_color=GREEN,
scene_bg_color=(1, 1, 1),
depth=True)

valid_mask = depth_img > 0
if np.sum(valid_mask) == 0:
# no valid depth
depth_norm = np.zeros_like(depth_img)
else:
min_val = np.min(depth_img[valid_mask])
max_val = np.max(depth_img[valid_mask])
if min_val == max_val:
depth_norm = np.zeros_like(depth_img)
else:
depth_norm = (depth_img - min_val) / (max_val - min_val + 1e-8)
depth_norm[~valid_mask] = 0

depth_vis = (depth_norm * 255).astype(np.uint8)
depth_vis = cv2.applyColorMap(depth_vis, cv2.COLORMAP_VIRIDIS)
depth_vis = cv2.cvtColor(depth_vis, cv2.COLOR_BGR2RGB)
depth_vis = depth_vis.astype(np.float32) / 255.0
depth_vis[~valid_mask] = 0
final_img = np.concatenate([final_img, depth_vis], axis=1)


cv2.imwrite(os.path.join(args.out_folder, f'{img_fn}_{animal_id}.png'),
cv2.cvtColor((255 * final_img).astype(np.uint8), cv2.COLOR_RGB2BGR))
Expand Down
49 changes: 48 additions & 1 deletion demo_tta.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,6 +92,27 @@ def save_keypoint_vis(patch_rgb: np.ndarray, kpts_xyc: np.ndarray, save_path: st
cv2.imwrite(save_path, vis)


def depth_to_viridis_rgb(depth_img: np.ndarray) -> np.ndarray:
valid_mask = depth_img > 0
if np.sum(valid_mask) == 0:
depth_norm = np.zeros_like(depth_img)
else:
min_val = np.min(depth_img[valid_mask])
max_val = np.max(depth_img[valid_mask])
if min_val == max_val:
depth_norm = np.zeros_like(depth_img)
else:
depth_norm = (depth_img - min_val) / (max_val - min_val + 1e-8)
depth_norm[~valid_mask] = 0

depth_vis = (depth_norm * 255).astype(np.uint8)
depth_vis = cv2.applyColorMap(depth_vis, cv2.COLORMAP_VIRIDIS)
depth_vis = cv2.cvtColor(depth_vis, cv2.COLOR_BGR2RGB)
depth_vis = depth_vis.astype(np.float32) / 255.0
depth_vis[~valid_mask] = 0
return depth_vis


def resolve_sa_weights_path(local_path: str) -> str:
"""Return a local path to the fine-tuned SuperAnimal .pt snapshot.

Expand Down Expand Up @@ -162,7 +183,19 @@ def run_superanimal_on_patch(patch_rgb: np.ndarray, args, tmp_dir: str):
return bodyparts[best_idx].astype(np.float32)


def render_and_save(renderer, cam_crop_to_full_fn, out, batch, img_fn, animal_id, out_folder, suffix, side_view, save_mesh):
def render_and_save(
renderer,
cam_crop_to_full_fn,
out,
batch,
img_fn,
animal_id,
out_folder,
suffix,
side_view,
save_mesh,
render_depth=False,
):
pred_cam = out['pred_cam']
box_center = batch['box_center'].float()
box_size = batch['box_size'].float()
Expand Down Expand Up @@ -195,6 +228,17 @@ def render_and_save(renderer, cam_crop_to_full_fn, out, batch, img_fn, animal_id
)
final_img = np.concatenate([final_img, side_img], axis=1)

if render_depth:
depth_img = renderer(
out['pred_vertices'][0].detach().cpu().numpy(),
out['pred_cam_t'][0].detach().cpu().numpy(),
white_img,
mesh_base_color=GREEN,
scene_bg_color=(1, 1, 1),
depth=True,
)
final_img = np.concatenate([final_img, depth_to_viridis_rgb(depth_img)], axis=1)

cv2.imwrite(
os.path.join(out_folder, f'{img_fn}_{animal_id}_{suffix}.png'),
cv2.cvtColor((255 * final_img).astype(np.uint8), cv2.COLOR_RGB2BGR),
Expand Down Expand Up @@ -252,6 +296,7 @@ def main():
parser.add_argument('--img_folder', type=str, default='demo_data/', help='Folder with input images')
parser.add_argument('--out_folder', type=str, default='demo_out_tta', help='Output folder')
parser.add_argument('--side_view', dest='side_view', action='store_true', default=False, help='Render side view')
parser.add_argument('--render_depth', dest='render_depth', action='store_true', default=False, help='Render depth map')
parser.add_argument('--save_mesh', dest='save_mesh', action='store_true', default=False, help='Save meshes')
parser.add_argument('--file_type', nargs='+', default=['*.jpg', '*.png', '*.jpeg', '*.JPEG'], help='Image globs')
parser.add_argument('--det_thresh', type=float, default=0.7, help='Detectron2 score threshold for animals')
Expand Down Expand Up @@ -348,6 +393,7 @@ def main():
suffix='before_tta',
side_view=args.side_view,
save_mesh=args.save_mesh,
render_depth=args.render_depth,
)

patch_rgb = denorm_patch_to_rgb(batch['img'][0])
Expand Down Expand Up @@ -392,6 +438,7 @@ def main():
suffix='after_tta',
side_view=args.side_view,
save_mesh=args.save_mesh,
render_depth=args.render_depth,
)


Expand Down
3 changes: 2 additions & 1 deletion demo_tta.sh
Original file line number Diff line number Diff line change
Expand Up @@ -12,4 +12,5 @@ python3 demo_tta.py \
--img_folder demo_data/ \
--out_folder demo_out_tta/ \
--tta_lr 1e-6 \
--tta_num_iters 30
--tta_num_iters 30 \
--render_depth
Loading