Undo shape lib changes + update function signature of sum + zero (#1035)

This commit does three things:
  1. Reverts some of the shape lib changes merged in
  https://github.com/llvm/torch-mlir/pull/844
  2. Updates the signature of `aten.sum_dim_IntList` that was recently
  updated in
  23bdb570cf
  3. Replaces `aten.zero.functional` with `aten.zero`, updated in 960758b0b7
pull/1036/head
Ramiro Leal-Cavazos 2022-07-11 12:56:12 -05:00 committed by GitHub
parent 2d75654b2c
commit 11148e60d6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 403 additions and 400 deletions

View File

@ -2346,12 +2346,12 @@ def Torch_AtenUnsqueeze_Op : Torch_Op<"aten.unsqueeze_", [
}];
}
def Torch_AtenZeroFunctionalOp : Torch_Op<"aten.zero.functional", [
def Torch_AtenZeroOp : Torch_Op<"aten.zero", [
AllowsTypeRefinement,
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::zero.functional : (Tensor) -> (Tensor)`";
let summary = "Generated op for `aten::zero : (Tensor) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self
);
@ -2360,16 +2360,17 @@ def Torch_AtenZeroFunctionalOp : Torch_Op<"aten.zero.functional", [
);
let hasCustomAssemblyFormat = 1;
let extraClassDefinition = [{
ParseResult AtenZeroFunctionalOp::parse(OpAsmParser &parser, OperationState &result) {
ParseResult AtenZeroOp::parse(OpAsmParser &parser, OperationState &result) {
return parseDefaultTorchOp(parser, result, 1, 1);
}
void AtenZeroFunctionalOp::print(OpAsmPrinter &printer) {
void AtenZeroOp::print(OpAsmPrinter &printer) {
printDefaultTorchOp(printer, *this, 1, 1);
}
}];
}
def Torch_AtenZero_Op : Torch_Op<"aten.zero_", [
IsTrailingUnderscoreInplaceVariant,
AllowsTypeRefinement
]> {
let summary = "Generated op for `aten::zero_ : (Tensor) -> (Tensor)`";
@ -5396,10 +5397,10 @@ def Torch_AtenSumDimIntListOp : Torch_Op<"aten.sum.dim_IntList", [
HasValueSemantics,
ReadOnly
]> {
let summary = "Generated op for `aten::sum.dim_IntList : (Tensor, int[], bool, int?) -> (Tensor)`";
let summary = "Generated op for `aten::sum.dim_IntList : (Tensor, int[]?, bool, int?) -> (Tensor)`";
let arguments = (ins
AnyTorchTensorType:$self,
AnyTorchListOfTorchIntType:$dim,
AnyTorchOptionalListOfTorchIntType:$dim,
Torch_BoolType:$keepdim,
AnyTorchOptionalIntType:$dtype
);

View File

@ -224,11 +224,11 @@ public:
} // namespace
namespace {
class DecomposeAtenZeroFunctionalOp
: public OpRewritePattern<AtenZeroFunctionalOp> {
class DecomposeAtenZeroOp
: public OpRewritePattern<AtenZeroOp> {
public:
using OpRewritePattern::OpRewritePattern;
LogicalResult matchAndRewrite(AtenZeroFunctionalOp op,
LogicalResult matchAndRewrite(AtenZeroOp op,
PatternRewriter &rewriter) const override {
Value zero = rewriter.create<ConstantIntOp>(op.getLoc(),
rewriter.getI64IntegerAttr(0));
@ -2272,8 +2272,8 @@ class DecomposeComplexOpsPass
target.addIllegalOp<ValsemVariantAtenBernoulliFloatOp>();
patterns.add<DecomposeValsemVariantAtenBernoulliTensorOp>(context);
target.addIllegalOp<ValsemVariantAtenBernoulliTensorOp>();
patterns.add<DecomposeAtenZeroFunctionalOp>(context);
target.addIllegalOp<AtenZeroFunctionalOp>();
patterns.add<DecomposeAtenZeroOp>(context);
target.addIllegalOp<AtenZeroOp>();
patterns.add<DecomposeAtenRandLikeOp>(context);
target.addIllegalOp<AtenRandLikeOp>();
patterns.add<DecomposeAtenHardsigmoidOp>(context);

View File

@ -183,9 +183,6 @@ public:
} else if (isa<AtenBernoulli_TensorOp>(op)) {
newOp = rewriter.create<ValsemVariantAtenBernoulliTensorOp>(
loc, op->getResultTypes(), op->getOperands());
} else if (isa<AtenZero_Op>(op)) {
newOp = rewriter.create<AtenZeroFunctionalOp>(
loc, op->getResultTypes(), op->getOperands());
} else if (isa<AtenFill_ScalarOp>(op)) {
newOp = rewriter.create<ValsemVariantAtenFillScalarOp>(
loc, op->getResultTypes(), op->getOperands());
@ -273,7 +270,6 @@ class ReduceOpVariantsPass : public ReduceOpVariantsBase<ReduceOpVariantsPass> {
target.addIllegalOp<AtenUniform_Op>();
target.addIllegalOp<AtenBernoulli_FloatOp>();
target.addIllegalOp<AtenBernoulli_TensorOp>();
target.addIllegalOp<AtenZero_Op>();
target.addIllegalOp<AtenFill_ScalarOp>();
target.addIllegalOp<Aten_IndexPutImpl_Op>();
target.addIllegalOp<AtenCopy_Op>();

View File

@ -649,7 +649,7 @@ ChangeResult TypeAnalyzer::visitOperation(
AtenGatherOp, AtenExpandOp, AtenExpandAsOp, AtenBroadcastToOp,
AtenRepeatOp, AtenConstantPadNdOp, AtenPadOp, AtenZero_Op,
AtenIndexTensorOp, ValsemVariantAtenIndexPutImplOp, AtenIndexPutOp,
ValsemVariantAtenCopyOp, AtenZeroFunctionalOp,
ValsemVariantAtenCopyOp, AtenZeroOp,
AtenIndexPutHackedTwinOp, AtenMaskedFillScalarOp, AtenFlipOp,
PrimAbsScalarOp, AtenNumpyTOp, AtenTriuOp>(op)) {
return incorporateKnowledge(op->getResult(0), operands[0]->getValue());

File diff suppressed because it is too large Load Diff

View File

@ -519,9 +519,11 @@ def atenmaxdim(self: List[int], dim: int, keepdim: bool = False) -> Tuple[
def atenmeandim(self: List[int], dim: List[int], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype)
def atensumdim_IntList(self: List[int], dim: List[int], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype)
def atensumdim_IntList(self: List[int], dim: Optional[List[int]], keepdim: bool = False, dtype: Optional[int] = None) -> List[int]:
if dim is None:
return upstream_shape_functions.mean_dim(self, [], keepdim, dtype)
else:
return upstream_shape_functions.mean_dim(self, dim, keepdim, dtype)
def atenpermute(self: List[int], dims: List[int]) -> List[int]:
return upstream_shape_functions.permute(self, dims)
@ -717,13 +719,9 @@ def aten_to_copy(self: List[int], dtype: Optional[int] = None, layout: Option
def atenmasked_fillScalar(self: List[int], mask: List[int], value: float) -> List[int]:
return upstream_shape_functions.unary(self)
@not_present_in_registry
def atenzero(self: List[int]) -> List[int]:
return self
def atenzerofunctional(self: List[int]) -> List[int]:
return self
@not_present_in_registry
def atenfillScalar(self: List[int], value: float) -> List[int]:
return self

View File

@ -288,7 +288,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
"aten::threshold : (Tensor, Scalar, Scalar) -> (Tensor)",
"aten::square : (Tensor) -> (Tensor)",
"aten::unsqueeze : (Tensor, int) -> (Tensor)",
"aten::zero.functional : (Tensor) -> (Tensor)",
"aten::zero : (Tensor) -> (Tensor)",
]:
emit_with_mutating_variants(key)
# Elementwise tensor compute ops that don't have the standard mutating
@ -441,7 +441,7 @@ def emit_ops(emitter_td: TextEmitter, registry: Registry):
emit("aten::size.int : (Tensor, int) -> (int)", has_folder=True)
emit("aten::stack : (Tensor[], int) -> (Tensor)")
emit("aten::sum : (Tensor, int?) -> (Tensor)")
emit("aten::sum.dim_IntList : (Tensor, int[], bool, int?) -> (Tensor)")
emit("aten::sum.dim_IntList : (Tensor, int[]?, bool, int?) -> (Tensor)")
emit("aten::max : (Tensor) -> (Tensor)")
emit("aten::max.dim : (Tensor, int, bool) -> (Tensor, Tensor)")
emit("aten::to.dtype : (Tensor, int, bool, bool, int?) -> (Tensor)", has_folder=True)

View File

@ -829,13 +829,13 @@ func.func @torch.aten.dropout$train(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.
}
// -----
// CHECK-LABEL: func.func @torch.aten.zero.functional(
// CHECK-LABEL: func.func @torch.aten.zero(
// CHECK-SAME: %[[INP:.*]]: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
// CHECK: %[[ZERO:.*]] = torch.constant.int 0
// CHECK: %[[OUT:.*]] = torch.valsem.aten.fill.Scalar %[[INP]], %[[ZERO]] : !torch.vtensor<[?,?],f32>, !torch.int -> !torch.vtensor<[?,?],f32>
// CHECK: return %[[OUT]] : !torch.vtensor<[?,?],f32>
func.func @torch.aten.zero.functional(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.zero.functional %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
func.func @torch.aten.zero(%arg0: !torch.vtensor<[?,?],f32>) -> !torch.vtensor<[?,?],f32> {
%0 = torch.aten.zero %arg0 : !torch.vtensor<[?,?],f32> -> !torch.vtensor<[?,?],f32>
return %0 : !torch.vtensor<[?,?],f32>
}