File size: 2,480 Bytes
899c526
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
#ifndef DISPATCH_H
#define DISPATCH_H

#include <torch/extension.h>

#include "so3.h"
#include "rxso3.h"
#include "se3.h"
#include "sim3.h"


#define PRIVATE_CASE_TYPE(group_index, enum_type, type, ...)    \
  case enum_type: {                                             \
    using scalar_t = type;                                      \
    switch (group_index) {                                      \
      case 1: {                                                 \
        using group_t = SO3<type>;                              \
        return __VA_ARGS__();                                   \
      }                                                         \
      case 2: {                                                 \
        using group_t = RxSO3<type>;                            \
        return __VA_ARGS__();                                   \
      }                                                         \
      case 3: {                                                 \
        using group_t = SE3<type>;                              \
        return __VA_ARGS__();                                   \
      }                                                         \
      case 4: {                                                 \
        using group_t = Sim3<type>;                             \
        return __VA_ARGS__();                                   \
      }                                                         \
    }                                                           \
  }                                                             \

#define DISPATCH_GROUP_AND_FLOATING_TYPES(GROUP_INDEX, TYPE, NAME, ...)              \
  [&] {                                                                              \
    const auto& the_type = TYPE;                                                     \
    /* don't use TYPE again in case it is an expensive or side-effect op */          \
    at::ScalarType _st = ::detail::scalar_type(the_type);                            \
    switch (_st) {                                                                   \
      PRIVATE_CASE_TYPE(GROUP_INDEX, at::ScalarType::Double, double, __VA_ARGS__)    \
      PRIVATE_CASE_TYPE(GROUP_INDEX, at::ScalarType::Float, float, __VA_ARGS__)      \
      default: break;                                                                \
    }                                                                                \
  }()

#endif