E2e implementation for `aten.cat`,`aten.gather`, `aten.bmm`

Also contains the following changes:
- Remove derefineOp canonicalizer because it's not safe.
- Support for optional tensor and list tensors in reduceOpVariant. This
only works for some special detected and easy to handle cases. For list,
it covers the case list is got from a `ListConstruct`. For optional, it
covers the case optional is constructed from a `DerefineOp`.
- Remove the `inferReturnTypes` for `FromBuiltinTensorOp` because it's
not safe to deduce types from the input. For example, a built-in tensor
of i8 could be converted to si8 or ui8. It's better to let the user
specify the return type explicitly.
pull/322/head
Yi Zhang 2021-09-13 20:57:59 -04:00
parent 3dc9b4ee2f
commit 603e068e45
18 changed files with 509 additions and 138 deletions

View File

@ -41,6 +41,28 @@ def MmModule_basic(module, tu: TestUtils):
# res = module.forward(tu.rand(4, 4), tu.rand(4, 4))
# module.forward(res, res)
# ==============================================================================
class BmmModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
])
def forward(self, lhs, rhs):
return torch.bmm(lhs, rhs)
@register_test_case(module_factory=lambda: BmmModule())
def BmmModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 5), tu.rand(3, 5, 4))
# ==============================================================================
@ -203,3 +225,41 @@ class TransposeIntModule(torch.nn.Module):
@register_test_case(module_factory=lambda: TransposeIntModule())
def TransposeIntModule_basic(module, tu: TestUtils):
module.forward(tu.rand(3, 4, 2))
class TensorsConcatModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.float32, True),
])
def forward(self, x, y, z):
return torch.cat([x, y, z], 1)
@register_test_case(module_factory=lambda: TensorsConcatModule())
def TensorsConcatModule_basic(module, tu: TestUtils):
module.forward(tu.rand(2, 2, 4), tu.rand(2, 1, 4), tu.rand(2, 3, 4))
class GatherModule(torch.nn.Module):
def __init__(self):
super().__init__()
@export
@annotate_args([
None,
([-1, -1, -1], torch.float32, True),
([-1, -1, -1], torch.int64, True),
])
def forward(self, tensor, indices):
return torch.gather(tensor, 2, indices)
#@register_test_case(module_factory=lambda: GatherModule())
#def GatherModule_basic(module, tu: TestUtils):
# module.forward(tu.rand(2, 3, 4), torch.tensor([[[1,2,3],[1,2,3]]]))

@ -1 +1 @@
Subproject commit 8dca953dd39c0cd8c80decbeb38753f58a4de580
Subproject commit 6e60bb6883178cf14e6fd47a6789495636e4322f

View File

@ -144,9 +144,9 @@ MLIR_CAPI_EXPORTED MlirType
torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation(
MlirContext context);
/// Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`.
/// Gets the !torch.tensor type with the tensor attribute.
MLIR_CAPI_EXPORTED MlirType
torchMlirTorchNonValueTensorTypeGetFromShaped(MlirType type);
torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr);
//===----------------------------------------------------------------------===//
// torch.vtensor type.
@ -169,10 +169,6 @@ MLIR_CAPI_EXPORTED MlirType torchMlirTorchValueTensorTypeGet(
MLIR_CAPI_EXPORTED MlirType
torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(MlirContext context);
/// Gets a !torch.tensor type, taking shape/dtype from a ShapedType `type`.
MLIR_CAPI_EXPORTED MlirType
torchMlirTorchValueTensorTypeGetFromShaped(MlirType type);
//===----------------------------------------------------------------------===//
// !torch.none type.
//===----------------------------------------------------------------------===//

View File

@ -647,8 +647,6 @@ def Torch_DerefineOp : Torch_Op<"derefine", [
let assemblyFormat = [{
$operand attr-dict `:` type($operand) `to` type($result)
}];
let hasCanonicalizer = 1;
}
def Torch_OperatorOp : Torch_Op<"operator", [

View File

@ -189,8 +189,6 @@ def Torch_NonValueTensorType : AnyTorchTensorType<"NonValueTensor", "tensor"> {
ValueTensorType getWithValueSemantics() const;
// Get the !torch.tensor type with the least static information.
static NonValueTensorType getWithLeastStaticInformation(MLIRContext *context);
// Get a NonValueTensorType with shape/dtype matching `type`.
static NonValueTensorType getFromShaped(ShapedType type);
}];
}
@ -200,8 +198,6 @@ def Torch_ValueTensorType : AnyTorchTensorType<"ValueTensor", "vtensor"> {
NonValueTensorType getWithoutValueSemantics() const;
// Get the !torch.tensor type with the least static information.
static ValueTensorType getWithLeastStaticInformation(MLIRContext *context);
// Get a NonValueTensorType with shape/dtype matching `type`.
static ValueTensorType getFromShaped(ShapedType type);
// Get the builtin tensor type with the same static information as this one,
// or nullptr if that is not possible (i.e. when the dtype is unknown).
TensorType toBuiltinTensor() const;

View File

@ -168,9 +168,11 @@ MlirType torchMlirTorchNonValueTensorTypeGetWithLeastStaticInformation(
unwrap(context)));
}
MlirType torchMlirTorchNonValueTensorTypeGetFromShaped(MlirType type) {
return wrap(Torch::NonValueTensorType::getFromShaped(
unwrap(type).cast<ShapedType>()));
MlirType torchMlirTorchNonValueTensorTypeGetFromAttribute(MlirAttribute attr) {
auto attrTensorType = unwrap(attr).getType().cast<RankedTensorType>();
return wrap(Torch::NonValueTensorType::get(attrTensorType.getContext(),
attrTensorType.getShape(),
attrTensorType.getElementType()));
}
//===----------------------------------------------------------------------===//
@ -198,11 +200,6 @@ MlirType torchMlirTorchValueTensorTypeGetWithLeastStaticInformation(
Torch::ValueTensorType::getWithLeastStaticInformation(unwrap(context)));
}
MlirType torchMlirTorchValueTensorTypeGetFromShaped(MlirType type) {
return wrap(
Torch::ValueTensorType::getFromShaped(unwrap(type).cast<ShapedType>()));
}
//===----------------------------------------------------------------------===//
// torch.none type.
//===----------------------------------------------------------------------===//

View File

@ -384,26 +384,6 @@ bool DerefineOp::areCastCompatible(mlir::TypeRange inputs,
return isValidSubtype(inputs[0], outputs[0]);
}
void DerefineOp::getCanonicalizationPatterns(RewritePatternSet &patterns,
MLIRContext *context) {
patterns.add(+[](DerefineOp op, PatternRewriter &rewriter) {
// TODO: This pattern should be removed because type refine does a better
// job dealing with control flow. However, removing this would expose an
// issue with ReduceOpVariants. DerefineOp doesn't have value semantics and
// if not removed eagerly by canonicalizer would prevent ReduceOpVariants
// from converting certain tensors value semantics.
bool allAllowRefinement =
llvm::all_of(op.getResult().getUsers(), [](Operation *op) {
return op
->hasTrait<mlir::torch::Torch::OpTrait::AllowsTypeRefinement>();
});
if (!allAllowRefinement)
return failure();
rewriter.replaceOp(op, op.getOperand());
return success();
});
}
template <typename OpTy>
static OpFoldResult atenIsOrIsNotFoldHelper(OpTy op, bool equalIsTrue) {
Type lhsType = op.self().getType();
@ -613,8 +593,11 @@ LogicalResult NonValueTensorLiteralOp::inferReturnTypes(
auto attr = attributes.get("value").dyn_cast_or_null<ElementsAttr>();
if (!attr)
return failure();
auto tensorType = attr.getType().cast<RankedTensorType>();
inferredReturnTypes.push_back(NonValueTensorType::getFromShaped(tensorType));
RankedTensorType tensorType = attr.getType().cast<RankedTensorType>();
NonValueTensorType returnType =
NonValueTensorType::get(tensorType.getContext(), tensorType.getShape(),
tensorType.getElementType());
inferredReturnTypes.push_back(returnType);
return success();
}
@ -649,8 +632,11 @@ LogicalResult ValueTensorLiteralOp::inferReturnTypes(
auto attr = attributes.get("value").dyn_cast_or_null<ElementsAttr>();
if (!attr)
return failure();
auto tensorType = attr.getType().cast<RankedTensorType>();
inferredReturnTypes.push_back(ValueTensorType::getFromShaped(tensorType));
RankedTensorType tensorType = attr.getType().cast<RankedTensorType>();
ValueTensorType returnType =
ValueTensorType::get(tensorType.getContext(), tensorType.getShape(),
tensorType.getElementType());
inferredReturnTypes.push_back(returnType);
return success();
}

View File

@ -219,13 +219,6 @@ NonValueTensorType::getWithLeastStaticInformation(MLIRContext *context) {
/*optionalDtype=*/Type());
}
NonValueTensorType NonValueTensorType::getFromShaped(ShapedType type) {
return NonValueTensorType::get(type.getContext(),
type.hasRank() ? type.getShape()
: Optional<ArrayRef<int64_t>>(),
type.getElementType());
}
LogicalResult
NonValueTensorType::verify(function_ref<InFlightDiagnostic()> emitError,
Optional<ArrayRef<int64_t>> optionalSizes,
@ -263,11 +256,14 @@ ValueTensorType::getWithLeastStaticInformation(MLIRContext *context) {
/*optionalDtype=*/Type());
}
ValueTensorType ValueTensorType::getFromShaped(ShapedType type) {
return ValueTensorType::get(type.getContext(),
type.hasRank() ? type.getShape()
: Optional<ArrayRef<int64_t>>(),
type.getElementType());
static Type convertDtypeToBuiltinElementType(MLIRContext *context, Type dtype) {
if (auto floatType = dtype.dyn_cast<mlir::FloatType>()) {
return dtype;
} else if (auto integerType = dtype.dyn_cast<IntegerType>()) {
return IntegerType::get(context, integerType.getWidth(),
IntegerType::Signless);
}
assert(false && "Unsupported dtype to convert to builtin element type");
}
TensorType ValueTensorType::toBuiltinTensor() const {
@ -275,7 +271,8 @@ TensorType ValueTensorType::toBuiltinTensor() const {
return nullptr;
if (!hasSizes())
return UnrankedTensorType::get(getDtype());
return RankedTensorType::get(getSizes(), getDtype());
return RankedTensorType::get(
getSizes(), convertDtypeToBuiltinElementType(getContext(), getDtype()));
}
LogicalResult

View File

@ -29,29 +29,75 @@ public:
if (!op->hasTrait<Torch::OpTrait::HasValueSemantics>())
return rewriter.notifyMatchFailure(op, "does not have value semantics");
rewriter.updateRootInPlace(op, [&]() {
// Convert all operands.
SmallVector<Value> newOperands;
for (OpOperand &opOperand : op->getOpOperands()) {
auto tensorType =
opOperand.get().getType().dyn_cast<NonValueTensorType>();
if (!tensorType)
continue;
rewriter.startRootUpdate(op);
// Convert all operands.
SmallVector<Value> newOperands;
for (OpOperand &opOperand : op->getOpOperands()) {
Type operandType = opOperand.get().getType();
if (operandType.isa<NonValueTensorType>()) {
opOperand.set(rewriter.create<CopyToValueTensorOp>(op->getLoc(),
opOperand.get()));
}
// Convert all results.
rewriter.setInsertionPointAfter(op);
for (Value result : op->getResults()) {
auto tensorType = result.getType().dyn_cast<NonValueTensorType>();
if (!tensorType)
} else if (auto listType = operandType.dyn_cast<ListType>()) {
if (!listType.getContainedType().isa<NonValueTensorType>())
continue;
result.setType(tensorType.getWithValueSemantics());
auto nonValueTensor =
rewriter.create<CopyToNonValueTensorOp>(op->getLoc(), result);
result.replaceAllUsesExcept(nonValueTensor, nonValueTensor);
// Construct a new list whose elements are value tensors copied from
// the none value tensors of the original list.
auto listConstruct =
opOperand.get().getDefiningOp<PrimListConstructOp>();
if (!listConstruct) {
rewriter.cancelRootUpdate(op);
return rewriter.notifyMatchFailure(op,
"unimplemented: list of non vtensor type not constructed "
"from list construct");
}
if (listConstruct.elements().empty())
continue;
auto newListElements = llvm::to_vector<4>(llvm::map_range(
listConstruct.elements(), [&](Value tensor) -> Value {
return rewriter.create<CopyToValueTensorOp>(op->getLoc(), tensor);
}));
opOperand.set(rewriter.create<PrimListConstructOp>(
op->getLoc(),
Torch::ListType::get(newListElements.front().getType()),
newListElements));
} else if (auto optionalType = operandType.dyn_cast<OptionalType>()) {
// TODO: A more general way to handle the optional type is to
// introduce a `copy.to_optional_vtensor` op.
if (!optionalType.getContainedType().isa<NonValueTensorType>())
continue;
// Create a new optional value whose input is a value tensor copied
// from the non value tensor of the original optional value.
auto derefine = opOperand.get().getDefiningOp<DerefineOp>();
if (!derefine) {
rewriter.cancelRootUpdate(op);
return rewriter.notifyMatchFailure(op,
"unimplemented: optional of non vtensor type not from derefine");
}
if (!derefine.operand().getType().isa<NonValueTensorType>())
continue;
auto newOperand = rewriter.create<CopyToValueTensorOp>(
op->getLoc(), derefine.operand());
opOperand.set(rewriter.create<DerefineOp>(
op->getLoc(), Torch::OptionalType::get(newOperand.getType()),
newOperand));
}
});
}
// Convert all results.
rewriter.setInsertionPointAfter(op);
for (Value result : op->getResults()) {
auto tensorType = result.getType().dyn_cast<NonValueTensorType>();
if (!tensorType)
continue;
result.setType(tensorType.getWithValueSemantics());
auto nonValueTensor =
rewriter.create<CopyToNonValueTensorOp>(op->getLoc(), result);
result.replaceAllUsesExcept(nonValueTensor, nonValueTensor);
}
rewriter.finalizeRootUpdate(op);
return success();
}
};

View File

@ -1057,6 +1057,7 @@ ChangeResult TypeAnalyzer::visitAtenBmmOp(
auto mat2 = operands[1]->getValue();
knowledge.sizes.resize(3, kUnknownSize);
knowledge.dtype = joinElementTypes(self.dtype, mat2.dtype);
knowledge.hasSizes = true;
return getLatticeElement(op->getResult(0)).join(knowledge);
}

View File

@ -363,11 +363,11 @@ MlirValue IValueImporter::importTensor(c10::IValue ivalue) {
// Import the bulk tensor representation.
at::Tensor tensor = ivalue.toTensor().contiguous();
MlirAttribute denseElements = convertTensorToMlirElementsAttr(tensor, loc);
MlirOperation tensorOp =
createMlirOperationAtEnd(importBlock, "torch.tensor.literal", loc,
torchMlirTorchNonValueTensorTypeGetFromShaped(
mlirAttributeGetType(denseElements)),
toMlirNamedAttribute("value", denseElements));
MlirOperation tensorOp = createMlirOperationAtEnd(
importBlock, "torch.tensor.literal", loc,
torchMlirTorchNonValueTensorTypeGetFromAttribute(denseElements),
toMlirNamedAttribute("value", denseElements));
MlirValue tensorReprValue = mlirOperationGetResult(tensorOp, 0);
// Construct the complete tensor value. This is trivial for most tensors, but

View File

@ -0,0 +1,26 @@
// RUN: torch-mlir-opt -torch-reduce-op-variants -verify-diagnostics -split-input-file %s
// -----
func @convert_to_value_semantic_tensors_list( %list: !torch.list<!torch.tensor>) -> !torch.tensor {
%int1 = torch.constant.int 1
// expected-error@+1 {{failed to legalize operation 'torch.aten.cat' that was explicitly marked illegal}}
%ret = torch.aten.cat %list, %int1 : !torch.list<!torch.tensor>, !torch.int -> !torch.tensor
return %ret : !torch.tensor
}
// -----
func @convert_to_value_semantic_tensors_optional(%tensor_optional: !torch.optional<!torch.tensor>,
%t: !torch.tensor,
%training: !torch.bool,
%cudnn_enable: !torch.bool,
%f : !torch.float) -> !torch.tensor {
// expected-error@+1 {{failed to legalize operation 'torch.aten.batch_norm' that was explicitly marked illegal}}
%ret = torch.aten.batch_norm %t, %tensor_optional, %tensor_optional, %tensor_optional,
%tensor_optional, %training, %f, %f, %cudnn_enable:
!torch.tensor, !torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>,
!torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>,
!torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.tensor
return %ret: !torch.tensor
}

View File

@ -11,6 +11,77 @@ func @convert_to_value_semantic_tensors(%arg0: !torch.tensor<[],f32>) -> !torch.
return %0 : !torch.tensor<[],f32>
}
// CHECK-LABEL: func @convert_to_value_semantic_tensors_list(
// CHECK-SAME: %[[VT0:.*]]: !torch.vtensor, %[[VT1:.*]]: !torch.vtensor,
// CHECK-SAME: %[[VT2:.*]]: !torch.vtensor) -> !torch.tensor {
// CHECK: %[[T0:.*]] = torch.copy.to_tensor %[[VT0]] : !torch.tensor
// CHECK: %[[T1:.*]] = torch.copy.to_tensor %[[VT1]] : !torch.tensor
// CHECK: %[[T2:.*]] = torch.copy.to_tensor %[[VT2]] : !torch.tensor
// CHECK: %[[DIM:.*]] = torch.constant.int 1
// CHECK: %[[LIST_ORIG:.*]] = torch.prim.ListConstruct %[[T0]], %[[T1]], %[[T2]] :
// CHECK-SAME: (!torch.tensor, !torch.tensor, !torch.tensor) -> !torch.list<!torch.tensor>
// CHECK: %[[VT0_COPY:.*]] = torch.copy.to_vtensor %[[T0]] : !torch.vtensor
// CHECK: %[[VT1_COPY:.*]] = torch.copy.to_vtensor %[[T1]] : !torch.vtensor
// CHECK: %[[VT2_COPY:.*]] = torch.copy.to_vtensor %[[T2]] : !torch.vtensor
// CHECK: %[[LIST_NEW:.*]] = torch.prim.ListConstruct
// CHECK-SAME: %[[VT0_COPY]], %[[VT1_COPY]], %[[VT2_COPY]] :
// CHECK-SAME: (!torch.vtensor, !torch.vtensor, !torch.vtensor) -> !torch.list<!torch.vtensor>
// CHECK: %[[VRET:.*]] = torch.aten.cat %[[LIST_NEW]], %[[DIM]] :
// CHECK-SAME: !torch.list<!torch.vtensor>, !torch.int -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: return %[[RET]] : !torch.tensor
func @convert_to_value_semantic_tensors_list(%vt0: !torch.vtensor, %vt1: !torch.vtensor, %vt2: !torch.vtensor) -> !torch.tensor {
%t0 = torch.copy.to_tensor %vt0 : !torch.tensor
%t1 = torch.copy.to_tensor %vt1 : !torch.tensor
%t2 = torch.copy.to_tensor %vt2 : !torch.tensor
%int1 = torch.constant.int 1
%list = torch.prim.ListConstruct %t0, %t1, %t2 : (!torch.tensor, !torch.tensor, !torch.tensor) -> !torch.list<!torch.tensor>
%ret = torch.aten.cat %list, %int1 : !torch.list<!torch.tensor>, !torch.int -> !torch.tensor
return %ret : !torch.tensor
}
// CHECK-LABEL: func @convert_to_value_semantic_tensors_optional(
// CHECK-SAME: %[[INPUT:.*]]: !torch.tensor, %[[FLOAT_TENSOR:.*]]: !torch.tensor<[4],f32>,
// CHECK-SAME: %[[TRAINING:.*]]: !torch.bool, %[[CUDNN_ENABLE:.*]]: !torch.bool,
// CHECK-SAME: %[[FLOAT:.*]]: !torch.float) -> !torch.tensor {
// CHECK: %[[NONE:.*]] = torch.constant.none
// CHECK: %[[FLOAT_TENSOR_OPTIONAL:.*]] = torch.derefine %[[FLOAT_TENSOR]] :
// CHECK-SAME: !torch.tensor<[4],f32> to !torch.optional<!torch.tensor>
// CHECK: %[[BIAS_NONE_OPTIONAL:.*]] = torch.derefine %[[NONE]] : !torch.none to !torch.optional<!torch.tensor>
// CHECK: %[[VINPUT:.*]] = torch.copy.to_vtensor %[[INPUT]] : !torch.vtensor
// CHECK: %[[FLOAT_VTENSOR:.*]] = torch.copy.to_vtensor %[[FLOAT_TENSOR]] : !torch.vtensor<[4],f32>
// CHECK: %[[WEIGHTS_TENSOR_OPTIONAL:.*]] = torch.derefine %[[FLOAT_VTENSOR]] :
// CHECK-SAME: !torch.vtensor<[4],f32> to !torch.optional<!torch.vtensor<[4],f32>>
// CHECK: %[[FLOAT_VTENSOR:.*]] = torch.copy.to_vtensor %[[FLOAT_TENSOR]] : !torch.vtensor<[4],f32>
// CHECK: %[[MEAN_VTENSOR_OPTIONAL:.*]] = torch.derefine %[[FLOAT_VTENSOR]] :
// CHECK-SAME: !torch.vtensor<[4],f32> to !torch.optional<!torch.vtensor<[4],f32>>
// CHECK: %[[FLOAT_VTENSOR:.*]] = torch.copy.to_vtensor %[[FLOAT_TENSOR]] : !torch.vtensor<[4],f32>
// CHECK: %[[VAR_VTENSOR_OPTIONAL:.*]] = torch.derefine %[[FLOAT_VTENSOR]] :
// CHECK-SAME: !torch.vtensor<[4],f32> to !torch.optional<!torch.vtensor<[4],f32>>
// CHECK: %[[VRET:.*]] = torch.aten.batch_norm %[[VINPUT]], %[[WEIGHTS_TENSOR_OPTIONAL]],
// CHECK-SAME: %[[BIAS_NONE_OPTIONAL]], %[[MEAN_VTENSOR_OPTIONAL]], %[[VAR_VTENSOR_OPTIONAL]],
// CHECK-SAME: %[[TRAINING]], %[[FLOAT]], %[[FLOAT]], %[[CUDNN_ENABLE]] :
// CHECK-SAME: !torch.vtensor, !torch.optional<!torch.vtensor<[4],f32>>, !torch.optional<!torch.tensor>,
// CHECK-SAME: !torch.optional<!torch.vtensor<[4],f32>>, !torch.optional<!torch.vtensor<[4],f32>>,
// CHECK-SAME: !torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.vtensor
// CHECK: %[[RET:.*]] = torch.copy.to_tensor %[[VRET]] : !torch.tensor
// CHECK: return %[[RET]] : !torch.tensor
// CHECK: }
func @convert_to_value_semantic_tensors_optional(%t: !torch.tensor,
%ft: !torch.tensor<[4],f32>,
%training: !torch.bool,
%cudnn_enable: !torch.bool,
%f : !torch.float) -> !torch.tensor {
%none = torch.constant.none
%tensor_optional = torch.derefine %ft: !torch.tensor<[4],f32> to !torch.optional<!torch.tensor>
%none_optional = torch.derefine %none : !torch.none to !torch.optional<!torch.tensor>
%ret = torch.aten.batch_norm %t, %tensor_optional, %none_optional, %tensor_optional,
%tensor_optional, %training, %f, %f, %cudnn_enable:
!torch.tensor, !torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>,
!torch.optional<!torch.tensor>, !torch.optional<!torch.tensor>,
!torch.bool, !torch.float, !torch.float, !torch.bool -> !torch.tensor
return %ret: !torch.tensor
}
// CHECK-LABEL: func @reduce_trailing_underscore_inplace_variant(
// CHECK-SAME: %[[ARG0:.*]]: !torch.tensor<[2,2],f32>,

View File

@ -43,9 +43,8 @@ def TorchConversion_ToBuiltinTensorOp : TorchConversion_Op<"to_builtin_tensor",
}];
}
def TorchConversion_FromBuiltinTensorOp : TorchConversion_Op<"from_builtin_tensor", [
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
def TorchConversion_FromBuiltinTensorOp : TorchConversion_Op<"from_builtin_tensor">
{
let summary = "Convert a `tensor` to a `!torch.vtensor`";
let description = [{
This op only operates on ValueTensorType, to avoid conflating conversions

View File

@ -66,8 +66,8 @@ static LogicalResult verifyLinalgCompatibleTypes(Operation *op,
// list values are not supported.
// TODO: loose this constraint when properly support list type
static bool isConstantIntListMatching(Value value,
llvm::SmallVectorImpl<int64_t> &expects) {
llvm::SmallVector<int64_t> intValues;
SmallVectorImpl<int64_t> &expects) {
SmallVector<int64_t> intValues;
if (!matchPattern(value, m_TorchConstantIntList(intValues)))
return false;
@ -81,10 +81,47 @@ static bool isConstantIntListMatching(Value value,
return true;
}
static Value castIntToIndex(OpBuilder &b, Location loc, Value v) {
assert(v.getType().isa<IntegerType>() && "must be called with integer type");
return b.create<IndexCastOp>(loc, b.getIndexType(), v);
}
static Value castIndexToInt(OpBuilder &b, Location loc, Value idx) {
assert(idx.getType().isa<IndexType>() && "must be called with integer type");
return b.create<IndexCastOp>(loc, b.getI64Type(), idx);
}
static Value getDimOp(OpBuilder &b, Location loc, Value v, int dimension) {
return b.create<tensor::DimOp>(loc, v, dimension);
}
static void checkDimEqualHelper(OpBuilder &b, Location loc, Value lhsDimIndex,
Value rhsDimIndex) {
Value lhsDimInt = castIndexToInt(b, loc, lhsDimIndex);
Value rhsDimInt = castIndexToInt(b, loc, rhsDimIndex);
Value contractingDimEqual =
b.create<CmpIOp>(loc, CmpIPredicate::eq, lhsDimInt, rhsDimInt);
b.create<AssertOp>(loc, contractingDimEqual,
b.getStringAttr("mismatching contracting dimension"));
}
static SmallVector<Value> getTensorSizes(OpBuilder &b, Location loc,
Value tensor) {
RankedTensorType type = tensor.getType().cast<RankedTensorType>();
SmallVector<Value> sizes;
for (int i = 0; i < type.getRank(); i++)
sizes.push_back(getDimOp(b, loc, tensor, i));
return sizes;
}
static Value createZeroInitTensor(OpBuilder &b, Location loc, ValueRange sizes,
Type elemTy) {
Value initTensor = b.create<linalg::InitTensorOp>(loc, sizes, elemTy);
RankedTensorType type = initTensor.getType().cast<RankedTensorType>();
Value c0 = b.create<ConstantOp>(loc, b.getZeroAttr(type.getElementType()));
return b.create<linalg::FillOp>(loc, c0, initTensor).getResult(0);
}
// Helper function to caculate the output tensor dims for convolution-like ops.
// Along each dim:
// dim_out =
@ -92,21 +129,13 @@ static Value getDimOp(OpBuilder &b, Location loc, Value v, int dimension) {
static Value getOutputDimForConvOps(OpBuilder &b, Location loc, Value in,
Value paddingInt, Value dilationInt,
Value kernelSizeInt, Value strideInt) {
Type intType = b.getIntegerType(64);
Type indexType = b.getIndexType();
auto castIndexToInt = [&](Value v) {
return b.create<IndexCastOp>(loc, intType, v);
};
auto castIntToIndex = [&](Value v) {
return b.create<IndexCastOp>(loc, indexType, v);
};
Value c1 = b.create<ConstantOp>(loc, b.getI64IntegerAttr(1));
Value c2 = b.create<ConstantOp>(loc, b.getI64IntegerAttr(2));
Value doublePadding = b.create<MulIOp>(loc, paddingInt, c2);
// in + 2 * padding
Value inAddDoublePadding =
b.create<AddIOp>(loc, castIndexToInt(in), doublePadding);
b.create<AddIOp>(loc, castIndexToInt(b, loc, in), doublePadding);
// dilation * (kernelSize - 1)
Value kernelSizeSub1 = b.create<SubIOp>(loc, kernelSizeInt, c1);
@ -118,7 +147,7 @@ static Value getOutputDimForConvOps(OpBuilder &b, Location loc, Value in,
Value dividend = b.create<SubIOp>(loc, temp, c1);
Value division = b.create<SignedFloorDivIOp>(loc, dividend, strideInt);
Value out = b.create<AddIOp>(loc, division, c1);
return castIntToIndex(out);
return castIntToIndex(b, loc, out);
}
static SmallVector<Value>
@ -150,15 +179,15 @@ static Value getPaddedTensor(Operation *op, OpBuilder &b, Value &input,
assert(input.getType().isa<RankedTensorType>() &&
"input must be RankedTensorType");
Location loc = op->getLoc();
Value c0float = b.create<ConstantOp>(
loc, FloatAttr::get(
input.getType().cast<RankedTensorType>().getElementType(), 0.0));
Value c0 = b.create<ConstantOp>(
loc,
b.getZeroAttr(input.getType().cast<RankedTensorType>().getElementType()));
SmallVector<OpFoldResult> paddings = getAsOpFoldResult(b, loc, paddingInts);
Type ranked4DTensorType = linalg::PadTensorOp::inferResultType(
input.getType().cast<RankedTensorType>(), paddingInts, paddingInts);
Value paddedInput = linalg::PadTensorOp::createPadScalarOp(
ranked4DTensorType, input, c0float, /*low=*/paddings, /*high=*/paddings,
loc, b);
ranked4DTensorType, input, c0, /*low=*/paddings, /*high=*/paddings, loc,
b);
return paddedInput;
}
@ -168,7 +197,7 @@ class ConvertAtenAdaptiveAvgPool2dOp
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenAdaptiveAvgPool2dOp op, llvm::ArrayRef<Value> operands,
matchAndRewrite(AtenAdaptiveAvgPool2dOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
@ -177,7 +206,7 @@ public:
RankedTensorType inputType = input.getType().cast<RankedTensorType>();
Type elementType = inputType.getElementType();
if (!elementType.isa<mlir::FloatType>())
op.emitError("unimplemented: non-floating point type");
return op.emitError("unimplemented: non-floating point type");
auto inputRank = inputType.getRank();
if (inputRank != 4)
@ -266,7 +295,7 @@ class ConvertAtenConv2dOp : public OpConversionPattern<AtenConv2dOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenConv2dOp op, llvm::ArrayRef<Value> operands,
matchAndRewrite(AtenConv2dOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
MLIRContext *context = op->getContext();
@ -278,7 +307,7 @@ public:
Type elementType =
input.getType().cast<RankedTensorType>().getElementType();
if (!elementType.isa<mlir::FloatType>())
op.emitError("unimplemented: non-floating point type");
return op.emitError("unimplemented: non-floating point type");
Type intType = IntegerType::get(context, 64);
auto castIndexToInt = [&](Value v) {
@ -295,17 +324,17 @@ public:
// Pattern match against the op's original operands, because otherwise we
// will get the lowered version of the operands which is harder to pattern
// match.
llvm::SmallVector<int64_t> paddingInts;
SmallVector<int64_t> paddingInts;
if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts))) {
return rewriter.notifyMatchFailure(
op, "only support constant padding values");
}
llvm::SmallVector<int64_t, 2> strideInts;
SmallVector<int64_t, 2> strideInts;
if (!matchPattern(op.stride(), m_TorchConstantIntList(strideInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int strides");
llvm::SmallVector<int64_t, 2> dilationInts;
SmallVector<int64_t, 2> dilationInts;
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int dilations");
@ -534,6 +563,60 @@ public:
};
} // namespace
namespace {
class ConvertAtenBmmOp : public OpConversionPattern<AtenBmmOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenBmmOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
Location loc = op->getLoc();
Value lhs = operands[0];
Value rhs = operands[1];
RankedTensorType lhsType = lhs.getType().cast<RankedTensorType>();
RankedTensorType rhsType = rhs.getType().cast<RankedTensorType>();
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
if (lhsType.getRank() != 3 || rhsType.getRank() != 3) {
return rewriter.notifyMatchFailure(
op, "expected both operands to aten.bmm to be rank 3");
}
if (!lhsType.getElementType().isa<mlir::FloatType>() ||
lhsType.getElementType() != rhsType.getElementType())
return op.emitError(
"unimplemented: non floating point operands or operands of "
"different types");
Value lhsDim0 = getDimOp(rewriter, loc, lhs, 0);
Value lhsDim1 = getDimOp(rewriter, loc, lhs, 1);
Value lhsDim2 = getDimOp(rewriter, loc, lhs, 2);
Value rhsDim0 = getDimOp(rewriter, loc, rhs, 0);
Value rhsDim1 = getDimOp(rewriter, loc, rhs, 1);
Value rhsDim2 = getDimOp(rewriter, loc, rhs, 2);
// Check the batch numbers are equal.
checkDimEqualHelper(rewriter, loc, lhsDim0, rhsDim0);
// Check the matrixs shapes are valid for mulplication.
checkDimEqualHelper(rewriter, loc, lhsDim2, rhsDim1);
Type newResultType = getTypeConverter()->convertType(op.getType());
Type elementType = newResultType.cast<TensorType>().getElementType();
Value initTensor0 = createZeroInitTensor(
rewriter, loc, ValueRange{lhsDim0, lhsDim1, rhsDim2}, elementType);
Value bmm =
rewriter
.create<linalg::BatchMatmulOp>(loc, initTensor0.getType(),
ValueRange{lhs, rhs}, initTensor0)
.getResult(0);
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, bmm);
return success();
}
};
} // namespace
namespace {
// See comments at in convertMmOp and the heading for this section for general
// considerations. This function needs to be auto-generated.
@ -1074,7 +1157,7 @@ class ConvertAtenMaxPool2dOp : public OpConversionPattern<AtenMaxPool2dOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenMaxPool2dOp op, llvm::ArrayRef<Value> operands,
matchAndRewrite(AtenMaxPool2dOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
@ -1085,24 +1168,24 @@ public:
Type elementType = self.getType().cast<RankedTensorType>().getElementType();
if (!elementType.isa<mlir::FloatType>())
op.emitError("unimplemented: non-floating point type");
return op.emitError("unimplemented: non-floating point type");
// Pattern match against the op's original operands, because otherwise we
// will get the lowered version of the operands which is harder to pattern
// match.
llvm::SmallVector<int64_t, 2> strideInts;
SmallVector<int64_t, 2> strideInts;
if (!matchPattern(op.stride(), m_TorchConstantIntList(strideInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int strides");
llvm::SmallVector<int64_t, 2> dilationInts;
SmallVector<int64_t, 2> dilationInts;
if (!matchPattern(op.dilation(), m_TorchConstantIntList(dilationInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int dilations");
llvm::SmallVector<int64_t, 2> paddingInts;
SmallVector<int64_t, 2> paddingInts;
if (!matchPattern(op.padding(), m_TorchConstantIntList(paddingInts)))
return rewriter.notifyMatchFailure(op,
"only support constant int paddings");
llvm::SmallVector<int64_t, 2> kernelSizeInts;
SmallVector<int64_t, 2> kernelSizeInts;
if (!matchPattern(op.kernel_size(), m_TorchConstantIntList(kernelSizeInts)))
return rewriter.notifyMatchFailure(op, "only support kernel size ints");
@ -1177,7 +1260,7 @@ class ConvertAtenFlattenUsingIntsOp
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenFlattenUsingIntsOp op, llvm::ArrayRef<Value> operands,
matchAndRewrite(AtenFlattenUsingIntsOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
@ -1281,7 +1364,7 @@ class ConvertAtenTransposeIntOp
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenTransposeIntOp op, llvm::ArrayRef<Value> operands,
matchAndRewrite(AtenTransposeIntOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
@ -1313,7 +1396,7 @@ public:
auto loc = op.getLoc();
llvm::SmallVector<Value> outputDims;
SmallVector<Value> outputDims;
for (auto i = 0; i < inputRank; i++)
outputDims.push_back(getDimOp(rewriter, loc, adaptor.self(), i));
std::swap(outputDims[dim0], outputDims[dim1]);
@ -1352,6 +1435,126 @@ public:
};
} // namespace
namespace {
class ConvertAtenCatOp : public OpConversionPattern<AtenCatOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenCatOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op.getLoc();
TypeConverter *typeConverter = getTypeConverter();
AtenCatOp::Adaptor adaptor(operands);
Value dimValue = op.dim();
int64_t dim;
if (!matchPattern(dimValue, m_TorchConstantInt(&dim)))
return op.emitError("unimplemented: dim is not constant");
// Collect all the tensors to be concatenated.
auto tensorList = op.tensors();
auto listConstruct = tensorList.getDefiningOp<PrimListConstructOp>();
if (!listConstruct)
return op.emitError(
"unimplemented: the tensor list is not from list construct");
auto tensors = llvm::to_vector<4>(
llvm::map_range(listConstruct.elements(), [&](Value tensor) -> Value {
return typeConverter->materializeTargetConversion(
rewriter, loc, getTypeConverter()->convertType(tensor.getType()),
tensor);
}));
RankedTensorType newResultType =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
int rank = newResultType.getRank();
SmallVector<Value> offsets, sizes, strides;
sizes.reserve(rank);
strides.resize(rank, rewriter.create<ConstantIndexOp>(loc, 1));
offsets.resize(rank, rewriter.create<ConstantIndexOp>(loc, 0));
for (int i = 0; i < rank; ++i)
sizes.push_back(rewriter.create<tensor::DimOp>(loc, tensors[0], i));
// Calculate the size of the `dim` result dimension by adding the dim size
// of each tensor together.
Value resultDimSize = sizes[dim];
Value dimIndex = rewriter.create<IndexCastOp>(loc, rewriter.getIndexType(),
adaptor.dim());
for (auto tensor : makeArrayRef(tensors).drop_front()) {
auto size = rewriter.create<tensor::DimOp>(loc, tensor, dimIndex);
resultDimSize = rewriter.create<AddIOp>(loc, resultDimSize, size);
}
sizes[dim] = resultDimSize;
Value result = rewriter.create<linalg::InitTensorOp>(
loc, sizes, newResultType.getElementType());
for (auto tensor : tensors) {
sizes[dim] = rewriter.create<tensor::DimOp>(loc, tensor, dimIndex);
result = rewriter.create<tensor::InsertSliceOp>(loc, tensor, result,
offsets, sizes, strides);
offsets[dim] = rewriter.create<AddIOp>(loc, offsets[dim], sizes[dim]);
}
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, newResultType, result);
return success();
}
};
} // namespace
namespace {
class ConvertAtenGatherOp : public OpConversionPattern<AtenGatherOp> {
public:
using OpConversionPattern::OpConversionPattern;
LogicalResult
matchAndRewrite(AtenGatherOp op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
Location loc = op->getLoc();
AtenGatherOp::Adaptor adaptor(operands);
Value dimValue = op.dim();
int64_t dim;
if (!matchPattern(dimValue, m_TorchConstantInt(&dim)))
return op.emitError("unimplemented: dim is not constant");
Value indices = adaptor.index();
Value self = adaptor.self();
RankedTensorType newResultTy =
getTypeConverter()->convertType(op.getType()).cast<RankedTensorType>();
int64_t rank = newResultTy.getRank();
SmallVector<Value> sizes = getTensorSizes(rewriter, loc, indices);
Value result = createZeroInitTensor(rewriter, loc, sizes,
newResultTy.getElementType());
SmallVector<AffineMap, 2> affineMaps(2,
rewriter.getMultiDimIdentityMap(rank));
SmallVector<StringRef> iteratorTypes(rank, getParallelIteratorTypeName());
auto genericOp = rewriter.create<linalg::GenericOp>(
loc, newResultTy, indices, result, affineMaps, iteratorTypes,
[&](OpBuilder &b, Location loc, ValueRange args) {
auto indexValue = args[0];
Value indexOfDim = rewriter.create<IndexCastOp>(
loc, rewriter.getIndexType(), indexValue);
SmallVector<Value> indices;
for (int i = 0; i < rank; i++) {
indices.push_back(i == dim
? indexOfDim
: rewriter.create<linalg::IndexOp>(loc, i));
}
Value extract =
rewriter.create<tensor::ExtractOp>(loc, self, indices);
rewriter.create<linalg::YieldOp>(loc, extract);
});
rewriter.replaceOp(op, genericOp.getResult(0));
return success();
}
};
} // namespace
// -----------------------------------------------------------------------------
// The pass
// -----------------------------------------------------------------------------
@ -1381,6 +1584,8 @@ public:
RewritePatternSet patterns(context);
target.addIllegalOp<AtenMmOp>();
patterns.add<ConvertAtenMmOp>(typeConverter, context);
target.addIllegalOp<AtenBmmOp>();
patterns.add<ConvertAtenBmmOp>(typeConverter, context);
target.addIllegalOp<AtenLinearOp>();
patterns.add<ConvertAtenLinearOp>(typeConverter, context);
target.addIllegalOp<AtenBatchNormOp>();
@ -1404,6 +1609,10 @@ public:
patterns.add<ConvertReductionOp>(typeConverter, context);
target.addIllegalOp<AtenTransposeIntOp>();
patterns.add<ConvertAtenTransposeIntOp>(typeConverter, context);
target.addIllegalOp<AtenCatOp>();
patterns.add<ConvertAtenCatOp>(typeConverter, context);
target.addIllegalOp<AtenGatherOp>();
patterns.add<ConvertAtenGatherOp>(typeConverter, context);
if (failed(applyPartialConversion(getOperation(), target,
std::move(patterns))))

View File

@ -36,18 +36,5 @@ LogicalResult ToBuiltinTensorOp::inferReturnTypes(
return success();
}
//===----------------------------------------------------------------------===//
// FromBuiltinTensorOp
//===----------------------------------------------------------------------===//
LogicalResult FromBuiltinTensorOp::inferReturnTypes(
MLIRContext *context, Optional<Location> location, ValueRange operands,
DictionaryAttr attributes, RegionRange regions,
SmallVectorImpl<Type> &inferredReturnTypes) {
inferredReturnTypes.push_back(Torch::ValueTensorType::getFromShaped(
operands[0].getType().cast<TensorType>()));
return success();
}
#define GET_OP_CLASSES
#include "npcomp/Dialect/TorchConversion/IR/TorchConversionOps.cpp.inc"

View File

@ -53,7 +53,7 @@ setupValueTensorToBuiltinTensorConversion(ConversionTarget &target,
ValueRange inputs, Location loc) -> Value {
assert(inputs.size() == 1);
assert(inputs[0].getType().isa<TensorType>());
return builder.create<FromBuiltinTensorOp>(loc, inputs[0]);
return builder.create<FromBuiltinTensorOp>(loc, type, inputs[0]);
};
typeConverter.addSourceMaterialization(sourceMaterialization);
typeConverter.addArgumentMaterialization(sourceMaterialization);

View File

@ -1,14 +1,16 @@
// RUN: npcomp-opt %s | npcomp-opt | FileCheck %s
// CHECK-LABEL: func @builtin_tensor_interop(
func @builtin_tensor_interop(%arg0: tensor<*xf32>, %arg1: tensor<3x?xsi8>, %arg2: !torch.vtensor<*,f32>, %arg3: !torch.vtensor<[3,?],si8>) {
func @builtin_tensor_interop(%arg0: tensor<*xf32>, %arg1: tensor<3x?xi8>, %arg2: !torch.vtensor<*,f32>, %arg3: !torch.vtensor<[3,?],si8>) {
// CHECK: torch_c.from_builtin_tensor %arg0 : tensor<*xf32> -> !torch.vtensor<*,f32>
%0 = torch_c.from_builtin_tensor %arg0 : tensor<*xf32> -> !torch.vtensor<*,f32>
// CHECK: torch_c.from_builtin_tensor %arg1 : tensor<3x?xsi8> -> !torch.vtensor<[3,?],si8>
%1 = torch_c.from_builtin_tensor %arg1 : tensor<3x?xsi8> -> !torch.vtensor<[3,?],si8>
// CHECK: torch_c.from_builtin_tensor %arg1 : tensor<3x?xi8> -> !torch.vtensor<[3,?],si8>
%1 = torch_c.from_builtin_tensor %arg1 : tensor<3x?xi8> -> !torch.vtensor<[3,?],si8>
// CHECK: torch_c.from_builtin_tensor %arg1 : tensor<3x?xi8> -> !torch.vtensor<[3,?],ui8>
%2 = torch_c.from_builtin_tensor %arg1 : tensor<3x?xi8> -> !torch.vtensor<[3,?],ui8>
// CHECK: torch_c.to_builtin_tensor %arg2 : !torch.vtensor<*,f32> -> tensor<*xf32>
%2 = torch_c.to_builtin_tensor %arg2 : !torch.vtensor<*,f32> -> tensor<*xf32>
// CHECK: torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[3,?],si8> -> tensor<3x?xsi8>
%3 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[3,?],si8> -> tensor<3x?xsi8>
%3 = torch_c.to_builtin_tensor %arg2 : !torch.vtensor<*,f32> -> tensor<*xf32>
// CHECK: torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[3,?],si8> -> tensor<3x?xi8>
%4 = torch_c.to_builtin_tensor %arg3 : !torch.vtensor<[3,?],si8> -> tensor<3x?xi8>
return
}