Fix signature of unboxed aten::arange for torch HEAD

pull/190/head
Sean Silva 2021-03-18 15:43:45 -07:00
parent 19b9398aee
commit a53ed850bd
2 changed files with 6 additions and 7 deletions

View File

@ -364,7 +364,7 @@ at::Tensor &AcapController::copyUnderKernel(at::Tensor &self,
} }
at::Tensor AcapController::arangeBackendSelectKernel( at::Tensor AcapController::arangeBackendSelectKernel(
at::Scalar end, c10::optional<at::ScalarType> dtype, const at::Scalar &end, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<at::Layout> layout, c10::optional<at::Device> device,
c10::optional<bool> pin_memory) { c10::optional<bool> pin_memory) {
static c10::OperatorName opName{"aten::arange", ""}; static c10::OperatorName opName{"aten::arange", ""};
@ -380,7 +380,7 @@ at::Tensor AcapController::arangeBackendSelectKernel(
// built-in handlers dispatch to BackendSelect kernels. // built-in handlers dispatch to BackendSelect kernels.
auto targetDk = c10::computeDispatchKey(dtype, layout, device); auto targetDk = c10::computeDispatchKey(dtype, layout, device);
auto opTyped = opHandle->typed<at::Tensor( auto opTyped = opHandle->typed<at::Tensor(
at::Scalar end, c10::optional<at::ScalarType> dtype, const at::Scalar &end, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Device> device, c10::optional<at::Layout> layout, c10::optional<at::Device> device,
c10::optional<bool> pin_memory)>(); c10::optional<bool> pin_memory)>();
return opTyped.redispatch(c10::DispatchKeySet({targetDk}), end, dtype, layout, device, return opTyped.redispatch(c10::DispatchKeySet({targetDk}), end, dtype, layout, device,

View File

@ -72,11 +72,10 @@ public:
bool non_blocking); bool non_blocking);
// Backend select kernel for arange factory function. // Backend select kernel for arange factory function.
static at::Tensor static at::Tensor arangeBackendSelectKernel(
arangeBackendSelectKernel(at::Scalar end, c10::optional<at::ScalarType> dtype, const at::Scalar &end, c10::optional<at::ScalarType> dtype,
c10::optional<at::Layout> layout, c10::optional<at::Layout> layout, c10::optional<at::Device> device,
c10::optional<at::Device> device, c10::optional<bool> pin_memory);
c10::optional<bool> pin_memory);
private: private:
/// Builds a kernel call step by step. /// Builds a kernel call step by step.