voyager / data_engine /depth_align.py
ezeanubis's picture
Upload folder using huggingface_hub
a7aea10 verified
import os, re
import json
import numpy as np
import cv2
import torch
import imageio
import pyexr
import trimesh
from PIL import Image
from create_input import render_from_cameras_videos
class DepthAlignMetric:
"""
深度缩放与相机参数更新处理器
Attributes:
moge_depth_dir (str): MOGe待处理深度目录
vggt_depth_dir (str): VGGT待处理深度目录
vggt_camera_json_file (str): VGGT关联的JSON文件目录
output_root (str): 输出根目录
"""
def __init__(self,
input_rgb_dir: str,
moge_depth_dir: str,
vggt_depth_dir: str,
metric3d_depth_dir: str,
vggt_camera_json_file: str,
output_root: str):
"""
Args:
moge_depth_dir: MOGe原始深度路径
vggt_depth_dir: VGGT原始深度路径
vggt_camera_json_file: VGGT关联JSON路径
output_root: 输出根目录,默认为./processed
"""
self.device = "cuda" if torch.cuda.is_available() else "cpu"
# align depth and camera pose to metric level
self.moge_depth_dir = moge_depth_dir
self.vggt_depth_dir = vggt_depth_dir
self.metric3d_depth_dir = metric3d_depth_dir
self.vggt_camera_json_file = vggt_camera_json_file
self.output_root = output_root
# depth to pointmap
self.metric_intrinsic = None
self.metric_w2c = None
self.input_rgb_dir = input_rgb_dir
self.input_color_paths = []
# output depth / camera pose / pointmap
self.output_metric_depth_dir = os.path.join(output_root, "output_metric_depth_dir")
self.output_metric_camera_json = os.path.join(output_root, "output_metric_camera_json")
self.output_metric_pointmap_dir = os.path.join(output_root, "output_metric_pointmap_dir")
os.makedirs(self.output_metric_depth_dir, exist_ok=True)
os.makedirs(self.output_metric_camera_json, exist_ok=True)
os.makedirs(self.output_metric_pointmap_dir, exist_ok=True)
def align_depth_scale(self):
# align Moge depth to VGGT
moge_align_depth_list, valid_mask_list = self.scale_moge_depth()
# align moge depth and camera pose to metric depth
self.align_metric_depth(moge_align_depth_list, valid_mask_list)
def segment_sky_with_oneformer(self, image_path, skyseg_processor, skyseg_model, SKY_CLASS_ID, save_path=None):
from PIL import Image
image = Image.open(image_path)
inputs = skyseg_processor(images=image, task_inputs=["semantic"], return_tensors="pt").to(skyseg_model.device)
with torch.no_grad():
outputs = skyseg_model(**inputs)
# 获取语义分割结果
predicted_semantic_map = skyseg_processor.post_process_semantic_segmentation(outputs, \
target_sizes=[image.size[::-1]])[0]
# 提取天空区域
sky_mask = (predicted_semantic_map == SKY_CLASS_ID).cpu().numpy().astype(np.uint8) * 255
# erosion sky
kernel = np.ones((3,3), np.uint8)
sky_mask = cv2.erode(sky_mask, kernel, iterations=1)
# 如果需要保存
if save_path:
cv2.imwrite(save_path, sky_mask)
return sky_mask
def get_valid_depth(self, vggt_files, moge_files, input_rgb_files, skyseg_processor, skyseg_model, SKY_CLASS_ID):
moge_align_depth_list = []
valid_mask_list = []
all_valid_max_list = []
for vggt_file, moge_file, input_rgb_file in zip(vggt_files, moge_files, input_rgb_files):
# 读取深度数据
depth_moge = pyexr.read(os.path.join(self.moge_depth_dir, moge_file)).squeeze()
depth_vggt = pyexr.read(os.path.join(self.vggt_depth_dir, vggt_file)).squeeze()
depth_vggt = cv2.resize(depth_vggt, dsize=(depth_moge.shape[1], depth_moge.shape[0]), \
interpolation=cv2.INTER_LINEAR)
depth_vggt = torch.from_numpy(depth_vggt).float().to(self.device)
depth_moge = torch.from_numpy(depth_moge).float().to(self.device)
# segmentation sky
sky_ima_path = os.path.join(self.input_rgb_dir, input_rgb_file)
sky_mask = self.segment_sky_with_oneformer(sky_ima_path, skyseg_processor, skyseg_model, SKY_CLASS_ID)
sky_mask_tensor = torch.from_numpy(sky_mask).float().to(self.device)
sky_mask = (sky_mask_tensor > 0) # 天空区域为True
valid_masks = ( # (H, W)
torch.isfinite(depth_moge) &
(depth_moge > 0) &
torch.isfinite(depth_vggt) &
(depth_vggt > 0) &
~sky_mask # 非天空区域
)
# depth_moge 无效部分 设置为 有效部分最大值的1.5倍 避免final_align_depth出现负数
depth_moge[~valid_masks] = depth_moge[valid_masks].max() * 1
source_inv_depth = 1.0 / depth_moge
target_inv_depth = 1.0 / depth_vggt
# print(f'倒数值:{source_inv_depth.min()}, {source_inv_depth.max()}') # 0.03 ~ 2.2
source_mask, target_mask = valid_masks, valid_masks
# Remove outliers 2/8分最合适
outlier_quantiles = torch.tensor([0.2, 0.8], device=self.device)
source_data_low, source_data_high = torch.quantile(
source_inv_depth[source_mask], outlier_quantiles
)
target_data_low, target_data_high = torch.quantile(
target_inv_depth[target_mask], outlier_quantiles
)
source_mask = (source_inv_depth > source_data_low) & (
source_inv_depth < source_data_high
)
target_mask = (target_inv_depth > target_data_low) & (
target_inv_depth < target_data_high
)
mask = torch.logical_and(source_mask, target_mask)
mask = torch.logical_and(mask, valid_masks)
source_data = source_inv_depth[mask].view(-1, 1)
target_data = target_inv_depth[mask].view(-1, 1)
ones = torch.ones((source_data.shape[0], 1), device=self.device)
source_data_h = torch.cat([source_data, ones], dim=1)
transform_matrix = torch.linalg.lstsq(source_data_h, target_data).solution
scale, bias = transform_matrix[0, 0], transform_matrix[1, 0]
aligned_inv_depth = source_inv_depth * scale + bias
valid_inv_depth = aligned_inv_depth > 0 # 创建新的有效掩码
valid_masks = valid_masks & valid_inv_depth # 合并到原有效掩码
valid_mask_list.append(valid_masks)
final_align_depth = 1.0 / aligned_inv_depth
moge_align_depth_list.append(final_align_depth)
all_valid_max_list.append(final_align_depth[valid_masks].max().item())
return moge_align_depth_list, valid_mask_list, all_valid_max_list
def scale_moge_depth(self):
vggt_files = sorted(f for f in os.listdir(self.vggt_depth_dir) if f.endswith('.exr'))
moge_files = sorted(f for f in os.listdir(self.moge_depth_dir) if f.endswith('.exr'))
input_rgb_files = sorted(f for f in os.listdir(self.input_rgb_dir) if f.endswith('.png'))
if len(vggt_files) != len(moge_files):
raise ValueError("文件数量不匹配")
from transformers import OneFormerProcessor, OneFormerForUniversalSegmentation
skyseg_processor = OneFormerProcessor.from_pretrained("shi-labs/oneformer_coco_swin_large")
skyseg_model = OneFormerForUniversalSegmentation.from_pretrained("shi-labs/oneformer_coco_swin_large")
skyseg_model.to(self.device)
# 定义天空类别的ID 119
SKY_CLASS_ID = 119
moge_align_depth_list, valid_mask_list, all_valid_max_list = self.get_valid_depth(
vggt_files, moge_files, input_rgb_files, skyseg_processor, skyseg_model, SKY_CLASS_ID
)
# 计算所有帧的有效最大值的中位数
valid_max_array = np.array(all_valid_max_list)
q50 = np.quantile(valid_max_array, 0.50) # 计算50%分位点
filtered_max = valid_max_array[valid_max_array <= q50] # 过滤超过分位点的异常值
# 取过滤后数据的最大值(正常范围内的最大值)
global_avg_max = np.max(filtered_max)
max_sky_value = global_avg_max * 5
max_sky_value = np.minimum(max_sky_value, 1000) # 相对深度最远不能超过 1000
# 统一设置所有帧的无效区域值
for i, (moge_depth, valid_mask) in enumerate(zip(moge_align_depth_list, valid_mask_list)):
moge_depth[~valid_mask] = max_sky_value
# 统计超限点占比(在clamp之前)
over_count = torch.sum(moge_depth > max_sky_value).item()
total_pixels = moge_depth.numel()
over_ratio = over_count / total_pixels * 100
moge_depth = torch.clamp(moge_depth, max=max_sky_value)
moge_align_depth_list[i] = moge_depth # 更新处理后的深度图
return moge_align_depth_list, valid_mask_list
def align_metric_depth(self, moge_align_depth_list, valid_mask_list):
# 获取metric文件列表
metric_files = sorted(f for f in os.listdir(self.metric3d_depth_dir) if f.endswith('.exr'))
metric_scales_list = []
# 遍历所有深度图对
for idx, (metric_file, moge_depth) in enumerate(zip(metric_files, moge_align_depth_list)):
depth_metric3d = pyexr.read(os.path.join(self.metric3d_depth_dir, metric_file)).squeeze()
depth_metric3d = torch.from_numpy(depth_metric3d).float().to(self.device)
# 获取对应帧的掩码
valid_mask = valid_mask_list[idx].to(self.device)
# 提取有效区域数据
valid_metric = depth_metric3d[valid_mask]
valid_moge = moge_depth[valid_mask]
# 分位数差计算
metric_diff = torch.quantile(valid_metric, 0.8) - torch.quantile(valid_metric, 0.2)
moge_diff = torch.quantile(valid_moge, 0.8) - torch.quantile(valid_moge, 0.2)
metric_scale = metric_diff / moge_diff
metric_scales_list.append(metric_scale.cpu().numpy())
# 计算全局平均缩放因子
metric_scales_mean = np.mean(metric_scales_list)
# 应用全局缩放 保存 metric depth
for idx, (metric_file, moge_depth) in enumerate(zip(metric_files, moge_align_depth_list)):
metric_moge_depth = (moge_depth * metric_scales_mean).cpu().numpy()
# 保存深度文件
output_path = os.path.join(
self.output_metric_depth_dir,
f"{os.path.splitext(metric_file)[0]}_metric.exr"
)
pyexr.write(output_path, metric_moge_depth, channel_names=["Y"])
# 阶段3:更新相机参数
with open(self.vggt_camera_json_file, 'r') as f:
camera_data = json.load(f)
# 更新所有帧的平移分量
for frame_info in camera_data.values():
w2c_matrix = np.array(frame_info['w2c'])
w2c_matrix[:3, 3] *= metric_scales_mean # 直接使用计算好的全局缩放因子
frame_info['w2c'] = w2c_matrix.tolist()
# 保存更新后的相机参数
output_json_path = os.path.join(
self.output_metric_camera_json,
os.path.basename(self.vggt_camera_json_file)
)
with open(output_json_path, 'w') as f:
json.dump(camera_data, f, indent=4)
def load_metirc_camera_parameters(self): # 修改:增加color_dir参数
metric_camera_json = os.path.join(self.output_metric_camera_json, 'colmap_data.json')
with open(metric_camera_json, 'r') as f:
data = json.load(f)
# load metric camera parameters
sorted_frames = sorted(data.items(), key=lambda x: int(x[0]))
first_frame_key, first_frame_data = sorted_frames[0]
self.metric_intrinsic = [np.array(frame['intrinsic']) for frame in data.values()]
self.metric_w2c = [np.array(frame['w2c']) for frame in data.values()]
# 加载pointmap input rgb 文件路径
self.input_color_paths = sorted(
[os.path.join(self.input_rgb_dir, f) for f in os.listdir(self.input_rgb_dir) if f.endswith(".png")],
key=lambda x: int(os.path.basename(x).split("_")[1].split(".")[0])
)
def depth_to_pointmap(self):
num_frames = len(self.metric_w2c)
for frame_index in range(num_frames):
exr_path = os.path.join(self.output_metric_depth_dir, f"frame_{frame_index+1:05d}_metric.exr")
depth_data = pyexr.read(exr_path).squeeze()
depth_tensor = torch.from_numpy(depth_data).to(self.device, torch.float32)
# 生成点云
height, width = depth_tensor.shape
K_tensor = torch.from_numpy(self.metric_intrinsic[frame_index]).to(device=self.device, dtype=torch.float32)
w2c = torch.from_numpy(self.metric_w2c[frame_index]).to(device=self.device, dtype=torch.float32)
camtoworld = torch.inverse(w2c)
# 生成相机坐标系坐标
u = torch.arange(width, device=self.device).float()
v = torch.arange(height, device=self.device).float()
u_grid, v_grid = torch.meshgrid(u, v, indexing='xy')
fx, fy = K_tensor[0, 0], K_tensor[1, 1]
cx, cy = K_tensor[0, 2], K_tensor[1, 2]
x_cam = (u_grid - cx) * depth_tensor / fx
y_cam = (v_grid - cy) * depth_tensor / fy
z_cam = depth_tensor
cam_coords_points = torch.stack([x_cam, y_cam, z_cam], dim=-1)
R_cam_to_world = camtoworld[:3, :3]
t_cam_to_world = camtoworld[:3, 3]
world_coords_points = torch.matmul(cam_coords_points, R_cam_to_world.T) + t_cam_to_world
# # 保存带颜色的点云
color_numpy = np.array(Image.open(self.input_color_paths[frame_index])) # 读取为HWC
colors_rgb = color_numpy.reshape(-1, 3) # 转换回HWC并展平
vertices_3d = world_coords_points.reshape(-1, 3).cpu().numpy()
point_cloud_data = trimesh.PointCloud(vertices=vertices_3d, colors=colors_rgb)
point_cloud_data.export(f"{self.output_metric_pointmap_dir}/pcd_{frame_index+1:04d}.ply")
# 保存为pointmap npy
pointmap_data = world_coords_points.cpu().numpy()
np.save(f"{self.output_metric_pointmap_dir}/pointmap_{frame_index+1:04d}.npy", pointmap_data)
def render_from_cameras(self):
render_output_dir = os.path.join(self.output_root, "rendered_views")
os.makedirs(render_output_dir, exist_ok=True)
select_frame = 0
npy_files = sorted(
[f for f in os.listdir(self.output_metric_pointmap_dir) if f.endswith(".npy")],
key=lambda x: int(re.findall(r'\d+', x)[0])
)
npy_path = os.path.join(self.output_metric_pointmap_dir, npy_files[select_frame])
# 读取npy_path
pointmap = np.load(npy_path)
points = pointmap.reshape(-1, 3)
color_numpy = np.array(Image.open(self.input_color_paths[select_frame])) # 读取为HWC
colors_rgb = color_numpy.reshape(-1, 3) # 转换回HWC并展平
colors = colors_rgb[:, :3]
height, width = cv2.imread(self.input_color_paths[0]).shape[:2]
renders, masks, _ = render_from_cameras_videos(
points, colors, self.metric_w2c, self.metric_intrinsic, height, width
)
# 使用imageio保存所有结果
for i, (render, mask) in enumerate(zip(renders, masks)):
# 保存渲染图
render_path = os.path.join(render_output_dir, f"render_{i:04d}.png")
imageio.imwrite(render_path, render)
# 保存掩码图
mask_path = os.path.join(render_output_dir, f"mask_{i:04d}.png")
imageio.imwrite(mask_path, mask)
print(f"All results saved to: {render_output_dir}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Depth alignment and metric processing.")
parser.add_argument('--image_dir', type=str, required=True, help='Input RGB directory')
parser.add_argument('--moge_depth_dir', type=str, required=True, help='MOGe depth directory')
parser.add_argument('--vggt_depth_dir', type=str, required=True, help='VGGT depth directory')
parser.add_argument('--metric3d_depth_dir', type=str, required=True, help='Metric3D depth directory')
parser.add_argument('--vggt_camera_json_file', type=str, required=True, help='VGGT camera JSON file')
parser.add_argument('--output_dir', type=str, required=True, help='Output root directory')
args = parser.parse_args()
depth_align_processor = DepthAlignMetric(
input_rgb_dir=args.image_dir,
moge_depth_dir=args.moge_depth_dir,
vggt_depth_dir=args.vggt_depth_dir,
metric3d_depth_dir=args.metric3d_depth_dir,
vggt_camera_json_file=args.vggt_camera_json_file,
output_root=args.output_dir
)
depth_align_processor.align_depth_scale()
depth_align_processor.load_metirc_camera_parameters()
depth_align_processor.depth_to_pointmap()
depth_align_processor.render_from_cameras()