Add `aten.mm` to linalg lowering.

This is our first op with error semantics, and stresses the system.

There are a few design notes of special interest:
- RefineTypes.cpp's note about shape inference in the presence of code
  that dynamically produces and error, and it is provable statically.
- ATenToLinalg.cpp's notes about future automation of the ATen->linalg
  path.
- The notes in Passes.td about using low-tech `std.assert` ops instead
  of `shape.assuming`.

Note: Doesn't work on IREE yet due to the `std.assert` op (needs to be
lowered to `vm.fail` on the IREE side).
pull/207/head
Sean Silva 2021-04-08 17:43:41 -07:00
parent 28a0f02746
commit f5dfa02523
13 changed files with 437 additions and 14 deletions

View File

@ -0,0 +1,51 @@
# -*- Python -*-
# This file is licensed under a pytorch-style license
# See frontends/pytorch/LICENSE for license information.
import typing
import torch
import torch_mlir
import npcomp
from npcomp.compiler.pytorch.backend import refjit, frontend_lowering
from npcomp.compiler.utils import logging
import test_utils
logging.enable()
# RUN: %PYTHON %s | npcomp-opt | FileCheck %s
mb = torch_mlir.ModuleBuilder()
class TestModule(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, lhs, rhs):
return torch.mm(lhs, rhs)
test_module = TestModule()
class_annotator = torch_mlir.ClassAnnotator()
recursivescriptmodule = torch.jit.script(test_module)
torch.jit.save(recursivescriptmodule, '/tmp/foo.pt')
class_annotator.exportNone(recursivescriptmodule._c._type())
class_annotator.exportPath(recursivescriptmodule._c._type(), ['forward'])
class_annotator.annotateShapesAndDtypes(recursivescriptmodule._c._type(), ['forward'], [
None,
([-1, -1], torch.float32),
([-1, -1], torch.float32),
])
# TODO: Automatically handle unpacking Python class RecursiveScriptModule into the underlying ScriptModule.
mb.import_module(recursivescriptmodule._c, class_annotator)
#mb.module.operation.print()
backend = refjit.CompilerBackend()
compiled = backend.compile(frontend_lowering.lower_object_graph(mb.module))
jit_module = backend.load(compiled)
torch.manual_seed(0)
lhs = torch.rand(2, 3)
rhs = torch.rand(3, 4)
test_utils.compare_outputs(test_module.forward, jit_module.forward, lhs, rhs)

View File

@ -0,0 +1,21 @@
//===------------------------------------------------------------*- C++ -*-===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#ifndef NPCOMP_CONVERSION_ATENTOLINALG_ATENTOLINALG_H
#define NPCOMP_CONVERSION_ATENTOLINALG_ATENTOLINALG_H
#include "mlir/Pass/Pass.h"
#include <memory>
namespace mlir {
namespace NPCOMP {
std::unique_ptr<OperationPass<FuncOp>> createConvertATenToLinalgPass();
}
} // namespace mlir
#endif // NPCOMP_CONVERSION_ATENTOLINALG_ATENTOLINALG_H

View File

@ -20,6 +20,85 @@ def ConvertATenToTCF : Pass<"convert-aten-to-tcf", "FuncOp"> {
let constructor = "mlir::NPCOMP::createConvertATenToTCFPass()";
}
def ConvertATenToLinalg : Pass<"convert-aten-to-linalg", "FuncOp"> {
let summary = "Convert recognized ATen to Linalg ops";
let description = [{
Convert ATen ops to linalg ops.
This pass's main responsibility is to bridge the world between ops
that safely terminate the program in case of operand shape mismatches
(ATen) and ops where such mismatches are undefined behavior (linalg).
To model the termination of the program for implementing error guards,
we use the `std.assert` op.
This is a design decision that is at variance from other passes in npcomp,
such as `convert-tcf-to-std` and `convert-tcf-to-linalg` which use the
`shape` dialect's witness system (`shape.cstr_*` family of ops feeding into
`shape.assuming` regions). This is a change in design decisions
from those passes (which will be subsumed by this one). The reasons for this
change are heuristic, but boil down to:
1. The modeling of `shape.assuming` is odd, as it uses a region, which is
not a good fit for modeling error guards. Regions mark a "start" and an
"end" (which is their nesting property). But
modeling assertions in the program doesn't fit into that. For assertions,
only the "start" matters (once tested, a predicate remains true "forever"
-- it doesn't end at the "yield" of the region).
Thus, having regions places arbitrary "end"s that just add IR structure
that has no semantic value for modeling this problem! (and to make things
worse the "end"s, which we don't need, are what require "yielding"
values, which interrupts use-def chains). Consider the different
structural properties of regions:
a. IsolatedFromAbove region:
- "start" interrupts use-def chains,
- "end" interrupts use-def chains
- structurally protects from intra-block upward and downward
code motion
b. Capturing region (like `shape.assuming`):
- "start" does not interrupt use-def chains,
- "end" interrupts use-def chains
- structurally protects from intra-block upward and downward
code motion
c. What we "ideally" want:
- "start" interrupts use-def chains (can be pruned though)
- no "end" IR structure!
- structurally protects from intra-block upward code motion
(but not downward code motion!)
- Observation: We probably can't get all of this, but overall this
problem is much better suited for a "MemorySSA"-like
abstraction, call it "EffectSSA" which is constructed on-demand
based on MLIR's effect modeling system (rather than
`shape.assuming`, which only covers the effects the IR creator
encoded -- with witnesses/`shape.assuming` -- it is easy to forget
to handle effects other than those encoded in the
witness structure).
2. The presence of `shape.assuming` regions tends to create highly nested
IR structures, which don't interoperate well with any other IR
structures, and creates very bulky IR (and IR creation code). In general
if we are going to do anything with anything (e.g. canonicalize) we
end up needing need to either:
a. Flatten the `shape.assuming` IR (defeating the purpose of having
it).
b. Do some sort of shape.assuming "region merging".
c. Have special patterns that handle a subset of special cases (looking
through "yields" and such) and don't generalize.
3. Witnesses tend to encourage non-scalable peephole transformations, which
tend to make analyses/transformations non-robust to the presence of
control flow and side effecting ops (easy to forget to handle side
effects other than those modeled by the witness system).
4. All this code operates on ranked tensors, for which using individual
SSA values for sizes (rather than a "shape type") seems to
work really well at this level of abstraction based on prior experience
in IREE. (unranked code tends to benefit from having a discrete
"shape type" to model shapes).
We will see if we end up needing something like `shape.assuming`, but for
now, it seems likely we can do something simpler and just bypass it. The
design of having an EffectSSA that is constructed on-demand seems very
compelling for modeling effects more broadly.
}];
let constructor = "mlir::NPCOMP::createConvertATenToLinalgPass()";
}
//===----------------------------------------------------------------------===//
// Basicpy conversions
//===----------------------------------------------------------------------===//

View File

@ -642,4 +642,3 @@ def aten_CopyInplaceOp: aten_Op<"copy.inplace", [DeclareOpInterfaceMethods<Torch
let results = (outs
);
}

View File

@ -0,0 +1,142 @@
//===----------------------------------------------------------------------===//
//
// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
// See https://llvm.org/LICENSE.txt for license information.
// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
//
//===----------------------------------------------------------------------===//
#include "npcomp/Conversion/ATenToLinalg/ATenToLinalg.h"
#include "../PassDetail.h"
#include "mlir/Dialect/Linalg/IR/LinalgOps.h"
#include "mlir/Dialect/MemRef/IR/MemRef.h" // TODO: For `memref.dim`.
#include "mlir/Dialect/Traits.h"
#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
#include "npcomp/Dialect/ATen/IR/ATenDialect.h"
using namespace mlir;
using namespace mlir::NPCOMP;
// -----------------------------------------------------------------------------
// Patterns (as this grows, it should be organized into multiple files)
// -----------------------------------------------------------------------------
// This is going to eventually be O(#aten ops), which is in the 100s.
//
// Most of these patterns consist of:
// 1. Checking that the operand/result types and other static properties are
// good-enough to create a valid linalg op (such as operands being of
// ranks/dtypes acceptable to the linalg op).
// 2. Creating dynamic error guards, usually checking a predicate on the
// compatibility of operand shapes.
// 3. Creating init tensors for the computation op. Usually this involves
// reifying IR for a shape transfer function based on the operand shapes.
// 4. Creating a named linalg op to replace the original op.
//
// TODO: Use linalg OpDSL to autogenerate at least 1)/2)/3) such
// that these patterns become mostly mechanical associations of
// "aten.foo -> linalg.foo".
static LogicalResult verifyLinalgCompatibleTypes(Operation *op, PatternRewriter &rewriter) {
// For now, use a small allowlist of types we don't reject.
// The main culprit in practice is that !numpy.any_dtype might be present
// if shape/dtype inference wasn't good enough.
auto isValidLinalgType = [](Type type) {
if (auto rankedTensor = type.dyn_cast<RankedTensorType>()) {
if (BaseMemRefType::isValidElementType(rankedTensor.getElementType()))
return true;
}
if (type.isa<FloatType, IntegerType, IndexType>())
return true;
return false;
};
bool valid = llvm::all_of(op->getOperandTypes(), isValidLinalgType) &&
llvm::all_of(op->getResultTypes(), isValidLinalgType);
if (!valid)
return rewriter.notifyMatchFailure(op, "type cannot be lowered to linalg");
return success();
}
LogicalResult convertMmOp(aten::MmOp op, PatternRewriter &rewriter) {
Location loc = op->getLoc();
Value lhs = op.getOperand(0);
Value rhs = op.getOperand(1);
// A user can write an errorneous program where `aten.mm` is in fact called
// with operands of invalid rank or dtype. We cannot convert to linalg in this
// case or we will get a verifier error, which corresponds to breaking of
// *internal* compiler invariants, and for a user manifests as a compiler
// crash in the worst case (such as we try to canonicalize/fold/print the
// invalid op before the verifier gets to see it -- also release builds of a
// mature copmiler usually have the verifier turned off for compile time
// reasons).
//
// The compiler cannot crash even if the user wrote an erroneous program!
if (failed(verifyLinalgCompatibleTypes(op, rewriter)))
return failure();
if (lhs.getType().cast<RankedTensorType>().getRank() != 2 ||
rhs.getType().cast<RankedTensorType>().getRank() != 2) {
return rewriter.notifyMatchFailure(
op, "expected both operands to aten.mm to be rank 2");
}
Value lhsDim0 = rewriter.create<memref::DimOp>(loc, lhs, 0);
Value lhsDim1 = rewriter.create<memref::DimOp>(loc, lhs, 1);
Value rhsDim0 = rewriter.create<memref::DimOp>(loc, rhs, 0);
Value rhsDim1 = rewriter.create<memref::DimOp>(loc, rhs, 1);
Value contractingDimEqual =
rewriter.create<CmpIOp>(loc, CmpIPredicate::eq, lhsDim1, rhsDim0);
rewriter.create<AssertOp>(
loc, contractingDimEqual,
rewriter.getStringAttr("mismatching contracting dimension for aten.mm"));
Type elementType = op.getType().cast<TensorType>().getElementType();
Value initTensor = rewriter.create<linalg::InitTensorOp>(
loc, ValueRange{lhsDim0, rhsDim1}, elementType);
Value c0 = rewriter.create<ConstantOp>(loc, FloatAttr::get(elementType, 0.0));
Value zeroFill =
rewriter.create<linalg::FillOp>(loc, initTensor, c0).getResult(0);
Value matmul = rewriter
.create<linalg::MatmulOp>(loc, zeroFill.getType(),
ValueRange{lhs, rhs}, zeroFill)
.getResult(0);
// When constructed with just dynamic sizes, InitTensorOp will have a result
// type which has all `?`'s for dimensions, which might not be the result
// type of `op`. The constraints on later linalg ops means that the result of
// the MatmulOp will have this type too. So cast it to the desired type so
// that in the end we have the original result type.
rewriter.replaceOpWithNewOp<tensor::CastOp>(op, op.getType(), matmul);
return success();
}
// -----------------------------------------------------------------------------
// The pass
// -----------------------------------------------------------------------------
namespace {
class ConvertATenToLinalg
: public ConvertATenToLinalgBase<ConvertATenToLinalg> {
public:
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect>();
}
void runOnOperation() override {
(void)applyPatternsAndFoldGreedily(getOperation(), getPatterns());
}
FrozenRewritePatternList getPatterns() {
MLIRContext *context = &getContext();
RewritePatternSet patterns(context);
patterns.add(convertMmOp);
return std::move(patterns);
}
};
} // namespace
std::unique_ptr<OperationPass<FuncOp>>
mlir::NPCOMP::createConvertATenToLinalgPass() {
return std::make_unique<ConvertATenToLinalg>();
}

View File

@ -0,0 +1,18 @@
add_npcomp_conversion_library(NPCOMPATenToLinalg
ATenToLinalg.cpp
ADDITIONAL_HEADER_DIRS
${PROJECT_SOURCE_DIR}/include/npcomp/Conversion/ATenToLinalg
DEPENDS
NPCOMPConversionPassIncGen
LINK_COMPONENTS
Core
LINK_LIBS PUBLIC
MLIRIR
MLIRPass
NPCOMPATenDialect
MLIRLinalg
)

View File

@ -173,6 +173,5 @@ void mlir::NPCOMP::populateCoreATenToTCFPatterns(RewritePatternSet &patterns) {
patterns.add<ConvertUnary<aten::TanhOp, tcf::TanhOp>>(context);
patterns.add<ConvertBinaryElementwise<aten::MulOp, tcf::MulOp>>(context);
patterns.add<ConvertBinaryElementwise<aten::MaximumOp, tcf::MaxOp>>(context);
patterns.add<ConvertBinaryElementwise<aten::MmOp, tcf::MatmulOp>>(context);
patterns.add<ConvertATenConv2d>(context);
}

View File

@ -1,3 +1,4 @@
add_subdirectory(ATenToLinalg)
add_subdirectory(ATenToTCF)
add_subdirectory(BasicpyToStd)
add_subdirectory(NumpyToTCF)

View File

@ -8,6 +8,7 @@
#include "npcomp/Conversion/Passes.h"
#include "npcomp/Conversion/ATenToLinalg/ATenToLinalg.h"
#include "npcomp/Conversion/ATenToTCF/Passes.h"
#include "npcomp/Conversion/BasicpyToStd/Passes.h"
#include "npcomp/Conversion/NumpyToTCF/Passes.h"

View File

@ -75,7 +75,8 @@ bool operator==(const ValueKnowledge &lhs, const ValueKnowledge &rhs) {
std::make_tuple(rhs.hasRank, rhs.sizes, rhs.elementType);
}
// static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, ValueKnowledge &knowledge) {
// static llvm::raw_ostream &operator<<(llvm::raw_ostream &os, ValueKnowledge
// &knowledge) {
// os << "hasRank = " << knowledge.hasRank << ", sizes = [";
// llvm::interleaveComma(knowledge.sizes, os);
// os << "]"
@ -83,6 +84,16 @@ bool operator==(const ValueKnowledge &lhs, const ValueKnowledge &rhs) {
// return os;
// }
Type joinElementTypes(Type lhs, Type rhs) {
if (lhs.isa<Numpy::AnyDtypeType>())
return rhs;
if (rhs.isa<Numpy::AnyDtypeType>())
return lhs;
if (lhs == rhs)
return lhs;
return Numpy::AnyDtypeType::get(lhs.getContext());
}
// Given two pieces of static knowledge, calculate conservatively the
// information we can be sure about.
ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs) {
@ -116,13 +127,7 @@ ValueKnowledge join(const ValueKnowledge &lhs, const ValueKnowledge &rhs) {
}
}
if (!lhs.elementType || lhs.elementType.isa<Numpy::AnyDtypeType>()) {
result.elementType = rhs.elementType;
} else if (!rhs.elementType || rhs.elementType.isa<Numpy::AnyDtypeType>()) {
result.elementType = lhs.elementType;
} else if (lhs.elementType == rhs.elementType) {
result.elementType = lhs.elementType;
}
result.elementType = joinElementTypes(lhs.elementType, rhs.elementType);
return result;
}
@ -157,6 +162,52 @@ private:
DenseMap<Value, ValueKnowledge> facts;
};
// Return the knowledge for the results of an op, based on the knowledge of the
// operands and any information intrinsic to `op`.
static SmallVector<ValueKnowledge>
forwardKnowledgeTransferFunction(Operation *op,
ArrayRef<ValueKnowledge> operandKnowledge) {
if (isa<Numpy::TensorStaticInfoCastOp, aten::TanhOp>(op)) {
return {operandKnowledge[0]};
} else if (isa<aten::MmOp>(op)) {
auto &lhs = operandKnowledge[0];
auto &rhs = operandKnowledge[1];
auto knowledge =
ValueKnowledge::getMostConservativeKnowledge(op->getContext());
knowledge.hasRank = true;
// WARNING: We could be more precise here by calculating the output
// shape as "(lhs.shape[0], rhs.shape[1])". However, that is really tricky
// at this stage in the compiler because we don't really have many static
// guarantees about the input ranks because `aten` ops do dynamic error
// checking and safely abort the program. There is nothing preventing us
// from (correctly!) statically inferring the shapes of the operands to
// shapes that are guaranteed to cause an error at runtime.
//
// Example: Suppose a user program calls `aten.mm` with two rank-0 operands.
// The program emits an error when invoked, but when running this pass,
// we will (correctly!) infer `lhs.hasRank && lhs.sizes.size() == 0 &&
// rhs.hasRank && rhs.sizes.size() == 0` -- it's not safe to access
// `lhs.sizes[0]` / `rhs.sizes[1]`! So when writing this transfer
// function, it's not as simple as taking `lhs.sizes[0]` and `rhs.sizes[1]`,
// as both of those might read out of bounds of the array. It would require
// more complicated logic.
//
// Just knowing dtypes and ranks is sufficient at this stage
// in the compiler. The precise per-dimension size propagation is best done
// lower in the stack, such as at the linalg level, where we have more
// static guarantees and more structure.
knowledge.sizes.resize(2, kUnknownSize);
// TODO: Investigate promotion rules if element types mismatch.
// This is conservatively correct, assuming that if both element types are
// the same, then the result is of that same element type.
knowledge.elementType = joinElementTypes(lhs.elementType, rhs.elementType);
return {knowledge};
}
return SmallVector<ValueKnowledge>(
op->getNumResults(),
ValueKnowledge::getMostConservativeKnowledge(op->getContext()));
}
void TypeAnalyzer::propagate(Region &region) {
bool changed;
do {
@ -168,10 +219,13 @@ void TypeAnalyzer::propagate(Region &region) {
for (Operation &op : block.getOperations()) {
for (Value v : op.getResults())
changed |= incorporateKnowledge(v, getKnowledgeFromType(v.getType()));
if (isa<Numpy::TensorStaticInfoCastOp, aten::TanhOp>(op)) {
changed |= incorporateKnowledge(op.getResult(0),
getKnowledge(op.getOperand(0)));
}
auto operandKnowledge = llvm::to_vector<6>(llvm::map_range(
op.getOperands(), [&](Value v) { return getKnowledge(v); }));
SmallVector<ValueKnowledge> resultKnowledge =
forwardKnowledgeTransferFunction(&op, operandKnowledge);
assert(resultKnowledge.size() == op.getNumResults());
for (auto t : llvm::zip(op.getResults(), resultKnowledge))
changed |= incorporateKnowledge(std::get<0>(t), std::get<1>(t));
}
};
} while (changed);

View File

@ -61,6 +61,7 @@ TORCH_TO_TCP_PASSES = (
# Lower to TCP (+ guards) which is the input to codegen backends.
# Most of this should be subsumed by aten->linalg+guards conversions.
# (the guard generation will be automated from the linalg Op DSL)
"func(convert-aten-to-linalg)",
"func(convert-aten-to-tcf)",
"func(convert-tcf-to-std)",
"func(convert-elementwise-to-linalg)",

View File

@ -0,0 +1,44 @@
// RUN: npcomp-opt <%s -convert-aten-to-linalg | FileCheck %s
// CHECK-LABEL: func @aten.mm$basic(
// CHECK-SAME: %[[LHS:.*]]: tensor<?x?xf32>,
// CHECK-SAME: %[[RHS:.*]]: tensor<?x?xf32>) -> tensor<?x2xf32> {
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[CF0:.*]] = constant 0.000000e+00 : f32
// CHECK: %[[LHS_DIM_0:.*]] = memref.dim %[[LHS]], %[[C0]] : tensor<?x?xf32>
// CHECK: %[[LHS_DIM_1:.*]] = memref.dim %[[LHS]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[RHS_DIM_0:.*]] = memref.dim %[[RHS]], %[[C0]] : tensor<?x?xf32>
// CHECK: %[[RHS_DIM_1:.*]] = memref.dim %[[RHS]], %[[C1]] : tensor<?x?xf32>
// CHECK: %[[EQ:.*]] = cmpi eq, %[[LHS_DIM_1]], %[[RHS_DIM_0]] : index
// CHECK: assert %[[EQ]], "mismatching contracting dimension for aten.mm"
// CHECK: %[[INIT_TENSOR:.*]] = linalg.init_tensor [%[[LHS_DIM_0]], %[[RHS_DIM_1]]] : tensor<?x?xf32>
// CHECK: %[[ZEROFILL:.*]] = linalg.fill(%[[INIT_TENSOR]], %[[CF0]]) : tensor<?x?xf32>, f32 -> tensor<?x?xf32>
// CHECK: %[[MATMUL:.*]] = linalg.matmul ins(%[[LHS]], %[[RHS]] : tensor<?x?xf32>, tensor<?x?xf32>) outs(%[[ZEROFILL]] : tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[CASTED:.*]] = tensor.cast %[[MATMUL]] : tensor<?x?xf32> to tensor<?x2xf32>
// CHECK: return %[[CASTED]] : tensor<?x2xf32>
func @aten.mm$basic(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x2xf32> {
%0 = "aten.mm"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x2xf32>
return %0 : tensor<?x2xf32>
}
// CHECK-LABEL: func @aten.mm$no_convert$missing_dtype
func @aten.mm$no_convert$missing_dtype(%arg0: tensor<*x!numpy.any_dtype>, %arg1: tensor<*x!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype> {
// CHECK-NEXT: aten.mm
%0 = "aten.mm"(%arg0, %arg1) : (tensor<*x!numpy.any_dtype>, tensor<*x!numpy.any_dtype>) -> tensor<*x!numpy.any_dtype>
return %0 : tensor<*x!numpy.any_dtype>
}
// CHECK-LABEL: func @aten.mm$no_convert$wrong_rank
func @aten.mm$no_convert$wrong_rank(%arg0: tensor<?xf32>, %arg1: tensor<?xf32>) -> tensor<*x!numpy.any_dtype> {
// CHECK-NEXT: aten.mm
%0 = "aten.mm"(%arg0, %arg1) : (tensor<?xf32>, tensor<?xf32>) -> tensor<*x!numpy.any_dtype>
return %0 : tensor<*x!numpy.any_dtype>
}
// CHECK-LABEL: func @aten.mm$no_convert$result_missing_dtype
func @aten.mm$no_convert$result_missing_dtype(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<*xf32> {
// CHECK-NEXT: aten.mm
%0 = "aten.mm"(%arg0, %arg1) : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

View File

@ -24,6 +24,19 @@ func @f(%arg0: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> {
// -----
// CHECK-LABEL: func @f(
// CHECK-SAME: %[[LHS:.*]]: tensor<2x?xf32>,
// CHECK-SAME: %[[RHS:.*]]: tensor<?x?xf32>) -> tensor<*x!numpy.any_dtype> {
// CHECK: %[[MM:.*]] = "aten.mm"(%[[LHS]], %[[RHS]]) : (tensor<2x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
// CHECK: %[[SHAPE_ERASED:.*]] = numpy.tensor_static_info_cast %[[MM]] : tensor<?x?xf32> to tensor<*x!numpy.any_dtype>
// CHECK: return %[[SHAPE_ERASED]] : tensor<*x!numpy.any_dtype>
func @f(%arg0: tensor<2x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<*x!numpy.any_dtype> {
%1 = "aten.mm"(%arg0, %arg1) : (tensor<2x?xf32>, tensor<?x?xf32>) -> tensor<*x!numpy.any_dtype>
return %1 : tensor<*x!numpy.any_dtype>
}
// -----
// CHECK-LABEL: func @f
func @f(%arg0: tensor<2x3x?xf32>) -> tensor<*x!numpy.any_dtype> {
// Check propagation through multiple ops.