mirror of https://github.com/llvm/torch-mlir
Add shape info to `rand_like` + support for `dtype` flag (#851)
The op `aten.rand_like` was missing a shape function, unit tests, and the `dtype` argument was being ignored in its decomposition. This commit fixes all three things.pull/845/merge
parent
e7f306ec2f
commit
96f90efd16
|
@ -1116,42 +1116,26 @@ public:
|
||||||
PatternRewriter &rewriter) const override {
|
PatternRewriter &rewriter) const override {
|
||||||
Location loc = op.getLoc();
|
Location loc = op.getLoc();
|
||||||
Value input = op.self();
|
Value input = op.self();
|
||||||
|
Type resultType = op.getType();
|
||||||
auto inputType = input.getType().cast<BaseTensorType>();
|
auto inputType = input.getType().cast<BaseTensorType>();
|
||||||
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>())
|
if (!inputType.hasDtype() || !inputType.getDtype().isa<mlir::FloatType>()) {
|
||||||
return rewriter.notifyMatchFailure(op,
|
return rewriter.notifyMatchFailure(op,
|
||||||
"only support floating-point type");
|
"only support floating-point type");
|
||||||
|
|
||||||
// TODO: Add support for layout, pin_memory and memory_format features.
|
|
||||||
// Only `none` layout is supported.
|
|
||||||
if (!op.layout().getType().isa<Torch::NoneType>())
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "unimplemented: only default layout is supported");
|
|
||||||
|
|
||||||
// The pin_memory should be either `none` or constant `False`.
|
|
||||||
if (!op.pin_memory().getType().isa<Torch::NoneType>()) {
|
|
||||||
bool pinMemory;
|
|
||||||
if (!matchPattern(op.pin_memory(), m_TorchConstantBool(&pinMemory)))
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "unimplemented: pin_memory must be a constant");
|
|
||||||
else if (pinMemory)
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "unimplemented: pin_memory is expected to be false");
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only `none` memory_format is supported.
|
// Create a uniform random op with low and high set to 0.0 and 1.0,
|
||||||
if (!op.memory_format().getType().isa<Torch::NoneType>())
|
|
||||||
return rewriter.notifyMatchFailure(
|
|
||||||
op, "unimplemented: only default memory format is supported");
|
|
||||||
|
|
||||||
// Create a uniform random op with low and high set to 0.0 and 1.0
|
|
||||||
// respectively.
|
// respectively.
|
||||||
Value none = rewriter.create<ConstantNoneOp>(loc);
|
Value none = rewriter.create<ConstantNoneOp>(loc);
|
||||||
Value lb =
|
Value zero =
|
||||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(0.0));
|
||||||
Value ub =
|
Value one =
|
||||||
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
rewriter.create<ConstantFloatOp>(loc, rewriter.getF64FloatAttr(1.0));
|
||||||
|
Value emptyTensor = rewriter.create<AtenEmptyLikeOp>(
|
||||||
|
loc, resultType, input, op.dtype(), op.layout(), op.device(),
|
||||||
|
op.pin_memory(), op.memory_format());
|
||||||
rewriter.replaceOpWithNewOp<ValsemVariantAtenUniformOp>(
|
rewriter.replaceOpWithNewOp<ValsemVariantAtenUniformOp>(
|
||||||
op, op.getType(), input, lb, ub, /*generator=*/none);
|
op, resultType, emptyTensor, /*from=*/zero, /*to=*/one,
|
||||||
|
/*generator=*/none);
|
||||||
return success();
|
return success();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
|
@ -2018,6 +2018,9 @@ module {
|
||||||
func @"__torch_mlir_shape_fn.aten.bernoulli"(%arg0: !torch.list<int>, %arg1: !torch.any) -> !torch.list<int> {
|
func @"__torch_mlir_shape_fn.aten.bernoulli"(%arg0: !torch.list<int>, %arg1: !torch.any) -> !torch.list<int> {
|
||||||
return %arg0 : !torch.list<int>
|
return %arg0 : !torch.list<int>
|
||||||
}
|
}
|
||||||
|
func @"__torch_mlir_shape_fn.aten.rand_like"(%arg0: !torch.list<int>, %arg1: !torch.optional<int>, %arg2: !torch.optional<int>, %arg3: !torch.optional<Device>, %arg4: !torch.optional<bool>, %arg5: !torch.optional<int>) -> !torch.list<int> {
|
||||||
|
return %arg0 : !torch.list<int>
|
||||||
|
}
|
||||||
func @"__torch_mlir_shape_fn.aten.arange.start_step"(%arg0: !torch.float, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {
|
func @"__torch_mlir_shape_fn.aten.arange.start_step"(%arg0: !torch.float, %arg1: !torch.float, %arg2: !torch.float, %arg3: !torch.optional<int>, %arg4: !torch.optional<int>, %arg5: !torch.optional<Device>, %arg6: !torch.optional<bool>) -> !torch.list<int> {
|
||||||
%0 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
|
%0 = torch.derefine %arg3 : !torch.optional<int> to !torch.any
|
||||||
%1 = torch.derefine %arg4 : !torch.optional<int> to !torch.any
|
%1 = torch.derefine %arg4 : !torch.optional<int> to !torch.any
|
||||||
|
|
|
@ -665,6 +665,9 @@ def aten〇index_put_impl(self: List[int], indices: List[Optional[List[int]]], v
|
||||||
def aten〇bernoulli(self: List[int], generator: Any = None) -> List[int]:
|
def aten〇bernoulli(self: List[int], generator: Any = None) -> List[int]:
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
def aten〇rand_like(self: List[int], dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None, memory_format: Optional[int] = None) -> List[int]:
|
||||||
|
return self
|
||||||
|
|
||||||
def aten〇arange〇start_step(start: float, end: float, step: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
def aten〇arange〇start_step(start: float, end: float, step: float, dtype: Optional[int] = None, layout: Optional[int] = None, device: Optional[device] = None, pin_memory: Optional[bool] = None) -> List[int]:
|
||||||
return upstream_shape_helpers.arange_start_step(start, end, step, dtype, layout, device, pin_memory)
|
return upstream_shape_helpers.arange_start_step(start, end, step, dtype, layout, device, pin_memory)
|
||||||
|
|
||||||
|
|
|
@ -234,3 +234,45 @@ def BernoulliTensorModule_basic(module, tu: TestUtils):
|
||||||
tu.rand(1024, 2048, 8).double(),
|
tu.rand(1024, 2048, 8).double(),
|
||||||
tu.rand(1024, 512, 8).double(),
|
tu.rand(1024, 512, 8).double(),
|
||||||
tu.rand(1024, 512, 8).double())
|
tu.rand(1024, 512, 8).double())
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class RandLikeModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
a = torch.ops.aten.rand_like(x)
|
||||||
|
mean = torch.mean(a)
|
||||||
|
return mean
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: RandLikeModule())
|
||||||
|
def RandLikeModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1024, 1024).double())
|
||||||
|
|
||||||
|
# ==============================================================================
|
||||||
|
|
||||||
|
class RandLikeDtypeModule(torch.nn.Module):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
|
||||||
|
@export
|
||||||
|
@annotate_args([
|
||||||
|
None,
|
||||||
|
([-1, -1], torch.float64, True),
|
||||||
|
])
|
||||||
|
def forward(self, x):
|
||||||
|
a = torch.ops.aten.rand_like(x, dtype=torch.float32)
|
||||||
|
mean = torch.mean(a)
|
||||||
|
return mean
|
||||||
|
|
||||||
|
|
||||||
|
@register_test_case(module_factory=lambda: RandLikeDtypeModule())
|
||||||
|
def RandLikeDtypeModule_basic(module, tu: TestUtils):
|
||||||
|
module.forward(tu.rand(1024, 1024).double())
|
||||||
|
|
|
@ -406,7 +406,16 @@ func @torch.aten._log_softmax(%arg0: !torch.vtensor<[?,?,?],f32> loc(unknown)) -
|
||||||
// CHECK: %[[NONE_2:.*]] = torch.constant.none
|
// CHECK: %[[NONE_2:.*]] = torch.constant.none
|
||||||
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
|
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
|
||||||
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
||||||
// CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT0]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT1]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT2]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[TENSOR_SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]], %[[DIM2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[TENSOR_SIZE]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||||
|
// CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[EMPTY]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||||
|
|
||||||
// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[INP]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64> -> !torch.vtensor<[?,?,?],i1>
|
// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[INP]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64> -> !torch.vtensor<[?,?,?],i1>
|
||||||
// CHECK: %[[INT7_2:.*]] = torch.constant.int 7
|
// CHECK: %[[INT7_2:.*]] = torch.constant.int 7
|
||||||
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
|
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
|
||||||
|
@ -437,7 +446,15 @@ func @torch.aten.bernoulli(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor
|
||||||
// CHECK: %[[NONE_2:.*]] = torch.constant.none
|
// CHECK: %[[NONE_2:.*]] = torch.constant.none
|
||||||
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
|
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
|
||||||
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
||||||
// CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT0]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT1]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT2]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[TENSOR_SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]], %[[DIM2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[TENSOR_SIZE]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||||
|
// CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[EMPTY]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||||
// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[PROB_TENSOR]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[],f64> -> !torch.vtensor<[?,?,?],i1>
|
// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[PROB_TENSOR]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[],f64> -> !torch.vtensor<[?,?,?],i1>
|
||||||
// CHECK: %[[INT7_2:.*]] = torch.constant.int 7
|
// CHECK: %[[INT7_2:.*]] = torch.constant.int 7
|
||||||
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
|
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
|
||||||
|
@ -468,7 +485,15 @@ func @torch.valsem.aten.bernoulli.float(%arg0: !torch.vtensor<[?,?,?],f64>) -> !
|
||||||
// CHECK: %[[NONE_2:.*]] = torch.constant.none
|
// CHECK: %[[NONE_2:.*]] = torch.constant.none
|
||||||
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
|
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
|
||||||
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
||||||
// CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT0]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT1]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT2]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[TENSOR_SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]], %[[DIM2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[TENSOR_SIZE]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||||
|
// CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[EMPTY]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f64>
|
||||||
// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[PROB]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64> -> !torch.vtensor<[?,?,?],i1>
|
// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[PROB]] : !torch.vtensor<[?,?,?],f64>, !torch.vtensor<[?,?,?],f64> -> !torch.vtensor<[?,?,?],i1>
|
||||||
// CHECK: %[[INT7_2:.*]] = torch.constant.int 7
|
// CHECK: %[[INT7_2:.*]] = torch.constant.int 7
|
||||||
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
|
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
|
||||||
|
@ -484,6 +509,33 @@ func @torch.valsem.aten.bernoulli.Tensor(%arg0: !torch.vtensor<[?,?,?],f64>, %ar
|
||||||
return %1 : !torch.vtensor
|
return %1 : !torch.vtensor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// -----
|
||||||
|
// CHECK-LABEL: func @torch.aten.rand_like(
|
||||||
|
// CHECK-SAME: %[[INPUT:.*]]: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
|
||||||
|
// CHECK: %[[INT6:.*]] = torch.constant.int 6
|
||||||
|
// CHECK: %[[NONE_0:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[NONE_1:.*]] = torch.constant.none
|
||||||
|
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
|
||||||
|
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
||||||
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[INPUT]], %[[INT0]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[INPUT]], %[[INT1]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[INT2:.*]] = torch.constant.int 2
|
||||||
|
// CHECK: %[[DIM2:.*]] = torch.aten.size.int %[[INPUT]], %[[INT2]] : !torch.vtensor<[?,?,?],f64>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[TENSOR_SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]], %[[DIM2]] : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[TENSOR_SIZE]], %[[INT6]], %[[NONE_0]], %[[NONE_0]], %[[NONE_0]], %[[NONE_0]] : !torch.list<int>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?,?],f32>
|
||||||
|
// CHECK: %[[UNIFORM:.*]] = torch.valsem.aten.uniform %[[EMPTY]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_1]] : !torch.vtensor<[?,?,?],f32>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?,?],f32>
|
||||||
|
// CHECK: %[[CAST:.*]] = torch.tensor_static_info_cast %[[UNIFORM]] : !torch.vtensor<[?,?,?],f32> to !torch.vtensor
|
||||||
|
// CHECK: return %[[CAST]] : !torch.vtensor
|
||||||
|
func @torch.aten.rand_like(%arg0: !torch.vtensor<[?,?,?],f64>) -> !torch.vtensor {
|
||||||
|
%int6 = torch.constant.int 6
|
||||||
|
%none = torch.constant.none
|
||||||
|
%0 = torch.aten.rand_like %arg0, %int6, %none, %none, %none, %none : !torch.vtensor<[?,?,?],f64>, !torch.int, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?,?],f32>
|
||||||
|
%1 = torch.tensor_static_info_cast %0 : !torch.vtensor<[?,?,?],f32> to !torch.vtensor
|
||||||
|
return %1 : !torch.vtensor
|
||||||
|
}
|
||||||
|
|
||||||
// -----
|
// -----
|
||||||
// CHECK-LABEL: func @torch.aten.select.int(
|
// CHECK-LABEL: func @torch.aten.select.int(
|
||||||
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?],si64> {
|
// CHECK-SAME: %[[T:.*]]: !torch.vtensor<[?,?],si64>) -> !torch.vtensor<[?],si64> {
|
||||||
|
@ -753,7 +805,13 @@ func @torch.aten.dropout$eval(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtenso
|
||||||
// CHECK: %[[NONE_2:.*]] = torch.constant.none
|
// CHECK: %[[NONE_2:.*]] = torch.constant.none
|
||||||
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
|
// CHECK: %[[FLOAT0:.*]] = torch.constant.float 0.000000e+00
|
||||||
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
// CHECK: %[[FLOAT1:.*]] = torch.constant.float 1.000000e+00
|
||||||
// CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[CON2FLOAT]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?],f64>
|
// CHECK: %[[INT0:.*]] = torch.constant.int 0
|
||||||
|
// CHECK: %[[DIM0:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT0]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[INT1:.*]] = torch.constant.int 1
|
||||||
|
// CHECK: %[[DIM1:.*]] = torch.aten.size.int %[[CON2FLOAT]], %[[INT1]] : !torch.vtensor<[?,?],f64>, !torch.int -> !torch.int
|
||||||
|
// CHECK: %[[TENSOR_SIZE:.*]] = torch.prim.ListConstruct %[[DIM0]], %[[DIM1]] : (!torch.int, !torch.int) -> !torch.list<int>
|
||||||
|
// CHECK: %[[EMPTY:.*]] = torch.aten.empty.memory_format %[[TENSOR_SIZE]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]], %[[NONE_1]] : !torch.list<int>, !torch.none, !torch.none, !torch.none, !torch.none, !torch.none -> !torch.vtensor<[?,?],f64>
|
||||||
|
// CHECK: %[[UNF:.*]] = torch.valsem.aten.uniform %[[EMPTY]], %[[FLOAT0]], %[[FLOAT1]], %[[NONE_2]] : !torch.vtensor<[?,?],f64>, !torch.float, !torch.float, !torch.none -> !torch.vtensor<[?,?],f64>
|
||||||
// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[PROB_TENSOR]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[],f64> -> !torch.vtensor<[?,?],i1>
|
// CHECK: %[[CMP:.*]] = torch.aten.lt.Tensor %[[UNF]], %[[PROB_TENSOR]] : !torch.vtensor<[?,?],f64>, !torch.vtensor<[],f64> -> !torch.vtensor<[?,?],i1>
|
||||||
// CHECK: %[[INT6:.*]] = torch.constant.int 6
|
// CHECK: %[[INT6:.*]] = torch.constant.int 6
|
||||||
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
|
// CHECK: %[[FALSE_2:.*]] = torch.constant.bool false
|
||||||
|
|
Loading…
Reference in New Issue