Cast `number` to `float` when shape function takes Scalar arg (#1978)

To keep things simple in shape functions, `Scalar` inputs are
considered `float`s. This means that when inserting the shape
functions into the IR, we must cast any `!torch.number`s into `float`s
so that the operand type matches the expected type in the shape
function. This commit adds the cast from `Scalar` to `float`.
pull/1986/head
Ramiro Leal-Cavazos 2023-03-28 09:30:31 -07:00 committed by GitHub
parent 72bb902640
commit d803ab4eeb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 21 additions and 0 deletions

View File

@ -194,6 +194,16 @@ FailureOr<Value> Torch::adjustFunctionArg(
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
}
// To keep things simple in shape functions, `Scalar` inputs are considered
// `float`s. This is safe since output shape of torch ops never depends on the
// dtype of input scalars. However, this also means we sometimes have to
// manually turn `Scalar`s into `float`s when inserting the shape functions
// into the IR.
if (operandType.isa<Torch::NumberType>() &&
desiredType.isa<Torch::FloatType>()) {
return b.create<AtenFloatScalarOp>(loc, desiredType, operand).getResult();
}
// If the operand type is statically !torch.optional, then we need to do
// different things for the None and non-None cases.
// For the None case, we just need to derefine it to the desired type.

View File

@ -231,3 +231,14 @@ func.func @adjust_shape_function_arg$list(%arg0: !torch.vtensor, %arg1: !torch.v
%1 = torch.aten.index.Tensor %arg0, %0 : !torch.vtensor, !torch.list<vtensor> -> !torch.vtensor
return %1 : !torch.vtensor
}
// -----
// CHECK-LABEL: func.func @adjust_shape_function_arg$number(
// CHECK: %[[FLOAT:.*]] = torch.aten.Float.Scalar {{.*}} : !torch.number -> !torch.float
// CHECK: %[[VAL_9:.*]] = func.call @__torch_mlir_shape_fn.aten.arange(%[[FLOAT]], {{.*}}) : (!torch.float, {{.*}}
func.func @adjust_shape_function_arg$number(%arg0: !torch.number) -> !torch.vtensor {
%none = torch.constant.none
%1 = torch.aten.arange %arg0, %none, %none, %none, %none : !torch.number, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor
return %1 : !torch.vtensor
}