/* * Copyright 2012 Phil Pratt-Szeliga and other contributors * http://chirrup.org/ * * See the file LICENSE for copying permission. */ package org.trifort.rootbeer.test; import java.io.ByteArrayOutputStream; import java.io.PrintWriter; import java.util.ArrayList; import java.util.List; import org.trifort.rootbeer.configuration.Configuration; import org.trifort.rootbeer.runtime.Context; import org.trifort.rootbeer.runtime.Kernel; import org.trifort.rootbeer.runtime.Rootbeer; import org.trifort.rootbeer.runtime.RootbeerGpu; import org.trifort.rootbeer.runtime.ThreadConfig; import org.trifort.rootbeer.runtime.util.Stopwatch; import org.trifort.rootbeer.util.ForceGC; public class RootbeerTestAgent { private long m_cpuTime; private long m_gpuTime; private boolean m_passed; private String m_message; private List<String> m_failedTests; public RootbeerTestAgent(){ m_failedTests = new ArrayList<String>(); } public void testOne(ClassLoader cls_loader, String test_case) throws Exception { Class test_case_cls = cls_loader.loadClass(test_case); Object test_case_obj = test_case_cls.newInstance(); if(test_case_obj instanceof TestSerialization){ TestSerialization test_ser = (TestSerialization) test_case_obj; System.out.println("[TEST 1/1] "+test_ser.toString()); if(test_case.equals("org.trifort.rootbeer.testcases.rootbeertest.gpurequired.ChangeThreadTest")){ testChangeThread(test_ser, true); } else { test(test_ser, true); } if(m_passed){ System.out.println(" PASSED"); System.out.println(" Cpu time: "+m_cpuTime+" ms"); System.out.println(" Gpu time: "+m_gpuTime+" ms"); } else { System.out.println(" FAILED"); System.out.println(" "+m_message); } } else if(test_case_obj instanceof TestException){ TestException test_ex = (TestException) test_case_obj; System.out.println("[TEST 1/1] "+test_ex.toString()); ex_test(test_ex, true); if(m_passed){ System.out.println(" PASSED"); } else { System.out.println(" FAILED"); System.out.println(" "+m_message); } } else if(test_case_obj instanceof TestKernelTemplate){ TestKernelTemplate test_kernel_template = (TestKernelTemplate) test_case_obj; System.out.println("[TEST 1/1] "+test_kernel_template.toString()); test(test_kernel_template, true); if(m_passed){ System.out.println(" PASSED"); System.out.println(" Cpu time: "+m_cpuTime+" ms"); System.out.println(" Gpu time: "+m_gpuTime+" ms"); } else { System.out.println(" FAILED"); System.out.println(" "+m_message); } } else if(test_case_obj instanceof TestApplication){ TestApplication test_application = (TestApplication) test_case_obj; System.out.println("[TEST 1/1] "+test_application.toString()); if(test_application.test()){ System.out.println(" PASSED"); } else { System.out.println(" FAILED"); System.out.println(" "+test_application.errorMessage()); } } else { throw new RuntimeException("unknown test case type"); } } public void test(ClassLoader cls_loader, boolean run_hard_tests) throws Exception { LoadTestSerialization loader = new LoadTestSerialization(); List<TestSerialization> creators = loader.load(cls_loader, "org.trifort.rootbeer.test.Main", run_hard_tests); List<TestException> ex_creators = loader.loadException(cls_loader, "org.trifort.rootbeer.test.ExMain"); List<TestSerialization> change_thread = loader.load(cls_loader, "org.trifort.rootbeer.test.ChangeThread", run_hard_tests); List<TestKernelTemplate> kernel_template_creators = loader.loadKernelTemplate(cls_loader, "org.trifort.rootbeer.test.KernelTemplateMain"); List<TestApplication> application_creators = loader.loadApplication(cls_loader, "org.trifort.rootbeer.test.ApplicationMain"); int num_tests = creators.size() + ex_creators.size() + change_thread.size() + kernel_template_creators.size() + application_creators.size(); int test_num = 1; for(TestSerialization creator : creators){ System.out.println("[TEST "+test_num+"/"+num_tests+"] "+creator.toString()); test(creator, false); ForceGC.gc(); if(m_passed){ System.out.println(" PASSED"); System.out.println(" Cpu time: "+m_cpuTime+" ms"); System.out.println(" Gpu time: "+m_gpuTime+" ms"); } else { System.out.println(" FAILED"); System.out.println(" "+m_message); m_failedTests.add(creator.toString()); } ++test_num; } for(TestException ex_creator : ex_creators){ System.out.println("[TEST "+test_num+"/"+num_tests+"] "+ex_creator.toString()); ex_test(ex_creator, false); if(m_passed){ System.out.println(" PASSED"); } else { System.out.println(" FAILED"); System.out.println(" "+m_message); m_failedTests.add(ex_creator.toString()); } ++test_num; } for(TestSerialization creator : change_thread){ System.out.println("[TEST "+test_num+"/"+num_tests+"] "+creator.toString()); testChangeThread(creator, false); if(m_passed){ System.out.println(" PASSED"); System.out.println(" Cpu time: "+m_cpuTime+" ms"); System.out.println(" Gpu time: "+m_gpuTime+" ms"); } else { System.out.println(" FAILED"); System.out.println(" "+m_message); m_failedTests.add(creator.toString()); } ++test_num; } for(TestKernelTemplate kernel_template : kernel_template_creators){ System.out.println("[TEST "+test_num+"/"+num_tests+"] "+kernel_template.toString()); test(kernel_template, false); ForceGC.gc(); if(m_passed){ System.out.println(" PASSED"); System.out.println(" Cpu time: "+m_cpuTime+" ms"); System.out.println(" Gpu time: "+m_gpuTime+" ms"); } else { System.out.println(" FAILED"); System.out.println(" "+m_message); m_failedTests.add(kernel_template.toString()); } ++test_num; } for(TestApplication application : application_creators){ System.out.println("[TEST "+test_num+"/"+num_tests+"] "+application.toString()); if(application.test()){ System.out.println(" PASSED"); } else { System.out.println(" FAILED"); System.out.println(" "+application.errorMessage()); m_failedTests.add(application.toString()); } ++test_num; } int test_passed = num_tests - m_failedTests.size(); System.out.println(test_passed+"/"+num_tests+" tests PASS"); if(test_passed == num_tests){ System.out.println("ALL TESTS PASS!"); } else { System.out.println("Failing tests:"); for(String failure : m_failedTests){ System.out.println(" "+failure); } } } private void test(TestSerialization creator, boolean print_mem) { int i = 0; try { Rootbeer rootbeer = new Rootbeer(); Configuration.setPrintMem(print_mem); List<Kernel> known_good_items = creator.create(); List<Kernel> testing_items = creator.create(); Stopwatch watch = new Stopwatch(); watch.start(); rootbeer.run(testing_items); m_passed = true; watch.stop(); m_gpuTime = watch.elapsedTimeMillis(); watch.start(); for(i = 0; i < known_good_items.size(); ++i){ Kernel known_good_item = known_good_items.get(i); known_good_item.gpuMethod(); } watch.stop(); m_cpuTime = watch.elapsedTimeMillis(); for(i = 0; i < known_good_items.size(); ++i){ Kernel known_good_item = known_good_items.get(i); Kernel testing_item = testing_items.get(i); if(!creator.compare(known_good_item, testing_item)){ m_message = "Compare failed at: "+i; m_passed = false; return; } } } catch(Throwable ex){ ex.printStackTrace(System.out); m_message = "Exception thrown at index: "+i; m_passed = false; } } private void test(TestKernelTemplate creator, boolean print_mem) { int i = 0; try { Rootbeer rootbeer = new Rootbeer(); Configuration.setPrintMem(print_mem); Kernel known_good_item = creator.create(); Kernel testing_item = creator.create(); ThreadConfig thread_config = creator.getThreadConfig(); Stopwatch watch = new Stopwatch(); watch.start(); Context context = rootbeer.createDefaultContext(); context.setKernel(testing_item); context.setThreadConfig(thread_config); context.buildState(); context.run(); context.close(); m_passed = true; watch.stop(); m_gpuTime = watch.elapsedTimeMillis(); watch.start(); RootbeerGpu.setBlockDimx(thread_config.getThreadCountX()); RootbeerGpu.setBlockDimy(thread_config.getThreadCountY()); RootbeerGpu.setBlockDimz(thread_config.getThreadCountZ()); RootbeerGpu.setGridDimx(thread_config.getBlockCountX()); RootbeerGpu.setGridDimy(thread_config.getBlockCountY()); for(int blockx = 0; blockx < thread_config.getBlockCountX(); ++blockx){ RootbeerGpu.setBlockIdxx(blockx); for(int blocky = 0; blocky < thread_config.getBlockCountY(); ++blocky){ RootbeerGpu.setBlockIdxy(blocky); for(int threadx = 0; threadx < thread_config.getThreadCountX(); ++threadx){ RootbeerGpu.setThreadIdxx(threadx); for(int thready = 0; thready < thread_config.getThreadCountY(); ++thready){ RootbeerGpu.setThreadIdxy(thready); for(int threadz = 0; threadz < thread_config.getThreadCountZ(); ++threadz){ RootbeerGpu.setThreadIdxz(threadz); known_good_item.gpuMethod(); } } } } } watch.stop(); m_cpuTime = watch.elapsedTimeMillis(); if(!creator.compare(known_good_item, testing_item)){ m_message = "Compare failed at: "+i; m_passed = false; return; } } catch(Throwable ex){ ex.printStackTrace(System.out); m_message = "Exception thrown at index: "+i; m_passed = false; } } private void ex_test(TestException creator, boolean print_mem) { Rootbeer rootbeer = new Rootbeer(); Configuration.setPrintMem(print_mem); List<Kernel> testing_items = creator.create(); try { rootbeer.run(testing_items); m_passed = false; m_message = "No exception thrown when expecting one."; } catch(Throwable ex){ m_passed = creator.catchException(ex); if(m_passed == false){ m_message = "Exception is: "+ex.toString(); } } } private void testChangeThread(TestSerialization creator, boolean print_mem) { Thread t = new Thread(new ChangeThread(creator, print_mem)); t.start(); try { t.join(); } catch(Exception ex){ ex.printStackTrace(); } } private class ChangeThread implements Runnable { private TestSerialization m_creator; private boolean m_printMem; private Rootbeer m_rootbeer; public ChangeThread(TestSerialization creator, boolean print_mem){ m_creator = creator; m_printMem = print_mem; m_rootbeer = new Rootbeer(); } public void run() { int i = 0; try { Configuration.setPrintMem(m_printMem); List<Kernel> known_good_items = m_creator.create(); List<Kernel> testing_items = m_creator.create(); Stopwatch watch = new Stopwatch(); watch.start(); m_rootbeer.run(testing_items); m_passed = true; watch.stop(); m_gpuTime = watch.elapsedTimeMillis(); watch.start(); for(i = 0; i < known_good_items.size(); ++i){ Kernel known_good_item = known_good_items.get(i); known_good_item.gpuMethod(); } watch.stop(); m_cpuTime = watch.elapsedTimeMillis(); for(i = 0; i < known_good_items.size(); ++i){ Kernel known_good_item = known_good_items.get(i); Kernel testing_item = testing_items.get(i); if(!m_creator.compare(known_good_item, testing_item)){ m_message = "Compare failed at: "+i; m_passed = false; return; } } } catch(Throwable ex){ ex.printStackTrace(System.out); m_message = "Exception thrown at index: "+i; m_passed = false; } } } }