Add `torch.vtensor.literal` op.

This op is much better behaved than the `torch.tensor.literal` op
(which is the new name of the `torch.tensor` op). In particular
`torch.tensor.literal`:
- always has a maximally refined type.
- always has value semantics.
- can be constant folded / CSE'd.

ReduceOpVariants is changed to perform the transformation from
`torch.tensor.literal` to `torch.vtensor.literal` (which in general
involves static information casts and copies.

This new op also allowed tightening up `torch.tensor.literal` to only
accept NonValueTensorType (instead of any tensor type).

This new ".literal" name is more descriptive. It was getting too
confusing seeing an op called just `torch.tensor` (we originally called
it that because that's the name of the similar function in the Torch
Python API, but it just doesn't fit here).
pull/232/head
Sean Silva 2021-06-17 08:52:13 -07:00
parent 4a0eb44d17
commit 333e07a74e
20 changed files with 152 additions and 76 deletions

View File

@ -528,7 +528,7 @@ MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
auto loc = getCurrentLocation();
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
MlirOperation tensorOp = createMlirOperationAtEnd(
funcBuilder->getEntryBlock(), "torch.tensor", loc,
funcBuilder->getEntryBlock(), "torch.tensor.literal", loc,
npcompTorchNonValueTensorTypeGetFromShaped(
mlirAttributeGetType(denseElements)),
toMlirNamedAttribute("value", denseElements));

View File

@ -349,7 +349,7 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
at::Tensor tensor = ivalue.toTensor().contiguous();
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
MlirOperation tensorOp =
createMlirOperationAtEnd(importBlock, "torch.tensor", loc,
createMlirOperationAtEnd(importBlock, "torch.tensor.literal", loc,
npcompTorchNonValueTensorTypeGetFromShaped(
mlirAttributeGetType(denseElements)),
toMlirNamedAttribute("value", denseElements));

View File

@ -13,6 +13,6 @@ with mb.capture_function("arange_test", []) as f:
x = torch.arange(10)
f.returns([x])
# CHECK: %[[T:.*]] = torch.tensor(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xsi64>) : !torch.tensor<[10],si64>
# CHECK: %[[T:.*]] = torch.tensor.literal(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xsi64>) : !torch.tensor<[10],si64>
# CHECK: return %[[T]]
mb.module.operation.print()

View File

@ -23,9 +23,9 @@ with mb.capture_function("add3", [t0, t1, t2]) as f:
# CHECK-SAME: %[[VAL_2:.*]]: !torch.tensor<[1,2,3,4],f32>) -> !torch.tensor<[1,2,3,4],f32> {
# CHECK: %[[VAL_3:.*]] = torch.constant.int 1
# CHECK: %[[VAL_4:.*]] = torch.constant.int 1
# CHECK: %[[VAL_5:.*]] = torch.tensor(dense<0.000000e+00> : tensor<1x2x3x4xf32>) : !torch.tensor<[1,2,3,4],f32>
# CHECK: %[[VAL_5:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<1x2x3x4xf32>) : !torch.tensor<[1,2,3,4],f32>
# CHECK: %[[VAL_6:.*]] = torch.operator "aten.add.out"(%[[VAL_0]], %[[VAL_1]], %[[VAL_3]], %[[VAL_5]]) : (!torch.tensor<[1,2,3,4],f32>, !torch.tensor<[1,2,3,4],f32>, !torch.int, !torch.tensor<[1,2,3,4],f32>) -> !torch.tensor<[1,2,3,4],f32>
# CHECK: %[[VAL_7:.*]] = torch.tensor(dense<0.000000e+00> : tensor<1x2x3x4xf32>) : !torch.tensor<[1,2,3,4],f32>
# CHECK: %[[VAL_7:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<1x2x3x4xf32>) : !torch.tensor<[1,2,3,4],f32>
# CHECK: %[[VAL_8:.*]] = torch.operator "aten.add.out"(%[[VAL_6]], %[[VAL_2]], %[[VAL_4]], %[[VAL_7]]) : (!torch.tensor<[1,2,3,4],f32>, !torch.tensor<[1,2,3,4],f32>, !torch.int, !torch.tensor<[1,2,3,4],f32>) -> !torch.tensor<[1,2,3,4],f32>
# CHECK: return %[[VAL_8]] : !torch.tensor<[1,2,3,4],f32>
# CHECK: }

View File

@ -43,8 +43,8 @@ with mb.capture_function("conv2d_fwd", [tensor]) as f:
# CHECK: %[[VAL_8:.*]] = torch.constant.int 0
# CHECK: %[[VAL_9:.*]] = torch.constant.int 0
# CHECK: %[[VAL_10:.*]] = torch.constant.int 1
# CHECK: %[[VAL_11:.*]] = torch.tensor(opaque<"_", "0xDEADBEEF"> : tensor<4x16x3x3xf32>) : !torch.tensor<[4,16,3,3],f32>
# CHECK: %[[VAL_12:.*]] = torch.tensor(opaque<"_", "0xDEADBEEF"> : tensor<4xf32>) : !torch.tensor<[4],f32>
# CHECK: %[[VAL_11:.*]] = torch.tensor.literal(opaque<"_", "0xDEADBEEF"> : tensor<4x16x3x3xf32>) : !torch.tensor<[4,16,3,3],f32>
# CHECK: %[[VAL_12:.*]] = torch.tensor.literal(opaque<"_", "0xDEADBEEF"> : tensor<4xf32>) : !torch.tensor<[4],f32>
# CHECK: %[[VAL_13:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
# CHECK: %[[VAL_14:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
# CHECK: %[[VAL_15:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>

View File

@ -14,7 +14,7 @@ mb = torch_mlir.ModuleBuilder()
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
# CHECK: %[[T:.*]] = torch.tensor
# CHECK: %[[T:.*]] = torch.tensor.literal
# CHECK: torch.nn_module {
# CHECK: torch.slot "t1", %[[T]]
# CHECK: torch.slot "t2", %[[T]]

View File

@ -21,9 +21,9 @@ class TestModule(torch.nn.Module):
dtype=torch.qint8)
# CHECK: %[[SCALE:.*]] = torch.constant.float
# CHECK: %[[ZERO_POINT:.*]] = torch.constant.int 0
# CHECK: %[[INT_REPR:.*]] = torch.tensor({{.*}}) : !torch.tensor<[2,5],si8>
# CHECK: %[[INT_REPR:.*]] = torch.tensor.literal({{.*}}) : !torch.tensor<[2,5],si8>
# CHECK: %[[WEIGHTS:.*]] = torch.per_tensor_affine.create %[[INT_REPR]], %[[SCALE]], %[[ZERO_POINT]] : !torch.tensor<[2,5],si8>, !torch.float, !torch.int -> !torch.tensor<[2,5],!torch.qint8>
# CHECK: %[[BIAS:.*]] = torch.tensor({{.*}}) : !torch.tensor<[2],f32>
# CHECK: %[[BIAS:.*]] = torch.tensor.literal({{.*}}) : !torch.tensor<[2],f32>
# CHECK: %[[LINEAR_PARAMS:.*]] = torch.linear_params.create %[[WEIGHTS]], %[[BIAS]] : !torch.tensor<[2,5],!torch.qint8>, !torch.tensor<[2],f32>
@torch.jit.export
def test_linear(self, t):

View File

@ -18,8 +18,8 @@ class TestModule(torch.nn.Module):
self.ones = torch.ones(1)
self.arange = torch.nn.Parameter(torch.arange(3.0))
# CHECK: %[[ARANGE:.*]] = torch.tensor(dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>) : !torch.tensor<[3],f32>
# CHECK: %[[ONES:.*]] = torch.tensor(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor<[1],f32>
# CHECK: %[[ARANGE:.*]] = torch.tensor.literal(dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>) : !torch.tensor<[3],f32>
# CHECK: %[[ONES:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor<[1],f32>
# CHECK: %[[ROOT:.*]] = torch.nn_module {
# CHECK: torch.slot "arange", %[[ARANGE]] : !torch.tensor<[3],f32>
# CHECK: torch.slot "ones", %[[ONES]] : !torch.tensor<[1],f32>

View File

@ -832,32 +832,27 @@ def Torch_PerTensorAffineCreateOp : Torch_Op<"per_tensor_affine.create", [
}];
}
// TODO: Disaggregate this op into a value-semantic constant + val->nonval
// conversion if needed.
// Currently, this op can effectively hide val->nonval conversion, which makes
// it an edge case for passes that care about that such as
// torch-maximize-value-semantics.
// So the suggestion would be to lower this to a `torch.vtensor` op
// (+`torch.copy.tensor` if needed).
// In particular, currently we end up relying on convert-torch-to-std
// to effectively expose this (as part of lowering to `std.constant`) +
// hoping that some canonicalization cleans it up.
// The `torch-maximize-value-semantics` pass should be doing this
// before we convert to std at all.
def Torch_TensorOp : Torch_Op<"tensor", [
def Torch_NonValueTensorLiteralOp : Torch_Op<"tensor.literal", [
DeclareOpInterfaceMethods<InferTypeOpInterface, ["isCompatibleReturnTypes"]>,
AllowsTypeRefinement
AllowsTypeRefinement,
]> {
let summary = "Create a value of !torch.tensor type from a literal";
let description = [{
Example:
```
%0 = torch.tensor(dense<0.0> : tensor<3x5xf32>) : !torch.tensor
%1 = torch.tensor(dense<0.0> : tensor<3xf32>) : !torch.vtensor<[3],f32>
%0 = torch.tensor.literal(dense<0.0> : tensor<3x5xf32>) : !torch.tensor
%1 = torch.tensor.literal(dense<0.0> : tensor<3xf32>) : !torch.tensor<[3],f32>
```
This op covers a typical frontend use case of creating a type-erased
`!torch.tensor`. Inside the compiler, we decompose it into
`torch.vtensor.literal` which is easier to analyze and transform.
Note: This op is not called "constant" because the created tensor is not
"constant" in any meaning of that word.
}];
let arguments = (ins ElementsAttr:$value);
let results = (outs AnyTorchTensorType:$result);
let results = (outs Torch_NonValueTensorType:$result);
let assemblyFormat = [{
`(` $value `)` attr-dict `:` type($result)
@ -869,6 +864,35 @@ def Torch_TensorOp : Torch_Op<"tensor", [
}];
}
def Torch_ValueTensorLiteralOp : Torch_Op<"vtensor.literal", [
DeclareOpInterfaceMethods<InferTypeOpInterface>,
ConstantLike,
NoSideEffect,
]> {
let summary = "Create a value of !torch.vtensor type from a literal";
let description = [{
Example:
```
%0 = torch.vtensor.literal(dense<0.0> : tensor<3x5xf32>) : !torch.vtensor<[3,5],f32>
%1 = torch.vtensor.literal(dense<0.0> : tensor<3xf32>) : !torch.vtensor<[3],f32>
```
Unlike `torch.tensor.literal`, which covers a typical frontend use case
and allows type refinement, this op always has a maximally resolved type
(which is always possible, because it is created from a literal). This
has a stronger set of invariants that better fit the needs of the
compiler internals.
}];
let arguments = (ins ElementsAttr:$value);
let results = (outs Torch_ValueTensorType:$result);
let assemblyFormat = [{
`(` $value `)` attr-dict `:` type($result)
}];
let hasFolder = 1;
}
def Torch_TensorStaticInfoCastOp : Torch_Op<"tensor_static_info_cast", [
DeclareOpInterfaceMethods<CastOpInterface>,
AllowsTypeRefinement,

View File

@ -70,14 +70,19 @@ public:
};
} // namespace
LogicalResult convertTensorOp(TensorOp op, PatternRewriter &rewriter) {
auto constant = rewriter.create<ConstantOp>(op->getLoc(), op.value());
auto vtensor = rewriter.create<FromBuiltinTensorOp>(op->getLoc(), constant);
Value result = copyTensorToType(rewriter, op->getLoc(),
op.getType().cast<BaseTensorType>(), vtensor);
rewriter.replaceOp(op, {result});
namespace {
class ConvertValueTensorLiteralOp
: public OpConversionPattern<ValueTensorLiteralOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(ValueTensorLiteralOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
rewriter.replaceOpWithNewOp<mlir::ConstantOp>(op, op.value());
return success();
}
};
} // namespace
// -----------------------------------------------------------------------------
// The pass
@ -106,8 +111,8 @@ public:
patterns.add<ConvertAtenNeIntOp>(typeConverter, context);
target.addIllegalOp<AtenGtIntOp>();
patterns.add<ConvertAtenGtIntOp>(typeConverter, context);
target.addIllegalOp<TensorOp>();
patterns.add(convertTensorOp);
target.addIllegalOp<ValueTensorLiteralOp>();
patterns.add<ConvertValueTensorLiteralOp>(typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))
return signalPassFailure();

View File

@ -136,5 +136,12 @@ Operation *TorchDialect::materializeConstant(OpBuilder &builder,
if (auto stringAttr = value.dyn_cast<StringAttr>())
return builder.create<ConstantStrOp>(loc, stringAttr);
if (auto elementsAttr = value.dyn_cast<ElementsAttr>()) {
// Only !torch.vtensor can be constant folded. !torch.tensor has
// non-trivial aliasing semantics which prevent deduplicating it.
assert(type.isa<ValueTensorType>() && "should be a vtensor type!");
return builder.create<ValueTensorLiteralOp>(loc, elementsAttr);
}
return nullptr;
}

View File

@ -438,13 +438,12 @@ OpFoldResult AtenNeIntOp::fold(ArrayRef<Attribute> operands) {
}
//===----------------------------------------------------------------------===//
// TensorOp
// NonValueTensorLiteralOp
//===----------------------------------------------------------------------===//
LogicalResult
TensorOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
ValueRange operands, DictionaryAttr attributes,
RegionRange regions,
LogicalResult NonValueTensorLiteralOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto attr = attributes.get("value").dyn_cast_or_null<ElementsAttr>();
if (!attr)
@ -466,13 +465,34 @@ static bool areSizesAndDtypesCompatible(BaseTensorType a, BaseTensorType b) {
return true;
}
bool TensorOp::isCompatibleReturnTypes(TypeRange inferred, TypeRange actual) {
bool NonValueTensorLiteralOp::isCompatibleReturnTypes(TypeRange inferred,
TypeRange actual) {
if (!actual[0].isa<BaseTensorType>())
return false;
return areSizesAndDtypesCompatible(inferred[0].cast<BaseTensorType>(),
actual[0].cast<BaseTensorType>());
}
//===----------------------------------------------------------------------===//
// ValueTensorLiteralOp
//===----------------------------------------------------------------------===//
LogicalResult ValueTensorLiteralOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
auto attr = attributes.get("value").dyn_cast_or_null<ElementsAttr>();
if (!attr)
return failure();
auto tensorType = attr.getType().cast<RankedTensorType>();
inferredReturnTypes.push_back(ValueTensorType::getFromShaped(tensorType));
return success();
}
OpFoldResult ValueTensorLiteralOp::fold(ArrayRef<Attribute> operands) {
return valueAttr();
}
//----------------------------------------------------------------------------//
// TensorStaticInfoCast
//----------------------------------------------------------------------------//

View File

@ -99,6 +99,17 @@ public:
};
} // namespace
static LogicalResult
reduceNonValueTensorLiteralOpToValueTensorLiteralOp(NonValueTensorLiteralOp op,
PatternRewriter &rewriter) {
Value valueTensor =
rewriter.create<ValueTensorLiteralOp>(op->getLoc(), op.value());
Value tensor =
copyTensorToType(rewriter, op->getLoc(), op.getType(), valueTensor);
rewriter.replaceOp(op, {tensor});
return success();
}
namespace {
class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
void runOnOperation() override {
@ -106,8 +117,10 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
RewritePatternSet patterns(context);
patterns.add<ConvertToImmutableTensors>(context);
patterns.add<ReduceTrailingUnderscoreInplaceVariant>(context);
patterns.add(reduceNonValueTensorLiteralOpToValueTensorLiteralOp);
ConversionTarget target(*context);
target.addIllegalOp<NonValueTensorLiteralOp>();
target.markUnknownOpDynamicallyLegal([](Operation *op) {
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {
auto hasValueSemantics = [](Type t) {

View File

@ -39,21 +39,11 @@ func @torch.aten.gt.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
return %0 : !torch.bool
}
// CHECK-LABEL: func @torch.tensor$value() -> !torch.vtensor<[],f32> {
// CHECK-LABEL: func @torch.vtensor.literal() -> !torch.vtensor<[],f32> {
// CHECK: %[[CST:.*]] = constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VTENSOR:.*]] = torch.from_builtin_tensor %[[CST]] : tensor<f32> -> !torch.vtensor<[],f32>
// CHECK: return %[[VTENSOR]] : !torch.vtensor<[],f32>
func @torch.tensor$value() -> !torch.vtensor<[],f32> {
%0 = torch.tensor(dense<0.0> : tensor<f32>) : !torch.vtensor<[],f32>
func @torch.vtensor.literal() -> !torch.vtensor<[],f32> {
%0 = torch.vtensor.literal(dense<0.0> : tensor<f32>) : !torch.vtensor<[],f32>
return %0 : !torch.vtensor<[],f32>
}
// CHECK-LABEL: func @torch.tensor$nonval() -> !torch.tensor<[],f32> {
// CHECK: %[[CST:.*]] = constant dense<0.000000e+00> : tensor<f32>
// CHECK: %[[VTENSOR:.*]] = torch.from_builtin_tensor %[[CST]] : tensor<f32> -> !torch.vtensor<[],f32>
// CHECK: %[[NONVAL:.*]] = torch.copy.tensor %[[VTENSOR]] : !torch.vtensor<[],f32> -> !torch.tensor<[],f32>
// CHECK: return %[[NONVAL]] : !torch.tensor<[],f32>
func @torch.tensor$nonval() -> !torch.tensor<[],f32> {
%0 = torch.tensor(dense<0.0> : tensor<f32>) : !torch.tensor<[],f32>
return %0 : !torch.tensor<[],f32>
}

View File

@ -18,7 +18,7 @@
// CHECK: }
// CHECK-LABEL: torch.global_slot @t : !torch.tensor {
// CHECK: %[[T:.*]] = torch.tensor(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
// CHECK: %[[T:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
// CHECK: torch.global_slot.init %[[T]] : !torch.tensor
// CHECK: }
@ -32,7 +32,7 @@ torch.class_type @c {
%bool_true = torch.constant.bool true
%i = torch.constant.int 3
%f = torch.constant.float 4.250000e+01
%t = torch.tensor(dense<1.0> : tensor<1xf32>) : !torch.tensor
%t = torch.tensor.literal(dense<1.0> : tensor<1xf32>) : !torch.tensor
torch.nn_module {
torch.slot "b", %bool_true : !torch.bool
torch.slot "i", %i : !torch.int

View File

@ -37,7 +37,7 @@ torch.class_type @c {
}
// expected-error @+1 {{potentially-aliased value used to initialize multiple slots}}
%t = torch.tensor(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
%t = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
torch.nn_module {
torch.slot "t1", %t : !torch.tensor
torch.slot "t2", %t : !torch.tensor

View File

@ -2,24 +2,24 @@
// CHECK-NOT: @readonly
torch.global_slot "private" @readonly : !torch.tensor {
%0 = torch.tensor(dense<0.0> : tensor<1xf32>) : !torch.tensor
%0 = torch.tensor.literal(dense<0.0> : tensor<1xf32>) : !torch.tensor
torch.global_slot.init %0 : !torch.tensor
}
// CHECK-LABEL: torch.global_slot @public
torch.global_slot @public : !torch.tensor {
%0 = torch.tensor(dense<0.0> : tensor<2xf32>) : !torch.tensor
%0 = torch.tensor.literal(dense<0.0> : tensor<2xf32>) : !torch.tensor
torch.global_slot.init %0 : !torch.tensor
}
// CHECK-LABEL: torch.global_slot "private" @mutated
torch.global_slot "private" @mutated : !torch.tensor {
%0 = torch.tensor(dense<0.0> : tensor<3xf32>) : !torch.tensor
%0 = torch.tensor.literal(dense<0.0> : tensor<3xf32>) : !torch.tensor
torch.global_slot.init %0 : !torch.tensor
}
// CHECK-LABEL: func @forward() -> (!torch.tensor, !torch.tensor, !torch.tensor) {
func @forward() -> (!torch.tensor, !torch.tensor, !torch.tensor) {
// Inlined.
// CHECK: %[[READONLY:.*]] = torch.tensor(dense<0.000000e+00> : tensor<1xf32>) : !torch.tensor
// CHECK: %[[READONLY:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<1xf32>) : !torch.tensor
%0 = torch.global_slot.get @readonly : !torch.tensor
// Not inlined: potentially mutated by externals.

View File

@ -131,7 +131,7 @@ func private @tensor.invalid_dtype() -> !torch.tensor<*,tuple<>>
func @torch.tensor() {
// Incompatible shape.
// expected-error@+1 {{incompatible}}
%0 = torch.tensor(dense<42.0> : tensor<3x2xf32>) : !torch.vtensor<[],f32>
%0 = torch.tensor.literal(dense<42.0> : tensor<3x2xf32>) : !torch.vtensor<[],f32>
return
}
@ -140,7 +140,7 @@ func @torch.tensor() {
func @torch.tensor() {
// Incompatible dtype.
// expected-error@+1 {{incompatible}}
%0 = torch.tensor(dense<42.0> : tensor<f32>) : !torch.vtensor<[],f64>
%0 = torch.tensor.literal(dense<42.0> : tensor<f32>) : !torch.vtensor<[],f64>
return
}
@ -149,7 +149,7 @@ func @torch.tensor() {
func @torch.tensor() {
// Incompatible type.
// expected-error@+1 {{incompatible}}
%0 = torch.tensor(dense<42.0> : tensor<f32>) : i1
%0 = torch.tensor.literal(dense<42.0> : tensor<f32>) : i1
return
}

View File

@ -48,12 +48,19 @@ func private @tuple.one_element() -> !torch.tuple<!torch.tensor>
// CHECK: @tuple.two_elements() -> !torch.tuple<!torch.tensor, !torch.tensor>
func private @tuple.two_elements() -> !torch.tuple<!torch.tensor, !torch.tensor>
// CHECK-LABEL: func @torch.tensor() {
func @torch.tensor() {
// CHECK: torch.tensor(dense<4.200000e+01> : tensor<3x2xf32>) : !torch.vtensor<[3,2],f32>
%0 = torch.tensor(dense<42.0> : tensor<3x2xf32>) : !torch.vtensor<[3,2],f32>
// CHECK: torch.tensor(dense<4.200000e+01> : tensor<3x2xf32>) : !torch.tensor<[3,2],f32>
%1 = torch.tensor(dense<42.0> : tensor<3x2xf32>) : !torch.tensor<[3,2],f32>
// CHECK-LABEL: func @torch.tensor.literal() {
func @torch.tensor.literal() {
// CHECK: torch.tensor.literal(dense<4.200000e+01> : tensor<3x2xf32>) : !torch.tensor
%0 = torch.tensor.literal(dense<42.0> : tensor<3x2xf32>) : !torch.tensor
// CHECK: torch.tensor.literal(dense<4.200000e+01> : tensor<3x2xf32>) : !torch.tensor<[3,2],f32>
%1 = torch.tensor.literal(dense<42.0> : tensor<3x2xf32>) : !torch.tensor<[3,2],f32>
return
}
// CHECK-LABEL: func @torch.vtensor.literal() {
func @torch.vtensor.literal() {
// CHECK: torch.vtensor.literal(dense<4.200000e+01> : tensor<3x2xf32>) : !torch.vtensor<[3,2],f32>
%0 = torch.vtensor.literal(dense<42.0> : tensor<3x2xf32>) : !torch.vtensor<[3,2],f32>
return
}
@ -81,7 +88,7 @@ func @torch.prim.If(%arg0: !torch.bool, %arg1: !torch.int) -> !torch.int {
%int3 = torch.constant.int 3
// CHECK: %float = torch.constant.float 4.250000e+01
%float = torch.constant.float 4.250000e+01
%tensor = torch.tensor(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
%tensor = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
// CHECK: %none = torch.constant.none
%none = torch.constant.none
// CHECK: %str = torch.constant.str "some str"

View File

@ -31,3 +31,13 @@ func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2],f32>
%0 = torch.aten.add_.Tensor %arg0, %arg1, %c1 : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>, !torch.int -> !torch.tensor<[2,2],f32>
return %0, %arg0 : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>
}
// CHECK-LABEL: func @torch.tensor.literal() -> !torch.tensor {
// CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<7xf32>) : !torch.vtensor<[7],f32>
// CHECK: %[[SIZES_ERASED:.*]] = torch.tensor_static_info_cast %[[VTENSOR]] : !torch.vtensor<[7],f32> to !torch.vtensor
// CHECK: %[[TENSOR:.*]] = torch.copy.tensor %[[SIZES_ERASED]] : !torch.vtensor -> !torch.tensor
// CHECK: return %[[TENSOR]] : !torch.tensor
func @torch.tensor.literal() -> !torch.tensor {
%0 = torch.tensor.literal(dense<0.0> : tensor<7xf32>) : !torch.tensor
return %0 : !torch.tensor
}