@@ -347,21 +347,6 @@ std::vector<std::vector<int64_t>> CreateReduceGroups(const py::list& groups) {
347347 return replica_groups;
348348}
349349
350- std::vector<at::Tensor> TpuCustomCall (
351- const std::vector<at::Tensor>& inputs, const std::string& payload,
352- const std::vector<std::vector<int64_t >>& output_shapes,
353- const std::vector<py::object>& output_dtypes) {
354- std::vector<at::ScalarType> dtypes;
355- dtypes.reserve (output_dtypes.size ());
356- for (auto & dtype : output_dtypes) {
357- dtypes.push_back (reinterpret_cast <THPDtype*>(dtype.ptr ())->scalar_type );
358- }
359- XLA_ASSIGN_OR_THROW (std::vector<absl_nonnull XLATensorPtr> xla_inputs,
360- bridge::GetXlaTensors (inputs));
361- return bridge::AtenFromXlaTensors (tensor_methods::tpu_custom_call (
362- xla_inputs, payload, output_shapes, dtypes));
363- }
364-
365350std::vector<std::vector<int >> ExtractXlaDotGeneralDimVectors (
366351 const py::tuple& dimension_numbers) {
367352 // Expect Python arg `dimension_numbers` to be
@@ -3116,30 +3101,33 @@ void InitXlaModuleBindings(py::module m) {
31163101 " _xla_custom_call" ,
31173102 [](const std::vector<at::Tensor>& inputs, const std::string& target,
31183103 const std::vector<std::vector<int64_t >>& output_shapes,
3119- const std::vector<py::object >& output_dtypes, bool has_side_effect,
3104+ const std::vector<at::ScalarType >& output_dtypes, bool has_side_effect,
31203105 const std::string& backend_config, const int api_version,
31213106 const std::unordered_map<std::string, std::string>&
31223107 frontend_attributes) -> std::vector<at::Tensor> {
3123- std::vector<at::ScalarType> dtypes;
3124- dtypes.reserve (output_dtypes.size ());
3125- for (auto & dtype : output_dtypes) {
3126- dtypes.push_back (
3127- reinterpret_cast <THPDtype*>(dtype.ptr ())->scalar_type );
3128- }
31293108
3130- XLA_ASSIGN_OR_THROW (std::vector<absl_nonnull XLATensorPtr> xla_inputs, bridge::GetXlaTensors (inputs));
3131- auto xtensors = tensor_methods::custom_call (
3132- xla_inputs, target,
3133- output_shapes, dtypes, has_side_effect, backend_config,
3134- api_version, frontend_attributes);
3135- return bridge::AtenFromXlaTensors (std::move (xtensors));
3109+ XLA_ASSIGN_OR_THROW (std::vector<absl_nonnull XLATensorPtr> xla_inputs,
3110+ bridge::GetXlaTensors (inputs));
3111+ XLA_ASSIGN_OR_THROW (std::vector<absl_nonnull XLATensorPtr> xla_outputs,
3112+ tensor_methods::custom_call (
3113+ xla_inputs, target, output_shapes, output_dtypes,
3114+ has_side_effect, backend_config, api_version,
3115+ frontend_attributes));
3116+
3117+ return bridge::AtenFromXlaTensors (std::move (xla_outputs));
31363118 })
31373119 .def (" _xla_tpu_custom_call" ,
31383120 [](const std::vector<at::Tensor>& inputs, const std::string& payload,
31393121 const std::vector<std::vector<int64_t >>& output_shapes,
3140- const std::vector<py::object >& output_dtypes)
3122+ const std::vector<at::ScalarType >& output_dtypes)
31413123 -> std::vector<at::Tensor> {
3142- return TpuCustomCall (inputs, payload, output_shapes, output_dtypes);
3124+
3125+ XLA_ASSIGN_OR_THROW (std::vector<absl_nonnull XLATensorPtr> xla_inputs,
3126+ bridge::GetXlaTensors (inputs));
3127+ XLA_ASSIGN_OR_THROW (std::vector<absl_nonnull XLATensorPtr> xla_outputs,
3128+ tensor_methods::tpu_custom_call (xla_inputs, payload, output_shapes, output_dtypes));
3129+
3130+ return bridge::AtenFromXlaTensors (std::move (xla_outputs));
31433131 })
31443132 .def (" _xla_register_custom_call_target" ,
31453133 [](const std::string& fn_name, const py::capsule& function_ptr,
0 commit comments