mirror of https://github.com/llvm/torch-mlir
Cast `number` to `float` when shape function takes Scalar arg (#1978)
To keep things simple in shape functions, `Scalar` inputs are considered `float`s. This means that when inserting the shape functions into the IR, we must cast any `!torch.number`s into `float`s so that the operand type matches the expected type in the shape function. This commit adds the cast from `Scalar` to `float`.pull/1986/head
parent
72bb902640
commit
d803ab4eeb
|
@ -194,6 +194,16 @@ FailureOr<Value> Torch::adjustFunctionArg(
|
|||
return b.create<DerefineOp>(loc, desiredType, operand).getResult();
|
||||
}
|
||||
|
||||
// To keep things simple in shape functions, `Scalar` inputs are considered
|
||||
// `float`s. This is safe since output shape of torch ops never depends on the
|
||||
// dtype of input scalars. However, this also means we sometimes have to
|
||||
// manually turn `Scalar`s into `float`s when inserting the shape functions
|
||||
// into the IR.
|
||||
if (operandType.isa<Torch::NumberType>() &&
|
||||
desiredType.isa<Torch::FloatType>()) {
|
||||
return b.create<AtenFloatScalarOp>(loc, desiredType, operand).getResult();
|
||||
}
|
||||
|
||||
// If the operand type is statically !torch.optional, then we need to do
|
||||
// different things for the None and non-None cases.
|
||||
// For the None case, we just need to derefine it to the desired type.
|
||||
|
|
|
@ -231,3 +231,14 @@ func.func @adjust_shape_function_arg$list(%arg0: !torch.vtensor, %arg1: !torch.v
|
|||
%1 = torch.aten.index.Tensor %arg0, %0 : !torch.vtensor, !torch.list<vtensor> -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// CHECK-LABEL: func.func @adjust_shape_function_arg$number(
|
||||
// CHECK: %[[FLOAT:.*]] = torch.aten.Float.Scalar {{.*}} : !torch.number -> !torch.float
|
||||
// CHECK: %[[VAL_9:.*]] = func.call @__torch_mlir_shape_fn.aten.arange(%[[FLOAT]], {{.*}}) : (!torch.float, {{.*}}
|
||||
func.func @adjust_shape_function_arg$number(%arg0: !torch.number) -> !torch.vtensor {
|
||||
%none = torch.constant.none
|
||||
%1 = torch.aten.arange %arg0, %none, %none, %none, %none : !torch.number, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor
|
||||
return %1 : !torch.vtensor
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue