# 项目六：模仿学习桌面操作 (PyBullet + ACT/Diffusion Policy)

## 项目概述

在PyBullet仿真中，用模仿学习训练Franka Panda机械臂完成桌面操作任务：
- 方块推送到目标位置
- 方块堆叠
- 零件插入
- 对比BC、ACT和Diffusion Policy三种方法

## 难度：★★★★☆ (4/5)
## 预估时间：4-6周

---

## 1. 项目结构

```
imitation_learning/
├── envs/
│   ├── push_env.py        # 推送任务环境
│   ├── stack_env.py       # 堆叠任务环境
│   └── insertion_env.py   # 插入任务环境
├── data_collection/
│   ├── keyboard_teleop.py # 键盘遥操作
│   └── record_episodes.py # 数据记录
├── policies/
│   ├── bc_policy.py       # 行为克隆
│   ├── act_policy.py      # ACT实现
│   └── diffusion_policy.py # Diffusion Policy
├── training/
│   ├── train.py           # 训练脚本
│   └── config.py          # 配置
├── evaluation/
│   ├── evaluate.py         # 评估脚本
│   └── compare_methods.py  # 方法对比
└── utils/
    ├── dataset.py          # 数据集加载
    └── visualization.py    # 可视化
```

---

## 2. PyBullet 操作环境

```python
import pybullet as p
import pybullet_data
import numpy as np

class PushEnv:
    """方块推送任务环境"""

    OBS_DIM = 30  # 状态维度
    ACT_DIM = 3   # 末端增量 (dx, dy, dz)

    def __init__(self, render=False, max_steps=200):
        self.render_mode = render
        self.max_steps = max_steps

        if render:
            self.client = p.connect(p.GUI)
        else:
            self.client = p.connect(p.DIRECT)

        p.setAdditionalSearchPath(pybullet_data.getDataPath())
        p.setGravity(0, 0, -9.81)
        p.setTimeStep(1/240)

        self._load_scene()

    def _load_scene(self):
        """加载场景"""
        # 地面 + 桌子
        self.plane = p.loadURDF("plane.urdf")
        self.table = p.loadURDF("table/table.urdf",
                                 [0.5, 0, 0], [0, 0, 0, 1])

        # Franka Panda
        self.robot = p.loadURDF("franka_panda/panda.urdf",
                                 [0, 0, 0.62], [0, 0, 0, 1],
                                 useFixedBase=True)

        # 获取关节索引
        self.arm_joints = []
        self.gripper_joints = []
        for i in range(p.getNumJoints(self.robot)):
            info = p.getJointInfo(self.robot, i)
            name = info[1].decode()
            if 'finger' in name:
                self.gripper_joints.append(i)
            elif info[2] != p.JOINT_FIXED:
                self.arm_joints.append(i)

        self.ee_idx = self.arm_joints[-1] + 1  # 末端连杆索引
        self.n_arm = len(self.arm_joints)

    def reset(self, goal_pos=None):
        """重置环境"""
        # 随机目标位置
        if goal_pos is None:
            self.goal_pos = np.array([
                np.random.uniform(0.3, 0.7),
                np.random.uniform(-0.3, 0.3),
                0.63,   # 桌子高度
            ])
        else:
            self.goal_pos = goal_pos

        # 随机方块初始位置（不与目标重合）
        while True:
            self.block_pos = np.array([
                np.random.uniform(0.3, 0.7),
                np.random.uniform(-0.3, 0.3),
                0.63,
            ])
            if np.linalg.norm(self.block_pos - self.goal_pos) > 0.15:
                break

        # 生成方块
        self.block = p.loadURDF("cube_small.urdf",
                                 self.block_pos, [0, 0, 0, 1],
                                 globalScaling=0.5)

        # 可视化目标
        p.addUserDebugLine(self.goal_pos - [0.05, 0, 0],
                          self.goal_pos + [0.05, 0, 0],
                          [0, 1, 0], 3, lifeTime=0)
        p.addUserDebugLine(self.goal_pos - [0, 0.05, 0],
                          self.goal_pos + [0, 0.05, 0],
                          [0, 1, 0], 3, lifeTime=0)

        # 重置机器人到默认姿态
        self._reset_arm()

        self.step_count = 0
        return self._get_obs()

    def _reset_arm(self):
        """将机械臂恢复到默认姿态"""
        default_q = [0, -0.5, 0, -2.0, 0, 1.5, 0.8]
        for i, q in zip(self.arm_joints, default_q):
            p.resetJointState(self.robot, i, q)
        for _ in range(100):
            p.stepSimulation()

    def _get_obs(self):
        """获取观测"""
        # 末端执行器状态
        ee_state = p.getLinkState(self.robot, self.ee_idx)

        # 关节状态
        joint_states = p.getJointStates(self.robot, self.arm_joints)
        joint_pos = np.array([s[0] for s in joint_states])
        joint_vel = np.array([s[1] for s in joint_states])

        # 方块位置
        block_pos, _ = p.getBasePositionAndOrientation(self.block)

        obs = np.concatenate([
            ee_state[0],          # 末端位置 (3)
            ee_state[1],          # 末端朝向四元数 (4)
            joint_pos,             # 关节位置 (7)
            joint_vel,             # 关节速度 (7)
            block_pos,             # 方块位置 (3)
            self.goal_pos,         # 目标位置 (3)
            self.goal_pos - np.array(block_pos),  # 差值 (3)
        ])
        return obs.astype(np.float32)

    def step(self, action):
        """
        action: [dx, dy, dz] 末端增量 (世界坐标系)
        """
        self.step_count += 1

        # 获取当前末端位姿
        ee_state = p.getLinkState(self.robot, self.ee_idx)
        ee_pos = np.array(ee_state[0])

        # 目标末端位置 = 当前位置 + 增量
        target_pos = ee_pos + np.array(action) * 0.02  # 步长缩放
        target_pos = np.clip(target_pos, [0.3, -0.3, 0.63],
                             [0.7, 0.3, 1.0])

        # 逆运动学求解
        joint_poses = p.calculateInverseKinematics(
            self.robot, self.ee_idx, target_pos
        )

        # 位置控制
        for i, q in zip(self.arm_joints, joint_poses[:self.n_arm]):
            p.setJointMotorControl2(
                self.robot, i, p.POSITION_CONTROL,
                targetPosition=q,
                force=500,
                maxVelocity=1.0
            )

        # 仿真步进
        for _ in range(20):  # 20步 = 约83ms @ 240Hz
            p.stepSimulation()

        # 计算奖励
        block_pos, _ = p.getBasePositionAndOrientation(self.block)
        dist = np.linalg.norm(np.array(block_pos) - self.goal_pos)

        reward = -dist
        done = dist < 0.03
        truncated = self.step_count >= self.max_steps

        return self._get_obs(), reward, done, truncated, {}

    def close(self):
        p.disconnect(self.client)
```

---

## 3. 数据采集（键盘遥操作）

```python
import pygame
import numpy as np
import pickle

class KeyboardTeleop:
    """键盘遥操作数据采集"""

    KEY_MAP = {
        pygame.K_w: np.array([1, 0, 0]),     # +x (远离)
        pygame.K_s: np.array([-1, 0, 0]),    # -x (靠近)
        pygame.K_a: np.array([0, 1, 0]),     # +y (左)
        pygame.K_d: np.array([0, -1, 0]),    # -y (右)
        pygame.K_q: np.array([0, 0, 1]),     # +z (上)
        pygame.K_e: np.array([0, 0, -1]),    # -z (下)
    }

    def __init__(self, env, save_dir='demos/'):
        self.env = env
        self.save_dir = save_dir
        self.episodes = []
        self.recording = False

    def collect_episodes(self, n_episodes=50):
        """采集n条演示数据"""
        pygame.init()
        screen = pygame.display.set_mode((300, 200))
        pygame.display.set_caption("Press SPACE to start episode, ESC to quit")

        clock = pygame.time.Clock()
        episode_idx = 0

        while episode_idx < n_episodes:
            for event in pygame.event.get():
                if event.type == pygame.QUIT:
                    return

                if event.type == pygame.KEYDOWN:
                    if event.key == pygame.K_ESCAPE:
                        return

                    if event.key == pygame.K_SPACE:
                        if not self.recording:
                            self._start_episode(episode_idx)
                            print(f"采集第 {episode_idx+1} 条演示...")
                        else:
                            self._end_episode()
                            episode_idx += 1
                            print(f"完成第 {episode_idx} 条演示")

            if self.recording:
                obs = self.env._get_obs()
                action = self._get_action()

                self.env.step(action)
                self._record_step(obs, action)

                done = self._check_done()
                if done:
                    self._end_episode()
                    episode_idx += 1

            clock.tick(10)  # 10Hz采集

        self._save_all()
        pygame.quit()

    def _get_action(self):
        keys = pygame.key.get_pressed()
        action = np.zeros(3)
        for key, vec in self.KEY_MAP.items():
            if keys[key]:
                action += vec
        # 归一化
        norm = np.linalg.norm(action)
        if norm > 0:
            action = action / norm
        return action
```

---

## 4. 训练与对比

```python
import torch
import torch.nn as nn
import matplotlib.pyplot as plt
from dataloader import DemonstrationDataset
from bc_policy import BCPolicy
from act_policy import ACTPolicy
from diffusion_policy import DiffusionPolicy

def train_and_compare():
    """训练三种方法并对比"""

    # 加载数据
    dataset = DemonstrationDataset('demos/push/')
    train_loader = torch.utils.data.DataLoader(
        dataset, batch_size=32, shuffle=True
    )

    results = {}

    # ========== 行为克隆 ==========
    print("=== 训练 Behavior Cloning ===")
    bc = BCPolicy(state_dim=30, action_dim=3)
    bc_optimizer = torch.optim.Adam(bc.parameters(), lr=1e-3)
    bc_losses = []

    for epoch in range(100):
        epoch_loss = 0
        for states, actions in train_loader:
            bc_optimizer.zero_grad()
            pred = bc(states)
            loss = nn.MSELoss()(pred, actions)
            loss.backward()
            bc_optimizer.step()
            epoch_loss += loss.item()
        bc_losses.append(epoch_loss / len(train_loader))

    # 评估BC
    bc_success_rate = evaluate_policy(bc, env, n_trials=50)
    results['BC'] = {'losses': bc_losses, 'success_rate': bc_success_rate}

    # ========== ACT ==========
    print("=== 训练 ACT ===")
    act = ACTPolicy(state_dim=30, action_dim=3, chunk_size=20)
    act_optimizer = torch.optim.Adam(act.parameters(), lr=1e-4)
    # ... 训练循环
    act_success_rate = evaluate_policy(act, env, n_trials=50)
    results['ACT'] = {'success_rate': act_success_rate}

    # ========== Diffusion Policy ==========
    print("=== 训练 Diffusion Policy ===")
    diffusion = DiffusionPolicy(state_dim=30, action_dim=3)
    diffusion_optimizer = torch.optim.Adam(diffusion.parameters(), lr=1e-4)
    # ... 训练循环
    diff_success_rate = evaluate_policy(diffusion, env, n_trials=50)
    results['Diffusion'] = {'success_rate': diff_success_rate}

    # ========== 可视化对比 ==========
    fig, axes = plt.subplots(1, 2, figsize=(12, 4))

    axes[0].plot(bc_losses, label='BC')
    axes[0].set_xlabel('Epoch'); axes[0].set_ylabel('Loss')
    axes[0].set_title('训练损失')
    axes[0].legend()

    methods = ['BC', 'ACT', 'Diffusion']
    rates = [results[m]['success_rate'] for m in methods]
    axes[1].bar(methods, rates, color=['#667eea', '#764ba2', '#f093fb'])
    axes[1].set_ylabel('Success Rate')
    axes[1].set_title('成功率对比 (50 trials)')
    for i, v in enumerate(rates):
        axes[1].text(i, v + 0.02, f'{v:.0%}', ha='center')

    plt.tight_layout()
    plt.savefig('imitation_learning_comparison.png', dpi=150)
    print(f"\n结果: BC={bc_success_rate:.1%}, "
          f"ACT={act_success_rate:.1%}, "
          f"Diffusion={diff_success_rate:.1%}")

    return results
```

---

## 5. 验收标准

1. **数据采集**：能通过键盘遥操作采集50+条有效演示
2. **BC训练**：在推送任务上成功率 > 60%
3. **ACT改进**：ACT成功率比BC提升 > 10个百分点
4. **Diffusion Policy**：多模态任务（可选择左右两侧推送）上Diffusion明显优于BC
5. **实验报告**：包含训练曲线、成功率、失败案例分析

---

## 参考资源

- [robomimic](https://robomimic.github.io/) - 模仿学习基准
- [ACT (Action Chunking Transformer)](https://arxiv.org/abs/2304.13705)
- [Diffusion Policy](https://diffusion-policy.cs.columbia.edu/)
