diff --git a/python/src/pywy/tests/test_dl.py b/python/src/pywy/tests/test_dl.py index a8f542469..c424bad84 100644 --- a/python/src/pywy/tests/test_dl.py +++ b/python/src/pywy/tests/test_dl.py @@ -26,53 +26,55 @@ from pywy.basic.model.optimizer import GradientDescent from pywy.basic.model.option import Option from pywy.basic.model.models import DLModel +from importlib import resources +from pywy.tests import resources as resources_folder -# TODO: fix this test by giving it proper test resources & fixing some type issues with lists. -@pytest.mark.skip(reason="no way of currently testing this, since we are missing implementations for proper test resources & types in types.py") def test_dl_tensorflow(): - l1 = Linear(4, 64, True) - s1 = Sigmoid() - l2 = Linear(64, 3, True) + with resources.path(resources_folder, "sample_data.md") as resource_path, \ + resources.path(resources_folder, "wordcount_out_python.txt") as output_path: + l1 = Linear(4, 64, True) + s1 = Sigmoid() + l2 = Linear(64, 3, True) - s1.with_ops(l1.with_ops(Input(Input.Type.FEATURES))) - l2.with_ops(s1) + s1.with_ops(l1.with_ops(Input(Input.Type.FEATURES))) + l2.with_ops(s1) - model = DLModel(l2) + model = DLModel(l2) - criterion = CrossEntropyLoss(3) - criterion.with_ops( - Input(Input.Type.PREDICTED), - Input(Input.Type.LABEL, Op.DType.INT32) - ) - acc = Mean(0) - acc.with_ops( - Cast(Op.DType.FLOAT32).with_ops( - Eq().with_ops( - ArgMax(1).with_ops( - Input(Input.Type.PREDICTED) - ), - Input(Input.Type.LABEL, Op.DType.INT32) + criterion = CrossEntropyLoss(3) + criterion.with_ops( + Input(Input.Type.PREDICTED), + Input(Input.Type.LABEL, Op.DType.INT32) + ) + acc = Mean(0) + acc.with_ops( + Cast(Op.DType.FLOAT32).with_ops( + Eq().with_ops( + ArgMax(1).with_ops( + Input(Input.Type.PREDICTED) + ), + Input(Input.Type.LABEL, Op.DType.INT32) + ) ) ) - ) - optimizer = GradientDescent(0.02) - option = Option(criterion, optimizer, 6, 100) + optimizer = GradientDescent(0.02) + option = Option(criterion, optimizer, 6, 100) - floats: List[List[int]] = [[5.1, 3.5, 1.4, 0.2]] + floats: List[List[float]] = [[5.1, 3.5, 1.4, 0.2]] - ints: List[List[int]] = [[0, 0, 1, 1, 2, 2]] + ints: List[List[int]] = [[0, 0, 1, 1, 2, 2]] - ctx = WayangContext() \ - .register({JavaPlugin, SparkPlugin, TensorflowPlugin}) - trainXSource = ctx.textfile("file:///var/www/html/README.md").map(lambda x: floats, str, List[List[float]]) - trainYSource = ctx.textfile("file:///var/www/html/README.md").map(lambda x: floats, str, List[List[float]]) - testXSource = ctx.textfile("file:///var/www/html/README.md").map(lambda x: floats, str, List[List[float]]) + ctx = WayangContext() \ + .register({JavaPlugin, SparkPlugin, TensorflowPlugin}) + trainXSource = ctx.textfile(f"file://{resource_path}").map(lambda x: floats, str, List[List[float]]) + trainYSource = ctx.textfile(f"file://{resource_path}").map(lambda x: floats, str, List[List[float]]) + testXSource = ctx.textfile(f"file://{resource_path}").map(lambda x: floats, str, List[List[float]]) - data_quanta = trainXSource.dlTraining(model, option, trainYSource, List[List[float]], List[List[float]]) \ - .predict(testXSource, List[List[float]], List[List[float]]) \ - .map(lambda x: "Test", List[List[float]], str) \ - .store_textfile("file:///var/www/html/data/wordcount-out-python.txt", List[float]) - - assert data_quanta is not None \ No newline at end of file + data_quanta = trainXSource.dlTraining(model, option, trainYSource, List[List[float]], List[List[float]]) \ + .predict(testXSource, List[List[float]], List[List[float]]) \ + .map(lambda x: "Test", List[List[float]], str) \ + .store_textfile(f"file://{output_path}", str) + + assert data_quanta is not None \ No newline at end of file diff --git a/python/src/pywy/types.py b/python/src/pywy/types.py index cbfa0ee5c..41d825ca3 100644 --- a/python/src/pywy/types.py +++ b/python/src/pywy/types.py @@ -194,9 +194,9 @@ def typecheck(input_type: Type[ConstrainedOperatorType]): origin = get_origin(input_type) args = get_args(input_type) - if isinstance(input_type, List) and args: + if origin is list and args: typecheck(args[0]) - elif isinstance(input_type, Tuple): + elif origin is tuple: if all(arg in allowed_types for arg in args): return else: