mirror of https://github.com/llvm/torch-mlir
Add `torch.vtensor.literal` op.
This op is much better behaved than the `torch.tensor.literal` op (which is the new name of the `torch.tensor` op). In particular `torch.tensor.literal`: - always has a maximally refined type. - always has value semantics. - can be constant folded / CSE'd. ReduceOpVariants is changed to perform the transformation from `torch.tensor.literal` to `torch.vtensor.literal` (which in general involves static information casts and copies. This new op also allowed tightening up `torch.tensor.literal` to only accept NonValueTensorType (instead of any tensor type). This new ".literal" name is more descriptive. It was getting too confusing seeing an op called just `torch.tensor` (we originally called it that because that's the name of the similar function in the Torch Python API, but it just doesn't fit here).pull/232/head
parent
4a0eb44d17
commit
333e07a74e
|
@ -528,7 +528,7 @@ MlirValue AcapController::importTensorByValue(at::Tensor tensor) {
|
|||
auto loc = getCurrentLocation();
|
||||
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
|
||||
MlirOperation tensorOp = createMlirOperationAtEnd(
|
||||
funcBuilder->getEntryBlock(), "torch.tensor", loc,
|
||||
funcBuilder->getEntryBlock(), "torch.tensor.literal", loc,
|
||||
npcompTorchNonValueTensorTypeGetFromShaped(
|
||||
mlirAttributeGetType(denseElements)),
|
||||
toMlirNamedAttribute("value", denseElements));
|
||||
|
|
|
@ -349,7 +349,7 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
|
|||
at::Tensor tensor = ivalue.toTensor().contiguous();
|
||||
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
|
||||
MlirOperation tensorOp =
|
||||
createMlirOperationAtEnd(importBlock, "torch.tensor", loc,
|
||||
createMlirOperationAtEnd(importBlock, "torch.tensor.literal", loc,
|
||||
npcompTorchNonValueTensorTypeGetFromShaped(
|
||||
mlirAttributeGetType(denseElements)),
|
||||
toMlirNamedAttribute("value", denseElements));
|
||||
|
|
|
@ -13,6 +13,6 @@ with mb.capture_function("arange_test", []) as f:
|
|||
x = torch.arange(10)
|
||||
f.returns([x])
|
||||
|
||||
# CHECK: %[[T:.*]] = torch.tensor(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xsi64>) : !torch.tensor<[10],si64>
|
||||
# CHECK: %[[T:.*]] = torch.tensor.literal(dense<[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]> : tensor<10xsi64>) : !torch.tensor<[10],si64>
|
||||
# CHECK: return %[[T]]
|
||||
mb.module.operation.print()
|
||||
|
|
|
@ -23,9 +23,9 @@ with mb.capture_function("add3", [t0, t1, t2]) as f:
|
|||
# CHECK-SAME: %[[VAL_2:.*]]: !torch.tensor<[1,2,3,4],f32>) -> !torch.tensor<[1,2,3,4],f32> {
|
||||
# CHECK: %[[VAL_3:.*]] = torch.constant.int 1
|
||||
# CHECK: %[[VAL_4:.*]] = torch.constant.int 1
|
||||
# CHECK: %[[VAL_5:.*]] = torch.tensor(dense<0.000000e+00> : tensor<1x2x3x4xf32>) : !torch.tensor<[1,2,3,4],f32>
|
||||
# CHECK: %[[VAL_5:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<1x2x3x4xf32>) : !torch.tensor<[1,2,3,4],f32>
|
||||
# CHECK: %[[VAL_6:.*]] = torch.operator "aten.add.out"(%[[VAL_0]], %[[VAL_1]], %[[VAL_3]], %[[VAL_5]]) : (!torch.tensor<[1,2,3,4],f32>, !torch.tensor<[1,2,3,4],f32>, !torch.int, !torch.tensor<[1,2,3,4],f32>) -> !torch.tensor<[1,2,3,4],f32>
|
||||
# CHECK: %[[VAL_7:.*]] = torch.tensor(dense<0.000000e+00> : tensor<1x2x3x4xf32>) : !torch.tensor<[1,2,3,4],f32>
|
||||
# CHECK: %[[VAL_7:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<1x2x3x4xf32>) : !torch.tensor<[1,2,3,4],f32>
|
||||
# CHECK: %[[VAL_8:.*]] = torch.operator "aten.add.out"(%[[VAL_6]], %[[VAL_2]], %[[VAL_4]], %[[VAL_7]]) : (!torch.tensor<[1,2,3,4],f32>, !torch.tensor<[1,2,3,4],f32>, !torch.int, !torch.tensor<[1,2,3,4],f32>) -> !torch.tensor<[1,2,3,4],f32>
|
||||
# CHECK: return %[[VAL_8]] : !torch.tensor<[1,2,3,4],f32>
|
||||
# CHECK: }
|
||||
|
|
|
@ -43,8 +43,8 @@ with mb.capture_function("conv2d_fwd", [tensor]) as f:
|
|||
# CHECK: %[[VAL_8:.*]] = torch.constant.int 0
|
||||
# CHECK: %[[VAL_9:.*]] = torch.constant.int 0
|
||||
# CHECK: %[[VAL_10:.*]] = torch.constant.int 1
|
||||
# CHECK: %[[VAL_11:.*]] = torch.tensor(opaque<"_", "0xDEADBEEF"> : tensor<4x16x3x3xf32>) : !torch.tensor<[4,16,3,3],f32>
|
||||
# CHECK: %[[VAL_12:.*]] = torch.tensor(opaque<"_", "0xDEADBEEF"> : tensor<4xf32>) : !torch.tensor<[4],f32>
|
||||
# CHECK: %[[VAL_11:.*]] = torch.tensor.literal(opaque<"_", "0xDEADBEEF"> : tensor<4x16x3x3xf32>) : !torch.tensor<[4,16,3,3],f32>
|
||||
# CHECK: %[[VAL_12:.*]] = torch.tensor.literal(opaque<"_", "0xDEADBEEF"> : tensor<4xf32>) : !torch.tensor<[4],f32>
|
||||
# CHECK: %[[VAL_13:.*]] = torch.prim.ListConstruct %[[VAL_1]], %[[VAL_2]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
# CHECK: %[[VAL_14:.*]] = torch.prim.ListConstruct %[[VAL_3]], %[[VAL_4]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
# CHECK: %[[VAL_15:.*]] = torch.prim.ListConstruct %[[VAL_5]], %[[VAL_6]] : (!torch.int, !torch.int) -> !torch.list<!torch.int>
|
||||
|
|
|
@ -14,7 +14,7 @@ mb = torch_mlir.ModuleBuilder()
|
|||
class TestModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
# CHECK: %[[T:.*]] = torch.tensor
|
||||
# CHECK: %[[T:.*]] = torch.tensor.literal
|
||||
# CHECK: torch.nn_module {
|
||||
# CHECK: torch.slot "t1", %[[T]]
|
||||
# CHECK: torch.slot "t2", %[[T]]
|
||||
|
|
|
@ -21,9 +21,9 @@ class TestModule(torch.nn.Module):
|
|||
dtype=torch.qint8)
|
||||
# CHECK: %[[SCALE:.*]] = torch.constant.float
|
||||
# CHECK: %[[ZERO_POINT:.*]] = torch.constant.int 0
|
||||
# CHECK: %[[INT_REPR:.*]] = torch.tensor({{.*}}) : !torch.tensor<[2,5],si8>
|
||||
# CHECK: %[[INT_REPR:.*]] = torch.tensor.literal({{.*}}) : !torch.tensor<[2,5],si8>
|
||||
# CHECK: %[[WEIGHTS:.*]] = torch.per_tensor_affine.create %[[INT_REPR]], %[[SCALE]], %[[ZERO_POINT]] : !torch.tensor<[2,5],si8>, !torch.float, !torch.int -> !torch.tensor<[2,5],!torch.qint8>
|
||||
# CHECK: %[[BIAS:.*]] = torch.tensor({{.*}}) : !torch.tensor<[2],f32>
|
||||
# CHECK: %[[BIAS:.*]] = torch.tensor.literal({{.*}}) : !torch.tensor<[2],f32>
|
||||
# CHECK: %[[LINEAR_PARAMS:.*]] = torch.linear_params.create %[[WEIGHTS]], %[[BIAS]] : !torch.tensor<[2,5],!torch.qint8>, !torch.tensor<[2],f32>
|
||||
@torch.jit.export
|
||||
def test_linear(self, t):
|
||||
|
|
|
@ -18,8 +18,8 @@ class TestModule(torch.nn.Module):
|
|||
self.ones = torch.ones(1)
|
||||
self.arange = torch.nn.Parameter(torch.arange(3.0))
|
||||
|
||||
# CHECK: %[[ARANGE:.*]] = torch.tensor(dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>) : !torch.tensor<[3],f32>
|
||||
# CHECK: %[[ONES:.*]] = torch.tensor(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor<[1],f32>
|
||||
# CHECK: %[[ARANGE:.*]] = torch.tensor.literal(dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]> : tensor<3xf32>) : !torch.tensor<[3],f32>
|
||||
# CHECK: %[[ONES:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor<[1],f32>
|
||||
# CHECK: %[[ROOT:.*]] = torch.nn_module {
|
||||
# CHECK: torch.slot "arange", %[[ARANGE]] : !torch.tensor<[3],f32>
|
||||
# CHECK: torch.slot "ones", %[[ONES]] : !torch.tensor<[1],f32>
|
||||
|
|
|
@ -832,32 +832,27 @@ def Torch_PerTensorAffineCreateOp : Torch_Op<"per_tensor_affine.create", [
|
|||
}];
|
||||
}
|
||||
|
||||
// TODO: Disaggregate this op into a value-semantic constant + val->nonval
|
||||
// conversion if needed.
|
||||
// Currently, this op can effectively hide val->nonval conversion, which makes
|
||||
// it an edge case for passes that care about that such as
|
||||
// torch-maximize-value-semantics.
|
||||
// So the suggestion would be to lower this to a `torch.vtensor` op
|
||||
// (+`torch.copy.tensor` if needed).
|
||||
// In particular, currently we end up relying on convert-torch-to-std
|
||||
// to effectively expose this (as part of lowering to `std.constant`) +
|
||||
// hoping that some canonicalization cleans it up.
|
||||
// The `torch-maximize-value-semantics` pass should be doing this
|
||||
// before we convert to std at all.
|
||||
def Torch_TensorOp : Torch_Op<"tensor", [
|
||||
def Torch_NonValueTensorLiteralOp : Torch_Op<"tensor.literal", [
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface, ["isCompatibleReturnTypes"]>,
|
||||
AllowsTypeRefinement
|
||||
AllowsTypeRefinement,
|
||||
]> {
|
||||
let summary = "Create a value of !torch.tensor type from a literal";
|
||||
let description = [{
|
||||
Example:
|
||||
```
|
||||
%0 = torch.tensor(dense<0.0> : tensor<3x5xf32>) : !torch.tensor
|
||||
%1 = torch.tensor(dense<0.0> : tensor<3xf32>) : !torch.vtensor<[3],f32>
|
||||
%0 = torch.tensor.literal(dense<0.0> : tensor<3x5xf32>) : !torch.tensor
|
||||
%1 = torch.tensor.literal(dense<0.0> : tensor<3xf32>) : !torch.tensor<[3],f32>
|
||||
```
|
||||
|
||||
This op covers a typical frontend use case of creating a type-erased
|
||||
`!torch.tensor`. Inside the compiler, we decompose it into
|
||||
`torch.vtensor.literal` which is easier to analyze and transform.
|
||||
|
||||
Note: This op is not called "constant" because the created tensor is not
|
||||
"constant" in any meaning of that word.
|
||||
}];
|
||||
let arguments = (ins ElementsAttr:$value);
|
||||
let results = (outs AnyTorchTensorType:$result);
|
||||
let results = (outs Torch_NonValueTensorType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $value `)` attr-dict `:` type($result)
|
||||
|
@ -869,6 +864,35 @@ def Torch_TensorOp : Torch_Op<"tensor", [
|
|||
}];
|
||||
}
|
||||
|
||||
def Torch_ValueTensorLiteralOp : Torch_Op<"vtensor.literal", [
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>,
|
||||
ConstantLike,
|
||||
NoSideEffect,
|
||||
]> {
|
||||
let summary = "Create a value of !torch.vtensor type from a literal";
|
||||
let description = [{
|
||||
Example:
|
||||
```
|
||||
%0 = torch.vtensor.literal(dense<0.0> : tensor<3x5xf32>) : !torch.vtensor<[3,5],f32>
|
||||
%1 = torch.vtensor.literal(dense<0.0> : tensor<3xf32>) : !torch.vtensor<[3],f32>
|
||||
```
|
||||
|
||||
Unlike `torch.tensor.literal`, which covers a typical frontend use case
|
||||
and allows type refinement, this op always has a maximally resolved type
|
||||
(which is always possible, because it is created from a literal). This
|
||||
has a stronger set of invariants that better fit the needs of the
|
||||
compiler internals.
|
||||
}];
|
||||
let arguments = (ins ElementsAttr:$value);
|
||||
let results = (outs Torch_ValueTensorType:$result);
|
||||
|
||||
let assemblyFormat = [{
|
||||
`(` $value `)` attr-dict `:` type($result)
|
||||
}];
|
||||
|
||||
let hasFolder = 1;
|
||||
}
|
||||
|
||||
def Torch_TensorStaticInfoCastOp : Torch_Op<"tensor_static_info_cast", [
|
||||
DeclareOpInterfaceMethods<CastOpInterface>,
|
||||
AllowsTypeRefinement,
|
||||
|
|
|
@ -70,14 +70,19 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
LogicalResult convertTensorOp(TensorOp op, PatternRewriter &rewriter) {
|
||||
auto constant = rewriter.create<ConstantOp>(op->getLoc(), op.value());
|
||||
auto vtensor = rewriter.create<FromBuiltinTensorOp>(op->getLoc(), constant);
|
||||
Value result = copyTensorToType(rewriter, op->getLoc(),
|
||||
op.getType().cast<BaseTensorType>(), vtensor);
|
||||
rewriter.replaceOp(op, {result});
|
||||
namespace {
|
||||
class ConvertValueTensorLiteralOp
|
||||
: public OpConversionPattern<ValueTensorLiteralOp> {
|
||||
public:
|
||||
using OpConversionPattern::OpConversionPattern;
|
||||
LogicalResult
|
||||
matchAndRewrite(ValueTensorLiteralOp op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
rewriter.replaceOpWithNewOp<mlir::ConstantOp>(op, op.value());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
} // namespace
|
||||
|
||||
// -----------------------------------------------------------------------------
|
||||
// The pass
|
||||
|
@ -106,8 +111,8 @@ public:
|
|||
patterns.add<ConvertAtenNeIntOp>(typeConverter, context);
|
||||
target.addIllegalOp<AtenGtIntOp>();
|
||||
patterns.add<ConvertAtenGtIntOp>(typeConverter, context);
|
||||
target.addIllegalOp<TensorOp>();
|
||||
patterns.add(convertTensorOp);
|
||||
target.addIllegalOp<ValueTensorLiteralOp>();
|
||||
patterns.add<ConvertValueTensorLiteralOp>(typeConverter, context);
|
||||
if (failed(applyPartialConversion(getOperation(), target,
|
||||
std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
|
|
|
@ -136,5 +136,12 @@ Operation *TorchDialect::materializeConstant(OpBuilder &builder,
|
|||
if (auto stringAttr = value.dyn_cast<StringAttr>())
|
||||
return builder.create<ConstantStrOp>(loc, stringAttr);
|
||||
|
||||
if (auto elementsAttr = value.dyn_cast<ElementsAttr>()) {
|
||||
// Only !torch.vtensor can be constant folded. !torch.tensor has
|
||||
// non-trivial aliasing semantics which prevent deduplicating it.
|
||||
assert(type.isa<ValueTensorType>() && "should be a vtensor type!");
|
||||
return builder.create<ValueTensorLiteralOp>(loc, elementsAttr);
|
||||
}
|
||||
|
||||
return nullptr;
|
||||
}
|
||||
|
|
|
@ -438,13 +438,12 @@ OpFoldResult AtenNeIntOp::fold(ArrayRef<Attribute> operands) {
|
|||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// TensorOp
|
||||
// NonValueTensorLiteralOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult
|
||||
TensorOp::inferReturnTypes(MLIRContext *context, Optional<Location> location,
|
||||
ValueRange operands, DictionaryAttr attributes,
|
||||
RegionRange regions,
|
||||
LogicalResult NonValueTensorLiteralOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
auto attr = attributes.get("value").dyn_cast_or_null<ElementsAttr>();
|
||||
if (!attr)
|
||||
|
@ -466,13 +465,34 @@ static bool areSizesAndDtypesCompatible(BaseTensorType a, BaseTensorType b) {
|
|||
return true;
|
||||
}
|
||||
|
||||
bool TensorOp::isCompatibleReturnTypes(TypeRange inferred, TypeRange actual) {
|
||||
bool NonValueTensorLiteralOp::isCompatibleReturnTypes(TypeRange inferred,
|
||||
TypeRange actual) {
|
||||
if (!actual[0].isa<BaseTensorType>())
|
||||
return false;
|
||||
return areSizesAndDtypesCompatible(inferred[0].cast<BaseTensorType>(),
|
||||
actual[0].cast<BaseTensorType>());
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// ValueTensorLiteralOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
||||
LogicalResult ValueTensorLiteralOp::inferReturnTypes(
|
||||
MLIRContext *context, Optional<Location> location, ValueRange operands,
|
||||
DictionaryAttr attributes, RegionRange regions,
|
||||
SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
auto attr = attributes.get("value").dyn_cast_or_null<ElementsAttr>();
|
||||
if (!attr)
|
||||
return failure();
|
||||
auto tensorType = attr.getType().cast<RankedTensorType>();
|
||||
inferredReturnTypes.push_back(ValueTensorType::getFromShaped(tensorType));
|
||||
return success();
|
||||
}
|
||||
|
||||
OpFoldResult ValueTensorLiteralOp::fold(ArrayRef<Attribute> operands) {
|
||||
return valueAttr();
|
||||
}
|
||||
|
||||
//----------------------------------------------------------------------------//
|
||||
// TensorStaticInfoCast
|
||||
//----------------------------------------------------------------------------//
|
||||
|
|
|
@ -99,6 +99,17 @@ public:
|
|||
};
|
||||
} // namespace
|
||||
|
||||
static LogicalResult
|
||||
reduceNonValueTensorLiteralOpToValueTensorLiteralOp(NonValueTensorLiteralOp op,
|
||||
PatternRewriter &rewriter) {
|
||||
Value valueTensor =
|
||||
rewriter.create<ValueTensorLiteralOp>(op->getLoc(), op.value());
|
||||
Value tensor =
|
||||
copyTensorToType(rewriter, op->getLoc(), op.getType(), valueTensor);
|
||||
rewriter.replaceOp(op, {tensor});
|
||||
return success();
|
||||
}
|
||||
|
||||
namespace {
|
||||
class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
||||
void runOnOperation() override {
|
||||
|
@ -106,8 +117,10 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
|
|||
RewritePatternSet patterns(context);
|
||||
patterns.add<ConvertToImmutableTensors>(context);
|
||||
patterns.add<ReduceTrailingUnderscoreInplaceVariant>(context);
|
||||
patterns.add(reduceNonValueTensorLiteralOpToValueTensorLiteralOp);
|
||||
|
||||
ConversionTarget target(*context);
|
||||
target.addIllegalOp<NonValueTensorLiteralOp>();
|
||||
target.markUnknownOpDynamicallyLegal([](Operation *op) {
|
||||
if (op->hasTrait<Torch::OpTrait::HasValueSemantics>()) {
|
||||
auto hasValueSemantics = [](Type t) {
|
||||
|
|
|
@ -39,21 +39,11 @@ func @torch.aten.gt.int(%arg0: !torch.int, %arg1: !torch.int) -> !torch.bool {
|
|||
return %0 : !torch.bool
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.tensor$value() -> !torch.vtensor<[],f32> {
|
||||
// CHECK-LABEL: func @torch.vtensor.literal() -> !torch.vtensor<[],f32> {
|
||||
// CHECK: %[[CST:.*]] = constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VTENSOR:.*]] = torch.from_builtin_tensor %[[CST]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: return %[[VTENSOR]] : !torch.vtensor<[],f32>
|
||||
func @torch.tensor$value() -> !torch.vtensor<[],f32> {
|
||||
%0 = torch.tensor(dense<0.0> : tensor<f32>) : !torch.vtensor<[],f32>
|
||||
func @torch.vtensor.literal() -> !torch.vtensor<[],f32> {
|
||||
%0 = torch.vtensor.literal(dense<0.0> : tensor<f32>) : !torch.vtensor<[],f32>
|
||||
return %0 : !torch.vtensor<[],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.tensor$nonval() -> !torch.tensor<[],f32> {
|
||||
// CHECK: %[[CST:.*]] = constant dense<0.000000e+00> : tensor<f32>
|
||||
// CHECK: %[[VTENSOR:.*]] = torch.from_builtin_tensor %[[CST]] : tensor<f32> -> !torch.vtensor<[],f32>
|
||||
// CHECK: %[[NONVAL:.*]] = torch.copy.tensor %[[VTENSOR]] : !torch.vtensor<[],f32> -> !torch.tensor<[],f32>
|
||||
// CHECK: return %[[NONVAL]] : !torch.tensor<[],f32>
|
||||
func @torch.tensor$nonval() -> !torch.tensor<[],f32> {
|
||||
%0 = torch.tensor(dense<0.0> : tensor<f32>) : !torch.tensor<[],f32>
|
||||
return %0 : !torch.tensor<[],f32>
|
||||
}
|
||||
|
|
|
@ -18,7 +18,7 @@
|
|||
// CHECK: }
|
||||
|
||||
// CHECK-LABEL: torch.global_slot @t : !torch.tensor {
|
||||
// CHECK: %[[T:.*]] = torch.tensor(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
|
||||
// CHECK: %[[T:.*]] = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
|
||||
// CHECK: torch.global_slot.init %[[T]] : !torch.tensor
|
||||
// CHECK: }
|
||||
|
||||
|
@ -32,7 +32,7 @@ torch.class_type @c {
|
|||
%bool_true = torch.constant.bool true
|
||||
%i = torch.constant.int 3
|
||||
%f = torch.constant.float 4.250000e+01
|
||||
%t = torch.tensor(dense<1.0> : tensor<1xf32>) : !torch.tensor
|
||||
%t = torch.tensor.literal(dense<1.0> : tensor<1xf32>) : !torch.tensor
|
||||
torch.nn_module {
|
||||
torch.slot "b", %bool_true : !torch.bool
|
||||
torch.slot "i", %i : !torch.int
|
||||
|
|
|
@ -37,7 +37,7 @@ torch.class_type @c {
|
|||
}
|
||||
|
||||
// expected-error @+1 {{potentially-aliased value used to initialize multiple slots}}
|
||||
%t = torch.tensor(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
|
||||
%t = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
|
||||
torch.nn_module {
|
||||
torch.slot "t1", %t : !torch.tensor
|
||||
torch.slot "t2", %t : !torch.tensor
|
||||
|
|
|
@ -2,24 +2,24 @@
|
|||
|
||||
// CHECK-NOT: @readonly
|
||||
torch.global_slot "private" @readonly : !torch.tensor {
|
||||
%0 = torch.tensor(dense<0.0> : tensor<1xf32>) : !torch.tensor
|
||||
%0 = torch.tensor.literal(dense<0.0> : tensor<1xf32>) : !torch.tensor
|
||||
torch.global_slot.init %0 : !torch.tensor
|
||||
}
|
||||
// CHECK-LABEL: torch.global_slot @public
|
||||
torch.global_slot @public : !torch.tensor {
|
||||
%0 = torch.tensor(dense<0.0> : tensor<2xf32>) : !torch.tensor
|
||||
%0 = torch.tensor.literal(dense<0.0> : tensor<2xf32>) : !torch.tensor
|
||||
torch.global_slot.init %0 : !torch.tensor
|
||||
}
|
||||
// CHECK-LABEL: torch.global_slot "private" @mutated
|
||||
torch.global_slot "private" @mutated : !torch.tensor {
|
||||
%0 = torch.tensor(dense<0.0> : tensor<3xf32>) : !torch.tensor
|
||||
%0 = torch.tensor.literal(dense<0.0> : tensor<3xf32>) : !torch.tensor
|
||||
torch.global_slot.init %0 : !torch.tensor
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @forward() -> (!torch.tensor, !torch.tensor, !torch.tensor) {
|
||||
func @forward() -> (!torch.tensor, !torch.tensor, !torch.tensor) {
|
||||
// Inlined.
|
||||
// CHECK: %[[READONLY:.*]] = torch.tensor(dense<0.000000e+00> : tensor<1xf32>) : !torch.tensor
|
||||
// CHECK: %[[READONLY:.*]] = torch.tensor.literal(dense<0.000000e+00> : tensor<1xf32>) : !torch.tensor
|
||||
%0 = torch.global_slot.get @readonly : !torch.tensor
|
||||
|
||||
// Not inlined: potentially mutated by externals.
|
||||
|
|
|
@ -131,7 +131,7 @@ func private @tensor.invalid_dtype() -> !torch.tensor<*,tuple<>>
|
|||
func @torch.tensor() {
|
||||
// Incompatible shape.
|
||||
// expected-error@+1 {{incompatible}}
|
||||
%0 = torch.tensor(dense<42.0> : tensor<3x2xf32>) : !torch.vtensor<[],f32>
|
||||
%0 = torch.tensor.literal(dense<42.0> : tensor<3x2xf32>) : !torch.vtensor<[],f32>
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -140,7 +140,7 @@ func @torch.tensor() {
|
|||
func @torch.tensor() {
|
||||
// Incompatible dtype.
|
||||
// expected-error@+1 {{incompatible}}
|
||||
%0 = torch.tensor(dense<42.0> : tensor<f32>) : !torch.vtensor<[],f64>
|
||||
%0 = torch.tensor.literal(dense<42.0> : tensor<f32>) : !torch.vtensor<[],f64>
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -149,7 +149,7 @@ func @torch.tensor() {
|
|||
func @torch.tensor() {
|
||||
// Incompatible type.
|
||||
// expected-error@+1 {{incompatible}}
|
||||
%0 = torch.tensor(dense<42.0> : tensor<f32>) : i1
|
||||
%0 = torch.tensor.literal(dense<42.0> : tensor<f32>) : i1
|
||||
return
|
||||
}
|
||||
|
||||
|
|
|
@ -48,12 +48,19 @@ func private @tuple.one_element() -> !torch.tuple<!torch.tensor>
|
|||
// CHECK: @tuple.two_elements() -> !torch.tuple<!torch.tensor, !torch.tensor>
|
||||
func private @tuple.two_elements() -> !torch.tuple<!torch.tensor, !torch.tensor>
|
||||
|
||||
// CHECK-LABEL: func @torch.tensor() {
|
||||
func @torch.tensor() {
|
||||
// CHECK: torch.tensor(dense<4.200000e+01> : tensor<3x2xf32>) : !torch.vtensor<[3,2],f32>
|
||||
%0 = torch.tensor(dense<42.0> : tensor<3x2xf32>) : !torch.vtensor<[3,2],f32>
|
||||
// CHECK: torch.tensor(dense<4.200000e+01> : tensor<3x2xf32>) : !torch.tensor<[3,2],f32>
|
||||
%1 = torch.tensor(dense<42.0> : tensor<3x2xf32>) : !torch.tensor<[3,2],f32>
|
||||
// CHECK-LABEL: func @torch.tensor.literal() {
|
||||
func @torch.tensor.literal() {
|
||||
// CHECK: torch.tensor.literal(dense<4.200000e+01> : tensor<3x2xf32>) : !torch.tensor
|
||||
%0 = torch.tensor.literal(dense<42.0> : tensor<3x2xf32>) : !torch.tensor
|
||||
// CHECK: torch.tensor.literal(dense<4.200000e+01> : tensor<3x2xf32>) : !torch.tensor<[3,2],f32>
|
||||
%1 = torch.tensor.literal(dense<42.0> : tensor<3x2xf32>) : !torch.tensor<[3,2],f32>
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.vtensor.literal() {
|
||||
func @torch.vtensor.literal() {
|
||||
// CHECK: torch.vtensor.literal(dense<4.200000e+01> : tensor<3x2xf32>) : !torch.vtensor<[3,2],f32>
|
||||
%0 = torch.vtensor.literal(dense<42.0> : tensor<3x2xf32>) : !torch.vtensor<[3,2],f32>
|
||||
return
|
||||
}
|
||||
|
||||
|
@ -81,7 +88,7 @@ func @torch.prim.If(%arg0: !torch.bool, %arg1: !torch.int) -> !torch.int {
|
|||
%int3 = torch.constant.int 3
|
||||
// CHECK: %float = torch.constant.float 4.250000e+01
|
||||
%float = torch.constant.float 4.250000e+01
|
||||
%tensor = torch.tensor(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
|
||||
%tensor = torch.tensor.literal(dense<1.000000e+00> : tensor<1xf32>) : !torch.tensor
|
||||
// CHECK: %none = torch.constant.none
|
||||
%none = torch.constant.none
|
||||
// CHECK: %str = torch.constant.str "some str"
|
||||
|
|
|
@ -31,3 +31,13 @@ func @reduce_trailing_underscore_inplace_variant(%arg0: !torch.tensor<[2,2],f32>
|
|||
%0 = torch.aten.add_.Tensor %arg0, %arg1, %c1 : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>, !torch.int -> !torch.tensor<[2,2],f32>
|
||||
return %0, %arg0 : !torch.tensor<[2,2],f32>, !torch.tensor<[2,2],f32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @torch.tensor.literal() -> !torch.tensor {
|
||||
// CHECK: %[[VTENSOR:.*]] = torch.vtensor.literal(dense<0.000000e+00> : tensor<7xf32>) : !torch.vtensor<[7],f32>
|
||||
// CHECK: %[[SIZES_ERASED:.*]] = torch.tensor_static_info_cast %[[VTENSOR]] : !torch.vtensor<[7],f32> to !torch.vtensor
|
||||
// CHECK: %[[TENSOR:.*]] = torch.copy.tensor %[[SIZES_ERASED]] : !torch.vtensor -> !torch.tensor
|
||||
// CHECK: return %[[TENSOR]] : !torch.tensor
|
||||
func @torch.tensor.literal() -> !torch.tensor {
|
||||
%0 = torch.tensor.literal(dense<0.0> : tensor<7xf32>) : !torch.tensor
|
||||
return %0 : !torch.tensor
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue